diff --git ql/src/java/org/apache/hadoop/hive/ql/io/arrow/Serializer.java ql/src/java/org/apache/hadoop/hive/ql/io/arrow/Serializer.java index 2961050532..81bc1bea0b 100644 --- ql/src/java/org/apache/hadoop/hive/ql/io/arrow/Serializer.java +++ ql/src/java/org/apache/hadoop/hive/ql/io/arrow/Serializer.java @@ -38,10 +38,13 @@ import org.apache.arrow.vector.complex.ListVector; import org.apache.arrow.vector.complex.MapVector; import org.apache.arrow.vector.complex.NullableMapVector; +import org.apache.arrow.vector.holders.DecimalHolder; import org.apache.arrow.vector.types.TimeUnit; import org.apache.arrow.vector.types.Types; import org.apache.arrow.vector.types.pojo.ArrowType; import org.apache.arrow.vector.types.pojo.FieldType; +import org.apache.arrow.vector.util.DecimalUtility; +import org.apache.hadoop.hive.common.type.HiveDecimal; import org.apache.hadoop.hive.conf.HiveConf; import org.apache.hadoop.hive.ql.exec.vector.BytesColumnVector; import org.apache.hadoop.hive.ql.exec.vector.ColumnVector; @@ -58,6 +61,7 @@ import org.apache.hadoop.hive.ql.exec.vector.VectorizedRowBatch; import org.apache.hadoop.hive.ql.metadata.HiveException; import org.apache.hadoop.hive.serde2.SerDeException; +import org.apache.hadoop.hive.serde2.io.HiveDecimalWritable; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils; import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector; @@ -70,6 +74,8 @@ import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo; import org.apache.hadoop.hive.serde2.typeinfo.UnionTypeInfo; +import java.math.BigDecimal; +import java.math.BigInteger; import java.util.ArrayList; import java.util.List; @@ -87,6 +93,7 @@ class Serializer { private final int MAX_BUFFERED_ROWS; + private final ArrowColumnarBatchSerDe serDe; // Schema private final StructTypeInfo structTypeInfo; @@ -100,6 +107,7 @@ private final NullableMapVector rootVector; Serializer(ArrowColumnarBatchSerDe serDe) throws SerDeException { + this.serDe = serDe; MAX_BUFFERED_ROWS = HiveConf.getIntVar(serDe.conf, HIVE_ARROW_BATCH_SIZE); ArrowColumnarBatchSerDe.LOG.info("ArrowColumnarBatchSerDe max number of buffered columns: " + MAX_BUFFERED_ROWS); @@ -470,14 +478,23 @@ private void writePrimitive(FieldVector arrowVector, ColumnVector hiveVector, Ty break; case DECIMAL: { - final DecimalVector decimalVector = (DecimalVector) arrowVector; - final int scale = decimalVector.getScale(); - for (int i = 0; i < size; i++) { - if (hiveVector.isNull[i]) { - decimalVector.setNull(i); - } else { - decimalVector.set(i, - ((DecimalColumnVector) hiveVector).vector[i].getHiveDecimal().bigDecimalValue().setScale(scale)); + try (ArrowBuf arrowBuf = serDe.rootAllocator.buffer(DecimalHolder.WIDTH)) { + final DecimalVector decimalVector = (DecimalVector) arrowVector; + final int scale = decimalVector.getScale(); + final DecimalHolder decimalHolder = new DecimalHolder(); + + for (int i = 0; i < size; i++) { + if (hiveVector.isNull[i]) { + decimalVector.setNull(i); + } else { + final HiveDecimalWritable writable = ((DecimalColumnVector) hiveVector).vector[i]; + decimalHolder.precision = writable.precision(); + decimalHolder.scale = scale; + decimalHolder.buffer = arrowBuf; + final BigInteger bigInteger = new BigInteger(writable.getInternalStorage()). + multiply(BigInteger.TEN.pow(scale - writable.scale())); + decimalVector.set(i, new BigDecimal(bigInteger, scale)); + } } } } diff --git ql/src/test/org/apache/hadoop/hive/ql/io/arrow/TestArrowColumnarBatchSerDe.java ql/src/test/org/apache/hadoop/hive/ql/io/arrow/TestArrowColumnarBatchSerDe.java index c9a5812e47..182aff7c7b 100644 --- ql/src/test/org/apache/hadoop/hive/ql/io/arrow/TestArrowColumnarBatchSerDe.java +++ ql/src/test/org/apache/hadoop/hive/ql/io/arrow/TestArrowColumnarBatchSerDe.java @@ -525,6 +525,31 @@ public void testPrimitiveDecimal() throws SerDeException { initAndSerializeAndDeserialize(schema, DECIMAL_ROWS); } + @Test + public void testRandomPrimitiveDecimal() throws SerDeException { + String[][] schema = { + {"decimal1", "decimal(38,10)"}, + }; + + int size = 1000; + Object[][] randomDecimals = new Object[size][]; + Random random = new Random(); + for (int i = 0; i < size; i++) { + StringBuilder builder = new StringBuilder(); + builder.append(random.nextBoolean() ? '+' : '-'); + for (int j = 0; j < 28 ; j++) { + builder.append(random.nextInt(10)); + } + builder.append('.'); + for (int j = 0; j < 10; j++) { + builder.append(random.nextInt(10)); + } + randomDecimals[i] = new Object[] {decimalW(HiveDecimal.create(builder.toString()))}; + } + + initAndSerializeAndDeserialize(schema, randomDecimals); + } + @Test public void testPrimitiveBoolean() throws SerDeException { String[][] schema = {