diff --git data/files/round.txt data/files/round.txt new file mode 100644 index 0000000..eb269b2 --- /dev/null +++ data/files/round.txt @@ -0,0 +1,8 @@ +1809242.3151111344,9 +-1809242.3151111344,9 +3.141592653589793,16 +3.141592653589793,-2 +3.141592653589793,-15 +12345,-4 +12345,0 +12345 \ No newline at end of file diff --git ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDFRound.java ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDFRound.java index ae81fe3..214b62b 100644 --- ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDFRound.java +++ ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDFRound.java @@ -33,6 +33,7 @@ import org.apache.hadoop.hive.serde2.io.DoubleWritable; import org.apache.hadoop.hive.serde2.io.HiveDecimalWritable; import org.apache.hadoop.hive.serde2.io.ShortWritable; +import org.apache.hadoop.hive.serde2.objectinspector.ConstantObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector.Category; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorConverters; @@ -76,6 +77,9 @@ private transient PrimitiveCategory inputType; private transient Converter converterFromString; + private transient Converter inputScaleConverter; + private transient PrimitiveObjectInspector inputScaleOI; + private transient boolean inputScaleConst; @Override public ObjectInspector initialize(ObjectInspector[] arguments) throws UDFArgumentException { @@ -95,43 +99,71 @@ public ObjectInspector initialize(ObjectInspector[] arguments) throws UDFArgumen throw new UDFArgumentTypeException(1, "ROUND second argument only takes primitive types, got " + arguments[1].getTypeName()); } - PrimitiveObjectInspector scaleOI = (PrimitiveObjectInspector) arguments[1]; - switch (scaleOI.getPrimitiveCategory()) { + inputScaleOI = (PrimitiveObjectInspector) arguments[1]; + inputScaleConst = arguments[1] instanceof ConstantObjectInspector; + switch (inputScaleOI.getPrimitiveCategory()) { case VOID: break; case BYTE: - if (!(scaleOI instanceof WritableConstantByteObjectInspector)) { - throw new UDFArgumentTypeException(1, getFuncName().toUpperCase() + " second argument only takes constant"); + if (inputScaleConst) { + if (!(inputScaleOI instanceof WritableConstantByteObjectInspector)) { + throw new UDFArgumentTypeException(1, + getFuncName().toUpperCase() + " second argument only takes constant"); + } + scale = + ((WritableConstantByteObjectInspector) inputScaleOI).getWritableConstantValue().get(); + } else { + inputScaleConverter = ObjectInspectorConverters.getConverter(arguments[1], + PrimitiveObjectInspectorFactory.writableByteObjectInspector); } - scale = ((WritableConstantByteObjectInspector)scaleOI).getWritableConstantValue().get(); break; case SHORT: - if (!(scaleOI instanceof WritableConstantShortObjectInspector)) { - throw new UDFArgumentTypeException(1, getFuncName().toUpperCase() + " second argument only takes constant"); + if (inputScaleConst) { + if (!(inputScaleOI instanceof WritableConstantShortObjectInspector)) { + throw new UDFArgumentTypeException(1, + getFuncName().toUpperCase() + " second argument only takes constant"); + } + scale = ((WritableConstantShortObjectInspector) inputScaleOI).getWritableConstantValue() + .get(); + } else { + inputScaleConverter = ObjectInspectorConverters.getConverter(arguments[1], + PrimitiveObjectInspectorFactory.writableShortObjectInspector); } - scale = ((WritableConstantShortObjectInspector)scaleOI).getWritableConstantValue().get(); break; case INT: - if (!(scaleOI instanceof WritableConstantIntObjectInspector)) { - throw new UDFArgumentTypeException(1, getFuncName().toUpperCase() + " second argument only takes constant"); + if (inputScaleConst) { + if (!(inputScaleOI instanceof WritableConstantIntObjectInspector)) { + throw new UDFArgumentTypeException(1, + getFuncName().toUpperCase() + " second argument only takes constant"); + } + scale = + ((WritableConstantIntObjectInspector) inputScaleOI).getWritableConstantValue().get(); + } else { + inputScaleConverter = ObjectInspectorConverters.getConverter(arguments[1], + PrimitiveObjectInspectorFactory.writableIntObjectInspector); } - scale = ((WritableConstantIntObjectInspector)scaleOI).getWritableConstantValue().get(); break; case LONG: - if (!(scaleOI instanceof WritableConstantLongObjectInspector)) { - throw new UDFArgumentTypeException(1, getFuncName().toUpperCase() - + " second argument only takes constant"); + if (inputScaleConst) { + if (!(inputScaleOI instanceof WritableConstantLongObjectInspector)) { + throw new UDFArgumentTypeException(1, + getFuncName().toUpperCase() + " second argument only takes constant"); + } + long l = + ((WritableConstantLongObjectInspector) inputScaleOI).getWritableConstantValue().get(); + if (l < Integer.MIN_VALUE || l > Integer.MAX_VALUE) { + throw new UDFArgumentException( + getFuncName().toUpperCase() + " scale argument out of allowed range"); + } + scale = (int) l; + } else { + inputScaleConverter = ObjectInspectorConverters.getConverter(arguments[1], + PrimitiveObjectInspectorFactory.writableLongObjectInspector); } - long l = ((WritableConstantLongObjectInspector)scaleOI).getWritableConstantValue().get(); - if (l < Integer.MIN_VALUE || l > Integer.MAX_VALUE) { - throw new UDFArgumentException(getFuncName().toUpperCase() - + " scale argument out of allowed range"); - } - scale = (int)l; break; default: - throw new UDFArgumentTypeException(1, getFuncName().toUpperCase() - + " second argument only takes integer constant"); + throw new UDFArgumentTypeException(1, + getFuncName().toUpperCase() + " second argument only takes integer constant"); } } @@ -184,7 +216,8 @@ private static DecimalTypeInfo getOutputTypeInfo(DecimalTypeInfo inputTypeInfo, @Override public Object evaluate(DeferredObject[] arguments) throws HiveException { - if (arguments.length == 2 && (arguments[1] == null || arguments[1].get() == null)) { + if (arguments.length == 2 && inputScaleConst + && (arguments[1] == null || arguments[1].get() == null)) { return null; } @@ -196,6 +229,35 @@ public Object evaluate(DeferredObject[] arguments) throws HiveException { if (input == null) { return null; } + + if (arguments.length == 2 && arguments[1] != null && arguments[1].get() != null + && !inputScaleConst) { + Object scaleObj = null; + switch (inputScaleOI.getPrimitiveCategory()) { + case BYTE: + scaleObj = inputScaleConverter.convert(arguments[1].get()); + scale = ((ByteWritable) scaleObj).get(); + break; + case SHORT: + scaleObj = inputScaleConverter.convert(arguments[1].get()); + scale = ((ShortWritable) scaleObj).get(); + break; + case INT: + scaleObj = inputScaleConverter.convert(arguments[1].get()); + scale = ((IntWritable) scaleObj).get(); + break; + case LONG: + scaleObj = inputScaleConverter.convert(arguments[1].get()); + long l = ((LongWritable) scaleObj).get(); + if (l < Integer.MIN_VALUE || l > Integer.MAX_VALUE) { + throw new UDFArgumentException( + getFuncName().toUpperCase() + " scale argument out of allowed range"); + } + scale = (int) l; + default: + break; + } + } switch (inputType) { case VOID: diff --git ql/src/test/queries/clientpositive/udf_round.q ql/src/test/queries/clientpositive/udf_round.q index 88b2274..3cc53ab 100644 --- ql/src/test/queries/clientpositive/udf_round.q +++ ql/src/test/queries/clientpositive/udf_round.q @@ -44,3 +44,14 @@ FROM src tablesample (1 rows); SELECT round(1809242.3151111344, 9), round(-1809242.3151111344, 9), round(1809242.3151111344BD, 9), round(-1809242.3151111344BD, 9) FROM src tablesample (1 rows); + +DROP TABLE sampletable; + +CREATE TABLE sampletable(c DOUBLE, d INT) +ROW FORMAT DELIMITED +FIELDS TERMINATED BY ',' +STORED AS TEXTFILE; + +LOAD DATA LOCAL INPATH '../../data/files/round.txt' INTO TABLE sampletable; + +select round(c,d) from sampletable; \ No newline at end of file diff --git ql/src/test/results/clientpositive/udf_round.q.out ql/src/test/results/clientpositive/udf_round.q.out index c80f821..ce5e895 100644 --- ql/src/test/results/clientpositive/udf_round.q.out +++ ql/src/test/results/clientpositive/udf_round.q.out @@ -120,3 +120,45 @@ POSTHOOK: type: QUERY POSTHOOK: Input: default@src #### A masked pattern was here #### 1809242.315111134 -1809242.315111134 1809242.315111134 -1809242.315111134 +PREHOOK: query: DROP TABLE sampletable +PREHOOK: type: DROPTABLE +POSTHOOK: query: DROP TABLE sampletable +POSTHOOK: type: DROPTABLE +PREHOOK: query: CREATE TABLE sampletable(c DOUBLE, d INT) +ROW FORMAT DELIMITED +FIELDS TERMINATED BY ',' +STORED AS TEXTFILE +PREHOOK: type: CREATETABLE +PREHOOK: Output: database:default +PREHOOK: Output: default@sampletable +POSTHOOK: query: CREATE TABLE sampletable(c DOUBLE, d INT) +ROW FORMAT DELIMITED +FIELDS TERMINATED BY ',' +STORED AS TEXTFILE +POSTHOOK: type: CREATETABLE +POSTHOOK: Output: database:default +POSTHOOK: Output: default@sampletable +PREHOOK: query: LOAD DATA LOCAL INPATH '../../data/files/round.txt' INTO TABLE sampletable +PREHOOK: type: LOAD +#### A masked pattern was here #### +PREHOOK: Output: default@sampletable +POSTHOOK: query: LOAD DATA LOCAL INPATH '../../data/files/round.txt' INTO TABLE sampletable +POSTHOOK: type: LOAD +#### A masked pattern was here #### +POSTHOOK: Output: default@sampletable +PREHOOK: query: select round(c,d) from sampletable +PREHOOK: type: QUERY +PREHOOK: Input: default@sampletable +#### A masked pattern was here #### +POSTHOOK: query: select round(c,d) from sampletable +POSTHOOK: type: QUERY +POSTHOOK: Input: default@sampletable +#### A masked pattern was here #### +1809242.315111134 +-1809242.315111134 +3.141592653589793 +0.0 +0.0 +10000.0 +12345.0 +12345.0