diff --git ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/HiveRexUtil.java ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/HiveRexUtil.java index d466378..c04530c 100644 --- ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/HiveRexUtil.java +++ ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/HiveRexUtil.java @@ -256,8 +256,10 @@ public static RexNode simplifyAnd(RexBuilder rexBuilder, RexCall e, public static RexNode simplifyAnd2(RexBuilder rexBuilder, List terms, List notTerms) { - if (terms.contains(rexBuilder.makeLiteral(false))) { - return rexBuilder.makeLiteral(false); + for (RexNode term : terms) { + if (term.isAlwaysFalse()) { + return rexBuilder.makeLiteral(false); + } } if (terms.isEmpty() && notTerms.isEmpty()) { return rexBuilder.makeLiteral(true); @@ -292,8 +294,10 @@ public static RexNode simplifyAnd2(RexBuilder rexBuilder, * UNKNOWN it will be interpreted as FALSE. */ public static RexNode simplifyAnd2ForUnknownAsFalse(RexBuilder rexBuilder, List terms, List notTerms) { - if (terms.contains(rexBuilder.makeLiteral(false))) { - return rexBuilder.makeLiteral(false); + for (RexNode term : terms) { + if (term.isAlwaysFalse()) { + return rexBuilder.makeLiteral(false); + } } if (terms.isEmpty() && notTerms.isEmpty()) { return rexBuilder.makeLiteral(true); diff --git ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/stats/HiveRelMdPredicates.java ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/stats/HiveRelMdPredicates.java index 9cec6ca..7946e69 100644 --- ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/stats/HiveRelMdPredicates.java +++ ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/stats/HiveRelMdPredicates.java @@ -18,16 +18,26 @@ package org.apache.hadoop.hive.ql.optimizer.calcite.stats; import java.util.ArrayList; +import java.util.BitSet; +import java.util.HashSet; +import java.util.Iterator; import java.util.LinkedHashMap; import java.util.List; import java.util.Map; import java.util.Map.Entry; +import java.util.Set; +import java.util.SortedMap; +import org.apache.calcite.linq4j.Linq4j; import org.apache.calcite.linq4j.Ord; +import org.apache.calcite.linq4j.function.Predicate1; import org.apache.calcite.plan.RelOptPredicateList; import org.apache.calcite.plan.RelOptUtil; import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rel.core.Join; +import org.apache.calcite.rel.core.JoinRelType; import org.apache.calcite.rel.core.Project; +import org.apache.calcite.rel.core.SemiJoin; import org.apache.calcite.rel.core.Union; import org.apache.calcite.rel.metadata.ReflectiveRelMetadataProvider; import org.apache.calcite.rel.metadata.RelMdPredicates; @@ -40,7 +50,10 @@ import org.apache.calcite.rex.RexNode; import org.apache.calcite.rex.RexPermuteInputsShuttle; import org.apache.calcite.rex.RexUtil; +import org.apache.calcite.rex.RexVisitorImpl; +import org.apache.calcite.sql.SqlKind; import org.apache.calcite.sql.fun.SqlStdOperatorTable; +import org.apache.calcite.util.BitSets; import org.apache.calcite.util.BuiltInMethod; import org.apache.calcite.util.ImmutableBitSet; import org.apache.calcite.util.mapping.Mapping; @@ -48,8 +61,13 @@ import org.apache.calcite.util.mapping.Mappings; import org.apache.hadoop.hive.ql.optimizer.calcite.HiveCalciteUtil; +import com.google.common.base.Function; import com.google.common.collect.HashMultimap; import com.google.common.collect.ImmutableList; +import com.google.common.collect.Iterables; +import com.google.common.collect.Iterators; +import com.google.common.collect.Lists; +import com.google.common.collect.Maps; //TODO: Move this to calcite @@ -131,6 +149,25 @@ public RelOptPredicateList getPredicates(Project project, RelMetadataQuery mq) { return RelOptPredicateList.of(projectPullUpPredicates); } + /** Infers predicates for a {@link org.apache.calcite.rel.core.Join}. */ + @Override + public RelOptPredicateList getPredicates(Join join, RelMetadataQuery mq) { + RexBuilder rB = join.getCluster().getRexBuilder(); + RelNode left = join.getInput(0); + RelNode right = join.getInput(1); + + final RelOptPredicateList leftInfo = mq.getPulledUpPredicates(left); + final RelOptPredicateList rightInfo = mq.getPulledUpPredicates(right); + + JoinConditionBasedPredicateInference jI = + new JoinConditionBasedPredicateInference(join, + RexUtil.composeConjunction(rB, leftInfo.pulledUpPredicates, false), + RexUtil.composeConjunction(rB, rightInfo.pulledUpPredicates, + false)); + + return jI.inferPredicates(false); + } + /** * Infers predicates for a Union. */ @@ -176,4 +213,433 @@ public RelOptPredicateList getPredicates(Union union, RelMetadataQuery mq) { return RelOptPredicateList.of(preds); } + /** + * Utility to infer predicates from one side of the join that apply on the + * other side. + * + *

Contract is:

    + * + *
  • initialize with a {@link org.apache.calcite.rel.core.Join} and + * optional predicates applicable on its left and right subtrees. + * + *
  • you can + * then ask it for equivalentPredicate(s) given a predicate. + * + *
+ * + *

So for: + *

    + *
  1. 'R1(x) join R2(y) on x = y' a call for + * equivalentPredicates on 'x > 7' will return ' + * [y > 7]' + *
  2. 'R1(x) join R2(y) on x = y join R3(z) on y = z' a call for + * equivalentPredicates on the second join 'x > 7' will return + *
+ */ + static class JoinConditionBasedPredicateInference { + final Join joinRel; + final boolean isSemiJoin; + final int nSysFields; + final int nFieldsLeft; + final int nFieldsRight; + final ImmutableBitSet leftFieldsBitSet; + final ImmutableBitSet rightFieldsBitSet; + final ImmutableBitSet allFieldsBitSet; + SortedMap equivalence; + final Map exprFields; + final Set allExprsDigests; + final Set equalityPredicates; + final RexNode leftChildPredicates; + final RexNode rightChildPredicates; + + public JoinConditionBasedPredicateInference(Join joinRel, + RexNode lPreds, RexNode rPreds) { + this(joinRel, joinRel instanceof SemiJoin, lPreds, rPreds); + } + + private JoinConditionBasedPredicateInference(Join joinRel, boolean isSemiJoin, + RexNode lPreds, RexNode rPreds) { + super(); + this.joinRel = joinRel; + this.isSemiJoin = isSemiJoin; + nFieldsLeft = joinRel.getLeft().getRowType().getFieldList().size(); + nFieldsRight = joinRel.getRight().getRowType().getFieldList().size(); + nSysFields = joinRel.getSystemFieldList().size(); + leftFieldsBitSet = ImmutableBitSet.range(nSysFields, + nSysFields + nFieldsLeft); + rightFieldsBitSet = ImmutableBitSet.range(nSysFields + nFieldsLeft, + nSysFields + nFieldsLeft + nFieldsRight); + allFieldsBitSet = ImmutableBitSet.range(0, + nSysFields + nFieldsLeft + nFieldsRight); + + exprFields = Maps.newHashMap(); + allExprsDigests = new HashSet<>(); + + if (lPreds == null) { + leftChildPredicates = null; + } else { + Mappings.TargetMapping leftMapping = Mappings.createShiftMapping( + nSysFields + nFieldsLeft, nSysFields, 0, nFieldsLeft); + leftChildPredicates = lPreds.accept( + new RexPermuteInputsShuttle(leftMapping, joinRel.getInput(0))); + + for (RexNode r : RelOptUtil.conjunctions(leftChildPredicates)) { + exprFields.put(r.toString(), RelOptUtil.InputFinder.bits(r)); + allExprsDigests.add(r.toString()); + } + } + if (rPreds == null) { + rightChildPredicates = null; + } else { + Mappings.TargetMapping rightMapping = Mappings.createShiftMapping( + nSysFields + nFieldsLeft + nFieldsRight, + nSysFields + nFieldsLeft, 0, nFieldsRight); + rightChildPredicates = rPreds.accept( + new RexPermuteInputsShuttle(rightMapping, joinRel.getInput(1))); + + for (RexNode r : RelOptUtil.conjunctions(rightChildPredicates)) { + exprFields.put(r.toString(), RelOptUtil.InputFinder.bits(r)); + allExprsDigests.add(r.toString()); + } + } + + equivalence = Maps.newTreeMap(); + equalityPredicates = new HashSet<>(); + for (int i = 0; i < nSysFields + nFieldsLeft + nFieldsRight; i++) { + equivalence.put(i, BitSets.of(i)); + } + + // Only process equivalences found in the join conditions. Processing + // Equivalences from the left or right side infer predicates that are + // already present in the Tree below the join. + RexBuilder rexBuilder = joinRel.getCluster().getRexBuilder(); + List exprs = + RelOptUtil.conjunctions( + compose(rexBuilder, ImmutableList.of(joinRel.getCondition()))); + + final EquivalenceFinder eF = new EquivalenceFinder(); + new ArrayList<>( + Lists.transform(exprs, + new Function() { + public Void apply(RexNode input) { + return input.accept(eF); + } + })); + + equivalence = BitSets.closure(equivalence); + } + + /** + * The PullUp Strategy is sound but not complete. + *
    + *
  1. We only pullUp inferred predicates for now. Pulling up existing + * predicates causes an explosion of duplicates. The existing predicates are + * pushed back down as new predicates. Once we have rules to eliminate + * duplicate Filter conditions, we should pullUp all predicates. + *
  2. For Left Outer: we infer new predicates from the left and set them as + * applicable on the Right side. No predicates are pulledUp. + *
  3. Right Outer Joins are handled in an analogous manner. + *
  4. For Full Outer Joins no predicates are pulledUp or inferred. + *
+ */ + public RelOptPredicateList inferPredicates( + boolean includeEqualityInference) { + final List inferredPredicates = new ArrayList<>(); + final List nonFieldsPredicates = new ArrayList<>(); + final Set allExprsDigests = new HashSet<>(this.allExprsDigests); + final JoinRelType joinType = joinRel.getJoinType(); + final List leftPreds = ImmutableList.copyOf(RelOptUtil.conjunctions(leftChildPredicates)); + final List rightPreds = ImmutableList.copyOf(RelOptUtil.conjunctions(rightChildPredicates)); + switch (joinType) { + case INNER: + case LEFT: + infer(leftPreds, allExprsDigests, inferredPredicates, + nonFieldsPredicates, includeEqualityInference, + joinType == JoinRelType.LEFT ? rightFieldsBitSet + : allFieldsBitSet); + break; + } + switch (joinType) { + case INNER: + case RIGHT: + infer(rightPreds, allExprsDigests, inferredPredicates, + nonFieldsPredicates, includeEqualityInference, + joinType == JoinRelType.RIGHT ? leftFieldsBitSet + : allFieldsBitSet); + break; + } + + Mappings.TargetMapping rightMapping = Mappings.createShiftMapping( + nSysFields + nFieldsLeft + nFieldsRight, + 0, nSysFields + nFieldsLeft, nFieldsRight); + final RexPermuteInputsShuttle rightPermute = + new RexPermuteInputsShuttle(rightMapping, joinRel); + Mappings.TargetMapping leftMapping = Mappings.createShiftMapping( + nSysFields + nFieldsLeft, 0, nSysFields, nFieldsLeft); + final RexPermuteInputsShuttle leftPermute = + new RexPermuteInputsShuttle(leftMapping, joinRel); + final List leftInferredPredicates = new ArrayList<>(); + final List rightInferredPredicates = new ArrayList<>(); + + for (RexNode iP : inferredPredicates) { + ImmutableBitSet iPBitSet = RelOptUtil.InputFinder.bits(iP); + if (leftFieldsBitSet.contains(iPBitSet)) { + leftInferredPredicates.add(iP.accept(leftPermute)); + } else if (rightFieldsBitSet.contains(iPBitSet)) { + rightInferredPredicates.add(iP.accept(rightPermute)); + } + } + + if (joinType == JoinRelType.INNER && !nonFieldsPredicates.isEmpty()) { + // Predicates without field references can be pushed to both inputs + final Set leftPredsSet = new HashSet( + Lists.transform(leftPreds, HiveCalciteUtil.REX_STR_FN)); + final Set rightPredsSet = new HashSet( + Lists.transform(rightPreds, HiveCalciteUtil.REX_STR_FN)); + for (RexNode iP : nonFieldsPredicates) { + if (!leftPredsSet.contains(iP.toString())) { + leftInferredPredicates.add(iP); + } + if (!rightPredsSet.contains(iP.toString())) { + rightInferredPredicates.add(iP); + } + } + } + + switch (joinType) { + case INNER: + Iterable pulledUpPredicates; + if (isSemiJoin) { + pulledUpPredicates = Iterables.concat(leftPreds, leftInferredPredicates); + } else { + pulledUpPredicates = Iterables.concat(leftPreds, rightPreds, + RelOptUtil.conjunctions(joinRel.getCondition()), inferredPredicates); + } + return RelOptPredicateList.of( + pulledUpPredicates, leftInferredPredicates, rightInferredPredicates); + case LEFT: + return RelOptPredicateList.of( + leftPreds, EMPTY_LIST, rightInferredPredicates); + case RIGHT: + return RelOptPredicateList.of( + rightPreds, leftInferredPredicates, EMPTY_LIST); + default: + assert inferredPredicates.size() == 0; + return RelOptPredicateList.EMPTY; + } + } + + public RexNode left() { + return leftChildPredicates; + } + + public RexNode right() { + return rightChildPredicates; + } + + private void infer(List predicates, Set allExprsDigests, + List inferedPredicates, List nonFieldsPredicates, + boolean includeEqualityInference, ImmutableBitSet inferringFields) { + for (RexNode r : predicates) { + if (!includeEqualityInference + && equalityPredicates.contains(r.toString())) { + continue; + } + Iterable ms = mappings(r); + if (ms.iterator().hasNext()) { + for (Mapping m : ms) { + RexNode tr = r.accept( + new RexPermuteInputsShuttle(m, joinRel.getInput(0), + joinRel.getInput(1))); + if (inferringFields.contains(RelOptUtil.InputFinder.bits(tr)) + && !allExprsDigests.contains(tr.toString()) + && !isAlwaysTrue(tr)) { + inferedPredicates.add(tr); + allExprsDigests.add(tr.toString()); + } + } + } else { + if (!isAlwaysTrue(r)) { + nonFieldsPredicates.add(r); + } + } + } + } + + Iterable mappings(final RexNode predicate) { + return new Iterable() { + public Iterator iterator() { + ImmutableBitSet fields = exprFields.get(predicate.toString()); + if (fields.cardinality() == 0) { + return Iterators.emptyIterator(); + } + return new ExprsItr(fields); + } + }; + } + + private void equivalent(int p1, int p2) { + BitSet b = equivalence.get(p1); + b.set(p2); + + b = equivalence.get(p2); + b.set(p1); + } + + RexNode compose(RexBuilder rexBuilder, Iterable exprs) { + exprs = Linq4j.asEnumerable(exprs).where(new Predicate1() { + public boolean apply(RexNode expr) { + return expr != null; + } + }); + return RexUtil.composeConjunction(rexBuilder, exprs, false); + } + + /** + * Find expressions of the form 'col_x = col_y'. + */ + class EquivalenceFinder extends RexVisitorImpl { + protected EquivalenceFinder() { + super(true); + } + + @Override public Void visitCall(RexCall call) { + if (call.getOperator().getKind() == SqlKind.EQUALS) { + int lPos = pos(call.getOperands().get(0)); + int rPos = pos(call.getOperands().get(1)); + if (lPos != -1 && rPos != -1) { + JoinConditionBasedPredicateInference.this.equivalent(lPos, rPos); + JoinConditionBasedPredicateInference.this.equalityPredicates + .add(call.toString()); + } + } + return null; + } + } + + /** + * Given an expression returns all the possible substitutions. + * + *

For example, for an expression 'a + b + c' and the following + * equivalences:

+     * a : {a, b}
+     * b : {a, b}
+     * c : {c, e}
+     * 
+ * + *

The following Mappings will be returned: + *

+     * {a → a, b → a, c → c}
+     * {a → a, b → a, c → e}
+     * {a → a, b → b, c → c}
+     * {a → a, b → b, c → e}
+     * {a → b, b → a, c → c}
+     * {a → b, b → a, c → e}
+     * {a → b, b → b, c → c}
+     * {a → b, b → b, c → e}
+     * 
+ * + *

which imply the following inferences: + *

+     * a + a + c
+     * a + a + e
+     * a + b + c
+     * a + b + e
+     * b + a + c
+     * b + a + e
+     * b + b + c
+     * b + b + e
+     * 
+ */ + class ExprsItr implements Iterator { + final int[] columns; + final BitSet[] columnSets; + final int[] iterationIdx; + Mapping nextMapping; + boolean firstCall; + + ExprsItr(ImmutableBitSet fields) { + nextMapping = null; + columns = new int[fields.cardinality()]; + columnSets = new BitSet[fields.cardinality()]; + iterationIdx = new int[fields.cardinality()]; + for (int j = 0, i = fields.nextSetBit(0); i >= 0; i = fields + .nextSetBit(i + 1), j++) { + columns[j] = i; + columnSets[j] = equivalence.get(i); + iterationIdx[j] = 0; + } + firstCall = true; + } + + public boolean hasNext() { + if (firstCall) { + initializeMapping(); + firstCall = false; + } else { + computeNextMapping(iterationIdx.length - 1); + } + return nextMapping != null; + } + + public Mapping next() { + return nextMapping; + } + + public void remove() { + throw new UnsupportedOperationException(); + } + + private void computeNextMapping(int level) { + int t = columnSets[level].nextSetBit(iterationIdx[level]); + if (t < 0) { + if (level == 0) { + nextMapping = null; + } else { + iterationIdx[level] = 0; + computeNextMapping(level - 1); + } + } else { + nextMapping.set(columns[level], t); + iterationIdx[level] = t + 1; + } + } + + private void initializeMapping() { + nextMapping = Mappings.create(MappingType.PARTIAL_FUNCTION, + nSysFields + nFieldsLeft + nFieldsRight, + nSysFields + nFieldsLeft + nFieldsRight); + for (int i = 0; i < columnSets.length; i++) { + BitSet c = columnSets[i]; + int t = c.nextSetBit(iterationIdx[i]); + if (t < 0) { + nextMapping = null; + return; + } + nextMapping.set(columns[i], t); + iterationIdx[i] = t + 1; + } + } + } + + private int pos(RexNode expr) { + if (expr instanceof RexInputRef) { + return ((RexInputRef) expr).getIndex(); + } + return -1; + } + + private boolean isAlwaysTrue(RexNode predicate) { + if (predicate instanceof RexCall) { + RexCall c = (RexCall) predicate; + if (c.getOperator().getKind() == SqlKind.EQUALS) { + int lPos = pos(c.getOperands().get(0)); + int rPos = pos(c.getOperands().get(1)); + return lPos != -1 && lPos == rPos; + } + } + return predicate.isAlwaysTrue(); + } + } + } diff --git ql/src/test/results/clientpositive/constprog3.q.out ql/src/test/results/clientpositive/constprog3.q.out index cb440dc..f54168d 100644 --- ql/src/test/results/clientpositive/constprog3.q.out +++ ql/src/test/results/clientpositive/constprog3.q.out @@ -14,7 +14,7 @@ POSTHOOK: query: create temporary table table3(id int, val int, val1 int) POSTHOOK: type: CREATETABLE POSTHOOK: Output: database:default POSTHOOK: Output: default@table3 -Warning: Shuffle Join JOIN[8][tables = [$hdt$_0, $hdt$_1]] in Stage 'Stage-1:MAPRED' is a cross product +Warning: Shuffle Join JOIN[9][tables = [$hdt$_0, $hdt$_1]] in Stage 'Stage-1:MAPRED' is a cross product PREHOOK: query: explain select table1.id, table1.val, table1.val1 from table1 inner join table3 @@ -49,15 +49,15 @@ STAGE PLANS: value expressions: _col0 (type: int), _col1 (type: int), _col2 (type: int) TableScan alias: table3 - Statistics: Num rows: 1 Data size: 0 Basic stats: PARTIAL Column stats: NONE - Filter Operator - predicate: (id = 1) (type: boolean) - Statistics: Num rows: 1 Data size: 0 Basic stats: PARTIAL Column stats: NONE - Select Operator - Statistics: Num rows: 1 Data size: 0 Basic stats: PARTIAL Column stats: NONE + Statistics: Num rows: 1 Data size: 0 Basic stats: PARTIAL Column stats: COMPLETE + Select Operator + Statistics: Num rows: 1 Data size: 0 Basic stats: PARTIAL Column stats: COMPLETE + Filter Operator + predicate: false (type: boolean) + Statistics: Num rows: 1 Data size: 0 Basic stats: PARTIAL Column stats: COMPLETE Reduce Output Operator sort order: - Statistics: Num rows: 1 Data size: 0 Basic stats: PARTIAL Column stats: NONE + Statistics: Num rows: 1 Data size: 0 Basic stats: PARTIAL Column stats: COMPLETE Reduce Operator Tree: Join Operator condition map: