diff --git ql/src/java/org/apache/hadoop/hive/ql/exec/vector/VectorizationContext.java ql/src/java/org/apache/hadoop/hive/ql/exec/vector/VectorizationContext.java index 0534abb..41cb1c5 100644 --- ql/src/java/org/apache/hadoop/hive/ql/exec/vector/VectorizationContext.java +++ ql/src/java/org/apache/hadoop/hive/ql/exec/vector/VectorizationContext.java @@ -124,7 +124,7 @@ public void setFileKey(String fileKey) { this.fileKey = fileKey; } - private int getInputColumnIndex(String name) { + protected int getInputColumnIndex(String name) { if (!columnMap.containsKey(name)) { LOG.error(String.format("The column %s is not in the vectorization context column map.", name)); } @@ -139,7 +139,7 @@ private int getInputColumnIndex(ExprNodeColumnDesc colExpr) { private final int initialOutputCol; private int outputColCount = 0; - OutputColumnManager(int initialOutputCol) { + protected OutputColumnManager(int initialOutputCol) { this.initialOutputCol = initialOutputCol; } @@ -152,6 +152,10 @@ private int getInputColumnIndex(ExprNodeColumnDesc colExpr) { private final Set usedOutputColumns = new HashSet(); int allocateOutputColumn(String columnType) { + if (initialOutputCol < 0) { + // This is a test + return 0; + } int relativeCol = allocateOutputColumnInternal(columnType); return initialOutputCol + relativeCol; } @@ -183,6 +187,10 @@ private int allocateOutputColumnInternal(String columnType) { } void freeOutputColumn(int index) { + if (initialOutputCol < 0) { + // This is a test + return; + } int colIndex = index-initialOutputCol; if (colIndex >= 0) { usedOutputColumns.remove(index-initialOutputCol); @@ -245,7 +253,7 @@ public VectorExpression getVectorExpression(ExprNodeDesc exprDesc, Mode mode) th ve = getCustomUDFExpression(expr); } else { ve = getGenericUdfVectorExpression(expr.getGenericUDF(), - expr.getChildExprs(), mode); + expr.getChildren(), mode); } } else if (exprDesc instanceof ExprNodeConstantDesc) { ve = getConstantVectorExpression((ExprNodeConstantDesc) exprDesc, mode); diff --git ql/src/java/org/apache/hadoop/hive/ql/optimizer/physical/Vectorizer.java ql/src/java/org/apache/hadoop/hive/ql/optimizer/physical/Vectorizer.java index e698870..e6641a0 100644 --- ql/src/java/org/apache/hadoop/hive/ql/optimizer/physical/Vectorizer.java +++ ql/src/java/org/apache/hadoop/hive/ql/optimizer/physical/Vectorizer.java @@ -46,6 +46,7 @@ import org.apache.hadoop.hive.ql.exec.Task; import org.apache.hadoop.hive.ql.exec.UDF; import org.apache.hadoop.hive.ql.exec.mr.MapRedTask; +import org.apache.hadoop.hive.ql.exec.vector.VectorExpressionDescriptor; import org.apache.hadoop.hive.ql.exec.vector.VectorizationContext; import org.apache.hadoop.hive.ql.exec.vector.VectorizedInputFormatInterface; import org.apache.hadoop.hive.ql.lib.DefaultGraphWalker; @@ -451,6 +452,17 @@ public Object process(Node nd, Stack stack, NodeProcessorCtx procCtx, } } + private static class ValidatorVectorizationContext extends VectorizationContext { + private ValidatorVectorizationContext() { + super(null, -1); + } + + @Override + protected int getInputColumnIndex(String name) { + return 0; + } + } + @Override public PhysicalContext resolve(PhysicalContext pctx) throws SemanticException { this.physicalContext = pctx; @@ -509,7 +521,7 @@ private boolean validateMapJoinOperator(MapJoinOperator op) { List filterExprs = desc.getFilters().get(posBigTable); List keyExprs = desc.getKeys().get(posBigTable); List valueExprs = desc.getExprs().get(posBigTable); - return validateExprNodeDesc(filterExprs) && + return validateExprNodeDesc(filterExprs, VectorExpressionDescriptor.Mode.FILTER) && validateExprNodeDesc(keyExprs) && validateExprNodeDesc(valueExprs); } @@ -535,7 +547,7 @@ private boolean validateSelectOperator(SelectOperator op) { private boolean validateFilterOperator(FilterOperator op) { ExprNodeDesc desc = op.getConf().getPredicate(); - return validateExprNodeDesc(desc); + return validateExprNodeDesc(desc, VectorExpressionDescriptor.Mode.FILTER); } private boolean validateGroupByOperator(GroupByOperator op) { @@ -547,6 +559,10 @@ private boolean validateGroupByOperator(GroupByOperator op) { } private boolean validateExprNodeDesc(List descs) { + return validateExprNodeDesc(descs, VectorExpressionDescriptor.Mode.PROJECTION); + } + + private boolean validateExprNodeDesc(List descs, VectorExpressionDescriptor.Mode mode) { for (ExprNodeDesc d : descs) { boolean ret = validateExprNodeDesc(d); if (!ret) { @@ -566,7 +582,7 @@ private boolean validateAggregationDesc(List descs) { return true; } - private boolean validateExprNodeDesc(ExprNodeDesc desc) { + private boolean validateExprNodeDescRecursive(ExprNodeDesc desc) { boolean ret = validateDataType(desc.getTypeInfo().getTypeName()); if (!ret) { return false; @@ -580,12 +596,34 @@ private boolean validateExprNodeDesc(ExprNodeDesc desc) { } if (desc.getChildren() != null) { for (ExprNodeDesc d: desc.getChildren()) { - validateExprNodeDesc(d); + boolean r = validateExprNodeDescRecursive(d); + if (!r) { + return false; + } } } return true; } + private boolean validateExprNodeDesc(ExprNodeDesc desc) { + return validateExprNodeDesc(desc, VectorExpressionDescriptor.Mode.PROJECTION); + } + + private boolean validateExprNodeDesc(ExprNodeDesc desc, VectorExpressionDescriptor.Mode mode) { + if (!validateExprNodeDescRecursive(desc)) { + return false; + } + try { + VectorizationContext vc = new ValidatorVectorizationContext(); + if (vc.getVectorExpression(desc, mode) == null) { + return false; + } + } catch (HiveException e) { + return false; + } + return true; + } + private boolean validateGenericUdf(ExprNodeGenericFuncDesc genericUDFExpr) { if (VectorizationContext.isCustomUDF(genericUDFExpr)) { return true;