diff --git ql/src/java/org/apache/hadoop/hive/ql/exec/FunctionRegistry.java ql/src/java/org/apache/hadoop/hive/ql/exec/FunctionRegistry.java index de74c3e..eb2601e 100644 --- ql/src/java/org/apache/hadoop/hive/ql/exec/FunctionRegistry.java +++ ql/src/java/org/apache/hadoop/hive/ql/exec/FunctionRegistry.java @@ -206,6 +206,7 @@ system.registerGenericUDF("round", GenericUDFRound.class); system.registerGenericUDF("bround", GenericUDFBRound.class); system.registerGenericUDF("floor", GenericUDFFloor.class); + system.registerGenericUDF("trunc2", GenericUDFTrunc2.class); system.registerUDF("sqrt", UDFSqrt.class, false); system.registerGenericUDF("cbrt", GenericUDFCbrt.class); system.registerGenericUDF("ceil", GenericUDFCeil.class); diff --git ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDFTrunc2.java ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDFTrunc2.java new file mode 100644 index 0000000..15d5c98 --- /dev/null +++ ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDFTrunc2.java @@ -0,0 +1,249 @@ +package org.apache.hadoop.hive.ql.udf.generic; + +import java.math.BigDecimal; + +import org.apache.hadoop.hive.common.type.HiveDecimal; +import org.apache.hadoop.hive.ql.exec.UDFArgumentException; +import org.apache.hadoop.hive.ql.exec.UDFArgumentLengthException; +import org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException; +import org.apache.hadoop.hive.ql.metadata.HiveException; +import org.apache.hadoop.hive.serde2.io.ByteWritable; +import org.apache.hadoop.hive.serde2.io.DoubleWritable; +import org.apache.hadoop.hive.serde2.io.HiveDecimalWritable; +import org.apache.hadoop.hive.serde2.io.ShortWritable; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector.Category; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorConverters; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorConverters.Converter; +import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector.PrimitiveCategory; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.WritableConstantByteObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.WritableConstantIntObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.WritableConstantLongObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.WritableConstantShortObjectInspector; +import org.apache.hadoop.io.FloatWritable; +import org.apache.hadoop.io.IntWritable; +import org.apache.hadoop.io.LongWritable; + +import groovy.ui.SystemOutputInterceptor; + +public class GenericUDFTrunc2 extends GenericUDF { + + private transient PrimitiveObjectInspector inputOI; + private int scale = 0; + + private transient PrimitiveCategory inputType; + private transient Converter converterFromString; + + @Override + public ObjectInspector initialize(ObjectInspector[] arguments) throws UDFArgumentException { + if (arguments.length < 1 || arguments.length > 2) { + throw new UDFArgumentLengthException( + "ROUND requires one or two argument, got " + arguments.length); + } + + if (arguments[0].getCategory() != Category.PRIMITIVE) { + throw new UDFArgumentTypeException(0, + "ROUND input only takes primitive types, got " + arguments[0].getTypeName()); + } + inputOI = (PrimitiveObjectInspector) arguments[0]; + + if (arguments.length == 2) { + if (arguments[1].getCategory() != Category.PRIMITIVE) { + throw new UDFArgumentTypeException(1, + "ROUND second argument only takes primitive types, got " + arguments[1].getTypeName()); + } + PrimitiveObjectInspector scaleOI = (PrimitiveObjectInspector) arguments[1]; + switch (scaleOI.getPrimitiveCategory()) { + case VOID: + break; + case BYTE: + if (!(scaleOI instanceof WritableConstantByteObjectInspector)) { + throw new UDFArgumentTypeException(1, + getFuncName().toUpperCase() + " second argument only takes constant"); + } + scale = ((WritableConstantByteObjectInspector) scaleOI).getWritableConstantValue().get(); + break; + case SHORT: + if (!(scaleOI instanceof WritableConstantShortObjectInspector)) { + throw new UDFArgumentTypeException(1, + getFuncName().toUpperCase() + " second argument only takes constant"); + } + scale = ((WritableConstantShortObjectInspector) scaleOI).getWritableConstantValue().get(); + break; + case INT: + if (!(scaleOI instanceof WritableConstantIntObjectInspector)) { + throw new UDFArgumentTypeException(1, + getFuncName().toUpperCase() + " second argument only takes constant"); + } + scale = ((WritableConstantIntObjectInspector) scaleOI).getWritableConstantValue().get(); + break; + case LONG: + if (!(scaleOI instanceof WritableConstantLongObjectInspector)) { + throw new UDFArgumentTypeException(1, + getFuncName().toUpperCase() + " second argument only takes constant"); + } + long l = ((WritableConstantLongObjectInspector) scaleOI).getWritableConstantValue().get(); + if (l < Integer.MIN_VALUE || l > Integer.MAX_VALUE) { + throw new UDFArgumentException( + getFuncName().toUpperCase() + " scale argument out of allowed range"); + } + scale = (int) l; + break; + default: + throw new UDFArgumentTypeException(1, + getFuncName().toUpperCase() + " second argument only takes integer constant"); + } + } + + inputType = inputOI.getPrimitiveCategory(); + ObjectInspector outputOI = null; + switch (inputType) { + case DECIMAL: + outputOI = PrimitiveObjectInspectorFactory.getPrimitiveWritableObjectInspector(inputType); + break; + case VOID: + case BYTE: + case SHORT: + case INT: + case LONG: + case FLOAT: + case DOUBLE: + outputOI = PrimitiveObjectInspectorFactory.getPrimitiveWritableObjectInspector(inputType); + break; + case STRING: + case VARCHAR: + case CHAR: + outputOI = PrimitiveObjectInspectorFactory + .getPrimitiveWritableObjectInspector(PrimitiveCategory.DOUBLE); + converterFromString = ObjectInspectorConverters.getConverter(inputOI, outputOI); + break; + default: + throw new UDFArgumentTypeException(0, + "Only numeric or string group data types are allowed for ROUND function. Got " + + inputType.name()); + } + + return outputOI; + } + + @Override + public Object evaluate(DeferredObject[] arguments) throws HiveException { + if (arguments.length == 2 && (arguments[1] == null || arguments[1].get() == null)) { + return null; + } + + if (arguments[0] == null) { + return null; + } + + Object input = arguments[0].get(); + if (input == null) { + return null; + } + + switch (inputType) { + case VOID: + return null; + case DECIMAL: + HiveDecimalWritable decimalWritable = + (HiveDecimalWritable) inputOI.getPrimitiveWritableObject(input); + HiveDecimal dec = trunc(decimalWritable.getHiveDecimal(), scale); + if (dec == null) { + return null; + } + return new HiveDecimalWritable(dec); + case BYTE: + ByteWritable byteWritable = (ByteWritable) inputOI.getPrimitiveWritableObject(input); + if (scale >= 0) { + return byteWritable; + } else { + return new ByteWritable((byte) trunc(byteWritable.get(), scale)); + } + case SHORT: + ShortWritable shortWritable = (ShortWritable) inputOI.getPrimitiveWritableObject(input); + if (scale >= 0) { + return shortWritable; + } else { + return new ShortWritable((short) trunc(shortWritable.get(), scale)); + } + case INT: + IntWritable intWritable = (IntWritable) inputOI.getPrimitiveWritableObject(input); + if (scale >= 0) { + return intWritable; + } else { + return new IntWritable((int) trunc(intWritable.get(), scale)); + } + case LONG: + LongWritable longWritable = (LongWritable) inputOI.getPrimitiveWritableObject(input); + if (scale >= 0) { + return longWritable; + } else { + return new LongWritable(trunc(longWritable.get(), scale)); + } + case FLOAT: + float f = ((FloatWritable) inputOI.getPrimitiveWritableObject(input)).get(); + return new FloatWritable((float) trunc(f, scale)); + case DOUBLE: + return trunc(((DoubleWritable) inputOI.getPrimitiveWritableObject(input)), scale); + case STRING: + case VARCHAR: + case CHAR: + DoubleWritable doubleValue = (DoubleWritable) converterFromString.convert(input); + if (doubleValue == null) { + return null; + } + return trunc(doubleValue, scale); + default: + throw new UDFArgumentTypeException(0, + "Only numeric or string group data types are allowed for ROUND function. Got " + + inputType.name()); + + } + } + + protected HiveDecimal trunc(HiveDecimal input, int scale) { + BigDecimal bigDecimal = trunc(input.bigDecimalValue(), scale); + return HiveDecimal.create(bigDecimal); + } + + protected long trunc(long input, int scale) { + return trunc(BigDecimal.valueOf(input), scale).longValue(); + } + + protected double trunc(double input, int scale) { + return trunc(BigDecimal.valueOf(input), scale).doubleValue(); + } + + protected DoubleWritable trunc(DoubleWritable input, int scale) { + BigDecimal bigDecimal = new BigDecimal(input.get()); + BigDecimal trunc = trunc(bigDecimal, scale); + DoubleWritable doubleWritable = new DoubleWritable(trunc.doubleValue()); + return doubleWritable; + } + + protected BigDecimal trunc(BigDecimal input, int scale) { + BigDecimal output = new BigDecimal(0); + BigDecimal pow = BigDecimal.valueOf(Math.pow(10, Math.abs(scale))); + if (scale >= 0) { + pow = BigDecimal.valueOf(Math.pow(10, scale)); + if (scale != 0) { + long longValue = input.multiply(pow).longValue(); + output = BigDecimal.valueOf(longValue).divide(pow); + } else { + output = BigDecimal.valueOf(input.longValue()); + } + } else { + long longValue2 = input.divide(pow).longValue(); + output = BigDecimal.valueOf(longValue2).multiply(pow); + } + return output; + } + + @Override + public String getDisplayString(String[] children) { + return getStandardDisplayString("trunc", children); + } + +} diff --git ql/src/test/queries/clientpositive/udf_trunc2.q ql/src/test/queries/clientpositive/udf_trunc2.q new file mode 100644 index 0000000..74388c4 --- /dev/null +++ ql/src/test/queries/clientpositive/udf_trunc2.q @@ -0,0 +1,7 @@ +set hive.fetch.task.conversion=more; + +DESCRIBE FUNCTION trunc2; +DESCRIBE FUNCTION EXTENDED trunc2; + +SELECT trunc2(1234567891.1234567891,4), trunc2(1234567891.1234567891,-4), trunc2(1234567891.1234567891,0), trunc2(1234567891.1234567891) +FROM src tablesample (1 rows); diff --git ql/src/test/results/clientpositive/udf_trunc2.q.out ql/src/test/results/clientpositive/udf_trunc2.q.out new file mode 100644 index 0000000..5943f9a --- /dev/null +++ ql/src/test/results/clientpositive/udf_trunc2.q.out @@ -0,0 +1,21 @@ +PREHOOK: query: DESCRIBE FUNCTION trunc2 +PREHOOK: type: DESCFUNCTION +POSTHOOK: query: DESCRIBE FUNCTION trunc2 +POSTHOOK: type: DESCFUNCTION +There is no documentation for function 'trunc2' +PREHOOK: query: DESCRIBE FUNCTION EXTENDED trunc2 +PREHOOK: type: DESCFUNCTION +POSTHOOK: query: DESCRIBE FUNCTION EXTENDED trunc2 +POSTHOOK: type: DESCFUNCTION +There is no documentation for function 'trunc2' +PREHOOK: query: SELECT trunc2(1234567891.1234567891,4), trunc2(1234567891.1234567891,-4), trunc2(1234567891.1234567891,0), trunc2(1234567891.1234567891) +FROM src tablesample (1 rows) +PREHOOK: type: QUERY +PREHOOK: Input: default@src +#### A masked pattern was here #### +POSTHOOK: query: SELECT trunc2(1234567891.1234567891,4), trunc2(1234567891.1234567891,-4), trunc2(1234567891.1234567891,0), trunc2(1234567891.1234567891) +FROM src tablesample (1 rows) +POSTHOOK: type: QUERY +POSTHOOK: Input: default@src +#### A masked pattern was here #### +1234567891.1234 1234560000 1234567891 1234567891