diff --git a/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/translator/HiveOpConverter.java b/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/translator/HiveOpConverter.java index d15520c..6df8e09 100644 --- a/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/translator/HiveOpConverter.java +++ b/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/translator/HiveOpConverter.java @@ -114,6 +114,7 @@ private final UnparseTranslator unparseTranslator; private final Map> topOps; private final boolean strictMode; + private int uniqueCounter; public HiveOpConverter(SemanticAnalyzer semanticAnalyzer, HiveConf hiveConf, UnparseTranslator unparseTranslator, Map> topOps, @@ -123,6 +124,7 @@ public HiveOpConverter(SemanticAnalyzer semanticAnalyzer, HiveConf hiveConf, this.unparseTranslator = unparseTranslator; this.topOps = topOps; this.strictMode = strictMode; + this.uniqueCounter = 0; } static class OpAttr { @@ -147,6 +149,10 @@ public Operator convert(RelNode root) throws SemanticException { } OpAttr dispatch(RelNode rn) throws SemanticException { + return dispatch(rn, null); + } + + OpAttr dispatch(RelNode rn, ExprNodeDesc[] exprNodeDescs) throws SemanticException { if (rn instanceof HiveTableScan) { return visit((HiveTableScan) rn); } else if (rn instanceof HiveProject) { @@ -165,7 +171,7 @@ OpAttr dispatch(RelNode rn) throws SemanticException { } else if (rn instanceof HiveUnion) { return visit((HiveUnion) rn); } else if (rn instanceof SortExchange) { - return visit((SortExchange) rn); + return visit((SortExchange) rn, exprNodeDescs); } else if (rn instanceof HiveAggregate) { return visit((HiveAggregate) rn); } @@ -295,12 +301,24 @@ OpAttr visit(HiveProject projectRel) throws SemanticException { } OpAttr visit(HiveJoin joinRel) throws SemanticException { - // 1. Convert inputs + // 0. Additional data structures needed for the join optimization + // through Hive + String[] baseSrc = new String[joinRel.getInputs().size()]; + + // 1. Convert join condition + JoinPredicateInfo joinPredInfo = JoinPredicateInfo.constructJoinPredicateInfo(joinRel); + + // 2. Extract join keys from condition + ExprNodeDesc[][] joinKeys = extractJoinKeys(joinPredInfo, joinRel.getInputs()); + + // 3. Convert inputs OpAttr[] inputs = new OpAttr[joinRel.getInputs().size()]; List> children = new ArrayList>(joinRel.getInputs().size()); for (int i = 0; i < inputs.length; i++) { - inputs[i] = dispatch(joinRel.getInput(i)); + inputs[i] = dispatch(joinRel.getInput(i), joinKeys[i]); children.add(inputs[i].inputs.get(0)); + baseSrc[i] = "inp" + i + "-" + uniqueCounter; + uniqueCounter++; } if (LOG.isDebugEnabled()) { @@ -308,19 +326,14 @@ OpAttr visit(HiveJoin joinRel) throws SemanticException { + " with row type: [" + joinRel.getRowType() + "]"); } - // 2. Convert join condition - JoinPredicateInfo joinPredInfo = JoinPredicateInfo.constructJoinPredicateInfo(joinRel); - - // 3. Extract join keys from condition - ExprNodeDesc[][] joinKeys = extractJoinKeys(joinPredInfo, joinRel.getInputs()); - // 4.a Generate tags for (int tag=0; tag inputs) { ExprNodeDesc[][] joinKeys = new ExprNodeDesc[inputs.size()][]; - for (int i = 0; i < inputs.size(); i++) { - joinKeys[i] = new ExprNodeDesc[joinPredInfo.getEquiJoinPredicateElements().size()]; - for (int j = 0; j < joinPredInfo.getEquiJoinPredicateElements().size(); j++) { - JoinLeafPredicateInfo joinLeafPredInfo = joinPredInfo.getEquiJoinPredicateElements().get(j); - RexNode key = joinLeafPredInfo.getJoinKeyExprs(j).get(0); - joinKeys[i][j] = convertToExprNode(key, inputs.get(j), null); + // Assumption here is that inputs only contains 2 inputs. input=0 is left, input=1 is right + for (int input = 0; input < inputs.size(); input++) { + joinKeys[input] = new ExprNodeDesc[joinPredInfo.getEquiJoinPredicateElements().size()]; + for (int element = 0; element < joinPredInfo.getEquiJoinPredicateElements().size(); element++) { + JoinLeafPredicateInfo joinLeafPredInfo = joinPredInfo.getEquiJoinPredicateElements().get(element); + //getJoinKeyExprs(0) is left, getJoinKeyExprs(1) is right + RexNode key = joinLeafPredInfo.getJoinKeyExprs(input).get(0); + joinKeys[input][element] = convertToExprNode(key, inputs.get(input), null); } } return joinKeys; @@ -719,7 +740,7 @@ private static ReduceSinkOperator genReduceSink(Operator input, ExprNodeDesc[ } private static JoinOperator genJoin(HiveJoin hiveJoin, JoinPredicateInfo joinPredInfo, - List> children, ExprNodeDesc[][] joinKeys) throws SemanticException { + List> children, ExprNodeDesc[][] joinKeys, String[] aliases) throws SemanticException { // Extract join type JoinType joinType = extractJoinType(hiveJoin); @@ -770,6 +791,7 @@ private static JoinOperator genJoin(HiveJoin hiveJoin, JoinPredicateInfo joinPre for (int i = 0; i < index.length; i++) { ColumnInfo info = new ColumnInfo(parentColumns.get(i)); info.setInternalName(outputColumnNames.get(outputPos)); + info.setTabAlias(aliases[pos]); outputColumns.add(info); reversedExprs.put(outputColumnNames.get(outputPos), tag); outputPos++;