diff --git a/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/reloperators/HiveMultiJoin.java b/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/reloperators/HiveMultiJoin.java index 660f01d..cff737c 100644 --- a/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/reloperators/HiveMultiJoin.java +++ b/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/reloperators/HiveMultiJoin.java @@ -70,6 +70,7 @@ * INNER * @param filters filters associated with each join * input + * @param joinPredicateInfo join predicate information */ public HiveMultiJoin( RelOptCluster cluster, @@ -78,9 +79,10 @@ public HiveMultiJoin( RelDataType rowType, List> joinInputs, List joinTypes, - List filters) { + List filters, + JoinPredicateInfo joinPredicateInfo) { super(cluster, TraitsUtil.getDefaultTraitSet(cluster)); - this.inputs = Lists.newArrayList(inputs); + this.inputs = inputs; this.condition = condition; this.rowType = rowType; @@ -89,14 +91,27 @@ public HiveMultiJoin( this.joinTypes = ImmutableList.copyOf(joinTypes); this.filters = ImmutableList.copyOf(filters); this.outerJoin = containsOuter(); - - try { - this.joinPredInfo = HiveCalciteUtil.JoinPredicateInfo.constructJoinPredicateInfo(this); - } catch (CalciteSemanticException e) { - throw new RuntimeException(e); + if (joinPredicateInfo == null) { + try { + this.joinPredInfo = HiveCalciteUtil.JoinPredicateInfo.constructJoinPredicateInfo(this); + } catch (CalciteSemanticException e) { + throw new RuntimeException(e); + } + } else { + this.joinPredInfo = joinPredicateInfo; } } + public HiveMultiJoin( + RelOptCluster cluster, + List inputs, + RexNode condition, + RelDataType rowType, + List> joinInputs, + List joinTypes, + List filters) { + this(cluster, Lists.newArrayList(inputs), condition, rowType, joinInputs, joinTypes, filters, null); + } @Override public void replaceInput(int ordinalInParent, RelNode p) { diff --git a/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/rules/HiveInsertExchange4JoinRule.java b/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/rules/HiveInsertExchange4JoinRule.java index d6e3915..c9cf396 100644 --- a/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/rules/HiveInsertExchange4JoinRule.java +++ b/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/rules/HiveInsertExchange4JoinRule.java @@ -38,6 +38,7 @@ import org.apache.hadoop.hive.ql.optimizer.calcite.HiveCalciteUtil.JoinPredicateInfo; import org.apache.hadoop.hive.ql.optimizer.calcite.HiveRelCollation; import org.apache.hadoop.hive.ql.optimizer.calcite.HiveRelDistribution; +import org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveJoin; import org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveMultiJoin; import org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveSortExchange; @@ -76,11 +77,10 @@ public void onMatch(RelOptRuleCall call) { JoinPredicateInfo joinPredInfo; if (call.rel(0) instanceof HiveMultiJoin) { HiveMultiJoin multiJoin = call.rel(0); - try { - joinPredInfo = HiveCalciteUtil.JoinPredicateInfo.constructJoinPredicateInfo(multiJoin); - } catch (CalciteSemanticException e) { - throw new RuntimeException(e); - } + joinPredInfo = multiJoin.getJoinPredicateInfo(); + } else if (call.rel(0) instanceof HiveJoin) { + HiveJoin hiveJoin = call.rel(0); + joinPredInfo = hiveJoin.getJoinPredicateInfo(); } else if (call.rel(0) instanceof Join) { Join join = call.rel(0); try { diff --git a/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/rules/HiveJoinToMultiJoinRule.java b/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/rules/HiveJoinToMultiJoinRule.java index 5d169a1..ce9535f 100644 --- a/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/rules/HiveJoinToMultiJoinRule.java +++ b/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/rules/HiveJoinToMultiJoinRule.java @@ -239,14 +239,49 @@ private static RelNode mergeJoin(HiveJoin join, RelNode left, RelNode right) { // We can now create a multijoin operator RexNode newCondition = RexUtil.flatten(rexBuilder, RexUtil.composeConjunction(rexBuilder, newJoinCondition, false)); + List newInputsArray = Lists.newArrayList(newInputs); + JoinPredicateInfo joinPredInfo = null; + try { + joinPredInfo = HiveCalciteUtil.JoinPredicateInfo.constructJoinPredicateInfo(newInputsArray, systemFieldList, newCondition); + } catch (CalciteSemanticException e) { + throw new RuntimeException(e); + } + + // If the number of joins < number of input tables-1, this is not a star join. + if (joinPredInfo.getEquiJoinPredicateElements().size() < newInputs.size()-1) { + return null; + } + // Validate that the multi-join is a valid star join before returning it. + for (int i=0; i joinKeys = null; + for (int j = 0; j < joinPredInfo.getEquiJoinPredicateElements().size(); j++) { + List currJoinKeys = joinPredInfo. + getEquiJoinPredicateElements().get(j).getJoinExprs(i); + if (currJoinKeys.isEmpty()) { + continue; + } + if (joinKeys == null) { + joinKeys = currJoinKeys; + } else { + // If we join on different keys on different tables, we can no longer apply + // multi-join conversion as this is no longer a valid star join. + // Bail out if this is the case. + if (!joinKeys.containsAll(currJoinKeys) || !currJoinKeys.containsAll(joinKeys)) { + return null; + } + } + } + } + return new HiveMultiJoin( join.getCluster(), - newInputs, + newInputsArray, newCondition, join.getRowType(), joinInputs, joinTypes, - joinFilters); + joinFilters, + joinPredInfo); } /*