diff --git ql/src/java/org/apache/hadoop/hive/ql/udf/UDFRound.java ql/src/java/org/apache/hadoop/hive/ql/udf/UDFRound.java index 1c807ef..f19fb91 100644 --- ql/src/java/org/apache/hadoop/hive/ql/udf/UDFRound.java +++ ql/src/java/org/apache/hadoop/hive/ql/udf/UDFRound.java @@ -24,8 +24,11 @@ import org.apache.hadoop.hive.ql.exec.Description; import org.apache.hadoop.hive.ql.exec.UDF; import org.apache.hadoop.hive.serde2.io.BigDecimalWritable; +import org.apache.hadoop.hive.serde2.io.ByteWritable; import org.apache.hadoop.hive.serde2.io.DoubleWritable; +import org.apache.hadoop.hive.serde2.io.ShortWritable; import org.apache.hadoop.io.IntWritable; +import org.apache.hadoop.io.LongWritable; /** * UDFRound. @@ -38,6 +41,10 @@ public class UDFRound extends UDF { private final BigDecimalWritable bigDecimalWritable = new BigDecimalWritable(); private final DoubleWritable doubleWritable = new DoubleWritable(); + private final LongWritable longWritable = new LongWritable(); + private final IntWritable intWritable = new IntWritable(); + private final ShortWritable shortWritable = new ShortWritable(); + private final ByteWritable byteWritable = new ByteWritable(); public UDFRound() { } @@ -48,7 +55,7 @@ private DoubleWritable evaluate(DoubleWritable n, int i) { doubleWritable.set(d); } else { doubleWritable.set(BigDecimal.valueOf(d).setScale(i, - RoundingMode.HALF_UP).doubleValue()); + RoundingMode.HALF_UP).doubleValue()); } return doubleWritable; } @@ -88,4 +95,41 @@ public BigDecimalWritable evaluate(BigDecimalWritable n, IntWritable i) { return evaluate(n, i.get()); } + public LongWritable evaluate(LongWritable n) { + if (n == null) { + return null; + } + longWritable.set(BigDecimal.valueOf(n.get()).setScale(0, + RoundingMode.HALF_UP).longValue()); + return longWritable; + } + + public IntWritable evaluate(IntWritable n) { + if (n == null) { + return null; + } + intWritable.set(BigDecimal.valueOf(n.get()).setScale(0, + RoundingMode.HALF_UP).intValue()); + return intWritable; + } + + public ShortWritable evaluate(ShortWritable n) { + if (n == null) { + return null; + } + shortWritable.set(BigDecimal.valueOf(n.get()).setScale(0, + RoundingMode.HALF_UP).shortValue()); + return shortWritable; + } + + public ByteWritable evaluate(ByteWritable n) { + if (n == null) { + return null; + } + byteWritable.set(BigDecimal.valueOf(n.get()).setScale(0, + RoundingMode.HALF_UP).byteValue()); + return byteWritable; + } + } + diff --git ql/src/test/queries/clientpositive/udf_round_3.q ql/src/test/queries/clientpositive/udf_round_3.q new file mode 100644 index 0000000..50a1f44 --- /dev/null +++ ql/src/test/queries/clientpositive/udf_round_3.q @@ -0,0 +1,14 @@ +-- test for TINYINT +select round(-128), round(127), round(0) from src limit 1; + +-- test for SMALLINT +select round(-32768), round(32767), round(-129), round(128) from src limit 1; + +-- test for INT +select round(cast(negative(pow(2, 31)) as INT)), round(cast((pow(2, 31) - 1) as INT)), round(-32769), round(32768) from src limit 1; + +-- test for BIGINT +select round(cast(negative(pow(2, 63)) as BIGINT)), round(cast((pow(2, 63) - 1) as BIGINT)), round(cast(negative(pow(2, 31) + 1) as BIGINT)), round(cast(pow(2, 31) as BIGINT)) from src limit 1; + +-- test for DOUBLE +select round(126.1), round(126.7), round(32766.1), round(32766.7) from src limit 1; diff --git ql/src/test/results/clientpositive/udf_round.q.out ql/src/test/results/clientpositive/udf_round.q.out index 900e91e..9ad1d91 100644 --- ql/src/test/results/clientpositive/udf_round.q.out +++ ql/src/test/results/clientpositive/udf_round.q.out @@ -40,7 +40,7 @@ FROM src LIMIT 1 POSTHOOK: type: QUERY POSTHOOK: Input: default@src #### A masked pattern was here #### -55555.0 55555.0 55555.0 55555.0 55555.0 55560.0 55600.0 56000.0 60000.0 100000.0 0.0 0.0 0.0 +55555 55555.0 55555.0 55555.0 55555.0 55560.0 55600.0 56000.0 60000.0 100000.0 0.0 0.0 0.0 PREHOOK: query: SELECT round(125.315), round(125.315, 0), round(125.315, 1), round(125.315, 2), round(125.315, 3), round(125.315, 4), diff --git ql/src/test/results/clientpositive/udf_round_3.q.out ql/src/test/results/clientpositive/udf_round_3.q.out new file mode 100644 index 0000000..0b00d6a --- /dev/null +++ ql/src/test/results/clientpositive/udf_round_3.q.out @@ -0,0 +1,55 @@ +PREHOOK: query: -- test for TINYINT +select round(-128), round(127), round(0) from src limit 1 +PREHOOK: type: QUERY +PREHOOK: Input: default@src +#### A masked pattern was here #### +POSTHOOK: query: -- test for TINYINT +select round(-128), round(127), round(0) from src limit 1 +POSTHOOK: type: QUERY +POSTHOOK: Input: default@src +#### A masked pattern was here #### +-128 127 0 +PREHOOK: query: -- test for SMALLINT +select round(-32768), round(32767), round(-129), round(128) from src limit 1 +PREHOOK: type: QUERY +PREHOOK: Input: default@src +#### A masked pattern was here #### +POSTHOOK: query: -- test for SMALLINT +select round(-32768), round(32767), round(-129), round(128) from src limit 1 +POSTHOOK: type: QUERY +POSTHOOK: Input: default@src +#### A masked pattern was here #### +-32768 32767 -129 128 +PREHOOK: query: -- test for INT +select round(cast(negative(pow(2, 31)) as INT)), round(cast((pow(2, 31) - 1) as INT)), round(-32769), round(32768) from src limit 1 +PREHOOK: type: QUERY +PREHOOK: Input: default@src +#### A masked pattern was here #### +POSTHOOK: query: -- test for INT +select round(cast(negative(pow(2, 31)) as INT)), round(cast((pow(2, 31) - 1) as INT)), round(-32769), round(32768) from src limit 1 +POSTHOOK: type: QUERY +POSTHOOK: Input: default@src +#### A masked pattern was here #### +-2147483648 2147483647 -32769 32768 +PREHOOK: query: -- test for BIGINT +select round(cast(negative(pow(2, 63)) as BIGINT)), round(cast((pow(2, 63) - 1) as BIGINT)), round(cast(negative(pow(2, 31) + 1) as BIGINT)), round(cast(pow(2, 31) as BIGINT)) from src limit 1 +PREHOOK: type: QUERY +PREHOOK: Input: default@src +#### A masked pattern was here #### +POSTHOOK: query: -- test for BIGINT +select round(cast(negative(pow(2, 63)) as BIGINT)), round(cast((pow(2, 63) - 1) as BIGINT)), round(cast(negative(pow(2, 31) + 1) as BIGINT)), round(cast(pow(2, 31) as BIGINT)) from src limit 1 +POSTHOOK: type: QUERY +POSTHOOK: Input: default@src +#### A masked pattern was here #### +-9223372036854775808 9223372036854775807 -2147483649 2147483648 +PREHOOK: query: -- test for DOUBLE +select round(126.1), round(126.7), round(32766.1), round(32766.7) from src limit 1 +PREHOOK: type: QUERY +PREHOOK: Input: default@src +#### A masked pattern was here #### +POSTHOOK: query: -- test for DOUBLE +select round(126.1), round(126.7), round(32766.1), round(32766.7) from src limit 1 +POSTHOOK: type: QUERY +POSTHOOK: Input: default@src +#### A masked pattern was here #### +126.0 127.0 32766.0 32767.0