diff --git ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDFAbs.java ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDFAbs.java index ee7b821..d655683 100644 --- ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDFAbs.java +++ ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDFAbs.java @@ -28,12 +28,14 @@ import org.apache.hadoop.hive.ql.exec.vector.expressions.gen.FuncAbsLongToLong; import org.apache.hadoop.hive.ql.metadata.HiveException; import org.apache.hadoop.hive.serde2.io.DoubleWritable; +import org.apache.hadoop.hive.serde2.io.HiveDecimalWritable; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector.Category; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorConverters; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorConverters.Converter; import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector.PrimitiveCategory; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.HiveDecimalObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory; import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils; import org.apache.hadoop.io.IntWritable; @@ -55,6 +57,7 @@ private final DoubleWritable resultDouble = new DoubleWritable(); private final LongWritable resultLong = new LongWritable(); private final IntWritable resultInt = new IntWritable(); + private final HiveDecimalWritable resultDecimal = new HiveDecimalWritable(); private transient PrimitiveObjectInspector argumentOI; private transient Converter inputConverter; @@ -94,9 +97,10 @@ public ObjectInspector initialize(ObjectInspector[] arguments) throws UDFArgumen outputOI = PrimitiveObjectInspectorFactory.writableDoubleObjectInspector; break; case DECIMAL: + outputOI = PrimitiveObjectInspectorFactory.getPrimitiveWritableObjectInspector( + ((PrimitiveObjectInspector) arguments[0]).getTypeInfo()); inputConverter = ObjectInspectorConverters.getConverter(arguments[0], - PrimitiveObjectInspectorFactory.writableHiveDecimalObjectInspector); - outputOI = PrimitiveObjectInspectorFactory.writableHiveDecimalObjectInspector; + outputOI); break; default: throw new UDFArgumentException( @@ -129,11 +133,15 @@ public Object evaluate(DeferredObject[] arguments) throws HiveException { resultDouble.set(Math.abs(((DoubleWritable) valObject).get())); return resultDouble; case DECIMAL: - return PrimitiveObjectInspectorFactory.writableHiveDecimalObjectInspector.set( - PrimitiveObjectInspectorFactory.writableHiveDecimalObjectInspector - .create(HiveDecimal.ZERO), - PrimitiveObjectInspectorUtils.getHiveDecimal(valObject, - argumentOI).abs()); + HiveDecimalObjectInspector decimalOI = + (HiveDecimalObjectInspector) argumentOI; + HiveDecimalWritable val = decimalOI.getPrimitiveWritableObject(valObject); + + if (val != null) { + resultDecimal.set(val.getHiveDecimal().abs()); + val = resultDecimal; + } + return val; default: throw new UDFArgumentException( "ABS only takes SHORT/BYTE/INT/LONG/DOUBLE/FLOAT/STRING/DECIMAL types, got " + inputType); diff --git ql/src/test/org/apache/hadoop/hive/ql/udf/generic/TestGenericUDFAbs.java ql/src/test/org/apache/hadoop/hive/ql/udf/generic/TestGenericUDFAbs.java index 1fe5361..8c531ea 100644 --- ql/src/test/org/apache/hadoop/hive/ql/udf/generic/TestGenericUDFAbs.java +++ ql/src/test/org/apache/hadoop/hive/ql/udf/generic/TestGenericUDFAbs.java @@ -28,7 +28,9 @@ import org.apache.hadoop.hive.serde2.io.DoubleWritable; import org.apache.hadoop.hive.serde2.io.HiveDecimalWritable; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory; +import org.apache.hadoop.hive.serde2.typeinfo.TypeInfoFactory; import org.apache.hadoop.io.FloatWritable; import org.apache.hadoop.io.IntWritable; import org.apache.hadoop.io.LongWritable; @@ -135,10 +137,17 @@ public void testText() throws HiveException { public void testHiveDecimal() throws HiveException { GenericUDFAbs udf = new GenericUDFAbs(); - ObjectInspector valueOI = PrimitiveObjectInspectorFactory.writableHiveDecimalObjectInspector; + int prec = 12; + int scale = 9; + ObjectInspector valueOI = PrimitiveObjectInspectorFactory.getPrimitiveWritableObjectInspector( + TypeInfoFactory.getDecimalTypeInfo(prec, scale)); ObjectInspector[] arguments = {valueOI}; - udf.initialize(arguments); + PrimitiveObjectInspector outputOI = (PrimitiveObjectInspector) udf.initialize(arguments); + // Make sure result precision/scale matches the input prec/scale + assertEquals("result precision for abs()", prec, outputOI.precision()); + assertEquals("result scale for abs()", scale, outputOI.scale()); + DeferredObject valueObj = new DeferredJavaObject(new HiveDecimalWritable(HiveDecimal.create( "107.123456789"))); DeferredObject[] args = {valueObj}; @@ -153,5 +162,15 @@ public void testHiveDecimal() throws HiveException { assertEquals("abs() test for HiveDecimal failed ", 107.123456789, output.getHiveDecimal() .doubleValue()); + + // null input + args[0] = new DeferredJavaObject(null); + output = (HiveDecimalWritable) udf.evaluate(args); + assertEquals("abs(null)", null, output); + + // if value too large, should also be null + args[0] = new DeferredJavaObject(new HiveDecimalWritable(HiveDecimal.create("-1000.123456"))); + output = (HiveDecimalWritable) udf.evaluate(args); + assertEquals("abs() of too large decimal value", null, output); } } diff --git ql/src/test/results/clientpositive/decimal_udf.q.out ql/src/test/results/clientpositive/decimal_udf.q.out index 1a30346..995401a 100644 --- ql/src/test/results/clientpositive/decimal_udf.q.out +++ ql/src/test/results/clientpositive/decimal_udf.q.out @@ -1220,7 +1220,7 @@ STAGE PLANS: alias: decimal_udf Statistics: Num rows: 3 Data size: 359 Basic stats: COMPLETE Column stats: NONE Select Operator - expressions: abs(key) (type: decimal(38,18)) + expressions: abs(key) (type: decimal(20,10)) outputColumnNames: _col0 Statistics: Num rows: 3 Data size: 359 Basic stats: COMPLETE Column stats: NONE ListSink diff --git ql/src/test/results/clientpositive/tez/vector_decimal_math_funcs.q.out ql/src/test/results/clientpositive/tez/vector_decimal_math_funcs.q.out index 9e09f71..d6f0923 100644 --- ql/src/test/results/clientpositive/tez/vector_decimal_math_funcs.q.out +++ ql/src/test/results/clientpositive/tez/vector_decimal_math_funcs.q.out @@ -99,7 +99,7 @@ STAGE PLANS: Filter Operator predicate: (((cbigint % 500) = 0) and (sin(cdecimal1) >= -1.0)) (type: boolean) Select Operator - expressions: cdecimal1 (type: decimal(20,10)), round(cdecimal1, 2) (type: decimal(13,2)), round(cdecimal1) (type: decimal(11,0)), floor(cdecimal1) (type: decimal(11,0)), ceil(cdecimal1) (type: decimal(11,0)), round(exp(cdecimal1), 58) (type: double), ln(cdecimal1) (type: double), log10(cdecimal1) (type: double), log2(cdecimal1) (type: double), log2((cdecimal1 - 15601.0)) (type: double), log(2.0, cdecimal1) (type: double), power(log2(cdecimal1), 2.0) (type: double), power(log2(cdecimal1), 2.0) (type: double), sqrt(cdecimal1) (type: double), abs(cdecimal1) (type: decimal(38,18)), sin(cdecimal1) (type: double), asin(cdecimal1) (type: double), cos(cdecimal1) (type: double), acos(cdecimal1) (type: double), atan(cdecimal1) (type: double), degrees(cdecimal1) (type: double), radians(cdecimal1) (type: double), cdecimal1 (type: decimal(20,10)), (- cdecimal1) (type: decimal(20,10)), sign(cdecimal1) (type: int), cos(((- sin(log(cdecimal1))) + 3.14159)) (type: double) + expressions: cdecimal1 (type: decimal(20,10)), round(cdecimal1, 2) (type: decimal(13,2)), round(cdecimal1) (type: decimal(11,0)), floor(cdecimal1) (type: decimal(11,0)), ceil(cdecimal1) (type: decimal(11,0)), round(exp(cdecimal1), 58) (type: double), ln(cdecimal1) (type: double), log10(cdecimal1) (type: double), log2(cdecimal1) (type: double), log2((cdecimal1 - 15601.0)) (type: double), log(2.0, cdecimal1) (type: double), power(log2(cdecimal1), 2.0) (type: double), power(log2(cdecimal1), 2.0) (type: double), sqrt(cdecimal1) (type: double), abs(cdecimal1) (type: decimal(20,10)), sin(cdecimal1) (type: double), asin(cdecimal1) (type: double), cos(cdecimal1) (type: double), acos(cdecimal1) (type: double), atan(cdecimal1) (type: double), degrees(cdecimal1) (type: double), radians(cdecimal1) (type: double), cdecimal1 (type: decimal(20,10)), (- cdecimal1) (type: decimal(20,10)), sign(cdecimal1) (type: int), cos(((- sin(log(cdecimal1))) + 3.14159)) (type: double) outputColumnNames: _col0, _col1, _col2, _col3, _col4, _col5, _col6, _col7, _col8, _col9, _col10, _col11, _col12, _col13, _col14, _col15, _col16, _col17, _col18, _col19, _col20, _col21, _col22, _col23, _col24, _col25 ListSink diff --git ql/src/test/results/clientpositive/tez/vector_decimal_udf.q.out ql/src/test/results/clientpositive/tez/vector_decimal_udf.q.out index 266b4cb..d87bab7 100644 --- ql/src/test/results/clientpositive/tez/vector_decimal_udf.q.out +++ ql/src/test/results/clientpositive/tez/vector_decimal_udf.q.out @@ -1546,7 +1546,7 @@ STAGE PLANS: alias: decimal_udf Statistics: Num rows: 38 Data size: 4296 Basic stats: COMPLETE Column stats: NONE Select Operator - expressions: abs(key) (type: decimal(38,18)) + expressions: abs(key) (type: decimal(20,10)) outputColumnNames: _col0 Statistics: Num rows: 38 Data size: 4296 Basic stats: COMPLETE Column stats: NONE File Output Operator diff --git ql/src/test/results/clientpositive/vector_decimal_math_funcs.q.out ql/src/test/results/clientpositive/vector_decimal_math_funcs.q.out index 24765af..213b95f 100644 --- ql/src/test/results/clientpositive/vector_decimal_math_funcs.q.out +++ ql/src/test/results/clientpositive/vector_decimal_math_funcs.q.out @@ -101,7 +101,7 @@ STAGE PLANS: predicate: (((cbigint % 500) = 0) and (sin(cdecimal1) >= -1.0)) (type: boolean) Statistics: Num rows: 2048 Data size: 366958 Basic stats: COMPLETE Column stats: NONE Select Operator - expressions: cdecimal1 (type: decimal(20,10)), round(cdecimal1, 2) (type: decimal(13,2)), round(cdecimal1) (type: decimal(11,0)), floor(cdecimal1) (type: decimal(11,0)), ceil(cdecimal1) (type: decimal(11,0)), round(exp(cdecimal1), 58) (type: double), ln(cdecimal1) (type: double), log10(cdecimal1) (type: double), log2(cdecimal1) (type: double), log2((cdecimal1 - 15601.0)) (type: double), log(2.0, cdecimal1) (type: double), power(log2(cdecimal1), 2.0) (type: double), power(log2(cdecimal1), 2.0) (type: double), sqrt(cdecimal1) (type: double), abs(cdecimal1) (type: decimal(38,18)), sin(cdecimal1) (type: double), asin(cdecimal1) (type: double), cos(cdecimal1) (type: double), acos(cdecimal1) (type: double), atan(cdecimal1) (type: double), degrees(cdecimal1) (type: double), radians(cdecimal1) (type: double), cdecimal1 (type: decimal(20,10)), (- cdecimal1) (type: decimal(20,10)), sign(cdecimal1) (type: int), cos(((- sin(log(cdecimal1))) + 3.14159)) (type: double) + expressions: cdecimal1 (type: decimal(20,10)), round(cdecimal1, 2) (type: decimal(13,2)), round(cdecimal1) (type: decimal(11,0)), floor(cdecimal1) (type: decimal(11,0)), ceil(cdecimal1) (type: decimal(11,0)), round(exp(cdecimal1), 58) (type: double), ln(cdecimal1) (type: double), log10(cdecimal1) (type: double), log2(cdecimal1) (type: double), log2((cdecimal1 - 15601.0)) (type: double), log(2.0, cdecimal1) (type: double), power(log2(cdecimal1), 2.0) (type: double), power(log2(cdecimal1), 2.0) (type: double), sqrt(cdecimal1) (type: double), abs(cdecimal1) (type: decimal(20,10)), sin(cdecimal1) (type: double), asin(cdecimal1) (type: double), cos(cdecimal1) (type: double), acos(cdecimal1) (type: double), atan(cdecimal1) (type: double), degrees(cdecimal1) (type: double), radians(cdecimal1) (type: double), cdecimal1 (type: decimal(20,10)), (- cdecimal1) (type: decimal(20,10)), sign(cdecimal1) (type: int), cos(((- sin(log(cdecimal1))) + 3.14159)) (type: double) outputColumnNames: _col0, _col1, _col2, _col3, _col4, _col5, _col6, _col7, _col8, _col9, _col10, _col11, _col12, _col13, _col14, _col15, _col16, _col17, _col18, _col19, _col20, _col21, _col22, _col23, _col24, _col25 Statistics: Num rows: 2048 Data size: 366958 Basic stats: COMPLETE Column stats: NONE File Output Operator diff --git ql/src/test/results/clientpositive/vector_decimal_udf.q.out ql/src/test/results/clientpositive/vector_decimal_udf.q.out index 05cd5b3..cfea2bc 100644 --- ql/src/test/results/clientpositive/vector_decimal_udf.q.out +++ ql/src/test/results/clientpositive/vector_decimal_udf.q.out @@ -1486,7 +1486,7 @@ STAGE PLANS: alias: decimal_udf Statistics: Num rows: 38 Data size: 4296 Basic stats: COMPLETE Column stats: NONE Select Operator - expressions: abs(key) (type: decimal(38,18)) + expressions: abs(key) (type: decimal(20,10)) outputColumnNames: _col0 Statistics: Num rows: 38 Data size: 4296 Basic stats: COMPLETE Column stats: NONE File Output Operator