diff --git ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/HiveCalciteUtil.java ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/HiveCalciteUtil.java index 7614463..372c93d 100644 --- ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/HiveCalciteUtil.java +++ ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/HiveCalciteUtil.java @@ -31,6 +31,7 @@ import org.apache.calcite.rel.core.Join; import org.apache.calcite.rel.core.RelFactories.ProjectFactory; import org.apache.calcite.rel.core.Sort; +import org.apache.calcite.rel.rules.MultiJoin; import org.apache.calcite.rel.type.RelDataType; import org.apache.calcite.rel.type.RelDataTypeField; import org.apache.calcite.rex.RexBuilder; @@ -65,6 +66,7 @@ import com.google.common.collect.ImmutableMap.Builder; import com.google.common.collect.ImmutableSet; import com.google.common.collect.Lists; +import com.google.common.collect.Sets; /** * Generic utility functions needed for Calcite based Hive CBO. @@ -269,25 +271,21 @@ public static RexNode projectNonColumnEquiConditions(ProjectFactory factory, Rel public static class JoinPredicateInfo { private final ImmutableList nonEquiJoinPredicateElements; private final ImmutableList equiJoinPredicateElements; - private final ImmutableSet projsFromLeftPartOfJoinKeysInChildSchema; - private final ImmutableSet projsFromRightPartOfJoinKeysInChildSchema; - private final ImmutableSet projsFromRightPartOfJoinKeysInJoinSchema; + private final ImmutableList> projsJoinKeysInChildSchema; + private final ImmutableList> projsJoinKeysInJoinSchema; private final ImmutableMap> mapOfProjIndxInJoinSchemaToLeafPInfo; public JoinPredicateInfo(List nonEquiJoinPredicateElements, List equiJoinPredicateElements, - Set projsFromLeftPartOfJoinKeysInChildSchema, - Set projsFromRightPartOfJoinKeysInChildSchema, - Set projsFromRightPartOfJoinKeysInJoinSchema, + List> projsJoinKeysInChildSchema, + List> projsJoinKeysInJoinSchema, Map> mapOfProjIndxInJoinSchemaToLeafPInfo) { this.nonEquiJoinPredicateElements = ImmutableList.copyOf(nonEquiJoinPredicateElements); this.equiJoinPredicateElements = ImmutableList.copyOf(equiJoinPredicateElements); - this.projsFromLeftPartOfJoinKeysInChildSchema = ImmutableSet - .copyOf(projsFromLeftPartOfJoinKeysInChildSchema); - this.projsFromRightPartOfJoinKeysInChildSchema = ImmutableSet - .copyOf(projsFromRightPartOfJoinKeysInChildSchema); - this.projsFromRightPartOfJoinKeysInJoinSchema = ImmutableSet - .copyOf(projsFromRightPartOfJoinKeysInJoinSchema); + this.projsJoinKeysInChildSchema = ImmutableList + .copyOf(projsJoinKeysInChildSchema); + this.projsJoinKeysInJoinSchema = ImmutableList + .copyOf(projsJoinKeysInJoinSchema); this.mapOfProjIndxInJoinSchemaToLeafPInfo = ImmutableMap .copyOf(mapOfProjIndxInJoinSchemaToLeafPInfo); } @@ -301,11 +299,17 @@ public JoinPredicateInfo(List nonEquiJoinPredicateElement } public Set getProjsFromLeftPartOfJoinKeysInChildSchema() { - return this.projsFromLeftPartOfJoinKeysInChildSchema; + assert projsJoinKeysInChildSchema.size() == 2; + return this.projsJoinKeysInChildSchema.get(0); } public Set getProjsFromRightPartOfJoinKeysInChildSchema() { - return this.projsFromRightPartOfJoinKeysInChildSchema; + assert projsJoinKeysInChildSchema.size() == 2; + return this.projsJoinKeysInChildSchema.get(1); + } + + public Set getProjsJoinKeysInChildSchema(int i) { + return this.projsJoinKeysInChildSchema.get(i); } /** @@ -314,11 +318,17 @@ public JoinPredicateInfo(List nonEquiJoinPredicateElement * schema. */ public Set getProjsFromLeftPartOfJoinKeysInJoinSchema() { - return this.projsFromLeftPartOfJoinKeysInChildSchema; + assert projsJoinKeysInJoinSchema.size() == 2; + return this.projsJoinKeysInJoinSchema.get(0); } public Set getProjsFromRightPartOfJoinKeysInJoinSchema() { - return this.projsFromRightPartOfJoinKeysInJoinSchema; + assert projsJoinKeysInJoinSchema.size() == 2; + return this.projsJoinKeysInJoinSchema.get(1); + } + + public Set getProjsJoinKeysInJoinSchema(int i) { + return this.projsJoinKeysInJoinSchema.get(i); } public Map> getMapOfProjIndxToLeafPInfo() { @@ -328,20 +338,39 @@ public JoinPredicateInfo(List nonEquiJoinPredicateElement public static JoinPredicateInfo constructJoinPredicateInfo(Join j) { return constructJoinPredicateInfo(j, j.getCondition()); } + + public static JoinPredicateInfo constructJoinPredicateInfo(MultiJoin mj) { + return constructJoinPredicateInfo(mj, mj.getJoinFilter()); + } public static JoinPredicateInfo constructJoinPredicateInfo(Join j, RexNode predicate) { + return constructJoinPredicateInfo(j.getInputs(), j.getSystemFieldList(), predicate); + } + + public static JoinPredicateInfo constructJoinPredicateInfo(MultiJoin mj, RexNode predicate) { + final List systemFieldList = ImmutableList.of(); + return constructJoinPredicateInfo(mj.getInputs(), systemFieldList, predicate); + } + + public static JoinPredicateInfo constructJoinPredicateInfo(List inputs, + List systemFieldList, RexNode predicate) { JoinPredicateInfo jpi = null; JoinLeafPredicateInfo jlpi = null; List equiLPIList = new ArrayList(); List nonEquiLPIList = new ArrayList(); - Set projsFromLeftPartOfJoinKeys = new HashSet(); - Set projsFromRightPartOfJoinKeys = new HashSet(); - Set projsFromRightPartOfJoinKeysInJoinSchema = new HashSet(); + List> projsJoinKeys = new ArrayList>(); + for (int i=0; i projsJoinKeysInput = Sets.newHashSet(); + projsJoinKeys.add(projsJoinKeysInput); + } + List> projsJoinKeysInJoinSchema = new ArrayList>(); + for (int i=0; i projsJoinKeysInJoinSchemaInput = Sets.newHashSet(); + projsJoinKeysInJoinSchema.add(projsJoinKeysInJoinSchemaInput); + } Map> tmpMapOfProjIndxInJoinSchemaToLeafPInfo = new HashMap>(); Map> mapOfProjIndxInJoinSchemaToLeafPInfo = new HashMap>(); List tmpJLPILst = null; - int rightOffSet = j.getLeft().getRowType().getFieldCount(); - int projIndxInJoin; List conjuctiveElements; // 1. Decompose Join condition to a number of leaf predicates @@ -351,7 +380,7 @@ public static JoinPredicateInfo constructJoinPredicateInfo(Join j, RexNode predi // 2. Walk through leaf predicates building up JoinLeafPredicateInfo for (RexNode ce : conjuctiveElements) { // 2.1 Construct JoinLeafPredicateInfo - jlpi = JoinLeafPredicateInfo.constructJoinLeafPredicateInfo(j, ce); + jlpi = JoinLeafPredicateInfo.constructJoinLeafPredicateInfo(inputs, systemFieldList, ce); // 2.2 Classify leaf predicate as Equi vs Non Equi if (jlpi.comparisonType.equals(SqlKind.EQUALS)) { @@ -360,34 +389,21 @@ public static JoinPredicateInfo constructJoinPredicateInfo(Join j, RexNode predi nonEquiLPIList.add(jlpi); } - // 2.3 Maintain join keys coming from left vs right (in child & - // Join Schema) - projsFromLeftPartOfJoinKeys.addAll(jlpi.getProjsFromLeftPartOfJoinKeysInChildSchema()); - projsFromRightPartOfJoinKeys.addAll(jlpi.getProjsFromRightPartOfJoinKeysInChildSchema()); - projsFromRightPartOfJoinKeysInJoinSchema.addAll(jlpi - .getProjsFromRightPartOfJoinKeysInJoinSchema()); - + // 2.3 Maintain join keys (in child & Join Schema) // 2.4 Update Join Key to JoinLeafPredicateInfo map with keys - // from left - for (Integer projIndx : jlpi.getProjsFromLeftPartOfJoinKeysInChildSchema()) { - tmpJLPILst = tmpMapOfProjIndxInJoinSchemaToLeafPInfo.get(projIndx); - if (tmpJLPILst == null) - tmpJLPILst = new ArrayList(); - tmpJLPILst.add(jlpi); - tmpMapOfProjIndxInJoinSchemaToLeafPInfo.put(projIndx, tmpJLPILst); - } - - // 2.5 Update Join Key to JoinLeafPredicateInfo map with keys - // from right - for (Integer projIndx : jlpi.getProjsFromRightPartOfJoinKeysInChildSchema()) { - projIndxInJoin = projIndx + rightOffSet; - tmpJLPILst = tmpMapOfProjIndxInJoinSchemaToLeafPInfo.get(projIndxInJoin); - if (tmpJLPILst == null) - tmpJLPILst = new ArrayList(); - tmpJLPILst.add(jlpi); - tmpMapOfProjIndxInJoinSchemaToLeafPInfo.put(projIndxInJoin, tmpJLPILst); + for (int i=0; i(); + } + tmpJLPILst.add(jlpi); + tmpMapOfProjIndxInJoinSchemaToLeafPInfo.put(projIndx, tmpJLPILst); + } } - } // 3. Update Update Join Key to List to use @@ -398,9 +414,8 @@ public static JoinPredicateInfo constructJoinPredicateInfo(Join j, RexNode predi } // 4. Construct JoinPredicateInfo - jpi = new JoinPredicateInfo(nonEquiLPIList, equiLPIList, projsFromLeftPartOfJoinKeys, - projsFromRightPartOfJoinKeys, projsFromRightPartOfJoinKeysInJoinSchema, - mapOfProjIndxInJoinSchemaToLeafPInfo); + jpi = new JoinPredicateInfo(nonEquiLPIList, equiLPIList, projsJoinKeys, + projsJoinKeysInJoinSchema, mapOfProjIndxInJoinSchemaToLeafPInfo); return jpi; } } @@ -416,101 +431,112 @@ public static JoinPredicateInfo constructJoinPredicateInfo(Join j, RexNode predi * of equi join keys; the indexes are both in child and Join node schema.
*/ public static class JoinLeafPredicateInfo { - private final SqlKind comparisonType; - private final ImmutableList joinKeyExprsFromLeft; - private final ImmutableList joinKeyExprsFromRight; - private final ImmutableSet projsFromLeftPartOfJoinKeysInChildSchema; - private final ImmutableSet projsFromRightPartOfJoinKeysInChildSchema; - private final ImmutableSet projsFromRightPartOfJoinKeysInJoinSchema; - - public JoinLeafPredicateInfo(SqlKind comparisonType, List joinKeyExprsFromLeft, - List joinKeyExprsFromRight, Set projsFromLeftPartOfJoinKeysInChildSchema, - Set projsFromRightPartOfJoinKeysInChildSchema, - Set projsFromRightPartOfJoinKeysInJoinSchema) { + private final SqlKind comparisonType; + private final ImmutableList> joinKeyExprs; + private final ImmutableList> projsJoinKeysInChildSchema; + private final ImmutableList> projsJoinKeysInJoinSchema; + + public JoinLeafPredicateInfo( + SqlKind comparisonType, + List> joinKeyExprs, + List> projsJoinKeysInChildSchema, + List> projsJoinKeysInJoinSchema) { this.comparisonType = comparisonType; - this.joinKeyExprsFromLeft = ImmutableList.copyOf(joinKeyExprsFromLeft); - this.joinKeyExprsFromRight = ImmutableList.copyOf(joinKeyExprsFromRight); - this.projsFromLeftPartOfJoinKeysInChildSchema = ImmutableSet - .copyOf(projsFromLeftPartOfJoinKeysInChildSchema); - this.projsFromRightPartOfJoinKeysInChildSchema = ImmutableSet - .copyOf(projsFromRightPartOfJoinKeysInChildSchema); - this.projsFromRightPartOfJoinKeysInJoinSchema = ImmutableSet - .copyOf(projsFromRightPartOfJoinKeysInJoinSchema); + ImmutableList.Builder> joinKeyExprsBuilder = + ImmutableList.builder(); + for (int i=0; i> projsJoinKeysInChildSchemaBuilder = + ImmutableList.builder(); + for (int i=0; i> projsJoinKeysInJoinSchemaBuilder = + ImmutableList.builder(); + for (int i=0; i getJoinKeyExprs(int input) { - if (input == 0) { - return this.joinKeyExprsFromLeft; - } - if (input == 1) { - return this.joinKeyExprsFromRight; - } - return null; + return this.joinKeyExprs.get(input); } - public List getJoinKeyExprsFromLeft() { - return this.joinKeyExprsFromLeft; + public Set getProjsFromLeftPartOfJoinKeysInChildSchema() { + assert projsJoinKeysInChildSchema.size() == 2; + return this.projsJoinKeysInChildSchema.get(0); } - public List getJoinKeyExprsFromRight() { - return this.joinKeyExprsFromRight; + public Set getProjsFromRightPartOfJoinKeysInChildSchema() { + assert projsJoinKeysInChildSchema.size() == 2; + return this.projsJoinKeysInChildSchema.get(1); } - public Set getProjsFromLeftPartOfJoinKeysInChildSchema() { - return this.projsFromLeftPartOfJoinKeysInChildSchema; + public Set getProjsJoinKeysInChildSchema(int input) { + return this.projsJoinKeysInChildSchema.get(input); } - /** - * NOTE: Join Schema = left Schema + (right Schema offset by - * left.fieldcount). Hence its ok to return projections from left in child - * schema. - */ public Set getProjsFromLeftPartOfJoinKeysInJoinSchema() { - return this.projsFromLeftPartOfJoinKeysInChildSchema; + assert projsJoinKeysInJoinSchema.size() == 2; + return this.projsJoinKeysInJoinSchema.get(0); } - public Set getProjsFromRightPartOfJoinKeysInChildSchema() { - return this.projsFromRightPartOfJoinKeysInChildSchema; + public Set getProjsFromRightPartOfJoinKeysInJoinSchema() { + assert projsJoinKeysInJoinSchema.size() == 2; + return this.projsJoinKeysInJoinSchema.get(1); } - public Set getProjsFromRightPartOfJoinKeysInJoinSchema() { - return this.projsFromRightPartOfJoinKeysInJoinSchema; + public Set getProjsJoinKeysInJoinSchema(int input) { + return this.projsJoinKeysInJoinSchema.get(input); } - private static JoinLeafPredicateInfo constructJoinLeafPredicateInfo(Join j, RexNode pe) { + private static JoinLeafPredicateInfo constructJoinLeafPredicateInfo(List inputs, + List systemFieldList, RexNode pe) { JoinLeafPredicateInfo jlpi = null; List filterNulls = new ArrayList(); - List joinKeyExprsFromLeft = new ArrayList(); - List joinKeyExprsFromRight = new ArrayList(); - Set projsFromLeftPartOfJoinKeysInChildSchema = new HashSet(); - Set projsFromRightPartOfJoinKeysInChildSchema = new HashSet(); - Set projsFromRightPartOfJoinKeysInJoinSchema = new HashSet(); - int rightOffSet = j.getLeft().getRowType().getFieldCount(); + List> joinKeyExprs = new ArrayList>(); + for (int i=0; i()); + } // 1. Split leaf join predicate to expressions from left, right - HiveRelOptUtil.splitJoinCondition(j.getSystemFieldList(), j.getLeft(), j.getRight(), pe, - joinKeyExprsFromLeft, joinKeyExprsFromRight, filterNulls, null); - - // 2. For left expressions, collect child projection indexes used - InputReferencedVisitor irvLeft = new InputReferencedVisitor(); - irvLeft.apply(joinKeyExprsFromLeft); - projsFromLeftPartOfJoinKeysInChildSchema.addAll(irvLeft.inputPosReferenced); - - // 3. For right expressions, collect child projection indexes used - InputReferencedVisitor irvRight = new InputReferencedVisitor(); - irvRight.apply(joinKeyExprsFromRight); - projsFromRightPartOfJoinKeysInChildSchema.addAll(irvRight.inputPosReferenced); - - // 3. Translate projection indexes from right to join schema, by adding - // offset. - for (Integer indx : projsFromRightPartOfJoinKeysInChildSchema) { - projsFromRightPartOfJoinKeysInJoinSchema.add(indx + rightOffSet); + HiveRelOptUtil.splitJoinCondition(systemFieldList, inputs, pe, + joinKeyExprs, filterNulls, null); + + // 2. Collect child projection indexes used + List> projsJoinKeysInChildSchema = + new ArrayList>(); + for (int i=0; i projsFromInputJoinKeysInChildSchema = ImmutableSet.builder(); + InputReferencedVisitor irvLeft = new InputReferencedVisitor(); + irvLeft.apply(joinKeyExprs.get(i)); + projsFromInputJoinKeysInChildSchema.addAll(irvLeft.inputPosReferenced); + projsJoinKeysInChildSchema.add(projsFromInputJoinKeysInChildSchema.build()); + } + + // 3. Translate projection indexes to join schema, by adding offset. + List> projsJoinKeysInJoinSchema = + new ArrayList>(); + // The offset of the first input does not need to change. + projsJoinKeysInJoinSchema.add(projsJoinKeysInChildSchema.get(0)); + for (int i=1; i projsFromInputJoinKeysInJoinSchema = ImmutableSet.builder(); + for (Integer indx : projsJoinKeysInChildSchema.get(i)) { + projsFromInputJoinKeysInJoinSchema.add(indx + offSet); + } + projsJoinKeysInJoinSchema.add(projsFromInputJoinKeysInJoinSchema.build()); } // 4. Construct JoinLeafPredicateInfo - jlpi = new JoinLeafPredicateInfo(pe.getKind(), joinKeyExprsFromLeft, joinKeyExprsFromRight, - projsFromLeftPartOfJoinKeysInChildSchema, projsFromRightPartOfJoinKeysInChildSchema, - projsFromRightPartOfJoinKeysInJoinSchema); + jlpi = new JoinLeafPredicateInfo(pe.getKind(), joinKeyExprs, + projsJoinKeysInChildSchema, projsJoinKeysInJoinSchema); return jlpi; } diff --git ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/rules/HiveInsertExchange4JoinRule.java ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/rules/HiveInsertExchange4JoinRule.java index b5404a3..6cceacb 100644 --- ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/rules/HiveInsertExchange4JoinRule.java +++ ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/rules/HiveInsertExchange4JoinRule.java @@ -22,11 +22,13 @@ import org.apache.calcite.plan.RelOptRule; import org.apache.calcite.plan.RelOptRuleCall; +import org.apache.calcite.plan.hep.HepRelVertex; import org.apache.calcite.rel.RelDistribution; import org.apache.calcite.rel.RelFieldCollation; import org.apache.calcite.rel.RelNode; import org.apache.calcite.rel.core.Exchange; import org.apache.calcite.rel.core.Join; +import org.apache.calcite.rel.rules.MultiJoin; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.apache.hadoop.hive.ql.optimizer.calcite.HiveCalciteUtil; @@ -52,59 +54,77 @@ protected static transient final Log LOG = LogFactory .getLog(HiveInsertExchange4JoinRule.class); - public HiveInsertExchange4JoinRule() { - // match join with exactly 2 inputs - super(RelOptRule.operand(Join.class, - operand(RelNode.class, any()), - operand(RelNode.class, any()))); + /** Rule that creates Exchange operators under a MultiJoin operator. */ + public static final HiveInsertExchange4JoinRule EXCHANGE_BELOW_MULTIJOIN = + new HiveInsertExchange4JoinRule(MultiJoin.class); + + /** Rule that creates Exchange operators under a Join operator. */ + public static final HiveInsertExchange4JoinRule EXCHANGE_BELOW_JOIN = + new HiveInsertExchange4JoinRule(Join.class); + + public HiveInsertExchange4JoinRule(Class clazz) { + // match multijoin or join + super(RelOptRule.operand(clazz, any())); } @Override public void onMatch(RelOptRuleCall call) { - Join join = call.rel(0); - - if (call.rel(1) instanceof Exchange && - call.rel(2) instanceof Exchange) { + JoinPredicateInfo joinPredInfo; + if (call.rel(0) instanceof MultiJoin) { + MultiJoin multiJoin = call.rel(0); + joinPredInfo = HiveCalciteUtil.JoinPredicateInfo.constructJoinPredicateInfo(multiJoin); + } else if (call.rel(0) instanceof Join) { + Join join = call.rel(0); + joinPredInfo = HiveCalciteUtil.JoinPredicateInfo.constructJoinPredicateInfo(join); + } else { return; } - JoinPredicateInfo joinPredInfo = - HiveCalciteUtil.JoinPredicateInfo.constructJoinPredicateInfo(join); + for (RelNode child : call.rel(0).getInputs()) { + if (((HepRelVertex)child).getCurrentRel() instanceof Exchange) { + return; + } + } // get key columns from inputs. Those are the columns on which we will distribute on. // It is also the columns we will sort on. - List joinLeftKeyPositions = new ArrayList(); - List joinRightKeyPositions = new ArrayList(); - ImmutableList.Builder leftCollationListBuilder = - new ImmutableList.Builder(); - ImmutableList.Builder rightCollationListBuilder = - new ImmutableList.Builder(); - for (int i = 0; i < joinPredInfo.getEquiJoinPredicateElements().size(); i++) { - JoinLeafPredicateInfo joinLeafPredInfo = joinPredInfo. - getEquiJoinPredicateElements().get(i); - joinLeftKeyPositions.addAll(joinLeafPredInfo.getProjsFromLeftPartOfJoinKeysInChildSchema()); - for (int leftPos : joinLeafPredInfo.getProjsFromLeftPartOfJoinKeysInChildSchema()) { - leftCollationListBuilder.add(new RelFieldCollation(leftPos)); - } - joinRightKeyPositions.addAll(joinLeafPredInfo.getProjsFromRightPartOfJoinKeysInChildSchema()); - for (int rightPos : joinLeafPredInfo.getProjsFromRightPartOfJoinKeysInChildSchema()) { - rightCollationListBuilder.add(new RelFieldCollation(rightPos)); + List newInputs = new ArrayList(); + for (int i=0; i joinKeyPositions = new ArrayList(); + ImmutableList.Builder collationListBuilder = + new ImmutableList.Builder(); + for (int j = 0; j < joinPredInfo.getEquiJoinPredicateElements().size(); j++) { + JoinLeafPredicateInfo joinLeafPredInfo = joinPredInfo. + getEquiJoinPredicateElements().get(j); + for (int pos : joinLeafPredInfo.getProjsJoinKeysInChildSchema(i)) { + if (!joinKeyPositions.contains(pos)) { + joinKeyPositions.add(pos); + collationListBuilder.add(new RelFieldCollation(pos)); + } + } } + HiveSortExchange exchange = HiveSortExchange.create(call.rel(0).getInput(i), + new HiveRelDistribution(RelDistribution.Type.HASH_DISTRIBUTED, joinKeyPositions), + new HiveRelCollation(collationListBuilder.build())); + newInputs.add(exchange); } - HiveSortExchange left = HiveSortExchange.create(join.getLeft(), - new HiveRelDistribution(RelDistribution.Type.HASH_DISTRIBUTED, joinLeftKeyPositions), - new HiveRelCollation(leftCollationListBuilder.build())); - HiveSortExchange right = HiveSortExchange.create(join.getRight(), - new HiveRelDistribution(RelDistribution.Type.HASH_DISTRIBUTED, joinRightKeyPositions), - new HiveRelCollation(rightCollationListBuilder.build())); - - Join newJoin = join.copy(join.getTraitSet(), join.getCondition(), - left, right, join.getJoinType(), join.isSemiJoinDone()); + RelNode newOp; + if (call.rel(0) instanceof MultiJoin) { + MultiJoin multiJoin = call.rel(0); + newOp = multiJoin.copy(multiJoin.getTraitSet(), newInputs); + } else if (call.rel(0) instanceof Join) { + Join join = call.rel(0); + newOp = join.copy(join.getTraitSet(), join.getCondition(), + newInputs.get(0), newInputs.get(1), join.getJoinType(), + join.isSemiJoinDone()); + } else { + return; + } - call.getPlanner().onCopy(join, newJoin); + call.getPlanner().onCopy(call.rel(0), newOp); - call.transformTo(newJoin); + call.transformTo(newOp); } } diff --git ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/rules/HiveJoinToMultiJoinRule.java ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/rules/HiveJoinToMultiJoinRule.java new file mode 100644 index 0000000..532d7d3 --- /dev/null +++ ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/rules/HiveJoinToMultiJoinRule.java @@ -0,0 +1,333 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.hadoop.hive.ql.optimizer.calcite.rules; + +import java.util.List; +import java.util.Map; + +import org.apache.calcite.plan.RelOptRule; +import org.apache.calcite.plan.RelOptRuleCall; +import org.apache.calcite.plan.RelOptUtil; +import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rel.core.Join; +import org.apache.calcite.rel.core.JoinRelType; +import org.apache.calcite.rel.rules.MultiJoin; +import org.apache.calcite.rex.RexBuilder; +import org.apache.calcite.rex.RexInputRef; +import org.apache.calcite.rex.RexNode; +import org.apache.calcite.rex.RexUtil; +import org.apache.calcite.rex.RexVisitorImpl; +import org.apache.calcite.util.ImmutableBitSet; +import org.apache.calcite.util.ImmutableIntList; +import org.apache.calcite.util.Pair; +import org.apache.hadoop.hive.ql.optimizer.calcite.HiveCalciteUtil; +import org.apache.hadoop.hive.ql.optimizer.calcite.HiveCalciteUtil.JoinPredicateInfo; + +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.Lists; +import com.google.common.collect.Maps; + +/** + * Rule that merges a join with multijoin/join children if + * the equi compared the same set of input columns. + */ +public class HiveJoinToMultiJoinRule extends RelOptRule { + + public static final HiveJoinToMultiJoinRule INSTANCE = + new HiveJoinToMultiJoinRule(Join.class); + + //~ Constructors ----------------------------------------------------------- + + /** + * Creates a JoinToMultiJoinRule. + */ + public HiveJoinToMultiJoinRule(Class clazz) { + super( + operand(clazz, + operand(RelNode.class, any()), + operand(RelNode.class, any()))); + } + + //~ Methods ---------------------------------------------------------------- + + @Override + public void onMatch(RelOptRuleCall call) { + final Join join = call.rel(0); + final RelNode left = call.rel(1); + final RelNode right = call.rel(2); + + final RexBuilder rexBuilder = join.getCluster().getRexBuilder(); + + // We do not merge outer joins currently + if (join.getJoinType() != JoinRelType.INNER) { + return; + } + + // We check whether the join can be combined with any of its children + final List newInputs = Lists.newArrayList(); + final List newJoinFilters = Lists.newArrayList(); + newJoinFilters.add(join.getCondition()); + final List> joinSpecs = Lists.newArrayList(); + final List projFields = Lists.newArrayList(); + + // Left child + if (left instanceof Join || left instanceof MultiJoin) { + final RexNode leftCondition; + if (left instanceof Join) { + leftCondition = ((Join) left).getCondition(); + } else { + leftCondition = ((MultiJoin) left).getJoinFilter(); + } + + boolean combinable = isCombinablePredicate(join, join.getCondition(), + leftCondition); + if (combinable) { + newJoinFilters.add(leftCondition); + for (RelNode input : left.getInputs()) { + projFields.add(null); + joinSpecs.add(Pair.of(JoinRelType.INNER, (RexNode) null)); + newInputs.add(input); + } + } else { + projFields.add(null); + joinSpecs.add(Pair.of(JoinRelType.INNER, (RexNode) null)); + newInputs.add(left); + } + } else { + projFields.add(null); + joinSpecs.add(Pair.of(JoinRelType.INNER, (RexNode) null)); + newInputs.add(left); + } + + // Right child + if (right instanceof Join || right instanceof MultiJoin) { + final RexNode rightCondition; + if (right instanceof Join) { + rightCondition = shiftRightFilter(join, left, right, + ((Join) right).getCondition()); + } else { + rightCondition = shiftRightFilter(join, left, right, + ((MultiJoin) right).getJoinFilter()); + } + + boolean combinable = isCombinablePredicate(join, join.getCondition(), + rightCondition); + if (combinable) { + newJoinFilters.add(rightCondition); + for (RelNode input : right.getInputs()) { + projFields.add(null); + joinSpecs.add(Pair.of(JoinRelType.INNER, (RexNode) null)); + newInputs.add(input); + } + } else { + projFields.add(null); + joinSpecs.add(Pair.of(JoinRelType.INNER, (RexNode) null)); + newInputs.add(right); + } + } else { + projFields.add(null); + joinSpecs.add(Pair.of(JoinRelType.INNER, (RexNode) null)); + newInputs.add(right); + } + + // If we cannot combine any of the children, we bail out + if (newJoinFilters.size() == 1) { + return; + } + + RexNode newCondition = RexUtil.flatten(rexBuilder, + RexUtil.composeConjunction(rexBuilder, newJoinFilters, false)); + final ImmutableMap newJoinFieldRefCountsMap = + addOnJoinFieldRefCounts(newInputs, + join.getRowType().getFieldCount(), + newCondition); + + List newPostJoinFilters = combinePostJoinFilters(join, left, right); + + RelNode multiJoin = + new MultiJoin( + join.getCluster(), + newInputs, + newCondition, + join.getRowType(), + false, + Pair.right(joinSpecs), + Pair.left(joinSpecs), + projFields, + newJoinFieldRefCountsMap, + RexUtil.composeConjunction(rexBuilder, newPostJoinFilters, true)); + + call.transformTo(multiJoin); + } + + private static boolean isCombinablePredicate(Join join, + RexNode condition, RexNode otherCondition) { + final JoinPredicateInfo joinPredInfo = HiveCalciteUtil.JoinPredicateInfo. + constructJoinPredicateInfo(join, condition); + final JoinPredicateInfo otherJoinPredInfo = HiveCalciteUtil.JoinPredicateInfo. + constructJoinPredicateInfo(join, otherCondition); + if (joinPredInfo.getProjsFromLeftPartOfJoinKeysInJoinSchema(). + equals(otherJoinPredInfo.getProjsFromLeftPartOfJoinKeysInJoinSchema())) { + return false; + } + if (joinPredInfo.getProjsFromRightPartOfJoinKeysInJoinSchema(). + equals(otherJoinPredInfo.getProjsFromRightPartOfJoinKeysInJoinSchema())) { + return false; + } + return true; + } + + /** + * Shifts a filter originating from the right child of the LogicalJoin to the + * right, to reflect the filter now being applied on the resulting + * MultiJoin. + * + * @param joinRel the original LogicalJoin + * @param left the left child of the LogicalJoin + * @param right the right child of the LogicalJoin + * @param rightFilter the filter originating from the right child + * @return the adjusted right filter + */ + private RexNode shiftRightFilter( + Join joinRel, + RelNode left, + RelNode right, + RexNode rightFilter) { + if (rightFilter == null) { + return null; + } + + int nFieldsOnLeft = left.getRowType().getFieldList().size(); + int nFieldsOnRight = right.getRowType().getFieldList().size(); + int[] adjustments = new int[nFieldsOnRight]; + for (int i = 0; i < nFieldsOnRight; i++) { + adjustments[i] = nFieldsOnLeft; + } + rightFilter = + rightFilter.accept( + new RelOptUtil.RexInputConverter( + joinRel.getCluster().getRexBuilder(), + right.getRowType().getFieldList(), + joinRel.getRowType().getFieldList(), + adjustments)); + return rightFilter; + } + + /** + * Adds on to the existing join condition reference counts the references + * from the new join condition. + * + * @param multiJoinInputs inputs into the new MultiJoin + * @param nTotalFields total number of fields in the MultiJoin + * @param joinCondition the new join condition + * @param origJoinFieldRefCounts existing join condition reference counts + * + * @return Map containing the new join condition + */ + private ImmutableMap addOnJoinFieldRefCounts( + List multiJoinInputs, + int nTotalFields, + RexNode joinCondition) { + // count the input references in the join condition + int[] joinCondRefCounts = new int[nTotalFields]; + joinCondition.accept(new InputReferenceCounter(joinCondRefCounts)); + + // add on to the counts for each input into the MultiJoin the + // reference counts computed for the current join condition + final Map refCountsMap = Maps.newHashMap(); + int nInputs = multiJoinInputs.size(); + int currInput = -1; + int startField = 0; + int nFields = 0; + for (int i = 0; i < nTotalFields; i++) { + if (joinCondRefCounts[i] == 0) { + continue; + } + while (i >= (startField + nFields)) { + startField += nFields; + currInput++; + assert currInput < nInputs; + nFields = + multiJoinInputs.get(currInput).getRowType().getFieldCount(); + } + int[] refCounts = refCountsMap.get(currInput); + if (refCounts == null) { + refCounts = new int[nFields]; + refCountsMap.put(currInput, refCounts); + } + refCounts[i - startField] += joinCondRefCounts[i]; + } + + final ImmutableMap.Builder builder = + ImmutableMap.builder(); + for (Map.Entry entry : refCountsMap.entrySet()) { + builder.put(entry.getKey(), ImmutableIntList.of(entry.getValue())); + } + return builder.build(); + } + + /** + * Combines the post-join filters from the left and right inputs (if they + * are MultiJoinRels) into a single AND'd filter. + * + * @param joinRel the original LogicalJoin + * @param left left child of the LogicalJoin + * @param right right child of the LogicalJoin + * @return combined post-join filters AND'd together + */ + private List combinePostJoinFilters( + Join joinRel, + RelNode left, + RelNode right) { + final List filters = Lists.newArrayList(); + if (right instanceof MultiJoin) { + final MultiJoin multiRight = (MultiJoin) right; + filters.add( + shiftRightFilter(joinRel, left, multiRight, + multiRight.getPostJoinFilter())); + } + + if (left instanceof MultiJoin) { + filters.add(((MultiJoin) left).getPostJoinFilter()); + } + + return filters; + } + + //~ Inner Classes ---------------------------------------------------------- + + /** + * Visitor that keeps a reference count of the inputs used by an expression. + */ + private class InputReferenceCounter extends RexVisitorImpl { + private final int[] refCounts; + + public InputReferenceCounter(int[] refCounts) { + super(true); + this.refCounts = refCounts; + } + + public Void visitInputRef(RexInputRef inputRef) { + refCounts[inputRef.getIndex()]++; + return null; + } + } +} + +// End JoinToMultiJoinRule.java + diff --git ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/translator/HiveOpConverter.java ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/translator/HiveOpConverter.java index 85d1663..3a3bc6e 100644 --- ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/translator/HiveOpConverter.java +++ ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/translator/HiveOpConverter.java @@ -34,6 +34,7 @@ import org.apache.calcite.rel.RelNode; import org.apache.calcite.rel.core.SemiJoin; import org.apache.calcite.rel.core.SortExchange; +import org.apache.calcite.rel.rules.MultiJoin; import org.apache.calcite.rex.RexInputRef; import org.apache.calcite.rex.RexLiteral; import org.apache.calcite.rex.RexNode; @@ -151,6 +152,8 @@ OpAttr dispatch(RelNode rn) throws SemanticException { return visit((HiveTableScan) rn); } else if (rn instanceof HiveProject) { return visit((HiveProject) rn); + } else if (rn instanceof MultiJoin) { + return visit((MultiJoin) rn); } else if (rn instanceof HiveJoin) { return visit((HiveJoin) rn); } else if (rn instanceof SemiJoin) { @@ -296,7 +299,15 @@ OpAttr visit(HiveProject projectRel) throws SemanticException { return new OpAttr(inputOpAf.tabAlias, colInfoVColPair.getValue(), selOp); } + OpAttr visit(MultiJoin joinRel) throws SemanticException { + return translateJoin(joinRel); + } + OpAttr visit(HiveJoin joinRel) throws SemanticException { + return translateJoin(joinRel); + } + + private OpAttr translateJoin(RelNode joinRel) throws SemanticException { // 1. Convert inputs OpAttr[] inputs = new OpAttr[joinRel.getInputs().size()]; List> children = new ArrayList>(joinRel.getInputs().size()); @@ -311,7 +322,12 @@ OpAttr visit(HiveJoin joinRel) throws SemanticException { } // 2. Convert join condition - JoinPredicateInfo joinPredInfo = JoinPredicateInfo.constructJoinPredicateInfo(joinRel); + JoinPredicateInfo joinPredInfo; + if (joinRel instanceof HiveJoin) { + joinPredInfo = JoinPredicateInfo.constructJoinPredicateInfo((HiveJoin)joinRel); + } else { + joinPredInfo = JoinPredicateInfo.constructJoinPredicateInfo((MultiJoin)joinRel); + } // 3. Extract join keys from condition ExprNodeDesc[][] joinKeys = extractJoinKeys(joinPredInfo, joinRel.getInputs(), inputs); @@ -330,7 +346,8 @@ OpAttr visit(HiveJoin joinRel) throws SemanticException { // 6. Virtual columns Set newVcolsInCalcite = new HashSet(); newVcolsInCalcite.addAll(inputs[0].vcolsInCalcite); - if (extractJoinType(joinRel) != JoinType.LEFTSEMI) { + if (joinRel instanceof MultiJoin || + extractJoinType((HiveJoin)joinRel) != JoinType.LEFTSEMI) { int shift = inputs[0].inputs.get(0).getSchema().getSignature().size(); for (int i = 1; i < inputs.length; i++) { newVcolsInCalcite.addAll(HiveCalciteUtil.shiftVColsSet(inputs[i].vcolsInCalcite, shift)); @@ -752,18 +769,24 @@ private static ReduceSinkOperator genReduceSink(Operator input, ExprNodeDesc[ return rsOp; } - private static JoinOperator genJoin(HiveJoin hiveJoin, JoinPredicateInfo joinPredInfo, + private static JoinOperator genJoin(RelNode join, JoinPredicateInfo joinPredInfo, List> children, ExprNodeDesc[][] joinKeys) throws SemanticException { // Extract join type - JoinType joinType = extractJoinType(hiveJoin); + JoinType joinType; + if (join instanceof MultiJoin) { + joinType = JoinType.INNER; + } else { + joinType = extractJoinType((HiveJoin)join); + } - // NOTE: Currently binary joins only - JoinCondDesc[] joinCondns = new JoinCondDesc[1]; - joinCondns[0] = new JoinCondDesc(new JoinCond(0, 1, joinType)); + JoinCondDesc[] joinCondns = new JoinCondDesc[children.size()-1]; + for (int i=1; i outputColumns = new ArrayList(); - ArrayList outputColumnNames = new ArrayList(hiveJoin.getRowType() + ArrayList outputColumnNames = new ArrayList(join.getRowType() .getFieldNames()); Operator[] childOps = new Operator[children.size()]; diff --git ql/src/java/org/apache/hadoop/hive/ql/parse/CalcitePlanner.java ql/src/java/org/apache/hadoop/hive/ql/parse/CalcitePlanner.java index 5855695..49ad6ad 100644 --- ql/src/java/org/apache/hadoop/hive/ql/parse/CalcitePlanner.java +++ ql/src/java/org/apache/hadoop/hive/ql/parse/CalcitePlanner.java @@ -142,6 +142,7 @@ import org.apache.hadoop.hive.ql.optimizer.calcite.rules.HiveFilterSetOpTransposeRule; import org.apache.hadoop.hive.ql.optimizer.calcite.rules.HiveInsertExchange4JoinRule; import org.apache.hadoop.hive.ql.optimizer.calcite.rules.HiveJoinAddNotNullRule; +import org.apache.hadoop.hive.ql.optimizer.calcite.rules.HiveJoinToMultiJoinRule; import org.apache.hadoop.hive.ql.optimizer.calcite.rules.HivePartitionPruneRule; import org.apache.hadoop.hive.ql.optimizer.calcite.translator.ASTConverter; import org.apache.hadoop.hive.ql.optimizer.calcite.translator.HiveOpConverter; @@ -854,10 +855,18 @@ public RelNode apply(RelOptCluster cluster, RelOptSchema relOptSchema, SchemaPlu if (HiveConf.getBoolVar(conf, ConfVars.HIVE_CBO_RETPATH_HIVEOP)) { // run rules to aid in translation from Optiq tree -> Hive tree - hepPgm = new HepProgramBuilder().addMatchOrder(HepMatchOrder.BOTTOM_UP) - .addRuleInstance(new HiveInsertExchange4JoinRule()).build(); - hepPlanner = new HepPlanner(hepPgm); + hepPgmBldr = new HepProgramBuilder().addMatchOrder(HepMatchOrder.BOTTOM_UP); + hepPgmBldr.addRuleInstance(HiveJoinToMultiJoinRule.INSTANCE); + hepPlanner = new HepPlanner(hepPgmBldr.build()); + hepPlanner.registerMetadataProviders(list); + cluster.setMetadataProvider(new CachingRelMetadataProvider(chainedProvider, hepPlanner)); + hepPlanner.setRoot(calciteOptimizedPlan); + calciteOptimizedPlan = hepPlanner.findBestExp(); + hepPgmBldr = new HepProgramBuilder().addMatchOrder(HepMatchOrder.BOTTOM_UP); + hepPgmBldr.addRuleInstance(HiveInsertExchange4JoinRule.EXCHANGE_BELOW_JOIN); + hepPgmBldr.addRuleInstance(HiveInsertExchange4JoinRule.EXCHANGE_BELOW_MULTIJOIN); + hepPlanner = new HepPlanner(hepPgmBldr.build()); hepPlanner.registerMetadataProviders(list); cluster.setMetadataProvider(new CachingRelMetadataProvider(chainedProvider, hepPlanner)); hepPlanner.setRoot(calciteOptimizedPlan);