diff --git ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/HiveCalciteUtil.java ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/HiveCalciteUtil.java index f4e7c45..b691ee1 100644 --- ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/HiveCalciteUtil.java +++ ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/HiveCalciteUtil.java @@ -31,7 +31,6 @@ import org.apache.calcite.rel.core.Join; import org.apache.calcite.rel.core.RelFactories.ProjectFactory; import org.apache.calcite.rel.core.Sort; -import org.apache.calcite.rel.rules.MultiJoin; import org.apache.calcite.rel.type.RelDataType; import org.apache.calcite.rel.type.RelDataTypeField; import org.apache.calcite.rex.RexBuilder; @@ -55,6 +54,7 @@ import org.apache.calcite.util.Util; import org.apache.hadoop.hive.metastore.api.FieldSchema; import org.apache.hadoop.hive.ql.metadata.VirtualColumn; +import org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveMultiJoin; import org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveProject; import org.apache.hadoop.hive.ql.optimizer.calcite.translator.ExprNodeConverter; import org.apache.hadoop.hive.ql.parse.ASTNode; @@ -337,15 +337,15 @@ public static JoinPredicateInfo constructJoinPredicateInfo(Join j) { return constructJoinPredicateInfo(j, j.getCondition()); } - public static JoinPredicateInfo constructJoinPredicateInfo(MultiJoin mj) { - return constructJoinPredicateInfo(mj, mj.getJoinFilter()); + public static JoinPredicateInfo constructJoinPredicateInfo(HiveMultiJoin mj) { + return constructJoinPredicateInfo(mj, mj.getCondition()); } public static JoinPredicateInfo constructJoinPredicateInfo(Join j, RexNode predicate) { return constructJoinPredicateInfo(j.getInputs(), j.getSystemFieldList(), predicate); } - public static JoinPredicateInfo constructJoinPredicateInfo(MultiJoin mj, RexNode predicate) { + public static JoinPredicateInfo constructJoinPredicateInfo(HiveMultiJoin mj, RexNode predicate) { final List systemFieldList = ImmutableList.of(); return constructJoinPredicateInfo(mj.getInputs(), systemFieldList, predicate); } diff --git ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/reloperators/HiveMultiJoin.java ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/reloperators/HiveMultiJoin.java new file mode 100644 index 0000000..6265f2c --- /dev/null +++ ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/reloperators/HiveMultiJoin.java @@ -0,0 +1,198 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.hadoop.hive.ql.optimizer.calcite.reloperators; + +import java.util.ArrayList; +import java.util.List; + +import org.apache.calcite.linq4j.Ord; +import org.apache.calcite.plan.Convention; +import org.apache.calcite.plan.RelOptCluster; +import org.apache.calcite.plan.RelTraitSet; +import org.apache.calcite.rel.AbstractRelNode; +import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rel.RelWriter; +import org.apache.calcite.rel.core.JoinRelType; +import org.apache.calcite.rel.type.RelDataType; +import org.apache.calcite.rex.RexNode; +import org.apache.calcite.rex.RexShuttle; +import org.apache.calcite.util.Pair; +import org.apache.hadoop.hive.ql.optimizer.calcite.HiveCalciteUtil; +import org.apache.hadoop.hive.ql.optimizer.calcite.HiveCalciteUtil.JoinPredicateInfo; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.Lists; + +/** + * A HiveMultiJoin represents a succession of binary joins. + */ +public final class HiveMultiJoin extends AbstractRelNode { + + private final List inputs; + private final RexNode condition; + private final RelDataType rowType; + private final ImmutableList> joinInputs; + private final ImmutableList joinTypes; + + private final boolean outerJoin; + private final JoinPredicateInfo joinPredInfo; + + + /** + * Constructs a MultiJoin. + * + * @param cluster cluster that join belongs to + * @param inputs inputs into this multi-join + * @param condition join filter applicable to this join node + * @param rowType row type of the join result of this node + * @param joinInputs + * @param joinTypes the join type corresponding to each input; if + * an input is null-generating in a left or right + * outer join, the entry indicates the type of + * outer join; otherwise, the entry is set to + * INNER + */ + public HiveMultiJoin( + RelOptCluster cluster, + List inputs, + RexNode joinFilter, + RelDataType rowType, + List> joinInputs, + List joinTypes) { + super(cluster, cluster.traitSetOf(Convention.NONE)); + this.inputs = Lists.newArrayList(inputs); + this.condition = joinFilter; + this.rowType = rowType; + + assert joinInputs.size() == joinTypes.size(); + this.joinInputs = ImmutableList.copyOf(joinInputs); + this.joinTypes = ImmutableList.copyOf(joinTypes); + this.outerJoin = containsOuter(); + + this.joinPredInfo = HiveCalciteUtil.JoinPredicateInfo.constructJoinPredicateInfo(this); + } + + + @Override + public void replaceInput(int ordinalInParent, RelNode p) { + inputs.set(ordinalInParent, p); + } + + @Override + public RelNode copy(RelTraitSet traitSet, List inputs) { + assert traitSet.containsIfApplicable(Convention.NONE); + return new HiveMultiJoin( + getCluster(), + inputs, + condition, + rowType, + joinInputs, + joinTypes); + } + + public RelWriter explainTerms(RelWriter pw) { + List joinsString = new ArrayList(); + for (int i = 0; i < joinInputs.size(); i++) { + final StringBuilder sb = new StringBuilder(); + sb.append(joinInputs.get(i).left).append(" - ").append(joinInputs.get(i).right) + .append(" : ").append(joinTypes.get(i).name()); + joinsString.add(sb.toString()); + } + + super.explainTerms(pw); + for (Ord ord : Ord.zip(inputs)) { + pw.input("input#" + ord.i, ord.e); + } + return pw.item("condition", condition) + .item("joinsDescription", joinsString); + } + + public RelDataType deriveRowType() { + return rowType; + } + + public List getInputs() { + return inputs; + } + + @Override public List getChildExps() { + return ImmutableList.of(condition); + } + + public RelNode accept(RexShuttle shuttle) { + RexNode joinFilter = shuttle.apply(this.condition); + + if (joinFilter == this.condition) { + return this; + } + + return new HiveMultiJoin( + getCluster(), + inputs, + joinFilter, + rowType, + joinInputs, + joinTypes); + } + + /** + * @return join filters associated with this MultiJoin + */ + public RexNode getCondition() { + return condition; + } + + /** + * @return true if the MultiJoin contains a (partial) outer join. + */ + public boolean isOuterJoin() { + return outerJoin; + } + + /** + * @return join relationships between inputs + */ + public List> getJoinInputs() { + return joinInputs; + } + + /** + * @return join types of each input + */ + public List getJoinTypes() { + return joinTypes; + } + + /** + * @return the join predicate information + */ + public JoinPredicateInfo getJoinPredicateInfo() { + return joinPredInfo; + } + + private boolean containsOuter() { + for (JoinRelType joinType : joinTypes) { + if (joinType != JoinRelType.INNER) { + return true; + } + } + return false; + } +} + +// End MultiJoin.java diff --git ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/rules/HiveInsertExchange4JoinRule.java ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/rules/HiveInsertExchange4JoinRule.java index 30db8fd..107e1bd 100644 --- ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/rules/HiveInsertExchange4JoinRule.java +++ ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/rules/HiveInsertExchange4JoinRule.java @@ -28,7 +28,6 @@ import org.apache.calcite.rel.RelNode; import org.apache.calcite.rel.core.Exchange; import org.apache.calcite.rel.core.Join; -import org.apache.calcite.rel.rules.MultiJoin; import org.apache.calcite.rex.RexNode; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; @@ -37,6 +36,7 @@ import org.apache.hadoop.hive.ql.optimizer.calcite.HiveCalciteUtil.JoinPredicateInfo; import org.apache.hadoop.hive.ql.optimizer.calcite.HiveRelCollation; import org.apache.hadoop.hive.ql.optimizer.calcite.HiveRelDistribution; +import org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveMultiJoin; import org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveSortExchange; import com.google.common.collect.ImmutableList; @@ -57,7 +57,7 @@ /** Rule that creates Exchange operators under a MultiJoin operator. */ public static final HiveInsertExchange4JoinRule EXCHANGE_BELOW_MULTIJOIN = - new HiveInsertExchange4JoinRule(MultiJoin.class); + new HiveInsertExchange4JoinRule(HiveMultiJoin.class); /** Rule that creates Exchange operators under a Join operator. */ public static final HiveInsertExchange4JoinRule EXCHANGE_BELOW_JOIN = @@ -71,8 +71,8 @@ public HiveInsertExchange4JoinRule(Class clazz) { @Override public void onMatch(RelOptRuleCall call) { JoinPredicateInfo joinPredInfo; - if (call.rel(0) instanceof MultiJoin) { - MultiJoin multiJoin = call.rel(0); + if (call.rel(0) instanceof HiveMultiJoin) { + HiveMultiJoin multiJoin = call.rel(0); joinPredInfo = HiveCalciteUtil.JoinPredicateInfo.constructJoinPredicateInfo(multiJoin); } else if (call.rel(0) instanceof Join) { Join join = call.rel(0); @@ -114,8 +114,8 @@ public void onMatch(RelOptRuleCall call) { } RelNode newOp; - if (call.rel(0) instanceof MultiJoin) { - MultiJoin multiJoin = call.rel(0); + if (call.rel(0) instanceof HiveMultiJoin) { + HiveMultiJoin multiJoin = call.rel(0); newOp = multiJoin.copy(multiJoin.getTraitSet(), newInputs); } else if (call.rel(0) instanceof Join) { Join join = call.rel(0); diff --git ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/rules/HiveJoinToMultiJoinRule.java ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/rules/HiveJoinToMultiJoinRule.java index 532d7d3..a24a0d1 100644 --- ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/rules/HiveJoinToMultiJoinRule.java +++ ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/rules/HiveJoinToMultiJoinRule.java @@ -17,8 +17,8 @@ */ package org.apache.hadoop.hive.ql.optimizer.calcite.rules; +import java.util.ArrayList; import java.util.List; -import java.util.Map; import org.apache.calcite.plan.RelOptRule; import org.apache.calcite.plan.RelOptRuleCall; @@ -26,21 +26,19 @@ import org.apache.calcite.rel.RelNode; import org.apache.calcite.rel.core.Join; import org.apache.calcite.rel.core.JoinRelType; -import org.apache.calcite.rel.rules.MultiJoin; +import org.apache.calcite.rel.type.RelDataTypeField; import org.apache.calcite.rex.RexBuilder; -import org.apache.calcite.rex.RexInputRef; import org.apache.calcite.rex.RexNode; import org.apache.calcite.rex.RexUtil; -import org.apache.calcite.rex.RexVisitorImpl; import org.apache.calcite.util.ImmutableBitSet; -import org.apache.calcite.util.ImmutableIntList; import org.apache.calcite.util.Pair; import org.apache.hadoop.hive.ql.optimizer.calcite.HiveCalciteUtil; import org.apache.hadoop.hive.ql.optimizer.calcite.HiveCalciteUtil.JoinPredicateInfo; +import org.apache.hadoop.hive.ql.optimizer.calcite.HiveRelOptUtil; +import org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveMultiJoin; -import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableList; import com.google.common.collect.Lists; -import com.google.common.collect.Maps; /** * Rule that merges a join with multijoin/join children if @@ -73,104 +71,103 @@ public void onMatch(RelOptRuleCall call) { final RexBuilder rexBuilder = join.getCluster().getRexBuilder(); - // We do not merge outer joins currently - if (join.getJoinType() != JoinRelType.INNER) { - return; - } - // We check whether the join can be combined with any of its children final List newInputs = Lists.newArrayList(); final List newJoinFilters = Lists.newArrayList(); newJoinFilters.add(join.getCondition()); - final List> joinSpecs = Lists.newArrayList(); - final List projFields = Lists.newArrayList(); + final List, JoinRelType>> joinSpecs = Lists.newArrayList(); // Left child - if (left instanceof Join || left instanceof MultiJoin) { + if (left instanceof Join || left instanceof HiveMultiJoin) { final RexNode leftCondition; + final List> leftJoinInputs; + final List leftJoinTypes; if (left instanceof Join) { - leftCondition = ((Join) left).getCondition(); + Join hj = (Join) left; + leftCondition = hj.getCondition(); + leftJoinInputs = ImmutableList.of(Pair.of(0, 1)); + leftJoinTypes = ImmutableList.of(hj.getJoinType()); } else { - leftCondition = ((MultiJoin) left).getJoinFilter(); + HiveMultiJoin hmj = (HiveMultiJoin) left; + leftCondition = hmj.getCondition(); + leftJoinInputs = hmj.getJoinInputs(); + leftJoinTypes = hmj.getJoinTypes(); } boolean combinable = isCombinablePredicate(join, join.getCondition(), leftCondition); if (combinable) { newJoinFilters.add(leftCondition); - for (RelNode input : left.getInputs()) { - projFields.add(null); - joinSpecs.add(Pair.of(JoinRelType.INNER, (RexNode) null)); - newInputs.add(input); + for (int i = 0; i < leftJoinInputs.size(); i++) { + joinSpecs.add(Pair.of(leftJoinInputs.get(i), leftJoinTypes.get(i))); } + newInputs.addAll(left.getInputs()); } else { - projFields.add(null); - joinSpecs.add(Pair.of(JoinRelType.INNER, (RexNode) null)); newInputs.add(left); } } else { - projFields.add(null); - joinSpecs.add(Pair.of(JoinRelType.INNER, (RexNode) null)); newInputs.add(left); } + final int numberLeftInputs = newInputs.size(); // Right child - if (right instanceof Join || right instanceof MultiJoin) { - final RexNode rightCondition; - if (right instanceof Join) { - rightCondition = shiftRightFilter(join, left, right, - ((Join) right).getCondition()); - } else { - rightCondition = shiftRightFilter(join, left, right, - ((MultiJoin) right).getJoinFilter()); - } - - boolean combinable = isCombinablePredicate(join, join.getCondition(), - rightCondition); - if (combinable) { - newJoinFilters.add(rightCondition); - for (RelNode input : right.getInputs()) { - projFields.add(null); - joinSpecs.add(Pair.of(JoinRelType.INNER, (RexNode) null)); - newInputs.add(input); - } - } else { - projFields.add(null); - joinSpecs.add(Pair.of(JoinRelType.INNER, (RexNode) null)); - newInputs.add(right); - } - } else { - projFields.add(null); - joinSpecs.add(Pair.of(JoinRelType.INNER, (RexNode) null)); - newInputs.add(right); - } + newInputs.add(right); // If we cannot combine any of the children, we bail out if (newJoinFilters.size() == 1) { return; } + final List systemFieldList = ImmutableList.of(); + List> joinKeyExprs = new ArrayList>(); + List filterNulls = new ArrayList(); + for (int i=0; i()); + } + HiveRelOptUtil.splitJoinCondition(systemFieldList, newInputs, join.getCondition(), + joinKeyExprs, filterNulls, null); + ImmutableBitSet.Builder keysInInputsBuilder = ImmutableBitSet.builder(); + for (int i=0; i partialCondition = joinKeyExprs.get(i); + if (!partialCondition.isEmpty()) { + keysInInputsBuilder.set(i); + } + } + // If we cannot merge, we bail out + ImmutableBitSet keysInInputs = keysInInputsBuilder.build(); + ImmutableBitSet leftReferencedInputs = + keysInInputs.intersect(ImmutableBitSet.range(numberLeftInputs)); + ImmutableBitSet rightReferencedInputs = + keysInInputs.intersect(ImmutableBitSet.range(numberLeftInputs, newInputs.size())); + if (join.getJoinType() != JoinRelType.INNER && + (leftReferencedInputs.cardinality() > 1 || rightReferencedInputs.cardinality() > 1)) { + return; + } + // Otherwise, we add to the join specs + if (join.getJoinType() != JoinRelType.INNER) { + int leftInput = keysInInputs.nextSetBit(0); + int rightInput = keysInInputs.nextSetBit(numberLeftInputs); + joinSpecs.add(Pair.of(Pair.of(leftInput, rightInput), join.getJoinType())); + } else { + for (int i : leftReferencedInputs) { + for (int j : rightReferencedInputs) { + joinSpecs.add(Pair.of(Pair.of(i, j), join.getJoinType())); + } + } + } + + // We can now create a multijoin operator RexNode newCondition = RexUtil.flatten(rexBuilder, RexUtil.composeConjunction(rexBuilder, newJoinFilters, false)); - final ImmutableMap newJoinFieldRefCountsMap = - addOnJoinFieldRefCounts(newInputs, - join.getRowType().getFieldCount(), - newCondition); - - List newPostJoinFilters = combinePostJoinFilters(join, left, right); RelNode multiJoin = - new MultiJoin( + new HiveMultiJoin( join.getCluster(), newInputs, newCondition, join.getRowType(), - false, - Pair.right(joinSpecs), Pair.left(joinSpecs), - projFields, - newJoinFieldRefCountsMap, - RexUtil.composeConjunction(rexBuilder, newPostJoinFilters, true)); + Pair.right(joinSpecs)); call.transformTo(multiJoin); } @@ -228,106 +225,4 @@ private RexNode shiftRightFilter( return rightFilter; } - /** - * Adds on to the existing join condition reference counts the references - * from the new join condition. - * - * @param multiJoinInputs inputs into the new MultiJoin - * @param nTotalFields total number of fields in the MultiJoin - * @param joinCondition the new join condition - * @param origJoinFieldRefCounts existing join condition reference counts - * - * @return Map containing the new join condition - */ - private ImmutableMap addOnJoinFieldRefCounts( - List multiJoinInputs, - int nTotalFields, - RexNode joinCondition) { - // count the input references in the join condition - int[] joinCondRefCounts = new int[nTotalFields]; - joinCondition.accept(new InputReferenceCounter(joinCondRefCounts)); - - // add on to the counts for each input into the MultiJoin the - // reference counts computed for the current join condition - final Map refCountsMap = Maps.newHashMap(); - int nInputs = multiJoinInputs.size(); - int currInput = -1; - int startField = 0; - int nFields = 0; - for (int i = 0; i < nTotalFields; i++) { - if (joinCondRefCounts[i] == 0) { - continue; - } - while (i >= (startField + nFields)) { - startField += nFields; - currInput++; - assert currInput < nInputs; - nFields = - multiJoinInputs.get(currInput).getRowType().getFieldCount(); - } - int[] refCounts = refCountsMap.get(currInput); - if (refCounts == null) { - refCounts = new int[nFields]; - refCountsMap.put(currInput, refCounts); - } - refCounts[i - startField] += joinCondRefCounts[i]; - } - - final ImmutableMap.Builder builder = - ImmutableMap.builder(); - for (Map.Entry entry : refCountsMap.entrySet()) { - builder.put(entry.getKey(), ImmutableIntList.of(entry.getValue())); - } - return builder.build(); - } - - /** - * Combines the post-join filters from the left and right inputs (if they - * are MultiJoinRels) into a single AND'd filter. - * - * @param joinRel the original LogicalJoin - * @param left left child of the LogicalJoin - * @param right right child of the LogicalJoin - * @return combined post-join filters AND'd together - */ - private List combinePostJoinFilters( - Join joinRel, - RelNode left, - RelNode right) { - final List filters = Lists.newArrayList(); - if (right instanceof MultiJoin) { - final MultiJoin multiRight = (MultiJoin) right; - filters.add( - shiftRightFilter(joinRel, left, multiRight, - multiRight.getPostJoinFilter())); - } - - if (left instanceof MultiJoin) { - filters.add(((MultiJoin) left).getPostJoinFilter()); - } - - return filters; - } - - //~ Inner Classes ---------------------------------------------------------- - - /** - * Visitor that keeps a reference count of the inputs used by an expression. - */ - private class InputReferenceCounter extends RexVisitorImpl { - private final int[] refCounts; - - public InputReferenceCounter(int[] refCounts) { - super(true); - this.refCounts = refCounts; - } - - public Void visitInputRef(RexInputRef inputRef) { - refCounts[inputRef.getIndex()]++; - return null; - } - } } - -// End JoinToMultiJoinRule.java - diff --git ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/translator/HiveOpConverter.java ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/translator/HiveOpConverter.java index efc2542..0b1a37a 100644 --- ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/translator/HiveOpConverter.java +++ ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/translator/HiveOpConverter.java @@ -32,8 +32,8 @@ import org.apache.calcite.rel.RelDistribution.Type; import org.apache.calcite.rel.RelFieldCollation; import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rel.core.JoinRelType; import org.apache.calcite.rel.core.SemiJoin; -import org.apache.calcite.rel.rules.MultiJoin; import org.apache.calcite.rex.RexInputRef; import org.apache.calcite.rex.RexLiteral; import org.apache.calcite.rex.RexNode; @@ -62,6 +62,7 @@ import org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveAggregate; import org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveFilter; import org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveJoin; +import org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveMultiJoin; import org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveProject; import org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveSort; import org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveSortExchange; @@ -152,8 +153,8 @@ OpAttr dispatch(RelNode rn) throws SemanticException { return visit((HiveTableScan) rn); } else if (rn instanceof HiveProject) { return visit((HiveProject) rn); - } else if (rn instanceof MultiJoin) { - return visit((MultiJoin) rn); + } else if (rn instanceof HiveMultiJoin) { + return visit((HiveMultiJoin) rn); } else if (rn instanceof HiveJoin) { return visit((HiveJoin) rn); } else if (rn instanceof SemiJoin) { @@ -299,7 +300,7 @@ OpAttr visit(HiveProject projectRel) throws SemanticException { return new OpAttr(inputOpAf.tabAlias, colInfoVColPair.getValue(), selOp); } - OpAttr visit(MultiJoin joinRel) throws SemanticException { + OpAttr visit(HiveMultiJoin joinRel) throws SemanticException { return translateJoin(joinRel); } @@ -326,7 +327,7 @@ private OpAttr translateJoin(RelNode joinRel) throws SemanticException { if (joinRel instanceof HiveJoin) { joinPredInfo = JoinPredicateInfo.constructJoinPredicateInfo((HiveJoin)joinRel); } else { - joinPredInfo = JoinPredicateInfo.constructJoinPredicateInfo((MultiJoin)joinRel); + joinPredInfo = JoinPredicateInfo.constructJoinPredicateInfo((HiveMultiJoin)joinRel); } // 3. Extract join key expressions from HiveSortExchange @@ -349,7 +350,7 @@ private OpAttr translateJoin(RelNode joinRel) throws SemanticException { // 6. Virtual columns Set newVcolsInCalcite = new HashSet(); newVcolsInCalcite.addAll(inputs[0].vcolsInCalcite); - if (joinRel instanceof MultiJoin || + if (joinRel instanceof HiveMultiJoin || extractJoinType((HiveJoin)joinRel) != JoinType.LEFTSEMI) { int shift = inputs[0].inputs.get(0).getSchema().getSignature().size(); for (int i = 1; i < inputs.length; i++) { @@ -775,16 +776,27 @@ private static JoinOperator genJoin(RelNode join, JoinPredicateInfo joinPredInfo List> children, ExprNodeDesc[][] joinKeys) throws SemanticException { // Extract join type - JoinType joinType; - if (join instanceof MultiJoin) { - joinType = JoinType.INNER; - } else { - joinType = extractJoinType((HiveJoin)join); - } - JoinCondDesc[] joinCondns = new JoinCondDesc[children.size()-1]; - for (int i=1; i outputColumns = new ArrayList(); @@ -811,7 +823,7 @@ private static JoinOperator genJoin(RelNode join, JoinPredicateInfo joinPredInfo Byte tag = (byte) rsDesc.getTag(); // Semijoin - if (joinType == JoinType.LEFTSEMI && pos != 0) { + if (semiJoin && pos != 0) { exprMap.put(tag, new ArrayList()); childOps[pos] = inputRS; continue; @@ -839,8 +851,6 @@ private static JoinOperator genJoin(RelNode join, JoinPredicateInfo joinPredInfo childOps[pos] = inputRS; } - boolean noOuterJoin = joinType != JoinType.FULLOUTER && joinType != JoinType.LEFTOUTER - && joinType != JoinType.RIGHTOUTER; JoinDesc desc = new JoinDesc(exprMap, outputColumnNames, noOuterJoin, joinCondns, joinKeys); desc.setReversedExprs(reversedExprs); @@ -886,6 +896,25 @@ private static JoinType extractJoinType(HiveJoin join) { return resultJoinType; } + private static JoinType transformJoinType(JoinRelType type) { + JoinType resultJoinType; + switch (type) { + case FULL: + resultJoinType = JoinType.FULLOUTER; + break; + case LEFT: + resultJoinType = JoinType.LEFTOUTER; + break; + case RIGHT: + resultJoinType = JoinType.RIGHTOUTER; + break; + default: + resultJoinType = JoinType.INNER; + break; + } + return resultJoinType; + } + private static Map buildBacktrackFromReduceSinkForJoin(int initialPos, List outputColumnNames, List keyColNames, List valueColNames, int[] index, Operator inputOp) {