diff --git a/ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDFWidthBucket.java b/ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDFWidthBucket.java index c767d35..e165274 100644 --- a/ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDFWidthBucket.java +++ b/ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDFWidthBucket.java @@ -1,37 +1,41 @@ package org.apache.hadoop.hive.ql.udf.generic; import com.google.common.base.Preconditions; + +import org.apache.hadoop.hive.common.type.HiveDecimal; import org.apache.hadoop.hive.ql.exec.Description; import org.apache.hadoop.hive.ql.exec.UDFArgumentException; import org.apache.hadoop.hive.ql.metadata.HiveException; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; -import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorConverters; import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils; + import org.apache.hadoop.io.IntWritable; -import org.apache.hadoop.io.LongWritable; import static org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils.PrimitiveGrouping.NUMERIC_GROUP; import static org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils.PrimitiveGrouping.VOID_GROUP; @Description(name = "width_bucket", - value = "_FUNC_(expr, min_value, max_value, num_buckets) - Returns an integer between 0 and num_buckets+1 by " - + "mapping the expr into buckets defined by the range [min_value, max_value]", - extended = "Returns an integer between 0 and num_buckets+1 by " - + "mapping expr into the ith equally sized bucket. Buckets are made by dividing [min_value, max_value] into " - + "equally sized regions. If expr < min_value, return 1, if expr > max_value return num_buckets+1\n" - + "Example: expr is an integer column withs values 1, 10, 20, 30.\n" - + " > SELECT _FUNC_(expr, 5, 25, 4) FROM src;\n1\n1\n3\n5") + value = "_FUNC_(expr, min_value, max_value, num_buckets) - Returns an integer between 0 and num_buckets+1 by " + + "mapping the expr into buckets defined by the range [min_value, max_value]", + extended = "Returns an integer between 0 and num_buckets+1 by " + + "mapping expr into the ith equally sized bucket. Buckets are made by dividing [min_value, max_value] into " + + "equally sized regions. If expr < min_value, return 1, if expr > max_value return num_buckets+1\n" + + "Example: expr is an integer column withs values 1, 10, 20, 30.\n" + + " > SELECT _FUNC_(expr, 5, 25, 4) FROM src;\n1\n1\n3\n5") public class GenericUDFWidthBucket extends GenericUDF { private transient PrimitiveObjectInspector.PrimitiveCategory[] inputTypes = new PrimitiveObjectInspector.PrimitiveCategory[4]; - private transient ObjectInspectorConverters.Converter[] converters = new ObjectInspectorConverters.Converter[4]; + private transient ObjectInspector[] objectInspectors; private final IntWritable output = new IntWritable(); @Override public ObjectInspector initialize(ObjectInspector[] arguments) throws UDFArgumentException { + this.objectInspectors = arguments; + checkArgsSize(arguments, 4, 4); checkArgPrimitive(arguments, 0); @@ -44,38 +48,40 @@ public ObjectInspector initialize(ObjectInspector[] arguments) throws UDFArgumen checkArgGroups(arguments, 2, inputTypes, NUMERIC_GROUP, VOID_GROUP); checkArgGroups(arguments, 3, inputTypes, NUMERIC_GROUP, VOID_GROUP); - obtainLongConverter(arguments, 0, inputTypes, converters); - obtainLongConverter(arguments, 1, inputTypes, converters); - obtainLongConverter(arguments, 2, inputTypes, converters); - obtainIntConverter(arguments, 3, inputTypes, converters); - return PrimitiveObjectInspectorFactory.writableIntObjectInspector; } @Override public Object evaluate(DeferredObject[] arguments) throws HiveException { - Long exprValue = getLongValue(arguments, 0, converters); - Long minValue = getLongValue(arguments, 1, converters); - Long maxValue = getLongValue(arguments, 2, converters); - Integer numBuckets = getIntValue(arguments, 3, converters); - - if (exprValue == null || minValue == null || maxValue == null || numBuckets == null) { + if (arguments[0].get() == null || arguments[1].get() == null || arguments[2].get() == null || arguments[3].get() == null) { return null; } + HiveDecimal exprValue = PrimitiveObjectInspectorUtils.getHiveDecimal(arguments[0].get(), + (PrimitiveObjectInspector) this.objectInspectors[0]); + HiveDecimal minValue = PrimitiveObjectInspectorUtils.getHiveDecimal(arguments[1].get(), + (PrimitiveObjectInspector) this.objectInspectors[1]); + HiveDecimal maxValue = PrimitiveObjectInspectorUtils.getHiveDecimal(arguments[2].get(), + (PrimitiveObjectInspector) this.objectInspectors[2]); + Integer numBuckets = PrimitiveObjectInspectorUtils.getInt(arguments[3].get(), + (PrimitiveObjectInspector) this.objectInspectors[3]); + Preconditions.checkArgument(numBuckets > 0, "numBuckets in width_bucket function must be above 0"); - long intervalSize = (maxValue - minValue) / numBuckets; + Preconditions.checkArgument(!maxValue.equals(minValue), + "maxValue and minValue in width_bucket function cannot be equal"); + + HiveDecimal intervalSize = maxValue.subtract(minValue).divide(HiveDecimal.create(numBuckets)); - if (exprValue < minValue) { + if (exprValue.compareTo(minValue) < 0) { output.set(0); - } else if (exprValue > maxValue) { + } else if (exprValue.compareTo(maxValue) > 0) { output.set(numBuckets + 1); } else { - long diff = exprValue - minValue; - if (diff % intervalSize == 0) { - output.set((int) (diff/intervalSize + 1)); + HiveDecimal diff = exprValue.subtract(minValue); + if (diff.remainder(intervalSize).equals(HiveDecimal.ZERO)) { + output.set(diff.divide(intervalSize).add(HiveDecimal.ONE).intValue()); } else { - output.set((int) Math.ceil((double) (diff) / intervalSize)); + output.set((int) Math.ceil(diff.divide(intervalSize).doubleValue())); } } diff --git a/ql/src/test/queries/clientpositive/udf_width_bucket.q b/ql/src/test/queries/clientpositive/udf_width_bucket.q index 6ac60d6..5ccae17 100644 --- a/ql/src/test/queries/clientpositive/udf_width_bucket.q +++ b/ql/src/test/queries/clientpositive/udf_width_bucket.q @@ -27,3 +27,18 @@ width_bucket(-10, -5, 15, 4), width_bucket(0, -5, 15, 4), width_bucket(10, -5, 15, 4), width_bucket(20, -5, 15, 4); + +select +width_bucket(0.1, 0, 1, 10), +width_bucket(0.25, 0, 1, 10), +width_bucket(0.3456, 0, 1, 10), +width_bucket(0.654321, 0, 1, 10); + +select +width_bucket(-0.5, -1.5, 1.5, 10), +width_bucket(-0.3, -1.5, 1.5, 10), +width_bucket(-0.25, -1.5, 1.5, 10), +width_bucket(0, -1.5, 1.5, 10), +width_bucket(0.75, -1.5, 1.5, 10), +width_bucket(1.25, -1.5, 1.5, 10), +width_bucket(1.5, -1.5, 1.5, 10); diff --git a/ql/src/test/results/clientpositive/udf_width_bucket.q.out b/ql/src/test/results/clientpositive/udf_width_bucket.q.out index a72e977..88ab1cd 100644 --- a/ql/src/test/results/clientpositive/udf_width_bucket.q.out +++ b/ql/src/test/results/clientpositive/udf_width_bucket.q.out @@ -109,3 +109,43 @@ POSTHOOK: type: QUERY POSTHOOK: Input: _dummy_database@_dummy_table #### A masked pattern was here #### 0 2 4 5 +PREHOOK: query: select +width_bucket(0.1, 0, 1, 10), +width_bucket(0.25, 0, 1, 10), +width_bucket(0.3456, 0, 1, 10), +width_bucket(0.654321, 0, 1, 10) +PREHOOK: type: QUERY +PREHOOK: Input: _dummy_database@_dummy_table +#### A masked pattern was here #### +POSTHOOK: query: select +width_bucket(0.1, 0, 1, 10), +width_bucket(0.25, 0, 1, 10), +width_bucket(0.3456, 0, 1, 10), +width_bucket(0.654321, 0, 1, 10) +POSTHOOK: type: QUERY +POSTHOOK: Input: _dummy_database@_dummy_table +#### A masked pattern was here #### +2 3 4 7 +PREHOOK: query: select +width_bucket(-0.5, -1.5, 1.5, 10), +width_bucket(-0.3, -1.5, 1.5, 10), +width_bucket(-0.25, -1.5, 1.5, 10), +width_bucket(0, -1.5, 1.5, 10), +width_bucket(0.75, -1.5, 1.5, 10), +width_bucket(1.25, -1.5, 1.5, 10), +width_bucket(1.5, -1.5, 1.5, 10) +PREHOOK: type: QUERY +PREHOOK: Input: _dummy_database@_dummy_table +#### A masked pattern was here #### +POSTHOOK: query: select +width_bucket(-0.5, -1.5, 1.5, 10), +width_bucket(-0.3, -1.5, 1.5, 10), +width_bucket(-0.25, -1.5, 1.5, 10), +width_bucket(0, -1.5, 1.5, 10), +width_bucket(0.75, -1.5, 1.5, 10), +width_bucket(1.25, -1.5, 1.5, 10), +width_bucket(1.5, -1.5, 1.5, 10) +POSTHOOK: type: QUERY +POSTHOOK: Input: _dummy_database@_dummy_table +#### A masked pattern was here #### +4 5 5 6 8 10 11