diff --git ql/src/java/org/apache/hadoop/hive/ql/exec/vector/VectorizationContext.java ql/src/java/org/apache/hadoop/hive/ql/exec/vector/VectorizationContext.java index 4a9b870..14cd565 100644 --- ql/src/java/org/apache/hadoop/hive/ql/exec/vector/VectorizationContext.java +++ ql/src/java/org/apache/hadoop/hive/ql/exec/vector/VectorizationContext.java @@ -124,14 +124,14 @@ public void setFileKey(String fileKey) { this.fileKey = fileKey; } - private int getInputColumnIndex(String name) { - if (!columnMap.containsKey(name)) { - LOG.error(String.format("The column %s is not in the vectorization context column map.", name)); - } - return columnMap.get(name); + protected int getInputColumnIndex(String name) { + if (!columnMap.containsKey(name)) { + LOG.error(String.format("The column %s is not in the vectorization context column map.", name)); + } + return columnMap.get(name); } - private int getInputColumnIndex(ExprNodeColumnDesc colExpr) { + protected int getInputColumnIndex(ExprNodeColumnDesc colExpr) { return columnMap.get(colExpr.getColumn()); } @@ -139,7 +139,7 @@ private int getInputColumnIndex(ExprNodeColumnDesc colExpr) { private final int initialOutputCol; private int outputColCount = 0; - OutputColumnManager(int initialOutputCol) { + protected OutputColumnManager(int initialOutputCol) { this.initialOutputCol = initialOutputCol; } @@ -152,6 +152,10 @@ private int getInputColumnIndex(ExprNodeColumnDesc colExpr) { private final Set usedOutputColumns = new HashSet(); int allocateOutputColumn(String columnType) { + if (initialOutputCol < 0) { + // This is a test + return 0; + } int relativeCol = allocateOutputColumnInternal(columnType); return initialOutputCol + relativeCol; } @@ -183,6 +187,10 @@ private int allocateOutputColumnInternal(String columnType) { } void freeOutputColumn(int index) { + if (initialOutputCol < 0) { + // This is a test + return; + } int colIndex = index-initialOutputCol; if (colIndex >= 0) { usedOutputColumns.remove(index-initialOutputCol); diff --git ql/src/java/org/apache/hadoop/hive/ql/optimizer/physical/Vectorizer.java ql/src/java/org/apache/hadoop/hive/ql/optimizer/physical/Vectorizer.java index e698870..c9919a4 100644 --- ql/src/java/org/apache/hadoop/hive/ql/optimizer/physical/Vectorizer.java +++ ql/src/java/org/apache/hadoop/hive/ql/optimizer/physical/Vectorizer.java @@ -46,6 +46,7 @@ import org.apache.hadoop.hive.ql.exec.Task; import org.apache.hadoop.hive.ql.exec.UDF; import org.apache.hadoop.hive.ql.exec.mr.MapRedTask; +import org.apache.hadoop.hive.ql.exec.vector.VectorExpressionDescriptor; import org.apache.hadoop.hive.ql.exec.vector.VectorizationContext; import org.apache.hadoop.hive.ql.exec.vector.VectorizedInputFormatInterface; import org.apache.hadoop.hive.ql.lib.DefaultGraphWalker; @@ -63,14 +64,7 @@ import org.apache.hadoop.hive.ql.metadata.Table; import org.apache.hadoop.hive.ql.parse.RowResolver; import org.apache.hadoop.hive.ql.parse.SemanticException; -import org.apache.hadoop.hive.ql.plan.AbstractOperatorDesc; -import org.apache.hadoop.hive.ql.plan.AggregationDesc; -import org.apache.hadoop.hive.ql.plan.ExprNodeDesc; -import org.apache.hadoop.hive.ql.plan.ExprNodeGenericFuncDesc; -import org.apache.hadoop.hive.ql.plan.MapJoinDesc; -import org.apache.hadoop.hive.ql.plan.MapWork; -import org.apache.hadoop.hive.ql.plan.OperatorDesc; -import org.apache.hadoop.hive.ql.plan.PartitionDesc; +import org.apache.hadoop.hive.ql.plan.*; import org.apache.hadoop.hive.ql.plan.api.OperatorType; import org.apache.hadoop.hive.ql.udf.UDFAcos; import org.apache.hadoop.hive.ql.udf.UDFAsin; @@ -352,7 +346,7 @@ public Object process(Node nd, Stack stack, NodeProcessorCtx procCtx, } boolean ret = validateOperator(op); if (!ret) { - LOG.info("Operator: "+op.getName()+" could not be vectorized."); + LOG.info("Operator: " + op.getName() + " could not be vectorized."); return new Boolean(false); } } @@ -451,6 +445,22 @@ public Object process(Node nd, Stack stack, NodeProcessorCtx procCtx, } } + private static class ValidatorVectorizationContext extends VectorizationContext { + private ValidatorVectorizationContext() { + super(null, -1); + } + + @Override + protected int getInputColumnIndex(String name) { + return 0; + } + + @Override + protected int getInputColumnIndex(ExprNodeColumnDesc colExpr) { + return 0; + } + } + @Override public PhysicalContext resolve(PhysicalContext pctx) throws SemanticException { this.physicalContext = pctx; @@ -509,7 +519,7 @@ private boolean validateMapJoinOperator(MapJoinOperator op) { List filterExprs = desc.getFilters().get(posBigTable); List keyExprs = desc.getKeys().get(posBigTable); List valueExprs = desc.getExprs().get(posBigTable); - return validateExprNodeDesc(filterExprs) && + return validateExprNodeDesc(filterExprs, VectorExpressionDescriptor.Mode.FILTER) && validateExprNodeDesc(keyExprs) && validateExprNodeDesc(valueExprs); } @@ -535,7 +545,7 @@ private boolean validateSelectOperator(SelectOperator op) { private boolean validateFilterOperator(FilterOperator op) { ExprNodeDesc desc = op.getConf().getPredicate(); - return validateExprNodeDesc(desc); + return validateExprNodeDesc(desc, VectorExpressionDescriptor.Mode.FILTER); } private boolean validateGroupByOperator(GroupByOperator op) { @@ -547,8 +557,12 @@ private boolean validateGroupByOperator(GroupByOperator op) { } private boolean validateExprNodeDesc(List descs) { + return validateExprNodeDesc(descs, VectorExpressionDescriptor.Mode.PROJECTION); + } + + private boolean validateExprNodeDesc(List descs, VectorExpressionDescriptor.Mode mode) { for (ExprNodeDesc d : descs) { - boolean ret = validateExprNodeDesc(d); + boolean ret = validateExprNodeDesc(d, mode); if (!ret) { return false; } @@ -566,7 +580,7 @@ private boolean validateAggregationDesc(List descs) { return true; } - private boolean validateExprNodeDesc(ExprNodeDesc desc) { + private boolean validateExprNodeDescRecursive(ExprNodeDesc desc) { boolean ret = validateDataType(desc.getTypeInfo().getTypeName()); if (!ret) { return false; @@ -580,12 +594,37 @@ private boolean validateExprNodeDesc(ExprNodeDesc desc) { } if (desc.getChildren() != null) { for (ExprNodeDesc d: desc.getChildren()) { - validateExprNodeDesc(d); + boolean r = validateExprNodeDescRecursive(d); + if (!r) { + return false; + } } } return true; } + private boolean validateExprNodeDesc(ExprNodeDesc desc) { + return validateExprNodeDesc(desc, VectorExpressionDescriptor.Mode.PROJECTION); + } + + boolean validateExprNodeDesc(ExprNodeDesc desc, VectorExpressionDescriptor.Mode mode) { + if (!validateExprNodeDescRecursive(desc)) { + return false; + } + try { + VectorizationContext vc = new ValidatorVectorizationContext(); + if (vc.getVectorExpression(desc, mode) == null) { + return false; + } + } catch (HiveException e) { + if (LOG.isDebugEnabled()) { + LOG.debug("Failed to vectorize", e); + } + return false; + } + return true; + } + private boolean validateGenericUdf(ExprNodeGenericFuncDesc genericUDFExpr) { if (VectorizationContext.isCustomUDF(genericUDFExpr)) { return true; diff --git ql/src/test/org/apache/hadoop/hive/ql/optimizer/physical/TestVectorizer.java ql/src/test/org/apache/hadoop/hive/ql/optimizer/physical/TestVectorizer.java index 51bc09a..5234969 100644 --- ql/src/test/org/apache/hadoop/hive/ql/optimizer/physical/TestVectorizer.java +++ ql/src/test/org/apache/hadoop/hive/ql/optimizer/physical/TestVectorizer.java @@ -25,18 +25,18 @@ import junit.framework.Assert; +import org.apache.hadoop.hive.ql.exec.Description; import org.apache.hadoop.hive.ql.exec.GroupByOperator; +import org.apache.hadoop.hive.ql.exec.UDFArgumentException; +import org.apache.hadoop.hive.ql.exec.vector.VectorExpressionDescriptor; import org.apache.hadoop.hive.ql.exec.vector.VectorGroupByOperator; import org.apache.hadoop.hive.ql.exec.vector.VectorizationContext; import org.apache.hadoop.hive.ql.exec.vector.expressions.aggregates.gen.VectorUDAFSumLong; import org.apache.hadoop.hive.ql.exec.vector.expressions.gen.FuncAbsLongToLong; import org.apache.hadoop.hive.ql.metadata.HiveException; -import org.apache.hadoop.hive.ql.plan.AggregationDesc; -import org.apache.hadoop.hive.ql.plan.ExprNodeColumnDesc; -import org.apache.hadoop.hive.ql.plan.ExprNodeDesc; -import org.apache.hadoop.hive.ql.plan.ExprNodeGenericFuncDesc; -import org.apache.hadoop.hive.ql.plan.GroupByDesc; -import org.apache.hadoop.hive.ql.udf.generic.GenericUDFAbs; +import org.apache.hadoop.hive.ql.plan.*; +import org.apache.hadoop.hive.ql.udf.generic.*; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; import org.apache.hadoop.hive.serde2.typeinfo.TypeInfoFactory; import org.junit.Before; import org.junit.Test; @@ -50,9 +50,28 @@ public void setUp() { Map columnMap = new HashMap(); columnMap.put("col1", 0); columnMap.put("col2", 1); + columnMap.put("col3", 2); //Generate vectorized expression - vContext = new VectorizationContext(columnMap, 2); + vContext = new VectorizationContext(columnMap, 3); + } + + @Description(name = "fake", value = "FAKE") + static class FakeGenericUDF extends GenericUDF { + @Override + public ObjectInspector initialize(ObjectInspector[] arguments) throws UDFArgumentException { + return null; + } + + @Override + public Object evaluate(DeferredObject[] arguments) throws HiveException { + return null; + } + + @Override + public String getDisplayString(String[] children) { + return "fake"; + } } @Test @@ -96,4 +115,37 @@ public void testAggregateOnUDF() throws HiveException { VectorUDAFSumLong udaf = (VectorUDAFSumLong) vectorOp.getAggregators()[0]; Assert.assertEquals(FuncAbsLongToLong.class, udaf.getInputExpression().getClass()); } + + @Test + public void testValidateNestedExpressions() { + ExprNodeColumnDesc col1Expr = new ExprNodeColumnDesc(Integer.class, "col1", "table", false); + ExprNodeConstantDesc constDesc = new ExprNodeConstantDesc(new Integer(10)); + + GenericUDFOPGreaterThan udf = new GenericUDFOPGreaterThan(); + ExprNodeGenericFuncDesc greaterExprDesc = new ExprNodeGenericFuncDesc(); + greaterExprDesc.setTypeInfo(TypeInfoFactory.booleanTypeInfo); + greaterExprDesc.setGenericUDF(udf); + List children1 = new ArrayList(2); + children1.add(col1Expr); + children1.add(constDesc); + greaterExprDesc.setChildren(children1); + + FakeGenericUDF udf2 = new FakeGenericUDF(); + ExprNodeGenericFuncDesc nonSupportedExpr = new ExprNodeGenericFuncDesc(); + nonSupportedExpr.setTypeInfo(TypeInfoFactory.booleanTypeInfo); + nonSupportedExpr.setGenericUDF(udf2); + + GenericUDFOPAnd andUdf = new GenericUDFOPAnd(); + ExprNodeGenericFuncDesc andExprDesc = new ExprNodeGenericFuncDesc(); + andExprDesc.setTypeInfo(TypeInfoFactory.booleanTypeInfo); + andExprDesc.setGenericUDF(andUdf); + List children3 = new ArrayList(2); + children3.add(greaterExprDesc); + children3.add(nonSupportedExpr); + andExprDesc.setChildren(children3); + + Vectorizer v = new Vectorizer(); + Assert.assertFalse(v.validateExprNodeDesc(andExprDesc, VectorExpressionDescriptor.Mode.FILTER)); + Assert.assertFalse(v.validateExprNodeDesc(andExprDesc, VectorExpressionDescriptor.Mode.PROJECTION)); + } }