diff --git a/ql/src/java/org/apache/hadoop/hive/ql/exec/FunctionRegistry.java b/ql/src/java/org/apache/hadoop/hive/ql/exec/FunctionRegistry.java index 31d786b..5aa3559 100644 --- a/ql/src/java/org/apache/hadoop/hive/ql/exec/FunctionRegistry.java +++ b/ql/src/java/org/apache/hadoop/hive/ql/exec/FunctionRegistry.java @@ -211,7 +211,7 @@ registerUDF("ceil", UDFCeil.class, false); registerUDF("ceiling", UDFCeil.class, false); registerUDF("rand", UDFRand.class, false); - registerUDF("abs", UDFAbs.class, false); + registerGenericUDF("abs", GenericUDFAbs.class); registerUDF("pmod", UDFPosMod.class, false); registerUDF("ln", UDFLn.class, false); diff --git a/ql/src/java/org/apache/hadoop/hive/ql/udf/UDFAbs.java b/ql/src/java/org/apache/hadoop/hive/ql/udf/UDFAbs.java deleted file mode 100644 index acaaa5b..0000000 --- a/ql/src/java/org/apache/hadoop/hive/ql/udf/UDFAbs.java +++ /dev/null @@ -1,82 +0,0 @@ -/** - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.hadoop.hive.ql.udf; - -import org.apache.hadoop.hive.ql.exec.Description; -import org.apache.hadoop.hive.ql.exec.UDF; -import org.apache.hadoop.hive.serde2.io.HiveDecimalWritable; -import org.apache.hadoop.hive.serde2.io.DoubleWritable; -import org.apache.hadoop.io.IntWritable; -import org.apache.hadoop.io.LongWritable; - -/** - * UDFAbs. - * - */ -@Description(name = "abs", - value = "_FUNC_(x) - returns the absolute value of x", - extended = "Example:\n" - + " > SELECT _FUNC_(0) FROM src LIMIT 1;\n" - + " 0\n" - + " > SELECT _FUNC_(-5) FROM src LIMIT 1;\n" + " 5") -public class UDFAbs extends UDF { - private final DoubleWritable resultDouble = new DoubleWritable(); - private final LongWritable resultLong = new LongWritable(); - private final IntWritable resultInt = new IntWritable(); - private final HiveDecimalWritable resultHiveDecimal = new HiveDecimalWritable(); - - public DoubleWritable evaluate(DoubleWritable n) { - if (n == null) { - return null; - } - - resultDouble.set(Math.abs(n.get())); - - return resultDouble; - } - - public LongWritable evaluate(LongWritable n) { - if (n == null) { - return null; - } - - resultLong.set(Math.abs(n.get())); - - return resultLong; - } - - public IntWritable evaluate(IntWritable n) { - if (n == null) { - return null; - } - - resultInt.set(Math.abs(n.get())); - - return resultInt; - } - - public HiveDecimalWritable evaluate(HiveDecimalWritable n) { - if (n == null) { - return null; - } - - resultHiveDecimal.set(n.getHiveDecimal().abs()); - return resultHiveDecimal; - } -} diff --git a/ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDFAbs.java b/ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDFAbs.java new file mode 100644 index 0000000..e5dd304 --- /dev/null +++ b/ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDFAbs.java @@ -0,0 +1,110 @@ +package org.apache.hadoop.hive.ql.udf.generic; + +import org.apache.hadoop.hive.ql.exec.Description; +import org.apache.hadoop.hive.ql.exec.UDFArgumentException; +import org.apache.hadoop.hive.ql.exec.UDFArgumentLengthException; +import org.apache.hadoop.hive.ql.metadata.HiveException; +import org.apache.hadoop.hive.serde2.io.DoubleWritable; +import org.apache.hadoop.hive.serde2.io.HiveDecimalWritable; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector.Category; +import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector.PrimitiveCategory; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory; +import org.apache.hadoop.io.IntWritable; +import org.apache.hadoop.io.LongWritable; + +/** + * GenericUDFAbs. + * + */ +@Description(name = "abs", + value = "_FUNC_(x) - returns the absolute value of x", + extended = "Example:\n" + + " > SELECT _FUNC_(0) FROM src LIMIT 1;\n" + + " 0\n" + + " > SELECT _FUNC_(-5) FROM src LIMIT 1;\n" + " 5") +public class GenericUDFAbs extends GenericUDF { + private transient PrimitiveCategory inputType; + private final DoubleWritable resultDouble = new DoubleWritable(); + private final LongWritable resultLong = new LongWritable(); + private final IntWritable resultInt = new IntWritable(); + private final HiveDecimalWritable resultHiveDecimal = new HiveDecimalWritable(); + + @Override + public ObjectInspector initialize(ObjectInspector[] arguments) throws UDFArgumentException { + if (arguments.length != 1) { + throw new UDFArgumentLengthException( + "ABS() requires 1 argument, got " + arguments.length); + } + + if (arguments[0].getCategory() != Category.PRIMITIVE) { + throw new UDFArgumentException( + "ABS only takes primitive types, got " + arguments[0].getTypeName()); + } + PrimitiveObjectInspector argumentOI = (PrimitiveObjectInspector) arguments[0]; + + inputType = argumentOI.getPrimitiveCategory(); + ObjectInspector outputOI = null; + switch (inputType) { + case INT: + outputOI = PrimitiveObjectInspectorFactory.writableIntObjectInspector; + break; + case LONG: + outputOI = PrimitiveObjectInspectorFactory.writableLongObjectInspector; + break; + case DOUBLE: + outputOI = PrimitiveObjectInspectorFactory.writableDoubleObjectInspector; + break; + case DECIMAL: + outputOI = PrimitiveObjectInspectorFactory.writableHiveDecimalObjectInspector; + break; + default: + throw new UDFArgumentException( + "ABS only takes INT/LONG/DOUBLE/DECIMAL types, got " + inputType); + } + return outputOI; + } + + @Override + public Object evaluate(DeferredObject[] arguments) throws HiveException { + Object valObject = arguments[0].get(); + if (valObject == null) { + return null; + } + switch (inputType) { + case INT: + resultInt.set(Math.abs(((IntWritable) valObject).get())); + return resultInt; + case LONG: + resultLong.set(Math.abs(((LongWritable) valObject).get())); + return resultLong; + case DOUBLE: + resultDouble.set(Math.abs(((DoubleWritable) valObject).get())); + return resultDouble; + case DECIMAL: + resultHiveDecimal.set(((HiveDecimalWritable) valObject).getHiveDecimal().abs()); + return resultHiveDecimal; + default: + throw new UDFArgumentException( + "ABS only takes INT/LONG/DOUBLE/DECIMAL types, got " + inputType); + } + + } + + @Override + public String getDisplayString(String[] children) { + StringBuilder sb = new StringBuilder(); + sb.append("abs("); + if (children.length > 0) { + sb.append(children[0]); + for (int i = 1; i < children.length; i++) { + sb.append(","); + sb.append(children[i]); + } + } + sb.append(")"); + return sb.toString(); + } + +} diff --git a/ql/src/test/org/apache/hadoop/hive/ql/udf/TestGenericUDFAbs.java b/ql/src/test/org/apache/hadoop/hive/ql/udf/TestGenericUDFAbs.java new file mode 100644 index 0000000..1ca6b99 --- /dev/null +++ b/ql/src/test/org/apache/hadoop/hive/ql/udf/TestGenericUDFAbs.java @@ -0,0 +1,27 @@ +package org.apache.hadoop.hive.ql.udf; + +import junit.framework.TestCase; + +import org.apache.hadoop.hive.ql.metadata.HiveException; +import org.apache.hadoop.hive.ql.udf.generic.GenericUDF.DeferredJavaObject; +import org.apache.hadoop.hive.ql.udf.generic.GenericUDF.DeferredObject; +import org.apache.hadoop.hive.ql.udf.generic.GenericUDFAbs; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory; +import org.apache.hadoop.io.IntWritable; + +public class TestGenericUDFAbs extends TestCase { + + public void testInt() throws HiveException { + GenericUDFAbs udf = new GenericUDFAbs(); + ObjectInspector valueOI = PrimitiveObjectInspectorFactory.writableIntObjectInspector; + ObjectInspector[] arguments = {valueOI}; + + udf.initialize(arguments); + DeferredObject valueObj = new DeferredJavaObject(new IntWritable(107)); + DeferredObject[] args = {valueObj}; + IntWritable output = (IntWritable) udf.evaluate(args); + + assertEquals("abs() test for INT failed " , 107, output.get()); + } +}