diff --git a/ql/src/java/org/apache/hadoop/hive/ql/exec/FunctionRegistry.java b/ql/src/java/org/apache/hadoop/hive/ql/exec/FunctionRegistry.java index 31d786b..7d0fc2f 100644 --- a/ql/src/java/org/apache/hadoop/hive/ql/exec/FunctionRegistry.java +++ b/ql/src/java/org/apache/hadoop/hive/ql/exec/FunctionRegistry.java @@ -48,7 +48,6 @@ import org.apache.hadoop.hive.ql.udf.GenericUDFEncode; import org.apache.hadoop.hive.ql.udf.SettableUDF; import org.apache.hadoop.hive.ql.udf.UDAFPercentile; -import org.apache.hadoop.hive.ql.udf.UDFAbs; import org.apache.hadoop.hive.ql.udf.UDFAcos; import org.apache.hadoop.hive.ql.udf.UDFAscii; import org.apache.hadoop.hive.ql.udf.UDFAsin; @@ -211,7 +210,7 @@ registerUDF("ceil", UDFCeil.class, false); registerUDF("ceiling", UDFCeil.class, false); registerUDF("rand", UDFRand.class, false); - registerUDF("abs", UDFAbs.class, false); + registerGenericUDF("abs", GenericUDFAbs.class); registerUDF("pmod", UDFPosMod.class, false); registerUDF("ln", UDFLn.class, false); diff --git a/ql/src/java/org/apache/hadoop/hive/ql/udf/UDFAbs.java b/ql/src/java/org/apache/hadoop/hive/ql/udf/UDFAbs.java deleted file mode 100644 index acaaa5b..0000000 --- a/ql/src/java/org/apache/hadoop/hive/ql/udf/UDFAbs.java +++ /dev/null @@ -1,82 +0,0 @@ -/** - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.hadoop.hive.ql.udf; - -import org.apache.hadoop.hive.ql.exec.Description; -import org.apache.hadoop.hive.ql.exec.UDF; -import org.apache.hadoop.hive.serde2.io.HiveDecimalWritable; -import org.apache.hadoop.hive.serde2.io.DoubleWritable; -import org.apache.hadoop.io.IntWritable; -import org.apache.hadoop.io.LongWritable; - -/** - * UDFAbs. - * - */ -@Description(name = "abs", - value = "_FUNC_(x) - returns the absolute value of x", - extended = "Example:\n" - + " > SELECT _FUNC_(0) FROM src LIMIT 1;\n" - + " 0\n" - + " > SELECT _FUNC_(-5) FROM src LIMIT 1;\n" + " 5") -public class UDFAbs extends UDF { - private final DoubleWritable resultDouble = new DoubleWritable(); - private final LongWritable resultLong = new LongWritable(); - private final IntWritable resultInt = new IntWritable(); - private final HiveDecimalWritable resultHiveDecimal = new HiveDecimalWritable(); - - public DoubleWritable evaluate(DoubleWritable n) { - if (n == null) { - return null; - } - - resultDouble.set(Math.abs(n.get())); - - return resultDouble; - } - - public LongWritable evaluate(LongWritable n) { - if (n == null) { - return null; - } - - resultLong.set(Math.abs(n.get())); - - return resultLong; - } - - public IntWritable evaluate(IntWritable n) { - if (n == null) { - return null; - } - - resultInt.set(Math.abs(n.get())); - - return resultInt; - } - - public HiveDecimalWritable evaluate(HiveDecimalWritable n) { - if (n == null) { - return null; - } - - resultHiveDecimal.set(n.getHiveDecimal().abs()); - return resultHiveDecimal; - } -} diff --git a/ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDFAbs.java b/ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDFAbs.java new file mode 100644 index 0000000..9d6fcea --- /dev/null +++ b/ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDFAbs.java @@ -0,0 +1,135 @@ +package org.apache.hadoop.hive.ql.udf.generic; + +import org.apache.hadoop.hive.common.type.HiveDecimal; +import org.apache.hadoop.hive.ql.exec.Description; +import org.apache.hadoop.hive.ql.exec.UDFArgumentException; +import org.apache.hadoop.hive.ql.exec.UDFArgumentLengthException; +import org.apache.hadoop.hive.ql.metadata.HiveException; +import org.apache.hadoop.hive.serde2.io.DoubleWritable; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector.Category; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorConverters; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorConverters.Converter; +import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector.PrimitiveCategory; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils; +import org.apache.hadoop.io.IntWritable; +import org.apache.hadoop.io.LongWritable; + +/** + * GenericUDFAbs. + * + */ +@Description(name = "abs", + value = "_FUNC_(x) - returns the absolute value of x", + extended = "Example:\n" + + " > SELECT _FUNC_(0) FROM src LIMIT 1;\n" + + " 0\n" + + " > SELECT _FUNC_(-5) FROM src LIMIT 1;\n" + " 5") +public class GenericUDFAbs extends GenericUDF { + private transient PrimitiveCategory inputType; + private final DoubleWritable resultDouble = new DoubleWritable(); + private final LongWritable resultLong = new LongWritable(); + private final IntWritable resultInt = new IntWritable(); + private transient PrimitiveObjectInspector argumentOI; + private transient Converter inputConverter; + + @Override + public ObjectInspector initialize(ObjectInspector[] arguments) throws UDFArgumentException { + if (arguments.length != 1) { + throw new UDFArgumentLengthException( + "ABS() requires 1 argument, got " + arguments.length); + } + + if (arguments[0].getCategory() != Category.PRIMITIVE) { + throw new UDFArgumentException( + "ABS only takes primitive types, got " + arguments[0].getTypeName()); + } + argumentOI = (PrimitiveObjectInspector) arguments[0]; + + inputType = argumentOI.getPrimitiveCategory(); + ObjectInspector outputOI = null; + switch (inputType) { + case SHORT: + case BYTE: + case INT: + inputConverter = ObjectInspectorConverters.getConverter(arguments[0], + PrimitiveObjectInspectorFactory.writableIntObjectInspector); + outputOI = PrimitiveObjectInspectorFactory.writableIntObjectInspector; + break; + case LONG: + inputConverter = ObjectInspectorConverters.getConverter(arguments[0], + PrimitiveObjectInspectorFactory.writableLongObjectInspector); + outputOI = PrimitiveObjectInspectorFactory.writableLongObjectInspector; + break; + case FLOAT: + case STRING: + case DOUBLE: + inputConverter = ObjectInspectorConverters.getConverter(arguments[0], + PrimitiveObjectInspectorFactory.writableDoubleObjectInspector); + outputOI = PrimitiveObjectInspectorFactory.writableDoubleObjectInspector; + break; + case DECIMAL: + inputConverter = ObjectInspectorConverters.getConverter(arguments[0], + PrimitiveObjectInspectorFactory.writableHiveDecimalObjectInspector); + outputOI = PrimitiveObjectInspectorFactory.writableHiveDecimalObjectInspector; + break; + default: + throw new UDFArgumentException( + "ABS only takes SHORT/BYTE/INT/LONG/DOUBLE/FLOAT/STRING/DECIMAL types, got " + inputType); + } + return outputOI; + } + + @Override + public Object evaluate(DeferredObject[] arguments) throws HiveException { + Object valObject = arguments[0].get(); + if (valObject == null) { + return null; + } + switch (inputType) { + case SHORT: + case BYTE: + case INT: + valObject = inputConverter.convert(valObject); + resultInt.set(Math.abs(((IntWritable) valObject).get())); + return resultInt; + case LONG: + valObject = inputConverter.convert(valObject); + resultLong.set(Math.abs(((LongWritable) valObject).get())); + return resultLong; + case FLOAT: + case STRING: + case DOUBLE: + valObject = inputConverter.convert(valObject); + resultDouble.set(Math.abs(((DoubleWritable) valObject).get())); + return resultDouble; + case DECIMAL: + return PrimitiveObjectInspectorFactory.writableHiveDecimalObjectInspector.set( + PrimitiveObjectInspectorFactory.writableHiveDecimalObjectInspector + .create(HiveDecimal.ZERO), + PrimitiveObjectInspectorUtils.getHiveDecimal(valObject, + argumentOI).abs()); + default: + throw new UDFArgumentException( + "ABS only takes SHORT/BYTE/INT/LONG/DOUBLE/FLOAT/STRING/DECIMAL types, got " + inputType); + } + } + + @Override + public String getDisplayString(String[] children) { + StringBuilder sb = new StringBuilder(); + sb.append("abs("); + if (children.length > 0) { + sb.append(children[0]); + for (int i = 1; i < children.length; i++) { + sb.append(","); + sb.append(children[i]); + } + } + sb.append(")"); + return sb.toString(); + } + +} diff --git a/ql/src/test/org/apache/hadoop/hive/ql/udf/TestGenericUDFAbs.java b/ql/src/test/org/apache/hadoop/hive/ql/udf/TestGenericUDFAbs.java new file mode 100644 index 0000000..2de4499 --- /dev/null +++ b/ql/src/test/org/apache/hadoop/hive/ql/udf/TestGenericUDFAbs.java @@ -0,0 +1,139 @@ +package org.apache.hadoop.hive.ql.udf; + +import junit.framework.TestCase; + +import org.apache.hadoop.hive.common.type.HiveDecimal; +import org.apache.hadoop.hive.ql.metadata.HiveException; +import org.apache.hadoop.hive.ql.udf.generic.GenericUDF.DeferredJavaObject; +import org.apache.hadoop.hive.ql.udf.generic.GenericUDF.DeferredObject; +import org.apache.hadoop.hive.ql.udf.generic.GenericUDFAbs; +import org.apache.hadoop.hive.serde2.io.DoubleWritable; +import org.apache.hadoop.hive.serde2.io.HiveDecimalWritable; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory; +import org.apache.hadoop.io.FloatWritable; +import org.apache.hadoop.io.IntWritable; +import org.apache.hadoop.io.LongWritable; +import org.apache.hadoop.io.Text; + +public class TestGenericUDFAbs extends TestCase { + + public void testInt() throws HiveException { + GenericUDFAbs udf = new GenericUDFAbs(); + ObjectInspector valueOI = PrimitiveObjectInspectorFactory.writableIntObjectInspector; + ObjectInspector[] arguments = {valueOI}; + + udf.initialize(arguments); + DeferredObject valueObj = new DeferredJavaObject(new IntWritable(107)); + DeferredObject[] args = {valueObj}; + IntWritable output = (IntWritable) udf.evaluate(args); + + assertEquals("abs() test for INT failed ", 107, output.get()); + + valueObj = new DeferredJavaObject(new IntWritable(-107)); + args[0] = valueObj; + output = (IntWritable) udf.evaluate(args); + + assertEquals("abs() test for INT failed ", 107, output.get()); + } + + public void testLong() throws HiveException { + GenericUDFAbs udf = new GenericUDFAbs(); + ObjectInspector valueOI = PrimitiveObjectInspectorFactory.writableLongObjectInspector; + ObjectInspector[] arguments = {valueOI}; + + udf.initialize(arguments); + DeferredObject valueObj = new DeferredJavaObject(new LongWritable(107L)); + DeferredObject[] args = {valueObj}; + LongWritable output = (LongWritable) udf.evaluate(args); + + assertEquals("abs() test for LONG failed ", 107, output.get()); + + valueObj = new DeferredJavaObject(new LongWritable(-107L)); + args[0] = valueObj; + output = (LongWritable) udf.evaluate(args); + + assertEquals("abs() test for LONG failed ", 107, output.get()); + } + + public void testDouble() throws HiveException { + GenericUDFAbs udf = new GenericUDFAbs(); + ObjectInspector valueOI = PrimitiveObjectInspectorFactory.writableDoubleObjectInspector; + ObjectInspector[] arguments = {valueOI}; + + udf.initialize(arguments); + DeferredObject valueObj = new DeferredJavaObject(new DoubleWritable(107.78)); + DeferredObject[] args = {valueObj}; + DoubleWritable output = (DoubleWritable) udf.evaluate(args); + + assertEquals("abs() test for Double failed ", 107.78, output.get()); + + valueObj = new DeferredJavaObject(new DoubleWritable(-107.78)); + args[0] = valueObj; + output = (DoubleWritable) udf.evaluate(args); + + assertEquals("abs() test for Double failed ", 107.78, output.get()); + } + + public void testFloat() throws HiveException { + GenericUDFAbs udf = new GenericUDFAbs(); + ObjectInspector valueOI = PrimitiveObjectInspectorFactory.writableFloatObjectInspector; + ObjectInspector[] arguments = {valueOI}; + + udf.initialize(arguments); + DeferredObject valueObj = new DeferredJavaObject(new FloatWritable(107.78f)); + DeferredObject[] args = {valueObj}; + DoubleWritable output = (DoubleWritable) udf.evaluate(args); + + // Make sure flow and double equality compare works + assertTrue("abs() test for Float failed ", Math.abs(107.78 - output.get()) < 0.0001); + + valueObj = new DeferredJavaObject(new FloatWritable(-107.78f)); + args[0] = valueObj; + output = (DoubleWritable) udf.evaluate(args); + + assertTrue("abs() test for Float failed ", Math.abs(107.78 - output.get()) < 0.0001); + } + + + public void testText() throws HiveException { + GenericUDFAbs udf = new GenericUDFAbs(); + ObjectInspector valueOI = PrimitiveObjectInspectorFactory.writableStringObjectInspector; + ObjectInspector[] arguments = {valueOI}; + + udf.initialize(arguments); + DeferredObject valueObj = new DeferredJavaObject(new Text("123.45")); + DeferredObject[] args = {valueObj}; + DoubleWritable output = (DoubleWritable) udf.evaluate(args); + + assertEquals("abs() test for String failed ", "123.45", output.toString()); + + valueObj = new DeferredJavaObject(new Text("-123.45")); + args[0] = valueObj; + output = (DoubleWritable) udf.evaluate(args); + + assertEquals("abs() test for String failed ", "123.45", output.toString()); + } + + public void testHiveDecimal() throws HiveException { + GenericUDFAbs udf = new GenericUDFAbs(); + ObjectInspector valueOI = PrimitiveObjectInspectorFactory.writableHiveDecimalObjectInspector; + ObjectInspector[] arguments = {valueOI}; + + udf.initialize(arguments); + DeferredObject valueObj = new DeferredJavaObject(new HiveDecimalWritable(new HiveDecimal( + "107.123456789"))); + DeferredObject[] args = {valueObj}; + HiveDecimalWritable output = (HiveDecimalWritable) udf.evaluate(args); + + assertEquals("abs() test for HiveDecimal failed ", 107.123456789, output.getHiveDecimal() + .doubleValue()); + + valueObj = new DeferredJavaObject(new HiveDecimalWritable(new HiveDecimal("-107.123456789"))); + args[0] = valueObj; + output = (HiveDecimalWritable) udf.evaluate(args); + + assertEquals("abs() test for HiveDecimal failed ", 107.123456789, output.getHiveDecimal() + .doubleValue()); + } +}