diff --git common/src/java/org/apache/hadoop/hive/conf/HiveConf.java common/src/java/org/apache/hadoop/hive/conf/HiveConf.java index 3295d1dbc5..834afa1643 100644 --- common/src/java/org/apache/hadoop/hive/conf/HiveConf.java +++ common/src/java/org/apache/hadoop/hive/conf/HiveConf.java @@ -2634,6 +2634,7 @@ private static void populateLlapDaemonVarsSet(Set llapDaemonVarsSetLocal HIVE_ARROW_ROOT_ALLOCATOR_LIMIT("hive.arrow.root.allocator.limit", Long.MAX_VALUE, "Arrow root allocator memory size limitation in bytes."), HIVE_ARROW_BATCH_SIZE("hive.arrow.batch.size", 1000, "The number of rows sent in one Arrow batch."), + HIVE_ARROW_ENCODE("hive.arrow.encode", false, "Set to true to encode repeating values."), // For Druid storage handler HIVE_DRUID_INDEXING_GRANULARITY("hive.druid.indexer.segments.granularity", "DAY", diff --git ql/src/java/org/apache/hadoop/hive/ql/io/arrow/ArrowWrapperWritable.java ql/src/java/org/apache/hadoop/hive/ql/io/arrow/ArrowWrapperWritable.java index dd490b1b90..4dd3ce02bd 100644 --- ql/src/java/org/apache/hadoop/hive/ql/io/arrow/ArrowWrapperWritable.java +++ ql/src/java/org/apache/hadoop/hive/ql/io/arrow/ArrowWrapperWritable.java @@ -19,6 +19,7 @@ package org.apache.hadoop.hive.ql.io.arrow; import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.dictionary.DictionaryProvider; import org.apache.hadoop.io.WritableComparable; import java.io.DataInput; @@ -27,10 +28,18 @@ public class ArrowWrapperWritable implements WritableComparable { private VectorSchemaRoot vectorSchemaRoot; + private DictionaryProvider dictionaryProvider; public ArrowWrapperWritable(VectorSchemaRoot vectorSchemaRoot) { + this(vectorSchemaRoot, null); + } + + public ArrowWrapperWritable(VectorSchemaRoot vectorSchemaRoot, + DictionaryProvider dictionaryProvider) { this.vectorSchemaRoot = vectorSchemaRoot; + this.dictionaryProvider = dictionaryProvider; } + public ArrowWrapperWritable() {} public VectorSchemaRoot getVectorSchemaRoot() { @@ -41,6 +50,14 @@ public void setVectorSchemaRoot(VectorSchemaRoot vectorSchemaRoot) { this.vectorSchemaRoot = vectorSchemaRoot; } + public DictionaryProvider getDictionaryProvider() { + return dictionaryProvider; + } + + public void setDictionaryProvider(DictionaryProvider dictionaryProvider) { + this.dictionaryProvider = dictionaryProvider; + } + @Override public void write(DataOutput dataOutput) throws IOException { throw new UnsupportedOperationException(); diff --git ql/src/java/org/apache/hadoop/hive/ql/io/arrow/Deserializer.java ql/src/java/org/apache/hadoop/hive/ql/io/arrow/Deserializer.java index fb5800b140..0739321047 100644 --- ql/src/java/org/apache/hadoop/hive/ql/io/arrow/Deserializer.java +++ ql/src/java/org/apache/hadoop/hive/ql/io/arrow/Deserializer.java @@ -34,9 +34,13 @@ import org.apache.arrow.vector.VarBinaryVector; import org.apache.arrow.vector.VarCharVector; import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.dictionary.Dictionary; +import org.apache.arrow.vector.dictionary.DictionaryEncoder; +import org.apache.arrow.vector.dictionary.DictionaryProvider; import org.apache.arrow.vector.holders.NullableIntervalDayHolder; import org.apache.hadoop.hive.common.type.HiveDecimal; import org.apache.hadoop.hive.common.type.HiveIntervalDayTime; +import org.apache.hadoop.hive.conf.HiveConf; import org.apache.hadoop.hive.ql.exec.vector.BytesColumnVector; import org.apache.hadoop.hive.ql.exec.vector.ColumnVector; import org.apache.hadoop.hive.ql.exec.vector.DecimalColumnVector; @@ -73,13 +77,17 @@ import static org.apache.hadoop.hive.ql.io.arrow.ArrowColumnarBatchSerDe.toStructListVector; class Deserializer { - private final ArrowColumnarBatchSerDe serDe; private final VectorExtractRow vectorExtractRow; private final VectorizedRowBatch vectorizedRowBatch; + private final StructTypeInfo rowTypeInfo; + private final boolean encode; + private DictionaryProvider dictionaryProvider; private Object[][] rows; Deserializer(ArrowColumnarBatchSerDe serDe) throws SerDeException { - this.serDe = serDe; + rowTypeInfo = serDe.rowTypeInfo; + encode = HiveConf.getBoolVar(serDe.conf, HiveConf.ConfVars.HIVE_ARROW_ENCODE); + vectorExtractRow = new VectorExtractRow(); final List fieldTypeInfoList = serDe.rowTypeInfo.getAllStructFieldTypeInfos(); final int fieldCount = fieldTypeInfoList.size(); @@ -104,6 +112,7 @@ public Object deserialize(Writable writable) { final List fieldVectors = vectorSchemaRoot.getFieldVectors(); final int fieldCount = fieldVectors.size(); final int rowCount = vectorSchemaRoot.getRowCount(); + dictionaryProvider = arrowWrapperWritable.getDictionaryProvider(); vectorizedRowBatch.ensureSize(rowCount); if (rows == null || rows.length < rowCount ) { @@ -117,8 +126,8 @@ public Object deserialize(Writable writable) { final FieldVector fieldVector = fieldVectors.get(fieldIndex); final int projectedCol = vectorizedRowBatch.projectedColumns[fieldIndex]; final ColumnVector columnVector = vectorizedRowBatch.cols[projectedCol]; - final TypeInfo typeInfo = serDe.rowTypeInfo.getAllStructFieldTypeInfos().get(fieldIndex); - read(fieldVector, columnVector, typeInfo); + final TypeInfo typeInfo = rowTypeInfo.getAllStructFieldTypeInfos().get(fieldIndex); + read(fieldVector, columnVector, typeInfo, encode); } for (int rowIndex = 0; rowIndex < rowCount; rowIndex++) { vectorExtractRow.extractRow(vectorizedRowBatch, rowIndex, rows[rowIndex]); @@ -127,19 +136,21 @@ public Object deserialize(Writable writable) { return rows; } - private void read(FieldVector arrowVector, ColumnVector hiveVector, TypeInfo typeInfo) { + private void read(FieldVector arrowVector, ColumnVector hiveVector, TypeInfo typeInfo, + boolean encode) { switch (typeInfo.getCategory()) { case PRIMITIVE: - readPrimitive(arrowVector, hiveVector, typeInfo); + readPrimitive(arrowVector, hiveVector, typeInfo, encode); break; case LIST: - readList(arrowVector, (ListColumnVector) hiveVector, (ListTypeInfo) typeInfo); + readList(arrowVector, (ListColumnVector) hiveVector, (ListTypeInfo) typeInfo, encode); break; case MAP: - readMap(arrowVector, (MapColumnVector) hiveVector, (MapTypeInfo) typeInfo); + readMap(arrowVector, (MapColumnVector) hiveVector, (MapTypeInfo) typeInfo, encode); break; case STRUCT: - readStruct(arrowVector, (StructColumnVector) hiveVector, (StructTypeInfo) typeInfo); + readStruct(arrowVector, (StructColumnVector) hiveVector, (StructTypeInfo) typeInfo, + encode); break; case UNION: readUnion(arrowVector, (UnionColumnVector) hiveVector, (UnionTypeInfo) typeInfo); @@ -149,7 +160,8 @@ private void read(FieldVector arrowVector, ColumnVector hiveVector, TypeInfo typ } } - private void readPrimitive(FieldVector arrowVector, ColumnVector hiveVector, TypeInfo typeInfo) { + private void readPrimitive(FieldVector arrowVector, ColumnVector hiveVector, TypeInfo typeInfo, + boolean encode) { final PrimitiveObjectInspector.PrimitiveCategory primitiveCategory = ((PrimitiveTypeInfo) typeInfo).getPrimitiveCategory(); @@ -245,12 +257,20 @@ private void readPrimitive(FieldVector arrowVector, ColumnVector hiveVector, Typ case VARCHAR: case CHAR: { + final VarCharVector varCharVector; + if (encode) { + final long id = arrowVector.getField().getDictionary().getId(); + final Dictionary dictionary = dictionaryProvider.lookup(id); + varCharVector = (VarCharVector) DictionaryEncoder.decode(arrowVector, dictionary); + } else { + varCharVector = ((VarCharVector) arrowVector); + } for (int i = 0; i < size; i++) { - if (arrowVector.isNull(i)) { + if (varCharVector.isNull(i)) { VectorizedBatchUtil.setNullColIsNullValue(hiveVector, i); } else { hiveVector.isNull[i] = false; - ((BytesColumnVector) hiveVector).setVal(i, ((VarCharVector) arrowVector).get(i)); + ((BytesColumnVector) hiveVector).setVal(i, varCharVector.get(i)); } } } @@ -362,14 +382,14 @@ private void readPrimitive(FieldVector arrowVector, ColumnVector hiveVector, Typ } } - private void readList(FieldVector arrowVector, ListColumnVector hiveVector, ListTypeInfo typeInfo) { + private void readList(FieldVector arrowVector, ListColumnVector hiveVector, ListTypeInfo typeInfo, + boolean encode) { final int size = arrowVector.getValueCount(); final ArrowBuf offsets = arrowVector.getOffsetBuffer(); final int OFFSET_WIDTH = 4; - read(arrowVector.getChildrenFromFields().get(0), - hiveVector.child, - typeInfo.getListElementTypeInfo()); + read(arrowVector.getChildrenFromFields().get(0), hiveVector.child, + typeInfo.getListElementTypeInfo(), encode); for (int i = 0; i < size; i++) { if (arrowVector.isNull(i)) { @@ -383,13 +403,14 @@ private void readList(FieldVector arrowVector, ListColumnVector hiveVector, List } } - private void readMap(FieldVector arrowVector, MapColumnVector hiveVector, MapTypeInfo typeInfo) { + private void readMap(FieldVector arrowVector, MapColumnVector hiveVector, MapTypeInfo typeInfo, + boolean encode) { final int size = arrowVector.getValueCount(); final ListTypeInfo mapStructListTypeInfo = toStructListTypeInfo(typeInfo); final ListColumnVector mapStructListVector = toStructListVector(hiveVector); final StructColumnVector mapStructVector = (StructColumnVector) mapStructListVector.child; - read(arrowVector, mapStructListVector, mapStructListTypeInfo); + read(arrowVector, mapStructListVector, mapStructListTypeInfo, encode); hiveVector.isRepeating = mapStructListVector.isRepeating; hiveVector.childCount = mapStructListVector.childCount; @@ -401,12 +422,14 @@ private void readMap(FieldVector arrowVector, MapColumnVector hiveVector, MapTyp System.arraycopy(mapStructListVector.isNull, 0, hiveVector.isNull, 0, size); } - private void readStruct(FieldVector arrowVector, StructColumnVector hiveVector, StructTypeInfo typeInfo) { + private void readStruct(FieldVector arrowVector, StructColumnVector hiveVector, + StructTypeInfo typeInfo, boolean encode) { final int size = arrowVector.getValueCount(); final List fieldTypeInfos = typeInfo.getAllStructFieldTypeInfos(); final int fieldSize = arrowVector.getChildrenFromFields().size(); for (int i = 0; i < fieldSize; i++) { - read(arrowVector.getChildrenFromFields().get(i), hiveVector.fields[i], fieldTypeInfos.get(i)); + read(arrowVector.getChildrenFromFields().get(i), hiveVector.fields[i], fieldTypeInfos.get(i), + encode); } for (int i = 0; i < size; i++) { diff --git ql/src/java/org/apache/hadoop/hive/ql/io/arrow/Serializer.java ql/src/java/org/apache/hadoop/hive/ql/io/arrow/Serializer.java index bd23011c93..9a310a6dc0 100644 --- ql/src/java/org/apache/hadoop/hive/ql/io/arrow/Serializer.java +++ ql/src/java/org/apache/hadoop/hive/ql/io/arrow/Serializer.java @@ -18,6 +18,7 @@ package org.apache.hadoop.hive.ql.io.arrow; import io.netty.buffer.ArrowBuf; +import org.apache.arrow.memory.BufferAllocator; import org.apache.arrow.vector.BigIntVector; import org.apache.arrow.vector.BitVector; import org.apache.arrow.vector.BitVectorHelper; @@ -37,10 +38,15 @@ import org.apache.arrow.vector.VectorSchemaRoot; import org.apache.arrow.vector.complex.ListVector; import org.apache.arrow.vector.complex.MapVector; -import org.apache.arrow.vector.complex.NullableMapVector; +import org.apache.arrow.vector.dictionary.Dictionary; +import org.apache.arrow.vector.dictionary.DictionaryEncoder; +import org.apache.arrow.vector.dictionary.DictionaryProvider; import org.apache.arrow.vector.types.Types; import org.apache.arrow.vector.types.pojo.ArrowType; +import org.apache.arrow.vector.types.pojo.DictionaryEncoding; +import org.apache.arrow.vector.types.pojo.Field; import org.apache.arrow.vector.types.pojo.FieldType; +import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.hive.conf.HiveConf; import org.apache.hadoop.hive.ql.exec.vector.BytesColumnVector; import org.apache.hadoop.hive.ql.exec.vector.ColumnVector; @@ -67,10 +73,15 @@ import org.apache.hadoop.hive.serde2.typeinfo.PrimitiveTypeInfo; import org.apache.hadoop.hive.serde2.typeinfo.StructTypeInfo; import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo; +import org.apache.hadoop.hive.serde2.typeinfo.TypeInfoFactory; import org.apache.hadoop.hive.serde2.typeinfo.UnionTypeInfo; +import java.nio.ByteBuffer; import java.util.ArrayList; +import java.util.Collections; +import java.util.HashSet; import java.util.List; +import java.util.Set; import static org.apache.hadoop.hive.conf.HiveConf.ConfVars.HIVE_ARROW_BATCH_SIZE; import static org.apache.hadoop.hive.ql.exec.vector.VectorizedBatchUtil.createColumnVector; @@ -94,20 +105,25 @@ private final VectorAssignRow vectorAssignRow; private int batchSize; - private final NullableMapVector rootVector; + private final Configuration conf; + private final BufferAllocator bufferAllocator; + private final boolean encode; + private long dictionaryId; + private DictionaryProvider.MapDictionaryProvider dictionaryProvider; Serializer(ArrowColumnarBatchSerDe serDe) throws SerDeException { MAX_BUFFERED_ROWS = HiveConf.getIntVar(serDe.conf, HIVE_ARROW_BATCH_SIZE); - ArrowColumnarBatchSerDe.LOG.info("ArrowColumnarBatchSerDe max number of buffered columns: " + MAX_BUFFERED_ROWS); + ArrowColumnarBatchSerDe.LOG.info("ArrowColumnarBatchSerDe max number of buffered columns: " + + MAX_BUFFERED_ROWS); + conf = serDe.conf; + bufferAllocator = serDe.rootAllocator; + encode = HiveConf.getBoolVar(conf, HiveConf.ConfVars.HIVE_ARROW_ENCODE); // Schema structTypeInfo = (StructTypeInfo) getTypeInfoFromObjectInspector(serDe.rowObjectInspector); List fieldTypeInfos = structTypeInfo.getAllStructFieldTypeInfos(); fieldSize = fieldTypeInfos.size(); - // Init Arrow stuffs - rootVector = NullableMapVector.empty(null, serDe.rootAllocator); - // Init Hive stuffs vectorizedRowBatch = new VectorizedRowBatch(fieldSize); for (int fieldIndex = 0; fieldIndex < fieldSize; fieldIndex++) { @@ -125,28 +141,77 @@ } private ArrowWrapperWritable serializeBatch() { - rootVector.setValueCount(0); + dictionaryProvider = new DictionaryProvider.MapDictionaryProvider(); + final List fieldList = new ArrayList<>(); + final List vectorList = new ArrayList<>(); for (int fieldIndex = 0; fieldIndex < vectorizedRowBatch.projectionSize; fieldIndex++) { final int projectedColumn = vectorizedRowBatch.projectedColumns[fieldIndex]; final ColumnVector hiveVector = vectorizedRowBatch.cols[projectedColumn]; final TypeInfo fieldTypeInfo = structTypeInfo.getAllStructFieldTypeInfos().get(fieldIndex); final String fieldName = structTypeInfo.getAllStructFieldNames().get(fieldIndex); - final FieldType fieldType = toFieldType(fieldTypeInfo); - final FieldVector arrowVector = rootVector.addOrGet(fieldName, fieldType, FieldVector.class); + final Field field = toField(fieldName, fieldTypeInfo); + final FieldVector arrowVector = field.createVector(bufferAllocator); arrowVector.setInitialCapacity(batchSize); arrowVector.allocateNew(); - write(arrowVector, hiveVector, fieldTypeInfo, batchSize); + vectorList.add(write(arrowVector, hiveVector, fieldTypeInfo, batchSize, encode)); + fieldList.add(arrowVector.getField()); + arrowVector.setValueCount(batchSize); } vectorizedRowBatch.reset(); - rootVector.setValueCount(batchSize); + VectorSchemaRoot vectorSchemaRoot = new VectorSchemaRoot(fieldList, vectorList, batchSize); batchSize = 0; - VectorSchemaRoot vectorSchemaRoot = new VectorSchemaRoot(rootVector); - return new ArrowWrapperWritable(vectorSchemaRoot); + return new ArrowWrapperWritable(vectorSchemaRoot, dictionaryProvider); + } + + private Field toField(String fieldName, TypeInfo typeInfo) { + switch (typeInfo.getCategory()) { + case PRIMITIVE: + return new Field(fieldName, toFieldType(typeInfo), null); + case MAP: + return toField(fieldName, toStructListTypeInfo((MapTypeInfo) typeInfo)); + case LIST: + { + final ListTypeInfo listTypeInfo = (ListTypeInfo) typeInfo; + final Field elementField = toField(null, listTypeInfo.getListElementTypeInfo()); + final List children = Collections.singletonList(elementField); + return new Field(fieldName, toFieldType(typeInfo), children); + } + case STRUCT: + { + final List children = new ArrayList<>(); + final StructTypeInfo structTypeInfo = (StructTypeInfo) typeInfo; + final List fieldNames = structTypeInfo.getAllStructFieldNames(); + final List fieldTypeInfos = structTypeInfo.getAllStructFieldTypeInfos(); + for (int i = 0; i < fieldNames.size(); i++) { + children.add(toField(fieldNames.get(i), fieldTypeInfos.get(i))); + } + return new Field(fieldName, toFieldType(typeInfo), children); + } + case UNION: + default: + throw new IllegalArgumentException(); + } } private FieldType toFieldType(TypeInfo typeInfo) { + if (encode) { + if (typeInfo.getCategory() == ObjectInspector.Category.PRIMITIVE) { + final PrimitiveTypeInfo primitiveTypeInfo = (PrimitiveTypeInfo) typeInfo; + final PrimitiveObjectInspector.PrimitiveCategory primitiveCategory = + primitiveTypeInfo.getPrimitiveCategory(); + switch (primitiveCategory) { + case VARCHAR: + case CHAR: + case STRING: + { + return new FieldType(true, toArrowType(TypeInfoFactory.intTypeInfo), + new DictionaryEncoding(dictionaryId++, false, null)); + } + } + } + } return new FieldType(true, toArrowType(typeInfo), null); } @@ -203,34 +268,34 @@ private ArrowType toArrowType(TypeInfo typeInfo) { } } - private void write(FieldVector arrowVector, ColumnVector hiveVector, TypeInfo typeInfo, int size) { + @SuppressWarnings("unchecked") + private FieldVector write(FieldVector arrowVector, ColumnVector hiveVector, TypeInfo typeInfo, + int size, boolean encode) { switch (typeInfo.getCategory()) { case PRIMITIVE: - writePrimitive(arrowVector, hiveVector, typeInfo, size); - break; + return writePrimitive(arrowVector, hiveVector, (PrimitiveTypeInfo) typeInfo, size, encode); case LIST: - writeList((ListVector) arrowVector, (ListColumnVector) hiveVector, (ListTypeInfo) typeInfo, size); - break; + return writeList((ListVector) arrowVector, (ListColumnVector) hiveVector, + (ListTypeInfo) typeInfo, size, encode); case STRUCT: - writeStruct((MapVector) arrowVector, (StructColumnVector) hiveVector, (StructTypeInfo) typeInfo, size); - break; + return writeStruct((MapVector) arrowVector, (StructColumnVector) hiveVector, + (StructTypeInfo) typeInfo, size, encode); case UNION: - writeUnion(arrowVector, hiveVector, typeInfo, size); - break; + return writeUnion(arrowVector, hiveVector, typeInfo, size, encode); case MAP: - writeMap((ListVector) arrowVector, (MapColumnVector) hiveVector, (MapTypeInfo) typeInfo, size); - break; + return writeMap((ListVector) arrowVector, (MapColumnVector) hiveVector, + (MapTypeInfo) typeInfo, size, encode); default: throw new IllegalArgumentException(); } } - private void writeMap(ListVector arrowVector, MapColumnVector hiveVector, MapTypeInfo typeInfo, - int size) { + private FieldVector writeMap(ListVector arrowVector, MapColumnVector hiveVector, + MapTypeInfo typeInfo, int size, boolean encode) { final ListTypeInfo structListTypeInfo = toStructListTypeInfo(typeInfo); final ListColumnVector structListVector = toStructListVector(hiveVector); - write(arrowVector, structListVector, structListTypeInfo, size); + write(arrowVector, structListVector, structListTypeInfo, size, encode); final ArrowBuf validityBuffer = arrowVector.getValidityBuffer(); for (int rowIndex = 0; rowIndex < size; rowIndex++) { @@ -240,10 +305,12 @@ private void writeMap(ListVector arrowVector, MapColumnVector hiveVector, MapTyp BitVectorHelper.setValidityBitToOne(validityBuffer, rowIndex); } } + + return arrowVector; } - private void writeUnion(FieldVector arrowVector, ColumnVector hiveVector, TypeInfo typeInfo, - int size) { + private FieldVector writeUnion(FieldVector arrowVector, ColumnVector hiveVector, + TypeInfo typeInfo, int size, boolean encode) { final UnionTypeInfo unionTypeInfo = (UnionTypeInfo) typeInfo; final List objectTypeInfos = unionTypeInfo.getAllUnionObjectTypeInfos(); final UnionColumnVector hiveUnionVector = (UnionColumnVector) hiveVector; @@ -253,11 +320,14 @@ private void writeUnion(FieldVector arrowVector, ColumnVector hiveVector, TypeIn final ColumnVector hiveObjectVector = hiveObjectVectors[tag]; final TypeInfo objectTypeInfo = objectTypeInfos.get(tag); - write(arrowVector, hiveObjectVector, objectTypeInfo, size); + write(arrowVector, hiveObjectVector, objectTypeInfo, size, encode); + + return arrowVector; } - private void writeStruct(MapVector arrowVector, StructColumnVector hiveVector, - StructTypeInfo typeInfo, int size) { + @SuppressWarnings("unchecked") + private FieldVector writeStruct(MapVector arrowVector, StructColumnVector hiveVector, + StructTypeInfo typeInfo, int size, boolean encode) { final List fieldNames = typeInfo.getAllStructFieldNames(); final List fieldTypeInfos = typeInfo.getAllStructFieldTypeInfos(); final ColumnVector[] hiveFieldVectors = hiveVector.fields; @@ -267,12 +337,11 @@ private void writeStruct(MapVector arrowVector, StructColumnVector hiveVector, final TypeInfo fieldTypeInfo = fieldTypeInfos.get(fieldIndex); final ColumnVector hiveFieldVector = hiveFieldVectors[fieldIndex]; final String fieldName = fieldNames.get(fieldIndex); - final FieldVector arrowFieldVector = - arrowVector.addOrGet(fieldName, - toFieldType(fieldTypeInfos.get(fieldIndex)), FieldVector.class); + final FieldVector arrowFieldVector = arrowVector.addOrGet(fieldName, + toFieldType(fieldTypeInfos.get(fieldIndex)), FieldVector.class); arrowFieldVector.setInitialCapacity(size); arrowFieldVector.allocateNew(); - write(arrowFieldVector, hiveFieldVector, fieldTypeInfo, size); + write(arrowFieldVector, hiveFieldVector, fieldTypeInfo, size, encode); } final ArrowBuf validityBuffer = arrowVector.getValidityBuffer(); @@ -283,10 +352,12 @@ private void writeStruct(MapVector arrowVector, StructColumnVector hiveVector, BitVectorHelper.setValidityBitToOne(validityBuffer, rowIndex); } } + + return (FieldVector) arrowVector; } - private void writeList(ListVector arrowVector, ListColumnVector hiveVector, ListTypeInfo typeInfo, - int size) { + private FieldVector writeList(ListVector arrowVector, ListColumnVector hiveVector, + ListTypeInfo typeInfo, int size, boolean encode) { final int OFFSET_WIDTH = 4; final TypeInfo elementTypeInfo = typeInfo.getListElementTypeInfo(); final ColumnVector hiveElementVector = hiveVector.child; @@ -295,7 +366,7 @@ private void writeList(ListVector arrowVector, ListColumnVector hiveVector, List arrowElementVector.setInitialCapacity(hiveVector.childCount); arrowElementVector.allocateNew(); - write(arrowElementVector, hiveElementVector, elementTypeInfo, hiveVector.childCount); + write(arrowElementVector, hiveElementVector, elementTypeInfo, hiveVector.childCount, encode); final ArrowBuf offsetBuffer = arrowVector.getOffsetBuffer(); int nextOffset = 0; @@ -310,13 +381,12 @@ private void writeList(ListVector arrowVector, ListColumnVector hiveVector, List } } offsetBuffer.setInt(size * OFFSET_WIDTH, nextOffset); + return arrowVector; } - private void writePrimitive(FieldVector arrowVector, ColumnVector hiveVector, TypeInfo typeInfo, - int size) { - final PrimitiveObjectInspector.PrimitiveCategory primitiveCategory = - ((PrimitiveTypeInfo) typeInfo).getPrimitiveCategory(); - switch (primitiveCategory) { + private FieldVector writePrimitive(FieldVector arrowVector, ColumnVector hiveVector, + PrimitiveTypeInfo primitiveTypeInfo, int size, boolean encode) { + switch (primitiveTypeInfo.getPrimitiveCategory()) { case BOOLEAN: { final BitVector bitVector = (BitVector) arrowVector; @@ -327,8 +397,8 @@ private void writePrimitive(FieldVector arrowVector, ColumnVector hiveVector, Ty bitVector.set(i, (int) ((LongColumnVector) hiveVector).vector[i]); } } + return bitVector; } - break; case BYTE: { final TinyIntVector tinyIntVector = (TinyIntVector) arrowVector; @@ -339,8 +409,8 @@ private void writePrimitive(FieldVector arrowVector, ColumnVector hiveVector, Ty tinyIntVector.set(i, (byte) ((LongColumnVector) hiveVector).vector[i]); } } + return tinyIntVector; } - break; case SHORT: { final SmallIntVector smallIntVector = (SmallIntVector) arrowVector; @@ -351,8 +421,8 @@ private void writePrimitive(FieldVector arrowVector, ColumnVector hiveVector, Ty smallIntVector.set(i, (short) ((LongColumnVector) hiveVector).vector[i]); } } + return smallIntVector; } - break; case INT: { final IntVector intVector = (IntVector) arrowVector; @@ -363,8 +433,8 @@ private void writePrimitive(FieldVector arrowVector, ColumnVector hiveVector, Ty intVector.set(i, (int) ((LongColumnVector) hiveVector).vector[i]); } } + return intVector; } - break; case LONG: { final BigIntVector bigIntVector = (BigIntVector) arrowVector; @@ -375,8 +445,8 @@ private void writePrimitive(FieldVector arrowVector, ColumnVector hiveVector, Ty bigIntVector.set(i, ((LongColumnVector) hiveVector).vector[i]); } } + return bigIntVector; } - break; case FLOAT: { final Float4Vector float4Vector = (Float4Vector) arrowVector; @@ -387,8 +457,8 @@ private void writePrimitive(FieldVector arrowVector, ColumnVector hiveVector, Ty float4Vector.set(i, (float) ((DoubleColumnVector) hiveVector).vector[i]); } } + return float4Vector; } - break; case DOUBLE: { final Float8Vector float8Vector = (Float8Vector) arrowVector; @@ -399,23 +469,74 @@ private void writePrimitive(FieldVector arrowVector, ColumnVector hiveVector, Ty float8Vector.set(i, ((DoubleColumnVector) hiveVector).vector[i]); } } + return float8Vector; } - break; - case STRING: case VARCHAR: case CHAR: + case STRING: { - final VarCharVector varCharVector = (VarCharVector) arrowVector; - final BytesColumnVector bytesVector = (BytesColumnVector) hiveVector; - for (int i = 0; i < size; i++) { - if (hiveVector.isNull[i]) { - varCharVector.setNull(i); - } else { - varCharVector.setSafe(i, bytesVector.vector[i], bytesVector.start[i], bytesVector.length[i]); + if (encode) { + final BytesColumnVector bytesVector = (BytesColumnVector) hiveVector; + final VarCharVector varCharVector = (VarCharVector) + Types.MinorType.VARCHAR.getNewVector(null, + FieldType.nullable(Types.MinorType.VARCHAR.getType()), bufferAllocator, null); + for (int i = 0; i < size; i++) { + if (hiveVector.isNull[i]) { + varCharVector.setNull(i); + } else { + varCharVector.setSafe(i, bytesVector.vector[i], bytesVector.start[i], + bytesVector.length[i]); + } + } + arrowVector.setValueCount(size); + + int j = 0; + final Set occurrences = new HashSet<>(); + for (int i = 0; i < size; i++) { + if (!bytesVector.isNull[i]) { + final ByteBuffer byteBuffer = + ByteBuffer.wrap(bytesVector.vector[j], bytesVector.start[j], + bytesVector.length[j]); + if (!occurrences.contains(byteBuffer)) { + occurrences.add(byteBuffer); + } + j++; + } } + final FieldType fieldType = arrowVector.getField().getFieldType(); + final VarCharVector dictionaryVector = (VarCharVector) Types.MinorType.VARCHAR. + getNewVector(null, fieldType, bufferAllocator, null); + j = 0; + for (ByteBuffer occurrence : occurrences) { + final int start = occurrence.position(); + final int length = occurrence.limit() - start; + dictionaryVector.setSafe(j++, occurrence.array(), start, length); + } + dictionaryVector.setValueCount(occurrences.size()); + varCharVector.setValueCount(size); + final DictionaryEncoding dictionaryEncoding = arrowVector.getField().getDictionary(); + final Dictionary dictionary = new Dictionary(dictionaryVector, dictionaryEncoding); + dictionaryProvider.put(dictionary); + final IntVector encodedVector = (IntVector) DictionaryEncoder.encode(varCharVector, + dictionary); + encodedVector.makeTransferPair(arrowVector).transfer(); + return arrowVector; + } else { + final BytesColumnVector bytesVector = (BytesColumnVector) hiveVector; + final VarCharVector varCharVector = (VarCharVector) arrowVector; + for (int i = 0; i < size; i++) { + if (hiveVector.isNull[i]) { + varCharVector.setNull(i); + } else { + varCharVector.setSafe(i, bytesVector.vector[i], bytesVector.start[i], + bytesVector.length[i]); + } + } + varCharVector.setValueCount(size); + + return varCharVector; } } - break; case DATE: { final DateDayVector dateDayVector = (DateDayVector) arrowVector; @@ -426,8 +547,8 @@ private void writePrimitive(FieldVector arrowVector, ColumnVector hiveVector, Ty dateDayVector.set(i, (int) ((LongColumnVector) hiveVector).vector[i]); } } + return dateDayVector; } - break; case TIMESTAMP: { final TimeStampNanoVector timeStampNanoVector = (TimeStampNanoVector) arrowVector; @@ -449,8 +570,8 @@ private void writePrimitive(FieldVector arrowVector, ColumnVector hiveVector, Ty } } } + return timeStampNanoVector; } - break; case BINARY: { final VarBinaryVector varBinaryVector = (VarBinaryVector) arrowVector; @@ -462,8 +583,8 @@ private void writePrimitive(FieldVector arrowVector, ColumnVector hiveVector, Ty varBinaryVector.setSafe(i, bytesVector.vector[i], bytesVector.start[i], bytesVector.length[i]); } } + return varBinaryVector; } - break; case DECIMAL: { final DecimalVector decimalVector = (DecimalVector) arrowVector; @@ -472,12 +593,12 @@ private void writePrimitive(FieldVector arrowVector, ColumnVector hiveVector, Ty if (hiveVector.isNull[i]) { decimalVector.setNull(i); } else { - decimalVector.set(i, - ((DecimalColumnVector) hiveVector).vector[i].getHiveDecimal().bigDecimalValue().setScale(scale)); + decimalVector.set(i, ((DecimalColumnVector) hiveVector).vector[i].getHiveDecimal(). + bigDecimalValue().setScale(scale)); } } + return decimalVector; } - break; case INTERVAL_YEAR_MONTH: { final IntervalYearVector intervalYearVector = (IntervalYearVector) arrowVector; @@ -488,8 +609,8 @@ private void writePrimitive(FieldVector arrowVector, ColumnVector hiveVector, Ty intervalYearVector.set(i, (int) ((LongColumnVector) hiveVector).vector[i]); } } + return intervalYearVector; } - break; case INTERVAL_DAY_TIME: { final IntervalDayVector intervalDayVector = (IntervalDayVector) arrowVector; @@ -501,17 +622,13 @@ private void writePrimitive(FieldVector arrowVector, ColumnVector hiveVector, Ty } else { final long totalSeconds = intervalDayTimeColumnVector.getTotalSeconds(i); final long days = totalSeconds / SECOND_PER_DAY; - final long millis = - (totalSeconds - days * SECOND_PER_DAY) * MS_PER_SECOND + - intervalDayTimeColumnVector.getNanos(i) / NS_PER_MS; + final long millis = (totalSeconds - days * SECOND_PER_DAY) * MS_PER_SECOND + + intervalDayTimeColumnVector.getNanos(i) / NS_PER_MS; intervalDayVector.set(i, (int) days, (int) millis); } } + return intervalDayVector; } - break; - case VOID: - case UNKNOWN: - case TIMESTAMPLOCALTZ: default: throw new IllegalArgumentException(); } diff --git ql/src/test/org/apache/hadoop/hive/ql/io/arrow/TestArrowColumnarBatchSerDe.java ql/src/test/org/apache/hadoop/hive/ql/io/arrow/TestArrowColumnarBatchSerDe.java index 74f6624597..2622ed42fb 100644 --- ql/src/test/org/apache/hadoop/hive/ql/io/arrow/TestArrowColumnarBatchSerDe.java +++ ql/src/test/org/apache/hadoop/hive/ql/io/arrow/TestArrowColumnarBatchSerDe.java @@ -26,6 +26,7 @@ import org.apache.hadoop.hive.common.type.HiveIntervalDayTime; import org.apache.hadoop.hive.common.type.HiveIntervalYearMonth; import org.apache.hadoop.hive.common.type.HiveVarchar; +import org.apache.hadoop.hive.conf.HiveConf; import org.apache.hadoop.hive.serde.serdeConstants; import org.apache.hadoop.hive.serde2.AbstractSerDe; import org.apache.hadoop.hive.serde2.SerDeException; @@ -494,6 +495,27 @@ public void testPrimitiveString() throws SerDeException { initAndSerializeAndDeserialize(schema, STRING_ROWS); } + @Test + public void testPrimitiveEncodeString() throws SerDeException { + String[][] schema = { + {"string1", "string"}, + }; + + HiveConf.setBoolVar(conf, HiveConf.ConfVars.HIVE_ARROW_ENCODE, true); + + final Object[][] rows = { + {text("")}, + {text("Hello")}, + {text("Hello")}, + {text("world!")}, + {text("Hello")}, + {text("world!")}, + {text("world")}, + {null}, + }; + initAndSerializeAndDeserialize(schema, rows); + } + @Test public void testPrimitiveDTI() throws SerDeException { String[][] schema = { @@ -578,6 +600,27 @@ public void testListString() throws SerDeException { initAndSerializeAndDeserialize(schema, toList(STRING_ROWS)); } + @Test + public void testListEncodeString() throws SerDeException { + String[][] schema = { + {"string1", "array"}, + }; + + HiveConf.setBoolVar(conf, HiveConf.ConfVars.HIVE_ARROW_ENCODE, true); + + final Object[][] rows = { + {text("")}, + {text("Hello")}, + {text("Hello")}, + {text("world!")}, + {text("Hello")}, + {text("world!")}, + {text("world")}, + {null}, + }; + initAndSerializeAndDeserialize(schema, toList(rows)); + } + @Test public void testListDTI() throws SerDeException { String[][] schema = {