diff --git ql/src/java/org/apache/hadoop/hive/ql/optimizer/stats/annotation/StatsRulesProcFactory.java ql/src/java/org/apache/hadoop/hive/ql/optimizer/stats/annotation/StatsRulesProcFactory.java index a8ff158..ea6b9ef 100644 --- ql/src/java/org/apache/hadoop/hive/ql/optimizer/stats/annotation/StatsRulesProcFactory.java +++ ql/src/java/org/apache/hadoop/hive/ql/optimizer/stats/annotation/StatsRulesProcFactory.java @@ -1204,9 +1261,10 @@ public Object process(Node nd, Stack stack, NodeProcessorCtx procCtx, private long inferPKFKRelationship(int numAttr, List> parents, CommonJoinOperator jop) { long newNumRows = -1; - if (numAttr == 1) { + int pos = getPrimaryKeyPos(parents); + if (pos != -1) { // If numAttr is 1, this means we join on one single key column. - Map parentsWithPK = getPrimaryKeyCandidates(parents); + Map parentsWithPK = getPrimaryKeyCandidates(pos, parents); // We only allow one single PK. if (parentsWithPK.size() != 1) { @@ -1217,7 +1275,7 @@ private long inferPKFKRelationship(int numAttr, List csFKs = getForeignKeyCandidates(parents, csPK); + Map csFKs = getForeignKeyCandidates(pos, parents, csPK); // we allow multiple foreign keys (snowflake schema) // csfKs.size() + 1 == parents.size() means we have a single PK and all @@ -1369,7 +1427,7 @@ private float getSelectivityComplexTree(Operator op) { * @param csPK - column statistics of primary key * @return - a map which contains position ids and the corresponding column statistics */ - private Map getForeignKeyCandidates(List> ops, + private Map getForeignKeyCandidates(int pos, List> ops, ColStatistics csPK) { Map result = new HashMap(); if (csPK == null || ops == null) { @@ -1381,14 +1439,12 @@ private float getSelectivityComplexTree(Operator op) { if (op != null && op instanceof ReduceSinkOperator) { ReduceSinkOperator rsOp = (ReduceSinkOperator) op; List keys = StatsUtils.getQualifedReducerKeyNames(rsOp.getConf().getOutputKeyColumnNames()); - if (keys.size() == 1) { - String joinCol = keys.get(0); - if (rsOp.getStatistics() != null) { - ColStatistics cs = rsOp.getStatistics().getColumnStatisticsFromColName(joinCol); - if (cs != null && !cs.isPrimaryKey()) { - if (StatsUtils.inferForeignKey(csPK, cs)) { - result.put(i,cs); - } + String joinCol = keys.get(pos); + if (rsOp.getStatistics() != null) { + ColStatistics cs = rsOp.getStatistics().getColumnStatisticsFromColName(joinCol); + if (cs != null && !cs.isPrimaryKey()) { + if (StatsUtils.inferForeignKey(csPK, cs)) { + result.put(i, cs); } } } @@ -1402,7 +1458,7 @@ private float getSelectivityComplexTree(Operator op) { * @param ops - operators * @return - list of primary key containing parent ids */ - private Map getPrimaryKeyCandidates(List> ops) { + private Map getPrimaryKeyCandidates(int pos, List> ops) { Map result = new HashMap(); if (ops != null && !ops.isEmpty()) { for (int i = 0; i < ops.size(); i++) { @@ -1410,8 +1466,8 @@ private float getSelectivityComplexTree(Operator op) { if (op instanceof ReduceSinkOperator) { ReduceSinkOperator rsOp = (ReduceSinkOperator) op; List keys = StatsUtils.getQualifedReducerKeyNames(rsOp.getConf().getOutputKeyColumnNames()); - if (keys.size() == 1) { - String joinCol = keys.get(0); + if (keys.size() > pos) { + String joinCol = keys.get(pos); if (rsOp.getStatistics() != null) { ColStatistics cs = rsOp.getStatistics().getColumnStatisticsFromColName(joinCol); if (cs != null && cs.isPrimaryKey()) { @@ -1424,6 +1480,38 @@ private float getSelectivityComplexTree(Operator op) { } return result; } + + /** + * Returns the index of the primary key join + * @param ops - operators + * @return - index of primary key if there's only 1 + */ + private int getPrimaryKeyPos(List> ops) { + boolean found = false; + int alias = -1; + if (ops != null && !ops.isEmpty()) { + for (int i = 0; i < ops.size(); i++) { + Operator op = ops.get(i); + if (op instanceof ReduceSinkOperator) { + ReduceSinkOperator rsOp = (ReduceSinkOperator) op; + List keys = StatsUtils.getQualifedReducerKeyNames(rsOp.getConf().getOutputKeyColumnNames()); + for (String joinCol : keys) { + if (rsOp.getStatistics() != null) { + ColStatistics cs = rsOp.getStatistics().getColumnStatisticsFromColName(joinCol); + if (cs != null && cs.isPrimaryKey()) { + if (found) { + return -1; + } + found = true; + alias = keys.indexOf(joinCol); + } + } + } + } + } + } + return alias; + } private Long getEasedOutDenominator(List distinctVals) { // Exponential back-off for NDVs.