diff --git common/src/java/org/apache/hadoop/hive/common/type/UnsignedInt128.java common/src/java/org/apache/hadoop/hive/common/type/UnsignedInt128.java index fb3c346..16f70dc 100644 --- common/src/java/org/apache/hadoop/hive/common/type/UnsignedInt128.java +++ common/src/java/org/apache/hadoop/hive/common/type/UnsignedInt128.java @@ -53,6 +53,14 @@ /** A special value representing 10**38. */ public static final UnsignedInt128 TEN_TO_THIRTYEIGHT = new UnsignedInt128(0, 0x98a2240, 0x5a86c47a, 0x4b3b4ca8); + + /** The TLS scratch array used for computations */ + private static final ThreadLocal scratch = new ThreadLocal() { + @Override + protected int[] initialValue() { + return new int[INT_COUNT * 2]; + } + }; /** * Int32 elements as little-endian (v[0] is least significant) unsigned @@ -1799,6 +1807,45 @@ private void shiftLeftDestructive(int wordShifts, int bitShiftsInWord) { } /** + * This is the BigInteger.multiplyToLen implementation + * adapted to fixed size numbers and LittleEndian storage in the arrays. + * Even the iterations are preserved from the BigEndian implementation, + * only the array index access was reversed (eg. x[4-i]) + */ + private static int[] multiplyToLen(int[] x, int[] y, int[] z) { + assert (x.length == INT_COUNT); + assert (y.length == INT_COUNT); + assert (z.length == INT_COUNT*2); + assert (z != x); + assert (z != y); + + int xstart = INT_COUNT - 1; + int ystart = INT_COUNT - 1; + + long carry = 0; + for (int j=ystart, k=ystart+1+xstart; j>=0; j--, k--) { + long product = (y[INT_COUNT-1-j] & SqlMathUtil.LONG_MASK) * + (x[INT_COUNT-1-xstart] & SqlMathUtil.LONG_MASK) + carry; + z[INT_COUNT*2-1-k] = (int)product; + carry = product >>> 32; + } + z[INT_COUNT*2 - 1 -xstart] = (int)carry; + + for (int i = xstart-1; i >= 0; i--) { + carry = 0; + for (int j=ystart, k=ystart + 1 + i; j >= 0; j--, k--) { + long product = (y[INT_COUNT - 1 - j] & SqlMathUtil.LONG_MASK) * + (x[INT_COUNT - 1 - i] & SqlMathUtil.LONG_MASK) + + (z[INT_COUNT*2 -1 -k] & SqlMathUtil.LONG_MASK) + carry; + z[INT_COUNT*2 - 1 - k] = (int)product; + carry = product >>> 32; + } + z[INT_COUNT*2 - 1 -i] = (int)carry; + } + return z; + } + + /** * Multiplies this value with the given value. * * @param left @@ -1807,111 +1854,27 @@ private void shiftLeftDestructive(int wordShifts, int bitShiftsInWord) { * the value to multiply. in */ private static void multiplyArrays4And4To4NoOverflow(int[] left, int[] right) { - assert (left.length == 4); - assert (right.length == 4); - long product; - - product = (right[0] & SqlMathUtil.LONG_MASK) - * (left[0] & SqlMathUtil.LONG_MASK); - int z0 = (int) product; - - product = (right[0] & SqlMathUtil.LONG_MASK) - * (left[1] & SqlMathUtil.LONG_MASK) - + (right[1] & SqlMathUtil.LONG_MASK) - * (left[0] & SqlMathUtil.LONG_MASK) + (product >>> 32); - int z1 = (int) product; - - product = (right[0] & SqlMathUtil.LONG_MASK) - * (left[2] & SqlMathUtil.LONG_MASK) - + (right[1] & SqlMathUtil.LONG_MASK) - * (left[1] & SqlMathUtil.LONG_MASK) - + (right[2] & SqlMathUtil.LONG_MASK) - * (left[0] & SqlMathUtil.LONG_MASK) + (product >>> 32); - int z2 = (int) product; - - // v[3] - product = (right[0] & SqlMathUtil.LONG_MASK) - * (left[3] & SqlMathUtil.LONG_MASK) - + (right[1] & SqlMathUtil.LONG_MASK) - * (left[2] & SqlMathUtil.LONG_MASK) - + (right[2] & SqlMathUtil.LONG_MASK) - * (left[1] & SqlMathUtil.LONG_MASK) - + (right[3] & SqlMathUtil.LONG_MASK) - * (left[0] & SqlMathUtil.LONG_MASK) + (product >>> 32); - int z3 = (int) product; - if ((product >>> 32) != 0) { - SqlMathUtil.throwOverflowException(); - } - - // the combinations below definitely result in overflow - if ((right[3] != 0 && (left[3] != 0 || left[2] != 0 || left[1] != 0)) - || (right[2] != 0 && (left[3] != 0 || left[2] != 0)) - || (right[1] != 0 && left[3] != 0)) { - SqlMathUtil.throwOverflowException(); - } + assert (left.length == INT_COUNT); + assert (right.length == INT_COUNT); + + int[] z = scratch.get(); + + z = multiplyToLen(left, right, z); + + left[0] = z[0]; + left[1] = z[1]; + left[2] = z[2]; + left[3] = z[3]; - left[0] = z0; - left[1] = z1; - left[2] = z2; - left[3] = z3; } private static int[] multiplyArrays4And4To8(int[] left, int[] right) { - assert (left.length == 4); - assert (right.length == 4); - long product; - - // this method could go beyond the integer ranges until we scale back - // so, we need twice more variables. - int[] z = new int[8]; - - product = (right[0] & SqlMathUtil.LONG_MASK) - * (left[0] & SqlMathUtil.LONG_MASK); - z[0] = (int) product; - - product = (right[0] & SqlMathUtil.LONG_MASK) - * (left[1] & SqlMathUtil.LONG_MASK) - + (right[1] & SqlMathUtil.LONG_MASK) - * (left[0] & SqlMathUtil.LONG_MASK) + (product >>> 32); - z[1] = (int) product; - - product = (right[0] & SqlMathUtil.LONG_MASK) - * (left[2] & SqlMathUtil.LONG_MASK) - + (right[1] & SqlMathUtil.LONG_MASK) - * (left[1] & SqlMathUtil.LONG_MASK) - + (right[2] & SqlMathUtil.LONG_MASK) - * (left[0] & SqlMathUtil.LONG_MASK) + (product >>> 32); - z[2] = (int) product; - - product = (right[0] & SqlMathUtil.LONG_MASK) - * (left[3] & SqlMathUtil.LONG_MASK) - + (right[1] & SqlMathUtil.LONG_MASK) - * (left[2] & SqlMathUtil.LONG_MASK) - + (right[2] & SqlMathUtil.LONG_MASK) - * (left[1] & SqlMathUtil.LONG_MASK) - + (right[3] & SqlMathUtil.LONG_MASK) - * (left[0] & SqlMathUtil.LONG_MASK) + (product >>> 32); - z[3] = (int) product; - - product = (right[1] & SqlMathUtil.LONG_MASK) - * (left[3] & SqlMathUtil.LONG_MASK) - + (right[2] & SqlMathUtil.LONG_MASK) - * (left[2] & SqlMathUtil.LONG_MASK) - + (right[3] & SqlMathUtil.LONG_MASK) - * (left[1] & SqlMathUtil.LONG_MASK) + (product >>> 32); - z[4] = (int) product; - - product = (right[2] & SqlMathUtil.LONG_MASK) - * (left[3] & SqlMathUtil.LONG_MASK) - + (right[3] & SqlMathUtil.LONG_MASK) - * (left[2] & SqlMathUtil.LONG_MASK) + (product >>> 32); - z[5] = (int) product; - - // v[1], v[0] - product = (right[3] & SqlMathUtil.LONG_MASK) - * (left[3] & SqlMathUtil.LONG_MASK) + (product >>> 32); - z[6] = (int) product; - z[7] = (int) (product >>> 32); + assert (left.length == INT_COUNT); + assert (right.length == INT_COUNT); + + int[] z= scratch.get(); + + multiplyToLen(left, right, z); return z; } diff --git common/src/test/org/apache/hadoop/hive/common/type/TestDecimal128.java common/src/test/org/apache/hadoop/hive/common/type/TestDecimal128.java index 6824cd7..080509c 100644 --- common/src/test/org/apache/hadoop/hive/common/type/TestDecimal128.java +++ common/src/test/org/apache/hadoop/hive/common/type/TestDecimal128.java @@ -18,14 +18,12 @@ import static org.junit.Assert.*; import java.math.BigDecimal; -import java.math.MathContext; import java.math.RoundingMode; import java.util.Random; import org.junit.After; import org.junit.Before; import org.junit.Test; - import org.apache.hadoop.hive.common.type.UnsignedInt128; /** @@ -736,6 +734,19 @@ public void testPrecisionOverflow() { } catch (ArithmeticException ex) { } } + + @Test + public void testRegressionHive6399() { + + Decimal128 op1 = new Decimal128("-605044214913338382", (short) 0); + Decimal128 op2 = new Decimal128("55269579109718297360", (short) 0); + Decimal128 expected = new Decimal128("-33440539101030154945490585226577271520", (short) 0); + + Decimal128 result = new Decimal128(); + Decimal128.multiply(op1, op2, result, (short) 0); + + assertEquals(expected, result); + } @Test public void testToLong() {