diff --git a/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/translator/HiveOpConverter.java b/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/translator/HiveOpConverter.java index 4f19caf..cb432de 100644 --- a/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/translator/HiveOpConverter.java +++ b/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/translator/HiveOpConverter.java @@ -27,6 +27,7 @@ import java.util.Map; import java.util.Set; +import org.apache.calcite.plan.RelOptUtil; import org.apache.calcite.rel.RelCollations; import org.apache.calcite.rel.RelDistribution; import org.apache.calcite.rel.RelDistribution.Type; @@ -115,6 +116,7 @@ private final UnparseTranslator unparseTranslator; private final Map> topOps; private final boolean strictMode; + private int uniqueCounter; public HiveOpConverter(SemanticAnalyzer semanticAnalyzer, HiveConf hiveConf, UnparseTranslator unparseTranslator, Map> topOps, @@ -124,6 +126,7 @@ public HiveOpConverter(SemanticAnalyzer semanticAnalyzer, HiveConf hiveConf, this.unparseTranslator = unparseTranslator; this.topOps = topOps; this.strictMode = strictMode; + this.uniqueCounter = 0; } static class OpAttr { @@ -307,13 +310,22 @@ OpAttr visit(HiveJoin joinRel) throws SemanticException { return translateJoin(joinRel); } + private String getHiveDerivedTableAlias() { + return "$hdt$_" + (this.uniqueCounter++); + } + private OpAttr translateJoin(RelNode joinRel) throws SemanticException { + // 0. Additional data structures needed for the join optimization + // through Hive + String[] baseSrc = new String[joinRel.getInputs().size()]; + String tabAlias = getHiveDerivedTableAlias(); // 1. Convert inputs OpAttr[] inputs = new OpAttr[joinRel.getInputs().size()]; List> children = new ArrayList>(joinRel.getInputs().size()); for (int i = 0; i < inputs.length; i++) { inputs[i] = dispatch(joinRel.getInput(i)); children.add(inputs[i].inputs.get(0)); + baseSrc[i] = inputs[i].tabAlias; } if (LOG.isDebugEnabled()) { @@ -341,7 +353,8 @@ private OpAttr translateJoin(RelNode joinRel) throws SemanticException { reduceSinkOp.getConf().setTag(tag); } // 4.b Generate Join operator - JoinOperator joinOp = genJoin(joinRel, joinPredInfo, children, joinExpressions); + JoinOperator joinOp = genJoin(joinRel, joinPredInfo, children, joinExpressions, baseSrc, tabAlias); + joinOp.getConf().setBaseSrc(baseSrc); // 5. TODO: Extract condition for non-equi join elements (if any) and // add it @@ -359,7 +372,7 @@ private OpAttr translateJoin(RelNode joinRel) throws SemanticException { } // 7. Return result - return new OpAttr(null, newVcolsInCalcite, joinOp); + return new OpAttr(tabAlias, newVcolsInCalcite, joinOp); } OpAttr visit(HiveAggregate aggRel) throws SemanticException { @@ -495,11 +508,26 @@ OpAttr visit(HiveFilter filterRel) throws SemanticException { return inputOpAf.clone(filOp); } + // use this function to make the union "flat" for both execution and explain + // purpose + private List extractRelNodeFromUnion(HiveUnion unionRel) { + List ret = new ArrayList(); + for (RelNode input : unionRel.getInputs()) { + if (input instanceof HiveUnion) { + ret.addAll(extractRelNodeFromUnion((HiveUnion) input)); + } else { + ret.add(input); + } + } + return ret; + } + OpAttr visit(HiveUnion unionRel) throws SemanticException { // 1. Convert inputs - OpAttr[] inputs = new OpAttr[unionRel.getInputs().size()]; + List inputsList = extractRelNodeFromUnion(unionRel); + OpAttr[] inputs = new OpAttr[inputsList.size()]; for (int i = 0; i < inputs.length; i++) { - inputs[i] = dispatch(unionRel.getInput(i)); + inputs[i] = dispatch(inputsList.get(i)); } if (LOG.isDebugEnabled()) { @@ -510,7 +538,8 @@ OpAttr visit(HiveUnion unionRel) throws SemanticException { // 2. Create a new union operator UnionDesc unionDesc = new UnionDesc(); unionDesc.setNumInputs(inputs.length); - ArrayList cinfoLst = createColInfos(inputs[0].inputs.get(0)); + String tableAlias = getHiveDerivedTableAlias(); + ArrayList cinfoLst = createColInfos(inputs[0].inputs.get(0), tableAlias); Operator[] children = new Operator[inputs.length]; for (int i = 0; i < children.length; i++) { children[i] = inputs[i].inputs.get(0); @@ -524,11 +553,16 @@ OpAttr visit(HiveUnion unionRel) throws SemanticException { //TODO: Can columns retain virtualness out of union // 3. Return result - return inputs[0].clone(unionOp); + return new OpAttr(tableAlias, inputs[0].vcolsInCalcite, unionOp); + } OpAttr visit(HiveSortExchange exchangeRel) throws SemanticException { OpAttr inputOpAf = dispatch(exchangeRel.getInput()); + String tabAlias = inputOpAf.tabAlias; + if (tabAlias == null || tabAlias.length() == 0) { + tabAlias = getHiveDerivedTableAlias(); + } if (LOG.isDebugEnabled()) { LOG.debug("Translating operator rel#" + exchangeRel.getId() + ":" @@ -542,14 +576,14 @@ OpAttr visit(HiveSortExchange exchangeRel) throws SemanticException { ExprNodeDesc[] expressions = new ExprNodeDesc[exchangeRel.getJoinKeys().size()]; for (int index = 0; index < exchangeRel.getJoinKeys().size(); index++) { expressions[index] = convertToExprNode((RexNode) exchangeRel.getJoinKeys().get(index), - exchangeRel.getInput(), null, inputOpAf); + exchangeRel.getInput(), inputOpAf.tabAlias, inputOpAf); } exchangeRel.setJoinExpressions(expressions); - ReduceSinkOperator rsOp = genReduceSink(inputOpAf.inputs.get(0), expressions, + ReduceSinkOperator rsOp = genReduceSink(inputOpAf.inputs.get(0), tabAlias, expressions, -1, -1, Operation.NOT_ACID, strictMode); - return inputOpAf.clone(rsOp); + return new OpAttr(tabAlias, inputOpAf.vcolsInCalcite, rsOp); } private OpAttr genPTF(OpAttr inputOpAf, WindowingSpec wSpec) throws SemanticException { @@ -618,19 +652,6 @@ private OpAttr genPTF(OpAttr inputOpAf, WindowingSpec wSpec) throws SemanticExce return inputOpAf.clone(input); } - private ExprNodeDesc[][] extractJoinKeys(JoinPredicateInfo joinPredInfo, List inputs, OpAttr[] inputAttr) { - ExprNodeDesc[][] joinKeys = new ExprNodeDesc[inputs.size()][]; - for (int i = 0; i < inputs.size(); i++) { - joinKeys[i] = new ExprNodeDesc[joinPredInfo.getEquiJoinPredicateElements().size()]; - for (int j = 0; j < joinPredInfo.getEquiJoinPredicateElements().size(); j++) { - JoinLeafPredicateInfo joinLeafPredInfo = joinPredInfo.getEquiJoinPredicateElements().get(j); - RexNode key = joinLeafPredInfo.getJoinKeyExprs(j).get(0); - joinKeys[i][j] = convertToExprNode(key, inputs.get(j), null, inputAttr[i]); - } - } - return joinKeys; - } - private static SelectOperator genReduceSinkAndBacktrackSelect(Operator input, ExprNodeDesc[] keys, int tag, ArrayList partitionCols, String order, int numReducers, Operation acidOperation, boolean strictMode) throws SemanticException { @@ -643,8 +664,13 @@ private static SelectOperator genReduceSinkAndBacktrackSelect(Operator input, int numReducers, Operation acidOperation, boolean strictMode, List keepColNames) throws SemanticException { // 1. Generate RS operator - ReduceSinkOperator rsOp = genReduceSink(input, keys, tag, partitionCols, order, numReducers, - acidOperation, strictMode); + if (input.getSchema().getTableNames().size() != 1) { + throw new SemanticException( + "In CBO return path, genReduceSinkAndBacktrackSelect is expecting only one SelectOp but there is " + + input.getSchema().getTableNames().size()); + } + ReduceSinkOperator rsOp = genReduceSink(input, input.getSchema().getTableNames().iterator() + .next(), keys, tag, partitionCols, order, numReducers, acidOperation, strictMode); // 2. Generate backtrack Select operator Map descriptors = buildBacktrackFromReduceSink(keepColNames, @@ -664,14 +690,14 @@ private static SelectOperator genReduceSinkAndBacktrackSelect(Operator input, return selectOp; } - private static ReduceSinkOperator genReduceSink(Operator input, ExprNodeDesc[] keys, int tag, + private static ReduceSinkOperator genReduceSink(Operator input, String tableAlias, ExprNodeDesc[] keys, int tag, int numReducers, Operation acidOperation, boolean strictMode) throws SemanticException { - return genReduceSink(input, keys, tag, new ArrayList(), "", numReducers, + return genReduceSink(input, tableAlias, keys, tag, new ArrayList(), "", numReducers, acidOperation, strictMode); } @SuppressWarnings({ "rawtypes", "unchecked" }) - private static ReduceSinkOperator genReduceSink(Operator input, ExprNodeDesc[] keys, int tag, + private static ReduceSinkOperator genReduceSink(Operator input, String tableAlias, ExprNodeDesc[] keys, int tag, ArrayList partitionCols, String order, int numReducers, Operation acidOperation, boolean strictMode) throws SemanticException { Operator dummy = Operator.createDummy(); // dummy for backtracking @@ -698,7 +724,7 @@ private static ReduceSinkOperator genReduceSink(Operator input, ExprNodeDesc[ for (int i = 0; i < inputColumns.size(); i++) { ColumnInfo colInfo = inputColumns.get(i); String outputColName = colInfo.getInternalName(); - ExprNodeDesc expr = new ExprNodeColumnDesc(colInfo); + ExprNodeColumnDesc expr = new ExprNodeColumnDesc(colInfo); // backtrack can be null when input is script operator ExprNodeDesc exprBack = ExprNodeDescUtils.backtrack(expr, dummy, input); @@ -707,7 +733,7 @@ private static ReduceSinkOperator genReduceSink(Operator input, ExprNodeDesc[ ColumnInfo newColInfo = new ColumnInfo(colInfo); newColInfo.setInternalName(Utilities.ReduceField.KEY + ".reducesinkkey" + kindex); newColInfo.setAlias(outputColName); - newColInfo.setTabAlias(colInfo.getTabAlias()); + newColInfo.setTabAlias(tableAlias); outputColumns.add(newColInfo); index[i] = kindex; continue; @@ -725,7 +751,7 @@ private static ReduceSinkOperator genReduceSink(Operator input, ExprNodeDesc[ ColumnInfo newColInfo = new ColumnInfo(colInfo); newColInfo.setInternalName(Utilities.ReduceField.VALUE + "." + outputColName); newColInfo.setAlias(outputColName); - newColInfo.setTabAlias(colInfo.getTabAlias()); + newColInfo.setTabAlias(tableAlias); outputColumns.add(newColInfo); outputColumnNames.add(outputColName); @@ -765,8 +791,8 @@ private static ReduceSinkOperator genReduceSink(Operator input, ExprNodeDesc[ rsOp.setValueIndex(index); rsOp.setColumnExprMap(colExprMap); - rsOp.setInputAliases(input.getSchema().getColumnNames() - .toArray(new String[input.getSchema().getColumnNames().size()])); + rsOp.setInputAliases(input.getSchema().getTableNames() + .toArray(new String[input.getSchema().getTableNames().size()])); if (LOG.isDebugEnabled()) { LOG.debug("Generated " + rsOp + " with row schema: [" + rsOp.getSchema() + "]"); @@ -776,7 +802,7 @@ private static ReduceSinkOperator genReduceSink(Operator input, ExprNodeDesc[ } private static JoinOperator genJoin(RelNode join, JoinPredicateInfo joinPredInfo, - List> children, ExprNodeDesc[][] joinKeys) throws SemanticException { + List> children, ExprNodeDesc[][] joinKeys, String[] baseSrc, String tabAlias) throws SemanticException { // Extract join type JoinType joinType; @@ -827,12 +853,13 @@ private static JoinOperator genJoin(RelNode join, JoinPredicateInfo joinPredInfo posToAliasMap.put(pos, new HashSet(inputRS.getSchema().getTableNames())); Map descriptors = buildBacktrackFromReduceSinkForJoin(outputPos, - outputColumnNames, keyColNames, valColNames, index, parent); + outputColumnNames, keyColNames, valColNames, index, parent, baseSrc[pos]); List parentColumns = parent.getSchema().getSignature(); for (int i = 0; i < index.length; i++) { ColumnInfo info = new ColumnInfo(parentColumns.get(i)); info.setInternalName(outputColumnNames.get(outputPos)); + info.setTabAlias(tabAlias); outputColumns.add(info); reversedExprs.put(outputColumnNames.get(outputPos), tag); outputPos++; @@ -892,7 +919,7 @@ private static JoinType extractJoinType(HiveJoin join) { private static Map buildBacktrackFromReduceSinkForJoin(int initialPos, List outputColumnNames, List keyColNames, List valueColNames, - int[] index, Operator inputOp) { + int[] index, Operator inputOp, String tabAlias) { Map columnDescriptors = new LinkedHashMap(); for (int i = 0; i < index.length; i++) { ColumnInfo info = new ColumnInfo(inputOp.getSchema().getSignature().get(i)); @@ -902,7 +929,7 @@ private static JoinType extractJoinType(HiveJoin join) { } else { field = Utilities.ReduceField.VALUE + "." + valueColNames.get(-index[i] - 1); } - ExprNodeColumnDesc desc = new ExprNodeColumnDesc(info.getType(), field, info.getTabAlias(), + ExprNodeColumnDesc desc = new ExprNodeColumnDesc(info.getType(), field, tabAlias, info.getIsVirtualCol()); columnDescriptors.put(outputColumnNames.get(initialPos + i), desc); } @@ -945,6 +972,16 @@ private static ExprNodeDesc convertToExprNode(RexNode rn, RelNode inputRel, Stri return cInfoLst; } + //create column info with new tableAlias + private static ArrayList createColInfos(Operator input, String tableAlias) { + ArrayList cInfoLst = new ArrayList(); + for (ColumnInfo ci : input.getSchema().getSignature()) { + ci.setTabAlias(tableAlias); + cInfoLst.add(new ColumnInfo(ci)); + } + return cInfoLst; + } + private static ArrayList createColInfosSubset(Operator input, List keepColNames) { ArrayList cInfoLst = new ArrayList();