diff --git a/ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDAFSum.java b/ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDAFSum.java index f53554c..38269f4 100644 --- a/ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDAFSum.java +++ b/ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDAFSum.java @@ -38,6 +38,7 @@ import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory; import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils; import org.apache.hadoop.hive.serde2.typeinfo.DecimalTypeInfo; +import org.apache.hadoop.hive.serde2.typeinfo.HiveDecimalUtils; import org.apache.hadoop.hive.serde2.typeinfo.PrimitiveTypeInfo; import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo; import org.apache.hadoop.hive.serde2.typeinfo.TypeInfoFactory; @@ -207,8 +208,6 @@ public ObjectInspector init(Mode m, ObjectInspector[] parameters) throws HiveExc super.init(m, parameters); result = new HiveDecimalWritable(HiveDecimal.ZERO); inputOI = (PrimitiveObjectInspector) parameters[0]; - outputOI = (PrimitiveObjectInspector) ObjectInspectorUtils.getStandardObjectInspector(inputOI, - ObjectInspectorCopyOption.JAVA); // The output precision is 10 greater than the input which should cover at least // 10b rows. The scale is the same as the input. DecimalTypeInfo outputTypeInfo = null; @@ -218,7 +217,11 @@ public ObjectInspector init(Mode m, ObjectInspector[] parameters) throws HiveExc } else { outputTypeInfo = (DecimalTypeInfo) inputOI.getTypeInfo(); } - return PrimitiveObjectInspectorFactory.getPrimitiveWritableObjectInspector(outputTypeInfo); + ObjectInspector oi = PrimitiveObjectInspectorFactory.getPrimitiveWritableObjectInspector(outputTypeInfo); + outputOI = (PrimitiveObjectInspector) ObjectInspectorUtils.getStandardObjectInspector( + oi, ObjectInspectorCopyOption.JAVA); + + return oi; } /** class for storing decimal sum value. */ @@ -287,6 +290,13 @@ public Object terminate(AggregationBuffer agg) throws HiveException { if (myagg.empty || myagg.sum == null) { return null; } + if (myagg.sum != null) { + if (HiveDecimalUtils.enforcePrecisionScale(myagg.sum, (DecimalTypeInfo)outputOI.getTypeInfo()) == null) { + LOG.warn("The sum of a column with data type HiveDecimal is out of range"); + return null; + } + } + result.set(myagg.sum); return result; } diff --git a/ql/src/test/queries/clientpositive/decimal_precision.q b/ql/src/test/queries/clientpositive/decimal_precision.q index 7d77455..e917f20 100644 --- a/ql/src/test/queries/clientpositive/decimal_precision.q +++ b/ql/src/test/queries/clientpositive/decimal_precision.q @@ -27,3 +27,10 @@ SELECT MIN(cast('12345678901234567890.12345678' as decimal(38,18))) FROM DECIMAL SELECT COUNT(cast('12345678901234567890.12345678' as decimal(38,18))) FROM DECIMAL_PRECISION; DROP TABLE DECIMAL_PRECISION; + +-- Expect overflow and return null as the value +CREATE TABLE DECIMAL_PRECISION(dec decimal(38,18)); +INSERT INTO DECIMAL_PRECISION VALUES(98765432109876543210.12345), (98765432109876543210.12345); +SELECT SUM(dec) FROM DECIMAL_PRECISION; + +DROP TABLE DECIMAL_PRECISION; diff --git a/ql/src/test/results/clientpositive/decimal_precision.q.out b/ql/src/test/results/clientpositive/decimal_precision.q.out index cb17e0d..a607d9f 100644 --- a/ql/src/test/results/clientpositive/decimal_precision.q.out +++ b/ql/src/test/results/clientpositive/decimal_precision.q.out @@ -631,3 +631,39 @@ POSTHOOK: query: DROP TABLE DECIMAL_PRECISION POSTHOOK: type: DROPTABLE POSTHOOK: Input: default@decimal_precision POSTHOOK: Output: default@decimal_precision +PREHOOK: query: -- Expect overflow and return null as the value +CREATE TABLE DECIMAL_PRECISION(dec decimal(38,18)) +PREHOOK: type: CREATETABLE +PREHOOK: Output: database:default +PREHOOK: Output: default@DECIMAL_PRECISION +POSTHOOK: query: -- Expect overflow and return null as the value +CREATE TABLE DECIMAL_PRECISION(dec decimal(38,18)) +POSTHOOK: type: CREATETABLE +POSTHOOK: Output: database:default +POSTHOOK: Output: default@DECIMAL_PRECISION +PREHOOK: query: INSERT INTO DECIMAL_PRECISION VALUES(98765432109876543210.12345), (98765432109876543210.12345) +PREHOOK: type: QUERY +PREHOOK: Input: default@values__tmp__table__1 +PREHOOK: Output: default@decimal_precision +POSTHOOK: query: INSERT INTO DECIMAL_PRECISION VALUES(98765432109876543210.12345), (98765432109876543210.12345) +POSTHOOK: type: QUERY +POSTHOOK: Input: default@values__tmp__table__1 +POSTHOOK: Output: default@decimal_precision +POSTHOOK: Lineage: decimal_precision.dec EXPRESSION [(values__tmp__table__1)values__tmp__table__1.FieldSchema(name:tmp_values_col1, type:string, comment:), ] +PREHOOK: query: SELECT SUM(dec) FROM DECIMAL_PRECISION +PREHOOK: type: QUERY +PREHOOK: Input: default@decimal_precision +#### A masked pattern was here #### +POSTHOOK: query: SELECT SUM(dec) FROM DECIMAL_PRECISION +POSTHOOK: type: QUERY +POSTHOOK: Input: default@decimal_precision +#### A masked pattern was here #### +NULL +PREHOOK: query: DROP TABLE DECIMAL_PRECISION +PREHOOK: type: DROPTABLE +PREHOOK: Input: default@decimal_precision +PREHOOK: Output: default@decimal_precision +POSTHOOK: query: DROP TABLE DECIMAL_PRECISION +POSTHOOK: type: DROPTABLE +POSTHOOK: Input: default@decimal_precision +POSTHOOK: Output: default@decimal_precision