diff --git a/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/reloperators/HiveSortExchange.java b/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/reloperators/HiveSortExchange.java index 8cbc953..3961636 100644 --- a/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/reloperators/HiveSortExchange.java +++ b/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/reloperators/HiveSortExchange.java @@ -6,42 +6,61 @@ import org.apache.calcite.rel.RelCollationTraitDef; import org.apache.calcite.rel.RelDistribution; import org.apache.calcite.rel.RelDistributionTraitDef; -import org.apache.calcite.rel.RelInput; import org.apache.calcite.rel.RelNode; import org.apache.calcite.rel.core.SortExchange; +import org.apache.calcite.rex.RexNode; +import org.apache.hadoop.hive.ql.plan.ExprNodeDesc; + +import com.google.common.collect.ImmutableList; public class HiveSortExchange extends SortExchange { + private ImmutableList joinKeys; + private ExprNodeDesc[] joinExpressions; private HiveSortExchange(RelOptCluster cluster, RelTraitSet traitSet, - RelNode input, RelDistribution distribution, RelCollation collation) { + RelNode input, RelDistribution distribution, RelCollation collation, ImmutableList joinKeys) { super(cluster, traitSet, input, distribution, collation); + this.joinKeys = new ImmutableList.Builder().addAll(joinKeys).build(); } - public HiveSortExchange(RelInput input) { - super(input); - } - /** * Creates a HiveSortExchange. * * @param input Input relational expression * @param distribution Distribution specification * @param collation Collation specification + * @param joinKeys Join Keys specification */ public static HiveSortExchange create(RelNode input, - RelDistribution distribution, RelCollation collation) { + RelDistribution distribution, RelCollation collation, ImmutableList joinKeys) { RelOptCluster cluster = input.getCluster(); distribution = RelDistributionTraitDef.INSTANCE.canonize(distribution); collation = RelCollationTraitDef.INSTANCE.canonize(collation); RelTraitSet traitSet = RelTraitSet.createEmpty().plus(distribution).plus(collation); - return new HiveSortExchange(cluster, traitSet, input, distribution, collation); + return new HiveSortExchange(cluster, traitSet, input, distribution, collation, joinKeys); } @Override public SortExchange copy(RelTraitSet traitSet, RelNode newInput, RelDistribution newDistribution, RelCollation newCollation) { return new HiveSortExchange(getCluster(), traitSet, newInput, - newDistribution, newCollation); + newDistribution, newCollation, joinKeys); + } + + public ImmutableList getJoinKeys() { + return joinKeys; + } + + public void setJoinKeys(ImmutableList joinKeys) { + this.joinKeys = joinKeys; + } + + public ExprNodeDesc[] getJoinExpressions() { + return joinExpressions; + } + + public void setJoinExpressions(ExprNodeDesc[] joinExpressions) { + this.joinExpressions = joinExpressions; } } diff --git a/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/rules/HiveInsertExchange4JoinRule.java b/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/rules/HiveInsertExchange4JoinRule.java index 6cceacb..30db8fd 100644 --- a/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/rules/HiveInsertExchange4JoinRule.java +++ b/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/rules/HiveInsertExchange4JoinRule.java @@ -29,6 +29,7 @@ import org.apache.calcite.rel.core.Exchange; import org.apache.calcite.rel.core.Join; import org.apache.calcite.rel.rules.MultiJoin; +import org.apache.calcite.rex.RexNode; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.apache.hadoop.hive.ql.optimizer.calcite.HiveCalciteUtil; @@ -91,6 +92,7 @@ public void onMatch(RelOptRuleCall call) { List newInputs = new ArrayList(); for (int i=0; i joinKeyPositions = new ArrayList(); + ImmutableList.Builder keyListBuilder = new ImmutableList.Builder(); ImmutableList.Builder collationListBuilder = new ImmutableList.Builder(); for (int j = 0; j < joinPredInfo.getEquiJoinPredicateElements().size(); j++) { @@ -100,12 +102,14 @@ public void onMatch(RelOptRuleCall call) { if (!joinKeyPositions.contains(pos)) { joinKeyPositions.add(pos); collationListBuilder.add(new RelFieldCollation(pos)); + keyListBuilder.add(joinLeafPredInfo.getJoinKeyExprs(i).get(0)); } } } HiveSortExchange exchange = HiveSortExchange.create(call.rel(0).getInput(i), new HiveRelDistribution(RelDistribution.Type.HASH_DISTRIBUTED, joinKeyPositions), - new HiveRelCollation(collationListBuilder.build())); + new HiveRelCollation(collationListBuilder.build()), + keyListBuilder.build()); newInputs.add(exchange); } diff --git a/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/translator/HiveOpConverter.java b/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/translator/HiveOpConverter.java index 94528e2..b8ae604 100644 --- a/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/translator/HiveOpConverter.java +++ b/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/translator/HiveOpConverter.java @@ -27,6 +27,7 @@ import java.util.Map; import java.util.Set; +import org.apache.calcite.rel.RelCollation; import org.apache.calcite.rel.RelCollations; import org.apache.calcite.rel.RelDistribution; import org.apache.calcite.rel.RelDistribution.Type; @@ -65,6 +66,7 @@ import org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveJoin; import org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveProject; import org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveSort; +import org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveSortExchange; import org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveTableScan; import org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveUnion; import org.apache.hadoop.hive.ql.parse.JoinCond; @@ -167,8 +169,8 @@ OpAttr dispatch(RelNode rn) throws SemanticException { return visit((HiveSort) rn); } else if (rn instanceof HiveUnion) { return visit((HiveUnion) rn); - } else if (rn instanceof SortExchange) { - return visit((SortExchange) rn); + } else if (rn instanceof HiveSortExchange) { + return visit((HiveSortExchange) rn); } else if (rn instanceof HiveAggregate) { return visit((HiveAggregate) rn); } @@ -329,8 +331,11 @@ private OpAttr translateJoin(RelNode joinRel) throws SemanticException { joinPredInfo = JoinPredicateInfo.constructJoinPredicateInfo((MultiJoin)joinRel); } - // 3. Extract join keys from condition - ExprNodeDesc[][] joinKeys = extractJoinKeys(joinPredInfo, joinRel.getInputs(), inputs); + // 3. Extract join key expressions from HiveSortExchange + ExprNodeDesc[][] joinExpressions = new ExprNodeDesc[inputs.length][]; + for (int i = 0; i < inputs.length; i++) { + joinExpressions[i] = ((HiveSortExchange) joinRel.getInput(i)).getJoinExpressions(); + } // 4.a Generate tags for (int tag=0; tag