diff --git ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDFRound.java ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDFRound.java index 36fc27e7c6..39b03805b6 100644 --- ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDFRound.java +++ ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDFRound.java @@ -39,11 +39,13 @@ import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorConverters.Converter; import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector.PrimitiveCategory; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorConverter.DoubleConverter; import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory; import org.apache.hadoop.hive.serde2.objectinspector.primitive.WritableConstantByteObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.primitive.WritableConstantIntObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.primitive.WritableConstantLongObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.primitive.WritableConstantShortObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.WritableConstantStringObjectInspector; import org.apache.hadoop.hive.serde2.typeinfo.DecimalTypeInfo; import org.apache.hadoop.hive.serde2.typeinfo.TypeInfoFactory; import org.apache.hadoop.io.FloatWritable; @@ -65,8 +67,8 @@ * */ @Description(name = "round", - value = "_FUNC_(x[, d]) - round x to d decimal places", - extended = "Example:\n" + value = "_FUNC_(x[, d]) - round x to d decimal places", + extended = "Example:\n" + " > SELECT _FUNC_(12.3456, 1) FROM src LIMIT 1;\n" + " 12.3'") @VectorizedExpressions({FuncRoundDoubleToDouble.class, RoundWithNumDigitsDoubleToDouble.class, FuncRoundWithNumDigitsDecimalToDecimal.class, FuncRoundDecimalToDecimal.class}) @@ -117,7 +119,7 @@ public ObjectInspector initialize(ObjectInspector[] arguments) throws UDFArgumen break; case INT: if (scaleOI instanceof WritableConstantIntObjectInspector) { - scale = ((WritableConstantIntObjectInspector)scaleOI).getWritableConstantValue().get(); + scale = ((WritableConstantIntObjectInspector)scaleOI).getWritableConstantValue().get(); } else { constantScale = false; } @@ -148,7 +150,7 @@ public ObjectInspector initialize(ObjectInspector[] arguments) throws UDFArgumen DecimalTypeInfo typeInfo = getOutputTypeInfo(inputTypeInfo, scale); outputOI = PrimitiveObjectInspectorFactory.getPrimitiveWritableObjectInspector(typeInfo); if (!constantScale) { - throw new UDFArgumentTypeException(1,getFuncName().toUpperCase() + " scale argument for " + throw new UDFArgumentTypeException(1, getFuncName().toUpperCase() + " scale argument for " + "decimal must be constant"); } break; @@ -164,7 +166,16 @@ public ObjectInspector initialize(ObjectInspector[] arguments) throws UDFArgumen case STRING: case VARCHAR: case CHAR: - outputOI = PrimitiveObjectInspectorFactory.getPrimitiveWritableObjectInspector(PrimitiveCategory.DOUBLE); + if (inputOI instanceof WritableConstantStringObjectInspector) { + Object obj = ((WritableConstantStringObjectInspector) inputOI).getWritableConstantValue(); + if (obj.toString().contains(".")) { + outputOI = PrimitiveObjectInspectorFactory.getPrimitiveWritableObjectInspector(PrimitiveCategory.DOUBLE); + } else { + outputOI = PrimitiveObjectInspectorFactory.getPrimitiveWritableObjectInspector(PrimitiveCategory.LONG); + } + } else { + outputOI = PrimitiveObjectInspectorFactory.getPrimitiveWritableObjectInspector(PrimitiveCategory.DOUBLE); + } converterFromString = ObjectInspectorConverters.getConverter(inputOI, outputOI); break; default: @@ -210,12 +221,10 @@ public Object evaluate(DeferredObject[] arguments) throws HiveException { case VOID: return null; case DECIMAL: - { - // The getPrimitiveWritableObject method returns a new writable. - HiveDecimalWritable decimalWritable = (HiveDecimalWritable) inputOI.getPrimitiveWritableObject(input); - // Call the different round flavor. - return round(decimalWritable, scale); - } + // The getPrimitiveWritableObject method returns a new writable. + HiveDecimalWritable decimalWritable = (HiveDecimalWritable) inputOI.getPrimitiveWritableObject(input); + // Call the different round flavor. + return round(decimalWritable, scale); case BYTE: ByteWritable byteWritable = (ByteWritable)inputOI.getPrimitiveWritableObject(input); if (!constantScale) { @@ -262,23 +271,34 @@ public Object evaluate(DeferredObject[] arguments) throws HiveException { scale = ((Number)scaleOI.getPrimitiveJavaObject(arguments[1].get())).intValue(); } return new FloatWritable((float)round(f, scale)); - case DOUBLE: - if (!constantScale) { - scale = ((Number)scaleOI.getPrimitiveJavaObject(arguments[1].get())).intValue(); - } - return round(((DoubleWritable)inputOI.getPrimitiveWritableObject(input)), scale); + case DOUBLE: + if (!constantScale) { + scale = ((Number)scaleOI.getPrimitiveJavaObject(arguments[1].get())).intValue(); + } + return round(((DoubleWritable)inputOI.getPrimitiveWritableObject(input)), scale); case STRING: case VARCHAR: case CHAR: - DoubleWritable doubleValue = (DoubleWritable) converterFromString.convert(input); - if (doubleValue == null) { - return null; - } - if (!constantScale) { - scale = ((Number)scaleOI.getPrimitiveJavaObject(arguments[1].get())).intValue(); - } - return round(doubleValue, scale); - default: + if (converterFromString instanceof DoubleConverter) { + DoubleWritable value = (DoubleWritable) converterFromString.convert(input); + if (value == null) { + return null; + } + if (!constantScale) { + scale = ((Number)scaleOI.getPrimitiveJavaObject(arguments[1].get())).intValue(); + } + return round(value, scale); + } else { + LongWritable value = (LongWritable) converterFromString.convert(input); + if (value == null) { + return null; + } + if (!constantScale) { + scale = ((Number)scaleOI.getPrimitiveJavaObject(arguments[1].get())).intValue(); + } + return round(value, scale); + } + default: throw new UDFArgumentTypeException(0, "Only numeric or string group data types are allowed for ROUND function. Got " + inputType.name()); @@ -308,6 +328,11 @@ protected DoubleWritable round(DoubleWritable input, int scale) { } } + protected LongWritable round(LongWritable input, int scale) { + long d = input.get(); + return new LongWritable(d); + } + @Override public String getDisplayString(String[] children) { return getStandardDisplayString("round", children); diff --git ql/src/test/queries/clientpositive/udf_round_decimal_integer.q ql/src/test/queries/clientpositive/udf_round_decimal_integer.q new file mode 100644 index 0000000000..611f5deceb --- /dev/null +++ ql/src/test/queries/clientpositive/udf_round_decimal_integer.q @@ -0,0 +1,8 @@ +--! qt:dataset:src +--! qt:dataset:lineitem +set hive.fetch.task.conversion=more; +set hive.explain.user=false; +explain select round(3,4) as r_rstr; +explain select round('3',4) as r_rstr; +explain select round(3.1,4) as r_rstr; +explain select round('3.1',4) as r_rstr; diff --git ql/src/test/results/clientpositive/udf_round_decimal_integer.q.out ql/src/test/results/clientpositive/udf_round_decimal_integer.q.out new file mode 100644 index 0000000000..bb099b196c --- /dev/null +++ ql/src/test/results/clientpositive/udf_round_decimal_integer.q.out @@ -0,0 +1,104 @@ +PREHOOK: query: explain select round(3,4) as r_rstr +PREHOOK: type: QUERY +PREHOOK: Input: _dummy_database@_dummy_table +#### A masked pattern was here #### +POSTHOOK: query: explain select round(3,4) as r_rstr +POSTHOOK: type: QUERY +POSTHOOK: Input: _dummy_database@_dummy_table +#### A masked pattern was here #### +STAGE DEPENDENCIES: + Stage-0 is a root stage + +STAGE PLANS: + Stage: Stage-0 + Fetch Operator + limit: -1 + Processor Tree: + TableScan + alias: _dummy_table + Row Limit Per Split: 1 + Statistics: Num rows: 1 Data size: 10 Basic stats: COMPLETE Column stats: COMPLETE + Select Operator + expressions: 3 (type: int) + outputColumnNames: _col0 + Statistics: Num rows: 1 Data size: 4 Basic stats: COMPLETE Column stats: COMPLETE + ListSink + +PREHOOK: query: explain select round('3',4) as r_rstr +PREHOOK: type: QUERY +PREHOOK: Input: _dummy_database@_dummy_table +#### A masked pattern was here #### +POSTHOOK: query: explain select round('3',4) as r_rstr +POSTHOOK: type: QUERY +POSTHOOK: Input: _dummy_database@_dummy_table +#### A masked pattern was here #### +STAGE DEPENDENCIES: + Stage-0 is a root stage + +STAGE PLANS: + Stage: Stage-0 + Fetch Operator + limit: -1 + Processor Tree: + TableScan + alias: _dummy_table + Row Limit Per Split: 1 + Statistics: Num rows: 1 Data size: 10 Basic stats: COMPLETE Column stats: COMPLETE + Select Operator + expressions: 3L (type: bigint) + outputColumnNames: _col0 + Statistics: Num rows: 1 Data size: 8 Basic stats: COMPLETE Column stats: COMPLETE + ListSink + +PREHOOK: query: explain select round(3.1,4) as r_rstr +PREHOOK: type: QUERY +PREHOOK: Input: _dummy_database@_dummy_table +#### A masked pattern was here #### +POSTHOOK: query: explain select round(3.1,4) as r_rstr +POSTHOOK: type: QUERY +POSTHOOK: Input: _dummy_database@_dummy_table +#### A masked pattern was here #### +STAGE DEPENDENCIES: + Stage-0 is a root stage + +STAGE PLANS: + Stage: Stage-0 + Fetch Operator + limit: -1 + Processor Tree: + TableScan + alias: _dummy_table + Row Limit Per Split: 1 + Statistics: Num rows: 1 Data size: 10 Basic stats: COMPLETE Column stats: COMPLETE + Select Operator + expressions: 3.1 (type: decimal(5,4)) + outputColumnNames: _col0 + Statistics: Num rows: 1 Data size: 112 Basic stats: COMPLETE Column stats: COMPLETE + ListSink + +PREHOOK: query: explain select round('3.1',4) as r_rstr +PREHOOK: type: QUERY +PREHOOK: Input: _dummy_database@_dummy_table +#### A masked pattern was here #### +POSTHOOK: query: explain select round('3.1',4) as r_rstr +POSTHOOK: type: QUERY +POSTHOOK: Input: _dummy_database@_dummy_table +#### A masked pattern was here #### +STAGE DEPENDENCIES: + Stage-0 is a root stage + +STAGE PLANS: + Stage: Stage-0 + Fetch Operator + limit: -1 + Processor Tree: + TableScan + alias: _dummy_table + Row Limit Per Split: 1 + Statistics: Num rows: 1 Data size: 10 Basic stats: COMPLETE Column stats: COMPLETE + Select Operator + expressions: 3.1D (type: double) + outputColumnNames: _col0 + Statistics: Num rows: 1 Data size: 8 Basic stats: COMPLETE Column stats: COMPLETE + ListSink +