diff --git ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/rules/HiveJoinToMultiJoinRule.java ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/rules/HiveJoinToMultiJoinRule.java index d0a29a7..ba4b917 100644 --- ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/rules/HiveJoinToMultiJoinRule.java +++ ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/rules/HiveJoinToMultiJoinRule.java @@ -19,6 +19,7 @@ import java.util.ArrayList; import java.util.List; +import java.util.Set; import org.apache.calcite.plan.RelOptRule; import org.apache.calcite.plan.RelOptRuleCall; @@ -46,6 +47,7 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.Lists; +import com.google.common.collect.Sets; /** * Rule that merges a join with multijoin/join children if @@ -128,6 +130,7 @@ private static RelNode mergeJoin(Join join, RelNode left, RelNode right) { final List> joinInputs = Lists.newArrayList(); final List joinTypes = Lists.newArrayList(); final List joinFilters = Lists.newArrayList(); + final List systemFieldList = ImmutableList.of(); // Left child if (left instanceof HiveJoin || left instanceof HiveMultiJoin) { @@ -151,8 +154,8 @@ private static RelNode mergeJoin(Join join, RelNode left, RelNode right) { boolean combinable; try { - combinable = isCombinablePredicate(join, join.getCondition(), - leftCondition); + combinable = isCombinablePredicate(join.getInputs(), systemFieldList, join.getCondition(), + left.getInputs(), systemFieldList, leftCondition); } catch (CalciteSemanticException e) { LOG.trace("Failed to merge joins", e); combinable = false; @@ -182,7 +185,6 @@ private static RelNode mergeJoin(Join join, RelNode left, RelNode right) { return null; } - final List systemFieldList = ImmutableList.of(); List> joinKeyExprs = new ArrayList>(); List filterNulls = new ArrayList(); for (int i=0; i joinInputs, + List systemFieldList, RexNode condition, + List childJoinInputs, List childSystemFieldList, + RexNode childCondition) throws CalciteSemanticException { final JoinPredicateInfo joinPredInfo = HiveCalciteUtil.JoinPredicateInfo. - constructJoinPredicateInfo(join, condition); - final JoinPredicateInfo otherJoinPredInfo = HiveCalciteUtil.JoinPredicateInfo. - constructJoinPredicateInfo(join, otherCondition); - if (joinPredInfo.getProjsFromLeftPartOfJoinKeysInJoinSchema(). - equals(otherJoinPredInfo.getProjsFromLeftPartOfJoinKeysInJoinSchema())) { - return false; - } - if (joinPredInfo.getProjsFromRightPartOfJoinKeysInJoinSchema(). - equals(otherJoinPredInfo.getProjsFromRightPartOfJoinKeysInJoinSchema())) { - return false; + constructJoinPredicateInfo(joinInputs, systemFieldList, condition); + final JoinPredicateInfo childJoinPredInfo = HiveCalciteUtil.JoinPredicateInfo. + constructJoinPredicateInfo(childJoinInputs, childSystemFieldList, childCondition); + Set keys = joinPredInfo.getProjsFromLeftPartOfJoinKeysInJoinSchema(); + Set childKeys = Sets.newHashSet(); + for (int i=0; i