diff --git ql/src/java/org/apache/hadoop/hive/ql/exec/vector/VectorMapOperator.java ql/src/java/org/apache/hadoop/hive/ql/exec/vector/VectorMapOperator.java index 02d7bb9..05bffd1 100644 --- ql/src/java/org/apache/hadoop/hive/ql/exec/vector/VectorMapOperator.java +++ ql/src/java/org/apache/hadoop/hive/ql/exec/vector/VectorMapOperator.java @@ -96,6 +96,8 @@ new HashMap, MapOpCtx>(); private ArrayList> extraChildrenToClose = null; + private VectorizationContext vectorizationContext = null; + private boolean outputColumnsInitialized = false;; private static class MapInputPath { String path; @@ -500,9 +502,7 @@ public void setChildren(Configuration hconf) throws HiveException { Path onepath = new Path(new Path(onefile).toUri().getPath()); List aliases = conf.getPathToAliases().get(onefile); - VectorizationContext vectorizationContext = new VectorizationContext - (columnMap, - columnCount); + vectorizationContext = new VectorizationContext(columnMap, columnCount); for (String onealias : aliases) { Operator op = conf.getAliasToWork().get( @@ -785,6 +785,21 @@ public void process(Object value) throws HiveException { // So, use tblOI (and not partOI) for forwarding try { if (value instanceof VectorizedRowBatch) { + if (!outputColumnsInitialized ) { + VectorizedRowBatch vrg = (VectorizedRowBatch) value; + Map outputColumnTypes = + vectorizationContext.getOutputColumnTypeMap(); + if (!outputColumnTypes.isEmpty()) { + int origNumCols = vrg.numCols; + int newNumCols = vrg.cols.length+outputColumnTypes.keySet().size(); + vrg.cols = Arrays.copyOf(vrg.cols, newNumCols); + for (int i = origNumCols; i < newNumCols; i++) { + vrg.cols[i] = vectorizationContext.allocateColumnVector(outputColumnTypes.get(i), + VectorizedRowBatch.DEFAULT_SIZE); + } + } + outputColumnsInitialized = true; + } forward(value, null); } else { Object row = null; diff --git ql/src/java/org/apache/hadoop/hive/ql/exec/vector/VectorizationContext.java ql/src/java/org/apache/hadoop/hive/ql/exec/vector/VectorizationContext.java index 807455e..ef02a66 100644 --- ql/src/java/org/apache/hadoop/hive/ql/exec/vector/VectorizationContext.java +++ ql/src/java/org/apache/hadoop/hive/ql/exec/vector/VectorizationContext.java @@ -20,14 +20,16 @@ import java.lang.reflect.Constructor; import java.util.ArrayList; +import java.util.Arrays; import java.util.HashMap; +import java.util.HashSet; import java.util.List; import java.util.Map; +import java.util.Set; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.apache.hadoop.hive.ql.exec.UDF; -import org.apache.hadoop.hive.ql.exec.vector.expressions.ColumnExpression; import org.apache.hadoop.hive.ql.exec.vector.expressions.FilterExprAndExpr; import org.apache.hadoop.hive.ql.exec.vector.expressions.FilterExprOrExpr; import org.apache.hadoop.hive.ql.exec.vector.expressions.FilterNotExpr; @@ -83,8 +85,6 @@ import org.apache.hadoop.hive.ql.udf.generic.GenericUDFOPNotNull; import org.apache.hadoop.hive.ql.udf.generic.GenericUDFOPNull; import org.apache.hadoop.hive.ql.udf.generic.GenericUDFOPOr; -import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; -import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory; import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory; /** @@ -99,24 +99,80 @@ //columnName to column position map private final Map columnMap; - //Next column to be used for intermediate output - private int nextOutputColumn; + private final int firstOutputColumnIndex; + private OperatorType opType; //Map column number to type - private final Map outputColumnTypes; + private final OutputColumnManager ocm; public VectorizationContext(Map columnMap, int initialOutputCol) { this.columnMap = columnMap; - this.nextOutputColumn = initialOutputCol; - this.outputColumnTypes = new HashMap(); + this.ocm = new OutputColumnManager(initialOutputCol); + this.firstOutputColumnIndex = initialOutputCol; } - - public int allocateOutputColumn (String columnName, String columnType) { - int newColumnIndex = nextOutputColumn++; - columnMap.put(columnName, newColumnIndex); - outputColumnTypes.put(newColumnIndex, columnType); - return newColumnIndex; + + + private class OutputColumnManager { + private final int initialOutputCol; + private int outputColCount = 0; + + OutputColumnManager(int initialOutputCol) { + this.initialOutputCol = initialOutputCol; + } + + //The complete list of output columns. These should be added to the + //Vectorized row batch for processing. The index in the row batch is + //equal to the index in this array plus initialOutputCol. + //Start with size 100 and double when needed. + private String [] outputColumnsTypes = new String[100]; + + private final Set usedOutputColumns = new HashSet(); + + int allocateOutputColumn(String columnType) { + return initialOutputCol + allocateOutputColumnInternal(columnType); + } + + private int allocateOutputColumnInternal(String columnType) { + for (int i = 0; i < outputColCount; i++) { + if (usedOutputColumns.contains(i) || + !(outputColumnsTypes)[i].equals(columnType)) { + continue; + } + //Use i + usedOutputColumns.add(i); + return i; + } + //Out of allocated columns + if (outputColCount < outputColumnsTypes.length) { + int newIndex = outputColCount; + outputColumnsTypes[outputColCount++] = columnType; + usedOutputColumns.add(newIndex); + return newIndex; + } else { + //Expand the array + outputColumnsTypes = Arrays.copyOf(outputColumnsTypes, 2*outputColCount); + int newIndex = outputColCount; + outputColumnsTypes[outputColCount++] = columnType; + usedOutputColumns.add(newIndex); + return newIndex; + } + } + + void freeOutputColumn(int index) { + int colIndex = index-initialOutputCol; + if (colIndex >= 0) { + usedOutputColumns.remove(index-initialOutputCol); + } + } + + String getOutputColumnType(int index) { + return outputColumnsTypes[index-initialOutputCol]; + } + + int getNumOfOutputColumn() { + return outputColCount; + } } public void setOperatorType(OperatorType opType) { @@ -151,23 +207,30 @@ private VectorExpression getVectorExpression(ExprNodeColumnDesc return ret; } + /** + * Returns a vector expression for a given expression + * description. + * @param exprDesc, Expression description + * @return {@link VectorExpression} + */ public VectorExpression getVectorExpression(ExprNodeDesc exprDesc) { + VectorExpression ve = null; if (exprDesc instanceof ExprNodeColumnDesc) { - return getVectorExpression((ExprNodeColumnDesc) exprDesc); + ve = getVectorExpression((ExprNodeColumnDesc) exprDesc); } else if (exprDesc instanceof ExprNodeGenericFuncDesc) { ExprNodeGenericFuncDesc expr = (ExprNodeGenericFuncDesc) exprDesc; - return getVectorExpression(expr.getGenericUDF(), + ve = getVectorExpression(expr.getGenericUDF(), expr.getChildExprs()); } - return null; + System.out.println("VectorExpression = "+ve.toString()); + return ve; } - public VectorExpression getUnaryMinusExpression(List childExprList) { + private VectorExpression getUnaryMinusExpression(List childExprList) { ExprNodeDesc childExpr = childExprList.get(0); int inputCol; String colType; VectorExpression v1 = null; - int outputCol = this.nextOutputColumn++; if (childExpr instanceof ExprNodeGenericFuncDesc) { v1 = getVectorExpression(childExpr); inputCol = v1.getOutputColumn(); @@ -179,8 +242,8 @@ public VectorExpression getUnaryMinusExpression(List childExprList } else { throw new RuntimeException("Expression not supported: "+childExpr); } + int outputCol = ocm.allocateOutputColumn(colType); String className = getNormalizedTypeName(colType) + "colUnaryMinus"; - this.nextOutputColumn = outputCol+1; VectorExpression expr; try { expr = (VectorExpression) Class.forName(className). @@ -190,11 +253,12 @@ public VectorExpression getUnaryMinusExpression(List childExprList } if (v1 != null) { expr.setChildExpressions(new VectorExpression [] {v1}); + ocm.freeOutputColumn(v1.getOutputColumn()); } return expr; } - public VectorExpression getUnaryPlusExpression(List childExprList) { + private VectorExpression getUnaryPlusExpression(List childExprList) { ExprNodeDesc childExpr = childExprList.get(0); int inputCol; String colType; @@ -274,6 +338,9 @@ private VectorExpression getBinaryArithmeticExpression(String method, ExprNodeDesc leftExpr = childExpr.get(0); ExprNodeDesc rightExpr = childExpr.get(1); + VectorExpression v1 = null; + VectorExpression v2 = null; + VectorExpression expr = null; if ( (leftExpr instanceof ExprNodeColumnDesc) && (rightExpr instanceof ExprNodeConstantDesc) ) { @@ -284,7 +351,8 @@ private VectorExpression getBinaryArithmeticExpression(String method, String scalarType = constDesc.getTypeString(); String className = getBinaryColumnScalarExpressionClassName(colType, scalarType, method); - int outputCol = this.nextOutputColumn++; + int outputCol = ocm.allocateOutputColumn(getOutputColType(colType, + scalarType, method)); try { expr = (VectorExpression) Class.forName(className). getDeclaredConstructors()[0].newInstance(inputCol, @@ -301,7 +369,8 @@ private VectorExpression getBinaryArithmeticExpression(String method, String scalarType = constDesc.getTypeString(); String className = getBinaryColumnScalarExpressionClassName(colType, scalarType, method); - int outputCol = this.nextOutputColumn++; + String outputColType = getOutputColType(colType, scalarType, method); + int outputCol = ocm.allocateOutputColumn(outputColType); try { expr = (VectorExpression) Class.forName(className). getDeclaredConstructors()[0].newInstance(inputCol, @@ -317,9 +386,10 @@ private VectorExpression getBinaryArithmeticExpression(String method, int inputCol2 = columnMap.get(leftColDesc.getColumn()); String colType1 = rightColDesc.getTypeString(); String colType2 = leftColDesc.getTypeString(); + String outputColType = getOutputColType(colType1, colType2, method); String className = getBinaryColumnColumnExpressionClassName(colType1, colType2, method); - int outputCol = this.nextOutputColumn++; + int outputCol = ocm.allocateOutputColumn(outputColType); try { expr = (VectorExpression) Class.forName(className). getDeclaredConstructors()[0].newInstance(inputCol1, inputCol2, @@ -330,15 +400,15 @@ private VectorExpression getBinaryArithmeticExpression(String method, } else if ((leftExpr instanceof ExprNodeGenericFuncDesc) && (rightExpr instanceof ExprNodeColumnDesc)) { ExprNodeColumnDesc colDesc = (ExprNodeColumnDesc) rightExpr; - int outputCol = this.nextOutputColumn++; - VectorExpression v1 = getVectorExpression(leftExpr); + v1 = getVectorExpression(leftExpr); int inputCol1 = v1.getOutputColumn(); int inputCol2 = columnMap.get(colDesc.getColumn()); String colType1 = v1.getOutputType(); String colType2 = colDesc.getTypeString(); + String outputColType = getOutputColType(colType1, colType2, method); String className = getBinaryColumnColumnExpressionClassName(colType1, colType2, method); - this.nextOutputColumn = outputCol+1; + int outputCol = ocm.allocateOutputColumn(outputColType); try { expr = (VectorExpression) Class.forName(className). getDeclaredConstructors()[0].newInstance(inputCol1, inputCol2, @@ -350,14 +420,14 @@ private VectorExpression getBinaryArithmeticExpression(String method, } else if ((leftExpr instanceof ExprNodeGenericFuncDesc) && (rightExpr instanceof ExprNodeConstantDesc)) { ExprNodeConstantDesc constDesc = (ExprNodeConstantDesc) rightExpr; - int outputCol = this.nextOutputColumn++; - VectorExpression v1 = getVectorExpression(leftExpr); + v1 = getVectorExpression(leftExpr); int inputCol1 = v1.getOutputColumn(); String colType1 = v1.getOutputType(); String scalarType = constDesc.getTypeString(); + String outputColType = getOutputColType(colType1, scalarType, method); + int outputCol = ocm.allocateOutputColumn(outputColType); String className = getBinaryColumnScalarExpressionClassName(colType1, scalarType, method); - this.nextOutputColumn = outputCol+1; try { expr = (VectorExpression) Class.forName(className). getDeclaredConstructors()[0].newInstance(inputCol1, @@ -369,15 +439,15 @@ private VectorExpression getBinaryArithmeticExpression(String method, } else if ((leftExpr instanceof ExprNodeColumnDesc) && (rightExpr instanceof ExprNodeGenericFuncDesc)) { ExprNodeColumnDesc colDesc = (ExprNodeColumnDesc) leftExpr; - int outputCol = this.nextOutputColumn++; - VectorExpression v2 = getVectorExpression(rightExpr); + v2 = getVectorExpression(rightExpr); int inputCol1 = columnMap.get(colDesc.getColumn()); int inputCol2 = v2.getOutputColumn(); String colType1 = colDesc.getTypeString(); String colType2 = v2.getOutputType(); + String outputColType = getOutputColType(colType1, colType2, method); + int outputCol = ocm.allocateOutputColumn(outputColType); String className = getBinaryColumnColumnExpressionClassName(colType1, colType2, method); - this.nextOutputColumn = outputCol+1; try { expr = (VectorExpression) Class.forName(className). getDeclaredConstructors()[0].newInstance(inputCol1, inputCol2, @@ -389,14 +459,14 @@ private VectorExpression getBinaryArithmeticExpression(String method, } else if ((leftExpr instanceof ExprNodeConstantDesc) && (rightExpr instanceof ExprNodeGenericFuncDesc)) { ExprNodeConstantDesc constDesc = (ExprNodeConstantDesc) leftExpr; - int outputCol = this.nextOutputColumn++; - VectorExpression v2 = getVectorExpression(rightExpr); + v2 = getVectorExpression(rightExpr); int inputCol2 = v2.getOutputColumn(); String colType2 = v2.getOutputType(); String scalarType = constDesc.getTypeString(); + String outputColType = getOutputColType(colType2, scalarType, method); + int outputCol = ocm.allocateOutputColumn(outputColType); String className = getBinaryScalarColumnExpressionClassName(colType2, scalarType, method); - this.nextOutputColumn = outputCol+1; try { expr = (VectorExpression) Class.forName(className). getDeclaredConstructors()[0].newInstance(inputCol2, @@ -409,17 +479,16 @@ private VectorExpression getBinaryArithmeticExpression(String method, && (rightExpr instanceof ExprNodeGenericFuncDesc)) { //For arithmetic expression, the child expressions must be materializing //columns - int outputCol = this.nextOutputColumn++; - VectorExpression v1 = getVectorExpression(leftExpr); - VectorExpression v2 = getVectorExpression(rightExpr); + v1 = getVectorExpression(leftExpr); + v2 = getVectorExpression(rightExpr); int inputCol1 = v1.getOutputColumn(); int inputCol2 = v2.getOutputColumn(); String colType1 = v1.getOutputType(); String colType2 = v2.getOutputType(); + String outputColType = getOutputColType(colType1, colType2, method); + int outputCol = ocm.allocateOutputColumn(outputColType); String className = getBinaryColumnColumnExpressionClassName(colType1, colType2, method); - //Reclaim the output columns - this.nextOutputColumn = outputCol+1; try { expr = (VectorExpression) Class.forName(className). getDeclaredConstructors()[0].newInstance(inputCol1, inputCol2, @@ -429,8 +498,14 @@ private VectorExpression getBinaryArithmeticExpression(String method, } expr.setChildExpressions(new VectorExpression [] {v1, v2}); } + //Reclaim output columns of children to be re-used later + if (v1 != null) { + ocm.freeOutputColumn(v1.getOutputColumn()); + } + if (v2 != null) { + ocm.freeOutputColumn(v2.getOutputColumn()); + } return expr; - } private VectorExpression getVectorExpression(GenericUDFOPOr udf, @@ -543,6 +618,8 @@ private VectorExpression getVectorBinaryComparisonFilterExpression(String ExprNodeDesc rightExpr = childExpr.get(1); VectorExpression expr = null; + VectorExpression v1 = null; + VectorExpression v2 = null; if ( (leftExpr instanceof ExprNodeColumnDesc) && (rightExpr instanceof ExprNodeConstantDesc) ) { ExprNodeColumnDesc leftColDesc = (ExprNodeColumnDesc) leftExpr; @@ -593,14 +670,16 @@ private VectorExpression getVectorBinaryComparisonFilterExpression(String } } else if ( (leftExpr instanceof ExprNodeGenericFuncDesc) && (rightExpr instanceof ExprNodeColumnDesc) ) { - VectorExpression v1 = getVectorExpression((ExprNodeGenericFuncDesc) leftExpr); - ExprNodeColumnDesc leftColDesc = (ExprNodeColumnDesc) leftExpr; + v1 = getVectorExpression((ExprNodeGenericFuncDesc) leftExpr); + ExprNodeColumnDesc leftColDesc = (ExprNodeColumnDesc) rightExpr; int inputCol1 = v1.getOutputColumn(); int inputCol2 = columnMap.get(leftColDesc.getColumn()); String colType1 = v1.getOutputType(); String colType2 = leftColDesc.getTypeString(); String className = getFilterColumnColumnExpressionClassName(colType1, colType2, opName); + System.out.println("In the context, Input column 1: "+inputCol1+ + ", column 2: "+inputCol2); try { expr = (VectorExpression) Class.forName(className). getDeclaredConstructors()[0].newInstance(inputCol1, inputCol2); @@ -611,7 +690,7 @@ private VectorExpression getVectorBinaryComparisonFilterExpression(String } else if ( (leftExpr instanceof ExprNodeColumnDesc) && (rightExpr instanceof ExprNodeGenericFuncDesc) ) { ExprNodeColumnDesc rightColDesc = (ExprNodeColumnDesc) leftExpr; - VectorExpression v2 = getVectorExpression((ExprNodeGenericFuncDesc) rightExpr); + v2 = getVectorExpression((ExprNodeGenericFuncDesc) rightExpr); int inputCol1 = columnMap.get(rightColDesc.getColumn()); int inputCol2 = v2.getOutputColumn(); String colType1 = rightColDesc.getTypeString(); @@ -627,8 +706,8 @@ private VectorExpression getVectorBinaryComparisonFilterExpression(String expr.setChildExpressions(new VectorExpression [] {v2}); } else if ( (leftExpr instanceof ExprNodeGenericFuncDesc) && (rightExpr instanceof ExprNodeConstantDesc) ) { - VectorExpression v1 = getVectorExpression((ExprNodeGenericFuncDesc) leftExpr); - ExprNodeConstantDesc constDesc = (ExprNodeConstantDesc) leftExpr; + v1 = getVectorExpression((ExprNodeGenericFuncDesc) leftExpr); + ExprNodeConstantDesc constDesc = (ExprNodeConstantDesc) rightExpr; int inputCol1 = v1.getOutputColumn(); String colType1 = v1.getOutputType(); String scalarType = constDesc.getTypeString(); @@ -645,7 +724,7 @@ private VectorExpression getVectorBinaryComparisonFilterExpression(String } else if ( (leftExpr instanceof ExprNodeConstantDesc) && (rightExpr instanceof ExprNodeGenericFuncDesc) ) { ExprNodeConstantDesc constDesc = (ExprNodeConstantDesc) leftExpr; - VectorExpression v2 = getVectorExpression((ExprNodeGenericFuncDesc) rightExpr); + v2 = getVectorExpression((ExprNodeGenericFuncDesc) rightExpr); int inputCol2 = v2.getOutputColumn(); String scalarType = constDesc.getTypeString(); String colType = v2.getOutputType(); @@ -662,8 +741,8 @@ private VectorExpression getVectorBinaryComparisonFilterExpression(String } else { //For comparison expression, the child expressions must be materializing //columns - VectorExpression v1 = getVectorExpression(leftExpr); - VectorExpression v2 = getVectorExpression(rightExpr); + v1 = getVectorExpression(leftExpr); + v2 = getVectorExpression(rightExpr); int inputCol1 = v1.getOutputColumn(); int inputCol2 = v2.getOutputColumn(); String colType1 = v1.getOutputType(); @@ -678,6 +757,12 @@ private VectorExpression getVectorBinaryComparisonFilterExpression(String } expr.setChildExpressions(new VectorExpression [] {v1, v2}); } + if (v1 != null) { + ocm.freeOutputColumn(v1.getOutputColumn()); + } + if (v2 != null) { + ocm.freeOutputColumn(v2.getOutputColumn()); + } return expr; } @@ -774,6 +859,29 @@ private String getBinaryColumnColumnExpressionClassName(String colType1, return b.toString(); } + private String getOutputColType(String inputType1, String inputType2, String method) { + if (method.equalsIgnoreCase("divide") || inputType1.equalsIgnoreCase("double") || + inputType2.equalsIgnoreCase("double")) { + return "double"; + } else { + if (inputType1.equalsIgnoreCase("string") || inputType2.equalsIgnoreCase("string")) { + return "string"; + } else { + return "long"; + } + } + } + + private String getOutputColType(String inputType, String method) { + if (inputType.equalsIgnoreCase("float") || inputType.equalsIgnoreCase("double")) { + return "double"; + } else if (inputType.equalsIgnoreCase("string")) { + return "string"; + } else { + return "long"; + } + } + static Object[][] aggregatesDefinition = { {"min", "Long", VectorUDAFMinLong.class}, {"min", "Double", VectorUDAFMinDouble.class}, @@ -800,8 +908,8 @@ private String getBinaryColumnColumnExpressionClassName(String colType1, {"stddev_samp","Long", VectorUDAFStdSampLong.class}, {"stddev_samp","Double",VectorUDAFStdSampDouble.class}, }; - - public VectorAggregateExpression getAggregatorExpression(AggregationDesc desc) + + public VectorAggregateExpression getAggregatorExpression(AggregationDesc desc) throws HiveException { ArrayList paramDescList = desc.getParameters(); VectorExpression[] vectorParams = new VectorExpression[paramDescList.size()]; @@ -810,7 +918,7 @@ public VectorAggregateExpression getAggregatorExpression(AggregationDesc desc) ExprNodeDesc exprDesc = paramDescList.get(i); vectorParams[i] = this.getVectorExpression(exprDesc); } - + String aggregateName = desc.getGenericUDAFName(); List params = desc.getParameters(); //TODO: handle length != 1 @@ -821,43 +929,41 @@ public VectorAggregateExpression getAggregatorExpression(AggregationDesc desc) for (Object[] aggDef : aggregatesDefinition) { if (aggDef[0].equals (aggregateName) && aggDef[1].equals(inputType)) { - Class aggClass = + Class aggClass = (Class) (aggDef[2]); try { - Constructor ctor = + Constructor ctor = aggClass.getConstructor(VectorExpression.class); VectorAggregateExpression aggExpr = ctor.newInstance(vectorParams[0]); return aggExpr; } // TODO: change to 1.7 syntax when possible - //catch (InvocationTargetException | IllegalAccessException + //catch (InvocationTargetException | IllegalAccessException // | NoSuchMethodException | InstantiationException) catch (Exception e) { - throw new HiveException("Internal exception for vector aggregate : \"" + + throw new HiveException("Internal exception for vector aggregate : \"" + aggregateName + "\" for type: \"" + inputType + "", e); } } } - throw new HiveException("Vector aggregate not implemented: \"" + aggregateName + + throw new HiveException("Vector aggregate not implemented: \"" + aggregateName + "\" for type: \"" + inputType + ""); } - + static Object[][] columnTypes = { {"Double", DoubleColumnVector.class}, {"Long", LongColumnVector.class}, {"String", BytesColumnVector.class}, }; - public VectorizedRowBatch allocateRowBatch(int rowCount) throws HiveException { - VectorizedRowBatch ret = new VectorizedRowBatch(nextOutputColumn, rowCount); - for (int i=0; i < nextOutputColumn; ++i) { - if (false == outputColumnTypes.containsKey(i)) { - continue; - } - String columnTypeName = outputColumnTypes.get(i); + private VectorizedRowBatch allocateRowBatch(int rowCount) throws HiveException { + int columnCount = firstOutputColumnIndex + ocm.getNumOfOutputColumn(); + VectorizedRowBatch ret = new VectorizedRowBatch(columnCount, rowCount); + for (int i=0; i < columnCount; ++i) { + String columnTypeName = ocm.getOutputColumnType(i); for (Object[] columnType: columnTypes) { if (columnTypeName.equalsIgnoreCase((String)columnType[0])) { Class columnTypeClass = (Class)columnType[1]; @@ -883,26 +989,23 @@ public VectorizedRowBatch allocateRowBatch(int rowCount) throws HiveException { {"long", PrimitiveObjectInspectorFactory.writableLongObjectInspector}, }; - public ObjectInspector getVectorRowObjectInspector(List columnNames) throws HiveException { - List oids = new ArrayList(); - for(String columnName: columnNames) { - int columnIndex = columnMap.get(columnName); - String outputType = outputColumnTypes.get(columnIndex); - ObjectInspector oi = null; - for(Object[] moi: mapObjectInspectors) { - if (outputType.equalsIgnoreCase((String) moi[0])) { - oi = (ObjectInspector) moi[1]; - break; - } - } - if (oi == null) { - throw new HiveException(String.format("Unsuported type: %s for column %d:%s", - outputType, columnIndex, columnName)); - } - oids.add(oi); + public Map getOutputColumnTypeMap() { + Map map = new HashMap(); + for (int i = 0; i < ocm.outputColCount; i++) { + String type = ocm.outputColumnsTypes[i]; + map.put(i+this.firstOutputColumnIndex, type); } + return map; + } - return ObjectInspectorFactory.getStandardStructObjectInspector(columnNames, oids); + public ColumnVector allocateColumnVector(String type, int defaultSize) { + if (type.equalsIgnoreCase("double")) { + return new DoubleColumnVector(defaultSize); + } else if (type.equalsIgnoreCase("string")) { + return new BytesColumnVector(defaultSize); + } else { + return new LongColumnVector(defaultSize); + } } } diff --git ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/templates/CodeGen.java ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/templates/CodeGen.java index 9279101..4bdd0c2 100644 --- ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/templates/CodeGen.java +++ ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/templates/CodeGen.java @@ -260,14 +260,6 @@ private void generate() throws Exception { generateColumnArithmeticColumn(tdesc); } else if (tdesc[0].equals("ColumnUnaryMinus")) { generateColumnUnaryMinus(tdesc); - } else if (tdesc[0].equals("VectorUDAFCount")) { - generateVectorUDAFCount(tdesc); - } else if (tdesc[0].equals("VectorUDAFSum")) { - generateVectorUDAFSum(tdesc); - } else if (tdesc[0].equals("VectorUDAFAvg")) { - generateVectorUDAFAvg(tdesc); - } else if (tdesc[0].equals("VectorUDAFVar")) { - generateVectorUDAFVar(tdesc); } else if (tdesc[0].equals("FilterStringColumnCompareScalar")) { generateFilterStringColumnCompareScalar(tdesc); } else { diff --git ql/src/test/org/apache/hadoop/hive/ql/exec/vector/TestVectorizationContext.java ql/src/test/org/apache/hadoop/hive/ql/exec/vector/TestVectorizationContext.java new file mode 100644 index 0000000..0ef11e3 --- /dev/null +++ ql/src/test/org/apache/hadoop/hive/ql/exec/vector/TestVectorizationContext.java @@ -0,0 +1,145 @@ +package org.apache.hadoop.hive.ql.exec.vector; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import org.apache.hadoop.hive.ql.exec.vector.expressions.VectorExpression; +import org.apache.hadoop.hive.ql.exec.vector.expressions.gen.FilterStringColGreaterStringScalar; +import org.apache.hadoop.hive.ql.exec.vector.expressions.gen.LongColAddLongColumn; +import org.apache.hadoop.hive.ql.exec.vector.expressions.gen.LongColModuloLongColumn; +import org.apache.hadoop.hive.ql.exec.vector.expressions.gen.LongColMultiplyLongColumn; +import org.apache.hadoop.hive.ql.exec.vector.expressions.gen.LongColSubtractLongColumn; +import org.apache.hadoop.hive.ql.plan.ExprNodeColumnDesc; +import org.apache.hadoop.hive.ql.plan.ExprNodeConstantDesc; +import org.apache.hadoop.hive.ql.plan.ExprNodeDesc; +import org.apache.hadoop.hive.ql.plan.ExprNodeGenericFuncDesc; +import org.apache.hadoop.hive.ql.plan.api.OperatorType; +import org.apache.hadoop.hive.ql.udf.UDFOPMinus; +import org.apache.hadoop.hive.ql.udf.UDFOPMod; +import org.apache.hadoop.hive.ql.udf.UDFOPMultiply; +import org.apache.hadoop.hive.ql.udf.UDFOPPlus; +import org.apache.hadoop.hive.ql.udf.generic.GenericUDFBridge; +import org.apache.hadoop.hive.ql.udf.generic.GenericUDFOPGreaterThan; +import org.junit.Test; + +public class TestVectorizationContext { + + @Test + public void testArithmeticExpressionVectorization() { + /** + * Create original expression tree for following + * (plus (minus (plus col1 col2) col3) (multiply col4 (mod col5 col6)) ) + */ + GenericUDFBridge udf1 = new GenericUDFBridge("+", true, UDFOPPlus.class); + GenericUDFBridge udf2 = new GenericUDFBridge("-", true, UDFOPMinus.class); + GenericUDFBridge udf3 = new GenericUDFBridge("*", true, UDFOPMultiply.class); + GenericUDFBridge udf4 = new GenericUDFBridge("+", true, UDFOPPlus.class); + GenericUDFBridge udf5 = new GenericUDFBridge("%", true, UDFOPMod.class); + + ExprNodeGenericFuncDesc sumExpr = new ExprNodeGenericFuncDesc(); + sumExpr.setGenericUDF(udf1); + ExprNodeGenericFuncDesc minusExpr = new ExprNodeGenericFuncDesc(); + minusExpr.setGenericUDF(udf2); + ExprNodeGenericFuncDesc multiplyExpr = new ExprNodeGenericFuncDesc(); + multiplyExpr.setGenericUDF(udf3); + ExprNodeGenericFuncDesc sum2Expr = new ExprNodeGenericFuncDesc(); + sum2Expr.setGenericUDF(udf4); + ExprNodeGenericFuncDesc modExpr = new ExprNodeGenericFuncDesc(); + modExpr.setGenericUDF(udf5); + + ExprNodeColumnDesc col1Expr = new ExprNodeColumnDesc(Long.class, "col1", "table", false); + ExprNodeColumnDesc col2Expr = new ExprNodeColumnDesc(Long.class, "col2", "table", false); + ExprNodeColumnDesc col3Expr = new ExprNodeColumnDesc(Long.class, "col3", "table", false); + ExprNodeColumnDesc col4Expr = new ExprNodeColumnDesc(Long.class, "col4", "table", false); + ExprNodeColumnDesc col5Expr = new ExprNodeColumnDesc(Long.class, "col5", "table", false); + ExprNodeColumnDesc col6Expr = new ExprNodeColumnDesc(Long.class, "col6", "table", false); + + List children1 = new ArrayList(2); + List children2 = new ArrayList(2); + List children3 = new ArrayList(2); + List children4 = new ArrayList(2); + List children5 = new ArrayList(2); + + children1.add(minusExpr); + children1.add(multiplyExpr); + sumExpr.setChildExprs(children1); + + children2.add(sum2Expr); + children2.add(col3Expr); + minusExpr.setChildExprs(children2); + + children3.add(col1Expr); + children3.add(col2Expr); + sum2Expr.setChildExprs(children3); + + children4.add(col4Expr); + children4.add(modExpr); + multiplyExpr.setChildExprs(children4); + + children5.add(col5Expr); + children5.add(col6Expr); + modExpr.setChildExprs(children5); + + Map columnMap = new HashMap(); + columnMap.put("col1", 1); + columnMap.put("col2", 2); + columnMap.put("col3", 3); + columnMap.put("col4", 4); + columnMap.put("col5", 5); + columnMap.put("col6", 6); + + //Generate vectorized expression + VectorizationContext vc = new VectorizationContext(columnMap, 6); + + VectorExpression ve = vc.getVectorExpression(sumExpr); + + //Verify vectorized expression + assertTrue(ve instanceof LongColAddLongColumn); + assertEquals(2, ve.getChildExpressions().length); + VectorExpression childExpr1 = ve.getChildExpressions()[0]; + VectorExpression childExpr2 = ve.getChildExpressions()[1]; + assertEquals(6, ve.getOutputColumn()); + + assertTrue(childExpr1 instanceof LongColSubtractLongColumn); + assertEquals(1, childExpr1.getChildExpressions().length); + assertTrue(childExpr1.getChildExpressions()[0] instanceof LongColAddLongColumn); + assertEquals(7, childExpr1.getOutputColumn()); + assertEquals(6, childExpr1.getChildExpressions()[0].getOutputColumn()); + + assertTrue(childExpr2 instanceof LongColMultiplyLongColumn); + assertEquals(1, childExpr2.getChildExpressions().length); + assertTrue(childExpr2.getChildExpressions()[0] instanceof LongColModuloLongColumn); + assertEquals(8, childExpr2.getOutputColumn()); + assertEquals(6, childExpr2.getChildExpressions()[0].getOutputColumn()); + } + + @Test + public void testStringFilterExpressions() { + ExprNodeColumnDesc col1Expr = new ExprNodeColumnDesc(String.class, "col1", "table", false); + ExprNodeConstantDesc constDesc = new ExprNodeConstantDesc("Alpha"); + + GenericUDFOPGreaterThan udf = new GenericUDFOPGreaterThan(); + ExprNodeGenericFuncDesc exprDesc = new ExprNodeGenericFuncDesc(); + exprDesc.setGenericUDF(udf); + List children1 = new ArrayList(2); + children1.add(col1Expr); + children1.add(constDesc); + exprDesc.setChildExprs(children1); + + Map columnMap = new HashMap(); + columnMap.put("col1", 1); + columnMap.put("col2", 2); + + VectorizationContext vc = new VectorizationContext(columnMap, 2); + vc.setOperatorType(OperatorType.FILTER); + + VectorExpression ve = vc.getVectorExpression(exprDesc); + + assertTrue(ve instanceof FilterStringColGreaterStringScalar); + } +} diff --git ql/src/test/org/apache/hadoop/hive/ql/exec/vector/expressions/TestVectorFilterExpressions.java ql/src/test/org/apache/hadoop/hive/ql/exec/vector/expressions/TestVectorFilterExpressions.java index 1ac1378..ea7eca0 100644 --- ql/src/test/org/apache/hadoop/hive/ql/exec/vector/expressions/TestVectorFilterExpressions.java +++ ql/src/test/org/apache/hadoop/hive/ql/exec/vector/expressions/TestVectorFilterExpressions.java @@ -152,7 +152,7 @@ public void testFilterLongColLessLongColumn() { //Basic case lcv0.vector[0] = 10; lcv0.vector[1] = 20; - lcv0.vector[2] = 10; + lcv0.vector[2] = 9; lcv0.vector[3] = 20; lcv0.vector[4] = 10; @@ -162,16 +162,9 @@ public void testFilterLongColLessLongColumn() { lcv1.vector[3] = 10; lcv1.vector[4] = 20; - childExpr.evaluate(vrg); - - assertEquals(20, lcv2.vector[0]); - assertEquals(30, lcv2.vector[1]); - assertEquals(20, lcv2.vector[2]); - assertEquals(30, lcv2.vector[3]); - assertEquals(20, lcv2.vector[4]); - expr.evaluate(vrg); - assertEquals(0, vrg.size); + assertEquals(1, vrg.size); + assertEquals(2, vrg.selected[0]); } }