diff --git a/ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDFOPDivide.java b/ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDFOPDivide.java index 746d87a..89e69be 100644 --- a/ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDFOPDivide.java +++ b/ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDFOPDivide.java @@ -110,24 +110,19 @@ protected HiveDecimalWritable evaluate(HiveDecimal left, HiveDecimal right) { return decimalWritable; } - private final static int MIN_START_DEC_DIGITS = 6; - private final static int DEFAULT_DEC_DIGITS = 18; /** * A balanced way to determine the precision/scale of decimal division result. Integer digits and * decimal digits are computed independently. However, when the precision from above reaches above - * HiveDecimal.MAX_PRECISION, integer digit and decimal digits are shrunk equally to fit. + * HiveDecimal.MAX_PRECISION, interger digit and decimal digits are shrunk equally to fit. */ @Override protected DecimalTypeInfo deriveResultDecimalTypeInfo(int prec1, int scale1, int prec2, int scale2) { int intDig = Math.min(HiveDecimal.MAX_SCALE, prec1 - scale1 + scale2); - int decDig = Math.min(HiveDecimal.MAX_SCALE, - Math.max(MIN_START_DEC_DIGITS, scale1 + prec2 + 1)); + int decDig = Math.min(HiveDecimal.MAX_SCALE, Math.max(6, scale1 + prec2 + 1)); int diff = intDig + decDig - HiveDecimal.MAX_SCALE; if (diff > 0) { decDig -= diff/2 + 1; // Slight negative bias. intDig = HiveDecimal.MAX_SCALE - decDig; - } else if (diff < 0 && decDig < DEFAULT_DEC_DIGITS) { - decDig += Math.min(-diff, DEFAULT_DEC_DIGITS - decDig); } return TypeInfoFactory.getDecimalTypeInfo(intDig + decDig, decDig); } diff --git a/ql/src/test/org/apache/hadoop/hive/ql/udf/generic/TestGenericUDFOPDivide.java b/ql/src/test/org/apache/hadoop/hive/ql/udf/generic/TestGenericUDFOPDivide.java index 59fb7dc..6fa3b3f 100644 --- a/ql/src/test/org/apache/hadoop/hive/ql/udf/generic/TestGenericUDFOPDivide.java +++ b/ql/src/test/org/apache/hadoop/hive/ql/udf/generic/TestGenericUDFOPDivide.java @@ -45,6 +45,7 @@ @Test public void testByteDivideShort() throws HiveException { GenericUDFOPDivide udf = new GenericUDFOPDivide(); + ByteWritable left = new ByteWritable((byte) 4); ShortWritable right = new ShortWritable((short) 6); ObjectInspector[] inputOIs = { @@ -57,9 +58,9 @@ public void testByteDivideShort() throws HiveException { }; PrimitiveObjectInspector oi = (PrimitiveObjectInspector) udf.initialize(inputOIs); - Assert.assertEquals(oi.getTypeInfo(), TypeInfoFactory.getDecimalTypeInfo(21, 18)); + Assert.assertEquals(oi.getTypeInfo(), TypeInfoFactory.getDecimalTypeInfo(9, 6)); HiveDecimalWritable res = (HiveDecimalWritable) udf.evaluate(args); - Assert.assertEquals(HiveDecimal.create("0.666666666666666667"), res.getHiveDecimal()); + Assert.assertEquals(HiveDecimal.create("0.666667"), res.getHiveDecimal()); } @Test @@ -108,12 +109,12 @@ public void testDoubleDivideLong() throws HiveException { @Test public void testLongDivideDecimal() throws HiveException { GenericUDFOPDivide udf = new GenericUDFOPDivide(); + LongWritable left = new LongWritable(104); HiveDecimalWritable right = new HiveDecimalWritable(HiveDecimal.create("234.97")); ObjectInspector[] inputOIs = { PrimitiveObjectInspectorFactory.writableLongObjectInspector, - PrimitiveObjectInspectorFactory.getPrimitiveWritableObjectInspector( - TypeInfoFactory.getDecimalTypeInfo(38, 15)) + PrimitiveObjectInspectorFactory.getPrimitiveWritableObjectInspector(TypeInfoFactory.getDecimalTypeInfo(9, 4)) }; DeferredObject[] args = { new DeferredJavaObject(left), @@ -121,9 +122,9 @@ public void testLongDivideDecimal() throws HiveException { }; PrimitiveObjectInspector oi = (PrimitiveObjectInspector) udf.initialize(inputOIs); - Assert.assertEquals(TypeInfoFactory.getDecimalTypeInfo(38, 20), oi.getTypeInfo()); + Assert.assertEquals(TypeInfoFactory.getDecimalTypeInfo(33, 10), oi.getTypeInfo()); HiveDecimalWritable res = (HiveDecimalWritable) udf.evaluate(args); - Assert.assertEquals(HiveDecimal.create("0.44260969485466229731"), res.getHiveDecimal()); + Assert.assertEquals(HiveDecimal.create("0.4426096949"), res.getHiveDecimal()); } @Test @@ -148,7 +149,7 @@ public void testFloatDivideFloat() throws HiveException { } @Test - public void testDoubleDivideDecimal() throws HiveException { + public void testDouleDivideDecimal() throws HiveException { GenericUDFOPDivide udf = new GenericUDFOPDivide(); DoubleWritable left = new DoubleWritable(74.52); @@ -175,25 +176,24 @@ public void testDecimalDivideDecimal() throws HiveException { HiveDecimalWritable left = new HiveDecimalWritable(HiveDecimal.create("14.5")); HiveDecimalWritable right = new HiveDecimalWritable(HiveDecimal.create("234.97")); ObjectInspector[] inputOIs = { - PrimitiveObjectInspectorFactory.getPrimitiveWritableObjectInspector( - TypeInfoFactory.getDecimalTypeInfo(3, 1)), - PrimitiveObjectInspectorFactory.getPrimitiveWritableObjectInspector( - TypeInfoFactory.getDecimalTypeInfo(5, 2)) + PrimitiveObjectInspectorFactory.getPrimitiveWritableObjectInspector(TypeInfoFactory.getDecimalTypeInfo(3, 1)), + PrimitiveObjectInspectorFactory.getPrimitiveWritableObjectInspector(TypeInfoFactory.getDecimalTypeInfo(5, 2)) }; DeferredObject[] args = { new DeferredJavaObject(left), new DeferredJavaObject(right), }; - + PrimitiveObjectInspector oi = (PrimitiveObjectInspector) udf.initialize(inputOIs); - Assert.assertEquals(TypeInfoFactory.getDecimalTypeInfo(22, 18), oi.getTypeInfo()); + Assert.assertEquals(TypeInfoFactory.getDecimalTypeInfo(11, 7), oi.getTypeInfo()); HiveDecimalWritable res = (HiveDecimalWritable) udf.evaluate(args); - Assert.assertEquals(HiveDecimal.create("0.061710005532621186"), res.getHiveDecimal()); + Assert.assertEquals(HiveDecimal.create("0.06171"), res.getHiveDecimal()); } @Test public void testDecimalDivideDecimal2() throws HiveException { GenericUDFOPDivide udf = new GenericUDFOPDivide(); + HiveDecimalWritable left = new HiveDecimalWritable(HiveDecimal.create("5")); HiveDecimalWritable right = new HiveDecimalWritable(HiveDecimal.create("25")); ObjectInspector[] inputOIs = { @@ -206,7 +206,7 @@ public void testDecimalDivideDecimal2() throws HiveException { }; PrimitiveObjectInspector oi = (PrimitiveObjectInspector) udf.initialize(inputOIs); - Assert.assertEquals(TypeInfoFactory.getDecimalTypeInfo(19, 18), oi.getTypeInfo()); + Assert.assertEquals(TypeInfoFactory.getDecimalTypeInfo(7, 6), oi.getTypeInfo()); HiveDecimalWritable res = (HiveDecimalWritable) udf.evaluate(args); Assert.assertEquals(HiveDecimal.create("0.2"), res.getHiveDecimal()); } @@ -221,19 +221,19 @@ public void testDecimalDivideDecimalSameParams() throws HiveException { }; PrimitiveObjectInspector oi = (PrimitiveObjectInspector) udf.initialize(inputOIs); - Assert.assertEquals(TypeInfoFactory.getDecimalTypeInfo(23, 18), oi.getTypeInfo()); + Assert.assertEquals(TypeInfoFactory.getDecimalTypeInfo(13, 8), oi.getTypeInfo()); } @Test public void testDecimalDivisionResultType() throws HiveException { - testDecimalDivisionResultType(5, 2, 3, 2, 23, 18); + testDecimalDivisionResultType(5, 2, 3, 2, 11, 6); testDecimalDivisionResultType(38, 18, 38, 18, 38, 18); testDecimalDivisionResultType(38, 18, 20, 0, 38, 27); - testDecimalDivisionResultType(20, 0, 8, 5, 38, 13); - testDecimalDivisionResultType(10, 0, 10, 0, 28, 18); - testDecimalDivisionResultType(5, 2, 5, 5, 26, 18); - testDecimalDivisionResultType(10, 10, 5, 0, 18, 18); - testDecimalDivisionResultType(10, 10, 5, 5, 23, 18); + testDecimalDivisionResultType(20, 0, 8, 5, 34, 9); + testDecimalDivisionResultType(10, 0, 10, 0, 21, 11); + testDecimalDivisionResultType(5, 2, 5, 5, 16, 8); + testDecimalDivisionResultType(10, 10, 5, 0, 16, 16); + testDecimalDivisionResultType(10, 10, 5, 5, 21, 16); testDecimalDivisionResultType(38, 38, 38, 38, 38, 18); testDecimalDivisionResultType(38, 0, 38, 0, 38, 18); } @@ -259,7 +259,7 @@ public void testReturnTypeBackwardCompat() throws Exception { verifyReturnType(new GenericUDFOPDivide(), "int", "int", "double"); // different from sql compat mode verifyReturnType(new GenericUDFOPDivide(), "int", "float", "double"); verifyReturnType(new GenericUDFOPDivide(), "int", "double", "double"); - verifyReturnType(new GenericUDFOPDivide(), "int", "decimal(10,2)", "decimal(30,18)"); + verifyReturnType(new GenericUDFOPDivide(), "int", "decimal(10,2)", "decimal(23,11)"); verifyReturnType(new GenericUDFOPDivide(), "float", "float", "double"); verifyReturnType(new GenericUDFOPDivide(), "float", "double", "double"); @@ -268,7 +268,7 @@ public void testReturnTypeBackwardCompat() throws Exception { verifyReturnType(new GenericUDFOPDivide(), "double", "double", "double"); verifyReturnType(new GenericUDFOPDivide(), "double", "decimal(10,2)", "double"); - verifyReturnType(new GenericUDFOPDivide(), "decimal(10,2)", "decimal(10,2)", "decimal(28,18)"); + verifyReturnType(new GenericUDFOPDivide(), "decimal(10,2)", "decimal(10,2)", "decimal(23,13)"); // Most tests are done with ANSI SQL mode enabled, set it back to true SessionState.get().getConf().setVar(HiveConf.ConfVars.HIVE_COMPAT, "latest"); @@ -278,10 +278,10 @@ public void testReturnTypeBackwardCompat() throws Exception { public void testReturnTypeAnsiSql() throws Exception { SessionState.get().getConf().setVar(HiveConf.ConfVars.HIVE_COMPAT, "latest"); - verifyReturnType(new GenericUDFOPDivide(), "int", "int", "decimal(28,18)"); + verifyReturnType(new GenericUDFOPDivide(), "int", "int", "decimal(21,11)"); verifyReturnType(new GenericUDFOPDivide(), "int", "float", "double"); verifyReturnType(new GenericUDFOPDivide(), "int", "double", "double"); - verifyReturnType(new GenericUDFOPDivide(), "int", "decimal(10,2)", "decimal(30,18)"); + verifyReturnType(new GenericUDFOPDivide(), "int", "decimal(10,2)", "decimal(23,11)"); verifyReturnType(new GenericUDFOPDivide(), "float", "float", "double"); verifyReturnType(new GenericUDFOPDivide(), "float", "double", "double"); @@ -290,6 +290,6 @@ public void testReturnTypeAnsiSql() throws Exception { verifyReturnType(new GenericUDFOPDivide(), "double", "double", "double"); verifyReturnType(new GenericUDFOPDivide(), "double", "decimal(10,2)", "double"); - verifyReturnType(new GenericUDFOPDivide(), "decimal(10,2)", "decimal(10,2)", "decimal(28,18)"); + verifyReturnType(new GenericUDFOPDivide(), "decimal(10,2)", "decimal(10,2)", "decimal(23,13)"); } }