diff --git ql/src/java/org/apache/hadoop/hive/ql/io/arrow/ArrowColumnarBatchSerDe.java ql/src/java/org/apache/hadoop/hive/ql/io/arrow/ArrowColumnarBatchSerDe.java index 330fa580e7..b093ebbd27 100644 --- ql/src/java/org/apache/hadoop/hive/ql/io/arrow/ArrowColumnarBatchSerDe.java +++ ql/src/java/org/apache/hadoop/hive/ql/io/arrow/ArrowColumnarBatchSerDe.java @@ -18,78 +18,26 @@ package org.apache.hadoop.hive.ql.io.arrow; import com.google.common.collect.Lists; -import com.google.common.collect.Maps; -import io.netty.buffer.ArrowBuf; import org.apache.arrow.memory.BufferAllocator; -import org.apache.arrow.vector.FieldVector; -import org.apache.arrow.vector.VectorSchemaRoot; import org.apache.arrow.vector.complex.impl.UnionListWriter; -import org.apache.arrow.vector.complex.impl.UnionReader; -import org.apache.arrow.vector.complex.impl.UnionWriter; -import org.apache.arrow.vector.complex.reader.FieldReader; import org.apache.arrow.vector.complex.writer.BaseWriter; -import org.apache.arrow.vector.complex.writer.BigIntWriter; -import org.apache.arrow.vector.complex.writer.BitWriter; -import org.apache.arrow.vector.complex.writer.DateDayWriter; -import org.apache.arrow.vector.complex.writer.DecimalWriter; -import org.apache.arrow.vector.complex.writer.FieldWriter; -import org.apache.arrow.vector.complex.writer.Float4Writer; -import org.apache.arrow.vector.complex.writer.Float8Writer; -import org.apache.arrow.vector.complex.writer.IntWriter; -import org.apache.arrow.vector.complex.writer.IntervalDayWriter; -import org.apache.arrow.vector.complex.writer.IntervalYearWriter; -import org.apache.arrow.vector.complex.writer.SmallIntWriter; -import org.apache.arrow.vector.complex.writer.TimeStampMilliWriter; -import org.apache.arrow.vector.complex.writer.TinyIntWriter; -import org.apache.arrow.vector.complex.writer.VarBinaryWriter; -import org.apache.arrow.vector.complex.writer.VarCharWriter; -import org.apache.arrow.vector.holders.NullableBigIntHolder; -import org.apache.arrow.vector.holders.NullableBitHolder; -import org.apache.arrow.vector.holders.NullableDateDayHolder; -import org.apache.arrow.vector.holders.NullableFloat4Holder; -import org.apache.arrow.vector.holders.NullableFloat8Holder; -import org.apache.arrow.vector.holders.NullableIntHolder; -import org.apache.arrow.vector.holders.NullableIntervalDayHolder; -import org.apache.arrow.vector.holders.NullableIntervalYearHolder; -import org.apache.arrow.vector.holders.NullableSmallIntHolder; -import org.apache.arrow.vector.holders.NullableTimeStampMilliHolder; -import org.apache.arrow.vector.holders.NullableTinyIntHolder; -import org.apache.arrow.vector.holders.NullableVarBinaryHolder; -import org.apache.arrow.vector.holders.NullableVarCharHolder; import org.apache.arrow.vector.types.TimeUnit; -import org.apache.arrow.vector.types.Types; +import org.apache.arrow.vector.types.Types.MinorType; import org.apache.arrow.vector.types.pojo.ArrowType; import org.apache.arrow.vector.types.pojo.Field; import org.apache.arrow.vector.types.pojo.FieldType; -import org.apache.arrow.vector.types.pojo.Schema; import org.apache.hadoop.conf.Configuration; -import org.apache.hadoop.hive.common.type.HiveDecimal; -import org.apache.hadoop.hive.common.type.HiveIntervalDayTime; -import org.apache.hadoop.hive.conf.HiveConf; -import org.apache.hadoop.hive.ql.exec.vector.BytesColumnVector; import org.apache.hadoop.hive.ql.exec.vector.ColumnVector; -import org.apache.hadoop.hive.ql.exec.vector.DecimalColumnVector; -import org.apache.hadoop.hive.ql.exec.vector.DoubleColumnVector; -import org.apache.hadoop.hive.ql.exec.vector.IntervalDayTimeColumnVector; import org.apache.hadoop.hive.ql.exec.vector.ListColumnVector; -import org.apache.hadoop.hive.ql.exec.vector.LongColumnVector; import org.apache.hadoop.hive.ql.exec.vector.MapColumnVector; import org.apache.hadoop.hive.ql.exec.vector.StructColumnVector; -import org.apache.hadoop.hive.ql.exec.vector.TimestampColumnVector; -import org.apache.hadoop.hive.ql.exec.vector.UnionColumnVector; import org.apache.hadoop.hive.ql.exec.vector.VectorAssignRow; -import org.apache.hadoop.hive.ql.exec.vector.VectorExtractRow; -import org.apache.hadoop.hive.ql.exec.vector.VectorizedBatchUtil; -import org.apache.hadoop.hive.ql.exec.vector.VectorizedRowBatch; -import org.apache.hadoop.hive.ql.metadata.HiveException; import org.apache.hadoop.hive.serde.serdeConstants; import org.apache.hadoop.hive.serde2.AbstractSerDe; import org.apache.hadoop.hive.serde2.SerDeException; import org.apache.hadoop.hive.serde2.SerDeStats; import org.apache.hadoop.hive.serde2.SerDeUtils; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; -import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils; -import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector; import org.apache.hadoop.hive.serde2.typeinfo.DecimalTypeInfo; import org.apache.hadoop.hive.serde2.typeinfo.ListTypeInfo; @@ -107,20 +55,12 @@ import java.io.DataInput; import java.io.DataOutput; -import java.lang.reflect.Method; -import java.sql.Timestamp; import java.util.ArrayList; import java.util.Arrays; import java.util.List; -import java.util.Map; import java.util.Properties; -import java.util.function.IntConsumer; -import static org.apache.hadoop.hive.conf.HiveConf.ConfVars.HIVE_ARROW_BATCH_SIZE; -import static org.apache.hadoop.hive.ql.exec.vector.VectorizedBatchUtil.createColumnVector; -import static org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils.ObjectInspectorCopyOption.WRITABLE; import static org.apache.hadoop.hive.serde2.typeinfo.TypeInfoUtils.getStandardWritableObjectInspectorFromTypeInfo; -import static org.apache.hadoop.hive.serde2.typeinfo.TypeInfoUtils.getTypeInfoFromObjectInspector; /** * ArrowColumnarBatchSerDe converts Apache Hive rows to Apache Arrow columns. Its serialized @@ -143,17 +83,16 @@ public static final Logger LOG = LoggerFactory.getLogger(ArrowColumnarBatchSerDe.class.getName()); private static final String DEFAULT_ARROW_FIELD_NAME = "[DEFAULT]"; - private static final int MS_PER_SECOND = 1_000; - private static final int MS_PER_MINUTE = MS_PER_SECOND * 60; - private static final int MS_PER_HOUR = MS_PER_MINUTE * 60; - private static final int MS_PER_DAY = MS_PER_HOUR * 24; - private static final int NS_PER_MS = 1_000_000; + static final int MS_PER_SECOND = 1_000; + static final int NS_PER_SECOND = 1_000_000_000; + static final int NS_PER_MS = 1_000_000; + static final int SECOND_PER_DAY = 24 * 60 * 60; - private BufferAllocator rootAllocator; + BufferAllocator rootAllocator; + StructTypeInfo rowTypeInfo; + StructObjectInspector rowObjectInspector; + Configuration conf; - private StructTypeInfo rowTypeInfo; - private StructObjectInspector rowObjectInspector; - private Configuration conf; private Serializer serializer; private Deserializer deserializer; @@ -191,859 +130,8 @@ public void initialize(Configuration conf, Properties tbl) throws SerDeException fields.add(toField(columnNames.get(i), columnTypes.get(i))); } - serializer = new Serializer(new Schema(fields)); - deserializer = new Deserializer(); - } - - private class Serializer { - private final int MAX_BUFFERED_ROWS; - - // Schema - private final StructTypeInfo structTypeInfo; - private final List fieldTypeInfos; - private final int fieldSize; - - // Hive columns - private final VectorizedRowBatch vectorizedRowBatch; - private final VectorAssignRow vectorAssignRow; - private int batchSize; - - // Arrow columns - private final VectorSchemaRoot vectorSchemaRoot; - private final List arrowVectors; - private final List fieldWriters; - - private Serializer(Schema schema) throws SerDeException { - MAX_BUFFERED_ROWS = HiveConf.getIntVar(conf, HIVE_ARROW_BATCH_SIZE); - LOG.info("ArrowColumnarBatchSerDe max number of buffered columns: " + MAX_BUFFERED_ROWS); - - // Schema - structTypeInfo = (StructTypeInfo) getTypeInfoFromObjectInspector(rowObjectInspector); - fieldTypeInfos = structTypeInfo.getAllStructFieldTypeInfos(); - fieldSize = fieldTypeInfos.size(); - - // Init Arrow stuffs - vectorSchemaRoot = VectorSchemaRoot.create(schema, rootAllocator); - arrowVectors = vectorSchemaRoot.getFieldVectors(); - fieldWriters = Lists.newArrayList(); - for (FieldVector fieldVector : arrowVectors) { - final FieldWriter fieldWriter = - Types.getMinorTypeForArrowType( - fieldVector.getField().getType()).getNewFieldWriter(fieldVector); - fieldWriters.add(fieldWriter); - } - - // Init Hive stuffs - vectorizedRowBatch = new VectorizedRowBatch(fieldSize); - for (int i = 0; i < fieldSize; i++) { - final ColumnVector columnVector = createColumnVector(fieldTypeInfos.get(i)); - vectorizedRowBatch.cols[i] = columnVector; - columnVector.init(); - } - vectorizedRowBatch.ensureSize(MAX_BUFFERED_ROWS); - vectorAssignRow = new VectorAssignRow(); - try { - vectorAssignRow.init(rowObjectInspector); - } catch (HiveException e) { - throw new SerDeException(e); - } - } - - private ArrowWrapperWritable serializeBatch() { - for (int i = 0; i < vectorizedRowBatch.projectionSize; i++) { - final int projectedColumn = vectorizedRowBatch.projectedColumns[i]; - final ColumnVector hiveVector = vectorizedRowBatch.cols[projectedColumn]; - final TypeInfo fieldTypeInfo = structTypeInfo.getAllStructFieldTypeInfos().get(i); - final FieldWriter fieldWriter = fieldWriters.get(i); - final FieldVector arrowVector = arrowVectors.get(i); - arrowVector.setValueCount(0); - fieldWriter.setPosition(0); - write(fieldWriter, arrowVector, hiveVector, fieldTypeInfo, 0, batchSize, true); - } - vectorizedRowBatch.reset(); - vectorSchemaRoot.setRowCount(batchSize); - - batchSize = 0; - return new ArrowWrapperWritable(vectorSchemaRoot); - } - - private BaseWriter getWriter(FieldWriter writer, TypeInfo typeInfo, String name) { - switch (typeInfo.getCategory()) { - case PRIMITIVE: - switch (((PrimitiveTypeInfo) typeInfo).getPrimitiveCategory()) { - case BOOLEAN: - return writer.bit(name); - case BYTE: - return writer.tinyInt(name); - case SHORT: - return writer.smallInt(name); - case INT: - return writer.integer(name); - case LONG: - return writer.bigInt(name); - case FLOAT: - return writer.float4(name); - case DOUBLE: - return writer.float8(name); - case STRING: - case VARCHAR: - case CHAR: - return writer.varChar(name); - case DATE: - return writer.dateDay(name); - case TIMESTAMP: - return writer.timeStampMilli(name); - case BINARY: - return writer.varBinary(name); - case DECIMAL: - final DecimalTypeInfo decimalTypeInfo = (DecimalTypeInfo) typeInfo; - final int scale = decimalTypeInfo.scale(); - final int precision = decimalTypeInfo.precision(); - return writer.decimal(name, scale, precision); - case INTERVAL_YEAR_MONTH: - return writer.intervalYear(name); - case INTERVAL_DAY_TIME: - return writer.intervalDay(name); - case TIMESTAMPLOCALTZ: // VectorAssignRow doesn't support it - case VOID: - case UNKNOWN: - default: - throw new IllegalArgumentException(); - } - case LIST: - case UNION: - return writer.list(name); - case STRUCT: - return writer.map(name); - case MAP: // The caller will convert map to array - return writer.list(name).map(); - default: - throw new IllegalArgumentException(); - } - } - - private BaseWriter getWriter(FieldWriter writer, TypeInfo typeInfo) { - switch (typeInfo.getCategory()) { - case PRIMITIVE: - switch (((PrimitiveTypeInfo) typeInfo).getPrimitiveCategory()) { - case BOOLEAN: - return writer.bit(); - case BYTE: - return writer.tinyInt(); - case SHORT: - return writer.smallInt(); - case INT: - return writer.integer(); - case LONG: - return writer.bigInt(); - case FLOAT: - return writer.float4(); - case DOUBLE: - return writer.float8(); - case STRING: - case VARCHAR: - case CHAR: - return writer.varChar(); - case DATE: - return writer.dateDay(); - case TIMESTAMP: - return writer.timeStampMilli(); - case BINARY: - return writer.varBinary(); - case INTERVAL_YEAR_MONTH: - return writer.intervalDay(); - case INTERVAL_DAY_TIME: - return writer.intervalYear(); - case TIMESTAMPLOCALTZ: // VectorAssignRow doesn't support it - case DECIMAL: // ListVector doesn't support it - case VOID: - case UNKNOWN: - default: - throw new IllegalArgumentException(); - } - case LIST: - case UNION: - return writer.list(); - case STRUCT: - return writer.map(); - case MAP: // The caller will convert map to array - return writer.list().map(); - default: - throw new IllegalArgumentException(); - } - } - - private void write(BaseWriter baseWriter, FieldVector arrowVector, ColumnVector hiveVector, - TypeInfo typeInfo, int offset, int length, boolean incrementIndex) { - - final IntConsumer writer; - switch (typeInfo.getCategory()) { - case PRIMITIVE: - final PrimitiveObjectInspector.PrimitiveCategory primitiveCategory = - ((PrimitiveTypeInfo) typeInfo).getPrimitiveCategory(); - switch (primitiveCategory) { - case BOOLEAN: - writer = index -> ((BitWriter) baseWriter).writeBit( - (int) ((LongColumnVector) hiveVector).vector[index]); - break; - case BYTE: - writer = index -> - ((TinyIntWriter) baseWriter).writeTinyInt( - (byte) ((LongColumnVector) hiveVector).vector[index]); - break; - case SHORT: - writer = index -> ((SmallIntWriter) baseWriter).writeSmallInt( - (short) ((LongColumnVector) hiveVector).vector[index]); - break; - case INT: - writer = index -> ((IntWriter) baseWriter).writeInt( - (int) ((LongColumnVector) hiveVector).vector[index]); - break; - case LONG: - writer = index -> ((BigIntWriter) baseWriter).writeBigInt( - ((LongColumnVector) hiveVector).vector[index]); - break; - case FLOAT: - writer = index -> ((Float4Writer) baseWriter).writeFloat4( - (float) ((DoubleColumnVector) hiveVector).vector[index]); - break; - case DOUBLE: - writer = index -> ((Float8Writer) baseWriter).writeFloat8( - ((DoubleColumnVector) hiveVector).vector[index]); - break; - case STRING: - case VARCHAR: - case CHAR: - writer = index -> { - BytesColumnVector stringVector = (BytesColumnVector) hiveVector; - byte[] bytes = stringVector.vector[index]; - int start = stringVector.start[index]; - int bytesLength = stringVector.length[index]; - try (ArrowBuf arrowBuf = rootAllocator.buffer(bytesLength)) { - arrowBuf.setBytes(0, bytes, start, bytesLength); - ((VarCharWriter) baseWriter).writeVarChar(0, bytesLength, arrowBuf); - } - }; - break; - case DATE: - writer = index -> ((DateDayWriter) baseWriter).writeDateDay( - (int) ((LongColumnVector) hiveVector).vector[index]); - break; - case TIMESTAMP: - writer = index -> ((TimeStampMilliWriter) baseWriter).writeTimeStampMilli( - ((TimestampColumnVector) hiveVector).getTime(index)); - break; - case BINARY: - writer = index -> { - BytesColumnVector binaryVector = (BytesColumnVector) hiveVector; - final byte[] bytes = binaryVector.vector[index]; - final int start = binaryVector.start[index]; - final int byteLength = binaryVector.length[index]; - try (ArrowBuf arrowBuf = rootAllocator.buffer(byteLength)) { - arrowBuf.setBytes(0, bytes, start, byteLength); - ((VarBinaryWriter) baseWriter).writeVarBinary(0, byteLength, arrowBuf); - } - }; - break; - case DECIMAL: - writer = index -> { - DecimalColumnVector hiveDecimalVector = (DecimalColumnVector) hiveVector; - ((DecimalWriter) baseWriter).writeDecimal( - hiveDecimalVector.vector[index].getHiveDecimal().bigDecimalValue() - .setScale(hiveDecimalVector.scale)); - }; - break; - case INTERVAL_YEAR_MONTH: - writer = index -> ((IntervalYearWriter) baseWriter).writeIntervalYear( - (int) ((LongColumnVector) hiveVector).vector[index]); - break; - case INTERVAL_DAY_TIME: - writer = index -> { - IntervalDayTimeColumnVector intervalDayTimeVector = - (IntervalDayTimeColumnVector) hiveVector; - final long millis = (intervalDayTimeVector.getTotalSeconds(index) * 1_000) + - (intervalDayTimeVector.getNanos(index) / 1_000_000); - final int days = (int) (millis / MS_PER_DAY); - ((IntervalDayWriter) baseWriter).writeIntervalDay( - days, (int) (millis % MS_PER_DAY)); - }; - break; - case VOID: - case UNKNOWN: - case TIMESTAMPLOCALTZ: - default: - throw new IllegalArgumentException(); - } - break; - case LIST: - final ListTypeInfo listTypeInfo = (ListTypeInfo) typeInfo; - final TypeInfo elementTypeInfo = listTypeInfo.getListElementTypeInfo(); - final ListColumnVector hiveListVector = (ListColumnVector) hiveVector; - final ColumnVector hiveElementVector = hiveListVector.child; - final FieldVector arrowElementVector = arrowVector.getChildrenFromFields().get(0); - final BaseWriter.ListWriter listWriter = (BaseWriter.ListWriter) baseWriter; - final BaseWriter elementWriter = getWriter((FieldWriter) baseWriter, elementTypeInfo); - - writer = index -> { - final int listOffset = (int) hiveListVector.offsets[index]; - final int listLength = (int) hiveListVector.lengths[index]; - listWriter.startList(); - write(elementWriter, arrowElementVector, hiveElementVector, elementTypeInfo, - listOffset, listLength, false); - listWriter.endList(); - }; - - incrementIndex = false; - break; - case STRUCT: - final StructTypeInfo structTypeInfo = (StructTypeInfo) typeInfo; - final List fieldTypeInfos = structTypeInfo.getAllStructFieldTypeInfos(); - final StructColumnVector hiveStructVector = (StructColumnVector) hiveVector; - final List arrowFieldVectors = arrowVector.getChildrenFromFields(); - final ColumnVector[] hiveFieldVectors = hiveStructVector.fields; - final BaseWriter.MapWriter structWriter = (BaseWriter.MapWriter) baseWriter; - final int fieldSize = fieldTypeInfos.size(); - - writer = index -> { - structWriter.start(); - for (int fieldIndex = 0; fieldIndex < fieldSize; fieldIndex++) { - final TypeInfo fieldTypeInfo = fieldTypeInfos.get(fieldIndex); - final String fieldName = structTypeInfo.getAllStructFieldNames().get(fieldIndex); - final ColumnVector hiveFieldVector = hiveFieldVectors[fieldIndex]; - final BaseWriter fieldWriter = getWriter((FieldWriter) structWriter, fieldTypeInfo, - fieldName); - final FieldVector arrowFieldVector = arrowFieldVectors.get(fieldIndex); - write(fieldWriter, arrowFieldVector, hiveFieldVector, fieldTypeInfo, index, 1, false); - } - structWriter.end(); - }; - - incrementIndex = false; - break; - case UNION: - final UnionTypeInfo unionTypeInfo = (UnionTypeInfo) typeInfo; - final List objectTypeInfos = unionTypeInfo.getAllUnionObjectTypeInfos(); - final UnionColumnVector hiveUnionVector = (UnionColumnVector) hiveVector; - final ColumnVector[] hiveObjectVectors = hiveUnionVector.fields; - final UnionWriter unionWriter = (UnionWriter) baseWriter; - - writer = index -> { - final int tag = hiveUnionVector.tags[index]; - final ColumnVector hiveObjectVector = hiveObjectVectors[tag]; - final TypeInfo objectTypeInfo = objectTypeInfos.get(tag); - write(unionWriter, arrowVector, hiveObjectVector, objectTypeInfo, index, 1, false); - }; - break; - case MAP: - final ListTypeInfo structListTypeInfo = - toStructListTypeInfo((MapTypeInfo) typeInfo); - final ListColumnVector structListVector = - toStructListVector((MapColumnVector) hiveVector); - - writer = index -> write(baseWriter, arrowVector, structListVector, structListTypeInfo, - index, length, false); - - incrementIndex = false; - break; - default: - throw new IllegalArgumentException(); - } - - if (hiveVector.noNulls) { - if (hiveVector.isRepeating) { - for (int i = 0; i < length; i++) { - writer.accept(0); - if (incrementIndex) { - baseWriter.setPosition(baseWriter.getPosition() + 1); - } - } - } else { - if (vectorizedRowBatch.selectedInUse) { - for (int j = 0; j < length; j++) { - final int i = vectorizedRowBatch.selected[j]; - writer.accept(offset + i); - if (incrementIndex) { - baseWriter.setPosition(baseWriter.getPosition() + 1); - } - } - } else { - for (int i = 0; i < length; i++) { - writer.accept(offset + i); - if (incrementIndex) { - baseWriter.setPosition(baseWriter.getPosition() + 1); - } - } - } - } - } else { - if (hiveVector.isRepeating) { - for (int i = 0; i < length; i++) { - if (hiveVector.isNull[0]) { - writeNull(baseWriter); - } else { - writer.accept(0); - } - if (incrementIndex) { - baseWriter.setPosition(baseWriter.getPosition() + 1); - } - } - } else { - if (vectorizedRowBatch.selectedInUse) { - for (int j = 0; j < length; j++) { - final int i = vectorizedRowBatch.selected[j]; - if (hiveVector.isNull[offset + i]) { - writeNull(baseWriter); - } else { - writer.accept(offset + i); - } - if (incrementIndex) { - baseWriter.setPosition(baseWriter.getPosition() + 1); - } - } - } else { - for (int i = 0; i < length; i++) { - if (hiveVector.isNull[offset + i]) { - writeNull(baseWriter); - } else { - writer.accept(offset + i); - } - if (incrementIndex) { - baseWriter.setPosition(baseWriter.getPosition() + 1); - } - } - } - } - } - } - - public ArrowWrapperWritable serialize(Object obj, ObjectInspector objInspector) { - // if row is null, it means there are no more rows (closeOp()). - // another case can be that the buffer is full. - if (obj == null) { - return serializeBatch(); - } - List standardObjects = new ArrayList(); - ObjectInspectorUtils.copyToStandardObject(standardObjects, obj, - ((StructObjectInspector) objInspector), WRITABLE); - - vectorAssignRow.assignRow(vectorizedRowBatch, batchSize, standardObjects, fieldSize); - batchSize++; - if (batchSize == MAX_BUFFERED_ROWS) { - return serializeBatch(); - } - return null; - } - } - - private static void writeNull(BaseWriter baseWriter) { - if (baseWriter instanceof UnionListWriter) { - // UnionListWriter should implement AbstractFieldWriter#writeNull - BaseWriter.ListWriter listWriter = ((UnionListWriter) baseWriter).list(); - listWriter.setPosition(listWriter.getPosition() + 1); - } else { - // FieldWriter should have a super method of AbstractFieldWriter#writeNull - try { - Method method = baseWriter.getClass().getMethod("writeNull"); - method.setAccessible(true); - method.invoke(baseWriter); - } catch (Exception e) { - throw new RuntimeException(e); - } - } - } - - private static abstract class PrimitiveReader { - final void read(FieldReader reader, ColumnVector columnVector, int offset, int length) { - for (int i = 0; i < length; i++) { - final int rowIndex = offset + i; - if (reader.isSet()) { - doRead(reader, columnVector, rowIndex); - } else { - VectorizedBatchUtil.setNullColIsNullValue(columnVector, rowIndex); - } - reader.setPosition(reader.getPosition() + 1); - } - } - - abstract void doRead(FieldReader reader, ColumnVector columnVector, int rowIndex); - } - - private class Deserializer { - private final VectorExtractRow vectorExtractRow; - private final VectorizedRowBatch vectorizedRowBatch; - private Object[][] rows; - - public Deserializer() throws SerDeException { - vectorExtractRow = new VectorExtractRow(); - final List fieldTypeInfoList = rowTypeInfo.getAllStructFieldTypeInfos(); - final int fieldCount = fieldTypeInfoList.size(); - final TypeInfo[] typeInfos = fieldTypeInfoList.toArray(new TypeInfo[fieldCount]); - try { - vectorExtractRow.init(typeInfos); - } catch (HiveException e) { - throw new SerDeException(e); - } - - vectorizedRowBatch = new VectorizedRowBatch(fieldCount); - for (int i = 0; i < fieldCount; i++) { - final ColumnVector columnVector = createColumnVector(typeInfos[i]); - columnVector.init(); - vectorizedRowBatch.cols[i] = columnVector; - } - } - - public Object deserialize(Writable writable) { - final ArrowWrapperWritable arrowWrapperWritable = (ArrowWrapperWritable) writable; - final VectorSchemaRoot vectorSchemaRoot = arrowWrapperWritable.getVectorSchemaRoot(); - final List fieldVectors = vectorSchemaRoot.getFieldVectors(); - final int fieldCount = fieldVectors.size(); - final int rowCount = vectorSchemaRoot.getRowCount(); - vectorizedRowBatch.ensureSize(rowCount); - - if (rows == null || rows.length < rowCount ) { - rows = new Object[rowCount][]; - for (int rowIndex = 0; rowIndex < rowCount; rowIndex++) { - rows[rowIndex] = new Object[fieldCount]; - } - } - - for (int i = 0; i < fieldCount; i++) { - final FieldVector fieldVector = fieldVectors.get(i); - final FieldReader fieldReader = fieldVector.getReader(); - fieldReader.setPosition(0); - final int projectedCol = vectorizedRowBatch.projectedColumns[i]; - final ColumnVector columnVector = vectorizedRowBatch.cols[projectedCol]; - final TypeInfo typeInfo = rowTypeInfo.getAllStructFieldTypeInfos().get(i); - read(fieldReader, columnVector, typeInfo, 0, rowCount); - } - for (int rowIndex = 0; rowIndex < rowCount; rowIndex++) { - vectorExtractRow.extractRow(vectorizedRowBatch, rowIndex, rows[rowIndex]); - } - vectorizedRowBatch.reset(); - return rows; - } - - private void read(FieldReader reader, ColumnVector columnVector, TypeInfo typeInfo, - int rowOffset, int rowLength) { - switch (typeInfo.getCategory()) { - case PRIMITIVE: - final PrimitiveObjectInspector.PrimitiveCategory primitiveCategory = - ((PrimitiveTypeInfo) typeInfo).getPrimitiveCategory(); - final PrimitiveReader primitiveReader; - switch (primitiveCategory) { - case BOOLEAN: - primitiveReader = new PrimitiveReader() { - NullableBitHolder holder = new NullableBitHolder(); - - @Override - void doRead(FieldReader reader, ColumnVector columnVector, int rowIndex) { - reader.read(holder); - ((LongColumnVector) columnVector).vector[rowIndex] = holder.value; - } - }; - break; - case BYTE: - primitiveReader = new PrimitiveReader() { - NullableTinyIntHolder holder = new NullableTinyIntHolder(); - - @Override - void doRead(FieldReader reader, ColumnVector columnVector, int rowIndex) { - reader.read(holder); - ((LongColumnVector) columnVector).vector[rowIndex] = holder.value; - } - }; - break; - case SHORT: - primitiveReader = new PrimitiveReader() { - NullableSmallIntHolder holder = new NullableSmallIntHolder(); - - @Override - void doRead(FieldReader reader, ColumnVector columnVector, int rowIndex) { - reader.read(holder); - ((LongColumnVector) columnVector).vector[rowIndex] = holder.value; - } - }; - break; - case INT: - primitiveReader = new PrimitiveReader() { - NullableIntHolder holder = new NullableIntHolder(); - - @Override - void doRead(FieldReader reader, ColumnVector columnVector, int rowIndex) { - reader.read(holder); - ((LongColumnVector) columnVector).vector[rowIndex] = holder.value; - } - }; - break; - case LONG: - primitiveReader = new PrimitiveReader() { - NullableBigIntHolder holder = new NullableBigIntHolder(); - - @Override - void doRead(FieldReader reader, ColumnVector columnVector, int rowIndex) { - reader.read(holder); - ((LongColumnVector) columnVector).vector[rowIndex] = holder.value; - } - }; - break; - case FLOAT: - primitiveReader = new PrimitiveReader() { - NullableFloat4Holder holder = new NullableFloat4Holder(); - - @Override - void doRead(FieldReader reader, ColumnVector columnVector, int rowIndex) { - reader.read(holder); - ((DoubleColumnVector) columnVector).vector[rowIndex] = holder.value; - } - }; - break; - case DOUBLE: - primitiveReader = new PrimitiveReader() { - NullableFloat8Holder holder = new NullableFloat8Holder(); - - @Override - void doRead(FieldReader reader, ColumnVector columnVector, int rowIndex) { - reader.read(holder); - ((DoubleColumnVector) columnVector).vector[rowIndex] = holder.value; - } - }; - break; - case STRING: - case VARCHAR: - case CHAR: - primitiveReader = new PrimitiveReader() { - NullableVarCharHolder holder = new NullableVarCharHolder(); - - @Override - void doRead(FieldReader reader, ColumnVector columnVector, int rowIndex) { - reader.read(holder); - int varCharSize = holder.end - holder.start; - byte[] varCharBytes = new byte[varCharSize]; - holder.buffer.getBytes(holder.start, varCharBytes); - ((BytesColumnVector) columnVector).setVal(rowIndex, varCharBytes, 0, varCharSize); - } - }; - break; - case DATE: - primitiveReader = new PrimitiveReader() { - NullableDateDayHolder holder = new NullableDateDayHolder(); - - @Override - void doRead(FieldReader reader, ColumnVector columnVector, int rowIndex) { - reader.read(holder); - ((LongColumnVector) columnVector).vector[rowIndex] = holder.value; - } - }; - break; - case TIMESTAMP: - primitiveReader = new PrimitiveReader() { - NullableTimeStampMilliHolder timeStampMilliHolder = - new NullableTimeStampMilliHolder(); - - @Override - void doRead(FieldReader reader, ColumnVector columnVector, int rowIndex) { - reader.read(timeStampMilliHolder); - ((TimestampColumnVector) columnVector).set(rowIndex, - new Timestamp(timeStampMilliHolder.value)); - } - }; - break; - case BINARY: - primitiveReader = new PrimitiveReader() { - NullableVarBinaryHolder holder = new NullableVarBinaryHolder(); - - @Override - void doRead(FieldReader reader, ColumnVector columnVector, int rowIndex) { - reader.read(holder); - final int binarySize = holder.end - holder.start; - final byte[] binaryBytes = new byte[binarySize]; - holder.buffer.getBytes(holder.start, binaryBytes); - ((BytesColumnVector) columnVector).setVal(rowIndex, binaryBytes, 0, binarySize); - } - }; - break; - case DECIMAL: - primitiveReader = new PrimitiveReader() { - @Override - void doRead(FieldReader reader, ColumnVector columnVector, int rowIndex) { - ((DecimalColumnVector) columnVector).set(rowIndex, - HiveDecimal.create(reader.readBigDecimal())); - } - }; - break; - case INTERVAL_YEAR_MONTH: - primitiveReader = new PrimitiveReader() { - NullableIntervalYearHolder holder = new NullableIntervalYearHolder(); - - @Override - void doRead(FieldReader reader, ColumnVector columnVector, int rowIndex) { - reader.read(holder); - ((LongColumnVector) columnVector).vector[rowIndex] = holder.value; - } - }; - break; - case INTERVAL_DAY_TIME: - primitiveReader = new PrimitiveReader() { - NullableIntervalDayHolder holder = new NullableIntervalDayHolder(); - - @Override - void doRead(FieldReader reader, ColumnVector columnVector, int rowIndex) { - IntervalDayTimeColumnVector intervalDayTimeVector = - (IntervalDayTimeColumnVector) columnVector; - reader.read(holder); - HiveIntervalDayTime intervalDayTime = new HiveIntervalDayTime( - holder.days, // days - holder.milliseconds / MS_PER_HOUR, // hour - (holder.milliseconds % MS_PER_HOUR) / MS_PER_MINUTE, // minute - (holder.milliseconds % MS_PER_MINUTE) / MS_PER_SECOND, // second - (holder.milliseconds % MS_PER_SECOND) * NS_PER_MS); // nanosecond - intervalDayTimeVector.set(rowIndex, intervalDayTime); - } - }; - break; - default: - throw new IllegalArgumentException(); - } - primitiveReader.read(reader, columnVector, rowOffset, rowLength); - break; - case LIST: - final ListTypeInfo listTypeInfo = (ListTypeInfo) typeInfo; - final TypeInfo elementTypeInfo = listTypeInfo.getListElementTypeInfo(); - final ListColumnVector listVector = (ListColumnVector) columnVector; - final ColumnVector elementVector = listVector.child; - final FieldReader elementReader = reader.reader(); - - int listOffset = 0; - for (int rowIndex = 0; rowIndex < rowLength; rowIndex++) { - final int adjustedRowIndex = rowOffset + rowIndex; - reader.setPosition(adjustedRowIndex); - final int listLength = reader.size(); - listVector.offsets[adjustedRowIndex] = listOffset; - listVector.lengths[adjustedRowIndex] = listLength; - read(elementReader, elementVector, elementTypeInfo, listOffset, listLength); - listOffset += listLength; - } - break; - case STRUCT: - final StructTypeInfo structTypeInfo = (StructTypeInfo) typeInfo; - final List fieldTypeInfos = structTypeInfo.getAllStructFieldTypeInfos(); - final List fieldNames = structTypeInfo.getAllStructFieldNames(); - final int fieldSize = fieldNames.size(); - final StructColumnVector structVector = (StructColumnVector) columnVector; - final ColumnVector[] fieldVectors = structVector.fields; - - for (int fieldIndex = 0; fieldIndex < fieldSize; fieldIndex++) { - final TypeInfo fieldTypeInfo = fieldTypeInfos.get(fieldIndex); - final FieldReader fieldReader = reader.reader(fieldNames.get(fieldIndex)); - final ColumnVector fieldVector = fieldVectors[fieldIndex]; - read(fieldReader, fieldVector, fieldTypeInfo, rowOffset, rowLength); - } - break; - case UNION: - final UnionTypeInfo unionTypeInfo = (UnionTypeInfo) typeInfo; - final List objectTypeInfos = unionTypeInfo.getAllUnionObjectTypeInfos(); - final UnionColumnVector unionVector = (UnionColumnVector) columnVector; - final ColumnVector[] objectVectors = unionVector.fields; - final Map minorTypeToTagMap = Maps.newHashMap(); - for (int tag = 0; tag < objectTypeInfos.size(); tag++) { - minorTypeToTagMap.put(toMinorType(objectTypeInfos.get(tag)), tag); - } - - final UnionReader unionReader = (UnionReader) reader; - for (int rowIndex = 0; rowIndex < rowLength; rowIndex++) { - final int adjustedRowIndex = rowIndex + rowOffset; - unionReader.setPosition(adjustedRowIndex); - final Types.MinorType minorType = unionReader.getMinorType(); - final int tag = minorTypeToTagMap.get(minorType); - unionVector.tags[adjustedRowIndex] = tag; - read(unionReader, objectVectors[tag], objectTypeInfos.get(tag), adjustedRowIndex, 1); - } - break; - case MAP: - final MapTypeInfo mapTypeInfo = (MapTypeInfo) typeInfo; - final ListTypeInfo mapStructListTypeInfo = toStructListTypeInfo(mapTypeInfo); - final MapColumnVector hiveMapVector = (MapColumnVector) columnVector; - final ListColumnVector mapStructListVector = toStructListVector(hiveMapVector); - final StructColumnVector mapStructVector = (StructColumnVector) mapStructListVector.child; - read(reader, mapStructListVector, mapStructListTypeInfo, rowOffset, rowLength); - - hiveMapVector.isRepeating = mapStructListVector.isRepeating; - hiveMapVector.childCount = mapStructListVector.childCount; - hiveMapVector.noNulls = mapStructListVector.noNulls; - System.arraycopy(mapStructListVector.offsets, 0, hiveMapVector.offsets, 0, rowLength); - System.arraycopy(mapStructListVector.lengths, 0, hiveMapVector.lengths, 0, rowLength); - hiveMapVector.keys = mapStructVector.fields[0]; - hiveMapVector.values = mapStructVector.fields[1]; - break; - default: - throw new IllegalArgumentException(); - } - } - } - - private static Types.MinorType toMinorType(TypeInfo typeInfo) { - switch (typeInfo.getCategory()) { - case PRIMITIVE: - switch (((PrimitiveTypeInfo) typeInfo).getPrimitiveCategory()) { - case BOOLEAN: - return Types.MinorType.BIT; - case BYTE: - return Types.MinorType.TINYINT; - case SHORT: - return Types.MinorType.SMALLINT; - case INT: - return Types.MinorType.INT; - case LONG: - return Types.MinorType.BIGINT; - case FLOAT: - return Types.MinorType.FLOAT4; - case DOUBLE: - return Types.MinorType.FLOAT8; - case STRING: - case VARCHAR: - case CHAR: - return Types.MinorType.VARCHAR; - case DATE: - return Types.MinorType.DATEDAY; - case TIMESTAMP: - return Types.MinorType.TIMESTAMPMILLI; - case BINARY: - return Types.MinorType.VARBINARY; - case DECIMAL: - return Types.MinorType.DECIMAL; - case INTERVAL_YEAR_MONTH: - return Types.MinorType.INTERVALYEAR; - case INTERVAL_DAY_TIME: - return Types.MinorType.INTERVALDAY; - case VOID: - case TIMESTAMPLOCALTZ: - case UNKNOWN: - default: - throw new IllegalArgumentException(); - } - case LIST: - return Types.MinorType.LIST; - case STRUCT: - return Types.MinorType.MAP; - case UNION: - return Types.MinorType.UNION; - case MAP: - // Apache Arrow doesn't have a map vector, so it's converted to a list vector of a struct - // vector. - return Types.MinorType.LIST; - default: - throw new IllegalArgumentException(); - } - } - - private static ListTypeInfo toStructListTypeInfo(MapTypeInfo mapTypeInfo) { - final StructTypeInfo structTypeInfo = new StructTypeInfo(); - structTypeInfo.setAllStructFieldNames(Lists.newArrayList("keys", "values")); - structTypeInfo.setAllStructFieldTypeInfos(Lists.newArrayList( - mapTypeInfo.getMapKeyTypeInfo(), mapTypeInfo.getMapValueTypeInfo())); - final ListTypeInfo structListTypeInfo = new ListTypeInfo(); - structListTypeInfo.setListElementTypeInfo(structTypeInfo); - return structListTypeInfo; + serializer = new Serializer(this); + deserializer = new Deserializer(this); } private static Field toField(String name, TypeInfo typeInfo) { @@ -1052,52 +140,50 @@ private static Field toField(String name, TypeInfo typeInfo) { final PrimitiveTypeInfo primitiveTypeInfo = (PrimitiveTypeInfo) typeInfo; switch (primitiveTypeInfo.getPrimitiveCategory()) { case BOOLEAN: - return Field.nullable(name, Types.MinorType.BIT.getType()); + return Field.nullable(name, MinorType.BIT.getType()); case BYTE: - return Field.nullable(name, Types.MinorType.TINYINT.getType()); + return Field.nullable(name, MinorType.TINYINT.getType()); case SHORT: - return Field.nullable(name, Types.MinorType.SMALLINT.getType()); + return Field.nullable(name, MinorType.SMALLINT.getType()); case INT: - return Field.nullable(name, Types.MinorType.INT.getType()); + return Field.nullable(name, MinorType.INT.getType()); case LONG: - return Field.nullable(name, Types.MinorType.BIGINT.getType()); + return Field.nullable(name, MinorType.BIGINT.getType()); case FLOAT: - return Field.nullable(name, Types.MinorType.FLOAT4.getType()); + return Field.nullable(name, MinorType.FLOAT4.getType()); case DOUBLE: - return Field.nullable(name, Types.MinorType.FLOAT8.getType()); + return Field.nullable(name, MinorType.FLOAT8.getType()); case STRING: - return Field.nullable(name, Types.MinorType.VARCHAR.getType()); + case VARCHAR: + case CHAR: + return Field.nullable(name, MinorType.VARCHAR.getType()); case DATE: - return Field.nullable(name, Types.MinorType.DATEDAY.getType()); + return Field.nullable(name, MinorType.DATEDAY.getType()); case TIMESTAMP: - return Field.nullable(name, Types.MinorType.TIMESTAMPMILLI.getType()); + return Field.nullable(name, MinorType.TIMESTAMPMILLI.getType()); case TIMESTAMPLOCALTZ: final TimestampLocalTZTypeInfo timestampLocalTZTypeInfo = (TimestampLocalTZTypeInfo) typeInfo; final String timeZone = timestampLocalTZTypeInfo.getTimeZone().toString(); return Field.nullable(name, new ArrowType.Timestamp(TimeUnit.MILLISECOND, timeZone)); case BINARY: - return Field.nullable(name, Types.MinorType.VARBINARY.getType()); + return Field.nullable(name, MinorType.VARBINARY.getType()); case DECIMAL: final DecimalTypeInfo decimalTypeInfo = (DecimalTypeInfo) typeInfo; final int precision = decimalTypeInfo.precision(); final int scale = decimalTypeInfo.scale(); return Field.nullable(name, new ArrowType.Decimal(precision, scale)); - case VARCHAR: - return Field.nullable(name, Types.MinorType.VARCHAR.getType()); - case CHAR: - return Field.nullable(name, Types.MinorType.VARCHAR.getType()); case INTERVAL_YEAR_MONTH: - return Field.nullable(name, Types.MinorType.INTERVALYEAR.getType()); + return Field.nullable(name, MinorType.INTERVALYEAR.getType()); case INTERVAL_DAY_TIME: - return Field.nullable(name, Types.MinorType.INTERVALDAY.getType()); + return Field.nullable(name, MinorType.INTERVALDAY.getType()); default: throw new IllegalArgumentException(); } case LIST: final ListTypeInfo listTypeInfo = (ListTypeInfo) typeInfo; final TypeInfo elementTypeInfo = listTypeInfo.getListElementTypeInfo(); - return new Field(name, FieldType.nullable(Types.MinorType.LIST.getType()), + return new Field(name, FieldType.nullable(MinorType.LIST.getType()), Lists.newArrayList(toField(DEFAULT_ARROW_FIELD_NAME, elementTypeInfo))); case STRUCT: final StructTypeInfo structTypeInfo = (StructTypeInfo) typeInfo; @@ -1108,7 +194,7 @@ private static Field toField(String name, TypeInfo typeInfo) { for (int i = 0; i < structSize; i++) { structFields.add(toField(fieldNames.get(i), fieldTypeInfos.get(i))); } - return new Field(name, FieldType.nullable(Types.MinorType.MAP.getType()), structFields); + return new Field(name, FieldType.nullable(MinorType.MAP.getType()), structFields); case UNION: final UnionTypeInfo unionTypeInfo = (UnionTypeInfo) typeInfo; final List objectTypeInfos = unionTypeInfo.getAllUnionObjectTypeInfos(); @@ -1117,17 +203,15 @@ private static Field toField(String name, TypeInfo typeInfo) { for (int i = 0; i < unionSize; i++) { unionFields.add(toField(DEFAULT_ARROW_FIELD_NAME, objectTypeInfos.get(i))); } - return new Field(name, FieldType.nullable(Types.MinorType.UNION.getType()), unionFields); + return new Field(name, FieldType.nullable(MinorType.UNION.getType()), unionFields); case MAP: final MapTypeInfo mapTypeInfo = (MapTypeInfo) typeInfo; final TypeInfo keyTypeInfo = mapTypeInfo.getMapKeyTypeInfo(); final TypeInfo valueTypeInfo = mapTypeInfo.getMapValueTypeInfo(); - final StructTypeInfo mapStructTypeInfo = new StructTypeInfo(); mapStructTypeInfo.setAllStructFieldNames(Lists.newArrayList("keys", "values")); mapStructTypeInfo.setAllStructFieldTypeInfos( Lists.newArrayList(keyTypeInfo, valueTypeInfo)); - final ListTypeInfo mapListStructTypeInfo = new ListTypeInfo(); mapListStructTypeInfo.setListElementTypeInfo(mapStructTypeInfo); @@ -1137,18 +221,28 @@ private static Field toField(String name, TypeInfo typeInfo) { } } - private static ListColumnVector toStructListVector(MapColumnVector mapVector) { + static ListTypeInfo toStructListTypeInfo(MapTypeInfo mapTypeInfo) { + final StructTypeInfo structTypeInfo = new StructTypeInfo(); + structTypeInfo.setAllStructFieldNames(Lists.newArrayList("keys", "values")); + structTypeInfo.setAllStructFieldTypeInfos(Lists.newArrayList( + mapTypeInfo.getMapKeyTypeInfo(), mapTypeInfo.getMapValueTypeInfo())); + final ListTypeInfo structListTypeInfo = new ListTypeInfo(); + structListTypeInfo.setListElementTypeInfo(structTypeInfo); + return structListTypeInfo; + } + + static ListColumnVector toStructListVector(MapColumnVector mapVector) { final StructColumnVector structVector; final ListColumnVector structListVector; structVector = new StructColumnVector(); structVector.fields = new ColumnVector[] {mapVector.keys, mapVector.values}; structListVector = new ListColumnVector(); structListVector.child = structVector; - System.arraycopy(mapVector.offsets, 0, structListVector.offsets, 0, mapVector.childCount); - System.arraycopy(mapVector.lengths, 0, structListVector.lengths, 0, mapVector.childCount); structListVector.childCount = mapVector.childCount; structListVector.isRepeating = mapVector.isRepeating; structListVector.noNulls = mapVector.noNulls; + System.arraycopy(mapVector.offsets, 0, structListVector.offsets, 0, mapVector.childCount); + System.arraycopy(mapVector.lengths, 0, structListVector.lengths, 0, mapVector.childCount); return structListVector; } diff --git ql/src/java/org/apache/hadoop/hive/ql/io/arrow/Deserializer.java ql/src/java/org/apache/hadoop/hive/ql/io/arrow/Deserializer.java new file mode 100644 index 0000000000..fb5800b140 --- /dev/null +++ ql/src/java/org/apache/hadoop/hive/ql/io/arrow/Deserializer.java @@ -0,0 +1,423 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.hadoop.hive.ql.io.arrow; + +import io.netty.buffer.ArrowBuf; +import org.apache.arrow.vector.BigIntVector; +import org.apache.arrow.vector.BitVector; +import org.apache.arrow.vector.DateDayVector; +import org.apache.arrow.vector.DecimalVector; +import org.apache.arrow.vector.FieldVector; +import org.apache.arrow.vector.Float4Vector; +import org.apache.arrow.vector.Float8Vector; +import org.apache.arrow.vector.IntVector; +import org.apache.arrow.vector.IntervalDayVector; +import org.apache.arrow.vector.IntervalYearVector; +import org.apache.arrow.vector.SmallIntVector; +import org.apache.arrow.vector.TimeStampNanoVector; +import org.apache.arrow.vector.TinyIntVector; +import org.apache.arrow.vector.VarBinaryVector; +import org.apache.arrow.vector.VarCharVector; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.holders.NullableIntervalDayHolder; +import org.apache.hadoop.hive.common.type.HiveDecimal; +import org.apache.hadoop.hive.common.type.HiveIntervalDayTime; +import org.apache.hadoop.hive.ql.exec.vector.BytesColumnVector; +import org.apache.hadoop.hive.ql.exec.vector.ColumnVector; +import org.apache.hadoop.hive.ql.exec.vector.DecimalColumnVector; +import org.apache.hadoop.hive.ql.exec.vector.DoubleColumnVector; +import org.apache.hadoop.hive.ql.exec.vector.IntervalDayTimeColumnVector; +import org.apache.hadoop.hive.ql.exec.vector.ListColumnVector; +import org.apache.hadoop.hive.ql.exec.vector.LongColumnVector; +import org.apache.hadoop.hive.ql.exec.vector.MapColumnVector; +import org.apache.hadoop.hive.ql.exec.vector.StructColumnVector; +import org.apache.hadoop.hive.ql.exec.vector.TimestampColumnVector; +import org.apache.hadoop.hive.ql.exec.vector.UnionColumnVector; +import org.apache.hadoop.hive.ql.exec.vector.VectorExtractRow; +import org.apache.hadoop.hive.ql.exec.vector.VectorizedBatchUtil; +import org.apache.hadoop.hive.ql.exec.vector.VectorizedRowBatch; +import org.apache.hadoop.hive.ql.metadata.HiveException; +import org.apache.hadoop.hive.serde2.SerDeException; +import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector; +import org.apache.hadoop.hive.serde2.typeinfo.ListTypeInfo; +import org.apache.hadoop.hive.serde2.typeinfo.MapTypeInfo; +import org.apache.hadoop.hive.serde2.typeinfo.PrimitiveTypeInfo; +import org.apache.hadoop.hive.serde2.typeinfo.StructTypeInfo; +import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo; +import org.apache.hadoop.hive.serde2.typeinfo.UnionTypeInfo; +import org.apache.hadoop.io.Writable; + +import java.util.List; + +import static org.apache.hadoop.hive.ql.exec.vector.VectorizedBatchUtil.createColumnVector; +import static org.apache.hadoop.hive.ql.io.arrow.ArrowColumnarBatchSerDe.MS_PER_SECOND; +import static org.apache.hadoop.hive.ql.io.arrow.ArrowColumnarBatchSerDe.NS_PER_MS; +import static org.apache.hadoop.hive.ql.io.arrow.ArrowColumnarBatchSerDe.NS_PER_SECOND; +import static org.apache.hadoop.hive.ql.io.arrow.ArrowColumnarBatchSerDe.SECOND_PER_DAY; +import static org.apache.hadoop.hive.ql.io.arrow.ArrowColumnarBatchSerDe.toStructListTypeInfo; +import static org.apache.hadoop.hive.ql.io.arrow.ArrowColumnarBatchSerDe.toStructListVector; + +class Deserializer { + private final ArrowColumnarBatchSerDe serDe; + private final VectorExtractRow vectorExtractRow; + private final VectorizedRowBatch vectorizedRowBatch; + private Object[][] rows; + + Deserializer(ArrowColumnarBatchSerDe serDe) throws SerDeException { + this.serDe = serDe; + vectorExtractRow = new VectorExtractRow(); + final List fieldTypeInfoList = serDe.rowTypeInfo.getAllStructFieldTypeInfos(); + final int fieldCount = fieldTypeInfoList.size(); + final TypeInfo[] typeInfos = fieldTypeInfoList.toArray(new TypeInfo[fieldCount]); + try { + vectorExtractRow.init(typeInfos); + } catch (HiveException e) { + throw new SerDeException(e); + } + + vectorizedRowBatch = new VectorizedRowBatch(fieldCount); + for (int fieldIndex = 0; fieldIndex < fieldCount; fieldIndex++) { + final ColumnVector columnVector = createColumnVector(typeInfos[fieldIndex]); + columnVector.init(); + vectorizedRowBatch.cols[fieldIndex] = columnVector; + } + } + + public Object deserialize(Writable writable) { + final ArrowWrapperWritable arrowWrapperWritable = (ArrowWrapperWritable) writable; + final VectorSchemaRoot vectorSchemaRoot = arrowWrapperWritable.getVectorSchemaRoot(); + final List fieldVectors = vectorSchemaRoot.getFieldVectors(); + final int fieldCount = fieldVectors.size(); + final int rowCount = vectorSchemaRoot.getRowCount(); + vectorizedRowBatch.ensureSize(rowCount); + + if (rows == null || rows.length < rowCount ) { + rows = new Object[rowCount][]; + for (int rowIndex = 0; rowIndex < rowCount; rowIndex++) { + rows[rowIndex] = new Object[fieldCount]; + } + } + + for (int fieldIndex = 0; fieldIndex < fieldCount; fieldIndex++) { + final FieldVector fieldVector = fieldVectors.get(fieldIndex); + final int projectedCol = vectorizedRowBatch.projectedColumns[fieldIndex]; + final ColumnVector columnVector = vectorizedRowBatch.cols[projectedCol]; + final TypeInfo typeInfo = serDe.rowTypeInfo.getAllStructFieldTypeInfos().get(fieldIndex); + read(fieldVector, columnVector, typeInfo); + } + for (int rowIndex = 0; rowIndex < rowCount; rowIndex++) { + vectorExtractRow.extractRow(vectorizedRowBatch, rowIndex, rows[rowIndex]); + } + vectorizedRowBatch.reset(); + return rows; + } + + private void read(FieldVector arrowVector, ColumnVector hiveVector, TypeInfo typeInfo) { + switch (typeInfo.getCategory()) { + case PRIMITIVE: + readPrimitive(arrowVector, hiveVector, typeInfo); + break; + case LIST: + readList(arrowVector, (ListColumnVector) hiveVector, (ListTypeInfo) typeInfo); + break; + case MAP: + readMap(arrowVector, (MapColumnVector) hiveVector, (MapTypeInfo) typeInfo); + break; + case STRUCT: + readStruct(arrowVector, (StructColumnVector) hiveVector, (StructTypeInfo) typeInfo); + break; + case UNION: + readUnion(arrowVector, (UnionColumnVector) hiveVector, (UnionTypeInfo) typeInfo); + break; + default: + throw new IllegalArgumentException(); + } + } + + private void readPrimitive(FieldVector arrowVector, ColumnVector hiveVector, TypeInfo typeInfo) { + final PrimitiveObjectInspector.PrimitiveCategory primitiveCategory = + ((PrimitiveTypeInfo) typeInfo).getPrimitiveCategory(); + + final int size = arrowVector.getValueCount(); + hiveVector.ensureSize(size, false); + + switch (primitiveCategory) { + case BOOLEAN: + { + for (int i = 0; i < size; i++) { + if (arrowVector.isNull(i)) { + VectorizedBatchUtil.setNullColIsNullValue(hiveVector, i); + } else { + hiveVector.isNull[i] = false; + ((LongColumnVector) hiveVector).vector[i] = ((BitVector) arrowVector).get(i); + } + } + } + break; + case BYTE: + { + for (int i = 0; i < size; i++) { + if (arrowVector.isNull(i)) { + VectorizedBatchUtil.setNullColIsNullValue(hiveVector, i); + } else { + hiveVector.isNull[i] = false; + ((LongColumnVector) hiveVector).vector[i] = ((TinyIntVector) arrowVector).get(i); + } + } + } + break; + case SHORT: + { + for (int i = 0; i < size; i++) { + if (arrowVector.isNull(i)) { + VectorizedBatchUtil.setNullColIsNullValue(hiveVector, i); + } else { + hiveVector.isNull[i] = false; + ((LongColumnVector) hiveVector).vector[i] = ((SmallIntVector) arrowVector).get(i); + } + } + } + break; + case INT: + { + for (int i = 0; i < size; i++) { + if (arrowVector.isNull(i)) { + VectorizedBatchUtil.setNullColIsNullValue(hiveVector, i); + } else { + hiveVector.isNull[i] = false; + ((LongColumnVector) hiveVector).vector[i] = ((IntVector) arrowVector).get(i); + } + } + } + break; + case LONG: + { + for (int i = 0; i < size; i++) { + if (arrowVector.isNull(i)) { + VectorizedBatchUtil.setNullColIsNullValue(hiveVector, i); + } else { + hiveVector.isNull[i] = false; + ((LongColumnVector) hiveVector).vector[i] = ((BigIntVector) arrowVector).get(i); + } + } + } + break; + case FLOAT: + { + for (int i = 0; i < size; i++) { + if (arrowVector.isNull(i)) { + VectorizedBatchUtil.setNullColIsNullValue(hiveVector, i); + } else { + hiveVector.isNull[i] = false; + ((DoubleColumnVector) hiveVector).vector[i] = ((Float4Vector) arrowVector).get(i); + } + } + } + break; + case DOUBLE: + { + for (int i = 0; i < size; i++) { + if (arrowVector.isNull(i)) { + VectorizedBatchUtil.setNullColIsNullValue(hiveVector, i); + } else { + hiveVector.isNull[i] = false; + ((DoubleColumnVector) hiveVector).vector[i] = ((Float8Vector) arrowVector).get(i); + } + } + } + break; + case STRING: + case VARCHAR: + case CHAR: + { + for (int i = 0; i < size; i++) { + if (arrowVector.isNull(i)) { + VectorizedBatchUtil.setNullColIsNullValue(hiveVector, i); + } else { + hiveVector.isNull[i] = false; + ((BytesColumnVector) hiveVector).setVal(i, ((VarCharVector) arrowVector).get(i)); + } + } + } + break; + case DATE: + { + for (int i = 0; i < size; i++) { + if (arrowVector.isNull(i)) { + VectorizedBatchUtil.setNullColIsNullValue(hiveVector, i); + } else { + hiveVector.isNull[i] = false; + ((LongColumnVector) hiveVector).vector[i] = ((DateDayVector) arrowVector).get(i); + } + } + } + break; + case TIMESTAMP: + { + for (int i = 0; i < size; i++) { + if (arrowVector.isNull(i)) { + VectorizedBatchUtil.setNullColIsNullValue(hiveVector, i); + } else { + hiveVector.isNull[i] = false; + + // Time = second + sub-second + final long timeInNanos = ((TimeStampNanoVector) arrowVector).get(i); + final TimestampColumnVector timestampColumnVector = (TimestampColumnVector) hiveVector; + int subSecondInNanos = (int) (timeInNanos % NS_PER_SECOND); + long second = timeInNanos / NS_PER_SECOND; + + // A nanosecond value should not be negative + if (subSecondInNanos < 0) { + + // So add one second to the negative nanosecond value to make it positive + subSecondInNanos += NS_PER_SECOND; + + // Subtract one second from the second value because we added one second, + // then subtract one more second because of the ceiling in the division. + second -= 2; + } + timestampColumnVector.time[i] = second * MS_PER_SECOND; + timestampColumnVector.nanos[i] = subSecondInNanos; + } + } + } + break; + case BINARY: + { + for (int i = 0; i < size; i++) { + if (arrowVector.isNull(i)) { + VectorizedBatchUtil.setNullColIsNullValue(hiveVector, i); + } else { + hiveVector.isNull[i] = false; + ((BytesColumnVector) hiveVector).setVal(i, ((VarBinaryVector) arrowVector).get(i)); + } + } + } + break; + case DECIMAL: + { + for (int i = 0; i < size; i++) { + if (arrowVector.isNull(i)) { + VectorizedBatchUtil.setNullColIsNullValue(hiveVector, i); + } else { + hiveVector.isNull[i] = false; + ((DecimalColumnVector) hiveVector).set(i, + HiveDecimal.create(((DecimalVector) arrowVector).getObject(i))); + } + } + } + break; + case INTERVAL_YEAR_MONTH: + { + for (int i = 0; i < size; i++) { + if (arrowVector.isNull(i)) { + VectorizedBatchUtil.setNullColIsNullValue(hiveVector, i); + } else { + hiveVector.isNull[i] = false; + ((LongColumnVector) hiveVector).vector[i] = ((IntervalYearVector) arrowVector).get(i); + } + } + } + break; + case INTERVAL_DAY_TIME: + { + final IntervalDayVector intervalDayVector = (IntervalDayVector) arrowVector; + final NullableIntervalDayHolder intervalDayHolder = new NullableIntervalDayHolder(); + final HiveIntervalDayTime intervalDayTime = new HiveIntervalDayTime(); + for (int i = 0; i < size; i++) { + if (arrowVector.isNull(i)) { + VectorizedBatchUtil.setNullColIsNullValue(hiveVector, i); + } else { + hiveVector.isNull[i] = false; + intervalDayVector.get(i, intervalDayHolder); + final long seconds = intervalDayHolder.days * SECOND_PER_DAY + + intervalDayHolder.milliseconds / MS_PER_SECOND; + final int nanos = (intervalDayHolder.milliseconds % 1_000) * NS_PER_MS; + intervalDayTime.set(seconds, nanos); + ((IntervalDayTimeColumnVector) hiveVector).set(i, intervalDayTime); + } + } + } + break; + case VOID: + case TIMESTAMPLOCALTZ: + case UNKNOWN: + default: + break; + } + } + + private void readList(FieldVector arrowVector, ListColumnVector hiveVector, ListTypeInfo typeInfo) { + final int size = arrowVector.getValueCount(); + final ArrowBuf offsets = arrowVector.getOffsetBuffer(); + final int OFFSET_WIDTH = 4; + + read(arrowVector.getChildrenFromFields().get(0), + hiveVector.child, + typeInfo.getListElementTypeInfo()); + + for (int i = 0; i < size; i++) { + if (arrowVector.isNull(i)) { + VectorizedBatchUtil.setNullColIsNullValue(hiveVector, i); + } else { + hiveVector.isNull[i] = false; + final int offset = offsets.getInt(i * OFFSET_WIDTH); + hiveVector.offsets[i] = offset; + hiveVector.lengths[i] = offsets.getInt((i + 1) * OFFSET_WIDTH) - offset; + } + } + } + + private void readMap(FieldVector arrowVector, MapColumnVector hiveVector, MapTypeInfo typeInfo) { + final int size = arrowVector.getValueCount(); + final ListTypeInfo mapStructListTypeInfo = toStructListTypeInfo(typeInfo); + final ListColumnVector mapStructListVector = toStructListVector(hiveVector); + final StructColumnVector mapStructVector = (StructColumnVector) mapStructListVector.child; + + read(arrowVector, mapStructListVector, mapStructListTypeInfo); + + hiveVector.isRepeating = mapStructListVector.isRepeating; + hiveVector.childCount = mapStructListVector.childCount; + hiveVector.noNulls = mapStructListVector.noNulls; + hiveVector.keys = mapStructVector.fields[0]; + hiveVector.values = mapStructVector.fields[1]; + System.arraycopy(mapStructListVector.offsets, 0, hiveVector.offsets, 0, size); + System.arraycopy(mapStructListVector.lengths, 0, hiveVector.lengths, 0, size); + System.arraycopy(mapStructListVector.isNull, 0, hiveVector.isNull, 0, size); + } + + private void readStruct(FieldVector arrowVector, StructColumnVector hiveVector, StructTypeInfo typeInfo) { + final int size = arrowVector.getValueCount(); + final List fieldTypeInfos = typeInfo.getAllStructFieldTypeInfos(); + final int fieldSize = arrowVector.getChildrenFromFields().size(); + for (int i = 0; i < fieldSize; i++) { + read(arrowVector.getChildrenFromFields().get(i), hiveVector.fields[i], fieldTypeInfos.get(i)); + } + + for (int i = 0; i < size; i++) { + if (arrowVector.isNull(i)) { + VectorizedBatchUtil.setNullColIsNullValue(hiveVector, i); + } else { + hiveVector.isNull[i] = false; + } + } + } + + private void readUnion(FieldVector arrowVector, UnionColumnVector hiveVector, UnionTypeInfo typeInfo) { + } +} diff --git ql/src/java/org/apache/hadoop/hive/ql/io/arrow/Serializer.java ql/src/java/org/apache/hadoop/hive/ql/io/arrow/Serializer.java new file mode 100644 index 0000000000..bd23011c93 --- /dev/null +++ ql/src/java/org/apache/hadoop/hive/ql/io/arrow/Serializer.java @@ -0,0 +1,537 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.hadoop.hive.ql.io.arrow; + +import io.netty.buffer.ArrowBuf; +import org.apache.arrow.vector.BigIntVector; +import org.apache.arrow.vector.BitVector; +import org.apache.arrow.vector.BitVectorHelper; +import org.apache.arrow.vector.DateDayVector; +import org.apache.arrow.vector.DecimalVector; +import org.apache.arrow.vector.FieldVector; +import org.apache.arrow.vector.Float4Vector; +import org.apache.arrow.vector.Float8Vector; +import org.apache.arrow.vector.IntVector; +import org.apache.arrow.vector.IntervalDayVector; +import org.apache.arrow.vector.IntervalYearVector; +import org.apache.arrow.vector.SmallIntVector; +import org.apache.arrow.vector.TimeStampNanoVector; +import org.apache.arrow.vector.TinyIntVector; +import org.apache.arrow.vector.VarBinaryVector; +import org.apache.arrow.vector.VarCharVector; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.complex.ListVector; +import org.apache.arrow.vector.complex.MapVector; +import org.apache.arrow.vector.complex.NullableMapVector; +import org.apache.arrow.vector.types.Types; +import org.apache.arrow.vector.types.pojo.ArrowType; +import org.apache.arrow.vector.types.pojo.FieldType; +import org.apache.hadoop.hive.conf.HiveConf; +import org.apache.hadoop.hive.ql.exec.vector.BytesColumnVector; +import org.apache.hadoop.hive.ql.exec.vector.ColumnVector; +import org.apache.hadoop.hive.ql.exec.vector.DecimalColumnVector; +import org.apache.hadoop.hive.ql.exec.vector.DoubleColumnVector; +import org.apache.hadoop.hive.ql.exec.vector.IntervalDayTimeColumnVector; +import org.apache.hadoop.hive.ql.exec.vector.ListColumnVector; +import org.apache.hadoop.hive.ql.exec.vector.LongColumnVector; +import org.apache.hadoop.hive.ql.exec.vector.MapColumnVector; +import org.apache.hadoop.hive.ql.exec.vector.StructColumnVector; +import org.apache.hadoop.hive.ql.exec.vector.TimestampColumnVector; +import org.apache.hadoop.hive.ql.exec.vector.UnionColumnVector; +import org.apache.hadoop.hive.ql.exec.vector.VectorAssignRow; +import org.apache.hadoop.hive.ql.exec.vector.VectorizedRowBatch; +import org.apache.hadoop.hive.ql.metadata.HiveException; +import org.apache.hadoop.hive.serde2.SerDeException; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils; +import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector; +import org.apache.hadoop.hive.serde2.typeinfo.DecimalTypeInfo; +import org.apache.hadoop.hive.serde2.typeinfo.ListTypeInfo; +import org.apache.hadoop.hive.serde2.typeinfo.MapTypeInfo; +import org.apache.hadoop.hive.serde2.typeinfo.PrimitiveTypeInfo; +import org.apache.hadoop.hive.serde2.typeinfo.StructTypeInfo; +import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo; +import org.apache.hadoop.hive.serde2.typeinfo.UnionTypeInfo; + +import java.util.ArrayList; +import java.util.List; + +import static org.apache.hadoop.hive.conf.HiveConf.ConfVars.HIVE_ARROW_BATCH_SIZE; +import static org.apache.hadoop.hive.ql.exec.vector.VectorizedBatchUtil.createColumnVector; +import static org.apache.hadoop.hive.ql.io.arrow.ArrowColumnarBatchSerDe.MS_PER_SECOND; +import static org.apache.hadoop.hive.ql.io.arrow.ArrowColumnarBatchSerDe.NS_PER_MS; +import static org.apache.hadoop.hive.ql.io.arrow.ArrowColumnarBatchSerDe.SECOND_PER_DAY; +import static org.apache.hadoop.hive.ql.io.arrow.ArrowColumnarBatchSerDe.toStructListTypeInfo; +import static org.apache.hadoop.hive.ql.io.arrow.ArrowColumnarBatchSerDe.toStructListVector; +import static org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils.ObjectInspectorCopyOption.WRITABLE; +import static org.apache.hadoop.hive.serde2.typeinfo.TypeInfoUtils.getTypeInfoFromObjectInspector; + +class Serializer { + private final int MAX_BUFFERED_ROWS; + + // Schema + private final StructTypeInfo structTypeInfo; + private final int fieldSize; + + // Hive columns + private final VectorizedRowBatch vectorizedRowBatch; + private final VectorAssignRow vectorAssignRow; + private int batchSize; + + private final NullableMapVector rootVector; + + Serializer(ArrowColumnarBatchSerDe serDe) throws SerDeException { + MAX_BUFFERED_ROWS = HiveConf.getIntVar(serDe.conf, HIVE_ARROW_BATCH_SIZE); + ArrowColumnarBatchSerDe.LOG.info("ArrowColumnarBatchSerDe max number of buffered columns: " + MAX_BUFFERED_ROWS); + + // Schema + structTypeInfo = (StructTypeInfo) getTypeInfoFromObjectInspector(serDe.rowObjectInspector); + List fieldTypeInfos = structTypeInfo.getAllStructFieldTypeInfos(); + fieldSize = fieldTypeInfos.size(); + + // Init Arrow stuffs + rootVector = NullableMapVector.empty(null, serDe.rootAllocator); + + // Init Hive stuffs + vectorizedRowBatch = new VectorizedRowBatch(fieldSize); + for (int fieldIndex = 0; fieldIndex < fieldSize; fieldIndex++) { + final ColumnVector columnVector = createColumnVector(fieldTypeInfos.get(fieldIndex)); + vectorizedRowBatch.cols[fieldIndex] = columnVector; + columnVector.init(); + } + vectorizedRowBatch.ensureSize(MAX_BUFFERED_ROWS); + vectorAssignRow = new VectorAssignRow(); + try { + vectorAssignRow.init(serDe.rowObjectInspector); + } catch (HiveException e) { + throw new SerDeException(e); + } + } + + private ArrowWrapperWritable serializeBatch() { + rootVector.setValueCount(0); + + for (int fieldIndex = 0; fieldIndex < vectorizedRowBatch.projectionSize; fieldIndex++) { + final int projectedColumn = vectorizedRowBatch.projectedColumns[fieldIndex]; + final ColumnVector hiveVector = vectorizedRowBatch.cols[projectedColumn]; + final TypeInfo fieldTypeInfo = structTypeInfo.getAllStructFieldTypeInfos().get(fieldIndex); + final String fieldName = structTypeInfo.getAllStructFieldNames().get(fieldIndex); + final FieldType fieldType = toFieldType(fieldTypeInfo); + final FieldVector arrowVector = rootVector.addOrGet(fieldName, fieldType, FieldVector.class); + arrowVector.setInitialCapacity(batchSize); + arrowVector.allocateNew(); + write(arrowVector, hiveVector, fieldTypeInfo, batchSize); + } + vectorizedRowBatch.reset(); + rootVector.setValueCount(batchSize); + + batchSize = 0; + VectorSchemaRoot vectorSchemaRoot = new VectorSchemaRoot(rootVector); + return new ArrowWrapperWritable(vectorSchemaRoot); + } + + private FieldType toFieldType(TypeInfo typeInfo) { + return new FieldType(true, toArrowType(typeInfo), null); + } + + private ArrowType toArrowType(TypeInfo typeInfo) { + switch (typeInfo.getCategory()) { + case PRIMITIVE: + switch (((PrimitiveTypeInfo) typeInfo).getPrimitiveCategory()) { + case BOOLEAN: + return Types.MinorType.BIT.getType(); + case BYTE: + return Types.MinorType.TINYINT.getType(); + case SHORT: + return Types.MinorType.SMALLINT.getType(); + case INT: + return Types.MinorType.INT.getType(); + case LONG: + return Types.MinorType.BIGINT.getType(); + case FLOAT: + return Types.MinorType.FLOAT4.getType(); + case DOUBLE: + return Types.MinorType.FLOAT8.getType(); + case STRING: + case VARCHAR: + case CHAR: + return Types.MinorType.VARCHAR.getType(); + case DATE: + return Types.MinorType.DATEDAY.getType(); + case TIMESTAMP: + return Types.MinorType.TIMESTAMPNANO.getType(); + case BINARY: + return Types.MinorType.VARBINARY.getType(); + case DECIMAL: + final DecimalTypeInfo decimalTypeInfo = (DecimalTypeInfo) typeInfo; + return new ArrowType.Decimal(decimalTypeInfo.precision(), decimalTypeInfo.scale()); + case INTERVAL_YEAR_MONTH: + return Types.MinorType.INTERVALYEAR.getType(); + case INTERVAL_DAY_TIME: + return Types.MinorType.INTERVALDAY.getType(); + case VOID: + case TIMESTAMPLOCALTZ: + case UNKNOWN: + default: + throw new IllegalArgumentException(); + } + case LIST: + return ArrowType.List.INSTANCE; + case STRUCT: + return ArrowType.Struct.INSTANCE; + case MAP: + return ArrowType.List.INSTANCE; + case UNION: + default: + throw new IllegalArgumentException(); + } + } + + private void write(FieldVector arrowVector, ColumnVector hiveVector, TypeInfo typeInfo, int size) { + switch (typeInfo.getCategory()) { + case PRIMITIVE: + writePrimitive(arrowVector, hiveVector, typeInfo, size); + break; + case LIST: + writeList((ListVector) arrowVector, (ListColumnVector) hiveVector, (ListTypeInfo) typeInfo, size); + break; + case STRUCT: + writeStruct((MapVector) arrowVector, (StructColumnVector) hiveVector, (StructTypeInfo) typeInfo, size); + break; + case UNION: + writeUnion(arrowVector, hiveVector, typeInfo, size); + break; + case MAP: + writeMap((ListVector) arrowVector, (MapColumnVector) hiveVector, (MapTypeInfo) typeInfo, size); + break; + default: + throw new IllegalArgumentException(); + } + } + + private void writeMap(ListVector arrowVector, MapColumnVector hiveVector, MapTypeInfo typeInfo, + int size) { + final ListTypeInfo structListTypeInfo = toStructListTypeInfo(typeInfo); + final ListColumnVector structListVector = toStructListVector(hiveVector); + + write(arrowVector, structListVector, structListTypeInfo, size); + + final ArrowBuf validityBuffer = arrowVector.getValidityBuffer(); + for (int rowIndex = 0; rowIndex < size; rowIndex++) { + if (hiveVector.isNull[rowIndex]) { + BitVectorHelper.setValidityBit(validityBuffer, rowIndex, 0); + } else { + BitVectorHelper.setValidityBitToOne(validityBuffer, rowIndex); + } + } + } + + private void writeUnion(FieldVector arrowVector, ColumnVector hiveVector, TypeInfo typeInfo, + int size) { + final UnionTypeInfo unionTypeInfo = (UnionTypeInfo) typeInfo; + final List objectTypeInfos = unionTypeInfo.getAllUnionObjectTypeInfos(); + final UnionColumnVector hiveUnionVector = (UnionColumnVector) hiveVector; + final ColumnVector[] hiveObjectVectors = hiveUnionVector.fields; + + final int tag = hiveUnionVector.tags[0]; + final ColumnVector hiveObjectVector = hiveObjectVectors[tag]; + final TypeInfo objectTypeInfo = objectTypeInfos.get(tag); + + write(arrowVector, hiveObjectVector, objectTypeInfo, size); + } + + private void writeStruct(MapVector arrowVector, StructColumnVector hiveVector, + StructTypeInfo typeInfo, int size) { + final List fieldNames = typeInfo.getAllStructFieldNames(); + final List fieldTypeInfos = typeInfo.getAllStructFieldTypeInfos(); + final ColumnVector[] hiveFieldVectors = hiveVector.fields; + final int fieldSize = fieldTypeInfos.size(); + + for (int fieldIndex = 0; fieldIndex < fieldSize; fieldIndex++) { + final TypeInfo fieldTypeInfo = fieldTypeInfos.get(fieldIndex); + final ColumnVector hiveFieldVector = hiveFieldVectors[fieldIndex]; + final String fieldName = fieldNames.get(fieldIndex); + final FieldVector arrowFieldVector = + arrowVector.addOrGet(fieldName, + toFieldType(fieldTypeInfos.get(fieldIndex)), FieldVector.class); + arrowFieldVector.setInitialCapacity(size); + arrowFieldVector.allocateNew(); + write(arrowFieldVector, hiveFieldVector, fieldTypeInfo, size); + } + + final ArrowBuf validityBuffer = arrowVector.getValidityBuffer(); + for (int rowIndex = 0; rowIndex < size; rowIndex++) { + if (hiveVector.isNull[rowIndex]) { + BitVectorHelper.setValidityBit(validityBuffer, rowIndex, 0); + } else { + BitVectorHelper.setValidityBitToOne(validityBuffer, rowIndex); + } + } + } + + private void writeList(ListVector arrowVector, ListColumnVector hiveVector, ListTypeInfo typeInfo, + int size) { + final int OFFSET_WIDTH = 4; + final TypeInfo elementTypeInfo = typeInfo.getListElementTypeInfo(); + final ColumnVector hiveElementVector = hiveVector.child; + final FieldVector arrowElementVector = + (FieldVector) arrowVector.addOrGetVector(toFieldType(elementTypeInfo)).getVector(); + arrowElementVector.setInitialCapacity(hiveVector.childCount); + arrowElementVector.allocateNew(); + + write(arrowElementVector, hiveElementVector, elementTypeInfo, hiveVector.childCount); + + final ArrowBuf offsetBuffer = arrowVector.getOffsetBuffer(); + int nextOffset = 0; + + for (int rowIndex = 0; rowIndex < size; rowIndex++) { + if (hiveVector.isNull[rowIndex]) { + offsetBuffer.setInt(rowIndex * OFFSET_WIDTH, nextOffset); + } else { + offsetBuffer.setInt(rowIndex * OFFSET_WIDTH, nextOffset); + nextOffset += (int) hiveVector.lengths[rowIndex]; + arrowVector.setNotNull(rowIndex); + } + } + offsetBuffer.setInt(size * OFFSET_WIDTH, nextOffset); + } + + private void writePrimitive(FieldVector arrowVector, ColumnVector hiveVector, TypeInfo typeInfo, + int size) { + final PrimitiveObjectInspector.PrimitiveCategory primitiveCategory = + ((PrimitiveTypeInfo) typeInfo).getPrimitiveCategory(); + switch (primitiveCategory) { + case BOOLEAN: + { + final BitVector bitVector = (BitVector) arrowVector; + for (int i = 0; i < size; i++) { + if (hiveVector.isNull[i]) { + bitVector.setNull(i); + } else { + bitVector.set(i, (int) ((LongColumnVector) hiveVector).vector[i]); + } + } + } + break; + case BYTE: + { + final TinyIntVector tinyIntVector = (TinyIntVector) arrowVector; + for (int i = 0; i < size; i++) { + if (hiveVector.isNull[i]) { + tinyIntVector.setNull(i); + } else { + tinyIntVector.set(i, (byte) ((LongColumnVector) hiveVector).vector[i]); + } + } + } + break; + case SHORT: + { + final SmallIntVector smallIntVector = (SmallIntVector) arrowVector; + for (int i = 0; i < size; i++) { + if (hiveVector.isNull[i]) { + smallIntVector.setNull(i); + } else { + smallIntVector.set(i, (short) ((LongColumnVector) hiveVector).vector[i]); + } + } + } + break; + case INT: + { + final IntVector intVector = (IntVector) arrowVector; + for (int i = 0; i < size; i++) { + if (hiveVector.isNull[i]) { + intVector.setNull(i); + } else { + intVector.set(i, (int) ((LongColumnVector) hiveVector).vector[i]); + } + } + } + break; + case LONG: + { + final BigIntVector bigIntVector = (BigIntVector) arrowVector; + for (int i = 0; i < size; i++) { + if (hiveVector.isNull[i]) { + bigIntVector.setNull(i); + } else { + bigIntVector.set(i, ((LongColumnVector) hiveVector).vector[i]); + } + } + } + break; + case FLOAT: + { + final Float4Vector float4Vector = (Float4Vector) arrowVector; + for (int i = 0; i < size; i++) { + if (hiveVector.isNull[i]) { + float4Vector.setNull(i); + } else { + float4Vector.set(i, (float) ((DoubleColumnVector) hiveVector).vector[i]); + } + } + } + break; + case DOUBLE: + { + final Float8Vector float8Vector = (Float8Vector) arrowVector; + for (int i = 0; i < size; i++) { + if (hiveVector.isNull[i]) { + float8Vector.setNull(i); + } else { + float8Vector.set(i, ((DoubleColumnVector) hiveVector).vector[i]); + } + } + } + break; + case STRING: + case VARCHAR: + case CHAR: + { + final VarCharVector varCharVector = (VarCharVector) arrowVector; + final BytesColumnVector bytesVector = (BytesColumnVector) hiveVector; + for (int i = 0; i < size; i++) { + if (hiveVector.isNull[i]) { + varCharVector.setNull(i); + } else { + varCharVector.setSafe(i, bytesVector.vector[i], bytesVector.start[i], bytesVector.length[i]); + } + } + } + break; + case DATE: + { + final DateDayVector dateDayVector = (DateDayVector) arrowVector; + for (int i = 0; i < size; i++) { + if (hiveVector.isNull[i]) { + dateDayVector.setNull(i); + } else { + dateDayVector.set(i, (int) ((LongColumnVector) hiveVector).vector[i]); + } + } + } + break; + case TIMESTAMP: + { + final TimeStampNanoVector timeStampNanoVector = (TimeStampNanoVector) arrowVector; + final TimestampColumnVector timestampColumnVector = (TimestampColumnVector) hiveVector; + for (int i = 0; i < size; i++) { + if (hiveVector.isNull[i]) { + timeStampNanoVector.setNull(i); + } else { + // Time = second + sub-second + final long secondInMillis = timestampColumnVector.getTime(i); + final long secondInNanos = (secondInMillis - secondInMillis % 1000) * NS_PER_MS; // second + final long subSecondInNanos = timestampColumnVector.getNanos(i); // sub-second + + if ((secondInMillis > 0 && secondInNanos < 0) || (secondInMillis < 0 && secondInNanos > 0)) { + // If the timestamp cannot be represented in long nanosecond, set it as a null value + timeStampNanoVector.setNull(i); + } else { + timeStampNanoVector.set(i, secondInNanos + subSecondInNanos); + } + } + } + } + break; + case BINARY: + { + final VarBinaryVector varBinaryVector = (VarBinaryVector) arrowVector; + final BytesColumnVector bytesVector = (BytesColumnVector) hiveVector; + for (int i = 0; i < size; i++) { + if (hiveVector.isNull[i]) { + varBinaryVector.setNull(i); + } else { + varBinaryVector.setSafe(i, bytesVector.vector[i], bytesVector.start[i], bytesVector.length[i]); + } + } + } + break; + case DECIMAL: + { + final DecimalVector decimalVector = (DecimalVector) arrowVector; + final int scale = decimalVector.getScale(); + for (int i = 0; i < size; i++) { + if (hiveVector.isNull[i]) { + decimalVector.setNull(i); + } else { + decimalVector.set(i, + ((DecimalColumnVector) hiveVector).vector[i].getHiveDecimal().bigDecimalValue().setScale(scale)); + } + } + } + break; + case INTERVAL_YEAR_MONTH: + { + final IntervalYearVector intervalYearVector = (IntervalYearVector) arrowVector; + for (int i = 0; i < size; i++) { + if (hiveVector.isNull[i]) { + intervalYearVector.setNull(i); + } else { + intervalYearVector.set(i, (int) ((LongColumnVector) hiveVector).vector[i]); + } + } + } + break; + case INTERVAL_DAY_TIME: + { + final IntervalDayVector intervalDayVector = (IntervalDayVector) arrowVector; + final IntervalDayTimeColumnVector intervalDayTimeColumnVector = + (IntervalDayTimeColumnVector) hiveVector; + for (int i = 0; i < size; i++) { + if (hiveVector.isNull[i]) { + intervalDayVector.setNull(i); + } else { + final long totalSeconds = intervalDayTimeColumnVector.getTotalSeconds(i); + final long days = totalSeconds / SECOND_PER_DAY; + final long millis = + (totalSeconds - days * SECOND_PER_DAY) * MS_PER_SECOND + + intervalDayTimeColumnVector.getNanos(i) / NS_PER_MS; + intervalDayVector.set(i, (int) days, (int) millis); + } + } + } + break; + case VOID: + case UNKNOWN: + case TIMESTAMPLOCALTZ: + default: + throw new IllegalArgumentException(); + } + } + + ArrowWrapperWritable serialize(Object obj, ObjectInspector objInspector) { + // if row is null, it means there are no more rows (closeOp()). + // another case can be that the buffer is full. + if (obj == null) { + return serializeBatch(); + } + List standardObjects = new ArrayList(); + ObjectInspectorUtils.copyToStandardObject(standardObjects, obj, + ((StructObjectInspector) objInspector), WRITABLE); + + vectorAssignRow.assignRow(vectorizedRowBatch, batchSize, standardObjects, fieldSize); + batchSize++; + if (batchSize == MAX_BUFFERED_ROWS) { + return serializeBatch(); + } + return null; + } +} diff --git ql/src/test/org/apache/hadoop/hive/ql/io/arrow/TestArrowColumnarBatchSerDe.java ql/src/test/org/apache/hadoop/hive/ql/io/arrow/TestArrowColumnarBatchSerDe.java index bcb7a88258..74f6624597 100644 --- ql/src/test/org/apache/hadoop/hive/ql/io/arrow/TestArrowColumnarBatchSerDe.java +++ ql/src/test/org/apache/hadoop/hive/ql/io/arrow/TestArrowColumnarBatchSerDe.java @@ -42,7 +42,6 @@ import org.apache.hadoop.hive.serde2.io.TimestampWritable; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector; -import org.apache.hadoop.hive.serde2.objectinspector.StandardUnionObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.StructField; import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector; import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo; @@ -54,7 +53,6 @@ import org.apache.hadoop.io.IntWritable; import org.apache.hadoop.io.LongWritable; import org.apache.hadoop.io.Text; -import org.apache.hadoop.io.Writable; import org.junit.Before; import org.junit.Test; @@ -66,10 +64,11 @@ import java.util.Properties; import java.util.Random; import java.util.Set; +import java.util.concurrent.TimeUnit; import static org.junit.Assert.assertArrayEquals; import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertTrue; +import static org.junit.Assert.assertNull; public class TestArrowColumnarBatchSerDe { private Configuration conf; @@ -105,14 +104,39 @@ {null, null, null}, }; - private final static long NOW = System.currentTimeMillis(); + private final static long TIME_IN_MS = TimeUnit.DAYS.toMillis(365 + 31 + 3); + private final static long NEGATIVE_TIME_IN_MS = TimeUnit.DAYS.toMillis(-9 * 365 + 31 + 3); + private final static Timestamp TIMESTAMP; + private final static Timestamp NEGATIVE_TIMESTAMP_WITHOUT_NANOS; + private final static Timestamp NEGATIVE_TIMESTAMP_WITH_NANOS; + + static { + TIMESTAMP = new Timestamp(TIME_IN_MS); + TIMESTAMP.setNanos(123456789); + NEGATIVE_TIMESTAMP_WITHOUT_NANOS = new Timestamp(NEGATIVE_TIME_IN_MS); + NEGATIVE_TIMESTAMP_WITH_NANOS = new Timestamp(NEGATIVE_TIME_IN_MS); + NEGATIVE_TIMESTAMP_WITH_NANOS.setNanos(123456789); + } + private final static Object[][] DTI_ROWS = { { - new DateWritable(DateWritable.millisToDays(NOW)), - new TimestampWritable(new Timestamp(NOW)), + new DateWritable(DateWritable.millisToDays(TIME_IN_MS)), + new TimestampWritable(TIMESTAMP), new HiveIntervalYearMonthWritable(new HiveIntervalYearMonth(1, 2)), new HiveIntervalDayTimeWritable(new HiveIntervalDayTime(1, 2, 3, 4, 5_000_000)) }, + { + new DateWritable(DateWritable.millisToDays(NEGATIVE_TIME_IN_MS)), + new TimestampWritable(NEGATIVE_TIMESTAMP_WITHOUT_NANOS), + null, + null + }, + { + null, + new TimestampWritable(NEGATIVE_TIMESTAMP_WITH_NANOS), + null, + null + }, {null, null, null, null}, }; @@ -184,7 +208,7 @@ private static HiveDecimalWritable decimalW(HiveDecimal value) { } private void initAndSerializeAndDeserialize(String[][] schema, Object[][] rows) throws SerDeException { - AbstractSerDe serDe = new ArrowColumnarBatchSerDe(); + ArrowColumnarBatchSerDe serDe = new ArrowColumnarBatchSerDe(); StructObjectInspector rowOI = initSerDe(serDe, schema); serializeAndDeserialize(serDe, rows, rowOI); } @@ -214,9 +238,9 @@ private StructObjectInspector initSerDe(AbstractSerDe serDe, String[][] schema) TypeInfoFactory.getStructTypeInfo(fieldNameList, typeInfoList)); } - private void serializeAndDeserialize(AbstractSerDe serDe, Object[][] rows, - StructObjectInspector rowOI) throws SerDeException { - Writable serialized = null; + private void serializeAndDeserialize(ArrowColumnarBatchSerDe serDe, Object[][] rows, + StructObjectInspector rowOI) { + ArrowWrapperWritable serialized = null; for (Object[] row : rows) { serialized = serDe.serialize(row, rowOI); } @@ -224,6 +248,7 @@ private void serializeAndDeserialize(AbstractSerDe serDe, Object[][] rows, if (serialized == null) { serialized = serDe.serialize(null, rowOI); } + String s = serialized.getVectorSchemaRoot().contentToTSVString(); final Object[][] deserializedRows = (Object[][]) serDe.deserialize(serialized); for (int rowIndex = 0; rowIndex < Math.min(deserializedRows.length, rows.length); rowIndex++) { @@ -254,21 +279,28 @@ private void serializeAndDeserialize(AbstractSerDe serDe, Object[][] rows, case STRUCT: final Object[] rowStruct = (Object[]) row[fieldIndex]; final List deserializedRowStruct = (List) deserializedRow[fieldIndex]; - assertArrayEquals(rowStruct, deserializedRowStruct.toArray()); + if (rowStruct == null) { + assertNull(deserializedRowStruct); + } else { + assertArrayEquals(rowStruct, deserializedRowStruct.toArray()); + } break; case LIST: case UNION: assertEquals(row[fieldIndex], deserializedRow[fieldIndex]); break; case MAP: - Map rowMap = (Map) row[fieldIndex]; - Map deserializedRowMap = (Map) deserializedRow[fieldIndex]; - Set rowMapKeySet = rowMap.keySet(); - Set deserializedRowMapKeySet = deserializedRowMap.keySet(); - assertTrue(rowMapKeySet.containsAll(deserializedRowMapKeySet)); - assertTrue(deserializedRowMapKeySet.containsAll(rowMapKeySet)); - for (Object key : rowMapKeySet) { - assertEquals(rowMap.get(key), deserializedRowMap.get(key)); + final Map rowMap = (Map) row[fieldIndex]; + final Map deserializedRowMap = (Map) deserializedRow[fieldIndex]; + if (rowMap == null) { + assertNull(deserializedRowMap); + } else { + final Set rowMapKeySet = rowMap.keySet(); + final Set deserializedRowMapKeySet = deserializedRowMap.keySet(); + assertEquals(rowMapKeySet, deserializedRowMapKeySet); + for (Object key : rowMapKeySet) { + assertEquals(rowMap.get(key), deserializedRowMap.get(key)); + } } break; } @@ -341,14 +373,18 @@ public void testComprehensive() throws SerDeException { newArrayList(text("hello")), input -> text(input.toString().toUpperCase())), intW(0))), // c16:array,n:int>> - new TimestampWritable(new Timestamp(NOW)), // c17:timestamp + new TimestampWritable(TIMESTAMP), // c17:timestamp decimalW(HiveDecimal.create(0, 0)), // c18:decimal(16,7) new BytesWritable("Hello".getBytes()), // c19:binary new DateWritable(123), // c20:date varcharW("x", 20), // c21:varchar(20) charW("y", 15), // c22:char(15) new BytesWritable("world!".getBytes()), // c23:binary - }, + }, { + null, null, null, null, null, null, null, null, null, null, // c1-c10 + null, null, null, null, null, null, null, null, null, null, // c11-c20 + null, null, null, // c21-c23 + } }; initAndSerializeAndDeserialize(schema, comprehensiveRows); @@ -378,7 +414,7 @@ public void testPrimitiveBigInt10000() throws SerDeException { final int batchSize = 1000; final Object[][] integerRows = new Object[batchSize][]; - final AbstractSerDe serDe = new ArrowColumnarBatchSerDe(); + final ArrowColumnarBatchSerDe serDe = new ArrowColumnarBatchSerDe(); StructObjectInspector rowOI = initSerDe(serDe, schema); for (int j = 0; j < 10; j++) { @@ -397,7 +433,7 @@ public void testPrimitiveBigIntRandom() { {"bigint1", "bigint"} }; - final AbstractSerDe serDe = new ArrowColumnarBatchSerDe(); + final ArrowColumnarBatchSerDe serDe = new ArrowColumnarBatchSerDe(); StructObjectInspector rowOI = initSerDe(serDe, schema); final Random random = new Random(); @@ -572,106 +608,6 @@ public void testListBinary() throws SerDeException { initAndSerializeAndDeserialize(schema, toList(BINARY_ROWS)); } - private StandardUnionObjectInspector.StandardUnion union(int tag, Object object) { - return new StandardUnionObjectInspector.StandardUnion((byte) tag, object); - } - - public void testUnionInteger() throws SerDeException { - String[][] schema = { - {"int_union", "uniontype"}, - }; - - StandardUnionObjectInspector.StandardUnion[][] integerUnions = { - {union(0, byteW(0))}, - {union(1, shortW(1))}, - {union(2, intW(2))}, - {union(3, longW(3))}, - }; - - initAndSerializeAndDeserialize(schema, integerUnions); - } - - public void testUnionFloat() throws SerDeException { - String[][] schema = { - {"float_union", "uniontype"}, - }; - - StandardUnionObjectInspector.StandardUnion[][] floatUnions = { - {union(0, floatW(0f))}, - {union(1, doubleW(1d))}, - }; - - initAndSerializeAndDeserialize(schema, floatUnions); - } - - public void testUnionString() throws SerDeException { - String[][] schema = { - {"string_union", "uniontype"}, - }; - - StandardUnionObjectInspector.StandardUnion[][] stringUnions = { - {union(0, text("Hello"))}, - {union(1, intW(1))}, - }; - - initAndSerializeAndDeserialize(schema, stringUnions); - } - - public void testUnionChar() throws SerDeException { - String[][] schema = { - {"char_union", "uniontype"}, - }; - - StandardUnionObjectInspector.StandardUnion[][] charUnions = { - {union(0, charW("Hello", 10))}, - {union(1, intW(1))}, - }; - - initAndSerializeAndDeserialize(schema, charUnions); - } - - public void testUnionVarchar() throws SerDeException { - String[][] schema = { - {"varchar_union", "uniontype"}, - }; - - StandardUnionObjectInspector.StandardUnion[][] varcharUnions = { - {union(0, varcharW("Hello", 10))}, - {union(1, intW(1))}, - }; - - initAndSerializeAndDeserialize(schema, varcharUnions); - } - - public void testUnionDTI() throws SerDeException { - String[][] schema = { - {"date_union", "uniontype"}, - }; - long NOW = System.currentTimeMillis(); - - StandardUnionObjectInspector.StandardUnion[][] dtiUnions = { - {union(0, new DateWritable(DateWritable.millisToDays(NOW)))}, - {union(1, new TimestampWritable(new Timestamp(NOW)))}, - {union(2, new HiveIntervalYearMonthWritable(new HiveIntervalYearMonth(1, 2)))}, - {union(3, new HiveIntervalDayTimeWritable(new HiveIntervalDayTime(1, 2, 3, 4, 5_000_000)))}, - }; - - initAndSerializeAndDeserialize(schema, dtiUnions); - } - - public void testUnionBooleanBinary() throws SerDeException { - String[][] schema = { - {"boolean_union", "uniontype"}, - }; - - StandardUnionObjectInspector.StandardUnion[][] booleanBinaryUnions = { - {union(0, new BooleanWritable(true))}, - {union(1, new BytesWritable("Hello".getBytes()))}, - }; - - initAndSerializeAndDeserialize(schema, booleanBinaryUnions); - } - private Object[][][] toStruct(Object[][] rows) { Object[][][] struct = new Object[rows.length][][]; for (int rowIndex = 0; rowIndex < rows.length; rowIndex++) { @@ -718,6 +654,15 @@ public void testStructDTI() throws SerDeException { initAndSerializeAndDeserialize(schema, toStruct(DTI_ROWS)); } + @Test + public void testStructDecimal() throws SerDeException { + String[][] schema = { + {"decimal_struct", "struct"}, + }; + + initAndSerializeAndDeserialize(schema, toStruct(DECIMAL_ROWS)); + } + @Test public void testStructBoolean() throws SerDeException { String[][] schema = { @@ -812,4 +757,21 @@ public void testMapBinary() throws SerDeException { initAndSerializeAndDeserialize(schema, toMap(BINARY_ROWS)); } + + public void testMapDecimal() throws SerDeException { + String[][] schema = { + {"decimal_map", "map"}, + }; + + initAndSerializeAndDeserialize(schema, toMap(DECIMAL_ROWS)); + } + + public void testListDecimal() throws SerDeException { + String[][] schema = { + {"decimal_list", "array"}, + }; + + initAndSerializeAndDeserialize(schema, toList(DECIMAL_ROWS)); + } + }