diff --git ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDFRound.java ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDFRound.java index e8b0d15..7bf8a9b 100644 --- ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDFRound.java +++ ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDFRound.java @@ -33,6 +33,7 @@ import org.apache.hadoop.hive.serde2.io.DoubleWritable; import org.apache.hadoop.hive.serde2.io.HiveDecimalWritable; import org.apache.hadoop.hive.serde2.io.ShortWritable; +import org.apache.hadoop.hive.serde2.objectinspector.ConstantObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector.Category; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorConverters; @@ -49,6 +50,8 @@ import org.apache.hadoop.io.FloatWritable; import org.apache.hadoop.io.IntWritable; import org.apache.hadoop.io.LongWritable; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; /** * Note: rounding function permits rounding off integer digits in decimal numbers, which essentially @@ -72,10 +75,14 @@ FuncRoundWithNumDigitsDecimalToDecimal.class, FuncRoundDecimalToDecimal.class}) public class GenericUDFRound extends GenericUDF { private transient PrimitiveObjectInspector inputOI; + private static final Logger LOG = LoggerFactory.getLogger(GenericUDFRound.class); private int scale = 0; private transient PrimitiveCategory inputType; private transient Converter converterFromString; + private transient Converter inputScaleConverter; + private transient PrimitiveObjectInspector inputScaleOI; + private transient boolean inputScaleConst; @Override public ObjectInspector initialize(ObjectInspector[] arguments) throws UDFArgumentException { @@ -95,43 +102,71 @@ public ObjectInspector initialize(ObjectInspector[] arguments) throws UDFArgumen throw new UDFArgumentTypeException(1, "ROUND second argument only takes primitive types, got " + arguments[1].getTypeName()); } - PrimitiveObjectInspector scaleOI = (PrimitiveObjectInspector) arguments[1]; - switch (scaleOI.getPrimitiveCategory()) { + inputScaleOI = (PrimitiveObjectInspector) arguments[1]; + inputScaleConst = arguments[1] instanceof ConstantObjectInspector; + switch (inputScaleOI.getPrimitiveCategory()) { case VOID: break; case BYTE: - if (!(scaleOI instanceof WritableConstantByteObjectInspector)) { - throw new UDFArgumentTypeException(1, getFuncName().toUpperCase() + " second argument only takes constant"); + if (inputScaleConst) { + if (!(inputScaleOI instanceof WritableConstantByteObjectInspector)) { + throw new UDFArgumentTypeException(1, + getFuncName().toUpperCase() + " second argument only takes constant"); + } + scale = + ((WritableConstantByteObjectInspector) inputScaleOI).getWritableConstantValue().get(); + } else { + inputScaleConverter = ObjectInspectorConverters.getConverter(arguments[1], + PrimitiveObjectInspectorFactory.writableByteObjectInspector); } - scale = ((WritableConstantByteObjectInspector)scaleOI).getWritableConstantValue().get(); break; case SHORT: - if (!(scaleOI instanceof WritableConstantShortObjectInspector)) { - throw new UDFArgumentTypeException(1, getFuncName().toUpperCase() + " second argument only takes constant"); + if (inputScaleConst) { + if (!(inputScaleOI instanceof WritableConstantShortObjectInspector)) { + throw new UDFArgumentTypeException(1, + getFuncName().toUpperCase() + " second argument only takes constant"); + } + scale = ((WritableConstantShortObjectInspector) inputScaleOI).getWritableConstantValue() + .get(); + } else { + inputScaleConverter = ObjectInspectorConverters.getConverter(arguments[1], + PrimitiveObjectInspectorFactory.writableShortObjectInspector); } - scale = ((WritableConstantShortObjectInspector)scaleOI).getWritableConstantValue().get(); break; case INT: - if (!(scaleOI instanceof WritableConstantIntObjectInspector)) { - throw new UDFArgumentTypeException(1, getFuncName().toUpperCase() + " second argument only takes constant"); + if (inputScaleConst) { + if (!(inputScaleOI instanceof WritableConstantIntObjectInspector)) { + throw new UDFArgumentTypeException(1, + getFuncName().toUpperCase() + " second argument only takes constant"); + } + scale = + ((WritableConstantIntObjectInspector) inputScaleOI).getWritableConstantValue().get(); + } else { + inputScaleConverter = ObjectInspectorConverters.getConverter(arguments[1], + PrimitiveObjectInspectorFactory.writableIntObjectInspector); } - scale = ((WritableConstantIntObjectInspector)scaleOI).getWritableConstantValue().get(); break; case LONG: - if (!(scaleOI instanceof WritableConstantLongObjectInspector)) { - throw new UDFArgumentTypeException(1, getFuncName().toUpperCase() - + " second argument only takes constant"); + if (inputScaleConst) { + if (!(inputScaleOI instanceof WritableConstantLongObjectInspector)) { + throw new UDFArgumentTypeException(1, + getFuncName().toUpperCase() + " second argument only takes constant"); + } + long l = + ((WritableConstantLongObjectInspector) inputScaleOI).getWritableConstantValue().get(); + if (l < Integer.MIN_VALUE || l > Integer.MAX_VALUE) { + throw new UDFArgumentException( + getFuncName().toUpperCase() + " scale argument out of allowed range"); + } + scale = (int) l; + } else { + inputScaleConverter = ObjectInspectorConverters.getConverter(arguments[1], + PrimitiveObjectInspectorFactory.writableLongObjectInspector); } - long l = ((WritableConstantLongObjectInspector)scaleOI).getWritableConstantValue().get(); - if (l < Integer.MIN_VALUE || l > Integer.MAX_VALUE) { - throw new UDFArgumentException(getFuncName().toUpperCase() - + " scale argument out of allowed range"); - } - scale = (int)l; break; default: - throw new UDFArgumentTypeException(1, getFuncName().toUpperCase() - + " second argument only takes integer constant"); + throw new UDFArgumentTypeException(1, + getFuncName().toUpperCase() + " second argument only takes integer constant"); } } @@ -184,10 +219,13 @@ private static DecimalTypeInfo getOutputTypeInfo(DecimalTypeInfo inputTypeInfo, @Override public Object evaluate(DeferredObject[] arguments) throws HiveException { - if (arguments.length == 2 && (arguments[1] == null || arguments[1].get() == null)) { + if (arguments.length == 2 && inputScaleConst + && (arguments[1] == null || arguments[1].get() == null)) { return null; } + LOG.info("**is SCale greater than :"); + if (arguments[0] == null) { return null; } @@ -196,6 +234,35 @@ public Object evaluate(DeferredObject[] arguments) throws HiveException { if (input == null) { return null; } + + if (arguments.length == 2 && arguments[1] != null && arguments[1].get() != null + && !inputScaleConst) { + Object scaleObj = null; + switch (inputScaleOI.getPrimitiveCategory()) { + case BYTE: + scaleObj = inputScaleConverter.convert(arguments[1].get()); + scale = ((ByteWritable) scaleObj).get(); + break; + case SHORT: + scaleObj = inputScaleConverter.convert(arguments[1].get()); + scale = ((ShortWritable) scaleObj).get(); + break; + case INT: + scaleObj = inputScaleConverter.convert(arguments[1].get()); + scale = ((IntWritable) scaleObj).get(); + break; + case LONG: + scaleObj = inputScaleConverter.convert(arguments[1].get()); + long l = ((LongWritable) scaleObj).get(); + if (l < Integer.MIN_VALUE || l > Integer.MAX_VALUE) { + throw new UDFArgumentException( + getFuncName().toUpperCase() + " scale argument out of allowed range"); + } + scale = (int) l; + default: + break; + } + } switch (inputType) { case VOID: @@ -262,22 +329,45 @@ protected HiveDecimalWritable round(HiveDecimalWritable inputDecWritable, int sc } protected long round(long input, int scale) { - return RoundUtils.round(input, scale); + String inputStr = String.valueOf(input); + if (!isScaleGreater(inputStr, scale)) { + return RoundUtils.round(input, scale); + } else { + return input; + } } protected double round(double input, int scale) { - return RoundUtils.round(input, scale); + String inputStr = String.valueOf(input); + if (!isScaleGreater(inputStr, scale)) { + return RoundUtils.round(input, scale); + } else { + return input; + } } protected DoubleWritable round(DoubleWritable input, int scale) { double d = input.get(); - if (Double.isNaN(d) || Double.isInfinite(d)) { - return new DoubleWritable(d); + String inputStr = String.valueOf(d); + if (!isScaleGreater(inputStr, scale)) { + if (Double.isNaN(d) || Double.isInfinite(d)) { + return new DoubleWritable(d); + } else { + return new DoubleWritable(RoundUtils.round(d, scale)); + } } else { - return new DoubleWritable(RoundUtils.round(d, scale)); + return input; } } + private boolean isScaleGreater(String input, int scale) { + String[] split = input.split("\\."); + boolean isScaleGreater = false; + isScaleGreater = (scale > 0 && scale >= split[1].length()) ? true : false; + LOG.info("**is SCale greater than :"+scale+" and input is : "+input+" : isScaleGreater :"+isScaleGreater); + return isScaleGreater; + } + @Override public String getDisplayString(String[] children) { return getStandardDisplayString("round", children); diff --git ql/src/test/queries/clientpositive/udf_round_4.q ql/src/test/queries/clientpositive/udf_round_4.q new file mode 100644 index 0000000..796ee70 --- /dev/null +++ ql/src/test/queries/clientpositive/udf_round_4.q @@ -0,0 +1,24 @@ + +SELECT round(1234567891.1234567891,4), round(1234567891.1234567891,-4), round(1234567891.1234567891,0), round(1234567891.1234567891) FROM src tablesample (1 rows); + +SELECT round(1234567891.1234567891,100), round(1234567891.1234567891,-100), round(1234567891.1234567891,0), round(1234567891.1234567891) FROM src tablesample (1 rows); + +SELECT round(1234567891.1234567891,50), round(1234567891.1234567891,-50), round(1234567891.1234567891,25), round(1234567891.1234567891,20), round(1234567891.1234567891,15) FROM src tablesample (1 rows); + +SELECT round(-1234567891.1234567891,4), round(-1234567891.1234567891,-4), round(-1234567891.1234567891,0), round(-1234567891.1234567891) FROM src tablesample (1 rows); + +select round(1234567891.1234567891,9),round(1234567891.1234567896,9), round(1234567891.1234567891,10), round(1234567891.1234567891,11), round(1234567891.1234567891,12) FROM src tablesample (1 rows); + +select round(1234567891.1234567891,-9),round(1234567891.1234567896,-9), round(1234567891.1234567891,-10), round(1234567891.1234567891,-11), round(1234567891.1234567891,-12) FROM src tablesample (1 rows); + +select round(-1234567891.1234567891,9), round(-1234567891.1234567891,10), round(-1234567891.1234567891,11), round(-1234567891.1234567891,12) FROM src tablesample (1 rows); + +select round(-1234567891.1234567891,-9), round(-1234567891.1234567891,-10), round(-1234567891.1234567891,-11), round(-1234567891.1234567891,-12) FROM src tablesample (1 rows); + +select round(1234567891.1234567891,15), round(1234567891.1234567891,25), round(1234567891.1234567891,20), round(1234567891.1234567891,50) FROM src tablesample (1 rows); + +select round(1234567891.1234567891,-15), round(1234567891.1234567891,-25), round(1234567891.1234567891,-20), round(1234567891.1234567891,-50) FROM src tablesample (1 rows); + +select round(-1234567891.1234567891,15), round(-1234567891.1234567891,25), round(-1234567891.1234567891,20), round(-1234567891.1234567891,50) FROM src tablesample (1 rows); + +select round(-1234567891.1234567891,-15), round(-1234567891.1234567891,-25), round(-1234567891.1234567891,-20), round(-1234567891.1234567891,-50) FROM src tablesample (1 rows);