diff --git common/src/java/org/apache/hadoop/hive/conf/HiveConf.java common/src/java/org/apache/hadoop/hive/conf/HiveConf.java index 931533a556..f3191d0f9e 100644 --- common/src/java/org/apache/hadoop/hive/conf/HiveConf.java +++ common/src/java/org/apache/hadoop/hive/conf/HiveConf.java @@ -2632,6 +2632,7 @@ private static void populateLlapDaemonVarsSet(Set llapDaemonVarsSetLocal HIVE_ARROW_ROOT_ALLOCATOR_LIMIT("hive.arrow.root.allocator.limit", Long.MAX_VALUE, "Arrow root allocator memory size limitation in bytes."), HIVE_ARROW_BATCH_SIZE("hive.arrow.batch.size", 1000, "The number of rows sent in one Arrow batch."), + HIVE_ARROW_ENCODE("hive.arrow.encode", false, "Set to true to encode repeating values."), // For Druid storage handler HIVE_DRUID_INDEXING_GRANULARITY("hive.druid.indexer.segments.granularity", "DAY", diff --git ql/src/java/org/apache/hadoop/hive/ql/io/arrow/ArrowWrapperWritable.java ql/src/java/org/apache/hadoop/hive/ql/io/arrow/ArrowWrapperWritable.java index dd490b1b90..4dd3ce02bd 100644 --- ql/src/java/org/apache/hadoop/hive/ql/io/arrow/ArrowWrapperWritable.java +++ ql/src/java/org/apache/hadoop/hive/ql/io/arrow/ArrowWrapperWritable.java @@ -19,6 +19,7 @@ package org.apache.hadoop.hive.ql.io.arrow; import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.dictionary.DictionaryProvider; import org.apache.hadoop.io.WritableComparable; import java.io.DataInput; @@ -27,10 +28,18 @@ public class ArrowWrapperWritable implements WritableComparable { private VectorSchemaRoot vectorSchemaRoot; + private DictionaryProvider dictionaryProvider; public ArrowWrapperWritable(VectorSchemaRoot vectorSchemaRoot) { + this(vectorSchemaRoot, null); + } + + public ArrowWrapperWritable(VectorSchemaRoot vectorSchemaRoot, + DictionaryProvider dictionaryProvider) { this.vectorSchemaRoot = vectorSchemaRoot; + this.dictionaryProvider = dictionaryProvider; } + public ArrowWrapperWritable() {} public VectorSchemaRoot getVectorSchemaRoot() { @@ -41,6 +50,14 @@ public void setVectorSchemaRoot(VectorSchemaRoot vectorSchemaRoot) { this.vectorSchemaRoot = vectorSchemaRoot; } + public DictionaryProvider getDictionaryProvider() { + return dictionaryProvider; + } + + public void setDictionaryProvider(DictionaryProvider dictionaryProvider) { + this.dictionaryProvider = dictionaryProvider; + } + @Override public void write(DataOutput dataOutput) throws IOException { throw new UnsupportedOperationException(); diff --git ql/src/java/org/apache/hadoop/hive/ql/io/arrow/Deserializer.java ql/src/java/org/apache/hadoop/hive/ql/io/arrow/Deserializer.java index fb5800b140..a6af62fed3 100644 --- ql/src/java/org/apache/hadoop/hive/ql/io/arrow/Deserializer.java +++ ql/src/java/org/apache/hadoop/hive/ql/io/arrow/Deserializer.java @@ -34,9 +34,13 @@ import org.apache.arrow.vector.VarBinaryVector; import org.apache.arrow.vector.VarCharVector; import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.dictionary.Dictionary; +import org.apache.arrow.vector.dictionary.DictionaryEncoder; +import org.apache.arrow.vector.dictionary.DictionaryProvider; import org.apache.arrow.vector.holders.NullableIntervalDayHolder; import org.apache.hadoop.hive.common.type.HiveDecimal; import org.apache.hadoop.hive.common.type.HiveIntervalDayTime; +import org.apache.hadoop.hive.conf.HiveConf; import org.apache.hadoop.hive.ql.exec.vector.BytesColumnVector; import org.apache.hadoop.hive.ql.exec.vector.ColumnVector; import org.apache.hadoop.hive.ql.exec.vector.DecimalColumnVector; @@ -73,13 +77,17 @@ import static org.apache.hadoop.hive.ql.io.arrow.ArrowColumnarBatchSerDe.toStructListVector; class Deserializer { - private final ArrowColumnarBatchSerDe serDe; private final VectorExtractRow vectorExtractRow; private final VectorizedRowBatch vectorizedRowBatch; + private final StructTypeInfo rowTypeInfo; + private final boolean encode; + private DictionaryProvider dictionaryProvider; private Object[][] rows; Deserializer(ArrowColumnarBatchSerDe serDe) throws SerDeException { - this.serDe = serDe; + rowTypeInfo = serDe.rowTypeInfo; + encode = HiveConf.getBoolVar(serDe.conf, HiveConf.ConfVars.HIVE_ARROW_ENCODE); + vectorExtractRow = new VectorExtractRow(); final List fieldTypeInfoList = serDe.rowTypeInfo.getAllStructFieldTypeInfos(); final int fieldCount = fieldTypeInfoList.size(); @@ -104,6 +112,7 @@ public Object deserialize(Writable writable) { final List fieldVectors = vectorSchemaRoot.getFieldVectors(); final int fieldCount = fieldVectors.size(); final int rowCount = vectorSchemaRoot.getRowCount(); + dictionaryProvider = arrowWrapperWritable.getDictionaryProvider(); vectorizedRowBatch.ensureSize(rowCount); if (rows == null || rows.length < rowCount ) { @@ -117,8 +126,8 @@ public Object deserialize(Writable writable) { final FieldVector fieldVector = fieldVectors.get(fieldIndex); final int projectedCol = vectorizedRowBatch.projectedColumns[fieldIndex]; final ColumnVector columnVector = vectorizedRowBatch.cols[projectedCol]; - final TypeInfo typeInfo = serDe.rowTypeInfo.getAllStructFieldTypeInfos().get(fieldIndex); - read(fieldVector, columnVector, typeInfo); + final TypeInfo typeInfo = rowTypeInfo.getAllStructFieldTypeInfos().get(fieldIndex); + read(fieldVector, columnVector, typeInfo, encode); } for (int rowIndex = 0; rowIndex < rowCount; rowIndex++) { vectorExtractRow.extractRow(vectorizedRowBatch, rowIndex, rows[rowIndex]); @@ -127,10 +136,10 @@ public Object deserialize(Writable writable) { return rows; } - private void read(FieldVector arrowVector, ColumnVector hiveVector, TypeInfo typeInfo) { + private void read(FieldVector arrowVector, ColumnVector hiveVector, TypeInfo typeInfo, boolean encodable) { switch (typeInfo.getCategory()) { case PRIMITIVE: - readPrimitive(arrowVector, hiveVector, typeInfo); + readPrimitive(arrowVector, hiveVector, typeInfo, encodable); break; case LIST: readList(arrowVector, (ListColumnVector) hiveVector, (ListTypeInfo) typeInfo); @@ -149,7 +158,8 @@ private void read(FieldVector arrowVector, ColumnVector hiveVector, TypeInfo typ } } - private void readPrimitive(FieldVector arrowVector, ColumnVector hiveVector, TypeInfo typeInfo) { + private void readPrimitive(FieldVector arrowVector, ColumnVector hiveVector, TypeInfo typeInfo, + boolean encodable) { final PrimitiveObjectInspector.PrimitiveCategory primitiveCategory = ((PrimitiveTypeInfo) typeInfo).getPrimitiveCategory(); @@ -245,12 +255,20 @@ private void readPrimitive(FieldVector arrowVector, ColumnVector hiveVector, Typ case VARCHAR: case CHAR: { + final VarCharVector varCharVector; + if (encodable) { + final long id = arrowVector.getField().getDictionary().getId(); + final Dictionary dictionary = dictionaryProvider.lookup(id); + varCharVector = (VarCharVector) DictionaryEncoder.decode(arrowVector, dictionary); + } else { + varCharVector = ((VarCharVector) arrowVector); + } for (int i = 0; i < size; i++) { - if (arrowVector.isNull(i)) { + if (varCharVector.isNull(i)) { VectorizedBatchUtil.setNullColIsNullValue(hiveVector, i); } else { hiveVector.isNull[i] = false; - ((BytesColumnVector) hiveVector).setVal(i, ((VarCharVector) arrowVector).get(i)); + ((BytesColumnVector) hiveVector).setVal(i, varCharVector.get(i)); } } } @@ -369,7 +387,7 @@ private void readList(FieldVector arrowVector, ListColumnVector hiveVector, List read(arrowVector.getChildrenFromFields().get(0), hiveVector.child, - typeInfo.getListElementTypeInfo()); + typeInfo.getListElementTypeInfo(), false); for (int i = 0; i < size; i++) { if (arrowVector.isNull(i)) { @@ -389,7 +407,7 @@ private void readMap(FieldVector arrowVector, MapColumnVector hiveVector, MapTyp final ListColumnVector mapStructListVector = toStructListVector(hiveVector); final StructColumnVector mapStructVector = (StructColumnVector) mapStructListVector.child; - read(arrowVector, mapStructListVector, mapStructListTypeInfo); + read(arrowVector, mapStructListVector, mapStructListTypeInfo, false); hiveVector.isRepeating = mapStructListVector.isRepeating; hiveVector.childCount = mapStructListVector.childCount; @@ -406,7 +424,8 @@ private void readStruct(FieldVector arrowVector, StructColumnVector hiveVector, final List fieldTypeInfos = typeInfo.getAllStructFieldTypeInfos(); final int fieldSize = arrowVector.getChildrenFromFields().size(); for (int i = 0; i < fieldSize; i++) { - read(arrowVector.getChildrenFromFields().get(i), hiveVector.fields[i], fieldTypeInfos.get(i)); + read(arrowVector.getChildrenFromFields().get(i), hiveVector.fields[i], fieldTypeInfos.get(i), + false); } for (int i = 0; i < size; i++) { diff --git ql/src/java/org/apache/hadoop/hive/ql/io/arrow/Serializer.java ql/src/java/org/apache/hadoop/hive/ql/io/arrow/Serializer.java index bd23011c93..e1f0b5f9eb 100644 --- ql/src/java/org/apache/hadoop/hive/ql/io/arrow/Serializer.java +++ ql/src/java/org/apache/hadoop/hive/ql/io/arrow/Serializer.java @@ -18,6 +18,7 @@ package org.apache.hadoop.hive.ql.io.arrow; import io.netty.buffer.ArrowBuf; +import org.apache.arrow.memory.BufferAllocator; import org.apache.arrow.vector.BigIntVector; import org.apache.arrow.vector.BitVector; import org.apache.arrow.vector.BitVectorHelper; @@ -37,10 +38,15 @@ import org.apache.arrow.vector.VectorSchemaRoot; import org.apache.arrow.vector.complex.ListVector; import org.apache.arrow.vector.complex.MapVector; -import org.apache.arrow.vector.complex.NullableMapVector; +import org.apache.arrow.vector.dictionary.Dictionary; +import org.apache.arrow.vector.dictionary.DictionaryEncoder; +import org.apache.arrow.vector.dictionary.DictionaryProvider; import org.apache.arrow.vector.types.Types; import org.apache.arrow.vector.types.pojo.ArrowType; +import org.apache.arrow.vector.types.pojo.DictionaryEncoding; +import org.apache.arrow.vector.types.pojo.Field; import org.apache.arrow.vector.types.pojo.FieldType; +import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.hive.conf.HiveConf; import org.apache.hadoop.hive.ql.exec.vector.BytesColumnVector; import org.apache.hadoop.hive.ql.exec.vector.ColumnVector; @@ -59,7 +65,6 @@ import org.apache.hadoop.hive.serde2.SerDeException; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils; -import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector; import org.apache.hadoop.hive.serde2.typeinfo.DecimalTypeInfo; import org.apache.hadoop.hive.serde2.typeinfo.ListTypeInfo; @@ -69,8 +74,12 @@ import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo; import org.apache.hadoop.hive.serde2.typeinfo.UnionTypeInfo; +import java.nio.ByteBuffer; import java.util.ArrayList; +import java.util.Collections; +import java.util.HashSet; import java.util.List; +import java.util.Set; import static org.apache.hadoop.hive.conf.HiveConf.ConfVars.HIVE_ARROW_BATCH_SIZE; import static org.apache.hadoop.hive.ql.exec.vector.VectorizedBatchUtil.createColumnVector; @@ -94,20 +103,25 @@ private final VectorAssignRow vectorAssignRow; private int batchSize; - private final NullableMapVector rootVector; + private final Configuration conf; + private final BufferAllocator bufferAllocator; + private final boolean encode; + private long dictionaryId; + private DictionaryProvider.MapDictionaryProvider dictionaryProvider; Serializer(ArrowColumnarBatchSerDe serDe) throws SerDeException { MAX_BUFFERED_ROWS = HiveConf.getIntVar(serDe.conf, HIVE_ARROW_BATCH_SIZE); - ArrowColumnarBatchSerDe.LOG.info("ArrowColumnarBatchSerDe max number of buffered columns: " + MAX_BUFFERED_ROWS); + ArrowColumnarBatchSerDe.LOG.info("ArrowColumnarBatchSerDe max number of buffered columns: " + + MAX_BUFFERED_ROWS); + conf = serDe.conf; + bufferAllocator = serDe.rootAllocator; + encode = HiveConf.getBoolVar(conf, HiveConf.ConfVars.HIVE_ARROW_ENCODE); // Schema structTypeInfo = (StructTypeInfo) getTypeInfoFromObjectInspector(serDe.rowObjectInspector); List fieldTypeInfos = structTypeInfo.getAllStructFieldTypeInfos(); fieldSize = fieldTypeInfos.size(); - // Init Arrow stuffs - rootVector = NullableMapVector.empty(null, serDe.rootAllocator); - // Init Hive stuffs vectorizedRowBatch = new VectorizedRowBatch(fieldSize); for (int fieldIndex = 0; fieldIndex < fieldSize; fieldIndex++) { @@ -125,29 +139,68 @@ } private ArrowWrapperWritable serializeBatch() { - rootVector.setValueCount(0); + dictionaryProvider = new DictionaryProvider.MapDictionaryProvider(); + final List fieldList = new ArrayList<>(); + final List vectorList = new ArrayList<>(); for (int fieldIndex = 0; fieldIndex < vectorizedRowBatch.projectionSize; fieldIndex++) { final int projectedColumn = vectorizedRowBatch.projectedColumns[fieldIndex]; final ColumnVector hiveVector = vectorizedRowBatch.cols[projectedColumn]; final TypeInfo fieldTypeInfo = structTypeInfo.getAllStructFieldTypeInfos().get(fieldIndex); final String fieldName = structTypeInfo.getAllStructFieldNames().get(fieldIndex); - final FieldType fieldType = toFieldType(fieldTypeInfo); - final FieldVector arrowVector = rootVector.addOrGet(fieldName, fieldType, FieldVector.class); + final Field field = toField(fieldName, fieldTypeInfo); + final FieldVector arrowVector = field.createVector(bufferAllocator); arrowVector.setInitialCapacity(batchSize); arrowVector.allocateNew(); - write(arrowVector, hiveVector, fieldTypeInfo, batchSize); + vectorList.add(write(arrowVector, hiveVector, fieldTypeInfo, batchSize, encode)); + fieldList.add(arrowVector.getField()); + arrowVector.setValueCount(batchSize); } vectorizedRowBatch.reset(); - rootVector.setValueCount(batchSize); + VectorSchemaRoot vectorSchemaRoot = new VectorSchemaRoot(fieldList, vectorList, batchSize); batchSize = 0; - VectorSchemaRoot vectorSchemaRoot = new VectorSchemaRoot(rootVector); - return new ArrowWrapperWritable(vectorSchemaRoot); + return new ArrowWrapperWritable(vectorSchemaRoot, dictionaryProvider); + } + + private Field toField(String fieldName, TypeInfo typeInfo) { + switch (typeInfo.getCategory()) { + case PRIMITIVE: + return new Field(fieldName, toFieldType(typeInfo), null); + case MAP: + return toField(fieldName, toStructListTypeInfo((MapTypeInfo) typeInfo)); + case LIST: + { + final ListTypeInfo listTypeInfo = (ListTypeInfo) typeInfo; + final Field elementField = toField(null, listTypeInfo.getListElementTypeInfo()); + final List children = Collections.singletonList(elementField); + return new Field(fieldName, toFieldType(typeInfo), children); + } + case STRUCT: + { + final List children = new ArrayList<>(); + final StructTypeInfo structTypeInfo = (StructTypeInfo) typeInfo; + final List fieldNames = structTypeInfo.getAllStructFieldNames(); + final List fieldTypeInfos = structTypeInfo.getAllStructFieldTypeInfos(); + for (int i = 0; i < fieldNames.size(); i++) { + children.add(toField(fieldNames.get(i), fieldTypeInfos.get(i))); + } + return new Field(fieldName, toFieldType(typeInfo), children); + } + case UNION: + default: + throw new IllegalArgumentException(); + } } private FieldType toFieldType(TypeInfo typeInfo) { - return new FieldType(true, toArrowType(typeInfo), null); + final DictionaryEncoding dictionaryEncoding; + if (encode) { + dictionaryEncoding = new DictionaryEncoding(dictionaryId++, false, null); + } else { + dictionaryEncoding = null; + } + return new FieldType(true, toArrowType(typeInfo), dictionaryEncoding); } private ArrowType toArrowType(TypeInfo typeInfo) { @@ -203,34 +256,34 @@ private ArrowType toArrowType(TypeInfo typeInfo) { } } - private void write(FieldVector arrowVector, ColumnVector hiveVector, TypeInfo typeInfo, int size) { + @SuppressWarnings("unchecked") + private FieldVector write(FieldVector arrowVector, ColumnVector hiveVector, TypeInfo typeInfo, + int size, boolean encodable) { switch (typeInfo.getCategory()) { case PRIMITIVE: - writePrimitive(arrowVector, hiveVector, typeInfo, size); - break; + return writePrimitive(arrowVector, hiveVector, size, encodable); case LIST: - writeList((ListVector) arrowVector, (ListColumnVector) hiveVector, (ListTypeInfo) typeInfo, size); - break; + return writeList((ListVector) arrowVector, (ListColumnVector) hiveVector, + (ListTypeInfo) typeInfo, size); case STRUCT: - writeStruct((MapVector) arrowVector, (StructColumnVector) hiveVector, (StructTypeInfo) typeInfo, size); - break; + return writeStruct((MapVector) arrowVector, (StructColumnVector) hiveVector, + (StructTypeInfo) typeInfo, size); case UNION: - writeUnion(arrowVector, hiveVector, typeInfo, size); - break; + return writeUnion(arrowVector, hiveVector, typeInfo, size); case MAP: - writeMap((ListVector) arrowVector, (MapColumnVector) hiveVector, (MapTypeInfo) typeInfo, size); - break; + return writeMap((ListVector) arrowVector, (MapColumnVector) hiveVector, + (MapTypeInfo) typeInfo, size); default: throw new IllegalArgumentException(); } } - private void writeMap(ListVector arrowVector, MapColumnVector hiveVector, MapTypeInfo typeInfo, - int size) { + private FieldVector writeMap(ListVector arrowVector, MapColumnVector hiveVector, + MapTypeInfo typeInfo, int size) { final ListTypeInfo structListTypeInfo = toStructListTypeInfo(typeInfo); final ListColumnVector structListVector = toStructListVector(hiveVector); - write(arrowVector, structListVector, structListTypeInfo, size); + write(arrowVector, structListVector, structListTypeInfo, size, false); final ArrowBuf validityBuffer = arrowVector.getValidityBuffer(); for (int rowIndex = 0; rowIndex < size; rowIndex++) { @@ -240,10 +293,12 @@ private void writeMap(ListVector arrowVector, MapColumnVector hiveVector, MapTyp BitVectorHelper.setValidityBitToOne(validityBuffer, rowIndex); } } + + return arrowVector; } - private void writeUnion(FieldVector arrowVector, ColumnVector hiveVector, TypeInfo typeInfo, - int size) { + private FieldVector writeUnion(FieldVector arrowVector, ColumnVector hiveVector, + TypeInfo typeInfo, int size) { final UnionTypeInfo unionTypeInfo = (UnionTypeInfo) typeInfo; final List objectTypeInfos = unionTypeInfo.getAllUnionObjectTypeInfos(); final UnionColumnVector hiveUnionVector = (UnionColumnVector) hiveVector; @@ -253,10 +308,13 @@ private void writeUnion(FieldVector arrowVector, ColumnVector hiveVector, TypeIn final ColumnVector hiveObjectVector = hiveObjectVectors[tag]; final TypeInfo objectTypeInfo = objectTypeInfos.get(tag); - write(arrowVector, hiveObjectVector, objectTypeInfo, size); + write(arrowVector, hiveObjectVector, objectTypeInfo, size, false); + + return arrowVector; } - private void writeStruct(MapVector arrowVector, StructColumnVector hiveVector, + @SuppressWarnings("unchecked") + private FieldVector writeStruct(MapVector arrowVector, StructColumnVector hiveVector, StructTypeInfo typeInfo, int size) { final List fieldNames = typeInfo.getAllStructFieldNames(); final List fieldTypeInfos = typeInfo.getAllStructFieldTypeInfos(); @@ -267,12 +325,11 @@ private void writeStruct(MapVector arrowVector, StructColumnVector hiveVector, final TypeInfo fieldTypeInfo = fieldTypeInfos.get(fieldIndex); final ColumnVector hiveFieldVector = hiveFieldVectors[fieldIndex]; final String fieldName = fieldNames.get(fieldIndex); - final FieldVector arrowFieldVector = - arrowVector.addOrGet(fieldName, - toFieldType(fieldTypeInfos.get(fieldIndex)), FieldVector.class); + final FieldVector arrowFieldVector = arrowVector.addOrGet(fieldName, + toFieldType(fieldTypeInfos.get(fieldIndex)), FieldVector.class); arrowFieldVector.setInitialCapacity(size); arrowFieldVector.allocateNew(); - write(arrowFieldVector, hiveFieldVector, fieldTypeInfo, size); + write(arrowFieldVector, hiveFieldVector, fieldTypeInfo, size, false); } final ArrowBuf validityBuffer = arrowVector.getValidityBuffer(); @@ -283,10 +340,12 @@ private void writeStruct(MapVector arrowVector, StructColumnVector hiveVector, BitVectorHelper.setValidityBitToOne(validityBuffer, rowIndex); } } + + return (FieldVector) arrowVector; } - private void writeList(ListVector arrowVector, ListColumnVector hiveVector, ListTypeInfo typeInfo, - int size) { + private FieldVector writeList(ListVector arrowVector, ListColumnVector hiveVector, + ListTypeInfo typeInfo, int size) { final int OFFSET_WIDTH = 4; final TypeInfo elementTypeInfo = typeInfo.getListElementTypeInfo(); final ColumnVector hiveElementVector = hiveVector.child; @@ -295,7 +354,7 @@ private void writeList(ListVector arrowVector, ListColumnVector hiveVector, List arrowElementVector.setInitialCapacity(hiveVector.childCount); arrowElementVector.allocateNew(); - write(arrowElementVector, hiveElementVector, elementTypeInfo, hiveVector.childCount); + write(arrowElementVector, hiveElementVector, elementTypeInfo, hiveVector.childCount, false); final ArrowBuf offsetBuffer = arrowVector.getOffsetBuffer(); int nextOffset = 0; @@ -310,14 +369,13 @@ private void writeList(ListVector arrowVector, ListColumnVector hiveVector, List } } offsetBuffer.setInt(size * OFFSET_WIDTH, nextOffset); + return arrowVector; } - private void writePrimitive(FieldVector arrowVector, ColumnVector hiveVector, TypeInfo typeInfo, - int size) { - final PrimitiveObjectInspector.PrimitiveCategory primitiveCategory = - ((PrimitiveTypeInfo) typeInfo).getPrimitiveCategory(); - switch (primitiveCategory) { - case BOOLEAN: + private FieldVector writePrimitive(FieldVector arrowVector, ColumnVector hiveVector, int size, + boolean encodable) { + switch (arrowVector.getMinorType()) { + case BIT: { final BitVector bitVector = (BitVector) arrowVector; for (int i = 0; i < size; i++) { @@ -327,9 +385,9 @@ private void writePrimitive(FieldVector arrowVector, ColumnVector hiveVector, Ty bitVector.set(i, (int) ((LongColumnVector) hiveVector).vector[i]); } } + return bitVector; } - break; - case BYTE: + case TINYINT: { final TinyIntVector tinyIntVector = (TinyIntVector) arrowVector; for (int i = 0; i < size; i++) { @@ -339,9 +397,9 @@ private void writePrimitive(FieldVector arrowVector, ColumnVector hiveVector, Ty tinyIntVector.set(i, (byte) ((LongColumnVector) hiveVector).vector[i]); } } + return tinyIntVector; } - break; - case SHORT: + case SMALLINT: { final SmallIntVector smallIntVector = (SmallIntVector) arrowVector; for (int i = 0; i < size; i++) { @@ -351,8 +409,8 @@ private void writePrimitive(FieldVector arrowVector, ColumnVector hiveVector, Ty smallIntVector.set(i, (short) ((LongColumnVector) hiveVector).vector[i]); } } + return smallIntVector; } - break; case INT: { final IntVector intVector = (IntVector) arrowVector; @@ -363,9 +421,9 @@ private void writePrimitive(FieldVector arrowVector, ColumnVector hiveVector, Ty intVector.set(i, (int) ((LongColumnVector) hiveVector).vector[i]); } } + return intVector; } - break; - case LONG: + case BIGINT: { final BigIntVector bigIntVector = (BigIntVector) arrowVector; for (int i = 0; i < size; i++) { @@ -375,9 +433,9 @@ private void writePrimitive(FieldVector arrowVector, ColumnVector hiveVector, Ty bigIntVector.set(i, ((LongColumnVector) hiveVector).vector[i]); } } + return bigIntVector; } - break; - case FLOAT: + case FLOAT4: { final Float4Vector float4Vector = (Float4Vector) arrowVector; for (int i = 0; i < size; i++) { @@ -387,9 +445,9 @@ private void writePrimitive(FieldVector arrowVector, ColumnVector hiveVector, Ty float4Vector.set(i, (float) ((DoubleColumnVector) hiveVector).vector[i]); } } + return float4Vector; } - break; - case DOUBLE: + case FLOAT8: { final Float8Vector float8Vector = (Float8Vector) arrowVector; for (int i = 0; i < size; i++) { @@ -399,11 +457,9 @@ private void writePrimitive(FieldVector arrowVector, ColumnVector hiveVector, Ty float8Vector.set(i, ((DoubleColumnVector) hiveVector).vector[i]); } } + return float8Vector; } - break; - case STRING: case VARCHAR: - case CHAR: { final VarCharVector varCharVector = (VarCharVector) arrowVector; final BytesColumnVector bytesVector = (BytesColumnVector) hiveVector; @@ -414,9 +470,41 @@ private void writePrimitive(FieldVector arrowVector, ColumnVector hiveVector, Ty varCharVector.setSafe(i, bytesVector.vector[i], bytesVector.start[i], bytesVector.length[i]); } } + varCharVector.setValueCount(size); + + if (encodable) { + int j = 0; + final Set occurrences = new HashSet<>(); + for (int i = 0; i < size; i++) { + if (!bytesVector.isNull[i]) { + final ByteBuffer byteBuffer = + ByteBuffer.wrap(bytesVector.vector[j], bytesVector.start[j], bytesVector.length[j]); + if (!occurrences.contains(byteBuffer)) { + occurrences.add(byteBuffer); + } + j++; + } + } + final FieldType fieldType = arrowVector.getField().getFieldType(); + final VarCharVector dictionaryVector = (VarCharVector) + Types.MinorType.VARCHAR.getNewVector(null, fieldType, bufferAllocator, null); + j = 0; + for (ByteBuffer occurrence : occurrences) { + final int start = occurrence.position(); + final int length = occurrence.limit() - start; + dictionaryVector.setSafe(j++, occurrence.array(), start, length); + } + dictionaryVector.setValueCount(occurrences.size()); + varCharVector.setValueCount(size); + final DictionaryEncoding dictionaryEncoding = arrowVector.getField().getDictionary(); + final Dictionary dictionary = new Dictionary(dictionaryVector, dictionaryEncoding); + dictionaryProvider.put(dictionary); + return (FieldVector) DictionaryEncoder.encode(varCharVector, dictionary); + } else { + return varCharVector; + } } - break; - case DATE: + case DATEDAY: { final DateDayVector dateDayVector = (DateDayVector) arrowVector; for (int i = 0; i < size; i++) { @@ -426,9 +514,9 @@ private void writePrimitive(FieldVector arrowVector, ColumnVector hiveVector, Ty dateDayVector.set(i, (int) ((LongColumnVector) hiveVector).vector[i]); } } + return dateDayVector; } - break; - case TIMESTAMP: + case TIMESTAMPNANO: { final TimeStampNanoVector timeStampNanoVector = (TimeStampNanoVector) arrowVector; final TimestampColumnVector timestampColumnVector = (TimestampColumnVector) hiveVector; @@ -449,9 +537,9 @@ private void writePrimitive(FieldVector arrowVector, ColumnVector hiveVector, Ty } } } + return timeStampNanoVector; } - break; - case BINARY: + case VARBINARY: { final VarBinaryVector varBinaryVector = (VarBinaryVector) arrowVector; final BytesColumnVector bytesVector = (BytesColumnVector) hiveVector; @@ -462,8 +550,8 @@ private void writePrimitive(FieldVector arrowVector, ColumnVector hiveVector, Ty varBinaryVector.setSafe(i, bytesVector.vector[i], bytesVector.start[i], bytesVector.length[i]); } } + return varBinaryVector; } - break; case DECIMAL: { final DecimalVector decimalVector = (DecimalVector) arrowVector; @@ -472,13 +560,13 @@ private void writePrimitive(FieldVector arrowVector, ColumnVector hiveVector, Ty if (hiveVector.isNull[i]) { decimalVector.setNull(i); } else { - decimalVector.set(i, - ((DecimalColumnVector) hiveVector).vector[i].getHiveDecimal().bigDecimalValue().setScale(scale)); + decimalVector.set(i, ((DecimalColumnVector) hiveVector).vector[i].getHiveDecimal(). + bigDecimalValue().setScale(scale)); } } + return decimalVector; } - break; - case INTERVAL_YEAR_MONTH: + case INTERVALYEAR: { final IntervalYearVector intervalYearVector = (IntervalYearVector) arrowVector; for (int i = 0; i < size; i++) { @@ -488,9 +576,9 @@ private void writePrimitive(FieldVector arrowVector, ColumnVector hiveVector, Ty intervalYearVector.set(i, (int) ((LongColumnVector) hiveVector).vector[i]); } } + return intervalYearVector; } - break; - case INTERVAL_DAY_TIME: + case INTERVALDAY: { final IntervalDayVector intervalDayVector = (IntervalDayVector) arrowVector; final IntervalDayTimeColumnVector intervalDayTimeColumnVector = @@ -501,17 +589,13 @@ private void writePrimitive(FieldVector arrowVector, ColumnVector hiveVector, Ty } else { final long totalSeconds = intervalDayTimeColumnVector.getTotalSeconds(i); final long days = totalSeconds / SECOND_PER_DAY; - final long millis = - (totalSeconds - days * SECOND_PER_DAY) * MS_PER_SECOND + - intervalDayTimeColumnVector.getNanos(i) / NS_PER_MS; + final long millis = (totalSeconds - days * SECOND_PER_DAY) * MS_PER_SECOND + + intervalDayTimeColumnVector.getNanos(i) / NS_PER_MS; intervalDayVector.set(i, (int) days, (int) millis); } } + return intervalDayVector; } - break; - case VOID: - case UNKNOWN: - case TIMESTAMPLOCALTZ: default: throw new IllegalArgumentException(); } diff --git ql/src/test/org/apache/hadoop/hive/ql/io/arrow/TestArrowColumnarBatchSerDe.java ql/src/test/org/apache/hadoop/hive/ql/io/arrow/TestArrowColumnarBatchSerDe.java index 74f6624597..3f2847eeaf 100644 --- ql/src/test/org/apache/hadoop/hive/ql/io/arrow/TestArrowColumnarBatchSerDe.java +++ ql/src/test/org/apache/hadoop/hive/ql/io/arrow/TestArrowColumnarBatchSerDe.java @@ -26,6 +26,7 @@ import org.apache.hadoop.hive.common.type.HiveIntervalDayTime; import org.apache.hadoop.hive.common.type.HiveIntervalYearMonth; import org.apache.hadoop.hive.common.type.HiveVarchar; +import org.apache.hadoop.hive.conf.HiveConf; import org.apache.hadoop.hive.serde.serdeConstants; import org.apache.hadoop.hive.serde2.AbstractSerDe; import org.apache.hadoop.hive.serde2.SerDeException; @@ -494,6 +495,27 @@ public void testPrimitiveString() throws SerDeException { initAndSerializeAndDeserialize(schema, STRING_ROWS); } + @Test + public void testPrimitiveEncodeString() throws SerDeException { + String[][] schema = { + {"string1", "string"}, + }; + + HiveConf.setBoolVar(conf, HiveConf.ConfVars.HIVE_ARROW_ENCODE, true); + + final Object[][] rows = { + {text("")}, + {text("Hello")}, + {text("Hello")}, + {text("world!")}, + {text("Hello")}, + {text("world!")}, + {text("world")}, + {null}, + }; + initAndSerializeAndDeserialize(schema, rows); + } + @Test public void testPrimitiveDTI() throws SerDeException { String[][] schema = {