diff --git a/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/reloperators/HiveSortExchange.java b/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/reloperators/HiveSortExchange.java index 8cbc953..3961636 100644 --- a/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/reloperators/HiveSortExchange.java +++ b/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/reloperators/HiveSortExchange.java @@ -6,42 +6,61 @@ import org.apache.calcite.rel.RelCollationTraitDef; import org.apache.calcite.rel.RelDistribution; import org.apache.calcite.rel.RelDistributionTraitDef; -import org.apache.calcite.rel.RelInput; import org.apache.calcite.rel.RelNode; import org.apache.calcite.rel.core.SortExchange; +import org.apache.calcite.rex.RexNode; +import org.apache.hadoop.hive.ql.plan.ExprNodeDesc; + +import com.google.common.collect.ImmutableList; public class HiveSortExchange extends SortExchange { + private ImmutableList joinKeys; + private ExprNodeDesc[] joinExpressions; private HiveSortExchange(RelOptCluster cluster, RelTraitSet traitSet, - RelNode input, RelDistribution distribution, RelCollation collation) { + RelNode input, RelDistribution distribution, RelCollation collation, ImmutableList joinKeys) { super(cluster, traitSet, input, distribution, collation); + this.joinKeys = new ImmutableList.Builder().addAll(joinKeys).build(); } - public HiveSortExchange(RelInput input) { - super(input); - } - /** * Creates a HiveSortExchange. * * @param input Input relational expression * @param distribution Distribution specification * @param collation Collation specification + * @param joinKeys Join Keys specification */ public static HiveSortExchange create(RelNode input, - RelDistribution distribution, RelCollation collation) { + RelDistribution distribution, RelCollation collation, ImmutableList joinKeys) { RelOptCluster cluster = input.getCluster(); distribution = RelDistributionTraitDef.INSTANCE.canonize(distribution); collation = RelCollationTraitDef.INSTANCE.canonize(collation); RelTraitSet traitSet = RelTraitSet.createEmpty().plus(distribution).plus(collation); - return new HiveSortExchange(cluster, traitSet, input, distribution, collation); + return new HiveSortExchange(cluster, traitSet, input, distribution, collation, joinKeys); } @Override public SortExchange copy(RelTraitSet traitSet, RelNode newInput, RelDistribution newDistribution, RelCollation newCollation) { return new HiveSortExchange(getCluster(), traitSet, newInput, - newDistribution, newCollation); + newDistribution, newCollation, joinKeys); + } + + public ImmutableList getJoinKeys() { + return joinKeys; + } + + public void setJoinKeys(ImmutableList joinKeys) { + this.joinKeys = joinKeys; + } + + public ExprNodeDesc[] getJoinExpressions() { + return joinExpressions; + } + + public void setJoinExpressions(ExprNodeDesc[] joinExpressions) { + this.joinExpressions = joinExpressions; } } diff --git a/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/rules/HiveInsertExchange4JoinRule.java b/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/rules/HiveInsertExchange4JoinRule.java index b5404a3..ab97de1 100644 --- a/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/rules/HiveInsertExchange4JoinRule.java +++ b/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/rules/HiveInsertExchange4JoinRule.java @@ -27,6 +27,7 @@ import org.apache.calcite.rel.RelNode; import org.apache.calcite.rel.core.Exchange; import org.apache.calcite.rel.core.Join; +import org.apache.calcite.rex.RexNode; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.apache.hadoop.hive.ql.optimizer.calcite.HiveCalciteUtil; @@ -79,6 +80,8 @@ public void onMatch(RelOptRuleCall call) { new ImmutableList.Builder(); ImmutableList.Builder rightCollationListBuilder = new ImmutableList.Builder(); + ImmutableList.Builder leftKeyListBuilder = new ImmutableList.Builder(); + ImmutableList.Builder rightKeyListBuilder = new ImmutableList.Builder(); for (int i = 0; i < joinPredInfo.getEquiJoinPredicateElements().size(); i++) { JoinLeafPredicateInfo joinLeafPredInfo = joinPredInfo. getEquiJoinPredicateElements().get(i); @@ -90,14 +93,19 @@ public void onMatch(RelOptRuleCall call) { for (int rightPos : joinLeafPredInfo.getProjsFromRightPartOfJoinKeysInChildSchema()) { rightCollationListBuilder.add(new RelFieldCollation(rightPos)); } + //getJoinKeyExprs(0) is left, getJoinKeyExprs(1) is right + leftKeyListBuilder.add(joinLeafPredInfo.getJoinKeyExprs(0).get(0)); + rightKeyListBuilder.add(joinLeafPredInfo.getJoinKeyExprs(1).get(0)); } HiveSortExchange left = HiveSortExchange.create(join.getLeft(), new HiveRelDistribution(RelDistribution.Type.HASH_DISTRIBUTED, joinLeftKeyPositions), - new HiveRelCollation(leftCollationListBuilder.build())); + new HiveRelCollation(leftCollationListBuilder.build()), + leftKeyListBuilder.build()); HiveSortExchange right = HiveSortExchange.create(join.getRight(), new HiveRelDistribution(RelDistribution.Type.HASH_DISTRIBUTED, joinRightKeyPositions), - new HiveRelCollation(rightCollationListBuilder.build())); + new HiveRelCollation(rightCollationListBuilder.build()), + rightKeyListBuilder.build()); Join newJoin = join.copy(join.getTraitSet(), join.getCondition(), left, right, join.getJoinType(), join.isSemiJoinDone()); diff --git a/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/translator/HiveOpConverter.java b/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/translator/HiveOpConverter.java index 082a36c..9335d70 100644 --- a/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/translator/HiveOpConverter.java +++ b/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/translator/HiveOpConverter.java @@ -27,6 +27,7 @@ import java.util.Map; import java.util.Set; +import org.apache.calcite.rel.RelCollation; import org.apache.calcite.rel.RelCollations; import org.apache.calcite.rel.RelDistribution; import org.apache.calcite.rel.RelDistribution.Type; @@ -64,6 +65,7 @@ import org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveJoin; import org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveProject; import org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveSort; +import org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveSortExchange; import org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveTableScan; import org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveUnion; import org.apache.hadoop.hive.ql.parse.JoinCond; @@ -165,8 +167,8 @@ OpAttr dispatch(RelNode rn) throws SemanticException { return visit((HiveSort) rn); } else if (rn instanceof HiveUnion) { return visit((HiveUnion) rn); - } else if (rn instanceof SortExchange) { - return visit((SortExchange) rn); + } else if (rn instanceof HiveSortExchange) { + return visit((HiveSortExchange) rn); } else if (rn instanceof HiveAggregate) { return visit((HiveAggregate) rn); } @@ -314,8 +316,11 @@ OpAttr visit(HiveJoin joinRel) throws SemanticException { // 2. Convert join condition JoinPredicateInfo joinPredInfo = JoinPredicateInfo.constructJoinPredicateInfo(joinRel); - // 3. Extract join keys from condition - ExprNodeDesc[][] joinKeys = extractJoinKeys(joinPredInfo, joinRel.getInputs(), inputs); + // 3. Extract join key expressions from HiveSortExchange + ExprNodeDesc[][] joinExpressions = new ExprNodeDesc[inputs.length][]; + for (int i = 0; i < inputs.length; i++) { + joinExpressions[i] = ((HiveSortExchange) joinRel.getInput(i)).getJoinExpressions(); + } // 4.a Generate tags for (int tag=0; tag