diff --git ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDFTrunc.java ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDFTrunc.java index e20ad65..0f6c035 100644 --- ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDFTrunc.java +++ ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDFTrunc.java @@ -19,47 +19,65 @@ package org.apache.hadoop.hive.ql.udf.generic; +import java.math.BigDecimal; import java.sql.Timestamp; import java.text.ParseException; import java.text.SimpleDateFormat; import java.util.Calendar; import java.util.Date; +import org.apache.hadoop.hive.common.type.HiveDecimal; import org.apache.hadoop.hive.ql.exec.Description; import org.apache.hadoop.hive.ql.exec.UDFArgumentException; import org.apache.hadoop.hive.ql.exec.UDFArgumentLengthException; import org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException; import org.apache.hadoop.hive.ql.metadata.HiveException; +import org.apache.hadoop.hive.serde2.io.ByteWritable; import org.apache.hadoop.hive.serde2.io.DateWritable; +import org.apache.hadoop.hive.serde2.io.DoubleWritable; +import org.apache.hadoop.hive.serde2.io.HiveDecimalWritable; +import org.apache.hadoop.hive.serde2.io.ShortWritable; import org.apache.hadoop.hive.serde2.io.TimestampWritable; import org.apache.hadoop.hive.serde2.objectinspector.ConstantObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector.Category; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorConverters; -import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorConverters.Converter; +import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector.PrimitiveCategory; -import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory; import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorConverter.TimestampConverter; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory; import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils; import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils.PrimitiveGrouping; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.WritableConstantByteObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.WritableConstantIntObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.WritableConstantLongObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.WritableConstantShortObjectInspector; +import org.apache.hadoop.io.FloatWritable; +import org.apache.hadoop.io.IntWritable; +import org.apache.hadoop.io.LongWritable; import org.apache.hadoop.io.Text; /** * GenericUDFTrunc. * - * Returns the first day of the month which the date belongs to. - * The time part of the date will be ignored. + * Returns the first day of the month which the date belongs to. The time part of the date will be + * ignored. * */ -@Description(name = "trunc", -value = "_FUNC_(date, fmt) - Returns returns date with the time portion of the day truncated " +@Description(name = "trunc", value = "_FUNC_(date, fmt) / _FUNC_(N,D) - Returns If input is date returns date with the time portion of the day truncated " + "to the unit specified by the format model fmt. If you omit fmt, then date is truncated to " - + "the nearest day. It now only supports 'MONTH'/'MON'/'MM' and 'YEAR'/'YYYY'/'YY' as format.", -extended = "date is a string in the format 'yyyy-MM-dd HH:mm:ss' or 'yyyy-MM-dd'." - + " The time part of date is ignored.\n" - + "Example:\n " - + " > SELECT _FUNC_('2009-02-12', 'MM');\n" + "OK\n" + " '2009-02-01'" + "\n" - + " > SELECT _FUNC_('2015-10-27', 'YEAR');\n" + "OK\n" + " '2015-01-01'") + + "the nearest day. It now only supports 'MONTH'/'MON'/'MM' and 'YEAR'/'YYYY'/'YY' as format." + + "If input is a number group returns N truncated to D decimal places. If D is omitted, then N is truncated to 0 places." + + "D can be negative to truncate (make zero) D digits left of the decimal point." + , extended = "date is a string in the format 'yyyy-MM-dd HH:mm:ss' or 'yyyy-MM-dd'." + + " The time part of date is ignored.\n" + "Example:\n " + + " > SELECT _FUNC_('2009-02-12', 'MM');\n" + "OK\n" + " '2009-02-01'" + "\n" + + " > SELECT _FUNC_('2015-10-27', 'YEAR');\n" + "OK\n" + " '2015-01-01'" + + " > SELECT _FUNC_(1234567891.1234567891,4);\n" + "OK\n" + " 1234567891.1234" + "\n" + + " > SELECT _FUNC_(1234567891.1234567891,-4);\n" + "OK\n" + " 1234560000" + + " > SELECT _FUNC_(1234567891.1234567891,0);\n" + "OK\n" + " 1234567891" + "\n" + + " > SELECT _FUNC_(1234567891.1234567891);\n" + "OK\n" + " 1234567891") public class GenericUDFTrunc extends GenericUDF { private transient SimpleDateFormat formatter = new SimpleDateFormat("yyyy-MM-dd"); @@ -72,9 +90,127 @@ private final Calendar calendar = Calendar.getInstance(); private final Text output = new Text(); private transient String fmtInput; + private transient PrimitiveObjectInspector inputOI; + private int scale = 0; + private boolean dateTypeArg; @Override public ObjectInspector initialize(ObjectInspector[] arguments) throws UDFArgumentException { + if (arguments.length == 2) { + inputType1 = ((PrimitiveObjectInspector) arguments[0]).getPrimitiveCategory(); + inputType2 = ((PrimitiveObjectInspector) arguments[1]).getPrimitiveCategory(); + if (PrimitiveObjectInspectorUtils + .getPrimitiveGrouping(inputType1) == PrimitiveGrouping.NUMERIC_GROUP + && PrimitiveObjectInspectorUtils + .getPrimitiveGrouping(inputType2) == PrimitiveGrouping.NUMERIC_GROUP) { + dateTypeArg = false; + return initializeNumber(arguments); + } else { + dateTypeArg = true; + return initializeDate(arguments); + } + } else if (arguments.length == 1) { + inputType1 = ((PrimitiveObjectInspector) arguments[0]).getPrimitiveCategory(); + if (PrimitiveObjectInspectorUtils + .getPrimitiveGrouping(inputType1) == PrimitiveGrouping.NUMERIC_GROUP) { + dateTypeArg = false; + return initializeNumber(arguments); + }else{ + throw new UDFArgumentException( + "Only primitive type arguments are accepted, when arguments lenght is one, got " + + arguments[1].getTypeName()); + } + } + throw new UDFArgumentException("TRUNC requires one or two argument, got " + arguments.length); + } + + private ObjectInspector initializeNumber(ObjectInspector[] arguments) + throws UDFArgumentException { + if (arguments.length < 1 || arguments.length > 2) { + throw new UDFArgumentLengthException( + "TRUNC requires one or two argument, got " + arguments.length); + } + + if (arguments[0].getCategory() != Category.PRIMITIVE) { + throw new UDFArgumentTypeException(0, + "TRUNC input only takes primitive types, got " + arguments[0].getTypeName()); + } + inputOI = (PrimitiveObjectInspector) arguments[0]; + + if (arguments.length == 2) { + if (arguments[1].getCategory() != Category.PRIMITIVE) { + throw new UDFArgumentTypeException(1, + "TRUNC second argument only takes primitive types, got " + arguments[1].getTypeName()); + } + PrimitiveObjectInspector scaleOI = (PrimitiveObjectInspector) arguments[1]; + switch (scaleOI.getPrimitiveCategory()) { + case VOID: + break; + case BYTE: + if (!(scaleOI instanceof WritableConstantByteObjectInspector)) { + throw new UDFArgumentTypeException(1, + getFuncName().toUpperCase() + " second argument only takes constant"); + } + scale = ((WritableConstantByteObjectInspector) scaleOI).getWritableConstantValue().get(); + break; + case SHORT: + if (!(scaleOI instanceof WritableConstantShortObjectInspector)) { + throw new UDFArgumentTypeException(1, + getFuncName().toUpperCase() + " second argument only takes constant"); + } + scale = ((WritableConstantShortObjectInspector) scaleOI).getWritableConstantValue().get(); + break; + case INT: + if (!(scaleOI instanceof WritableConstantIntObjectInspector)) { + throw new UDFArgumentTypeException(1, + getFuncName().toUpperCase() + " second argument only takes constant"); + } + scale = ((WritableConstantIntObjectInspector) scaleOI).getWritableConstantValue().get(); + break; + case LONG: + if (!(scaleOI instanceof WritableConstantLongObjectInspector)) { + throw new UDFArgumentTypeException(1, + getFuncName().toUpperCase() + " second argument only takes constant"); + } + long l = ((WritableConstantLongObjectInspector) scaleOI).getWritableConstantValue().get(); + if (l < Integer.MIN_VALUE || l > Integer.MAX_VALUE) { + throw new UDFArgumentException( + getFuncName().toUpperCase() + " scale argument out of allowed range"); + } + scale = (int) l; + break; + default: + throw new UDFArgumentTypeException(1, + getFuncName().toUpperCase() + " second argument only takes integer constant"); + } + } + + inputType1 = inputOI.getPrimitiveCategory(); + ObjectInspector outputOI = null; + switch (inputType1) { + case DECIMAL: + outputOI = PrimitiveObjectInspectorFactory.getPrimitiveWritableObjectInspector(inputType1); + break; + case VOID: + case BYTE: + case SHORT: + case INT: + case LONG: + case FLOAT: + case DOUBLE: + outputOI = PrimitiveObjectInspectorFactory.getPrimitiveWritableObjectInspector(inputType1); + break; + default: + throw new UDFArgumentTypeException(0, + "Only numeric or string group data types are allowed for TRUNC function. Got " + + inputType1.name()); + } + + return outputOI; + } + + private ObjectInspector initializeDate(ObjectInspector[] arguments) + throws UDFArgumentLengthException, UDFArgumentTypeException { if (arguments.length != 2) { throw new UDFArgumentLengthException("trunc() requires 2 argument, got " + arguments.length); } @@ -97,8 +233,7 @@ public ObjectInspector initialize(ObjectInspector[] arguments) throws UDFArgumen case CHAR: case VOID: inputType1 = PrimitiveCategory.STRING; - textConverter1 = ObjectInspectorConverters.getConverter( - arguments[0], + textConverter1 = ObjectInspectorConverters.getConverter(arguments[0], PrimitiveObjectInspectorFactory.writableStringObjectInspector); break; case TIMESTAMP: @@ -106,8 +241,7 @@ public ObjectInspector initialize(ObjectInspector[] arguments) throws UDFArgumen PrimitiveObjectInspectorFactory.writableTimestampObjectInspector); break; case DATE: - dateWritableConverter = ObjectInspectorConverters.getConverter( - arguments[0], + dateWritableConverter = ObjectInspectorConverters.getConverter(arguments[0], PrimitiveObjectInspectorFactory.writableDateObjectInspector); break; default: @@ -117,9 +251,10 @@ public ObjectInspector initialize(ObjectInspector[] arguments) throws UDFArgumen } inputType2 = ((PrimitiveObjectInspector) arguments[1]).getPrimitiveCategory(); - if (PrimitiveObjectInspectorUtils.getPrimitiveGrouping(inputType2) - != PrimitiveGrouping.STRING_GROUP && PrimitiveObjectInspectorUtils.getPrimitiveGrouping(inputType2) - != PrimitiveGrouping.VOID_GROUP) { + if (PrimitiveObjectInspectorUtils + .getPrimitiveGrouping(inputType2) != PrimitiveGrouping.STRING_GROUP + && PrimitiveObjectInspectorUtils + .getPrimitiveGrouping(inputType2) != PrimitiveGrouping.VOID_GROUP) { throw new UDFArgumentTypeException(1, "trunk() only takes STRING/CHAR/VARCHAR types as second argument, got " + inputType2); } @@ -130,16 +265,23 @@ public ObjectInspector initialize(ObjectInspector[] arguments) throws UDFArgumen Object obj = ((ConstantObjectInspector) arguments[1]).getWritableConstantValue(); fmtInput = obj != null ? obj.toString() : null; } else { - textConverter2 = ObjectInspectorConverters.getConverter( - arguments[1], + textConverter2 = ObjectInspectorConverters.getConverter(arguments[1], PrimitiveObjectInspectorFactory.writableStringObjectInspector); } - return outputOI; } @Override public Object evaluate(DeferredObject[] arguments) throws HiveException { + if (dateTypeArg) { + return evaluateDate(arguments); + } else { + return evaluateNumber(arguments); + } + } + + private Object evaluateDate(DeferredObject[] arguments) throws UDFArgumentLengthException, + HiveException, UDFArgumentTypeException, UDFArgumentException { if (arguments.length != 2) { throw new UDFArgumentLengthException("trunc() requires 2 argument, got " + arguments.length); } @@ -163,8 +305,8 @@ public Object evaluate(DeferredObject[] arguments) throws HiveException { } break; case TIMESTAMP: - Timestamp ts = ((TimestampWritable) timestampConverter.convert(arguments[0].get())) - .getTimestamp(); + Timestamp ts = + ((TimestampWritable) timestampConverter.convert(arguments[0].get())).getTimestamp(); date = ts; break; case DATE: @@ -185,6 +327,72 @@ public Object evaluate(DeferredObject[] arguments) throws HiveException { return output; } + private Object evaluateNumber(DeferredObject[] arguments) + throws HiveException, UDFArgumentTypeException { + if (arguments.length == 2 && (arguments[1] == null || arguments[1].get() == null)) { + return null; + } + + if (arguments[0] == null) { + return null; + } + + Object input = arguments[0].get(); + if (input == null) { + return null; + } + + switch (inputType1) { + case VOID: + return null; + case DECIMAL: + HiveDecimalWritable decimalWritable = + (HiveDecimalWritable) inputOI.getPrimitiveWritableObject(input); + HiveDecimal dec = trunc(decimalWritable.getHiveDecimal(), scale); + if (dec == null) { + return null; + } + return new HiveDecimalWritable(dec); + case BYTE: + ByteWritable byteWritable = (ByteWritable) inputOI.getPrimitiveWritableObject(input); + if (scale >= 0) { + return byteWritable; + } else { + return new ByteWritable((byte) trunc(byteWritable.get(), scale)); + } + case SHORT: + ShortWritable shortWritable = (ShortWritable) inputOI.getPrimitiveWritableObject(input); + if (scale >= 0) { + return shortWritable; + } else { + return new ShortWritable((short) trunc(shortWritable.get(), scale)); + } + case INT: + IntWritable intWritable = (IntWritable) inputOI.getPrimitiveWritableObject(input); + if (scale >= 0) { + return intWritable; + } else { + return new IntWritable((int) trunc(intWritable.get(), scale)); + } + case LONG: + LongWritable longWritable = (LongWritable) inputOI.getPrimitiveWritableObject(input); + if (scale >= 0) { + return longWritable; + } else { + return new LongWritable(trunc(longWritable.get(), scale)); + } + case FLOAT: + float f = ((FloatWritable) inputOI.getPrimitiveWritableObject(input)).get(); + return new FloatWritable((float) trunc(f, scale)); + case DOUBLE: + return trunc(((DoubleWritable) inputOI.getPrimitiveWritableObject(input)), scale); + default: + throw new UDFArgumentTypeException(0, + "Only numeric or string group data types are allowed for TRUNC function. Got " + + inputType1.name()); + } + } + @Override public String getDisplayString(String[] children) { return getStandardDisplayString("trunc", children); @@ -203,4 +411,43 @@ private Calendar evalDate(Date d) throws UDFArgumentException { return null; } } + + protected HiveDecimal trunc(HiveDecimal input, int scale) { + BigDecimal bigDecimal = trunc(input.bigDecimalValue(), scale); + return HiveDecimal.create(bigDecimal); + } + + protected long trunc(long input, int scale) { + return trunc(BigDecimal.valueOf(input), scale).longValue(); + } + + protected double trunc(double input, int scale) { + return trunc(BigDecimal.valueOf(input), scale).doubleValue(); + } + + protected DoubleWritable trunc(DoubleWritable input, int scale) { + BigDecimal bigDecimal = new BigDecimal(input.get()); + BigDecimal trunc = trunc(bigDecimal, scale); + DoubleWritable doubleWritable = new DoubleWritable(trunc.doubleValue()); + return doubleWritable; + } + + protected BigDecimal trunc(BigDecimal input, int scale) { + BigDecimal output = new BigDecimal(0); + BigDecimal pow = BigDecimal.valueOf(Math.pow(10, Math.abs(scale))); + if (scale >= 0) { + pow = BigDecimal.valueOf(Math.pow(10, scale)); + if (scale != 0) { + long longValue = input.multiply(pow).longValue(); + output = BigDecimal.valueOf(longValue).divide(pow); + } else { + output = BigDecimal.valueOf(input.longValue()); + } + } else { + long longValue2 = input.divide(pow).longValue(); + output = BigDecimal.valueOf(longValue2).multiply(pow); + } + return output; + } + } \ No newline at end of file diff --git ql/src/test/queries/clientpositive/udf_trunc_number.q ql/src/test/queries/clientpositive/udf_trunc_number.q new file mode 100644 index 0000000..b5c4015 --- /dev/null +++ ql/src/test/queries/clientpositive/udf_trunc_number.q @@ -0,0 +1,7 @@ +set hive.fetch.task.conversion=more; + +DESCRIBE FUNCTION trunc; +DESCRIBE FUNCTION EXTENDED trunc; + +SELECT trunc(1234567891.1234567891,4), trunc(1234567891.1234567891,-4), trunc(1234567891.1234567891,0), trunc(1234567891.1234567891) +FROM src tablesample (1 rows); diff --git ql/src/test/results/clientpositive/udf_trunc_number.q.out ql/src/test/results/clientpositive/udf_trunc_number.q.out new file mode 100644 index 0000000..3651054 --- /dev/null +++ ql/src/test/results/clientpositive/udf_trunc_number.q.out @@ -0,0 +1,29 @@ +PREHOOK: query: DESCRIBE FUNCTION trunc +PREHOOK: type: DESCFUNCTION +POSTHOOK: query: DESCRIBE FUNCTION trunc +POSTHOOK: type: DESCFUNCTION +trunc(date, fmt) - Returns returns date with the time portion of the day truncated to the unit specified by the format model fmt. If you omit fmt, then date is truncated to the nearest day. It now only supports 'MONTH'/'MON'/'MM' and 'YEAR'/'YYYY'/'YY' as format. +PREHOOK: query: DESCRIBE FUNCTION EXTENDED trunc +PREHOOK: type: DESCFUNCTION +POSTHOOK: query: DESCRIBE FUNCTION EXTENDED trunc +POSTHOOK: type: DESCFUNCTION +trunc(date, fmt) - Returns returns date with the time portion of the day truncated to the unit specified by the format model fmt. If you omit fmt, then date is truncated to the nearest day. It now only supports 'MONTH'/'MON'/'MM' and 'YEAR'/'YYYY'/'YY' as format. +date is a string in the format 'yyyy-MM-dd HH:mm:ss' or 'yyyy-MM-dd'. The time part of date is ignored. +Example: + > SELECT trunc('2009-02-12', 'MM'); +OK + '2009-02-01' + > SELECT trunc('2015-10-27', 'YEAR'); +OK + '2015-01-01' +PREHOOK: query: SELECT trunc(1234567891.1234567891,4), trunc(1234567891.1234567891,-4), trunc(1234567891.1234567891,0), trunc(1234567891.1234567891) +FROM src tablesample (1 rows) +PREHOOK: type: QUERY +PREHOOK: Input: default@src +#### A masked pattern was here #### +POSTHOOK: query: SELECT trunc(1234567891.1234567891,4), trunc(1234567891.1234567891,-4), trunc(1234567891.1234567891,0), trunc(1234567891.1234567891) +FROM src tablesample (1 rows) +POSTHOOK: type: QUERY +POSTHOOK: Input: default@src +#### A masked pattern was here #### +1234567891.1234 1234560000 1234567891 1234567891