diff --git a/common/src/java/org/apache/hadoop/hive/conf/HiveConf.java b/common/src/java/org/apache/hadoop/hive/conf/HiveConf.java index 65ec1b9..941f30d 100644 --- a/common/src/java/org/apache/hadoop/hive/conf/HiveConf.java +++ b/common/src/java/org/apache/hadoop/hive/conf/HiveConf.java @@ -703,7 +703,7 @@ public void setSparkConfigUpdated(boolean isSparkConfigUpdated) { // CBO related HIVE_CBO_ENABLED("hive.cbo.enable", true, "Flag to control enabling Cost Based Optimizations using Calcite framework."), - HIVE_CBO_RETPATH_HIVEOP("hive.cbo.returnpath.hiveop", false, "Flag to control calcite plan to hive operator conversion"), + HIVE_CBO_RETPATH_HIVEOP("hive.cbo.returnpath.hiveop", true, "Flag to control calcite plan to hive operator conversion"), HIVE_CBO_EXTENDED_COST_MODEL("hive.cbo.costmodel.extended", false, "Flag to control enabling the extended cost model based on" + "CPU, IO and cardinality. Otherwise, the cost model is based on cardinality."), HIVE_CBO_COST_MODEL_CPU("hive.cbo.costmodel.cpu", "0.000001", "Default cost of a comparison"), diff --git a/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/HiveCalciteUtil.java b/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/HiveCalciteUtil.java index 6147791..7614463 100644 --- a/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/HiveCalciteUtil.java +++ b/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/HiveCalciteUtil.java @@ -589,13 +589,12 @@ public Void visitCall(org.apache.calcite.rex.RexCall call) { return bldr.build(); } - - public static ImmutableMap shiftVColsMap(Map hiveVCols, - int shift) { - Builder bldr = ImmutableMap. builder(); - for (Integer pos : hiveVCols.keySet()) { - bldr.put(shift + pos, hiveVCols.get(pos)); + public static ImmutableSet shiftVColsSet(Set hiveVCols, int shift) { + ImmutableSet.Builder bldr = ImmutableSet. builder(); + + for (Integer pos : hiveVCols) { + bldr.add(shift + pos); } return bldr.build(); @@ -661,13 +660,14 @@ public static ExprNodeDesc getExprNode(Integer inputRefIndx, RelNode inputRel, List exprNodes = new ArrayList(); List rexInputRefs = getInputRef(inputRefs, inputRel); // TODO: Change ExprNodeConverter to be independent of Partition Expr - ExprNodeConverter exprConv = new ExprNodeConverter(inputTabAlias, inputRel.getRowType(), false, inputRel.getCluster().getTypeFactory()); + ExprNodeConverter exprConv = new ExprNodeConverter(inputTabAlias, inputRel.getRowType(), + new HashSet(), inputRel.getCluster().getTypeFactory()); for (RexNode iRef : rexInputRefs) { exprNodes.add(iRef.accept(exprConv)); } return exprNodes; } - + public static List getFieldNames(List inputRefs, RelNode inputRel) { List fieldNames = new ArrayList(); List schemaNames = inputRel.getRowType().getFieldNames(); @@ -732,4 +732,27 @@ public Boolean visitFieldAccess(RexFieldAccess fieldAccess) { return fieldAccess.getReferenceExpr().accept(this); } } + + public static Set getInputRefs(RexNode expr) { + InputRefsCollector irefColl = new InputRefsCollector(true); + return irefColl.getInputRefSet(); + } + + private static class InputRefsCollector extends RexVisitorImpl { + + private Set inputRefSet = new HashSet(); + + private InputRefsCollector(boolean deep) { + super(deep); + } + + public Void visitInputRef(RexInputRef inputRef) { + inputRefSet.add(inputRef.getIndex()); + return null; + } + + public Set getInputRefSet() { + return inputRefSet; + } + } } diff --git a/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/RelOptHiveTable.java b/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/RelOptHiveTable.java index 746b107..0de7488 100644 --- a/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/RelOptHiveTable.java +++ b/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/RelOptHiveTable.java @@ -248,7 +248,7 @@ public void computePartitionList(HiveConf conf, RexNode pruneNode) { // We have valid pruning expressions, only retrieve qualifying partitions ExprNodeDesc pruneExpr = pruneNode.accept(new ExprNodeConverter(getName(), getRowType(), - true, this.getRelOptSchema().getTypeFactory())); + HiveCalciteUtil.getInputRefs(pruneNode), this.getRelOptSchema().getTypeFactory())); partitionList = PartitionPruner.prune(hiveTblMetadata, pruneExpr, conf, getName(), partitionCache); diff --git a/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/translator/ExprNodeConverter.java b/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/translator/ExprNodeConverter.java index de4b2bc..bcce74a 100644 --- a/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/translator/ExprNodeConverter.java +++ b/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/translator/ExprNodeConverter.java @@ -23,6 +23,7 @@ import java.util.Calendar; import java.util.LinkedList; import java.util.List; +import java.util.Set; import org.apache.calcite.rel.RelFieldCollation; import org.apache.calcite.rel.type.RelDataType; @@ -68,34 +69,36 @@ import org.apache.hadoop.hive.ql.udf.generic.GenericUDF; import org.apache.hadoop.hive.serde2.typeinfo.TypeInfoFactory; +import com.google.common.collect.ImmutableSet; + /* * convert a RexNode to an ExprNodeDesc */ public class ExprNodeConverter extends RexVisitorImpl { - String tabAlias; - String columnAlias; - RelDataType inputRowType; - RelDataType outputRowType; - boolean partitioningExpr; - WindowFunctionSpec wfs; + private final String tabAlias; + private final String columnAlias; + private final RelDataType inputRowType; + private final RelDataType outputRowType; + private final ImmutableSet inputVCols; + private WindowFunctionSpec wfs; private final RelDataTypeFactory dTFactory; protected final Log LOG = LogFactory.getLog(this.getClass().getName()); public ExprNodeConverter(String tabAlias, RelDataType inputRowType, - boolean partitioningExpr, RelDataTypeFactory dTFactory) { - this(tabAlias, null, inputRowType, null, partitioningExpr, dTFactory); + Set vCols, RelDataTypeFactory dTFactory) { + this(tabAlias, null, inputRowType, null, vCols, dTFactory); } public ExprNodeConverter(String tabAlias, String columnAlias, RelDataType inputRowType, - RelDataType outputRowType, boolean partitioningExpr, RelDataTypeFactory dTFactory) { + RelDataType outputRowType, Set inputVCols, RelDataTypeFactory dTFactory) { super(true); this.tabAlias = tabAlias; this.columnAlias = columnAlias; this.inputRowType = inputRowType; this.outputRowType = outputRowType; - this.partitioningExpr = partitioningExpr; - this.dTFactory = dTFactory; + this.inputVCols = ImmutableSet.copyOf(inputVCols); + this.dTFactory = dTFactory; } public WindowFunctionSpec getWindowFunctionSpec() { @@ -106,7 +109,7 @@ public WindowFunctionSpec getWindowFunctionSpec() { public ExprNodeDesc visitInputRef(RexInputRef inputRef) { RelDataTypeField f = inputRowType.getFieldList().get(inputRef.getIndex()); return new ExprNodeColumnDesc(TypeConverter.convert(f.getType()), f.getName(), tabAlias, - partitioningExpr); + inputVCols.contains(inputRef.getIndex())); } /** @@ -257,7 +260,7 @@ public ExprNodeDesc visitOver(RexOver over) { RelDataTypeField f = outputRowType.getField(columnAlias, false, false); return new ExprNodeColumnDesc(TypeConverter.convert(f.getType()), columnAlias, tabAlias, - partitioningExpr); + false); } private PartitioningSpec getPSpec(RexWindow window) { diff --git a/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/translator/HiveGBOpConvUtil.java b/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/translator/HiveGBOpConvUtil.java index 55f1247..d7f5fad 100644 --- a/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/translator/HiveGBOpConvUtil.java +++ b/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/translator/HiveGBOpConvUtil.java @@ -154,7 +154,7 @@ private static GBInfo getGBInfo(HiveAggregate aggRel, OpAttr inputOpAf, HiveConf // 1. Collect GB Keys RelNode aggInputRel = aggRel.getInput(); ExprNodeConverter exprConv = new ExprNodeConverter(inputOpAf.tabAlias, - aggInputRel.getRowType(), false, aggRel.getCluster().getTypeFactory()); + aggInputRel.getRowType(), new HashSet(), aggRel.getCluster().getTypeFactory()); ExprNodeDesc tmpExprNodeDesc; for (int i : aggRel.getGroupSet()) { @@ -639,7 +639,7 @@ private static OpAttr genReduceGBRS(OpAttr inputOpAf, GBInfo gbInfo) throws Sema rsOp.setColumnExprMap(colExprMap); - return new OpAttr("", new HashMap(), rsOp); + return new OpAttr("", new HashSet(), rsOp); } private static OpAttr genMapSideGBRS(OpAttr inputOpAf, GBInfo gbInfo) throws SemanticException { @@ -677,7 +677,7 @@ private static OpAttr genMapSideGBRS(OpAttr inputOpAf, GBInfo gbInfo) throws Sem rsOp.setColumnExprMap(colExprMap); - return new OpAttr("", new HashMap(), rsOp); + return new OpAttr("", new HashSet(), rsOp); } private static OpAttr genMapSideRS(OpAttr inputOpAf, GBInfo gbInfo) throws SemanticException { @@ -732,7 +732,7 @@ private static OpAttr genMapSideRS(OpAttr inputOpAf, GBInfo gbInfo) throws Seman rsOp.setColumnExprMap(colExprMap); - return new OpAttr("", new HashMap(), rsOp); + return new OpAttr("", new HashSet(), rsOp); } private static OpAttr genReduceSideGB2(OpAttr inputOpAf, GBInfo gbInfo) throws SemanticException { @@ -799,7 +799,7 @@ private static OpAttr genReduceSideGB2(OpAttr inputOpAf, GBInfo gbInfo) throws S rsGBOp2.setColumnExprMap(colExprMap); // TODO: Shouldn't we propgate vc? is it vc col from tab or all vc - return new OpAttr("", new HashMap(), rsGBOp2); + return new OpAttr("", new HashSet(), rsGBOp2); } private static OpAttr genReduceSideGB1(OpAttr inputOpAf, GBInfo gbInfo, boolean computeGrpSet, @@ -935,7 +935,7 @@ private static OpAttr genReduceSideGB1(OpAttr inputOpAf, GBInfo gbInfo, boolean rsGBOp.setColumnExprMap(colExprMap); - return new OpAttr("", new HashMap(), rsGBOp); + return new OpAttr("", new HashSet(), rsGBOp); } /** @@ -1035,7 +1035,7 @@ private static OpAttr genReduceSideGB1NoMapGB(OpAttr inputOpAf, GBInfo gbInfo, false, -1, numDistinctUDFs > 0), new RowSchema(colInfoLst), rs); rsGB1.setColumnExprMap(colExprMap); - return new OpAttr("", new HashMap(), rsGB1); + return new OpAttr("", new HashSet(), rsGB1); } @SuppressWarnings("unchecked") @@ -1111,7 +1111,7 @@ private static OpAttr genMapSideGB(OpAttr inputOpAf, GBInfo gbAttrs) throws Sema // NOTE: UDAF is not included in ExprColMap gbOp.setColumnExprMap(colExprMap); - return new OpAttr("", new HashMap(), gbOp); + return new OpAttr("", new HashSet(), gbOp); } private static void addGrpSetCol(boolean createConstantExpr, String grpSetIDExprName, diff --git a/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/translator/HiveOpConverter.java b/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/translator/HiveOpConverter.java index d15520c..8746f95 100644 --- a/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/translator/HiveOpConverter.java +++ b/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/translator/HiveOpConverter.java @@ -27,7 +27,6 @@ import java.util.Map; import java.util.Set; -import org.apache.calcite.rel.RelCollation; import org.apache.calcite.rel.RelCollations; import org.apache.calcite.rel.RelDistribution; import org.apache.calcite.rel.RelDistribution.Type; @@ -38,6 +37,7 @@ import org.apache.calcite.rex.RexInputRef; import org.apache.calcite.rex.RexLiteral; import org.apache.calcite.rex.RexNode; +import org.apache.calcite.util.ImmutableBitSet; import org.apache.calcite.util.Pair; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; @@ -95,6 +95,7 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; public class HiveOpConverter { @@ -128,16 +129,16 @@ public HiveOpConverter(SemanticAnalyzer semanticAnalyzer, HiveConf hiveConf, static class OpAttr { final String tabAlias; ImmutableList inputs; - ImmutableMap vcolMap; + ImmutableSet vcolsInCalcite; - OpAttr(String tabAlias, Map vcolMap, Operator... inputs) { + OpAttr(String tabAlias, Set vcols, Operator... inputs) { this.tabAlias = tabAlias; - this.vcolMap = ImmutableMap.copyOf(vcolMap); this.inputs = ImmutableList.copyOf(inputs); + this.vcolsInCalcite = ImmutableSet.copyOf(vcols); } private OpAttr clone(Operator... inputs) { - return new OpAttr(tabAlias, this.vcolMap, inputs); + return new OpAttr(tabAlias, vcolsInCalcite, inputs); } } @@ -192,13 +193,13 @@ OpAttr visit(HiveTableScan scanRel) { // 1. Setup TableScan Desc // 1.1 Build col details used by scan ArrayList colInfos = new ArrayList(); - List virtualCols = new ArrayList(ht.getVirtualCols()); - Map hiveScanVColMap = new HashMap(); - List partColNames = new ArrayList(); + List virtualCols = new ArrayList(); List neededColumnIDs = new ArrayList(); - List neededColumns = new ArrayList(); + List neededColumnNames = new ArrayList(); + Set vcolsInCalcite = new HashSet(); - Map posToVColMap = HiveCalciteUtil.getVColsMap(virtualCols, + List partColNames = new ArrayList(); + Map VColsMap = HiveCalciteUtil.getVColsMap(ht.getVirtualCols(), ht.getNoOfNonVirtualCols()); Map posToPartColInfo = ht.getPartColInfoMap(); Map posToNonPartColInfo = ht.getNonPartColInfoMap(); @@ -214,19 +215,20 @@ OpAttr visit(HiveTableScan scanRel) { for (int i = 0; i < neededColIndxsFrmReloptHT.size(); i++) { colName = scanColNames.get(i); posInRHT = neededColIndxsFrmReloptHT.get(i); - if (posToVColMap.containsKey(posInRHT)) { - vc = posToVColMap.get(posInRHT); + if (VColsMap.containsKey(posInRHT)) { + vc = VColsMap.get(posInRHT); virtualCols.add(vc); colInfo = new ColumnInfo(vc.getName(), vc.getTypeInfo(), tableAlias, true, vc.getIsHidden()); - hiveScanVColMap.put(i, vc); + vcolsInCalcite.add(posInRHT); } else if (posToPartColInfo.containsKey(posInRHT)) { partColNames.add(colName); colInfo = posToPartColInfo.get(posInRHT); + vcolsInCalcite.add(posInRHT); } else { colInfo = posToNonPartColInfo.get(posInRHT); } neededColumnIDs.add(posInRHT); - neededColumns.add(colName); + neededColumnNames.add(colName); colInfos.add(colInfo); } @@ -238,7 +240,7 @@ OpAttr visit(HiveTableScan scanRel) { // 1.4. Set needed cols in TSDesc tsd.setNeededColumnIDs(neededColumnIDs); - tsd.setNeededColumns(neededColumns); + tsd.setNeededColumns(neededColumnNames); // 2. Setup TableScan TableScanOperator ts = (TableScanOperator) OperatorFactory.get(tsd, new RowSchema(colInfos)); @@ -249,7 +251,7 @@ OpAttr visit(HiveTableScan scanRel) { LOG.debug("Generated " + ts + " with row schema: [" + ts.getSchema() + "]"); } - return new OpAttr(tableAlias, hiveScanVColMap, ts); + return new OpAttr(tableAlias, vcolsInCalcite, ts); } OpAttr visit(HiveProject projectRel) throws SemanticException { @@ -267,10 +269,11 @@ OpAttr visit(HiveProject projectRel) throws SemanticException { for (int pos = 0; pos < projectRel.getChildExps().size(); pos++) { ExprNodeConverter converter = new ExprNodeConverter(inputOpAf.tabAlias, projectRel .getRowType().getFieldNames().get(pos), projectRel.getInput().getRowType(), - projectRel.getRowType(), false, projectRel.getCluster().getTypeFactory()); + projectRel.getRowType(), inputOpAf.vcolsInCalcite, projectRel.getCluster().getTypeFactory()); ExprNodeDesc exprCol = projectRel.getChildExps().get(pos).accept(converter); colExprMap.put(exprNames.get(pos), exprCol); exprCols.add(exprCol); + //TODO: Cols that come through PTF should it retain (VirtualColumness)? if (converter.getWindowFunctionSpec() != null) { windowingSpec.addWindowFunction(converter.getWindowFunctionSpec()); } @@ -281,7 +284,7 @@ OpAttr visit(HiveProject projectRel) throws SemanticException { } // TODO: is this a safe assumption (name collision, external names...) SelectDesc sd = new SelectDesc(exprCols, exprNames); - Pair, Map> colInfoVColPair = createColInfos( + Pair, Set> colInfoVColPair = createColInfos( projectRel.getChildExps(), exprCols, exprNames, inputOpAf); SelectOperator selOp = (SelectOperator) OperatorFactory.getAndMakeChild(sd, new RowSchema( colInfoVColPair.getKey()), inputOpAf.inputs.get(0)); @@ -312,7 +315,7 @@ OpAttr visit(HiveJoin joinRel) throws SemanticException { JoinPredicateInfo joinPredInfo = JoinPredicateInfo.constructJoinPredicateInfo(joinRel); // 3. Extract join keys from condition - ExprNodeDesc[][] joinKeys = extractJoinKeys(joinPredInfo, joinRel.getInputs()); + ExprNodeDesc[][] joinKeys = extractJoinKeys(joinPredInfo, joinRel.getInputs(), inputs); // 4.a Generate tags for (int tag=0; tag vcolMap = new HashMap(); - vcolMap.putAll(inputs[0].vcolMap); + Set newVcolsInCalcite = new HashSet(); + newVcolsInCalcite.addAll(inputs[0].vcolsInCalcite); if (extractJoinType(joinRel) != JoinType.LEFTSEMI) { int shift = inputs[0].inputs.get(0).getSchema().getSignature().size(); for (int i = 1; i < inputs.length; i++) { - vcolMap.putAll(HiveCalciteUtil.shiftVColsMap(inputs[i].vcolMap, shift)); + newVcolsInCalcite.addAll(HiveCalciteUtil.shiftVColsSet(inputs[i].vcolsInCalcite, shift)); shift += inputs[i].inputs.get(0).getSchema().getSignature().size(); } } // 7. Return result - return new OpAttr(null, vcolMap, joinOp); + return new OpAttr(null, newVcolsInCalcite, joinOp); } OpAttr visit(HiveAggregate aggRel) throws SemanticException { @@ -365,6 +368,7 @@ OpAttr visit(HiveSort sortRel) throws SemanticException { Operator inputOp = inputOpAf.inputs.get(0); Operator resultOp = inputOpAf.inputs.get(0); + // 1. If we need to sort tuples based on the value of some // of their columns if (sortRel.getCollation() != RelCollations.EMPTY) { @@ -377,30 +381,51 @@ OpAttr visit(HiveSort sortRel) throws SemanticException { // 1.a. Extract order for each column from collation // Generate sortCols and order + ImmutableBitSet.Builder sortColsPosBuilder = new ImmutableBitSet.Builder(); + ImmutableBitSet.Builder sortOutputColsPosBuilder = new ImmutableBitSet.Builder(); + Map obRefToCallMap = sortRel.getInputRefToCallMap(); List sortCols = new ArrayList(); StringBuilder order = new StringBuilder(); - for (RelCollation collation : sortRel.getCollationList()) { - for (RelFieldCollation sortInfo : collation.getFieldCollations()) { - int sortColumnPos = sortInfo.getFieldIndex(); - ColumnInfo columnInfo = new ColumnInfo(inputOp.getSchema().getSignature() - .get(sortColumnPos)); - ExprNodeColumnDesc sortColumn = new ExprNodeColumnDesc(columnInfo.getType(), - columnInfo.getInternalName(), columnInfo.getTabAlias(), columnInfo.getIsVirtualCol()); - sortCols.add(sortColumn); - if (sortInfo.getDirection() == RelFieldCollation.Direction.DESCENDING) { - order.append("-"); - } else { - order.append("+"); + for (RelFieldCollation sortInfo : sortRel.getCollation().getFieldCollations()) { + int sortColumnPos = sortInfo.getFieldIndex(); + ColumnInfo columnInfo = new ColumnInfo(inputOp.getSchema().getSignature() + .get(sortColumnPos)); + ExprNodeColumnDesc sortColumn = new ExprNodeColumnDesc(columnInfo.getType(), + columnInfo.getInternalName(), columnInfo.getTabAlias(), columnInfo.getIsVirtualCol()); + sortCols.add(sortColumn); + if (sortInfo.getDirection() == RelFieldCollation.Direction.DESCENDING) { + order.append("-"); + } else { + order.append("+"); + } + + if (obRefToCallMap != null) { + RexNode obExpr = obRefToCallMap.get(sortColumnPos); + sortColsPosBuilder.set(sortColumnPos); + if (obExpr == null) { + sortOutputColsPosBuilder.set(sortColumnPos); } } } // Use only 1 reducer for order by int numReducers = 1; + + // We keep the columns only the columns that are part of the final output + List keepColumns = new ArrayList(); + final ImmutableBitSet sortColsPos = sortColsPosBuilder.build(); + final ImmutableBitSet sortOutputColsPos = sortOutputColsPosBuilder.build(); + final ArrayList inputSchema = inputOp.getSchema().getSignature(); + for (int pos=0; pos(), - order.toString(), numReducers, Operation.NOT_ACID, strictMode); + order.toString(), numReducers, Operation.NOT_ACID, strictMode, keepColumns); } // 2. If we need to generate limit @@ -433,7 +458,7 @@ OpAttr visit(HiveFilter filterRel) throws SemanticException { } ExprNodeDesc filCondExpr = filterRel.getCondition().accept( - new ExprNodeConverter(inputOpAf.tabAlias, filterRel.getInput().getRowType(), false, + new ExprNodeConverter(inputOpAf.tabAlias, filterRel.getInput().getRowType(), inputOpAf.vcolsInCalcite, filterRel.getCluster().getTypeFactory())); FilterDesc filDesc = new FilterDesc(filCondExpr, false); ArrayList cinfoLst = createColInfos(inputOpAf.inputs.get(0)); @@ -474,6 +499,7 @@ OpAttr visit(HiveUnion unionRel) throws SemanticException { LOG.debug("Generated " + unionOp + " with row schema: [" + unionOp.getSchema() + "]"); } + //TODO: Can columns retain virtualness out of union // 3. Return result return inputs[0].clone(unionOp); } @@ -570,32 +596,41 @@ private OpAttr genPTF(OpAttr inputOpAf, WindowingSpec wSpec) throws SemanticExce return inputOpAf.clone(input); } - private ExprNodeDesc[][] extractJoinKeys(JoinPredicateInfo joinPredInfo, List inputs) { + private ExprNodeDesc[][] extractJoinKeys(JoinPredicateInfo joinPredInfo, List inputs, OpAttr[] inputAttr) { ExprNodeDesc[][] joinKeys = new ExprNodeDesc[inputs.size()][]; for (int i = 0; i < inputs.size(); i++) { joinKeys[i] = new ExprNodeDesc[joinPredInfo.getEquiJoinPredicateElements().size()]; for (int j = 0; j < joinPredInfo.getEquiJoinPredicateElements().size(); j++) { JoinLeafPredicateInfo joinLeafPredInfo = joinPredInfo.getEquiJoinPredicateElements().get(j); RexNode key = joinLeafPredInfo.getJoinKeyExprs(j).get(0); - joinKeys[i][j] = convertToExprNode(key, inputs.get(j), null); + joinKeys[i][j] = convertToExprNode(key, inputs.get(j), null, inputAttr[i]); } } return joinKeys; } private static SelectOperator genReduceSinkAndBacktrackSelect(Operator input, + ExprNodeDesc[] keys, int tag, ArrayList partitionCols, String order, + int numReducers, Operation acidOperation, boolean strictMode) throws SemanticException { + return genReduceSinkAndBacktrackSelect(input, keys, tag, partitionCols, order, + numReducers, acidOperation, strictMode, input.getSchema().getColumnNames()); + } + + private static SelectOperator genReduceSinkAndBacktrackSelect(Operator input, ExprNodeDesc[] keys, int tag, ArrayList partitionCols, String order, - int numReducers, Operation acidOperation, boolean strictMode) throws SemanticException { + int numReducers, Operation acidOperation, boolean strictMode, + List keepColNames) throws SemanticException { // 1. Generate RS operator ReduceSinkOperator rsOp = genReduceSink(input, keys, tag, partitionCols, order, numReducers, acidOperation, strictMode); // 2. Generate backtrack Select operator - Map descriptors = buildBacktrackFromReduceSink(rsOp, - input); + Map descriptors = buildBacktrackFromReduceSink(keepColNames, + rsOp.getConf().getOutputKeyColumnNames(), rsOp.getConf().getOutputValueColumnNames(), + rsOp.getValueIndex(), input); SelectDesc selectDesc = new SelectDesc(new ArrayList(descriptors.values()), new ArrayList(descriptors.keySet())); - ArrayList cinfoLst = createColInfos(input); + ArrayList cinfoLst = createColInfosSubset(input, keepColNames); SelectOperator selectOp = (SelectOperator) OperatorFactory.getAndMakeChild(selectDesc, new RowSchema(cinfoLst), rsOp); selectOp.setColumnExprMap(descriptors); @@ -763,7 +798,7 @@ private static JoinOperator genJoin(HiveJoin hiveJoin, JoinPredicateInfo joinPre posToAliasMap.put(pos, new HashSet(inputRS.getSchema().getTableNames())); - Map descriptors = buildBacktrackFromReduceSink(outputPos, + Map descriptors = buildBacktrackFromReduceSinkForJoin(outputPos, outputColumnNames, keyColNames, valColNames, index, parent); List parentColumns = parent.getSchema().getSignature(); @@ -827,14 +862,7 @@ private static JoinType extractJoinType(HiveJoin join) { return resultJoinType; } - private static Map buildBacktrackFromReduceSink(ReduceSinkOperator rsOp, - Operator inputOp) { - return buildBacktrackFromReduceSink(0, inputOp.getSchema().getColumnNames(), rsOp.getConf() - .getOutputKeyColumnNames(), rsOp.getConf().getOutputValueColumnNames(), - rsOp.getValueIndex(), inputOp); - } - - private static Map buildBacktrackFromReduceSink(int initialPos, + private static Map buildBacktrackFromReduceSinkForJoin(int initialPos, List outputColumnNames, List keyColNames, List valueColNames, int[] index, Operator inputOp) { Map columnDescriptors = new LinkedHashMap(); @@ -853,8 +881,31 @@ private static JoinType extractJoinType(HiveJoin join) { return columnDescriptors; } - private static ExprNodeDesc convertToExprNode(RexNode rn, RelNode inputRel, String tabAlias) { - return rn.accept(new ExprNodeConverter(tabAlias, inputRel.getRowType(), false, + private static Map buildBacktrackFromReduceSink(List keepColNames, + List keyColNames, List valueColNames, int[] index, Operator inputOp) { + Map columnDescriptors = new LinkedHashMap(); + int pos = 0; + for (int i = 0; i < index.length; i++) { + ColumnInfo info = inputOp.getSchema().getSignature().get(i); + if (pos < keepColNames.size() && + info.getInternalName().equals(keepColNames.get(pos))) { + String field; + if (index[i] >= 0) { + field = Utilities.ReduceField.KEY + "." + keyColNames.get(index[i]); + } else { + field = Utilities.ReduceField.VALUE + "." + valueColNames.get(-index[i] - 1); + } + ExprNodeColumnDesc desc = new ExprNodeColumnDesc(info.getType(), field, info.getTabAlias(), + info.getIsVirtualCol()); + columnDescriptors.put(keepColNames.get(pos), desc); + pos++; + } + } + return columnDescriptors; + } + + private static ExprNodeDesc convertToExprNode(RexNode rn, RelNode inputRel, String tabAlias, OpAttr inputAttr) { + return rn.accept(new ExprNodeConverter(tabAlias, inputRel.getRowType(), inputAttr.vcolsInCalcite, inputRel.getCluster().getTypeFactory())); } @@ -866,7 +917,21 @@ private static ExprNodeDesc convertToExprNode(RexNode rn, RelNode inputRel, Stri return cInfoLst; } - private static Pair, Map> createColInfos( + private static ArrayList createColInfosSubset(Operator input, + List keepColNames) { + ArrayList cInfoLst = new ArrayList(); + int pos = 0; + for (ColumnInfo ci : input.getSchema().getSignature()) { + if (pos < keepColNames.size() && + ci.getInternalName().equals(keepColNames.get(pos))) { + cInfoLst.add(new ColumnInfo(ci)); + pos++; + } + } + return cInfoLst; + } + + private static Pair, Set> createColInfos( List calciteExprs, List hiveExprs, List projNames, OpAttr inpOpAf) { if (hiveExprs.size() != projNames.size()) { @@ -876,22 +941,22 @@ private static ExprNodeDesc convertToExprNode(RexNode rn, RelNode inputRel, Stri RexNode rexN; ExprNodeDesc pe; ArrayList colInfos = new ArrayList(); - VirtualColumn vc; - Map newVColMap = new HashMap(); + boolean vc; + Set newVColSet = new HashSet(); for (int i = 0; i < hiveExprs.size(); i++) { pe = hiveExprs.get(i); rexN = calciteExprs.get(i); - vc = null; + vc = false; if (rexN instanceof RexInputRef) { - vc = inpOpAf.vcolMap.get(((RexInputRef) rexN).getIndex()); - if (vc != null) { - newVColMap.put(i, vc); + if (inpOpAf.vcolsInCalcite.contains(((RexInputRef) rexN).getIndex())) { + newVColSet.add(i); + vc = true; } } colInfos - .add(new ColumnInfo(projNames.get(i), pe.getTypeInfo(), inpOpAf.tabAlias, vc != null)); + .add(new ColumnInfo(projNames.get(i), pe.getTypeInfo(), inpOpAf.tabAlias, vc)); } - return new Pair, Map>(colInfos, newVColMap); + return new Pair, Set>(colInfos, newVColSet); } } diff --git a/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/translator/PlanModifierForASTConv.java b/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/translator/PlanModifierForASTConv.java index cba37bc..d7d8e75 100644 --- a/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/translator/PlanModifierForASTConv.java +++ b/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/translator/PlanModifierForASTConv.java @@ -18,14 +18,11 @@ package org.apache.hadoop.hive.ql.optimizer.calcite.translator; import java.util.ArrayList; -import java.util.HashSet; import java.util.List; -import java.util.Set; import org.apache.calcite.plan.RelOptUtil; import org.apache.calcite.plan.hep.HepRelVertex; import org.apache.calcite.plan.volcano.RelSubset; -import org.apache.calcite.rel.RelCollationImpl; import org.apache.calcite.rel.RelNode; import org.apache.calcite.rel.SingleRel; import org.apache.calcite.rel.core.Aggregate; @@ -38,19 +35,14 @@ import org.apache.calcite.rel.rules.MultiJoin; import org.apache.calcite.rel.type.RelDataType; import org.apache.calcite.rel.type.RelDataTypeFactory; -import org.apache.calcite.rex.RexBuilder; -import org.apache.calcite.rex.RexCall; -import org.apache.calcite.rex.RexInputRef; import org.apache.calcite.rex.RexNode; import org.apache.calcite.sql.SqlAggFunction; -import org.apache.calcite.sql.SqlKind; import org.apache.calcite.util.Pair; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.apache.hadoop.hive.metastore.api.FieldSchema; import org.apache.hadoop.hive.ql.optimizer.calcite.CalciteSemanticException; import org.apache.hadoop.hive.ql.optimizer.calcite.HiveCalciteUtil; -import org.apache.hadoop.hive.ql.optimizer.calcite.RelOptHiveTable; import org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveAggregate; import org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveProject; import org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveSort; @@ -58,11 +50,12 @@ import org.apache.hadoop.hive.serde2.typeinfo.TypeInfoFactory; import com.google.common.collect.ImmutableList; -import com.google.common.collect.ImmutableMap; public class PlanModifierForASTConv { + private static final Log LOG = LogFactory.getLog(PlanModifierForASTConv.class); + public static RelNode convertOpTree(RelNode rel, List resultSchema) throws CalciteSemanticException { RelNode newTopNode = rel; @@ -71,7 +64,7 @@ public static RelNode convertOpTree(RelNode rel, List resultSchema) } if (!(newTopNode instanceof Project) && !(newTopNode instanceof Sort)) { - newTopNode = introduceDerivedTable(newTopNode); + newTopNode = PlanModifierUtil.introduceDerivedTable(newTopNode); if (LOG.isDebugEnabled()) { LOG.debug("Plan after top-level introduceDerivedTable\n " + RelOptUtil.toString(newTopNode)); @@ -84,7 +77,7 @@ public static RelNode convertOpTree(RelNode rel, List resultSchema) } Pair topSelparentPair = HiveCalciteUtil.getTopLevelSelect(newTopNode); - fixTopOBSchema(newTopNode, topSelparentPair, resultSchema); + PlanModifierUtil.fixTopOBSchema(newTopNode, topSelparentPair, resultSchema, true); if (LOG.isDebugEnabled()) { LOG.debug("Plan after fixTopOBSchema\n " + RelOptUtil.toString(newTopNode)); } @@ -176,79 +169,6 @@ private static void convertOpTree(RelNode rel, RelNode parent) { } } - private static void fixTopOBSchema(final RelNode rootRel, - Pair topSelparentPair, List resultSchema) - throws CalciteSemanticException { - if (!(topSelparentPair.getKey() instanceof Sort) - || !HiveCalciteUtil.orderRelNode(topSelparentPair.getKey())) { - return; - } - HiveSort obRel = (HiveSort) topSelparentPair.getKey(); - Project obChild = (Project) topSelparentPair.getValue(); - if (obChild.getRowType().getFieldCount() <= resultSchema.size()) { - return; - } - - RelDataType rt = obChild.getRowType(); - @SuppressWarnings({ "unchecked", "rawtypes" }) - Set collationInputRefs = new HashSet( - RelCollationImpl.ordinals(obRel.getCollation())); - ImmutableMap.Builder inputRefToCallMapBldr = ImmutableMap.builder(); - for (int i = resultSchema.size(); i < rt.getFieldCount(); i++) { - if (collationInputRefs.contains(i)) { - RexNode obyExpr = obChild.getChildExps().get(i); - if (obyExpr instanceof RexCall) { - int a = -1; - List operands = new ArrayList<>(); - for (int k = 0; k< ((RexCall) obyExpr).operands.size(); k++) { - RexNode rn = ((RexCall) obyExpr).operands.get(k); - for (int j = 0; j < resultSchema.size(); j++) { - if( obChild.getChildExps().get(j).toString().equals(rn.toString())) { - a = j; - break; - } - } if (a != -1) { - operands.add(new RexInputRef(a, rn.getType())); - } else { - operands.add(rn); - } - a = -1; - } - obyExpr = obChild.getCluster().getRexBuilder().makeCall(((RexCall)obyExpr).getOperator(), operands); - } - inputRefToCallMapBldr.put(i, obyExpr); - } - } - ImmutableMap inputRefToCallMap = inputRefToCallMapBldr.build(); - - if ((obChild.getRowType().getFieldCount() - inputRefToCallMap.size()) != resultSchema.size()) { - LOG.error(generateInvalidSchemaMessage(obChild, resultSchema, inputRefToCallMap.size())); - throw new CalciteSemanticException("Result Schema didn't match Optimized Op Tree Schema"); - } - // This removes order-by only expressions from the projections. - HiveProject replacementProjectRel = HiveProject.create(obChild.getInput(), obChild - .getChildExps().subList(0, resultSchema.size()), obChild.getRowType().getFieldNames() - .subList(0, resultSchema.size())); - obRel.replaceInput(0, replacementProjectRel); - obRel.setInputRefToCallMap(inputRefToCallMap); - } - - private static String generateInvalidSchemaMessage(Project topLevelProj, - List resultSchema, int fieldsForOB) { - String errorDesc = "Result Schema didn't match Calcite Optimized Op Tree; schema: "; - for (FieldSchema fs : resultSchema) { - errorDesc += "[" + fs.getName() + ":" + fs.getType() + "], "; - } - errorDesc += " projection fields: "; - for (RexNode exp : topLevelProj.getChildExps()) { - errorDesc += "[" + exp.toString() + ":" + exp.getType() + "], "; - } - if (fieldsForOB != 0) { - errorDesc += fieldsForOB + " fields removed due to ORDER BY "; - } - return errorDesc.substring(0, errorDesc.length() - 2); - } - private static RelNode renameTopLevelSelectInResultSchema(final RelNode rootRel, Pair topSelparentPair, List resultSchema) throws CalciteSemanticException { @@ -260,7 +180,7 @@ private static RelNode renameTopLevelSelectInResultSchema(final RelNode rootRel, List rootChildExps = originalProjRel.getChildExps(); if (resultSchema.size() != rootChildExps.size()) { // Safeguard against potential issues in CBO RowResolver construction. Disable CBO for now. - LOG.error(generateInvalidSchemaMessage(originalProjRel, resultSchema, 0)); + LOG.error(PlanModifierUtil.generateInvalidSchemaMessage(originalProjRel, resultSchema, 0)); throw new CalciteSemanticException("Result Schema didn't match Optimized Op Tree Schema"); } @@ -285,15 +205,6 @@ private static RelNode renameTopLevelSelectInResultSchema(final RelNode rootRel, } } - private static RelNode introduceDerivedTable(final RelNode rel) { - List projectList = HiveCalciteUtil.getProjsFromBelowAsInputRef(rel); - - HiveProject select = HiveProject.create(rel.getCluster(), rel, projectList, - rel.getRowType(), rel.getCollationList()); - - return select; - } - private static RelNode introduceDerivedTable(final RelNode rel, RelNode parent) { int i = 0; int pos = -1; @@ -311,7 +222,7 @@ private static RelNode introduceDerivedTable(final RelNode rel, RelNode parent) throw new RuntimeException("Couldn't find child node in parent's inputs"); } - RelNode select = introduceDerivedTable(rel); + RelNode select = PlanModifierUtil.introduceDerivedTable(rel); parent.replaceInput(pos, select); @@ -434,7 +345,7 @@ private static void replaceEmptyGroupAggr(final RelNode rel, RelNode parent) { Aggregate newAggRel = oldAggRel.copy(oldAggRel.getTraitSet(), oldAggRel.getInput(), oldAggRel.indicator, oldAggRel.getGroupSet(), oldAggRel.getGroupSets(), ImmutableList.of(dummyCall)); - RelNode select = introduceDerivedTable(newAggRel); + RelNode select = PlanModifierUtil.introduceDerivedTable(newAggRel); parent.replaceInput(0, select); } } diff --git a/ql/src/java/org/apache/hadoop/hive/ql/parse/CalcitePlanner.java b/ql/src/java/org/apache/hadoop/hive/ql/parse/CalcitePlanner.java index 6a15bf6..7fa8a77 100644 --- a/ql/src/java/org/apache/hadoop/hive/ql/parse/CalcitePlanner.java +++ b/ql/src/java/org/apache/hadoop/hive/ql/parse/CalcitePlanner.java @@ -58,10 +58,8 @@ import org.apache.calcite.rel.core.Filter; import org.apache.calcite.rel.core.Join; import org.apache.calcite.rel.core.JoinRelType; -import org.apache.calcite.rel.core.Project; import org.apache.calcite.rel.core.RelFactories; import org.apache.calcite.rel.core.SemiJoin; -import org.apache.calcite.rel.core.Sort; import org.apache.calcite.rel.metadata.CachingRelMetadataProvider; import org.apache.calcite.rel.metadata.ChainedRelMetadataProvider; import org.apache.calcite.rel.metadata.RelMetadataProvider; @@ -147,6 +145,7 @@ import org.apache.hadoop.hive.ql.optimizer.calcite.translator.HiveOpConverter; import org.apache.hadoop.hive.ql.optimizer.calcite.translator.JoinCondTypeCheckProcFactory; import org.apache.hadoop.hive.ql.optimizer.calcite.translator.JoinTypeCheckCtx; +import org.apache.hadoop.hive.ql.optimizer.calcite.translator.PlanModifierForReturnPath; import org.apache.hadoop.hive.ql.optimizer.calcite.translator.RexNodeConverter; import org.apache.hadoop.hive.ql.optimizer.calcite.translator.SqlFunctionConverter; import org.apache.hadoop.hive.ql.optimizer.calcite.translator.TypeConverter; @@ -629,7 +628,8 @@ Operator getOptimizedHiveOPDag() throws SemanticException { throw new AssertionError("rethrowCalciteException didn't throw for " + e.getMessage()); } - RelNode modifiedOptimizedOptiqPlan = introduceProjectIfNeeded(optimizedOptiqPlan); + RelNode modifiedOptimizedOptiqPlan = + PlanModifierForReturnPath.convertOpTree(optimizedOptiqPlan, topLevelFieldSchema); LOG.debug("Translating the following plan:\n" + RelOptUtil.toString(modifiedOptimizedOptiqPlan)); Operator hiveRoot = new HiveOpConverter(this, conf, unparseTranslator, topOps, @@ -639,30 +639,6 @@ Operator getOptimizedHiveOPDag() throws SemanticException { return genFileSinkPlan(getQB().getParseInfo().getClauseNames().iterator().next(), getQB(), hiveRoot); } - private RelNode introduceProjectIfNeeded(RelNode optimizedOptiqPlan) - throws CalciteSemanticException { - RelNode parent = null; - RelNode input = optimizedOptiqPlan; - RelNode newRoot = optimizedOptiqPlan; - - while (!(input instanceof Project) && (input instanceof Sort)) { - parent = input; - input = input.getInput(0); - } - - if (!(input instanceof Project)) { - HiveProject hpRel = HiveProject.create(input, - HiveCalciteUtil.getProjsFromBelowAsInputRef(input), input.getRowType().getFieldNames()); - if (input == optimizedOptiqPlan) { - newRoot = hpRel; - } else { - parent.replaceInput(0, hpRel); - } - } - - return newRoot; - } - /*** * Unwraps Calcite Invocation exceptions coming meta data provider chain and * obtains the real cause.