diff --git ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/aggregates/gen/VectorUDAFSumDouble.java ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/aggregates/gen/VectorUDAFSumDouble.java index a2e8fb3..e9572a2 100644 --- ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/aggregates/gen/VectorUDAFSumDouble.java +++ ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/aggregates/gen/VectorUDAFSumDouble.java @@ -51,7 +51,7 @@ /** /* class for storing the current aggregate value. */ - static private final class Aggregation implements AggregationBuffer { + private static final class Aggregation implements AggregationBuffer { double sum; boolean isNull; @@ -97,9 +97,9 @@ public void aggregateInputSelection( inputExpression.evaluate(batch); - LongColumnVector inputVector = (LongColumnVector)batch. + DoubleColumnVector inputVector = (DoubleColumnVector)batch. cols[this.inputExpression.getOutputColumn()]; - long[] vector = inputVector.vector; + double[] vector = inputVector.vector; if (inputVector.noNulls) { if (inputVector.isRepeating) { @@ -145,7 +145,7 @@ public void aggregateInputSelection( private void iterateNoNullsRepeatingWithAggregationSelection( VectorAggregationBufferRow[] aggregationBufferSets, int aggregateIndex, - long value, + double value, int batchSize) { for (int i=0; i < batchSize; ++i) { @@ -160,7 +160,7 @@ private void iterateNoNullsRepeatingWithAggregationSelection( private void iterateNoNullsSelectionWithAggregationSelection( VectorAggregationBufferRow[] aggregationBufferSets, int aggregateIndex, - long[] values, + double[] values, int[] selection, int batchSize) { @@ -176,7 +176,7 @@ private void iterateNoNullsSelectionWithAggregationSelection( private void iterateNoNullsWithAggregationSelection( VectorAggregationBufferRow[] aggregationBufferSets, int aggregateIndex, - long[] values, + double[] values, int batchSize) { for (int i=0; i < batchSize; ++i) { Aggregation myagg = getCurrentAggregationBuffer( @@ -190,7 +190,7 @@ private void iterateNoNullsWithAggregationSelection( private void iterateHasNullsRepeatingSelectionWithAggregationSelection( VectorAggregationBufferRow[] aggregationBufferSets, int aggregateIndex, - long value, + double value, int batchSize, int[] selection, boolean[] isNull) { @@ -210,7 +210,7 @@ private void iterateHasNullsRepeatingSelectionWithAggregationSelection( private void iterateHasNullsRepeatingWithAggregationSelection( VectorAggregationBufferRow[] aggregationBufferSets, int aggregateIndex, - long value, + double value, int batchSize, boolean[] isNull) { @@ -228,7 +228,7 @@ private void iterateHasNullsRepeatingWithAggregationSelection( private void iterateHasNullsSelectionWithAggregationSelection( VectorAggregationBufferRow[] aggregationBufferSets, int aggregateIndex, - long[] values, + double[] values, int batchSize, int[] selection, boolean[] isNull) { @@ -248,7 +248,7 @@ private void iterateHasNullsSelectionWithAggregationSelection( private void iterateHasNullsWithAggregationSelection( VectorAggregationBufferRow[] aggregationBufferSets, int aggregateIndex, - long[] values, + double[] values, int batchSize, boolean[] isNull) { diff --git ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/templates/CodeGen.java ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/templates/CodeGen.java index 4d0d309..4e1fd2e 100644 --- ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/templates/CodeGen.java +++ ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/templates/CodeGen.java @@ -275,8 +275,6 @@ private void generate() throws Exception { generateColumnArithmeticColumn(tdesc); } else if (tdesc[0].equals("ColumnUnaryMinus")) { generateColumnUnaryMinus(tdesc); - } else if (tdesc[0].equals("VectorUDAFCount")) { - generateVectorUDAFCount(tdesc); } else if (tdesc[0].equals("VectorUDAFMinMax")) { generateVectorUDAFMinMax(tdesc); } else if (tdesc[0].equals("VectorUDAFMinMaxString")) { @@ -342,22 +340,6 @@ private void generateVectorUDAFMinMaxString(String[] tdesc) throws Exception { } - - private void generateVectorUDAFCount(String[] tdesc) throws IOException { - String className = tdesc[1]; - String valueType = tdesc[2]; - String columnType = getColumnVectorType(valueType); - - String outputFile = joinPath(this.outputDirectory, className + ".java"); - String templateFile = joinPath(this.templateDirectory, tdesc[0] + ".txt"); - - String templateString = readFile(templateFile); - templateString = templateString.replaceAll("", className); - templateString = templateString.replaceAll("", valueType); - templateString = templateString.replaceAll("", columnType); - writeFile(outputFile, templateString); - } - private void generateVectorUDAFSum(String[] tdesc) throws Exception { //template, , , , String className = tdesc[1]; diff --git ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/templates/VectorUDAFSum.txt ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/templates/VectorUDAFSum.txt index aaaa6ad..096aec5 100644 --- ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/templates/VectorUDAFSum.txt +++ ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/templates/VectorUDAFSum.txt @@ -97,9 +97,9 @@ public class extends VectorAggregateExpression { inputExpression.evaluate(batch); - LongColumnVector inputVector = (LongColumnVector)batch. + inputVector = ()batch. cols[this.inputExpression.getOutputColumn()]; - long[] vector = inputVector.vector; + [] vector = inputVector.vector; if (inputVector.noNulls) { if (inputVector.isRepeating) { @@ -145,7 +145,7 @@ public class extends VectorAggregateExpression { private void iterateNoNullsRepeatingWithAggregationSelection( VectorAggregationBufferRow[] aggregationBufferSets, int aggregateIndex, - long value, + value, int batchSize) { for (int i=0; i < batchSize; ++i) { @@ -160,7 +160,7 @@ public class extends VectorAggregateExpression { private void iterateNoNullsSelectionWithAggregationSelection( VectorAggregationBufferRow[] aggregationBufferSets, int aggregateIndex, - long[] values, + [] values, int[] selection, int batchSize) { @@ -176,7 +176,7 @@ public class extends VectorAggregateExpression { private void iterateNoNullsWithAggregationSelection( VectorAggregationBufferRow[] aggregationBufferSets, int aggregateIndex, - long[] values, + [] values, int batchSize) { for (int i=0; i < batchSize; ++i) { Aggregation myagg = getCurrentAggregationBuffer( @@ -190,7 +190,7 @@ public class extends VectorAggregateExpression { private void iterateHasNullsRepeatingSelectionWithAggregationSelection( VectorAggregationBufferRow[] aggregationBufferSets, int aggregateIndex, - long value, + value, int batchSize, int[] selection, boolean[] isNull) { @@ -210,7 +210,7 @@ public class extends VectorAggregateExpression { private void iterateHasNullsRepeatingWithAggregationSelection( VectorAggregationBufferRow[] aggregationBufferSets, int aggregateIndex, - long value, + value, int batchSize, boolean[] isNull) { @@ -228,7 +228,7 @@ public class extends VectorAggregateExpression { private void iterateHasNullsSelectionWithAggregationSelection( VectorAggregationBufferRow[] aggregationBufferSets, int aggregateIndex, - long[] values, + [] values, int batchSize, int[] selection, boolean[] isNull) { @@ -248,7 +248,7 @@ public class extends VectorAggregateExpression { private void iterateHasNullsWithAggregationSelection( VectorAggregationBufferRow[] aggregationBufferSets, int aggregateIndex, - long[] values, + [] values, int batchSize, boolean[] isNull) { diff --git ql/src/test/org/apache/hadoop/hive/ql/exec/vector/TestVectorGroupByOperator.java ql/src/test/org/apache/hadoop/hive/ql/exec/vector/TestVectorGroupByOperator.java index 9e6372f..df81eaf 100644 --- ql/src/test/org/apache/hadoop/hive/ql/exec/vector/TestVectorGroupByOperator.java +++ ql/src/test/org/apache/hadoop/hive/ql/exec/vector/TestVectorGroupByOperator.java @@ -90,33 +90,14 @@ private static AggregationDesc buildAggregationDescCountStar( } - private static GroupByDesc buildGroupByDescLong( + private static GroupByDesc buildGroupByDescType( VectorizationContext ctx, String aggregate, - String column) { - - AggregationDesc agg = buildAggregationDesc(ctx, aggregate, - column, TypeInfoFactory.longTypeInfo); - ArrayList aggs = new ArrayList(); - aggs.add(agg); - - ArrayList outputColumnNames = new ArrayList(); - outputColumnNames.add("_col0"); - - GroupByDesc desc = new GroupByDesc(); - desc.setOutputColumnNames(outputColumnNames); - desc.setAggregators(aggs); - - return desc; - } - - private static GroupByDesc buildGroupByDescString( - VectorizationContext ctx, - String aggregate, - String column) { + String column, + TypeInfo dataType) { AggregationDesc agg = buildAggregationDesc(ctx, aggregate, - column, TypeInfoFactory.stringTypeInfo); + column, dataType); ArrayList aggs = new ArrayList(); aggs.add(agg); @@ -130,7 +111,6 @@ private static GroupByDesc buildGroupByDescString( return desc; } - private static GroupByDesc buildGroupByDescCountStar( VectorizationContext ctx) { @@ -153,12 +133,13 @@ private static GroupByDesc buildKeyGroupByDesc( VectorizationContext ctx, String aggregate, String column, - TypeInfo typeInfo, - String key) { + TypeInfo dataTypeInfo, + String key, + TypeInfo keyTypeInfo) { - GroupByDesc desc = buildGroupByDescLong(ctx, aggregate, column); + GroupByDesc desc = buildGroupByDescType(ctx, aggregate, column, dataTypeInfo); - ExprNodeDesc keyExp = buildColumnDesc(ctx, key, typeInfo); + ExprNodeDesc keyExp = buildColumnDesc(ctx, key, keyTypeInfo); ArrayList keys = new ArrayList(); keys.add(keyExp); desc.setKeys(keys); @@ -645,6 +626,25 @@ public void testCountLongRepeatConcatValues () throws HiveException { } @Test + public void testSumDoubleSimple() throws HiveException { + testAggregateDouble( + "sum", + 2, + Arrays.asList(new Object[]{13.0,5.0,7.0,19.0}), + 13.0 + 5.0 + 7.0 + 19.0); + } + + @Test + public void testSumDoubleGroupByString() throws HiveException { + testAggregateDoubleStringKeyAggregate( + "sum", + 4, + Arrays.asList(new Object[]{"A", null, "A", null}), + Arrays.asList(new Object[]{13.0,5.0,7.0,19.0}), + buildHashMap("A", 20.0, null, 24.0)); + } + + @Test public void testSumLongSimple () throws HiveException { testAggregateLongAggregate( "sum", @@ -1048,9 +1048,24 @@ public void testAggregateStringKeyAggregate ( new String[] {"string", "long"}, list, values); - testAggregateStringKeyIterable (aggregateName, fdr, expected); + testAggregateStringKeyIterable (aggregateName, fdr, TypeInfoFactory.longTypeInfo, expected); } + public void testAggregateDoubleStringKeyAggregate ( + String aggregateName, + int batchSize, + Iterable list, + Iterable values, + HashMap expected) throws HiveException { + + @SuppressWarnings("unchecked") + FakeVectorRowBatchFromObjectIterables fdr = new FakeVectorRowBatchFromObjectIterables( + batchSize, + new String[] {"string", "double"}, + list, + values); + testAggregateStringKeyIterable (aggregateName, fdr, TypeInfoFactory.doubleTypeInfo, expected); + } public void testAggregateLongKeyAggregate ( String aggregateName, @@ -1076,6 +1091,18 @@ public void testAggregateString ( testAggregateStringIterable (aggregateName, fdr, expected); } + public void testAggregateDouble ( + String aggregateName, + int batchSize, + Iterable values, + Object expected) throws HiveException { + + @SuppressWarnings("unchecked") + FakeVectorRowBatchFromObjectIterables fdr = new FakeVectorRowBatchFromObjectIterables( + batchSize, new String[] {"double"}, values); + testAggregateDoubleIterable (aggregateName, fdr, expected); + } + public void testAggregateLongAggregate ( String aggregateName, @@ -1120,6 +1147,9 @@ public void validate(Object expected, Object result) { BytesWritable bw = (BytesWritable) arr[0]; String sbw = new String(bw.getBytes()); assertEquals((String) expected, sbw); + } else if (arr[0] instanceof DoubleWritable) { + DoubleWritable dw = (DoubleWritable) arr[0]; + assertEquals ((Double) expected, (Double) dw.get()); } else { Assert.fail("Unsupported result type: " + expected.getClass().getName()); } @@ -1280,7 +1310,7 @@ public void testAggregateStringIterable ( mapColumnNames.put("A", 0); VectorizationContext ctx = new VectorizationContext(mapColumnNames, 1); - GroupByDesc desc = buildGroupByDescString (ctx, aggregateName, "A"); + GroupByDesc desc = buildGroupByDescType (ctx, aggregateName, "A", TypeInfoFactory.stringTypeInfo); VectorGroupByOperator vgo = new VectorGroupByOperator(ctx, desc); @@ -1302,6 +1332,35 @@ public void testAggregateStringIterable ( validator.validate(expected, result); } + public void testAggregateDoubleIterable ( + String aggregateName, + Iterable data, + Object expected) throws HiveException { + Map mapColumnNames = new HashMap(); + mapColumnNames.put("A", 0); + VectorizationContext ctx = new VectorizationContext(mapColumnNames, 1); + + GroupByDesc desc = buildGroupByDescType (ctx, aggregateName, "A", TypeInfoFactory.doubleTypeInfo); + + VectorGroupByOperator vgo = new VectorGroupByOperator(ctx, desc); + + FakeCaptureOutputOperator out = FakeCaptureOutputOperator.addCaptureOutputChild(vgo); + vgo.initialize(null, null); + + for (VectorizedRowBatch unit: data) { + vgo.process(unit, 0); + } + vgo.close(false); + + List outBatchList = out.getCapturedRows(); + assertNotNull(outBatchList); + assertEquals(1, outBatchList.size()); + + Object result = outBatchList.get(0); + + Validator validator = getValidator(aggregateName); + validator.validate(expected, result); + } public void testAggregateLongIterable ( String aggregateName, @@ -1311,7 +1370,7 @@ public void testAggregateLongIterable ( mapColumnNames.put("A", 0); VectorizationContext ctx = new VectorizationContext(mapColumnNames, 1); - GroupByDesc desc = buildGroupByDescLong (ctx, aggregateName, "A"); + GroupByDesc desc = buildGroupByDescType(ctx, aggregateName, "A", TypeInfoFactory.longTypeInfo); VectorGroupByOperator vgo = new VectorGroupByOperator(ctx, desc); @@ -1344,7 +1403,7 @@ public void testAggregateLongKeyIterable ( Set keys = new HashSet(); GroupByDesc desc = buildKeyGroupByDesc (ctx, aggregateName, "Value", - TypeInfoFactory.longTypeInfo, "Key"); + TypeInfoFactory.longTypeInfo, "Key", TypeInfoFactory.longTypeInfo); VectorGroupByOperator vgo = new VectorGroupByOperator(ctx, desc); @@ -1400,6 +1459,7 @@ public void inspectRow(Object row, int tag) throws HiveException { public void testAggregateStringKeyIterable ( String aggregateName, Iterable data, + TypeInfo dataTypeInfo, HashMap expected) throws HiveException { Map mapColumnNames = new HashMap(); mapColumnNames.put("Key", 0); @@ -1408,7 +1468,7 @@ public void testAggregateStringKeyIterable ( Set keys = new HashSet(); GroupByDesc desc = buildKeyGroupByDesc (ctx, aggregateName, "Value", - TypeInfoFactory.stringTypeInfo, "Key"); + dataTypeInfo, "Key", TypeInfoFactory.stringTypeInfo); VectorGroupByOperator vgo = new VectorGroupByOperator(ctx, desc);