diff --git a/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/rules/HiveInsertExchange4JoinRule.java b/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/rules/HiveInsertExchange4JoinRule.java index d6e3915..c9cf396 100644 --- a/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/rules/HiveInsertExchange4JoinRule.java +++ b/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/rules/HiveInsertExchange4JoinRule.java @@ -38,6 +38,7 @@ import org.apache.hadoop.hive.ql.optimizer.calcite.HiveCalciteUtil.JoinPredicateInfo; import org.apache.hadoop.hive.ql.optimizer.calcite.HiveRelCollation; import org.apache.hadoop.hive.ql.optimizer.calcite.HiveRelDistribution; +import org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveJoin; import org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveMultiJoin; import org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveSortExchange; @@ -76,11 +77,10 @@ public void onMatch(RelOptRuleCall call) { JoinPredicateInfo joinPredInfo; if (call.rel(0) instanceof HiveMultiJoin) { HiveMultiJoin multiJoin = call.rel(0); - try { - joinPredInfo = HiveCalciteUtil.JoinPredicateInfo.constructJoinPredicateInfo(multiJoin); - } catch (CalciteSemanticException e) { - throw new RuntimeException(e); - } + joinPredInfo = multiJoin.getJoinPredicateInfo(); + } else if (call.rel(0) instanceof HiveJoin) { + HiveJoin hiveJoin = call.rel(0); + joinPredInfo = hiveJoin.getJoinPredicateInfo(); } else if (call.rel(0) instanceof Join) { Join join = call.rel(0); try { diff --git a/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/rules/HiveJoinToMultiJoinRule.java b/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/rules/HiveJoinToMultiJoinRule.java index 5d169a1..bf0cd80 100644 --- a/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/rules/HiveJoinToMultiJoinRule.java +++ b/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/rules/HiveJoinToMultiJoinRule.java @@ -37,6 +37,7 @@ import org.apache.calcite.util.Pair; import org.apache.hadoop.hive.ql.optimizer.calcite.CalciteSemanticException; import org.apache.hadoop.hive.ql.optimizer.calcite.HiveCalciteUtil; +import org.apache.hadoop.hive.ql.optimizer.calcite.HiveCalciteUtil.JoinLeafPredicateInfo; import org.apache.hadoop.hive.ql.optimizer.calcite.HiveCalciteUtil.JoinPredicateInfo; import org.apache.hadoop.hive.ql.optimizer.calcite.HiveRelFactories; import org.apache.hadoop.hive.ql.optimizer.calcite.HiveRelOptUtil; @@ -47,6 +48,7 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.Lists; +import com.google.common.collect.Sets; /** * Rule that merges a join with multijoin/join children if @@ -239,7 +241,7 @@ private static RelNode mergeJoin(HiveJoin join, RelNode left, RelNode right) { // We can now create a multijoin operator RexNode newCondition = RexUtil.flatten(rexBuilder, RexUtil.composeConjunction(rexBuilder, newJoinCondition, false)); - return new HiveMultiJoin( + HiveMultiJoin multiJoin = new HiveMultiJoin( join.getCluster(), newInputs, newCondition, @@ -247,6 +249,38 @@ private static RelNode mergeJoin(HiveJoin join, RelNode left, RelNode right) { joinInputs, joinTypes, joinFilters); + + boolean firstInput = true; + int numJoinKeys = 0; + JoinPredicateInfo joinPredInfo = multiJoin.getJoinPredicateInfo(); + + // Validate that the multi-join is a valid star join before returning it. + for (int i=0; i keySet = Sets.newHashSet(); + for (int j = 0; j < joinPredInfo.getEquiJoinPredicateElements().size(); j++) { + JoinLeafPredicateInfo joinLeafPredInfo = joinPredInfo. + getEquiJoinPredicateElements().get(j); + for (RexNode joinExprNode : joinLeafPredInfo.getJoinExprs(i)) { + if (keySet.add(joinExprNode.toString())) { + numKeys++; + } + } + } + if (firstInput) { + numJoinKeys = numKeys; + firstInput = false; + } else { + // If we join on different keys on different tables, we can no longer apply + // multi-join conversion as this is no longer a valid star join. + // Bail out if this is the case. + if (numJoinKeys != numKeys) { + return null; + } + } + } + + return multiJoin; } /*