diff --git ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/aggregates/gen/VectorUDAFAvgDouble.java ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/aggregates/gen/VectorUDAFAvgDouble.java index 38b14f1..0470b3c 100644 --- ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/aggregates/gen/VectorUDAFAvgDouble.java +++ ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/aggregates/gen/VectorUDAFAvgDouble.java @@ -117,9 +117,9 @@ public void aggregateInputSelection( inputExpression.evaluate(batch); - LongColumnVector inputVector = (LongColumnVector)batch. + DoubleColumnVector inputVector = ( DoubleColumnVector)batch. cols[this.inputExpression.getOutputColumn()]; - long[] vector = inputVector.vector; + double[] vector = inputVector.vector; if (inputVector.noNulls) { if (inputVector.isRepeating) { @@ -165,7 +165,7 @@ public void aggregateInputSelection( private void iterateNoNullsRepeatingWithAggregationSelection( VectorAggregationBufferRow[] aggregationBufferSets, int bufferIndex, - long value, + double value, int batchSize) { for (int i=0; i < batchSize; ++i) { @@ -180,7 +180,7 @@ private void iterateNoNullsRepeatingWithAggregationSelection( private void iterateNoNullsSelectionWithAggregationSelection( VectorAggregationBufferRow[] aggregationBufferSets, int bufferIndex, - long[] values, + double[] values, int[] selection, int batchSize) { @@ -196,7 +196,7 @@ private void iterateNoNullsSelectionWithAggregationSelection( private void iterateNoNullsWithAggregationSelection( VectorAggregationBufferRow[] aggregationBufferSets, int bufferIndex, - long[] values, + double[] values, int batchSize) { for (int i=0; i < batchSize; ++i) { Aggregation myagg = getCurrentAggregationBuffer( @@ -210,7 +210,7 @@ private void iterateNoNullsWithAggregationSelection( private void iterateHasNullsRepeatingSelectionWithAggregationSelection( VectorAggregationBufferRow[] aggregationBufferSets, int bufferIndex, - long value, + double value, int batchSize, int[] selection, boolean[] isNull) { @@ -230,7 +230,7 @@ private void iterateHasNullsRepeatingSelectionWithAggregationSelection( private void iterateHasNullsRepeatingWithAggregationSelection( VectorAggregationBufferRow[] aggregationBufferSets, int bufferIndex, - long value, + double value, int batchSize, boolean[] isNull) { @@ -248,7 +248,7 @@ private void iterateHasNullsRepeatingWithAggregationSelection( private void iterateHasNullsSelectionWithAggregationSelection( VectorAggregationBufferRow[] aggregationBufferSets, int bufferIndex, - long[] values, + double[] values, int batchSize, int[] selection, boolean[] isNull) { @@ -268,7 +268,7 @@ private void iterateHasNullsSelectionWithAggregationSelection( private void iterateHasNullsWithAggregationSelection( VectorAggregationBufferRow[] aggregationBufferSets, int bufferIndex, - long[] values, + double[] values, int batchSize, boolean[] isNull) { diff --git ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/aggregates/gen/VectorUDAFAvgLong.java ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/aggregates/gen/VectorUDAFAvgLong.java index 115444d..e7db9f2 100644 --- ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/aggregates/gen/VectorUDAFAvgLong.java +++ ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/aggregates/gen/VectorUDAFAvgLong.java @@ -117,7 +117,7 @@ public void aggregateInputSelection( inputExpression.evaluate(batch); - LongColumnVector inputVector = (LongColumnVector)batch. + LongColumnVector inputVector = ( LongColumnVector)batch. cols[this.inputExpression.getOutputColumn()]; long[] vector = inputVector.vector; diff --git ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/aggregates/gen/VectorUDAFMaxDouble.java ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/aggregates/gen/VectorUDAFMaxDouble.java index bc7f852..987e756 100644 --- ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/aggregates/gen/VectorUDAFMaxDouble.java +++ ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/aggregates/gen/VectorUDAFMaxDouble.java @@ -31,7 +31,7 @@ import org.apache.hadoop.hive.ql.exec.vector.DoubleColumnVector; import org.apache.hadoop.hive.ql.metadata.HiveException; import org.apache.hadoop.io.LongWritable; -import org.apache.hadoop.hive.serde2.io.DoubleWritable; +import org.apache.hadoop.io.DoubleWritable; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory; import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector; diff --git ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/aggregates/gen/VectorUDAFMaxLong.java ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/aggregates/gen/VectorUDAFMaxLong.java index 6ba416e..882d21e 100644 --- ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/aggregates/gen/VectorUDAFMaxLong.java +++ ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/aggregates/gen/VectorUDAFMaxLong.java @@ -31,7 +31,7 @@ import org.apache.hadoop.hive.ql.exec.vector.DoubleColumnVector; import org.apache.hadoop.hive.ql.metadata.HiveException; import org.apache.hadoop.io.LongWritable; -import org.apache.hadoop.hive.serde2.io.DoubleWritable; +import org.apache.hadoop.io.DoubleWritable; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory; import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector; diff --git ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/aggregates/gen/VectorUDAFMinDouble.java ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/aggregates/gen/VectorUDAFMinDouble.java index d982fc2..97a8c38 100644 --- ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/aggregates/gen/VectorUDAFMinDouble.java +++ ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/aggregates/gen/VectorUDAFMinDouble.java @@ -31,7 +31,7 @@ import org.apache.hadoop.hive.ql.exec.vector.DoubleColumnVector; import org.apache.hadoop.hive.ql.metadata.HiveException; import org.apache.hadoop.io.LongWritable; -import org.apache.hadoop.hive.serde2.io.DoubleWritable; +import org.apache.hadoop.io.DoubleWritable; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory; import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector; diff --git ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/aggregates/gen/VectorUDAFMinLong.java ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/aggregates/gen/VectorUDAFMinLong.java index a8f5531..9d039ff 100644 --- ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/aggregates/gen/VectorUDAFMinLong.java +++ ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/aggregates/gen/VectorUDAFMinLong.java @@ -31,7 +31,7 @@ import org.apache.hadoop.hive.ql.exec.vector.DoubleColumnVector; import org.apache.hadoop.hive.ql.metadata.HiveException; import org.apache.hadoop.io.LongWritable; -import org.apache.hadoop.hive.serde2.io.DoubleWritable; +import org.apache.hadoop.io.DoubleWritable; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory; import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector; diff --git ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/aggregates/gen/VectorUDAFSumDouble.java ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/aggregates/gen/VectorUDAFSumDouble.java index a5dac79..fdb7eeb 100644 --- ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/aggregates/gen/VectorUDAFSumDouble.java +++ ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/aggregates/gen/VectorUDAFSumDouble.java @@ -31,7 +31,7 @@ import org.apache.hadoop.hive.ql.exec.vector.DoubleColumnVector; import org.apache.hadoop.hive.ql.metadata.HiveException; import org.apache.hadoop.io.LongWritable; -import org.apache.hadoop.hive.serde2.io.DoubleWritable; +import org.apache.hadoop.io.DoubleWritable; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory; import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector; @@ -97,9 +97,9 @@ public void aggregateInputSelection( inputExpression.evaluate(batch); - LongColumnVector inputVector = (LongColumnVector)batch. + DoubleColumnVector inputVector = (DoubleColumnVector)batch. cols[this.inputExpression.getOutputColumn()]; - long[] vector = inputVector.vector; + double[] vector = inputVector.vector; if (inputVector.noNulls) { if (inputVector.isRepeating) { @@ -145,7 +145,7 @@ public void aggregateInputSelection( private void iterateNoNullsRepeatingWithAggregationSelection( VectorAggregationBufferRow[] aggregationBufferSets, int aggregateIndex, - long value, + double value, int batchSize) { for (int i=0; i < batchSize; ++i) { @@ -160,7 +160,7 @@ private void iterateNoNullsRepeatingWithAggregationSelection( private void iterateNoNullsSelectionWithAggregationSelection( VectorAggregationBufferRow[] aggregationBufferSets, int aggregateIndex, - long[] values, + double[] values, int[] selection, int batchSize) { @@ -176,7 +176,7 @@ private void iterateNoNullsSelectionWithAggregationSelection( private void iterateNoNullsWithAggregationSelection( VectorAggregationBufferRow[] aggregationBufferSets, int aggregateIndex, - long[] values, + double[] values, int batchSize) { for (int i=0; i < batchSize; ++i) { Aggregation myagg = getCurrentAggregationBuffer( @@ -190,7 +190,7 @@ private void iterateNoNullsWithAggregationSelection( private void iterateHasNullsRepeatingSelectionWithAggregationSelection( VectorAggregationBufferRow[] aggregationBufferSets, int aggregateIndex, - long value, + double value, int batchSize, int[] selection, boolean[] isNull) { @@ -210,7 +210,7 @@ private void iterateHasNullsRepeatingSelectionWithAggregationSelection( private void iterateHasNullsRepeatingWithAggregationSelection( VectorAggregationBufferRow[] aggregationBufferSets, int aggregateIndex, - long value, + double value, int batchSize, boolean[] isNull) { @@ -228,7 +228,7 @@ private void iterateHasNullsRepeatingWithAggregationSelection( private void iterateHasNullsSelectionWithAggregationSelection( VectorAggregationBufferRow[] aggregationBufferSets, int aggregateIndex, - long[] values, + double[] values, int batchSize, int[] selection, boolean[] isNull) { @@ -248,7 +248,7 @@ private void iterateHasNullsSelectionWithAggregationSelection( private void iterateHasNullsWithAggregationSelection( VectorAggregationBufferRow[] aggregationBufferSets, int aggregateIndex, - long[] values, + double[] values, int batchSize, boolean[] isNull) { diff --git ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/aggregates/gen/VectorUDAFSumLong.java ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/aggregates/gen/VectorUDAFSumLong.java index 4d1db3d..8f7bac0 100644 --- ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/aggregates/gen/VectorUDAFSumLong.java +++ ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/aggregates/gen/VectorUDAFSumLong.java @@ -31,7 +31,7 @@ import org.apache.hadoop.hive.ql.exec.vector.DoubleColumnVector; import org.apache.hadoop.hive.ql.metadata.HiveException; import org.apache.hadoop.io.LongWritable; -import org.apache.hadoop.hive.serde2.io.DoubleWritable; +import org.apache.hadoop.io.DoubleWritable; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory; import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector; diff --git ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/templates/CodeGen.java ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/templates/CodeGen.java index 888f9ca..9be0384 100644 --- ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/templates/CodeGen.java +++ ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/templates/CodeGen.java @@ -275,8 +275,6 @@ private void generate() throws Exception { generateColumnArithmeticColumn(tdesc); } else if (tdesc[0].equals("ColumnUnaryMinus")) { generateColumnUnaryMinus(tdesc); - } else if (tdesc[0].equals("VectorUDAFCount")) { - generateVectorUDAFCount(tdesc); } else if (tdesc[0].equals("VectorUDAFMinMax")) { generateVectorUDAFMinMax(tdesc); } else if (tdesc[0].equals("VectorUDAFMinMaxString")) { @@ -342,22 +340,6 @@ private void generateVectorUDAFMinMaxString(String[] tdesc) throws Exception { } - - private void generateVectorUDAFCount(String[] tdesc) throws IOException { - String className = tdesc[1]; - String valueType = tdesc[2]; - String columnType = getColumnVectorType(valueType); - - String outputFile = joinPath(this.outputDirectory, className + ".java"); - String templateFile = joinPath(this.templateDirectory, tdesc[0] + ".txt"); - - String templateString = readFile(templateFile); - templateString = templateString.replaceAll("", className); - templateString = templateString.replaceAll("", valueType); - templateString = templateString.replaceAll("", columnType); - writeFile(outputFile, templateString); - } - private void generateVectorUDAFSum(String[] tdesc) throws Exception { //template, , , , String className = tdesc[1]; diff --git ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/templates/VectorUDAFAvg.txt ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/templates/VectorUDAFAvg.txt index 7887ceb..8bd8e76 100644 --- ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/templates/VectorUDAFAvg.txt +++ ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/templates/VectorUDAFAvg.txt @@ -117,9 +117,9 @@ public class extends VectorAggregateExpression { inputExpression.evaluate(batch); - LongColumnVector inputVector = (LongColumnVector)batch. + inputVector = ( )batch. cols[this.inputExpression.getOutputColumn()]; - long[] vector = inputVector.vector; + [] vector = inputVector.vector; if (inputVector.noNulls) { if (inputVector.isRepeating) { @@ -165,7 +165,7 @@ public class extends VectorAggregateExpression { private void iterateNoNullsRepeatingWithAggregationSelection( VectorAggregationBufferRow[] aggregationBufferSets, int bufferIndex, - long value, + value, int batchSize) { for (int i=0; i < batchSize; ++i) { @@ -180,7 +180,7 @@ public class extends VectorAggregateExpression { private void iterateNoNullsSelectionWithAggregationSelection( VectorAggregationBufferRow[] aggregationBufferSets, int bufferIndex, - long[] values, + [] values, int[] selection, int batchSize) { @@ -196,7 +196,7 @@ public class extends VectorAggregateExpression { private void iterateNoNullsWithAggregationSelection( VectorAggregationBufferRow[] aggregationBufferSets, int bufferIndex, - long[] values, + [] values, int batchSize) { for (int i=0; i < batchSize; ++i) { Aggregation myagg = getCurrentAggregationBuffer( @@ -210,7 +210,7 @@ public class extends VectorAggregateExpression { private void iterateHasNullsRepeatingSelectionWithAggregationSelection( VectorAggregationBufferRow[] aggregationBufferSets, int bufferIndex, - long value, + value, int batchSize, int[] selection, boolean[] isNull) { @@ -230,7 +230,7 @@ public class extends VectorAggregateExpression { private void iterateHasNullsRepeatingWithAggregationSelection( VectorAggregationBufferRow[] aggregationBufferSets, int bufferIndex, - long value, + value, int batchSize, boolean[] isNull) { @@ -248,7 +248,7 @@ public class extends VectorAggregateExpression { private void iterateHasNullsSelectionWithAggregationSelection( VectorAggregationBufferRow[] aggregationBufferSets, int bufferIndex, - long[] values, + [] values, int batchSize, int[] selection, boolean[] isNull) { @@ -268,7 +268,7 @@ public class extends VectorAggregateExpression { private void iterateHasNullsWithAggregationSelection( VectorAggregationBufferRow[] aggregationBufferSets, int bufferIndex, - long[] values, + [] values, int batchSize, boolean[] isNull) { diff --git ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/templates/VectorUDAFMinMax.txt ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/templates/VectorUDAFMinMax.txt index d00d9ae..4c0130d 100644 --- ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/templates/VectorUDAFMinMax.txt +++ ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/templates/VectorUDAFMinMax.txt @@ -31,7 +31,7 @@ import org.apache.hadoop.hive.ql.exec.vector.LongColumnVector; import org.apache.hadoop.hive.ql.exec.vector.DoubleColumnVector; import org.apache.hadoop.hive.ql.metadata.HiveException; import org.apache.hadoop.io.LongWritable; -import org.apache.hadoop.hive.serde2.io.DoubleWritable; +import org.apache.hadoop.io.DoubleWritable; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory; import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector; diff --git ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/templates/VectorUDAFSum.txt ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/templates/VectorUDAFSum.txt index aaaa6ad..34b555f 100644 --- ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/templates/VectorUDAFSum.txt +++ ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/templates/VectorUDAFSum.txt @@ -31,7 +31,7 @@ import org.apache.hadoop.hive.ql.exec.vector.LongColumnVector; import org.apache.hadoop.hive.ql.exec.vector.DoubleColumnVector; import org.apache.hadoop.hive.ql.metadata.HiveException; import org.apache.hadoop.io.LongWritable; -import org.apache.hadoop.hive.serde2.io.DoubleWritable; +import org.apache.hadoop.io.DoubleWritable; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory; import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector; @@ -97,9 +97,9 @@ public class extends VectorAggregateExpression { inputExpression.evaluate(batch); - LongColumnVector inputVector = (LongColumnVector)batch. + inputVector = ()batch. cols[this.inputExpression.getOutputColumn()]; - long[] vector = inputVector.vector; + [] vector = inputVector.vector; if (inputVector.noNulls) { if (inputVector.isRepeating) { @@ -145,7 +145,7 @@ public class extends VectorAggregateExpression { private void iterateNoNullsRepeatingWithAggregationSelection( VectorAggregationBufferRow[] aggregationBufferSets, int aggregateIndex, - long value, + value, int batchSize) { for (int i=0; i < batchSize; ++i) { @@ -160,7 +160,7 @@ public class extends VectorAggregateExpression { private void iterateNoNullsSelectionWithAggregationSelection( VectorAggregationBufferRow[] aggregationBufferSets, int aggregateIndex, - long[] values, + [] values, int[] selection, int batchSize) { @@ -176,7 +176,7 @@ public class extends VectorAggregateExpression { private void iterateNoNullsWithAggregationSelection( VectorAggregationBufferRow[] aggregationBufferSets, int aggregateIndex, - long[] values, + [] values, int batchSize) { for (int i=0; i < batchSize; ++i) { Aggregation myagg = getCurrentAggregationBuffer( @@ -190,7 +190,7 @@ public class extends VectorAggregateExpression { private void iterateHasNullsRepeatingSelectionWithAggregationSelection( VectorAggregationBufferRow[] aggregationBufferSets, int aggregateIndex, - long value, + value, int batchSize, int[] selection, boolean[] isNull) { @@ -210,7 +210,7 @@ public class extends VectorAggregateExpression { private void iterateHasNullsRepeatingWithAggregationSelection( VectorAggregationBufferRow[] aggregationBufferSets, int aggregateIndex, - long value, + value, int batchSize, boolean[] isNull) { @@ -228,7 +228,7 @@ public class extends VectorAggregateExpression { private void iterateHasNullsSelectionWithAggregationSelection( VectorAggregationBufferRow[] aggregationBufferSets, int aggregateIndex, - long[] values, + [] values, int batchSize, int[] selection, boolean[] isNull) { @@ -248,7 +248,7 @@ public class extends VectorAggregateExpression { private void iterateHasNullsWithAggregationSelection( VectorAggregationBufferRow[] aggregationBufferSets, int aggregateIndex, - long[] values, + [] values, int batchSize, boolean[] isNull) { diff --git ql/src/test/org/apache/hadoop/hive/ql/exec/vector/TestVectorGroupByOperator.java ql/src/test/org/apache/hadoop/hive/ql/exec/vector/TestVectorGroupByOperator.java index 42cdcf4..ef276fd 100644 --- ql/src/test/org/apache/hadoop/hive/ql/exec/vector/TestVectorGroupByOperator.java +++ ql/src/test/org/apache/hadoop/hive/ql/exec/vector/TestVectorGroupByOperator.java @@ -98,33 +98,14 @@ private static AggregationDesc buildAggregationDescCountStar( } - private static GroupByDesc buildGroupByDescLong( + private static GroupByDesc buildGroupByDescType( VectorizationContext ctx, String aggregate, - String column) { - - AggregationDesc agg = buildAggregationDesc(ctx, aggregate, - column, TypeInfoFactory.longTypeInfo); - ArrayList aggs = new ArrayList(); - aggs.add(agg); - - ArrayList outputColumnNames = new ArrayList(); - outputColumnNames.add("_col0"); - - GroupByDesc desc = new GroupByDesc(); - desc.setOutputColumnNames(outputColumnNames); - desc.setAggregators(aggs); - - return desc; - } - - private static GroupByDesc buildGroupByDescString( - VectorizationContext ctx, - String aggregate, - String column) { + String column, + TypeInfo dataType) { AggregationDesc agg = buildAggregationDesc(ctx, aggregate, - column, TypeInfoFactory.stringTypeInfo); + column, dataType); ArrayList aggs = new ArrayList(); aggs.add(agg); @@ -138,7 +119,6 @@ private static GroupByDesc buildGroupByDescString( return desc; } - private static GroupByDesc buildGroupByDescCountStar( VectorizationContext ctx) { @@ -161,21 +141,94 @@ private static GroupByDesc buildKeyGroupByDesc( VectorizationContext ctx, String aggregate, String column, - TypeInfo typeInfo, - String key) { + TypeInfo dataTypeInfo, + String key, + TypeInfo keyTypeInfo) { - GroupByDesc desc = buildGroupByDescLong(ctx, aggregate, column); + GroupByDesc desc = buildGroupByDescType(ctx, aggregate, column, dataTypeInfo); - ExprNodeDesc keyExp = buildColumnDesc(ctx, key, typeInfo); + ExprNodeDesc keyExp = buildColumnDesc(ctx, key, keyTypeInfo); ArrayList keys = new ArrayList(); keys.add(keyExp); desc.setKeys(keys); return desc; } + + @Test + public void testDoubleValueTypeSum() throws HiveException { + testKeyTypeAggregate( + "sum", + new FakeVectorRowBatchFromObjectIterables( + 2, + new String[] {"tinyint", "double"}, + Arrays.asList(new Object[]{ 1,null, 1, null}), + Arrays.asList(new Object[]{13.0,null,7.0, 19.0})), + buildHashMap((byte)1, 20.0, null, 19.0)); + } + + @Test + public void testDoubleValueTypeCount() throws HiveException { + testKeyTypeAggregate( + "count", + new FakeVectorRowBatchFromObjectIterables( + 2, + new String[] {"tinyint", "double"}, + Arrays.asList(new Object[]{ 1,null, 1, null}), + Arrays.asList(new Object[]{13.0,null,7.0, 19.0})), + buildHashMap((byte)1, 2L, null, 1L)); + } + + @Test + public void testDoubleValueTypeAvg() throws HiveException { + testKeyTypeAggregate( + "avg", + new FakeVectorRowBatchFromObjectIterables( + 2, + new String[] {"tinyint", "double"}, + Arrays.asList(new Object[]{ 1,null, 1, null}), + Arrays.asList(new Object[]{13.0,null,7.0, 19.0})), + buildHashMap((byte)1, 10.0, null, 19.0)); + } + + @Test + public void testDoubleValueTypeMin() throws HiveException { + testKeyTypeAggregate( + "min", + new FakeVectorRowBatchFromObjectIterables( + 2, + new String[] {"tinyint", "double"}, + Arrays.asList(new Object[]{ 1,null, 1, null}), + Arrays.asList(new Object[]{13.0,null,7.0, 19.0})), + buildHashMap((byte)1, 7.0, null, 19.0)); + } @Test - public void testTinyintKeyTypeAggregate () throws HiveException { + public void testDoubleValueTypeMax() throws HiveException { + testKeyTypeAggregate( + "max", + new FakeVectorRowBatchFromObjectIterables( + 2, + new String[] {"tinyint", "double"}, + Arrays.asList(new Object[]{ 1,null, 1, null}), + Arrays.asList(new Object[]{13.0,null,7.0, 19.0})), + buildHashMap((byte)1, 13.0, null, 19.0)); + } + + @Test + public void testDoubleValueTypeVariance() throws HiveException { + testKeyTypeAggregate( + "variance", + new FakeVectorRowBatchFromObjectIterables( + 2, + new String[] {"tinyint", "double"}, + Arrays.asList(new Object[]{ 1,null, 1, null}), + Arrays.asList(new Object[]{13.0,null,7.0, 19.0})), + buildHashMap((byte)1, 9.0, null, 0.0)); + } + + @Test + public void testTinyintKeyTypeAggregate() throws HiveException { testKeyTypeAggregate( "sum", new FakeVectorRowBatchFromObjectIterables( @@ -187,7 +240,7 @@ public void testTinyintKeyTypeAggregate () throws HiveException { } @Test - public void testSmallintKeyTypeAggregate () throws HiveException { + public void testSmallintKeyTypeAggregate() throws HiveException { testKeyTypeAggregate( "sum", new FakeVectorRowBatchFromObjectIterables( @@ -199,7 +252,7 @@ public void testSmallintKeyTypeAggregate () throws HiveException { } @Test - public void testIntKeyTypeAggregate () throws HiveException { + public void testIntKeyTypeAggregate() throws HiveException { testKeyTypeAggregate( "sum", new FakeVectorRowBatchFromObjectIterables( @@ -211,7 +264,7 @@ public void testIntKeyTypeAggregate () throws HiveException { } @Test - public void testBigintKeyTypeAggregate () throws HiveException { + public void testBigintKeyTypeAggregate() throws HiveException { testKeyTypeAggregate( "sum", new FakeVectorRowBatchFromObjectIterables( @@ -223,7 +276,7 @@ public void testBigintKeyTypeAggregate () throws HiveException { } @Test - public void testBooleanKeyTypeAggregate () throws HiveException { + public void testBooleanKeyTypeAggregate() throws HiveException { testKeyTypeAggregate( "sum", new FakeVectorRowBatchFromObjectIterables( @@ -235,7 +288,7 @@ public void testBooleanKeyTypeAggregate () throws HiveException { } @Test - public void testTimestampKeyTypeAggregate () throws HiveException { + public void testTimestampKeyTypeAggregate() throws HiveException { testKeyTypeAggregate( "sum", new FakeVectorRowBatchFromObjectIterables( @@ -247,7 +300,7 @@ public void testTimestampKeyTypeAggregate () throws HiveException { } @Test - public void testFloatKeyTypeAggregate () throws HiveException { + public void testFloatKeyTypeAggregate() throws HiveException { testKeyTypeAggregate( "sum", new FakeVectorRowBatchFromObjectIterables( @@ -259,7 +312,7 @@ public void testFloatKeyTypeAggregate () throws HiveException { } @Test - public void testDoubleKeyTypeAggregate () throws HiveException { + public void testDoubleKeyTypeAggregate() throws HiveException { testKeyTypeAggregate( "sum", new FakeVectorRowBatchFromObjectIterables( @@ -271,7 +324,7 @@ public void testDoubleKeyTypeAggregate () throws HiveException { } @Test - public void testCountStar () throws HiveException { + public void testCountStar() throws HiveException { testAggregateCountStar( 2, Arrays.asList(new Long[]{13L,null,7L,19L}), @@ -279,7 +332,7 @@ public void testCountStar () throws HiveException { } @Test - public void testCountString () throws HiveException { + public void testCountString() throws HiveException { testAggregateString( "count", 2, @@ -288,7 +341,7 @@ public void testCountString () throws HiveException { } @Test - public void testMaxString () throws HiveException { + public void testMaxString() throws HiveException { testAggregateString( "max", 2, @@ -302,7 +355,7 @@ public void testMaxString () throws HiveException { } @Test - public void testMinString () throws HiveException { + public void testMinString() throws HiveException { testAggregateString( "min", 2, @@ -316,7 +369,7 @@ public void testMinString () throws HiveException { } @Test - public void testMaxNullString () throws HiveException { + public void testMaxNullString() throws HiveException { testAggregateString( "max", 2, @@ -330,7 +383,7 @@ public void testMaxNullString () throws HiveException { } @Test - public void testCountStringWithNull () throws HiveException { + public void testCountStringWithNull() throws HiveException { testAggregateString( "count", 2, @@ -339,7 +392,7 @@ public void testCountStringWithNull () throws HiveException { } @Test - public void testCountStringAllNull () throws HiveException { + public void testCountStringAllNull() throws HiveException { testAggregateString( "count", 4, @@ -749,6 +802,25 @@ public void testCountLongRepeatConcatValues () throws HiveException { } @Test + public void testSumDoubleSimple() throws HiveException { + testAggregateDouble( + "sum", + 2, + Arrays.asList(new Object[]{13.0,5.0,7.0,19.0}), + 13.0 + 5.0 + 7.0 + 19.0); + } + + @Test + public void testSumDoubleGroupByString() throws HiveException { + testAggregateDoubleStringKeyAggregate( + "sum", + 4, + Arrays.asList(new Object[]{"A", null, "A", null}), + Arrays.asList(new Object[]{13.0,5.0,7.0,19.0}), + buildHashMap("A", 20.0, null, 24.0)); + } + + @Test public void testSumLongSimple () throws HiveException { testAggregateLongAggregate( "sum", @@ -1258,9 +1330,24 @@ public void testAggregateStringKeyAggregate ( new String[] {"string", "long"}, list, values); - testAggregateStringKeyIterable (aggregateName, fdr, expected); + testAggregateStringKeyIterable (aggregateName, fdr, TypeInfoFactory.longTypeInfo, expected); } + public void testAggregateDoubleStringKeyAggregate ( + String aggregateName, + int batchSize, + Iterable list, + Iterable values, + HashMap expected) throws HiveException { + + @SuppressWarnings("unchecked") + FakeVectorRowBatchFromObjectIterables fdr = new FakeVectorRowBatchFromObjectIterables( + batchSize, + new String[] {"string", "double"}, + list, + values); + testAggregateStringKeyIterable (aggregateName, fdr, TypeInfoFactory.doubleTypeInfo, expected); + } public void testAggregateLongKeyAggregate ( String aggregateName, @@ -1286,6 +1373,18 @@ public void testAggregateString ( testAggregateStringIterable (aggregateName, fdr, expected); } + public void testAggregateDouble ( + String aggregateName, + int batchSize, + Iterable values, + Object expected) throws HiveException { + + @SuppressWarnings("unchecked") + FakeVectorRowBatchFromObjectIterables fdr = new FakeVectorRowBatchFromObjectIterables( + batchSize, new String[] {"double"}, values); + testAggregateDoubleIterable (aggregateName, fdr, expected); + } + public void testAggregateLongAggregate ( String aggregateName, @@ -1330,8 +1429,15 @@ public void validate(Object expected, Object result) { BytesWritable bw = (BytesWritable) arr[0]; String sbw = new String(bw.getBytes()); assertEquals((String) expected, sbw); + } else if (arr[0] instanceof DoubleWritable) { + DoubleWritable dw = (DoubleWritable) arr[0]; + assertEquals ((Double) expected, (Double) dw.get()); + } else if (arr[0] instanceof Double) { + assertEquals ((Double) expected, (Double) arr[0]); + } else if (arr[0] instanceof Long) { + assertEquals ((Long) expected, (Long) arr[0]); } else { - Assert.fail("Unsupported result type: " + expected.getClass().getName()); + Assert.fail("Unsupported result type: " + arr[0].getClass().getName()); } } } @@ -1490,7 +1596,7 @@ public void testAggregateStringIterable ( mapColumnNames.put("A", 0); VectorizationContext ctx = new VectorizationContext(mapColumnNames, 1); - GroupByDesc desc = buildGroupByDescString (ctx, aggregateName, "A"); + GroupByDesc desc = buildGroupByDescType (ctx, aggregateName, "A", TypeInfoFactory.stringTypeInfo); VectorGroupByOperator vgo = new VectorGroupByOperator(ctx, desc); @@ -1512,6 +1618,35 @@ public void testAggregateStringIterable ( validator.validate(expected, result); } + public void testAggregateDoubleIterable ( + String aggregateName, + Iterable data, + Object expected) throws HiveException { + Map mapColumnNames = new HashMap(); + mapColumnNames.put("A", 0); + VectorizationContext ctx = new VectorizationContext(mapColumnNames, 1); + + GroupByDesc desc = buildGroupByDescType (ctx, aggregateName, "A", TypeInfoFactory.doubleTypeInfo); + + VectorGroupByOperator vgo = new VectorGroupByOperator(ctx, desc); + + FakeCaptureOutputOperator out = FakeCaptureOutputOperator.addCaptureOutputChild(vgo); + vgo.initialize(null, null); + + for (VectorizedRowBatch unit: data) { + vgo.process(unit, 0); + } + vgo.close(false); + + List outBatchList = out.getCapturedRows(); + assertNotNull(outBatchList); + assertEquals(1, outBatchList.size()); + + Object result = outBatchList.get(0); + + Validator validator = getValidator(aggregateName); + validator.validate(expected, result); + } public void testAggregateLongIterable ( String aggregateName, @@ -1521,7 +1656,7 @@ public void testAggregateLongIterable ( mapColumnNames.put("A", 0); VectorizationContext ctx = new VectorizationContext(mapColumnNames, 1); - GroupByDesc desc = buildGroupByDescLong (ctx, aggregateName, "A"); + GroupByDesc desc = buildGroupByDescType(ctx, aggregateName, "A", TypeInfoFactory.longTypeInfo); VectorGroupByOperator vgo = new VectorGroupByOperator(ctx, desc); @@ -1554,7 +1689,7 @@ public void testAggregateLongKeyIterable ( Set keys = new HashSet(); GroupByDesc desc = buildKeyGroupByDesc (ctx, aggregateName, "Value", - TypeInfoFactory.longTypeInfo, "Key"); + TypeInfoFactory.longTypeInfo, "Key", TypeInfoFactory.longTypeInfo); VectorGroupByOperator vgo = new VectorGroupByOperator(ctx, desc); @@ -1610,6 +1745,7 @@ public void inspectRow(Object row, int tag) throws HiveException { public void testAggregateStringKeyIterable ( String aggregateName, Iterable data, + TypeInfo dataTypeInfo, HashMap expected) throws HiveException { Map mapColumnNames = new HashMap(); mapColumnNames.put("Key", 0); @@ -1618,7 +1754,7 @@ public void testAggregateStringKeyIterable ( Set keys = new HashSet(); GroupByDesc desc = buildKeyGroupByDesc (ctx, aggregateName, "Value", - TypeInfoFactory.stringTypeInfo, "Key"); + dataTypeInfo, "Key", TypeInfoFactory.stringTypeInfo); VectorGroupByOperator vgo = new VectorGroupByOperator(ctx, desc);