diff --git a/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/HiveRelOptUtil.java b/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/HiveRelOptUtil.java index 9aa30129b6..322e925843 100644 --- a/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/HiveRelOptUtil.java +++ b/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/HiveRelOptUtil.java @@ -17,12 +17,14 @@ */ package org.apache.hadoop.hive.ql.optimizer.calcite; +import com.google.common.collect.HashMultimap; import com.google.common.collect.Multimap; import com.google.common.collect.Sets; import java.util.AbstractList; import java.util.ArrayList; import java.util.Collection; import java.util.HashMap; +import java.util.HashSet; import java.util.LinkedHashSet; import java.util.List; @@ -62,6 +64,7 @@ import org.apache.calcite.tools.RelBuilder; import org.apache.calcite.util.ImmutableBitSet; import org.apache.calcite.util.Pair; +import org.apache.commons.lang3.tuple.Triple; import org.apache.hadoop.hive.ql.exec.FunctionRegistry; import org.apache.hadoop.hive.ql.optimizer.calcite.translator.TypeConverter; import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo; @@ -536,7 +539,183 @@ public static boolean isRowFilteringPlan(final RelMetadataQuery mq, RelNode oper return false; } - public static Pair> isRewritablePKFKJoin(RelBuilder builder, Join join, + /** + * Returns a triple where first value represents whether we could extract a FK-PK join + * or not, the second value is a pair with the column from left and right input that + * are used for the FK-PK join, and the third value are the predicates that are not + * part of the FK-PK condition. Currently we can only extract one FK-PK join. + */ + public static Triple, List> extractPKFKJoin( + Join join, List joinFilters, boolean leftInputPotentialFK, RelMetadataQuery mq) { + final List residualPreds = new ArrayList<>(); + final JoinRelType joinType = join.getJoinType(); + final RelNode fkInput = leftInputPotentialFK ? join.getLeft() : join.getRight(); + final Triple, List> cannotExtract = + Triple.of(false, null, null); + + if (joinType != JoinRelType.INNER) { + // If it is not an inner, we transform it as the metadata + // providers for expressions do not pull information through + // outer join (as it would not be correct) + join = join.copy(join.getTraitSet(), join.getCluster().getRexBuilder().makeLiteral(true), + join.getLeft(), join.getRight(), JoinRelType.INNER, false); + } + + // 1) Gather all tables from the FK side and the table from the + // non-FK side + final Set leftTables = mq.getTableReferences(join.getLeft()); + final Set rightTables = + Sets.difference(mq.getTableReferences(join), mq.getTableReferences(join.getLeft())); + final Set fkTables = join.getLeft() == fkInput ? leftTables : rightTables; + final Set nonFkTables = join.getLeft() == fkInput ? rightTables : leftTables; + if (nonFkTables.size() != 1) { + // More than one table in PK side, we bail out + return cannotExtract; + } + + // 2) Check whether there is a FK relationship + Set candidatePredicates = new HashSet<>(); + EquivalenceClasses ec = new EquivalenceClasses(); + for (RexNode conj : joinFilters) { + if (!conj.isA(SqlKind.EQUALS)) { + // Not an equality, continue + residualPreds.add(conj); + continue; + } + RexCall equiCond = (RexCall) conj; + RexNode eqOp1 = equiCond.getOperands().get(0); + if (!RexUtil.isReferenceOrAccess(eqOp1, true)) { + // Ignore + residualPreds.add(conj); + continue; + } + Set eqOp1ExprsLineage = mq.getExpressionLineage(join, eqOp1); + if (eqOp1ExprsLineage == null) { + // Cannot be mapped, continue + residualPreds.add(conj); + continue; + } + RexNode eqOp2 = equiCond.getOperands().get(1); + if (!RexUtil.isReferenceOrAccess(eqOp2, true)) { + // Ignore + residualPreds.add(conj); + continue; + } + Set eqOp2ExprsLineage = mq.getExpressionLineage(join, eqOp2); + if (eqOp2ExprsLineage == null) { + // Cannot be mapped, continue + residualPreds.add(conj); + continue; + } + List eqOp2ExprsFiltered = null; + for (RexNode eqOpExprLineage1 : eqOp1ExprsLineage) { + RexTableInputRef inputRef1 = extractTableInputRef(eqOpExprLineage1); + if (inputRef1 == null) { + // This condition could not be map into an input reference + continue; + } + if (eqOp2ExprsFiltered == null) { + // First iteration + eqOp2ExprsFiltered = new ArrayList<>(); + for (RexNode eqOpExprLineage2 : eqOp2ExprsLineage) { + RexTableInputRef inputRef2 = extractTableInputRef(eqOpExprLineage2); + if (inputRef2 == null) { + // Bail out as this condition could not be map into an input reference + continue; + } + // Add to list of expressions for follow-up iterations + eqOp2ExprsFiltered.add(inputRef2); + // Add to equivalence classes and backwards mapping + ec.addEquivalence(inputRef1, inputRef2, equiCond); + candidatePredicates.add(equiCond); + } + } else { + // Rest of iterations, only adding, no checking + for (RexTableInputRef inputRef2 : eqOp2ExprsFiltered) { + ec.addEquivalence(inputRef1, inputRef2, equiCond); + } + } + } + if (!candidatePredicates.contains(conj)) { + // We add it to residual already + residualPreds.add(conj); + } + } + if (ec.getEquivalenceClassesMap().isEmpty()) { + // This may be a cartesian product, we bail out + return cannotExtract; + } + + // 4) For each table, check whether there is a matching on the non-FK side. + // If there is and it is the only condition, we are ready to transform + final RelTableRef nonFkTable = nonFkTables.iterator().next(); + final List nonFkTableQName = nonFkTable.getQualifiedName(); + for (RelTableRef tRef : fkTables) { + List constraints = tRef.getTable().getReferentialConstraints(); + for (RelReferentialConstraint constraint : constraints) { + if (constraint.getTargetQualifiedName().equals(nonFkTableQName)) { + EquivalenceClasses ecT = EquivalenceClasses.copy(ec); + Set removedOriginalPredicates = new HashSet<>(); + ImmutableBitSet.Builder lBitSet = ImmutableBitSet.builder(); + ImmutableBitSet.Builder rBitSet = ImmutableBitSet.builder(); + boolean allContained = true; + for (int pos = 0; pos < constraint.getNumColumns(); pos++) { + int foreignKeyPos = constraint.getColumnPairs().get(pos).source; + RelDataType foreignKeyColumnType = + tRef.getTable().getRowType().getFieldList().get(foreignKeyPos).getType(); + RexTableInputRef foreignKeyColumnRef = + RexTableInputRef.of(tRef, foreignKeyPos, foreignKeyColumnType); + int uniqueKeyPos = constraint.getColumnPairs().get(pos).target; + RexTableInputRef uniqueKeyColumnRef = RexTableInputRef.of(nonFkTable, uniqueKeyPos, + nonFkTable.getTable().getRowType().getFieldList().get(uniqueKeyPos).getType()); + if (ecT.getEquivalenceClassesMap().containsKey(uniqueKeyColumnRef) && + ecT.getEquivalenceClassesMap().get(uniqueKeyColumnRef).contains(foreignKeyColumnRef)) { + // Remove this condition from eq classes as we have checked that it is present + // in the join condition. In turn, populate the columns that are referenced + // from the join inputs + for (RexCall originalPred : ecT.removeEquivalence(uniqueKeyColumnRef, foreignKeyColumnRef)) { + ImmutableBitSet leftCols = RelOptUtil.InputFinder.bits(originalPred.getOperands().get(0)); + ImmutableBitSet rightCols = RelOptUtil.InputFinder.bits(originalPred.getOperands().get(1)); + // Get length and flip column references if join condition specified in + // reverse order to join sources + int nFieldsLeft = join.getLeft().getRowType().getFieldList().size(); + int nFieldsRight = join.getRight().getRowType().getFieldList().size(); + int nSysFields = join.getSystemFieldList().size(); + ImmutableBitSet rightFieldsBitSet = ImmutableBitSet.range(nSysFields + nFieldsLeft, + nSysFields + nFieldsLeft + nFieldsRight); + if (rightFieldsBitSet.contains(leftCols)) { + ImmutableBitSet t = leftCols; + leftCols = rightCols; + rightCols = t; + } + lBitSet.set(leftCols.nextSetBit(0) - nSysFields); + rBitSet.set(rightCols.nextSetBit(0) - (nSysFields + nFieldsLeft)); + removedOriginalPredicates.add(originalPred); + } + } else { + // No relationship, we cannot do anything + allContained = false; + break; + } + } + if (allContained) { + // This is a PK-FK, reassign equivalence classes and remove conditions + // TODO: Support inference of multiple PK-FK relationships + + // 4.1) Add to residual whatever is remaining + candidatePredicates.removeAll(removedOriginalPredicates); + residualPreds.addAll(candidatePredicates); + // 4.2) Return result + return Triple.of(true, Pair.of(lBitSet.build(), rBitSet.build()), residualPreds); + } + } + } + } + + return cannotExtract; + } + + public static Pair> isRewritablePKFKJoin(Join join, boolean leftInputPotentialFK, RelMetadataQuery mq) { final JoinRelType joinType = join.getJoinType(); final RexNode cond = join.getCondition(); @@ -548,10 +727,9 @@ public static boolean isRowFilteringPlan(final RelMetadataQuery mq, RelNode oper // If it is not an inner, we transform it as the metadata // providers for expressions do not pull information through // outer join (as it would not be correct) - join = (Join) builder - .push(join.getLeft()).push(join.getRight()) - .join(JoinRelType.INNER, cond) - .build(); + join = join.copy(join.getTraitSet(), cond, + join.getLeft(), join.getRight(), JoinRelType.INNER, + false); } // 1) Check whether there is any filtering condition on the @@ -602,13 +780,13 @@ public static boolean isRowFilteringPlan(final RelMetadataQuery mq, RelNode oper // Add to list of expressions for follow-up iterations eqOp2ExprsFiltered.add(inputRef2); // Add to equivalence classes and backwards mapping - ec.addEquivalenceClass(inputRef1, inputRef2); + ec.addEquivalence(inputRef1, inputRef2); refToRex.put(inputRef2, eqOp2); } } else { // Rest of iterations, only adding, no checking for (RexTableInputRef inputRef2 : eqOp2ExprsFiltered) { - ec.addEquivalenceClass(inputRef1, inputRef2); + ec.addEquivalence(inputRef1, inputRef2); } } } @@ -665,14 +843,7 @@ public static boolean isRowFilteringPlan(final RelMetadataQuery mq, RelNode oper } // Remove this condition from eq classes as we have checked that it is present // in the join condition - ecT.getEquivalenceClassesMap().get(uniqueKeyColumnRef).remove(foreignKeyColumnRef); - if (ecT.getEquivalenceClassesMap().get(uniqueKeyColumnRef).size() == 1) { // self - ecT.getEquivalenceClassesMap().remove(uniqueKeyColumnRef); - } - ecT.getEquivalenceClassesMap().get(foreignKeyColumnRef).remove(uniqueKeyColumnRef); - if (ecT.getEquivalenceClassesMap().get(foreignKeyColumnRef).size() == 1) { // self - ecT.getEquivalenceClassesMap().remove(foreignKeyColumnRef); - } + ecT.removeEquivalence(uniqueKeyColumnRef, foreignKeyColumnRef); } else { // No relationship, we cannot do anything allContained = false; @@ -711,13 +882,23 @@ private static RexTableInputRef extractTableInputRef(RexNode node) { */ private static class EquivalenceClasses { + // Contains the node to equivalence class nodes private final Map> nodeToEquivalenceClass; + // Contains the pair of equivalences to original expression that they originate from + private final Multimap, RexCall> equivalenceToOriginalNode; protected EquivalenceClasses() { nodeToEquivalenceClass = new HashMap<>(); + equivalenceToOriginalNode = HashMultimap.create(); } - protected void addEquivalenceClass(RexTableInputRef p1, RexTableInputRef p2) { + protected void addEquivalence(RexTableInputRef p1, RexTableInputRef p2, RexCall originalCond) { + addEquivalence(p1, p2); + equivalenceToOriginalNode.put(Pair.of(p1, p2), originalCond); + equivalenceToOriginalNode.put(Pair.of(p2, p1), originalCond); + } + + protected void addEquivalence(RexTableInputRef p1, RexTableInputRef p2) { Set c1 = nodeToEquivalenceClass.get(p1); Set c2 = nodeToEquivalenceClass.get(p2); if (c1 != null && c2 != null) { @@ -754,11 +935,30 @@ protected void addEquivalenceClass(RexTableInputRef p1, RexTableInputRef p2) { return nodeToEquivalenceClass; } + // Returns the original nodes that the equivalences were generated from + protected Set removeEquivalence(RexTableInputRef p1, RexTableInputRef p2) { + nodeToEquivalenceClass.get(p1).remove(p2); + if (nodeToEquivalenceClass.get(p1).size() == 1) { // self + nodeToEquivalenceClass.remove(p1); + } + nodeToEquivalenceClass.get(p2).remove(p1); + if (nodeToEquivalenceClass.get(p2).size() == 1) { // self + nodeToEquivalenceClass.remove(p2); + } + Set originalNodes = new HashSet<>(); + originalNodes.addAll(equivalenceToOriginalNode.removeAll(Pair.of(p1, p2))); + originalNodes.addAll(equivalenceToOriginalNode.removeAll(Pair.of(p2, p1))); + return originalNodes; + } + protected static EquivalenceClasses copy(EquivalenceClasses ec) { final EquivalenceClasses newEc = new EquivalenceClasses(); for (Entry> e : ec.nodeToEquivalenceClass.entrySet()) { newEc.nodeToEquivalenceClass.put(e.getKey(), Sets.newLinkedHashSet(e.getValue())); } + for (Entry, Collection> e : ec.equivalenceToOriginalNode.asMap().entrySet()) { + newEc.equivalenceToOriginalNode.putAll(e.getKey(), e.getValue()); + } return newEc; } } diff --git a/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/rules/HiveJoinConstraintsRule.java b/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/rules/HiveJoinConstraintsRule.java index 534a5c9531..802d318fcd 100644 --- a/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/rules/HiveJoinConstraintsRule.java +++ b/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/rules/HiveJoinConstraintsRule.java @@ -220,7 +220,7 @@ public void onMatch(RelOptRuleCall call) { } // 2) Check whether this join can be rewritten or removed - Pair> r = HiveRelOptUtil.isRewritablePKFKJoin(call.builder(), + Pair> r = HiveRelOptUtil.isRewritablePKFKJoin( join, leftInput == fkInput, call.getMetadataQuery()); // 3) If it is the only condition, we can trigger the rewriting diff --git a/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/stats/HiveRelMdRowCount.java b/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/stats/HiveRelMdRowCount.java index 576ed34bf3..f1f9b670cc 100644 --- a/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/stats/HiveRelMdRowCount.java +++ b/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/stats/HiveRelMdRowCount.java @@ -34,6 +34,7 @@ import org.apache.calcite.rel.core.TableScan; import org.apache.calcite.rel.metadata.ReflectiveRelMetadataProvider; import org.apache.calcite.rel.metadata.RelMdRowCount; +import org.apache.calcite.rel.metadata.RelMdUtil; import org.apache.calcite.rel.metadata.RelMetadataProvider; import org.apache.calcite.rel.metadata.RelMetadataQuery; import org.apache.calcite.rex.RexBuilder; @@ -46,6 +47,8 @@ import org.apache.calcite.util.BuiltInMethod; import org.apache.calcite.util.ImmutableBitSet; import org.apache.calcite.util.Pair; +import org.apache.commons.lang3.tuple.Triple; +import org.apache.hadoop.hive.ql.optimizer.calcite.HiveRelOptUtil; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveTableScan; @@ -63,15 +66,33 @@ protected HiveRelMdRowCount() { } public Double getRowCount(Join join, RelMetadataQuery mq) { + // Try to infer from constraints first + Pair constraintBasedResult = + constraintsBasedAnalyzeJoinForPKFK(join, mq); + if (constraintBasedResult != null) { + // We succeeded, we calculate the selectivity based on the inferred information + // and any residual predicate + double joinSelectivity = Math.min(1.0, + constraintBasedResult.left.pkInfo.selectivity * constraintBasedResult.left.ndvScalingFactor); + double residualSelectivity = RelMdUtil.guessSelectivity(constraintBasedResult.right); + double rowCount = constraintBasedResult.left.fkInfo.rowCount * joinSelectivity * residualSelectivity; + if (LOG.isDebugEnabled()) { + LOG.debug("Identified Primary - Foreign Key relation from constraints:\n {} {} Row count for join: {}\n", + RelOptUtil.toString(join), constraintBasedResult.left, rowCount); + } + return rowCount; + } + // Otherwise, try to infer from stats PKFKRelationInfo pkfk = analyzeJoinForPKFK(join, mq); if (pkfk != null) { - double selectivity = (pkfk.pkInfo.selectivity * pkfk.ndvScalingFactor); + double selectivity = pkfk.pkInfo.selectivity * pkfk.ndvScalingFactor; selectivity = Math.min(1.0, selectivity); if (LOG.isDebugEnabled()) { LOG.debug("Identified Primary - Foreign Key relation: {} {}",RelOptUtil.toString(join), pkfk); } return pkfk.fkInfo.rowCount * selectivity; } + // If we cannot infer anything, then we just go to join.estimateRowCount(mq). // Do not call mq.getRowCount(join), will trigger CyclicMetadataException return join.estimateRowCount(mq); } @@ -80,7 +101,7 @@ public Double getRowCount(Join join, RelMetadataQuery mq) { public Double getRowCount(SemiJoin rel, RelMetadataQuery mq) { PKFKRelationInfo pkfk = analyzeJoinForPKFK(rel, mq); if (pkfk != null) { - double selectivity = (pkfk.pkInfo.selectivity * pkfk.ndvScalingFactor); + double selectivity = pkfk.pkInfo.selectivity * pkfk.ndvScalingFactor; selectivity = Math.min(1.0, selectivity); if (LOG.isDebugEnabled()) { LOG.debug("Identified Primary - Foreign Key relation: {} {}", RelOptUtil.toString(rel), pkfk); @@ -217,10 +238,10 @@ public static PKFKRelationInfo analyzeJoinForPKFK(Join joinRel, RelMetadataQuery int rightColIdx = joinCols.right; RexBuilder rexBuilder = joinRel.getCluster().getRexBuilder(); - RexNode leftPred = RexUtil - .composeConjunction(rexBuilder, leftFilters, true); - RexNode rightPred = RexUtil.composeConjunction(rexBuilder, rightFilters, - true); + RexNode leftPred = RexUtil.composeConjunction( + rexBuilder, leftFilters, true); + RexNode rightPred = RexUtil.composeConjunction( + rexBuilder, rightFilters, true); ImmutableBitSet lBitSet = ImmutableBitSet.of(leftColIdx); ImmutableBitSet rBitSet = ImmutableBitSet.of(rightColIdx); @@ -228,11 +249,10 @@ public static PKFKRelationInfo analyzeJoinForPKFK(Join joinRel, RelMetadataQuery * If the form is Dim loj F or Fact roj Dim or Dim semij Fact then return * null. */ - boolean leftIsKey = (joinRel.getJoinType() == JoinRelType.INNER || joinRel - .getJoinType() == JoinRelType.RIGHT) - && !(joinRel instanceof SemiJoin) && isKey(lBitSet, left, mq); - boolean rightIsKey = (joinRel.getJoinType() == JoinRelType.INNER || joinRel - .getJoinType() == JoinRelType.LEFT) && isKey(rBitSet, right, mq); + boolean leftIsKey = (joinRel.getJoinType() == JoinRelType.INNER || joinRel.getJoinType() == JoinRelType.RIGHT) + && isKey(lBitSet, left, mq); + boolean rightIsKey = (joinRel.getJoinType() == JoinRelType.INNER || joinRel.getJoinType() == JoinRelType.LEFT) + && isKey(rBitSet, right, mq); if (!leftIsKey && !rightIsKey) { return null; @@ -247,41 +267,37 @@ public static PKFKRelationInfo analyzeJoinForPKFK(Join joinRel, RelMetadataQuery } } - int pkSide = leftIsKey ? 0 : rightIsKey ? 1 : -1; - - boolean isPKSideSimpleTree = pkSide != -1 ? - IsSimpleTreeOnJoinKey.check( - pkSide == 0 ? left : right, - pkSide == 0 ? leftColIdx : rightColIdx, mq) : false; - - double leftNDV = isPKSideSimpleTree ? mq.getDistinctRowCount(left, lBitSet, leftPred) : -1; - double rightNDV = isPKSideSimpleTree ? mq.getDistinctRowCount(right, rBitSet, rightPred) : -1; - - /* - * If the ndv of the PK - FK side don't match, and the PK side is a filter - * on the Key column then scale the NDV on the FK side. - * - * As described by Peter Boncz: http://databasearchitects.blogspot.com/ - * in such cases we can be off by a large margin in the Join cardinality - * estimate. The e.g. he provides is on the join of StoreSales and DateDim - * on the TPCDS dataset. Since the DateDim is populated for 20 years into - * the future, while the StoreSales only has 5 years worth of data, there - * are 40 times fewer distinct dates in StoreSales. - * - * In general it is hard to infer the range for the foreign key on an - * arbitrary expression. For e.g. the NDV for DayofWeek is the same - * irrespective of NDV on the number of unique days, whereas the - * NDV of Quarters has the same ratio as the NDV on the keys. - * - * But for expressions that apply only on columns that have the same NDV - * as the key (implying that they are alternate keys) we can apply the - * ratio. So in the case of StoreSales - DateDim joins for predicate on the - * d_date column we can apply the scaling factor. - */ - double ndvScalingFactor = 1.0; - if ( isPKSideSimpleTree ) { - ndvScalingFactor = pkSide == 0 ? leftNDV/rightNDV : rightNDV / leftNDV; - } + int pkSide = leftIsKey ? 0 : 1; + boolean isPKSideSimpleTree = leftIsKey ? SimpleTreeOnJoinKey.check(false, left, lBitSet, mq) : + SimpleTreeOnJoinKey.check(false, right, rBitSet, mq); + double leftNDV = isPKSideSimpleTree ? mq.getDistinctRowCount(left, lBitSet, leftPred) : -1; + double rightNDV = isPKSideSimpleTree ? mq.getDistinctRowCount(right, rBitSet, rightPred) : -1; + + /* + * If the ndv of the PK - FK side don't match, and the PK side is a filter + * on the Key column then scale the NDV on the FK side. + * + * As described by Peter Boncz: http://databasearchitects.blogspot.com/ + * in such cases we can be off by a large margin in the Join cardinality + * estimate. The e.g. he provides is on the join of StoreSales and DateDim + * on the TPCDS dataset. Since the DateDim is populated for 20 years into + * the future, while the StoreSales only has 5 years worth of data, there + * are 40 times fewer distinct dates in StoreSales. + * + * In general it is hard to infer the range for the foreign key on an + * arbitrary expression. For e.g. the NDV for DayofWeek is the same + * irrespective of NDV on the number of unique days, whereas the + * NDV of Quarters has the same ratio as the NDV on the keys. + * + * But for expressions that apply only on columns that have the same NDV + * as the key (implying that they are alternate keys) we can apply the + * ratio. So in the case of StoreSales - DateDim joins for predicate on the + * d_date column we can apply the scaling factor. + */ + double ndvScalingFactor = 1.0; + if ( isPKSideSimpleTree ) { + ndvScalingFactor = pkSide == 0 ? leftNDV/rightNDV : rightNDV / leftNDV; + } if (pkSide == 0) { FKSideInfo fkInfo = new FKSideInfo(rightRowCount, @@ -293,9 +309,7 @@ public static PKFKRelationInfo analyzeJoinForPKFK(Join joinRel, RelMetadataQuery pkSelectivity); return new PKFKRelationInfo(1, fkInfo, pkInfo, ndvScalingFactor, isPKSideSimpleTree); - } - - if (pkSide == 1) { + } else { // pkSide == 1 FKSideInfo fkInfo = new FKSideInfo(leftRowCount, leftNDV); double pkSelectivity = pkSelectivity(joinRel, mq, false, right, rightRowCount); @@ -304,10 +318,114 @@ public static PKFKRelationInfo analyzeJoinForPKFK(Join joinRel, RelMetadataQuery joinRel.getJoinType().generatesNullsOnLeft() ? 1.0 : pkSelectivity); - return new PKFKRelationInfo(1, fkInfo, pkInfo, ndvScalingFactor, isPKSideSimpleTree); + return new PKFKRelationInfo(0, fkInfo, pkInfo, ndvScalingFactor, isPKSideSimpleTree); + } + } + + /* + * + */ + public static Pair constraintsBasedAnalyzeJoinForPKFK(Join join, RelMetadataQuery mq) { + + if (join instanceof SemiJoin) { + // TODO: Support semijoin + return null; } - return null; + final RelNode left = join.getInputs().get(0); + final RelNode right = join.getInputs().get(1); + + // 1) Split filters in conjuncts + final List condConjs = RelOptUtil.conjunctions( + join.getCondition()); + + if (condConjs.isEmpty()) { + // Bail out + return null; + } + + // 2) Classify filters depending on their provenance + final List joinFilters = new ArrayList<>(condConjs); + final List leftFilters = new ArrayList<>(); + final List rightFilters = new ArrayList<>(); + RelOptUtil.classifyFilters(join, joinFilters, join.getJoinType(),false, + !join.getJoinType().generatesNullsOnRight(), !join.getJoinType().generatesNullsOnLeft(), + joinFilters, leftFilters, rightFilters); + + // 3) Check if we are joining on PK-FK + final Triple, List> leftInputResult = + HiveRelOptUtil.extractPKFKJoin(join, joinFilters, false, mq); + final Triple, List> rightInputResult = + HiveRelOptUtil.extractPKFKJoin(join, joinFilters, true, mq); + if (leftInputResult == null && rightInputResult == null) { + // Nothing to do here, bail out + return null; + } + + boolean leftIsKey = (join.getJoinType() == JoinRelType.INNER || join.getJoinType() == JoinRelType.RIGHT) + && leftInputResult.getLeft(); + boolean rightIsKey = (join.getJoinType() == JoinRelType.INNER || join.getJoinType() == JoinRelType.LEFT) + && rightInputResult.getLeft(); + if (!leftIsKey && !rightIsKey) { + // Nothing to do here, bail out + return null; + } + final double leftRowCount = mq.getRowCount(left); + final double rightRowCount = mq.getRowCount(right); + if (leftIsKey && rightIsKey) { + if (rightRowCount < leftRowCount) { + leftIsKey = false; + } + } + final ImmutableBitSet lBitSet = leftIsKey ? leftInputResult.getMiddle().left : rightInputResult.getMiddle().left; + final ImmutableBitSet rBitSet = leftIsKey ? leftInputResult.getMiddle().right : rightInputResult.getMiddle().right; + final List residualFilters = leftIsKey ? leftInputResult.getRight() : rightInputResult.getRight(); + + // 4) Extract additional information on the PK-FK relationship + int pkSide = leftIsKey ? 0 : 1; + boolean isPKSideSimpleTree = leftIsKey ? SimpleTreeOnJoinKey.check(true, left, lBitSet, mq) : + SimpleTreeOnJoinKey.check(true, right, rBitSet, mq); + RexBuilder rexBuilder = join.getCluster().getRexBuilder(); + RexNode leftPred = RexUtil.composeConjunction( + rexBuilder, leftFilters, true); + RexNode rightPred = RexUtil.composeConjunction( + rexBuilder, rightFilters, true); + double leftNDV = isPKSideSimpleTree ? mq.getDistinctRowCount(left, lBitSet, leftPred) : -1; + double rightNDV = isPKSideSimpleTree ? mq.getDistinctRowCount(right, rBitSet, rightPred) : -1; + + // 5) Add the rest of operators back to the join filters + // and create residual condition + RexNode residualCond = residualFilters.isEmpty() ? null : + residualFilters.size() == 1 ? residualFilters.get(0) : + rexBuilder.makeCall(SqlStdOperatorTable.AND, residualFilters); + + // 6) Return result + if (pkSide == 0) { + FKSideInfo fkInfo = new FKSideInfo(rightRowCount, + rightNDV); + double pkSelectivity = pkSelectivity(join, mq, true, left, leftRowCount); + PKSideInfo pkInfo = new PKSideInfo(leftRowCount, + leftNDV, + join.getJoinType().generatesNullsOnRight() ? 1.0 : + pkSelectivity); + double ndvScalingFactor = isPKSideSimpleTree ? leftNDV/rightNDV : 1.0; + if (isPKSideSimpleTree) { + ndvScalingFactor = leftNDV/rightNDV; + } + return Pair.of(new PKFKRelationInfo(1, fkInfo, pkInfo, ndvScalingFactor, isPKSideSimpleTree), + residualCond); + } else { // pkSide == 1 + FKSideInfo fkInfo = new FKSideInfo(leftRowCount, + leftNDV); + double pkSelectivity = pkSelectivity(join, mq, false, right, rightRowCount); + PKSideInfo pkInfo = new PKSideInfo(rightRowCount, + rightNDV, + join.getJoinType().generatesNullsOnLeft() ? 1.0 : + pkSelectivity); + double ndvScalingFactor = isPKSideSimpleTree ? rightNDV/leftNDV : 1.0; + return Pair.of(new PKFKRelationInfo(0, fkInfo, pkInfo, ndvScalingFactor, isPKSideSimpleTree), + residualCond); + } } private static double pkSelectivity(Join joinRel, RelMetadataQuery mq, boolean leftChild, @@ -402,20 +520,22 @@ private static boolean isKey(ImmutableBitSet c, RelNode rel, RelMetadataQuery mq return new Pair(leftColIdx, rightColIdx); } - private static class IsSimpleTreeOnJoinKey extends RelVisitor { + private static class SimpleTreeOnJoinKey extends RelVisitor { - int joinKey; + boolean constraintsBased; + ImmutableBitSet joinKey; boolean simpleTree; RelMetadataQuery mq; - static boolean check(RelNode r, int joinKey, RelMetadataQuery mq) { - IsSimpleTreeOnJoinKey v = new IsSimpleTreeOnJoinKey(joinKey, mq); + static boolean check(boolean constraintsBased, RelNode r, ImmutableBitSet joinKey, RelMetadataQuery mq) { + SimpleTreeOnJoinKey v = new SimpleTreeOnJoinKey(constraintsBased, joinKey, mq); v.go(r); return v.simpleTree; } - IsSimpleTreeOnJoinKey(int joinKey, RelMetadataQuery mq) { + SimpleTreeOnJoinKey(boolean constraintsBased, ImmutableBitSet joinKey, RelMetadataQuery mq) { super(); + this.constraintsBased = constraintsBased; this.joinKey = joinKey; this.mq = mq; simpleTree = true; @@ -444,16 +564,23 @@ public void visit(RelNode node, int ordinal, RelNode parent) { } private boolean isSimple(Project project) { - RexNode r = project.getProjects().get(joinKey); - if (r instanceof RexInputRef) { - joinKey = ((RexInputRef) r).getIndex(); - return true; + ImmutableBitSet.Builder b = ImmutableBitSet.builder(); + for (int pos : joinKey) { + RexNode r = project.getProjects().get(pos); + if (!(r instanceof RexInputRef)) { + return false; + } + b.set(((RexInputRef) r).getIndex()); } - return false; + joinKey = b.build(); + return true; } private boolean isSimple(Filter filter, RelMetadataQuery mq) { ImmutableBitSet condBits = RelOptUtil.InputFinder.bits(filter.getCondition()); + if (constraintsBased) { + return mq.areColumnsUnique(filter, condBits); + } return isKey(condBits, filter, mq); } diff --git a/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/stats/HiveRelMdSelectivity.java b/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/stats/HiveRelMdSelectivity.java index 575902d78d..7e9208229a 100644 --- a/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/stats/HiveRelMdSelectivity.java +++ b/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/stats/HiveRelMdSelectivity.java @@ -98,8 +98,7 @@ private Double computeInnerJoinSelectivity(Join j, RelMetadataQuery mq, RexNode } catch (CalciteSemanticException e) { throw new RuntimeException(e); } - ImmutableMap.Builder colStatMapBuilder = ImmutableMap - .builder(); + ImmutableMap.Builder colStatMapBuilder = ImmutableMap.builder(); ImmutableMap colStatMap; int rightOffSet = j.getLeft().getRowType().getFieldCount();