diff --git a/ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDFBaseNumeric.java b/ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDFBaseNumeric.java index ef6ef11..1734328 100644 --- a/ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDFBaseNumeric.java +++ b/ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDFBaseNumeric.java @@ -288,6 +288,35 @@ protected DecimalTypeInfo deriveResultDecimalTypeInfo() { protected abstract DecimalTypeInfo deriveResultDecimalTypeInfo(int prec1, int scale1, int prec2, int scale2); + public static final int MINIMUM_ADJUSTED_SCALE = 6; + + /** + * Create DecimalTypeInfo from input precision/scale, adjusting if necessary to fit max precision + * @param precision precision value before adjustment + * @param scale scale value before adjustment + * @return + */ + protected DecimalTypeInfo adjustPrecScale(int precision, int scale) { + // Assumptions: + // precision >= scale + // scale >= 0 + + if (precision <= HiveDecimal.MAX_PRECISION) { + // Adjustment only needed when we exceed max precision + return new DecimalTypeInfo(precision, scale); + } + + // Precision/scale exceed maximum precision. Result must be adjusted to HiveDecimal.MAX_PRECISION. + // See https://blogs.msdn.microsoft.com/sqlprogrammability/2006/03/29/multiplication-and-division-with-numerics/ + int intDigits = precision - scale; + // If original scale less than 6, use original scale value; otherwise preserve at least 6 fractional digits + int minScaleValue = Math.min(scale, MINIMUM_ADJUSTED_SCALE); + int adjustedScale = HiveDecimal.MAX_PRECISION - intDigits; + adjustedScale = Math.max(adjustedScale, minScaleValue); + + return new DecimalTypeInfo(HiveDecimal.MAX_PRECISION, adjustedScale); + } + public void copyToNewInstance(Object newInstance) throws UDFArgumentException { super.copyToNewInstance(newInstance); GenericUDFBaseNumeric other = (GenericUDFBaseNumeric) newInstance; diff --git a/ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDFOPDivide.java b/ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDFOPDivide.java index 89e69be..225a529 100644 --- a/ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDFOPDivide.java +++ b/ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDFOPDivide.java @@ -110,21 +110,16 @@ protected HiveDecimalWritable evaluate(HiveDecimal left, HiveDecimal right) { return decimalWritable; } - /** - * A balanced way to determine the precision/scale of decimal division result. Integer digits and - * decimal digits are computed independently. However, when the precision from above reaches above - * HiveDecimal.MAX_PRECISION, interger digit and decimal digits are shrunk equally to fit. - */ @Override protected DecimalTypeInfo deriveResultDecimalTypeInfo(int prec1, int scale1, int prec2, int scale2) { - int intDig = Math.min(HiveDecimal.MAX_SCALE, prec1 - scale1 + scale2); - int decDig = Math.min(HiveDecimal.MAX_SCALE, Math.max(6, scale1 + prec2 + 1)); - int diff = intDig + decDig - HiveDecimal.MAX_SCALE; - if (diff > 0) { - decDig -= diff/2 + 1; // Slight negative bias. - intDig = HiveDecimal.MAX_SCALE - decDig; - } - return TypeInfoFactory.getDecimalTypeInfo(intDig + decDig, decDig); + // From https://msdn.microsoft.com/en-us/library/ms190476.aspx + // e1 / e2 + // Precision: p1 - s1 + s2 + max(6, s1 + p2 + 1) + // Scale: max(6, s1 + p2 + 1) + int intDig = prec1 - scale1 + scale2; + int scale = Math.max(6, scale1 + prec2 + 1); + int prec = intDig + scale; + return adjustPrecScale(prec, scale); } } diff --git a/ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDFOPMod.java b/ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDFOPMod.java index 9d283bd..6d3e82e 100644 --- a/ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDFOPMod.java +++ b/ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDFOPMod.java @@ -119,9 +119,13 @@ protected HiveDecimalWritable evaluate(HiveDecimal left, HiveDecimal right) { @Override protected DecimalTypeInfo deriveResultDecimalTypeInfo(int prec1, int scale1, int prec2, int scale2) { + // From https://msdn.microsoft.com/en-us/library/ms190476.aspx + // e1 % e2 + // Precision: min(p1-s1, p2 -s2) + max( s1,s2 ) + // Scale: max(s1, s2) + int prec = Math.min(prec1 - scale1, prec2 - scale2) + Math.max(scale1, scale2); int scale = Math.max(scale1, scale2); - int prec = Math.min(HiveDecimal.MAX_PRECISION, Math.min(prec1 - scale1, prec2 - scale2) + scale); - return TypeInfoFactory.getDecimalTypeInfo(prec, scale); + return adjustPrecScale(prec, scale); } } diff --git a/ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDFOPMultiply.java b/ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDFOPMultiply.java index 7dc1f83..47a11f3 100644 --- a/ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDFOPMultiply.java +++ b/ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDFOPMultiply.java @@ -98,9 +98,13 @@ protected HiveDecimalWritable evaluate(HiveDecimal left, HiveDecimal right) { @Override protected DecimalTypeInfo deriveResultDecimalTypeInfo(int prec1, int scale1, int prec2, int scale2) { - int scale = Math.min(HiveDecimal.MAX_SCALE, scale1 + scale2 ); - int prec = Math.min(HiveDecimal.MAX_PRECISION, prec1 + prec2 + 1); - return TypeInfoFactory.getDecimalTypeInfo(prec, scale); + // From https://msdn.microsoft.com/en-us/library/ms190476.aspx + // e1 * e2 + // Precision: p1 + p2 + 1 + // Scale: s1 + s2 + int scale = scale1 + scale2; + int prec = prec1 + prec2 + 1; + return adjustPrecScale(prec, scale); } } diff --git a/ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDFOPNumericMinus.java b/ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDFOPNumericMinus.java index a31cf78..28f7907 100644 --- a/ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDFOPNumericMinus.java +++ b/ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDFOPNumericMinus.java @@ -88,10 +88,14 @@ protected HiveDecimalWritable evaluate(HiveDecimal left, HiveDecimal right) { @Override protected DecimalTypeInfo deriveResultDecimalTypeInfo(int prec1, int scale1, int prec2, int scale2) { + // From https://msdn.microsoft.com/en-us/library/ms190476.aspx + // e1 + e2 + // Precision: max(s1, s2) + max(p1-s1, p2-s2) + 1 + // Scale: max(s1, s2) int intPart = Math.max(prec1 - scale1, prec2 - scale2); int scale = Math.max(scale1, scale2); - int prec = Math.min(intPart + scale + 1, HiveDecimal.MAX_PRECISION); - return TypeInfoFactory.getDecimalTypeInfo(prec, scale); + int prec = intPart + scale + 1; + return adjustPrecScale(prec, scale); } } diff --git a/ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDFOPNumericPlus.java b/ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDFOPNumericPlus.java index b055776..b2b76f0 100644 --- a/ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDFOPNumericPlus.java +++ b/ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDFOPNumericPlus.java @@ -99,10 +99,14 @@ protected HiveDecimalWritable evaluate(HiveDecimal left, HiveDecimal right) { @Override protected DecimalTypeInfo deriveResultDecimalTypeInfo(int prec1, int scale1, int prec2, int scale2) { + // From https://msdn.microsoft.com/en-us/library/ms190476.aspx + // e1 + e2 + // Precision: max(s1, s2) + max(p1-s1, p2-s2) + 1 + // Scale: max(s1, s2) int intPart = Math.max(prec1 - scale1, prec2 - scale2); int scale = Math.max(scale1, scale2); - int prec = Math.min(intPart + scale + 1, HiveDecimal.MAX_PRECISION); - return TypeInfoFactory.getDecimalTypeInfo(prec, scale); + int prec = intPart + scale + 1; + return adjustPrecScale(prec, scale); } } diff --git a/ql/src/test/org/apache/hadoop/hive/ql/udf/generic/TestGenericUDFOPDivide.java b/ql/src/test/org/apache/hadoop/hive/ql/udf/generic/TestGenericUDFOPDivide.java index 6fa3b3f..523a1a4 100644 --- a/ql/src/test/org/apache/hadoop/hive/ql/udf/generic/TestGenericUDFOPDivide.java +++ b/ql/src/test/org/apache/hadoop/hive/ql/udf/generic/TestGenericUDFOPDivide.java @@ -227,15 +227,15 @@ public void testDecimalDivideDecimalSameParams() throws HiveException { @Test public void testDecimalDivisionResultType() throws HiveException { testDecimalDivisionResultType(5, 2, 3, 2, 11, 6); - testDecimalDivisionResultType(38, 18, 38, 18, 38, 18); - testDecimalDivisionResultType(38, 18, 20, 0, 38, 27); + testDecimalDivisionResultType(38, 18, 38, 18, 38, 6); + testDecimalDivisionResultType(38, 18, 20, 0, 38, 18); testDecimalDivisionResultType(20, 0, 8, 5, 34, 9); testDecimalDivisionResultType(10, 0, 10, 0, 21, 11); testDecimalDivisionResultType(5, 2, 5, 5, 16, 8); testDecimalDivisionResultType(10, 10, 5, 0, 16, 16); testDecimalDivisionResultType(10, 10, 5, 5, 21, 16); - testDecimalDivisionResultType(38, 38, 38, 38, 38, 18); - testDecimalDivisionResultType(38, 0, 38, 0, 38, 18); + testDecimalDivisionResultType(38, 38, 38, 38, 38, 6); + testDecimalDivisionResultType(38, 0, 38, 0, 38, 6); } private void testDecimalDivisionResultType(int prec1, int scale1, int prec2, int scale2, int prec3, int scale3) diff --git a/ql/src/test/org/apache/hadoop/hive/ql/udf/generic/TestGenericUDFOPMultiply.java b/ql/src/test/org/apache/hadoop/hive/ql/udf/generic/TestGenericUDFOPMultiply.java index e342a76..9b02538 100644 --- a/ql/src/test/org/apache/hadoop/hive/ql/udf/generic/TestGenericUDFOPMultiply.java +++ b/ql/src/test/org/apache/hadoop/hive/ql/udf/generic/TestGenericUDFOPMultiply.java @@ -243,5 +243,11 @@ public void testReturnTypeAnsiSql() throws Exception { verifyReturnType(new GenericUDFOPMultiply(), "double", "decimal(10,2)", "double"); verifyReturnType(new GenericUDFOPMultiply(), "decimal(10,2)", "decimal(10,2)", "decimal(21,4)"); + + verifyReturnType(new GenericUDFOPMultiply(), "decimal(38,18)", "decimal(38,18)", "decimal(38,6)"); + verifyReturnType(new GenericUDFOPMultiply(), "decimal(38,38)", "decimal(38,38)", "decimal(38,37)"); + verifyReturnType(new GenericUDFOPMultiply(), "decimal(38,0)", "decimal(38,0)", "decimal(38,0)"); + verifyReturnType(new GenericUDFOPMultiply(), "decimal(38,38)", "decimal(38,0)", "decimal(38,6)"); + verifyReturnType(new GenericUDFOPMultiply(), "decimal(20,2)", "decimal(20,0)", "decimal(38,2)"); } }