diff --git a/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/rules/HiveJoinToMultiJoinRule.java b/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/rules/HiveJoinToMultiJoinRule.java index 5d169a1..1426fee 100644 --- a/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/rules/HiveJoinToMultiJoinRule.java +++ b/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/rules/HiveJoinToMultiJoinRule.java @@ -42,6 +42,8 @@ import org.apache.hadoop.hive.ql.optimizer.calcite.HiveRelOptUtil; import org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveJoin; import org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveMultiJoin; +import org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveSortExchange; +import org.apache.hadoop.hive.ql.plan.ExprNodeDesc; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -202,6 +204,20 @@ private static RelNode mergeJoin(HiveJoin join, RelNode left, RelNode right) { LOG.trace("Failed to merge joins", e); return null; } + + if (joinKeyExprs.size() > 0) { + List baseJoinExpr = joinKeyExprs.get(0); + // If we dont join on the same key types, bail out. + for (int i = 1; i < joinKeyExprs.size(); i++) { + List currJoinExpr = joinKeyExprs.get(i); + if ((baseJoinExpr == null && currJoinExpr != null) || + (baseJoinExpr != null && currJoinExpr == null) || + (baseJoinExpr.size() != currJoinExpr.size())) { + return null; + } + } + } + ImmutableBitSet.Builder keysInInputsBuilder = ImmutableBitSet.builder(); for (int i=0; i partialCondition = joinKeyExprs.get(i);