diff --git a/ql/src/java/org/apache/hadoop/hive/ql/exec/FunctionRegistry.java b/ql/src/java/org/apache/hadoop/hive/ql/exec/FunctionRegistry.java index ccfb455..8dc5f2e 100644 --- a/ql/src/java/org/apache/hadoop/hive/ql/exec/FunctionRegistry.java +++ b/ql/src/java/org/apache/hadoop/hive/ql/exec/FunctionRegistry.java @@ -481,6 +481,7 @@ system.registerGenericUDF("greatest", GenericUDFGreatest.class); system.registerGenericUDF("least", GenericUDFLeast.class); system.registerGenericUDF("cardinality_violation", GenericUDFCardinalityViolation.class); + system.registerGenericUDF("width_bucket", GenericUDFWidthBucket.class); system.registerGenericUDF("from_utc_timestamp", GenericUDFFromUtcTimestamp.class); system.registerGenericUDF("to_utc_timestamp", GenericUDFToUtcTimestamp.class); diff --git a/ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDF.java b/ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDF.java index 00a4f38..303f023 100644 --- a/ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDF.java +++ b/ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDF.java @@ -375,7 +375,7 @@ protected void obtainLongConverter(ObjectInspector[] arguments, int i, Converter converter = ObjectInspectorConverters.getConverter( arguments[i], - PrimitiveObjectInspectorFactory.writableIntObjectInspector); + PrimitiveObjectInspectorFactory.writableLongObjectInspector); converters[i] = converter; inputTypes[i] = inputType; } @@ -566,6 +566,28 @@ protected Integer getConstantIntValue(ObjectInspector[] arguments, int i) return v; } + protected Long getConstantLongValue(ObjectInspector[] arguments, int i) + throws UDFArgumentTypeException { + Object constValue = ((ConstantObjectInspector) arguments[i]).getWritableConstantValue(); + if (constValue == null) { + return null; + } + long v; + if (constValue instanceof LongWritable) { + v = ((LongWritable) constValue).get(); + } else if (constValue instanceof IntWritable) { + v = ((IntWritable) constValue).get(); + } else if (constValue instanceof ShortWritable) { + v = ((ShortWritable) constValue).get(); + } else if (constValue instanceof ByteWritable) { + v = ((ByteWritable) constValue).get(); + } else { + throw new UDFArgumentTypeException(i, getFuncName() + " only takes LONG/INT/SHORT/BYTE types as " + + getArgOrder(i) + " argument, got " + constValue.getClass()); + } + return v; + } + protected String getArgOrder(int i) { i++; switch (i % 100) { diff --git a/ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDFWidthBucket.java b/ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDFWidthBucket.java new file mode 100644 index 0000000..22f0192 --- /dev/null +++ b/ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDFWidthBucket.java @@ -0,0 +1,89 @@ +package org.apache.hadoop.hive.ql.udf.generic; + +import com.google.common.base.Preconditions; +import org.apache.hadoop.hive.ql.exec.Description; +import org.apache.hadoop.hive.ql.exec.UDFArgumentException; +import org.apache.hadoop.hive.ql.metadata.HiveException; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorConverters; +import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory; +import org.apache.hadoop.io.IntWritable; +import org.apache.hadoop.io.LongWritable; + +import static org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils.PrimitiveGrouping.NUMERIC_GROUP; + + +@Description(name = "width_bucket", + value = "_FUNC_(expr, min_value, max_value, num_buckets) - Returns an integer between 0 and num_buckets+1 by " + + "mapping the expr into buckets defined by the range [min_value, max_value]", + extended = "Returns an integer between 0 and num_buckets+1 by " + + "mapping expr into the ith equally sized bucket. Buckets are made by dividing [min_value, max_value] into " + + "equally sized regions. If expr < min_value, return 1, if expr > max_value return num_buckets+1\n" + + "Example: expr is an integer column withs values 1, 10, 20, 30.\n" + + " > SELECT _FUNC_(expr, 5, 25, 4) FROM src;\n1\n1\n3\n5") +public class GenericUDFWidthBucket extends GenericUDF { + + private transient PrimitiveObjectInspector.PrimitiveCategory[] inputTypes = new PrimitiveObjectInspector.PrimitiveCategory[4]; + private transient ObjectInspectorConverters.Converter[] converters = new ObjectInspectorConverters.Converter[4]; + + private final IntWritable output = new IntWritable(); + + @Override + public ObjectInspector initialize(ObjectInspector[] arguments) throws UDFArgumentException { + checkArgsSize(arguments, 4, 4); + + checkArgPrimitive(arguments, 0); + checkArgPrimitive(arguments, 1); + checkArgPrimitive(arguments, 2); + checkArgPrimitive(arguments, 3); + + checkArgGroups(arguments, 0, inputTypes, NUMERIC_GROUP); + checkArgGroups(arguments, 1, inputTypes, NUMERIC_GROUP); + checkArgGroups(arguments, 2, inputTypes, NUMERIC_GROUP); + checkArgGroups(arguments, 3, inputTypes, NUMERIC_GROUP); + + obtainLongConverter(arguments, 0, inputTypes, converters); + obtainLongConverter(arguments, 1, inputTypes, converters); + obtainLongConverter(arguments, 2, inputTypes, converters); + obtainIntConverter(arguments, 3, inputTypes, converters); + + return PrimitiveObjectInspectorFactory.writableIntObjectInspector; + } + + @Override + public Object evaluate(DeferredObject[] arguments) throws HiveException { + Long exprValue = getLongValue(arguments, 0, converters); + + if (exprValue == null) { + return null; + } + + Long minValue = getLongValue(arguments, 1, converters); + Preconditions.checkNotNull(minValue, "minValue in width_bucket function cannot be null"); + + Long maxValue = getLongValue(arguments, 2, converters); + Preconditions.checkNotNull(minValue, "maxValue in width_bucket function cannot be null"); + + Integer numBuckets = getIntValue(arguments, 3, converters); + Preconditions.checkNotNull(numBuckets, "numBuckets in width_bucket function cannot be null"); + Preconditions.checkArgument(numBuckets > 0, "numBuckets in width_bucket function must be above 0"); + + long intervalSize = (maxValue - minValue) / numBuckets; + + if (exprValue <= minValue) { + output.set(1); + } else if (exprValue > maxValue) { + output.set(numBuckets + 1); + } else { + output.set((int) Math.ceil((double) (exprValue - minValue)/intervalSize)); + } + + return output; + } + + @Override + public String getDisplayString(String[] children) { + return getStandardDisplayString("width_bucket", children); + } +} diff --git a/ql/src/test/org/apache/hadoop/hive/ql/udf/generic/TestGenericUDFWidthBucket.java b/ql/src/test/org/apache/hadoop/hive/ql/udf/generic/TestGenericUDFWidthBucket.java new file mode 100644 index 0000000..ba48666 --- /dev/null +++ b/ql/src/test/org/apache/hadoop/hive/ql/udf/generic/TestGenericUDFWidthBucket.java @@ -0,0 +1,64 @@ +package org.apache.hadoop.hive.ql.udf.generic; + +import org.apache.hadoop.hive.ql.metadata.HiveException; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory; +import org.apache.hadoop.io.IntWritable; + +import org.junit.Test; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNull; + + +public class TestGenericUDFWidthBucket { + + @Test + public void testExprLessThanMinValue() throws HiveException { + assertEquals(1, testWidthBucketWithValues(99L, 100L, 5000L, 10).get()); + } + + @Test + public void testExprEqualsMinValue() throws HiveException { + assertEquals(1, testWidthBucketWithValues(100L, 100L, 5000L, 10).get()); + } + + @Test + public void testExprEqualsBoundaryValue() throws HiveException { + assertEquals(1, testWidthBucketWithValues(490L, 100L, 5000L, 10).get()); + } + + @Test + public void testExprEqualsMaxValue() throws HiveException { + assertEquals(10, testWidthBucketWithValues(5000L, 100L, 5000L, 10).get()); + } + + @Test + public void testExprAboveMaxValue() throws HiveException { + assertEquals(11, testWidthBucketWithValues(6000L, 100L, 5000L, 10).get()); + } + + @Test + public void testExprIsNull() throws HiveException { + assertNull(testWidthBucketWithValues(null, 100L, 5000L, 10)); + } + + private IntWritable testWidthBucketWithValues(Long expr, Long minValue, Long maxValue, Integer numBuckets) throws HiveException { + GenericUDFWidthBucket udf = new GenericUDFWidthBucket(); + ObjectInspector valueOI1 = PrimitiveObjectInspectorFactory.javaLongObjectInspector; + ObjectInspector valueOI2 = PrimitiveObjectInspectorFactory.javaLongObjectInspector; + ObjectInspector valueOI3 = PrimitiveObjectInspectorFactory.javaLongObjectInspector; + ObjectInspector valueOI4 = PrimitiveObjectInspectorFactory.javaIntObjectInspector; + ObjectInspector[] arguments = {valueOI1, valueOI2, valueOI3, valueOI4}; + + udf.initialize(arguments); + + GenericUDF.DeferredObject valueObj1 = new GenericUDF.DeferredJavaObject(expr); + GenericUDF.DeferredObject valueObj2 = new GenericUDF.DeferredJavaObject(minValue); + GenericUDF.DeferredObject valueObj3 = new GenericUDF.DeferredJavaObject(maxValue); + GenericUDF.DeferredObject valueObj4 = new GenericUDF.DeferredJavaObject(numBuckets); + GenericUDF.DeferredObject[] args = {valueObj1, valueObj2, valueObj3, valueObj4}; + + return (IntWritable) udf.evaluate(args); + } +} diff --git a/ql/src/test/queries/clientpositive/udf_width_bucket.q b/ql/src/test/queries/clientpositive/udf_width_bucket.q new file mode 100644 index 0000000..40501ea --- /dev/null +++ b/ql/src/test/queries/clientpositive/udf_width_bucket.q @@ -0,0 +1,10 @@ +describe function width_bucket; +desc function extended width_bucket; + +explain select width_bucket(10, 5, 25, 4); + +select +width_bucket(1, 5, 25, 4), +width_bucket(10, 5, 25, 4), +width_bucket(20, 5, 25, 4), +width_bucket(30, 5, 25, 4); diff --git a/ql/src/test/results/clientpositive/udf_width_bucket.q.out b/ql/src/test/results/clientpositive/udf_width_bucket.q.out new file mode 100644 index 0000000..9ca044a --- /dev/null +++ b/ql/src/test/results/clientpositive/udf_width_bucket.q.out @@ -0,0 +1,58 @@ +PREHOOK: query: describe function width_bucket +PREHOOK: type: DESCFUNCTION +POSTHOOK: query: describe function width_bucket +POSTHOOK: type: DESCFUNCTION +width_bucket(expr, min_value, max_value, num_buckets) - Returns an integer between 0 and num_buckets+1 by mapping the expr into buckets defined by the range [min_value, max_value] +PREHOOK: query: desc function extended width_bucket +PREHOOK: type: DESCFUNCTION +POSTHOOK: query: desc function extended width_bucket +POSTHOOK: type: DESCFUNCTION +width_bucket(expr, min_value, max_value, num_buckets) - Returns an integer between 0 and num_buckets+1 by mapping the expr into buckets defined by the range [min_value, max_value] +Returns an integer between 0 and num_buckets+1 by mapping expr into the ith equally sized bucket. Buckets are made by dividing [min_value, max_value] into equally sized regions. If expr < min_value, return 1, if expr > max_value return num_buckets+1 +Example: expr is an integer column withs values 1, 10, 20, 30. + > SELECT width_bucket(expr, 5, 25, 4) FROM src; +1 +1 +3 +5 +Function class:org.apache.hadoop.hive.ql.udf.generic.GenericUDFWidthBucket +Function type:BUILTIN +PREHOOK: query: explain select width_bucket(10, 5, 25, 4) +PREHOOK: type: QUERY +POSTHOOK: query: explain select width_bucket(10, 5, 25, 4) +POSTHOOK: type: QUERY +STAGE DEPENDENCIES: + Stage-0 is a root stage + +STAGE PLANS: + Stage: Stage-0 + Fetch Operator + limit: -1 + Processor Tree: + TableScan + alias: _dummy_table + Row Limit Per Split: 1 + Statistics: Num rows: 1 Data size: 1 Basic stats: COMPLETE Column stats: COMPLETE + Select Operator + expressions: 1 (type: int) + outputColumnNames: _col0 + Statistics: Num rows: 1 Data size: 4 Basic stats: COMPLETE Column stats: COMPLETE + ListSink + +PREHOOK: query: select +width_bucket(1, 5, 25, 4), +width_bucket(10, 5, 25, 4), +width_bucket(20, 5, 25, 4), +width_bucket(30, 5, 25, 4) +PREHOOK: type: QUERY +PREHOOK: Input: _dummy_database@_dummy_table +#### A masked pattern was here #### +POSTHOOK: query: select +width_bucket(1, 5, 25, 4), +width_bucket(10, 5, 25, 4), +width_bucket(20, 5, 25, 4), +width_bucket(30, 5, 25, 4) +POSTHOOK: type: QUERY +POSTHOOK: Input: _dummy_database@_dummy_table +#### A masked pattern was here #### +1 1 3 5