diff --git ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDFRound.java ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDFRound.java index 36fc27e7c6..62e7d8b485 100644 --- ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDFRound.java +++ ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDFRound.java @@ -44,6 +44,7 @@ import org.apache.hadoop.hive.serde2.objectinspector.primitive.WritableConstantIntObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.primitive.WritableConstantLongObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.primitive.WritableConstantShortObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.WritableConstantStringObjectInspector; import org.apache.hadoop.hive.serde2.typeinfo.DecimalTypeInfo; import org.apache.hadoop.hive.serde2.typeinfo.TypeInfoFactory; import org.apache.hadoop.io.FloatWritable; @@ -164,7 +165,12 @@ public ObjectInspector initialize(ObjectInspector[] arguments) throws UDFArgumen case STRING: case VARCHAR: case CHAR: - outputOI = PrimitiveObjectInspectorFactory.getPrimitiveWritableObjectInspector(PrimitiveCategory.DOUBLE); + Object obj = ((WritableConstantStringObjectInspector) inputOI).getWritableConstantValue(); + if (obj.toString().contains(".")) { + outputOI = PrimitiveObjectInspectorFactory.getPrimitiveWritableObjectInspector(PrimitiveCategory.DOUBLE); + } else { + outputOI = PrimitiveObjectInspectorFactory.getPrimitiveWritableObjectInspector(PrimitiveCategory.LONG); + } converterFromString = ObjectInspectorConverters.getConverter(inputOI, outputOI); break; default: @@ -270,14 +276,25 @@ public Object evaluate(DeferredObject[] arguments) throws HiveException { case STRING: case VARCHAR: case CHAR: - DoubleWritable doubleValue = (DoubleWritable) converterFromString.convert(input); - if (doubleValue == null) { - return null; + if (input.toString().contains(".")) { + DoubleWritable value = (DoubleWritable) converterFromString.convert(input); + if (value == null) { + return null; + } + if (!constantScale) { + scale = ((Number)scaleOI.getPrimitiveJavaObject(arguments[1].get())).intValue(); + } + return round(value, scale); + } else { + LongWritable value = (LongWritable) converterFromString.convert(input); + if (value == null) { + return null; + } + if (!constantScale) { + scale = ((Number)scaleOI.getPrimitiveJavaObject(arguments[1].get())).intValue(); + } + return round(value, scale); } - if (!constantScale) { - scale = ((Number)scaleOI.getPrimitiveJavaObject(arguments[1].get())).intValue(); - } - return round(doubleValue, scale); default: throw new UDFArgumentTypeException(0, "Only numeric or string group data types are allowed for ROUND function. Got " @@ -308,6 +325,11 @@ protected DoubleWritable round(DoubleWritable input, int scale) { } } + protected LongWritable round(LongWritable input, int scale) { + long d = input.get(); + return new LongWritable(d); + } + @Override public String getDisplayString(String[] children) { return getStandardDisplayString("round", children); diff --git ql/src/test/queries/clientpositive/udf_round_decimal_integer.q ql/src/test/queries/clientpositive/udf_round_decimal_integer.q new file mode 100644 index 0000000000..611f5deceb --- /dev/null +++ ql/src/test/queries/clientpositive/udf_round_decimal_integer.q @@ -0,0 +1,8 @@ +--! qt:dataset:src +--! qt:dataset:lineitem +set hive.fetch.task.conversion=more; +set hive.explain.user=false; +explain select round(3,4) as r_rstr; +explain select round('3',4) as r_rstr; +explain select round(3.1,4) as r_rstr; +explain select round('3.1',4) as r_rstr; diff --git ql/src/test/results/clientpositive/udf_round_decimal_integer.q.out ql/src/test/results/clientpositive/udf_round_decimal_integer.q.out new file mode 100644 index 0000000000..bb099b196c --- /dev/null +++ ql/src/test/results/clientpositive/udf_round_decimal_integer.q.out @@ -0,0 +1,104 @@ +PREHOOK: query: explain select round(3,4) as r_rstr +PREHOOK: type: QUERY +PREHOOK: Input: _dummy_database@_dummy_table +#### A masked pattern was here #### +POSTHOOK: query: explain select round(3,4) as r_rstr +POSTHOOK: type: QUERY +POSTHOOK: Input: _dummy_database@_dummy_table +#### A masked pattern was here #### +STAGE DEPENDENCIES: + Stage-0 is a root stage + +STAGE PLANS: + Stage: Stage-0 + Fetch Operator + limit: -1 + Processor Tree: + TableScan + alias: _dummy_table + Row Limit Per Split: 1 + Statistics: Num rows: 1 Data size: 10 Basic stats: COMPLETE Column stats: COMPLETE + Select Operator + expressions: 3 (type: int) + outputColumnNames: _col0 + Statistics: Num rows: 1 Data size: 4 Basic stats: COMPLETE Column stats: COMPLETE + ListSink + +PREHOOK: query: explain select round('3',4) as r_rstr +PREHOOK: type: QUERY +PREHOOK: Input: _dummy_database@_dummy_table +#### A masked pattern was here #### +POSTHOOK: query: explain select round('3',4) as r_rstr +POSTHOOK: type: QUERY +POSTHOOK: Input: _dummy_database@_dummy_table +#### A masked pattern was here #### +STAGE DEPENDENCIES: + Stage-0 is a root stage + +STAGE PLANS: + Stage: Stage-0 + Fetch Operator + limit: -1 + Processor Tree: + TableScan + alias: _dummy_table + Row Limit Per Split: 1 + Statistics: Num rows: 1 Data size: 10 Basic stats: COMPLETE Column stats: COMPLETE + Select Operator + expressions: 3L (type: bigint) + outputColumnNames: _col0 + Statistics: Num rows: 1 Data size: 8 Basic stats: COMPLETE Column stats: COMPLETE + ListSink + +PREHOOK: query: explain select round(3.1,4) as r_rstr +PREHOOK: type: QUERY +PREHOOK: Input: _dummy_database@_dummy_table +#### A masked pattern was here #### +POSTHOOK: query: explain select round(3.1,4) as r_rstr +POSTHOOK: type: QUERY +POSTHOOK: Input: _dummy_database@_dummy_table +#### A masked pattern was here #### +STAGE DEPENDENCIES: + Stage-0 is a root stage + +STAGE PLANS: + Stage: Stage-0 + Fetch Operator + limit: -1 + Processor Tree: + TableScan + alias: _dummy_table + Row Limit Per Split: 1 + Statistics: Num rows: 1 Data size: 10 Basic stats: COMPLETE Column stats: COMPLETE + Select Operator + expressions: 3.1 (type: decimal(5,4)) + outputColumnNames: _col0 + Statistics: Num rows: 1 Data size: 112 Basic stats: COMPLETE Column stats: COMPLETE + ListSink + +PREHOOK: query: explain select round('3.1',4) as r_rstr +PREHOOK: type: QUERY +PREHOOK: Input: _dummy_database@_dummy_table +#### A masked pattern was here #### +POSTHOOK: query: explain select round('3.1',4) as r_rstr +POSTHOOK: type: QUERY +POSTHOOK: Input: _dummy_database@_dummy_table +#### A masked pattern was here #### +STAGE DEPENDENCIES: + Stage-0 is a root stage + +STAGE PLANS: + Stage: Stage-0 + Fetch Operator + limit: -1 + Processor Tree: + TableScan + alias: _dummy_table + Row Limit Per Split: 1 + Statistics: Num rows: 1 Data size: 10 Basic stats: COMPLETE Column stats: COMPLETE + Select Operator + expressions: 3.1D (type: double) + outputColumnNames: _col0 + Statistics: Num rows: 1 Data size: 8 Basic stats: COMPLETE Column stats: COMPLETE + ListSink +