diff --git ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/aggregates/VectorUDAFSumDecimal.java ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/aggregates/VectorUDAFSumDecimal.java index 2d1b71e..8eeaf2b 100644 --- ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/aggregates/VectorUDAFSumDecimal.java +++ ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/aggregates/VectorUDAFSumDecimal.java @@ -29,6 +29,7 @@ import org.apache.hadoop.hive.ql.metadata.HiveException; import org.apache.hadoop.hive.ql.plan.AggregationDesc; import org.apache.hadoop.hive.ql.util.JavaDataModel; +import org.apache.hadoop.hive.serde2.io.HiveDecimalWritable; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory; @@ -74,6 +75,7 @@ public void reset() { private VectorExpression inputExpression; transient private final Decimal128 scratchDecimal; + transient private HiveDecimalWritable result; public VectorUDAFSumDecimal(VectorExpression inputExpression) { this(); @@ -83,6 +85,7 @@ public VectorUDAFSumDecimal(VectorExpression inputExpression) { public VectorUDAFSumDecimal() { super(); scratchDecimal = new Decimal128(); + result = new HiveDecimalWritable(); } private Aggregation getCurrentAggregationBuffer( @@ -427,7 +430,8 @@ public Object evaluateOutput(AggregationBuffer agg) throws HiveException { return null; } else { - return HiveDecimal.create(myagg.sum.toBigDecimal()); + result.set(HiveDecimal.create(myagg.sum.toBigDecimal())); + return result; } } diff --git ql/src/test/queries/clientpositive/vector_decimal_aggregate.q ql/src/test/queries/clientpositive/vector_decimal_aggregate.q index eb9146e..4678b45 100644 --- ql/src/test/queries/clientpositive/vector_decimal_aggregate.q +++ ql/src/test/queries/clientpositive/vector_decimal_aggregate.q @@ -1,20 +1,36 @@ CREATE TABLE decimal_vgby STORED AS ORC AS - SELECT cdouble, CAST (((cdouble*22.1)/37) AS DECIMAL(20,10)) AS cdecimal1, - CAST (((cdouble*9.3)/13) AS DECIMAL(23,14)) AS cdecimal2, - cint - FROM alltypesorc; + SELECT cdouble, CAST (((cdouble*22.1)/37) AS DECIMAL(20,10)) AS cdecimal1, + CAST (((cdouble*9.3)/13) AS DECIMAL(23,14)) AS cdecimal2, + cint + FROM alltypesorc; SET hive.vectorized.execution.enabled=true; +-- First only do simple aggregations that output primitives only EXPLAIN SELECT cint, - COUNT(cdecimal1), MAX(cdecimal1), MIN(cdecimal1), SUM(cdecimal1), AVG(cdecimal1), STDDEV_POP(cdecimal1), STDDEV_SAMP(cdecimal1), - COUNT(cdecimal2), MAX(cdecimal2), MIN(cdecimal2), SUM(cdecimal2), AVG(cdecimal2), STDDEV_POP(cdecimal2), STDDEV_SAMP(cdecimal2) - FROM decimal_vgby - GROUP BY cint - HAVING COUNT(*) > 1; + COUNT(cdecimal1), MAX(cdecimal1), MIN(cdecimal1), SUM(cdecimal1), + COUNT(cdecimal2), MAX(cdecimal2), MIN(cdecimal2), SUM(cdecimal2) + FROM decimal_vgby + GROUP BY cint + HAVING COUNT(*) > 1; + +SELECT cint, + COUNT(cdecimal1), MAX(cdecimal1), MIN(cdecimal1), SUM(cdecimal1), + COUNT(cdecimal2), MAX(cdecimal2), MIN(cdecimal2), SUM(cdecimal2) + FROM decimal_vgby + GROUP BY cint + HAVING COUNT(*) > 1; + +-- Now add the others... +EXPLAIN SELECT cint, + COUNT(cdecimal1), MAX(cdecimal1), MIN(cdecimal1), SUM(cdecimal1), AVG(cdecimal1), STDDEV_POP(cdecimal1), STDDEV_SAMP(cdecimal1), + COUNT(cdecimal2), MAX(cdecimal2), MIN(cdecimal2), SUM(cdecimal2), AVG(cdecimal2), STDDEV_POP(cdecimal2), STDDEV_SAMP(cdecimal2) + FROM decimal_vgby + GROUP BY cint + HAVING COUNT(*) > 1; SELECT cint, - COUNT(cdecimal1), MAX(cdecimal1), MIN(cdecimal1), SUM(cdecimal1), AVG(cdecimal1), STDDEV_POP(cdecimal1), STDDEV_SAMP(cdecimal1), - COUNT(cdecimal2), MAX(cdecimal2), MIN(cdecimal2), SUM(cdecimal2), AVG(cdecimal2), STDDEV_POP(cdecimal2), STDDEV_SAMP(cdecimal2) - FROM decimal_vgby - GROUP BY cint - HAVING COUNT(*) > 1; \ No newline at end of file + COUNT(cdecimal1), MAX(cdecimal1), MIN(cdecimal1), SUM(cdecimal1), AVG(cdecimal1), STDDEV_POP(cdecimal1), STDDEV_SAMP(cdecimal1), + COUNT(cdecimal2), MAX(cdecimal2), MIN(cdecimal2), SUM(cdecimal2), AVG(cdecimal2), STDDEV_POP(cdecimal2), STDDEV_SAMP(cdecimal2) + FROM decimal_vgby + GROUP BY cint + HAVING COUNT(*) > 1; \ No newline at end of file