diff --git a/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/HiveCalciteUtil.java b/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/HiveCalciteUtil.java index 6147791..7614463 100644 --- a/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/HiveCalciteUtil.java +++ b/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/HiveCalciteUtil.java @@ -589,13 +589,12 @@ public Void visitCall(org.apache.calcite.rex.RexCall call) { return bldr.build(); } - - public static ImmutableMap shiftVColsMap(Map hiveVCols, - int shift) { - Builder bldr = ImmutableMap. builder(); - for (Integer pos : hiveVCols.keySet()) { - bldr.put(shift + pos, hiveVCols.get(pos)); + public static ImmutableSet shiftVColsSet(Set hiveVCols, int shift) { + ImmutableSet.Builder bldr = ImmutableSet. builder(); + + for (Integer pos : hiveVCols) { + bldr.add(shift + pos); } return bldr.build(); @@ -661,13 +660,14 @@ public static ExprNodeDesc getExprNode(Integer inputRefIndx, RelNode inputRel, List exprNodes = new ArrayList(); List rexInputRefs = getInputRef(inputRefs, inputRel); // TODO: Change ExprNodeConverter to be independent of Partition Expr - ExprNodeConverter exprConv = new ExprNodeConverter(inputTabAlias, inputRel.getRowType(), false, inputRel.getCluster().getTypeFactory()); + ExprNodeConverter exprConv = new ExprNodeConverter(inputTabAlias, inputRel.getRowType(), + new HashSet(), inputRel.getCluster().getTypeFactory()); for (RexNode iRef : rexInputRefs) { exprNodes.add(iRef.accept(exprConv)); } return exprNodes; } - + public static List getFieldNames(List inputRefs, RelNode inputRel) { List fieldNames = new ArrayList(); List schemaNames = inputRel.getRowType().getFieldNames(); @@ -732,4 +732,27 @@ public Boolean visitFieldAccess(RexFieldAccess fieldAccess) { return fieldAccess.getReferenceExpr().accept(this); } } + + public static Set getInputRefs(RexNode expr) { + InputRefsCollector irefColl = new InputRefsCollector(true); + return irefColl.getInputRefSet(); + } + + private static class InputRefsCollector extends RexVisitorImpl { + + private Set inputRefSet = new HashSet(); + + private InputRefsCollector(boolean deep) { + super(deep); + } + + public Void visitInputRef(RexInputRef inputRef) { + inputRefSet.add(inputRef.getIndex()); + return null; + } + + public Set getInputRefSet() { + return inputRefSet; + } + } } diff --git a/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/RelOptHiveTable.java b/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/RelOptHiveTable.java index 746b107..0de7488 100644 --- a/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/RelOptHiveTable.java +++ b/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/RelOptHiveTable.java @@ -248,7 +248,7 @@ public void computePartitionList(HiveConf conf, RexNode pruneNode) { // We have valid pruning expressions, only retrieve qualifying partitions ExprNodeDesc pruneExpr = pruneNode.accept(new ExprNodeConverter(getName(), getRowType(), - true, this.getRelOptSchema().getTypeFactory())); + HiveCalciteUtil.getInputRefs(pruneNode), this.getRelOptSchema().getTypeFactory())); partitionList = PartitionPruner.prune(hiveTblMetadata, pruneExpr, conf, getName(), partitionCache); diff --git a/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/translator/ExprNodeConverter.java b/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/translator/ExprNodeConverter.java index de4b2bc..bcce74a 100644 --- a/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/translator/ExprNodeConverter.java +++ b/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/translator/ExprNodeConverter.java @@ -23,6 +23,7 @@ import java.util.Calendar; import java.util.LinkedList; import java.util.List; +import java.util.Set; import org.apache.calcite.rel.RelFieldCollation; import org.apache.calcite.rel.type.RelDataType; @@ -68,34 +69,36 @@ import org.apache.hadoop.hive.ql.udf.generic.GenericUDF; import org.apache.hadoop.hive.serde2.typeinfo.TypeInfoFactory; +import com.google.common.collect.ImmutableSet; + /* * convert a RexNode to an ExprNodeDesc */ public class ExprNodeConverter extends RexVisitorImpl { - String tabAlias; - String columnAlias; - RelDataType inputRowType; - RelDataType outputRowType; - boolean partitioningExpr; - WindowFunctionSpec wfs; + private final String tabAlias; + private final String columnAlias; + private final RelDataType inputRowType; + private final RelDataType outputRowType; + private final ImmutableSet inputVCols; + private WindowFunctionSpec wfs; private final RelDataTypeFactory dTFactory; protected final Log LOG = LogFactory.getLog(this.getClass().getName()); public ExprNodeConverter(String tabAlias, RelDataType inputRowType, - boolean partitioningExpr, RelDataTypeFactory dTFactory) { - this(tabAlias, null, inputRowType, null, partitioningExpr, dTFactory); + Set vCols, RelDataTypeFactory dTFactory) { + this(tabAlias, null, inputRowType, null, vCols, dTFactory); } public ExprNodeConverter(String tabAlias, String columnAlias, RelDataType inputRowType, - RelDataType outputRowType, boolean partitioningExpr, RelDataTypeFactory dTFactory) { + RelDataType outputRowType, Set inputVCols, RelDataTypeFactory dTFactory) { super(true); this.tabAlias = tabAlias; this.columnAlias = columnAlias; this.inputRowType = inputRowType; this.outputRowType = outputRowType; - this.partitioningExpr = partitioningExpr; - this.dTFactory = dTFactory; + this.inputVCols = ImmutableSet.copyOf(inputVCols); + this.dTFactory = dTFactory; } public WindowFunctionSpec getWindowFunctionSpec() { @@ -106,7 +109,7 @@ public WindowFunctionSpec getWindowFunctionSpec() { public ExprNodeDesc visitInputRef(RexInputRef inputRef) { RelDataTypeField f = inputRowType.getFieldList().get(inputRef.getIndex()); return new ExprNodeColumnDesc(TypeConverter.convert(f.getType()), f.getName(), tabAlias, - partitioningExpr); + inputVCols.contains(inputRef.getIndex())); } /** @@ -257,7 +260,7 @@ public ExprNodeDesc visitOver(RexOver over) { RelDataTypeField f = outputRowType.getField(columnAlias, false, false); return new ExprNodeColumnDesc(TypeConverter.convert(f.getType()), columnAlias, tabAlias, - partitioningExpr); + false); } private PartitioningSpec getPSpec(RexWindow window) { diff --git a/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/translator/HiveGBOpConvUtil.java b/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/translator/HiveGBOpConvUtil.java index 55f1247..d7f5fad 100644 --- a/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/translator/HiveGBOpConvUtil.java +++ b/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/translator/HiveGBOpConvUtil.java @@ -154,7 +154,7 @@ private static GBInfo getGBInfo(HiveAggregate aggRel, OpAttr inputOpAf, HiveConf // 1. Collect GB Keys RelNode aggInputRel = aggRel.getInput(); ExprNodeConverter exprConv = new ExprNodeConverter(inputOpAf.tabAlias, - aggInputRel.getRowType(), false, aggRel.getCluster().getTypeFactory()); + aggInputRel.getRowType(), new HashSet(), aggRel.getCluster().getTypeFactory()); ExprNodeDesc tmpExprNodeDesc; for (int i : aggRel.getGroupSet()) { @@ -639,7 +639,7 @@ private static OpAttr genReduceGBRS(OpAttr inputOpAf, GBInfo gbInfo) throws Sema rsOp.setColumnExprMap(colExprMap); - return new OpAttr("", new HashMap(), rsOp); + return new OpAttr("", new HashSet(), rsOp); } private static OpAttr genMapSideGBRS(OpAttr inputOpAf, GBInfo gbInfo) throws SemanticException { @@ -677,7 +677,7 @@ private static OpAttr genMapSideGBRS(OpAttr inputOpAf, GBInfo gbInfo) throws Sem rsOp.setColumnExprMap(colExprMap); - return new OpAttr("", new HashMap(), rsOp); + return new OpAttr("", new HashSet(), rsOp); } private static OpAttr genMapSideRS(OpAttr inputOpAf, GBInfo gbInfo) throws SemanticException { @@ -732,7 +732,7 @@ private static OpAttr genMapSideRS(OpAttr inputOpAf, GBInfo gbInfo) throws Seman rsOp.setColumnExprMap(colExprMap); - return new OpAttr("", new HashMap(), rsOp); + return new OpAttr("", new HashSet(), rsOp); } private static OpAttr genReduceSideGB2(OpAttr inputOpAf, GBInfo gbInfo) throws SemanticException { @@ -799,7 +799,7 @@ private static OpAttr genReduceSideGB2(OpAttr inputOpAf, GBInfo gbInfo) throws S rsGBOp2.setColumnExprMap(colExprMap); // TODO: Shouldn't we propgate vc? is it vc col from tab or all vc - return new OpAttr("", new HashMap(), rsGBOp2); + return new OpAttr("", new HashSet(), rsGBOp2); } private static OpAttr genReduceSideGB1(OpAttr inputOpAf, GBInfo gbInfo, boolean computeGrpSet, @@ -935,7 +935,7 @@ private static OpAttr genReduceSideGB1(OpAttr inputOpAf, GBInfo gbInfo, boolean rsGBOp.setColumnExprMap(colExprMap); - return new OpAttr("", new HashMap(), rsGBOp); + return new OpAttr("", new HashSet(), rsGBOp); } /** @@ -1035,7 +1035,7 @@ private static OpAttr genReduceSideGB1NoMapGB(OpAttr inputOpAf, GBInfo gbInfo, false, -1, numDistinctUDFs > 0), new RowSchema(colInfoLst), rs); rsGB1.setColumnExprMap(colExprMap); - return new OpAttr("", new HashMap(), rsGB1); + return new OpAttr("", new HashSet(), rsGB1); } @SuppressWarnings("unchecked") @@ -1111,7 +1111,7 @@ private static OpAttr genMapSideGB(OpAttr inputOpAf, GBInfo gbAttrs) throws Sema // NOTE: UDAF is not included in ExprColMap gbOp.setColumnExprMap(colExprMap); - return new OpAttr("", new HashMap(), gbOp); + return new OpAttr("", new HashSet(), gbOp); } private static void addGrpSetCol(boolean createConstantExpr, String grpSetIDExprName, diff --git a/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/translator/HiveOpConverter.java b/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/translator/HiveOpConverter.java index d15520c..feb26bc 100644 --- a/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/translator/HiveOpConverter.java +++ b/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/translator/HiveOpConverter.java @@ -95,6 +95,7 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; public class HiveOpConverter { @@ -128,16 +129,16 @@ public HiveOpConverter(SemanticAnalyzer semanticAnalyzer, HiveConf hiveConf, static class OpAttr { final String tabAlias; ImmutableList inputs; - ImmutableMap vcolMap; + ImmutableSet vcolsInCalcite; - OpAttr(String tabAlias, Map vcolMap, Operator... inputs) { + OpAttr(String tabAlias, Set vcols, Operator... inputs) { this.tabAlias = tabAlias; - this.vcolMap = ImmutableMap.copyOf(vcolMap); this.inputs = ImmutableList.copyOf(inputs); + this.vcolsInCalcite = ImmutableSet.copyOf(vcols); } private OpAttr clone(Operator... inputs) { - return new OpAttr(tabAlias, this.vcolMap, inputs); + return new OpAttr(tabAlias, vcolsInCalcite, inputs); } } @@ -192,13 +193,13 @@ OpAttr visit(HiveTableScan scanRel) { // 1. Setup TableScan Desc // 1.1 Build col details used by scan ArrayList colInfos = new ArrayList(); - List virtualCols = new ArrayList(ht.getVirtualCols()); - Map hiveScanVColMap = new HashMap(); - List partColNames = new ArrayList(); + List virtualCols = new ArrayList(); List neededColumnIDs = new ArrayList(); - List neededColumns = new ArrayList(); + List neededColumnNames = new ArrayList(); + Set vcolsInCalcite = new HashSet(); - Map posToVColMap = HiveCalciteUtil.getVColsMap(virtualCols, + List partColNames = new ArrayList(); + Map VColsMap = HiveCalciteUtil.getVColsMap(ht.getVirtualCols(), ht.getNoOfNonVirtualCols()); Map posToPartColInfo = ht.getPartColInfoMap(); Map posToNonPartColInfo = ht.getNonPartColInfoMap(); @@ -214,19 +215,20 @@ OpAttr visit(HiveTableScan scanRel) { for (int i = 0; i < neededColIndxsFrmReloptHT.size(); i++) { colName = scanColNames.get(i); posInRHT = neededColIndxsFrmReloptHT.get(i); - if (posToVColMap.containsKey(posInRHT)) { - vc = posToVColMap.get(posInRHT); + if (VColsMap.containsKey(posInRHT)) { + vc = VColsMap.get(posInRHT); virtualCols.add(vc); colInfo = new ColumnInfo(vc.getName(), vc.getTypeInfo(), tableAlias, true, vc.getIsHidden()); - hiveScanVColMap.put(i, vc); + vcolsInCalcite.add(posInRHT); } else if (posToPartColInfo.containsKey(posInRHT)) { partColNames.add(colName); colInfo = posToPartColInfo.get(posInRHT); + vcolsInCalcite.add(posInRHT); } else { colInfo = posToNonPartColInfo.get(posInRHT); } neededColumnIDs.add(posInRHT); - neededColumns.add(colName); + neededColumnNames.add(colName); colInfos.add(colInfo); } @@ -238,7 +240,7 @@ OpAttr visit(HiveTableScan scanRel) { // 1.4. Set needed cols in TSDesc tsd.setNeededColumnIDs(neededColumnIDs); - tsd.setNeededColumns(neededColumns); + tsd.setNeededColumns(neededColumnNames); // 2. Setup TableScan TableScanOperator ts = (TableScanOperator) OperatorFactory.get(tsd, new RowSchema(colInfos)); @@ -249,7 +251,7 @@ OpAttr visit(HiveTableScan scanRel) { LOG.debug("Generated " + ts + " with row schema: [" + ts.getSchema() + "]"); } - return new OpAttr(tableAlias, hiveScanVColMap, ts); + return new OpAttr(tableAlias, vcolsInCalcite, ts); } OpAttr visit(HiveProject projectRel) throws SemanticException { @@ -267,10 +269,11 @@ OpAttr visit(HiveProject projectRel) throws SemanticException { for (int pos = 0; pos < projectRel.getChildExps().size(); pos++) { ExprNodeConverter converter = new ExprNodeConverter(inputOpAf.tabAlias, projectRel .getRowType().getFieldNames().get(pos), projectRel.getInput().getRowType(), - projectRel.getRowType(), false, projectRel.getCluster().getTypeFactory()); + projectRel.getRowType(), inputOpAf.vcolsInCalcite, projectRel.getCluster().getTypeFactory()); ExprNodeDesc exprCol = projectRel.getChildExps().get(pos).accept(converter); colExprMap.put(exprNames.get(pos), exprCol); exprCols.add(exprCol); + //TODO: Cols that come through PTF should it retain (VirtualColumness)? if (converter.getWindowFunctionSpec() != null) { windowingSpec.addWindowFunction(converter.getWindowFunctionSpec()); } @@ -281,7 +284,7 @@ OpAttr visit(HiveProject projectRel) throws SemanticException { } // TODO: is this a safe assumption (name collision, external names...) SelectDesc sd = new SelectDesc(exprCols, exprNames); - Pair, Map> colInfoVColPair = createColInfos( + Pair, Set> colInfoVColPair = createColInfos( projectRel.getChildExps(), exprCols, exprNames, inputOpAf); SelectOperator selOp = (SelectOperator) OperatorFactory.getAndMakeChild(sd, new RowSchema( colInfoVColPair.getKey()), inputOpAf.inputs.get(0)); @@ -312,7 +315,7 @@ OpAttr visit(HiveJoin joinRel) throws SemanticException { JoinPredicateInfo joinPredInfo = JoinPredicateInfo.constructJoinPredicateInfo(joinRel); // 3. Extract join keys from condition - ExprNodeDesc[][] joinKeys = extractJoinKeys(joinPredInfo, joinRel.getInputs()); + ExprNodeDesc[][] joinKeys = extractJoinKeys(joinPredInfo, joinRel.getInputs(), inputs); // 4.a Generate tags for (int tag=0; tag vcolMap = new HashMap(); - vcolMap.putAll(inputs[0].vcolMap); + Set newVcolsInCalcite = new HashSet(); + newVcolsInCalcite.addAll(inputs[0].vcolsInCalcite); if (extractJoinType(joinRel) != JoinType.LEFTSEMI) { int shift = inputs[0].inputs.get(0).getSchema().getSignature().size(); for (int i = 1; i < inputs.length; i++) { - vcolMap.putAll(HiveCalciteUtil.shiftVColsMap(inputs[i].vcolMap, shift)); + newVcolsInCalcite.addAll(HiveCalciteUtil.shiftVColsSet(inputs[i].vcolsInCalcite, shift)); shift += inputs[i].inputs.get(0).getSchema().getSignature().size(); } } // 7. Return result - return new OpAttr(null, vcolMap, joinOp); + return new OpAttr(null, newVcolsInCalcite, joinOp); } OpAttr visit(HiveAggregate aggRel) throws SemanticException { @@ -433,7 +436,7 @@ OpAttr visit(HiveFilter filterRel) throws SemanticException { } ExprNodeDesc filCondExpr = filterRel.getCondition().accept( - new ExprNodeConverter(inputOpAf.tabAlias, filterRel.getInput().getRowType(), false, + new ExprNodeConverter(inputOpAf.tabAlias, filterRel.getInput().getRowType(), inputOpAf.vcolsInCalcite, filterRel.getCluster().getTypeFactory())); FilterDesc filDesc = new FilterDesc(filCondExpr, false); ArrayList cinfoLst = createColInfos(inputOpAf.inputs.get(0)); @@ -474,6 +477,7 @@ OpAttr visit(HiveUnion unionRel) throws SemanticException { LOG.debug("Generated " + unionOp + " with row schema: [" + unionOp.getSchema() + "]"); } + //TODO: Can columns retain virtualness out of union // 3. Return result return inputs[0].clone(unionOp); } @@ -570,14 +574,14 @@ private OpAttr genPTF(OpAttr inputOpAf, WindowingSpec wSpec) throws SemanticExce return inputOpAf.clone(input); } - private ExprNodeDesc[][] extractJoinKeys(JoinPredicateInfo joinPredInfo, List inputs) { + private ExprNodeDesc[][] extractJoinKeys(JoinPredicateInfo joinPredInfo, List inputs, OpAttr[] inputAttr) { ExprNodeDesc[][] joinKeys = new ExprNodeDesc[inputs.size()][]; for (int i = 0; i < inputs.size(); i++) { joinKeys[i] = new ExprNodeDesc[joinPredInfo.getEquiJoinPredicateElements().size()]; for (int j = 0; j < joinPredInfo.getEquiJoinPredicateElements().size(); j++) { JoinLeafPredicateInfo joinLeafPredInfo = joinPredInfo.getEquiJoinPredicateElements().get(j); RexNode key = joinLeafPredInfo.getJoinKeyExprs(j).get(0); - joinKeys[i][j] = convertToExprNode(key, inputs.get(j), null); + joinKeys[i][j] = convertToExprNode(key, inputs.get(j), null, inputAttr[i]); } } return joinKeys; @@ -853,8 +857,8 @@ private static JoinType extractJoinType(HiveJoin join) { return columnDescriptors; } - private static ExprNodeDesc convertToExprNode(RexNode rn, RelNode inputRel, String tabAlias) { - return rn.accept(new ExprNodeConverter(tabAlias, inputRel.getRowType(), false, + private static ExprNodeDesc convertToExprNode(RexNode rn, RelNode inputRel, String tabAlias, OpAttr inputAttr) { + return rn.accept(new ExprNodeConverter(tabAlias, inputRel.getRowType(), inputAttr.vcolsInCalcite, inputRel.getCluster().getTypeFactory())); } @@ -866,7 +870,7 @@ private static ExprNodeDesc convertToExprNode(RexNode rn, RelNode inputRel, Stri return cInfoLst; } - private static Pair, Map> createColInfos( + private static Pair, Set> createColInfos( List calciteExprs, List hiveExprs, List projNames, OpAttr inpOpAf) { if (hiveExprs.size() != projNames.size()) { @@ -876,22 +880,22 @@ private static ExprNodeDesc convertToExprNode(RexNode rn, RelNode inputRel, Stri RexNode rexN; ExprNodeDesc pe; ArrayList colInfos = new ArrayList(); - VirtualColumn vc; - Map newVColMap = new HashMap(); + boolean vc; + Set newVColSet = new HashSet(); for (int i = 0; i < hiveExprs.size(); i++) { pe = hiveExprs.get(i); rexN = calciteExprs.get(i); - vc = null; + vc = false; if (rexN instanceof RexInputRef) { - vc = inpOpAf.vcolMap.get(((RexInputRef) rexN).getIndex()); - if (vc != null) { - newVColMap.put(i, vc); + if (inpOpAf.vcolsInCalcite.contains(((RexInputRef) rexN).getIndex())) { + newVColSet.add(i); + vc = true; } } colInfos - .add(new ColumnInfo(projNames.get(i), pe.getTypeInfo(), inpOpAf.tabAlias, vc != null)); + .add(new ColumnInfo(projNames.get(i), pe.getTypeInfo(), inpOpAf.tabAlias, vc)); } - return new Pair, Map>(colInfos, newVColMap); + return new Pair, Set>(colInfos, newVColSet); } }