diff --git common/src/java/org/apache/hadoop/hive/conf/HiveConf.java common/src/java/org/apache/hadoop/hive/conf/HiveConf.java index 1e8a389784..66fe4bd32c 100644 --- common/src/java/org/apache/hadoop/hive/conf/HiveConf.java +++ common/src/java/org/apache/hadoop/hive/conf/HiveConf.java @@ -2639,6 +2639,7 @@ private static void populateLlapDaemonVarsSet(Set llapDaemonVarsSetLocal HIVE_ARROW_ROOT_ALLOCATOR_LIMIT("hive.arrow.root.allocator.limit", Long.MAX_VALUE, "Arrow root allocator memory size limitation in bytes."), HIVE_ARROW_BATCH_SIZE("hive.arrow.batch.size", 1000, "The number of rows sent in one Arrow batch."), + HIVE_ARROW_ENCODE("hive.arrow.encode", false, "Set to true to encode repeating values."), // For Druid storage handler HIVE_DRUID_INDEXING_GRANULARITY("hive.druid.indexer.segments.granularity", "DAY", diff --git ql/src/java/org/apache/hadoop/hive/ql/io/arrow/ArrowWrapperWritable.java ql/src/java/org/apache/hadoop/hive/ql/io/arrow/ArrowWrapperWritable.java index dd490b1b90..4dd3ce02bd 100644 --- ql/src/java/org/apache/hadoop/hive/ql/io/arrow/ArrowWrapperWritable.java +++ ql/src/java/org/apache/hadoop/hive/ql/io/arrow/ArrowWrapperWritable.java @@ -19,6 +19,7 @@ package org.apache.hadoop.hive.ql.io.arrow; import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.dictionary.DictionaryProvider; import org.apache.hadoop.io.WritableComparable; import java.io.DataInput; @@ -27,10 +28,18 @@ public class ArrowWrapperWritable implements WritableComparable { private VectorSchemaRoot vectorSchemaRoot; + private DictionaryProvider dictionaryProvider; public ArrowWrapperWritable(VectorSchemaRoot vectorSchemaRoot) { + this(vectorSchemaRoot, null); + } + + public ArrowWrapperWritable(VectorSchemaRoot vectorSchemaRoot, + DictionaryProvider dictionaryProvider) { this.vectorSchemaRoot = vectorSchemaRoot; + this.dictionaryProvider = dictionaryProvider; } + public ArrowWrapperWritable() {} public VectorSchemaRoot getVectorSchemaRoot() { @@ -41,6 +50,14 @@ public void setVectorSchemaRoot(VectorSchemaRoot vectorSchemaRoot) { this.vectorSchemaRoot = vectorSchemaRoot; } + public DictionaryProvider getDictionaryProvider() { + return dictionaryProvider; + } + + public void setDictionaryProvider(DictionaryProvider dictionaryProvider) { + this.dictionaryProvider = dictionaryProvider; + } + @Override public void write(DataOutput dataOutput) throws IOException { throw new UnsupportedOperationException(); diff --git ql/src/java/org/apache/hadoop/hive/ql/io/arrow/Deserializer.java ql/src/java/org/apache/hadoop/hive/ql/io/arrow/Deserializer.java index 6e09d3991f..140eb69739 100644 --- ql/src/java/org/apache/hadoop/hive/ql/io/arrow/Deserializer.java +++ ql/src/java/org/apache/hadoop/hive/ql/io/arrow/Deserializer.java @@ -36,10 +36,14 @@ import org.apache.arrow.vector.VarBinaryVector; import org.apache.arrow.vector.VarCharVector; import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.dictionary.Dictionary; +import org.apache.arrow.vector.dictionary.DictionaryEncoder; +import org.apache.arrow.vector.dictionary.DictionaryProvider; import org.apache.arrow.vector.holders.NullableIntervalDayHolder; import org.apache.arrow.vector.types.Types; import org.apache.hadoop.hive.common.type.HiveDecimal; import org.apache.hadoop.hive.common.type.HiveIntervalDayTime; +import org.apache.hadoop.hive.conf.HiveConf; import org.apache.hadoop.hive.ql.exec.vector.BytesColumnVector; import org.apache.hadoop.hive.ql.exec.vector.ColumnVector; import org.apache.hadoop.hive.ql.exec.vector.DecimalColumnVector; @@ -56,8 +60,10 @@ import org.apache.hadoop.hive.ql.exec.vector.VectorizedRowBatch; import org.apache.hadoop.hive.ql.metadata.HiveException; import org.apache.hadoop.hive.serde2.SerDeException; +import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector; import org.apache.hadoop.hive.serde2.typeinfo.ListTypeInfo; import org.apache.hadoop.hive.serde2.typeinfo.MapTypeInfo; +import org.apache.hadoop.hive.serde2.typeinfo.PrimitiveTypeInfo; import org.apache.hadoop.hive.serde2.typeinfo.StructTypeInfo; import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo; import org.apache.hadoop.hive.serde2.typeinfo.UnionTypeInfo; @@ -76,13 +82,17 @@ import static org.apache.hadoop.hive.ql.io.arrow.ArrowColumnarBatchSerDe.toStructListVector; class Deserializer { - private final ArrowColumnarBatchSerDe serDe; private final VectorExtractRow vectorExtractRow; private final VectorizedRowBatch vectorizedRowBatch; + private final StructTypeInfo rowTypeInfo; + private final boolean encode; + private DictionaryProvider dictionaryProvider; private Object[][] rows; Deserializer(ArrowColumnarBatchSerDe serDe) throws SerDeException { - this.serDe = serDe; + rowTypeInfo = serDe.rowTypeInfo; + encode = HiveConf.getBoolVar(serDe.conf, HiveConf.ConfVars.HIVE_ARROW_ENCODE); + vectorExtractRow = new VectorExtractRow(); final List fieldTypeInfoList = serDe.rowTypeInfo.getAllStructFieldTypeInfos(); final int fieldCount = fieldTypeInfoList.size(); @@ -107,6 +117,7 @@ public Object deserialize(Writable writable) { final List fieldVectors = vectorSchemaRoot.getFieldVectors(); final int fieldCount = fieldVectors.size(); final int rowCount = vectorSchemaRoot.getRowCount(); + dictionaryProvider = arrowWrapperWritable.getDictionaryProvider(); vectorizedRowBatch.ensureSize(rowCount); if (rows == null || rows.length < rowCount ) { @@ -120,8 +131,8 @@ public Object deserialize(Writable writable) { final FieldVector fieldVector = fieldVectors.get(fieldIndex); final int projectedCol = vectorizedRowBatch.projectedColumns[fieldIndex]; final ColumnVector columnVector = vectorizedRowBatch.cols[projectedCol]; - final TypeInfo typeInfo = serDe.rowTypeInfo.getAllStructFieldTypeInfos().get(fieldIndex); - read(fieldVector, columnVector, typeInfo); + final TypeInfo typeInfo = rowTypeInfo.getAllStructFieldTypeInfos().get(fieldIndex); + read(fieldVector, columnVector, typeInfo, encode); } for (int rowIndex = 0; rowIndex < rowCount; rowIndex++) { vectorExtractRow.extractRow(vectorizedRowBatch, rowIndex, rows[rowIndex]); @@ -130,19 +141,21 @@ public Object deserialize(Writable writable) { return rows; } - private void read(FieldVector arrowVector, ColumnVector hiveVector, TypeInfo typeInfo) { + private void read(FieldVector arrowVector, ColumnVector hiveVector, TypeInfo typeInfo, + boolean encode) { switch (typeInfo.getCategory()) { case PRIMITIVE: - readPrimitive(arrowVector, hiveVector); + readPrimitive(arrowVector, hiveVector, typeInfo, encode); break; case LIST: - readList(arrowVector, (ListColumnVector) hiveVector, (ListTypeInfo) typeInfo); + readList(arrowVector, (ListColumnVector) hiveVector, (ListTypeInfo) typeInfo, encode); break; case MAP: - readMap(arrowVector, (MapColumnVector) hiveVector, (MapTypeInfo) typeInfo); + readMap(arrowVector, (MapColumnVector) hiveVector, (MapTypeInfo) typeInfo, encode); break; case STRUCT: - readStruct(arrowVector, (StructColumnVector) hiveVector, (StructTypeInfo) typeInfo); + readStruct(arrowVector, (StructColumnVector) hiveVector, (StructTypeInfo) typeInfo, + encode); break; case UNION: readUnion(arrowVector, (UnionColumnVector) hiveVector, (UnionTypeInfo) typeInfo); @@ -152,14 +165,16 @@ private void read(FieldVector arrowVector, ColumnVector hiveVector, TypeInfo typ } } - private void readPrimitive(FieldVector arrowVector, ColumnVector hiveVector) { - final Types.MinorType minorType = arrowVector.getMinorType(); + private void readPrimitive(FieldVector arrowVector, ColumnVector hiveVector, TypeInfo typeInfo, + boolean encode) { + final PrimitiveObjectInspector.PrimitiveCategory primitiveCategory = + ((PrimitiveTypeInfo) typeInfo).getPrimitiveCategory(); final int size = arrowVector.getValueCount(); hiveVector.ensureSize(size, false); - switch (minorType) { - case BIT: + switch (primitiveCategory) { + case BOOLEAN: { for (int i = 0; i < size; i++) { if (arrowVector.isNull(i)) { @@ -171,7 +186,7 @@ private void readPrimitive(FieldVector arrowVector, ColumnVector hiveVector) { } } break; - case TINYINT: + case BYTE: { for (int i = 0; i < size; i++) { if (arrowVector.isNull(i)) { @@ -183,7 +198,7 @@ private void readPrimitive(FieldVector arrowVector, ColumnVector hiveVector) { } } break; - case SMALLINT: + case SHORT: { for (int i = 0; i < size; i++) { if (arrowVector.isNull(i)) { @@ -207,7 +222,7 @@ private void readPrimitive(FieldVector arrowVector, ColumnVector hiveVector) { } } break; - case BIGINT: + case LONG: { for (int i = 0; i < size; i++) { if (arrowVector.isNull(i)) { @@ -219,7 +234,7 @@ private void readPrimitive(FieldVector arrowVector, ColumnVector hiveVector) { } } break; - case FLOAT4: + case FLOAT: { for (int i = 0; i < size; i++) { if (arrowVector.isNull(i)) { @@ -231,7 +246,7 @@ private void readPrimitive(FieldVector arrowVector, ColumnVector hiveVector) { } } break; - case FLOAT8: + case DOUBLE: { for (int i = 0; i < size; i++) { if (arrowVector.isNull(i)) { @@ -243,19 +258,29 @@ private void readPrimitive(FieldVector arrowVector, ColumnVector hiveVector) { } } break; + case STRING: + case CHAR: case VARCHAR: { + final VarCharVector varCharVector; + if (encode) { + final long id = arrowVector.getField().getDictionary().getId(); + final Dictionary dictionary = dictionaryProvider.lookup(id); + varCharVector = (VarCharVector) DictionaryEncoder.decode(arrowVector, dictionary); + } else { + varCharVector = ((VarCharVector) arrowVector); + } for (int i = 0; i < size; i++) { - if (arrowVector.isNull(i)) { + if (varCharVector.isNull(i)) { VectorizedBatchUtil.setNullColIsNullValue(hiveVector, i); } else { hiveVector.isNull[i] = false; - ((BytesColumnVector) hiveVector).setVal(i, ((VarCharVector) arrowVector).get(i)); + ((BytesColumnVector) hiveVector).setVal(i, varCharVector.get(i)); } } } break; - case DATEDAY: + case DATE: { for (int i = 0; i < size; i++) { if (arrowVector.isNull(i)) { @@ -267,94 +292,103 @@ private void readPrimitive(FieldVector arrowVector, ColumnVector hiveVector) { } } break; - case TIMESTAMPMILLI: + case TIMESTAMP: { - for (int i = 0; i < size; i++) { - if (arrowVector.isNull(i)) { - VectorizedBatchUtil.setNullColIsNullValue(hiveVector, i); - } else { - hiveVector.isNull[i] = false; - - // Time = second + sub-second - final long timeInMillis = ((TimeStampMilliVector) arrowVector).get(i); - final TimestampColumnVector timestampColumnVector = (TimestampColumnVector) hiveVector; - int subSecondInNanos = (int) ((timeInMillis % MILLIS_PER_SECOND) * NS_PER_MILLIS); - long second = timeInMillis / MILLIS_PER_SECOND; - - // A nanosecond value should not be negative - if (subSecondInNanos < 0) { - - // So add one second to the negative nanosecond value to make it positive - subSecondInNanos += NS_PER_SECOND; - - // Subtract one second from the second value because we added one second - second -= 1; + final Types.MinorType minorType = arrowVector.getMinorType(); + switch (minorType) { + case TIMESTAMPMILLI: + { + for (int i = 0; i < size; i++) { + if (arrowVector.isNull(i)) { + VectorizedBatchUtil.setNullColIsNullValue(hiveVector, i); + } else { + hiveVector.isNull[i] = false; + + // Time = second + sub-second + final long timeInMillis = ((TimeStampMilliVector) arrowVector).get(i); + final TimestampColumnVector timestampColumnVector = (TimestampColumnVector) hiveVector; + int subSecondInNanos = (int) ((timeInMillis % MILLIS_PER_SECOND) * NS_PER_MILLIS); + long second = timeInMillis / MILLIS_PER_SECOND; + + // A nanosecond value should not be negative + if (subSecondInNanos < 0) { + + // So add one second to the negative nanosecond value to make it positive + subSecondInNanos += NS_PER_SECOND; + + // Subtract one second from the second value because we added one second + second -= 1; + } + timestampColumnVector.time[i] = second * MILLIS_PER_SECOND; + timestampColumnVector.nanos[i] = subSecondInNanos; + } + } } - timestampColumnVector.time[i] = second * MILLIS_PER_SECOND; - timestampColumnVector.nanos[i] = subSecondInNanos; - } - } - } - break; - case TIMESTAMPMICRO: - { - for (int i = 0; i < size; i++) { - if (arrowVector.isNull(i)) { - VectorizedBatchUtil.setNullColIsNullValue(hiveVector, i); - } else { - hiveVector.isNull[i] = false; - - // Time = second + sub-second - final long timeInMicros = ((TimeStampMicroVector) arrowVector).get(i); - final TimestampColumnVector timestampColumnVector = (TimestampColumnVector) hiveVector; - int subSecondInNanos = (int) ((timeInMicros % MICROS_PER_SECOND) * NS_PER_MICROS); - long second = timeInMicros / MICROS_PER_SECOND; - - // A nanosecond value should not be negative - if (subSecondInNanos < 0) { - - // So add one second to the negative nanosecond value to make it positive - subSecondInNanos += NS_PER_SECOND; - - // Subtract one second from the second value because we added one second - second -= 1; + break; + case TIMESTAMPMICRO: + { + for (int i = 0; i < size; i++) { + if (arrowVector.isNull(i)) { + VectorizedBatchUtil.setNullColIsNullValue(hiveVector, i); + } else { + hiveVector.isNull[i] = false; + + // Time = second + sub-second + final long timeInMicros = ((TimeStampMicroVector) arrowVector).get(i); + final TimestampColumnVector timestampColumnVector = (TimestampColumnVector) hiveVector; + int subSecondInNanos = (int) ((timeInMicros % MICROS_PER_SECOND) * NS_PER_MICROS); + long second = timeInMicros / MICROS_PER_SECOND; + + // A nanosecond value should not be negative + if (subSecondInNanos < 0) { + + // So add one second to the negative nanosecond value to make it positive + subSecondInNanos += NS_PER_SECOND; + + // Subtract one second from the second value because we added one second + second -= 1; + } + timestampColumnVector.time[i] = second * MILLIS_PER_SECOND; + timestampColumnVector.nanos[i] = subSecondInNanos; + } + } } - timestampColumnVector.time[i] = second * MILLIS_PER_SECOND; - timestampColumnVector.nanos[i] = subSecondInNanos; - } - } - } - break; - case TIMESTAMPNANO: - { - for (int i = 0; i < size; i++) { - if (arrowVector.isNull(i)) { - VectorizedBatchUtil.setNullColIsNullValue(hiveVector, i); - } else { - hiveVector.isNull[i] = false; - - // Time = second + sub-second - final long timeInNanos = ((TimeStampNanoVector) arrowVector).get(i); - final TimestampColumnVector timestampColumnVector = (TimestampColumnVector) hiveVector; - int subSecondInNanos = (int) (timeInNanos % NS_PER_SECOND); - long second = timeInNanos / NS_PER_SECOND; - - // A nanosecond value should not be negative - if (subSecondInNanos < 0) { - - // So add one second to the negative nanosecond value to make it positive - subSecondInNanos += NS_PER_SECOND; - - // Subtract one second from the second value because we added one second - second -= 1; + break; + case TIMESTAMPNANO: + { + for (int i = 0; i < size; i++) { + if (arrowVector.isNull(i)) { + VectorizedBatchUtil.setNullColIsNullValue(hiveVector, i); + } else { + hiveVector.isNull[i] = false; + + // Time = second + sub-second + final long timeInNanos = ((TimeStampNanoVector) arrowVector).get(i); + final TimestampColumnVector timestampColumnVector = (TimestampColumnVector) hiveVector; + int subSecondInNanos = (int) (timeInNanos % NS_PER_SECOND); + long second = timeInNanos / NS_PER_SECOND; + + // A nanosecond value should not be negative + if (subSecondInNanos < 0) { + + // So add one second to the negative nanosecond value to make it positive + subSecondInNanos += NS_PER_SECOND; + + // Subtract one second from the second value because we added one second + second -= 1; + } + timestampColumnVector.time[i] = second * MILLIS_PER_SECOND; + timestampColumnVector.nanos[i] = subSecondInNanos; + } + } } - timestampColumnVector.time[i] = second * MILLIS_PER_SECOND; - timestampColumnVector.nanos[i] = subSecondInNanos; - } + break; + default: + throw new IllegalArgumentException(); } } break; - case VARBINARY: + case BINARY: { for (int i = 0; i < size; i++) { if (arrowVector.isNull(i)) { @@ -379,7 +413,7 @@ private void readPrimitive(FieldVector arrowVector, ColumnVector hiveVector) { } } break; - case INTERVALYEAR: + case INTERVAL_YEAR_MONTH: { for (int i = 0; i < size; i++) { if (arrowVector.isNull(i)) { @@ -391,7 +425,7 @@ private void readPrimitive(FieldVector arrowVector, ColumnVector hiveVector) { } } break; - case INTERVALDAY: + case INTERVAL_DAY_TIME: { final IntervalDayVector intervalDayVector = (IntervalDayVector) arrowVector; final NullableIntervalDayHolder intervalDayHolder = new NullableIntervalDayHolder(); @@ -416,14 +450,14 @@ private void readPrimitive(FieldVector arrowVector, ColumnVector hiveVector) { } } - private void readList(FieldVector arrowVector, ListColumnVector hiveVector, ListTypeInfo typeInfo) { + private void readList(FieldVector arrowVector, ListColumnVector hiveVector, ListTypeInfo typeInfo, + boolean encode) { final int size = arrowVector.getValueCount(); final ArrowBuf offsets = arrowVector.getOffsetBuffer(); final int OFFSET_WIDTH = 4; - read(arrowVector.getChildrenFromFields().get(0), - hiveVector.child, - typeInfo.getListElementTypeInfo()); + read(arrowVector.getChildrenFromFields().get(0), hiveVector.child, + typeInfo.getListElementTypeInfo(), encode); for (int i = 0; i < size; i++) { if (arrowVector.isNull(i)) { @@ -437,13 +471,14 @@ private void readList(FieldVector arrowVector, ListColumnVector hiveVector, List } } - private void readMap(FieldVector arrowVector, MapColumnVector hiveVector, MapTypeInfo typeInfo) { + private void readMap(FieldVector arrowVector, MapColumnVector hiveVector, MapTypeInfo typeInfo, + boolean encode) { final int size = arrowVector.getValueCount(); final ListTypeInfo mapStructListTypeInfo = toStructListTypeInfo(typeInfo); final ListColumnVector mapStructListVector = toStructListVector(hiveVector); final StructColumnVector mapStructVector = (StructColumnVector) mapStructListVector.child; - read(arrowVector, mapStructListVector, mapStructListTypeInfo); + read(arrowVector, mapStructListVector, mapStructListTypeInfo, encode); hiveVector.isRepeating = mapStructListVector.isRepeating; hiveVector.childCount = mapStructListVector.childCount; @@ -455,12 +490,14 @@ private void readMap(FieldVector arrowVector, MapColumnVector hiveVector, MapTyp System.arraycopy(mapStructListVector.isNull, 0, hiveVector.isNull, 0, size); } - private void readStruct(FieldVector arrowVector, StructColumnVector hiveVector, StructTypeInfo typeInfo) { + private void readStruct(FieldVector arrowVector, StructColumnVector hiveVector, + StructTypeInfo typeInfo, boolean encode) { final int size = arrowVector.getValueCount(); final List fieldTypeInfos = typeInfo.getAllStructFieldTypeInfos(); final int fieldSize = arrowVector.getChildrenFromFields().size(); for (int i = 0; i < fieldSize; i++) { - read(arrowVector.getChildrenFromFields().get(i), hiveVector.fields[i], fieldTypeInfos.get(i)); + read(arrowVector.getChildrenFromFields().get(i), hiveVector.fields[i], fieldTypeInfos.get(i), + encode); } for (int i = 0; i < size; i++) { diff --git ql/src/java/org/apache/hadoop/hive/ql/io/arrow/Serializer.java ql/src/java/org/apache/hadoop/hive/ql/io/arrow/Serializer.java index e6af916ce8..f76de96e02 100644 --- ql/src/java/org/apache/hadoop/hive/ql/io/arrow/Serializer.java +++ ql/src/java/org/apache/hadoop/hive/ql/io/arrow/Serializer.java @@ -18,6 +18,7 @@ package org.apache.hadoop.hive.ql.io.arrow; import io.netty.buffer.ArrowBuf; +import org.apache.arrow.memory.BufferAllocator; import org.apache.arrow.vector.BigIntVector; import org.apache.arrow.vector.BitVector; import org.apache.arrow.vector.BitVectorHelper; @@ -37,10 +38,15 @@ import org.apache.arrow.vector.VectorSchemaRoot; import org.apache.arrow.vector.complex.ListVector; import org.apache.arrow.vector.complex.MapVector; -import org.apache.arrow.vector.complex.NullableMapVector; +import org.apache.arrow.vector.dictionary.Dictionary; +import org.apache.arrow.vector.dictionary.DictionaryEncoder; +import org.apache.arrow.vector.dictionary.DictionaryProvider; import org.apache.arrow.vector.types.Types; import org.apache.arrow.vector.types.pojo.ArrowType; +import org.apache.arrow.vector.types.pojo.DictionaryEncoding; +import org.apache.arrow.vector.types.pojo.Field; import org.apache.arrow.vector.types.pojo.FieldType; +import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.hive.conf.HiveConf; import org.apache.hadoop.hive.ql.exec.vector.BytesColumnVector; import org.apache.hadoop.hive.ql.exec.vector.ColumnVector; @@ -67,10 +73,15 @@ import org.apache.hadoop.hive.serde2.typeinfo.PrimitiveTypeInfo; import org.apache.hadoop.hive.serde2.typeinfo.StructTypeInfo; import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo; +import org.apache.hadoop.hive.serde2.typeinfo.TypeInfoFactory; import org.apache.hadoop.hive.serde2.typeinfo.UnionTypeInfo; +import java.nio.ByteBuffer; import java.util.ArrayList; +import java.util.Collections; +import java.util.HashSet; import java.util.List; +import java.util.Set; import static org.apache.hadoop.hive.conf.HiveConf.ConfVars.HIVE_ARROW_BATCH_SIZE; import static org.apache.hadoop.hive.ql.exec.vector.VectorizedBatchUtil.createColumnVector; @@ -96,20 +107,25 @@ private final VectorAssignRow vectorAssignRow; private int batchSize; - private final NullableMapVector rootVector; + private final Configuration conf; + private final BufferAllocator bufferAllocator; + private final boolean encode; + private long dictionaryId; + private DictionaryProvider.MapDictionaryProvider dictionaryProvider; Serializer(ArrowColumnarBatchSerDe serDe) throws SerDeException { MAX_BUFFERED_ROWS = HiveConf.getIntVar(serDe.conf, HIVE_ARROW_BATCH_SIZE); - ArrowColumnarBatchSerDe.LOG.info("ArrowColumnarBatchSerDe max number of buffered columns: " + MAX_BUFFERED_ROWS); + ArrowColumnarBatchSerDe.LOG.info("ArrowColumnarBatchSerDe max number of buffered columns: " + + MAX_BUFFERED_ROWS); + conf = serDe.conf; + bufferAllocator = serDe.rootAllocator; + encode = HiveConf.getBoolVar(conf, HiveConf.ConfVars.HIVE_ARROW_ENCODE); // Schema structTypeInfo = (StructTypeInfo) getTypeInfoFromObjectInspector(serDe.rowObjectInspector); List fieldTypeInfos = structTypeInfo.getAllStructFieldTypeInfos(); fieldSize = fieldTypeInfos.size(); - // Init Arrow stuffs - rootVector = NullableMapVector.empty(null, serDe.rootAllocator); - // Init Hive stuffs vectorizedRowBatch = new VectorizedRowBatch(fieldSize); for (int fieldIndex = 0; fieldIndex < fieldSize; fieldIndex++) { @@ -127,28 +143,77 @@ } private ArrowWrapperWritable serializeBatch() { - rootVector.setValueCount(0); + dictionaryProvider = new DictionaryProvider.MapDictionaryProvider(); + final List fieldList = new ArrayList<>(); + final List vectorList = new ArrayList<>(); for (int fieldIndex = 0; fieldIndex < vectorizedRowBatch.projectionSize; fieldIndex++) { final int projectedColumn = vectorizedRowBatch.projectedColumns[fieldIndex]; final ColumnVector hiveVector = vectorizedRowBatch.cols[projectedColumn]; final TypeInfo fieldTypeInfo = structTypeInfo.getAllStructFieldTypeInfos().get(fieldIndex); final String fieldName = structTypeInfo.getAllStructFieldNames().get(fieldIndex); - final FieldType fieldType = toFieldType(fieldTypeInfo); - final FieldVector arrowVector = rootVector.addOrGet(fieldName, fieldType, FieldVector.class); + final Field field = toField(fieldName, fieldTypeInfo); + final FieldVector arrowVector = field.createVector(bufferAllocator); arrowVector.setInitialCapacity(batchSize); arrowVector.allocateNew(); - write(arrowVector, hiveVector, fieldTypeInfo, batchSize); + vectorList.add(write(arrowVector, hiveVector, fieldTypeInfo, batchSize, encode)); + fieldList.add(arrowVector.getField()); + arrowVector.setValueCount(batchSize); } vectorizedRowBatch.reset(); - rootVector.setValueCount(batchSize); + VectorSchemaRoot vectorSchemaRoot = new VectorSchemaRoot(fieldList, vectorList, batchSize); batchSize = 0; - VectorSchemaRoot vectorSchemaRoot = new VectorSchemaRoot(rootVector); - return new ArrowWrapperWritable(vectorSchemaRoot); + return new ArrowWrapperWritable(vectorSchemaRoot, dictionaryProvider); + } + + private Field toField(String fieldName, TypeInfo typeInfo) { + switch (typeInfo.getCategory()) { + case PRIMITIVE: + return new Field(fieldName, toFieldType(typeInfo), null); + case MAP: + return toField(fieldName, toStructListTypeInfo((MapTypeInfo) typeInfo)); + case LIST: + { + final ListTypeInfo listTypeInfo = (ListTypeInfo) typeInfo; + final Field elementField = toField(null, listTypeInfo.getListElementTypeInfo()); + final List children = Collections.singletonList(elementField); + return new Field(fieldName, toFieldType(typeInfo), children); + } + case STRUCT: + { + final List children = new ArrayList<>(); + final StructTypeInfo structTypeInfo = (StructTypeInfo) typeInfo; + final List fieldNames = structTypeInfo.getAllStructFieldNames(); + final List fieldTypeInfos = structTypeInfo.getAllStructFieldTypeInfos(); + for (int i = 0; i < fieldNames.size(); i++) { + children.add(toField(fieldNames.get(i), fieldTypeInfos.get(i))); + } + return new Field(fieldName, toFieldType(typeInfo), children); + } + case UNION: + default: + throw new IllegalArgumentException(); + } } private FieldType toFieldType(TypeInfo typeInfo) { + if (encode) { + if (typeInfo.getCategory() == ObjectInspector.Category.PRIMITIVE) { + final PrimitiveTypeInfo primitiveTypeInfo = (PrimitiveTypeInfo) typeInfo; + final PrimitiveObjectInspector.PrimitiveCategory primitiveCategory = + primitiveTypeInfo.getPrimitiveCategory(); + switch (primitiveCategory) { + case VARCHAR: + case CHAR: + case STRING: + { + return new FieldType(true, toArrowType(TypeInfoFactory.intTypeInfo), + new DictionaryEncoding(dictionaryId++, false, null)); + } + } + } + } return new FieldType(true, toArrowType(typeInfo), null); } @@ -206,34 +271,34 @@ private ArrowType toArrowType(TypeInfo typeInfo) { } } - private void write(FieldVector arrowVector, ColumnVector hiveVector, TypeInfo typeInfo, int size) { + @SuppressWarnings("unchecked") + private FieldVector write(FieldVector arrowVector, ColumnVector hiveVector, TypeInfo typeInfo, + int size, boolean encode) { switch (typeInfo.getCategory()) { case PRIMITIVE: - writePrimitive(arrowVector, hiveVector, typeInfo, size); - break; + return writePrimitive(arrowVector, hiveVector, (PrimitiveTypeInfo) typeInfo, size, encode); case LIST: - writeList((ListVector) arrowVector, (ListColumnVector) hiveVector, (ListTypeInfo) typeInfo, size); - break; + return writeList((ListVector) arrowVector, (ListColumnVector) hiveVector, + (ListTypeInfo) typeInfo, size, encode); case STRUCT: - writeStruct((MapVector) arrowVector, (StructColumnVector) hiveVector, (StructTypeInfo) typeInfo, size); - break; + return writeStruct((MapVector) arrowVector, (StructColumnVector) hiveVector, + (StructTypeInfo) typeInfo, size, encode); case UNION: - writeUnion(arrowVector, hiveVector, typeInfo, size); - break; + return writeUnion(arrowVector, hiveVector, typeInfo, size, encode); case MAP: - writeMap((ListVector) arrowVector, (MapColumnVector) hiveVector, (MapTypeInfo) typeInfo, size); - break; + return writeMap((ListVector) arrowVector, (MapColumnVector) hiveVector, + (MapTypeInfo) typeInfo, size, encode); default: throw new IllegalArgumentException(); } } - private void writeMap(ListVector arrowVector, MapColumnVector hiveVector, MapTypeInfo typeInfo, - int size) { + private FieldVector writeMap(ListVector arrowVector, MapColumnVector hiveVector, + MapTypeInfo typeInfo, int size, boolean encode) { final ListTypeInfo structListTypeInfo = toStructListTypeInfo(typeInfo); final ListColumnVector structListVector = toStructListVector(hiveVector); - write(arrowVector, structListVector, structListTypeInfo, size); + write(arrowVector, structListVector, structListTypeInfo, size, encode); final ArrowBuf validityBuffer = arrowVector.getValidityBuffer(); for (int rowIndex = 0; rowIndex < size; rowIndex++) { @@ -243,10 +308,12 @@ private void writeMap(ListVector arrowVector, MapColumnVector hiveVector, MapTyp BitVectorHelper.setValidityBitToOne(validityBuffer, rowIndex); } } + + return arrowVector; } - private void writeUnion(FieldVector arrowVector, ColumnVector hiveVector, TypeInfo typeInfo, - int size) { + private FieldVector writeUnion(FieldVector arrowVector, ColumnVector hiveVector, + TypeInfo typeInfo, int size, boolean encode) { final UnionTypeInfo unionTypeInfo = (UnionTypeInfo) typeInfo; final List objectTypeInfos = unionTypeInfo.getAllUnionObjectTypeInfos(); final UnionColumnVector hiveUnionVector = (UnionColumnVector) hiveVector; @@ -256,11 +323,14 @@ private void writeUnion(FieldVector arrowVector, ColumnVector hiveVector, TypeIn final ColumnVector hiveObjectVector = hiveObjectVectors[tag]; final TypeInfo objectTypeInfo = objectTypeInfos.get(tag); - write(arrowVector, hiveObjectVector, objectTypeInfo, size); + write(arrowVector, hiveObjectVector, objectTypeInfo, size, encode); + + return arrowVector; } - private void writeStruct(MapVector arrowVector, StructColumnVector hiveVector, - StructTypeInfo typeInfo, int size) { + @SuppressWarnings("unchecked") + private FieldVector writeStruct(MapVector arrowVector, StructColumnVector hiveVector, + StructTypeInfo typeInfo, int size, boolean encode) { final List fieldNames = typeInfo.getAllStructFieldNames(); final List fieldTypeInfos = typeInfo.getAllStructFieldTypeInfos(); final ColumnVector[] hiveFieldVectors = hiveVector.fields; @@ -270,12 +340,11 @@ private void writeStruct(MapVector arrowVector, StructColumnVector hiveVector, final TypeInfo fieldTypeInfo = fieldTypeInfos.get(fieldIndex); final ColumnVector hiveFieldVector = hiveFieldVectors[fieldIndex]; final String fieldName = fieldNames.get(fieldIndex); - final FieldVector arrowFieldVector = - arrowVector.addOrGet(fieldName, - toFieldType(fieldTypeInfos.get(fieldIndex)), FieldVector.class); + final FieldVector arrowFieldVector = arrowVector.addOrGet(fieldName, + toFieldType(fieldTypeInfos.get(fieldIndex)), FieldVector.class); arrowFieldVector.setInitialCapacity(size); arrowFieldVector.allocateNew(); - write(arrowFieldVector, hiveFieldVector, fieldTypeInfo, size); + write(arrowFieldVector, hiveFieldVector, fieldTypeInfo, size, encode); } final ArrowBuf validityBuffer = arrowVector.getValidityBuffer(); @@ -286,10 +355,12 @@ private void writeStruct(MapVector arrowVector, StructColumnVector hiveVector, BitVectorHelper.setValidityBitToOne(validityBuffer, rowIndex); } } + + return (FieldVector) arrowVector; } - private void writeList(ListVector arrowVector, ListColumnVector hiveVector, ListTypeInfo typeInfo, - int size) { + private FieldVector writeList(ListVector arrowVector, ListColumnVector hiveVector, + ListTypeInfo typeInfo, int size, boolean encode) { final int OFFSET_WIDTH = 4; final TypeInfo elementTypeInfo = typeInfo.getListElementTypeInfo(); final ColumnVector hiveElementVector = hiveVector.child; @@ -298,7 +369,7 @@ private void writeList(ListVector arrowVector, ListColumnVector hiveVector, List arrowElementVector.setInitialCapacity(hiveVector.childCount); arrowElementVector.allocateNew(); - write(arrowElementVector, hiveElementVector, elementTypeInfo, hiveVector.childCount); + write(arrowElementVector, hiveElementVector, elementTypeInfo, hiveVector.childCount, encode); final ArrowBuf offsetBuffer = arrowVector.getOffsetBuffer(); int nextOffset = 0; @@ -313,13 +384,12 @@ private void writeList(ListVector arrowVector, ListColumnVector hiveVector, List } } offsetBuffer.setInt(size * OFFSET_WIDTH, nextOffset); + return arrowVector; } - private void writePrimitive(FieldVector arrowVector, ColumnVector hiveVector, TypeInfo typeInfo, - int size) { - final PrimitiveObjectInspector.PrimitiveCategory primitiveCategory = - ((PrimitiveTypeInfo) typeInfo).getPrimitiveCategory(); - switch (primitiveCategory) { + private FieldVector writePrimitive(FieldVector arrowVector, ColumnVector hiveVector, + PrimitiveTypeInfo primitiveTypeInfo, int size, boolean encode) { + switch (primitiveTypeInfo.getPrimitiveCategory()) { case BOOLEAN: { final BitVector bitVector = (BitVector) arrowVector; @@ -330,8 +400,8 @@ private void writePrimitive(FieldVector arrowVector, ColumnVector hiveVector, Ty bitVector.set(i, (int) ((LongColumnVector) hiveVector).vector[i]); } } + return bitVector; } - break; case BYTE: { final TinyIntVector tinyIntVector = (TinyIntVector) arrowVector; @@ -342,8 +412,8 @@ private void writePrimitive(FieldVector arrowVector, ColumnVector hiveVector, Ty tinyIntVector.set(i, (byte) ((LongColumnVector) hiveVector).vector[i]); } } + return tinyIntVector; } - break; case SHORT: { final SmallIntVector smallIntVector = (SmallIntVector) arrowVector; @@ -354,8 +424,8 @@ private void writePrimitive(FieldVector arrowVector, ColumnVector hiveVector, Ty smallIntVector.set(i, (short) ((LongColumnVector) hiveVector).vector[i]); } } + return smallIntVector; } - break; case INT: { final IntVector intVector = (IntVector) arrowVector; @@ -366,8 +436,8 @@ private void writePrimitive(FieldVector arrowVector, ColumnVector hiveVector, Ty intVector.set(i, (int) ((LongColumnVector) hiveVector).vector[i]); } } + return intVector; } - break; case LONG: { final BigIntVector bigIntVector = (BigIntVector) arrowVector; @@ -378,8 +448,8 @@ private void writePrimitive(FieldVector arrowVector, ColumnVector hiveVector, Ty bigIntVector.set(i, ((LongColumnVector) hiveVector).vector[i]); } } + return bigIntVector; } - break; case FLOAT: { final Float4Vector float4Vector = (Float4Vector) arrowVector; @@ -390,8 +460,8 @@ private void writePrimitive(FieldVector arrowVector, ColumnVector hiveVector, Ty float4Vector.set(i, (float) ((DoubleColumnVector) hiveVector).vector[i]); } } + return float4Vector; } - break; case DOUBLE: { final Float8Vector float8Vector = (Float8Vector) arrowVector; @@ -402,23 +472,74 @@ private void writePrimitive(FieldVector arrowVector, ColumnVector hiveVector, Ty float8Vector.set(i, ((DoubleColumnVector) hiveVector).vector[i]); } } + return float8Vector; } - break; - case STRING: case VARCHAR: case CHAR: + case STRING: { - final VarCharVector varCharVector = (VarCharVector) arrowVector; - final BytesColumnVector bytesVector = (BytesColumnVector) hiveVector; - for (int i = 0; i < size; i++) { - if (hiveVector.isNull[i]) { - varCharVector.setNull(i); - } else { - varCharVector.setSafe(i, bytesVector.vector[i], bytesVector.start[i], bytesVector.length[i]); + if (encode) { + final BytesColumnVector bytesVector = (BytesColumnVector) hiveVector; + final VarCharVector varCharVector = (VarCharVector) + Types.MinorType.VARCHAR.getNewVector(null, + FieldType.nullable(Types.MinorType.VARCHAR.getType()), bufferAllocator, null); + for (int i = 0; i < size; i++) { + if (hiveVector.isNull[i]) { + varCharVector.setNull(i); + } else { + varCharVector.setSafe(i, bytesVector.vector[i], bytesVector.start[i], + bytesVector.length[i]); + } + } + arrowVector.setValueCount(size); + + int j = 0; + final Set occurrences = new HashSet<>(); + for (int i = 0; i < size; i++) { + if (!bytesVector.isNull[i]) { + final ByteBuffer byteBuffer = + ByteBuffer.wrap(bytesVector.vector[j], bytesVector.start[j], + bytesVector.length[j]); + if (!occurrences.contains(byteBuffer)) { + occurrences.add(byteBuffer); + } + j++; + } } + final FieldType fieldType = arrowVector.getField().getFieldType(); + final VarCharVector dictionaryVector = (VarCharVector) Types.MinorType.VARCHAR. + getNewVector(null, fieldType, bufferAllocator, null); + j = 0; + for (ByteBuffer occurrence : occurrences) { + final int start = occurrence.position(); + final int length = occurrence.limit() - start; + dictionaryVector.setSafe(j++, occurrence.array(), start, length); + } + dictionaryVector.setValueCount(occurrences.size()); + varCharVector.setValueCount(size); + final DictionaryEncoding dictionaryEncoding = arrowVector.getField().getDictionary(); + final Dictionary dictionary = new Dictionary(dictionaryVector, dictionaryEncoding); + dictionaryProvider.put(dictionary); + final IntVector encodedVector = (IntVector) DictionaryEncoder.encode(varCharVector, + dictionary); + encodedVector.makeTransferPair(arrowVector).transfer(); + return arrowVector; + } else { + final BytesColumnVector bytesVector = (BytesColumnVector) hiveVector; + final VarCharVector varCharVector = (VarCharVector) arrowVector; + for (int i = 0; i < size; i++) { + if (hiveVector.isNull[i]) { + varCharVector.setNull(i); + } else { + varCharVector.setSafe(i, bytesVector.vector[i], bytesVector.start[i], + bytesVector.length[i]); + } + } + varCharVector.setValueCount(size); + + return varCharVector; } } - break; case DATE: { final DateDayVector dateDayVector = (DateDayVector) arrowVector; @@ -429,8 +550,8 @@ private void writePrimitive(FieldVector arrowVector, ColumnVector hiveVector, Ty dateDayVector.set(i, (int) ((LongColumnVector) hiveVector).vector[i]); } } + return dateDayVector; } - break; case TIMESTAMP: { final TimeStampMicroVector timeStampMicroVector = (TimeStampMicroVector) arrowVector; @@ -452,8 +573,8 @@ private void writePrimitive(FieldVector arrowVector, ColumnVector hiveVector, Ty } } } + return timeStampMicroVector; } - break; case BINARY: { final VarBinaryVector varBinaryVector = (VarBinaryVector) arrowVector; @@ -465,8 +586,8 @@ private void writePrimitive(FieldVector arrowVector, ColumnVector hiveVector, Ty varBinaryVector.setSafe(i, bytesVector.vector[i], bytesVector.start[i], bytesVector.length[i]); } } + return varBinaryVector; } - break; case DECIMAL: { final DecimalVector decimalVector = (DecimalVector) arrowVector; @@ -475,12 +596,12 @@ private void writePrimitive(FieldVector arrowVector, ColumnVector hiveVector, Ty if (hiveVector.isNull[i]) { decimalVector.setNull(i); } else { - decimalVector.set(i, - ((DecimalColumnVector) hiveVector).vector[i].getHiveDecimal().bigDecimalValue().setScale(scale)); + decimalVector.set(i, ((DecimalColumnVector) hiveVector).vector[i].getHiveDecimal(). + bigDecimalValue().setScale(scale)); } } + return decimalVector; } - break; case INTERVAL_YEAR_MONTH: { final IntervalYearVector intervalYearVector = (IntervalYearVector) arrowVector; @@ -491,8 +612,8 @@ private void writePrimitive(FieldVector arrowVector, ColumnVector hiveVector, Ty intervalYearVector.set(i, (int) ((LongColumnVector) hiveVector).vector[i]); } } + return intervalYearVector; } - break; case INTERVAL_DAY_TIME: { final IntervalDayVector intervalDayVector = (IntervalDayVector) arrowVector; @@ -510,11 +631,8 @@ private void writePrimitive(FieldVector arrowVector, ColumnVector hiveVector, Ty intervalDayVector.set(i, (int) days, (int) millis); } } + return intervalDayVector; } - break; - case VOID: - case UNKNOWN: - case TIMESTAMPLOCALTZ: default: throw new IllegalArgumentException(); } diff --git ql/src/test/org/apache/hadoop/hive/ql/io/arrow/TestArrowColumnarBatchSerDe.java ql/src/test/org/apache/hadoop/hive/ql/io/arrow/TestArrowColumnarBatchSerDe.java index ce25c3e8f9..b6815d5db1 100644 --- ql/src/test/org/apache/hadoop/hive/ql/io/arrow/TestArrowColumnarBatchSerDe.java +++ ql/src/test/org/apache/hadoop/hive/ql/io/arrow/TestArrowColumnarBatchSerDe.java @@ -485,6 +485,27 @@ public void testPrimitiveString() throws SerDeException { initAndSerializeAndDeserialize(schema, STRING_ROWS); } + @Test + public void testPrimitiveEncodeString() throws SerDeException { + String[][] schema = { + {"string1", "string"}, + }; + + HiveConf.setBoolVar(conf, HiveConf.ConfVars.HIVE_ARROW_ENCODE, true); + + final Object[][] rows = { + {text("")}, + {text("Hello")}, + {text("Hello")}, + {text("world!")}, + {text("Hello")}, + {text("world!")}, + {text("world")}, + {null}, + }; + initAndSerializeAndDeserialize(schema, rows); + } + @Test public void testPrimitiveDTI() throws SerDeException { String[][] schema = { @@ -588,6 +609,27 @@ public void testListString() throws SerDeException { initAndSerializeAndDeserialize(schema, toList(STRING_ROWS)); } + @Test + public void testListEncodeString() throws SerDeException { + String[][] schema = { + {"string1", "array"}, + }; + + HiveConf.setBoolVar(conf, HiveConf.ConfVars.HIVE_ARROW_ENCODE, true); + + final Object[][] rows = { + {text("")}, + {text("Hello")}, + {text("Hello")}, + {text("world!")}, + {text("Hello")}, + {text("world!")}, + {text("world")}, + {null}, + }; + initAndSerializeAndDeserialize(schema, toList(rows)); + } + @Test public void testListDTI() throws SerDeException { String[][] schema = {