From a339f9f4a8db59f8fa2d16ec54e775b2c757b4ce Mon Sep 17 00:00:00 2001 From: Ashutosh Chauhan Date: Tue, 27 Feb 2018 16:51:50 -0800 Subject: [PATCH] HIVE-18793 : Round udf should support variable as second argument --- .../hive/ql/udf/generic/GenericUDFRound.java | 70 +++++++++++++++------- ql/src/test/queries/clientpositive/udf_round.q | 8 ++- ql/src/test/results/clientpositive/udf_round.q.out | 23 +++++++ 3 files changed, 80 insertions(+), 21 deletions(-) diff --git a/ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDFRound.java b/ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDFRound.java index ea7466ea40..36fc27e7c6 100644 --- a/ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDFRound.java +++ b/ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDFRound.java @@ -76,6 +76,8 @@ private transient PrimitiveCategory inputType; private transient Converter converterFromString; + private transient boolean constantScale = true; + private transient PrimitiveObjectInspector scaleOI; @Override public ObjectInspector initialize(ObjectInspector[] arguments) throws UDFArgumentException { @@ -95,43 +97,46 @@ public ObjectInspector initialize(ObjectInspector[] arguments) throws UDFArgumen throw new UDFArgumentTypeException(1, "ROUND second argument only takes primitive types, got " + arguments[1].getTypeName()); } - PrimitiveObjectInspector scaleOI = (PrimitiveObjectInspector) arguments[1]; + scaleOI = (PrimitiveObjectInspector) arguments[1]; switch (scaleOI.getPrimitiveCategory()) { case VOID: break; case BYTE: - if (!(scaleOI instanceof WritableConstantByteObjectInspector)) { - throw new UDFArgumentTypeException(1, getFuncName().toUpperCase() + " second argument only takes constant"); + if (scaleOI instanceof WritableConstantByteObjectInspector) { + scale = ((WritableConstantByteObjectInspector)scaleOI).getWritableConstantValue().get(); + } else { + constantScale = false; } - scale = ((WritableConstantByteObjectInspector)scaleOI).getWritableConstantValue().get(); break; case SHORT: - if (!(scaleOI instanceof WritableConstantShortObjectInspector)) { - throw new UDFArgumentTypeException(1, getFuncName().toUpperCase() + " second argument only takes constant"); + if (scaleOI instanceof WritableConstantShortObjectInspector) { + scale = ((WritableConstantShortObjectInspector)scaleOI).getWritableConstantValue().get(); + } else { + constantScale = false; } - scale = ((WritableConstantShortObjectInspector)scaleOI).getWritableConstantValue().get(); break; case INT: - if (!(scaleOI instanceof WritableConstantIntObjectInspector)) { - throw new UDFArgumentTypeException(1, getFuncName().toUpperCase() + " second argument only takes constant"); + if (scaleOI instanceof WritableConstantIntObjectInspector) { + scale = ((WritableConstantIntObjectInspector)scaleOI).getWritableConstantValue().get(); + } else { + constantScale = false; } - scale = ((WritableConstantIntObjectInspector)scaleOI).getWritableConstantValue().get(); break; case LONG: - if (!(scaleOI instanceof WritableConstantLongObjectInspector)) { - throw new UDFArgumentTypeException(1, getFuncName().toUpperCase() - + " second argument only takes constant"); + if (scaleOI instanceof WritableConstantLongObjectInspector) { + long l = ((WritableConstantLongObjectInspector)scaleOI).getWritableConstantValue().get(); + if (l < Integer.MIN_VALUE || l > Integer.MAX_VALUE) { + throw new UDFArgumentException(getFuncName().toUpperCase() + + " scale argument out of allowed range"); + } + scale = (int)l; + } else { + constantScale = false; } - long l = ((WritableConstantLongObjectInspector)scaleOI).getWritableConstantValue().get(); - if (l < Integer.MIN_VALUE || l > Integer.MAX_VALUE) { - throw new UDFArgumentException(getFuncName().toUpperCase() - + " scale argument out of allowed range"); - } - scale = (int)l; break; default: throw new UDFArgumentTypeException(1, getFuncName().toUpperCase() - + " second argument only takes integer constant"); + + " second argument only takes numeric type"); } } @@ -142,6 +147,10 @@ public ObjectInspector initialize(ObjectInspector[] arguments) throws UDFArgumen DecimalTypeInfo inputTypeInfo = (DecimalTypeInfo) inputOI.getTypeInfo(); DecimalTypeInfo typeInfo = getOutputTypeInfo(inputTypeInfo, scale); outputOI = PrimitiveObjectInspectorFactory.getPrimitiveWritableObjectInspector(typeInfo); + if (!constantScale) { + throw new UDFArgumentTypeException(1,getFuncName().toUpperCase() + " scale argument for " + + "decimal must be constant"); + } break; case VOID: case BYTE: @@ -209,6 +218,9 @@ public Object evaluate(DeferredObject[] arguments) throws HiveException { } case BYTE: ByteWritable byteWritable = (ByteWritable)inputOI.getPrimitiveWritableObject(input); + if (!constantScale) { + scale = ((Number)scaleOI.getPrimitiveJavaObject(arguments[1].get())).intValue(); + } if (scale >= 0) { return byteWritable; } else { @@ -216,6 +228,9 @@ public Object evaluate(DeferredObject[] arguments) throws HiveException { } case SHORT: ShortWritable shortWritable = (ShortWritable)inputOI.getPrimitiveWritableObject(input); + if (!constantScale) { + scale = ((Number)scaleOI.getPrimitiveJavaObject(arguments[1].get())).intValue(); + } if (scale >= 0) { return shortWritable; } else { @@ -223,6 +238,9 @@ public Object evaluate(DeferredObject[] arguments) throws HiveException { } case INT: IntWritable intWritable = (IntWritable)inputOI.getPrimitiveWritableObject(input); + if (!constantScale) { + scale = ((Number)scaleOI.getPrimitiveJavaObject(arguments[1].get())).intValue(); + } if (scale >= 0) { return intWritable; } else { @@ -230,6 +248,9 @@ public Object evaluate(DeferredObject[] arguments) throws HiveException { } case LONG: LongWritable longWritable = (LongWritable)inputOI.getPrimitiveWritableObject(input); + if (!constantScale) { + scale = ((Number)scaleOI.getPrimitiveJavaObject(arguments[1].get())).intValue(); + } if (scale >= 0) { return longWritable; } else { @@ -237,8 +258,14 @@ public Object evaluate(DeferredObject[] arguments) throws HiveException { } case FLOAT: float f = ((FloatWritable)inputOI.getPrimitiveWritableObject(input)).get(); + if (!constantScale) { + scale = ((Number)scaleOI.getPrimitiveJavaObject(arguments[1].get())).intValue(); + } return new FloatWritable((float)round(f, scale)); case DOUBLE: + if (!constantScale) { + scale = ((Number)scaleOI.getPrimitiveJavaObject(arguments[1].get())).intValue(); + } return round(((DoubleWritable)inputOI.getPrimitiveWritableObject(input)), scale); case STRING: case VARCHAR: @@ -247,6 +274,9 @@ public Object evaluate(DeferredObject[] arguments) throws HiveException { if (doubleValue == null) { return null; } + if (!constantScale) { + scale = ((Number)scaleOI.getPrimitiveJavaObject(arguments[1].get())).intValue(); + } return round(doubleValue, scale); default: throw new UDFArgumentTypeException(0, diff --git a/ql/src/test/queries/clientpositive/udf_round.q b/ql/src/test/queries/clientpositive/udf_round.q index 88b22749a3..2441ff0302 100644 --- a/ql/src/test/queries/clientpositive/udf_round.q +++ b/ql/src/test/queries/clientpositive/udf_round.q @@ -1,5 +1,4 @@ set hive.fetch.task.conversion=more; - DESCRIBE FUNCTION round; DESCRIBE FUNCTION EXTENDED round; @@ -44,3 +43,10 @@ FROM src tablesample (1 rows); SELECT round(1809242.3151111344, 9), round(-1809242.3151111344, 9), round(1809242.3151111344BD, 9), round(-1809242.3151111344BD, 9) FROM src tablesample (1 rows); + +select round(cast(l_suppkey as bigint), l_linenumber * -1 ), + round(l_extendedprice, cast(l_orderkey % 2 as tinyint)), + round(cast(l_suppkey as smallint), (cast(l_linenumber * -1 as tinyint)) %3), + round(cast(l_discount as float), cast(l_partkey % 2 as smallint)), + round(l_suppkey, cast(l_orderkey as bigint) * -1) +from lineitem limit 5; diff --git a/ql/src/test/results/clientpositive/udf_round.q.out b/ql/src/test/results/clientpositive/udf_round.q.out index 456e6ea918..47ff06c089 100644 --- a/ql/src/test/results/clientpositive/udf_round.q.out +++ b/ql/src/test/results/clientpositive/udf_round.q.out @@ -122,3 +122,26 @@ POSTHOOK: type: QUERY POSTHOOK: Input: default@src #### A masked pattern was here #### 1809242.315111134 -1809242.315111134 1809242.315111134 -1809242.315111134 +PREHOOK: query: select round(cast(l_suppkey as bigint), l_linenumber * -1 ), + round(l_extendedprice, cast(l_orderkey % 2 as tinyint)), + round(cast(l_suppkey as smallint), (cast(l_linenumber * -1 as tinyint)) %3), + round(cast(l_discount as float), cast(l_partkey % 2 as smallint)), + round(l_suppkey, cast(l_orderkey as bigint) * -1) +from lineitem limit 5 +PREHOOK: type: QUERY +PREHOOK: Input: default@lineitem +#### A masked pattern was here #### +POSTHOOK: query: select round(cast(l_suppkey as bigint), l_linenumber * -1 ), + round(l_extendedprice, cast(l_orderkey % 2 as tinyint)), + round(cast(l_suppkey as smallint), (cast(l_linenumber * -1 as tinyint)) %3), + round(cast(l_discount as float), cast(l_partkey % 2 as smallint)), + round(l_suppkey, cast(l_orderkey as bigint) * -1) +from lineitem limit 5 +POSTHOOK: type: QUERY +POSTHOOK: Input: default@lineitem +#### A masked pattern was here #### +7710 21168.2 7710 0.0 7710 +7300 45983.2 7300 0.0 7310 +4000 13309.6 3701 0.0 3700 +0 28955.6 4630 0.0 4630 +0 22824.5 1500 0.1 1530 -- 2.14.3 (Apple Git-98)