diff --git a/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/reloperators/HiveMultiJoin.java b/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/reloperators/HiveMultiJoin.java index 660f01d..cff737c 100644 --- a/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/reloperators/HiveMultiJoin.java +++ b/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/reloperators/HiveMultiJoin.java @@ -70,6 +70,7 @@ * INNER * @param filters filters associated with each join * input + * @param joinPredicateInfo join predicate information */ public HiveMultiJoin( RelOptCluster cluster, @@ -78,9 +79,10 @@ public HiveMultiJoin( RelDataType rowType, List> joinInputs, List joinTypes, - List filters) { + List filters, + JoinPredicateInfo joinPredicateInfo) { super(cluster, TraitsUtil.getDefaultTraitSet(cluster)); - this.inputs = Lists.newArrayList(inputs); + this.inputs = inputs; this.condition = condition; this.rowType = rowType; @@ -89,14 +91,27 @@ public HiveMultiJoin( this.joinTypes = ImmutableList.copyOf(joinTypes); this.filters = ImmutableList.copyOf(filters); this.outerJoin = containsOuter(); - - try { - this.joinPredInfo = HiveCalciteUtil.JoinPredicateInfo.constructJoinPredicateInfo(this); - } catch (CalciteSemanticException e) { - throw new RuntimeException(e); + if (joinPredicateInfo == null) { + try { + this.joinPredInfo = HiveCalciteUtil.JoinPredicateInfo.constructJoinPredicateInfo(this); + } catch (CalciteSemanticException e) { + throw new RuntimeException(e); + } + } else { + this.joinPredInfo = joinPredicateInfo; } } + public HiveMultiJoin( + RelOptCluster cluster, + List inputs, + RexNode condition, + RelDataType rowType, + List> joinInputs, + List joinTypes, + List filters) { + this(cluster, Lists.newArrayList(inputs), condition, rowType, joinInputs, joinTypes, filters, null); + } @Override public void replaceInput(int ordinalInParent, RelNode p) { diff --git a/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/rules/HiveInsertExchange4JoinRule.java b/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/rules/HiveInsertExchange4JoinRule.java index d6e3915..c9cf396 100644 --- a/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/rules/HiveInsertExchange4JoinRule.java +++ b/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/rules/HiveInsertExchange4JoinRule.java @@ -38,6 +38,7 @@ import org.apache.hadoop.hive.ql.optimizer.calcite.HiveCalciteUtil.JoinPredicateInfo; import org.apache.hadoop.hive.ql.optimizer.calcite.HiveRelCollation; import org.apache.hadoop.hive.ql.optimizer.calcite.HiveRelDistribution; +import org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveJoin; import org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveMultiJoin; import org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveSortExchange; @@ -76,11 +77,10 @@ public void onMatch(RelOptRuleCall call) { JoinPredicateInfo joinPredInfo; if (call.rel(0) instanceof HiveMultiJoin) { HiveMultiJoin multiJoin = call.rel(0); - try { - joinPredInfo = HiveCalciteUtil.JoinPredicateInfo.constructJoinPredicateInfo(multiJoin); - } catch (CalciteSemanticException e) { - throw new RuntimeException(e); - } + joinPredInfo = multiJoin.getJoinPredicateInfo(); + } else if (call.rel(0) instanceof HiveJoin) { + HiveJoin hiveJoin = call.rel(0); + joinPredInfo = hiveJoin.getJoinPredicateInfo(); } else if (call.rel(0) instanceof Join) { Join join = call.rel(0); try { diff --git a/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/rules/HiveJoinToMultiJoinRule.java b/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/rules/HiveJoinToMultiJoinRule.java index 5d169a1..f329a6f 100644 --- a/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/rules/HiveJoinToMultiJoinRule.java +++ b/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/rules/HiveJoinToMultiJoinRule.java @@ -37,6 +37,7 @@ import org.apache.calcite.util.Pair; import org.apache.hadoop.hive.ql.optimizer.calcite.CalciteSemanticException; import org.apache.hadoop.hive.ql.optimizer.calcite.HiveCalciteUtil; +import org.apache.hadoop.hive.ql.optimizer.calcite.HiveCalciteUtil.JoinLeafPredicateInfo; import org.apache.hadoop.hive.ql.optimizer.calcite.HiveCalciteUtil.JoinPredicateInfo; import org.apache.hadoop.hive.ql.optimizer.calcite.HiveRelFactories; import org.apache.hadoop.hive.ql.optimizer.calcite.HiveRelOptUtil; @@ -47,6 +48,7 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.Lists; +import com.google.common.collect.Sets; /** * Rule that merges a join with multijoin/join children if @@ -239,14 +241,47 @@ private static RelNode mergeJoin(HiveJoin join, RelNode left, RelNode right) { // We can now create a multijoin operator RexNode newCondition = RexUtil.flatten(rexBuilder, RexUtil.composeConjunction(rexBuilder, newJoinCondition, false)); + List newInputsArray = Lists.newArrayList(newInputs); + JoinPredicateInfo joinPredInfo = null; + try { + joinPredInfo = HiveCalciteUtil.JoinPredicateInfo.constructJoinPredicateInfo(newInputsArray, systemFieldList, newCondition); + } catch (CalciteSemanticException e) { + throw new RuntimeException(e); + } + + boolean firstInput = true; + int numJoinKeys = 0; + + // Validate that the multi-join is a valid star join before returning it. + for (int i=0; i