diff --git common/src/java/org/apache/hadoop/hive/conf/HiveConf.java common/src/java/org/apache/hadoop/hive/conf/HiveConf.java index f40c60606c..8b5e69ffa4 100644 --- common/src/java/org/apache/hadoop/hive/conf/HiveConf.java +++ common/src/java/org/apache/hadoop/hive/conf/HiveConf.java @@ -2581,6 +2581,11 @@ private static void populateLlapDaemonVarsSet(Set llapDaemonVarsSetLocal "Set to true to ensure that each SQL Merge statement ensures that for each row in the target\n" + "table there is at most 1 matching row in the source table per SQL Specification."), + // For Arrow SerDe + HIVE_ARROW_ROOT_ALLOCATOR_LIMIT("hive.arrow.root.allocator.limit", Long.MAX_VALUE, + "Arrow root allocator memory size limitation in bytes."), + HIVE_ARROW_BATCH_SIZE("hive.arrow.batch.size", 1000, "The number of rows sent in one Arrow batch."), + // For Druid storage handler HIVE_DRUID_INDEXING_GRANULARITY("hive.druid.indexer.segments.granularity", "DAY", new PatternSet("YEAR", "MONTH", "WEEK", "DAY", "HOUR", "MINUTE", "SECOND"), diff --git pom.xml pom.xml index afcf76e855..34f969aa8b 100644 --- pom.xml +++ pom.xml @@ -119,6 +119,7 @@ 3.5.2 1.5.6 0.1 + 0.8.0 1.11.0 1.7.7 0.8.0.RELEASE diff --git ql/src/java/org/apache/hadoop/hive/ql/io/arrow/ArrowColumnarBatchSerDe.java ql/src/java/org/apache/hadoop/hive/ql/io/arrow/ArrowColumnarBatchSerDe.java new file mode 100644 index 0000000000..cde39cc216 --- /dev/null +++ ql/src/java/org/apache/hadoop/hive/ql/io/arrow/ArrowColumnarBatchSerDe.java @@ -0,0 +1,1163 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.hadoop.hive.ql.io.arrow; + +import com.google.common.collect.Lists; +import com.google.common.collect.Maps; +import io.netty.buffer.ArrowBuf; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.vector.BitVectorHelper; +import org.apache.arrow.vector.FieldVector; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.complex.impl.UnionListWriter; +import org.apache.arrow.vector.complex.impl.UnionReader; +import org.apache.arrow.vector.complex.impl.UnionWriter; +import org.apache.arrow.vector.complex.reader.FieldReader; +import org.apache.arrow.vector.complex.writer.BaseWriter; +import org.apache.arrow.vector.complex.writer.BigIntWriter; +import org.apache.arrow.vector.complex.writer.BitWriter; +import org.apache.arrow.vector.complex.writer.DateDayWriter; +import org.apache.arrow.vector.complex.writer.DecimalWriter; +import org.apache.arrow.vector.complex.writer.FieldWriter; +import org.apache.arrow.vector.complex.writer.Float4Writer; +import org.apache.arrow.vector.complex.writer.Float8Writer; +import org.apache.arrow.vector.complex.writer.IntWriter; +import org.apache.arrow.vector.complex.writer.IntervalDayWriter; +import org.apache.arrow.vector.complex.writer.IntervalYearWriter; +import org.apache.arrow.vector.complex.writer.SmallIntWriter; +import org.apache.arrow.vector.complex.writer.TimeStampMilliWriter; +import org.apache.arrow.vector.complex.writer.TinyIntWriter; +import org.apache.arrow.vector.complex.writer.VarBinaryWriter; +import org.apache.arrow.vector.complex.writer.VarCharWriter; +import org.apache.arrow.vector.holders.NullableBigIntHolder; +import org.apache.arrow.vector.holders.NullableBitHolder; +import org.apache.arrow.vector.holders.NullableDateDayHolder; +import org.apache.arrow.vector.holders.NullableFloat4Holder; +import org.apache.arrow.vector.holders.NullableFloat8Holder; +import org.apache.arrow.vector.holders.NullableIntHolder; +import org.apache.arrow.vector.holders.NullableIntervalDayHolder; +import org.apache.arrow.vector.holders.NullableIntervalYearHolder; +import org.apache.arrow.vector.holders.NullableSmallIntHolder; +import org.apache.arrow.vector.holders.NullableTimeStampMilliHolder; +import org.apache.arrow.vector.holders.NullableTinyIntHolder; +import org.apache.arrow.vector.holders.NullableVarBinaryHolder; +import org.apache.arrow.vector.holders.NullableVarCharHolder; +import org.apache.arrow.vector.types.TimeUnit; +import org.apache.arrow.vector.types.Types; +import org.apache.arrow.vector.types.pojo.ArrowType; +import org.apache.arrow.vector.types.pojo.Field; +import org.apache.arrow.vector.types.pojo.FieldType; +import org.apache.arrow.vector.types.pojo.Schema; +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.hive.common.type.HiveDecimal; +import org.apache.hadoop.hive.common.type.HiveIntervalDayTime; +import org.apache.hadoop.hive.conf.HiveConf; +import org.apache.hadoop.hive.ql.exec.vector.BytesColumnVector; +import org.apache.hadoop.hive.ql.exec.vector.ColumnVector; +import org.apache.hadoop.hive.ql.exec.vector.DecimalColumnVector; +import org.apache.hadoop.hive.ql.exec.vector.DoubleColumnVector; +import org.apache.hadoop.hive.ql.exec.vector.IntervalDayTimeColumnVector; +import org.apache.hadoop.hive.ql.exec.vector.ListColumnVector; +import org.apache.hadoop.hive.ql.exec.vector.LongColumnVector; +import org.apache.hadoop.hive.ql.exec.vector.MapColumnVector; +import org.apache.hadoop.hive.ql.exec.vector.StructColumnVector; +import org.apache.hadoop.hive.ql.exec.vector.TimestampColumnVector; +import org.apache.hadoop.hive.ql.exec.vector.UnionColumnVector; +import org.apache.hadoop.hive.ql.exec.vector.VectorAssignRow; +import org.apache.hadoop.hive.ql.exec.vector.VectorExtractRow; +import org.apache.hadoop.hive.ql.exec.vector.VectorizedRowBatch; +import org.apache.hadoop.hive.ql.metadata.HiveException; +import org.apache.hadoop.hive.serde.serdeConstants; +import org.apache.hadoop.hive.serde2.AbstractSerDe; +import org.apache.hadoop.hive.serde2.SerDeException; +import org.apache.hadoop.hive.serde2.SerDeStats; +import org.apache.hadoop.hive.serde2.SerDeUtils; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector; +import org.apache.hadoop.hive.serde2.typeinfo.DecimalTypeInfo; +import org.apache.hadoop.hive.serde2.typeinfo.ListTypeInfo; +import org.apache.hadoop.hive.serde2.typeinfo.MapTypeInfo; +import org.apache.hadoop.hive.serde2.typeinfo.PrimitiveTypeInfo; +import org.apache.hadoop.hive.serde2.typeinfo.StructTypeInfo; +import org.apache.hadoop.hive.serde2.typeinfo.TimestampLocalTZTypeInfo; +import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo; +import org.apache.hadoop.hive.serde2.typeinfo.TypeInfoFactory; +import org.apache.hadoop.hive.serde2.typeinfo.TypeInfoUtils; +import org.apache.hadoop.hive.serde2.typeinfo.UnionTypeInfo; +import org.apache.hadoop.io.Writable; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.DataInput; +import java.io.DataOutput; +import java.sql.Timestamp; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.Map; +import java.util.Properties; +import java.util.function.IntConsumer; + +import static org.apache.hadoop.hive.conf.HiveConf.ConfVars.HIVE_ARROW_BATCH_SIZE; +import static org.apache.hadoop.hive.ql.exec.vector.VectorizedBatchUtil.createColumnVector; +import static org.apache.hadoop.hive.serde2.typeinfo.TypeInfoUtils.getStandardWritableObjectInspectorFromTypeInfo; +import static org.apache.hadoop.hive.serde2.typeinfo.TypeInfoUtils.getTypeInfoFromObjectInspector; + +/** + * ArrowColumnarBatchSerDe converts Apache Hive rows to Apache Arrow columns. Its serialized + * class is {@link ArrowWrapperWritable}, which doesn't support {@link + * Writable#readFields(DataInput)} and {@link Writable#write(DataOutput)}. + * + * Followings are known issues of current implementation. + * + * A list column cannot have a decimal column. {@link UnionListWriter} doesn't have an + * implementation for {@link BaseWriter.ListWriter#decimal()}. + * + * A union column can have only one of string, char, varchar fields at a same time. Apache Arrow + * doesn't have string and char, so {@link ArrowColumnarBatchSerDe} uses varchar to simulate + * string and char. They will be considered as a same data type in + * {@link org.apache.arrow.vector.complex.UnionVector}. + * + * Timestamp with local timezone is not supported. {@link VectorAssignRow} doesn't support it. + */ +public class ArrowColumnarBatchSerDe extends AbstractSerDe { + public static final Logger LOG = LoggerFactory.getLogger(ArrowColumnarBatchSerDe.class.getName()); + private static final String DEFAULT_ARROW_FIELD_NAME = "[DEFAULT]"; + + private static final int MS_PER_SECOND = 1_000; + private static final int MS_PER_MINUTE = MS_PER_SECOND * 60; + private static final int MS_PER_HOUR = MS_PER_MINUTE * 60; + private static final int MS_PER_DAY = MS_PER_HOUR * 24; + private static final int NS_PER_MS = 1_000_000; + + private static BufferAllocator rootAllocator; + + private StructTypeInfo rowTypeInfo; + private StructObjectInspector rowObjectInspector; + private Configuration conf; + private Serializer serializer; + private Deserializer deserializer; + + @Override + public void initialize(Configuration conf, Properties tbl) throws SerDeException { + this.conf = conf; + + rootAllocator = RootAllocatorFactory.INSTANCE.getRootAllocator(conf); + + final String columnNameProperty = tbl.getProperty(serdeConstants.LIST_COLUMNS); + final String columnTypeProperty = tbl.getProperty(serdeConstants.LIST_COLUMN_TYPES); + final String columnNameDelimiter = tbl.containsKey(serdeConstants.COLUMN_NAME_DELIMITER) ? tbl + .getProperty(serdeConstants.COLUMN_NAME_DELIMITER) : String.valueOf(SerDeUtils.COMMA); + + // Create an object inspector + final List columnNames; + if (columnNameProperty.length() == 0) { + columnNames = new ArrayList<>(); + } else { + columnNames = Arrays.asList(columnNameProperty.split(columnNameDelimiter)); + } + final List columnTypes; + if (columnTypeProperty.length() == 0) { + columnTypes = new ArrayList<>(); + } else { + columnTypes = TypeInfoUtils.getTypeInfosFromTypeString(columnTypeProperty); + } + rowTypeInfo = (StructTypeInfo) TypeInfoFactory.getStructTypeInfo(columnNames, columnTypes); + rowObjectInspector = + (StructObjectInspector) getStandardWritableObjectInspectorFromTypeInfo(rowTypeInfo); + + final List fields = new ArrayList<>(); + final int size = columnNames.size(); + for (int i = 0; i < size; i++) { + fields.add(toField(columnNames.get(i), columnTypes.get(i))); + } + + serializer = new Serializer(new Schema(fields)); + deserializer = new Deserializer(); + } + + private class Serializer { + private final int MAX_BUFFERED_ROWS; + + // Schema + private final StructTypeInfo structTypeInfo; + private final List fieldTypeInfos; + private final int fieldSize; + + // Hive columns + private final VectorizedRowBatch vectorizedRowBatch; + private final VectorAssignRow vectorAssignRow; + private int batchSize; + + // Arrow columns + private final VectorSchemaRoot vectorSchemaRoot; + private final List arrowVectors; + private final List fieldWriters; + + private Serializer(Schema schema) throws SerDeException { + MAX_BUFFERED_ROWS = HiveConf.getIntVar(conf, HIVE_ARROW_BATCH_SIZE); + LOG.info("ArrowColumnarBatchSerDe max number of buffered columns: " + MAX_BUFFERED_ROWS); + + // Schema + structTypeInfo = (StructTypeInfo) getTypeInfoFromObjectInspector(rowObjectInspector); + fieldTypeInfos = structTypeInfo.getAllStructFieldTypeInfos(); + fieldSize = fieldTypeInfos.size(); + + // Init Arrow stuffs + vectorSchemaRoot = VectorSchemaRoot.create(schema, rootAllocator); + arrowVectors = vectorSchemaRoot.getFieldVectors(); + fieldWriters = Lists.newArrayList(); + for (FieldVector fieldVector : arrowVectors) { + final FieldWriter fieldWriter = + Types.getMinorTypeForArrowType( + fieldVector.getField().getType()).getNewFieldWriter(fieldVector); + fieldWriters.add(fieldWriter); + } + + // Init Hive stuffs + vectorizedRowBatch = new VectorizedRowBatch(fieldSize); + for (int i = 0; i < fieldSize; i++) { + final ColumnVector columnVector = createColumnVector(fieldTypeInfos.get(i)); + vectorizedRowBatch.cols[i] = columnVector; + columnVector.init(); + } + vectorizedRowBatch.ensureSize(MAX_BUFFERED_ROWS); + vectorAssignRow = new VectorAssignRow(); + try { + vectorAssignRow.init(rowObjectInspector); + } catch (HiveException e) { + throw new SerDeException(e); + } + } + + private ArrowWrapperWritable serializeBatch() { + for (int i = 0; i < vectorizedRowBatch.projectionSize; i++) { + final int projectedColumn = vectorizedRowBatch.projectedColumns[i]; + final ColumnVector hiveVector = vectorizedRowBatch.cols[projectedColumn]; + final TypeInfo fieldTypeInfo = structTypeInfo.getAllStructFieldTypeInfos().get(i); + final FieldWriter fieldWriter = fieldWriters.get(i); + final FieldVector arrowVector = arrowVectors.get(i); + arrowVector.setValueCount(0); + fieldWriter.setPosition(0); + write(fieldWriter, arrowVector, hiveVector, fieldTypeInfo, 0, batchSize, true); + } + vectorizedRowBatch.reset(); + vectorSchemaRoot.setRowCount(batchSize); + + batchSize = 0; + return new ArrowWrapperWritable(vectorSchemaRoot); + } + + private BaseWriter getWriter(FieldWriter writer, TypeInfo typeInfo, String name) { + switch (typeInfo.getCategory()) { + case PRIMITIVE: + switch (((PrimitiveTypeInfo) typeInfo).getPrimitiveCategory()) { + case BOOLEAN: + return writer.bit(name); + case BYTE: + return writer.tinyInt(name); + case SHORT: + return writer.smallInt(name); + case INT: + return writer.integer(name); + case LONG: + return writer.bigInt(name); + case FLOAT: + return writer.float4(name); + case DOUBLE: + return writer.float8(name); + case STRING: + case VARCHAR: + case CHAR: + return writer.varChar(name); + case DATE: + return writer.dateDay(name); + case TIMESTAMP: + return writer.timeStampMilli(name); + case BINARY: + return writer.varBinary(name); + case DECIMAL: + final DecimalTypeInfo decimalTypeInfo = (DecimalTypeInfo) typeInfo; + final int scale = decimalTypeInfo.scale(); + final int precision = decimalTypeInfo.precision(); + return writer.decimal(name, scale, precision); + case INTERVAL_YEAR_MONTH: + return writer.intervalDay(name); + case INTERVAL_DAY_TIME: + return writer.intervalYear(name); + case TIMESTAMPLOCALTZ: // VectorAssignRow doesn't support it + case VOID: + case UNKNOWN: + default: + throw new IllegalArgumentException(); + } + case LIST: + case UNION: + return writer.list(name); + case STRUCT: + return writer.map(name); + case MAP: // The caller will convert map to array + default: + throw new IllegalArgumentException(); + } + } + + private BaseWriter getWriter(FieldWriter writer, TypeInfo typeInfo) { + switch (typeInfo.getCategory()) { + case PRIMITIVE: + switch (((PrimitiveTypeInfo) typeInfo).getPrimitiveCategory()) { + case BOOLEAN: + return writer.bit(); + case BYTE: + return writer.tinyInt(); + case SHORT: + return writer.smallInt(); + case INT: + return writer.integer(); + case LONG: + return writer.bigInt(); + case FLOAT: + return writer.float4(); + case DOUBLE: + return writer.float8(); + case STRING: + case VARCHAR: + case CHAR: + return writer.varChar(); + case DATE: + return writer.dateDay(); + case TIMESTAMP: + return writer.timeStampMilli(); + case BINARY: + return writer.varBinary(); + case INTERVAL_YEAR_MONTH: + return writer.intervalDay(); + case INTERVAL_DAY_TIME: + return writer.intervalYear(); + case TIMESTAMPLOCALTZ: // VectorAssignRow doesn't support it + case DECIMAL: // ListVector doesn't support it + case VOID: + case UNKNOWN: + default: + throw new IllegalArgumentException(); + } + case LIST: + case UNION: + return writer.list(); + case STRUCT: + return writer.map(); + case MAP: // The caller will convert map to array + default: + throw new IllegalArgumentException(); + } + } + + private void write(BaseWriter baseWriter, FieldVector arrowVector, ColumnVector hiveVector, + TypeInfo typeInfo, int offset, int length, boolean incrementIndex) { + + final IntConsumer writer; + switch (typeInfo.getCategory()) { + case PRIMITIVE: + final PrimitiveObjectInspector.PrimitiveCategory primitiveCategory = + ((PrimitiveTypeInfo) typeInfo).getPrimitiveCategory(); + switch (primitiveCategory) { + case BOOLEAN: + final LongColumnVector bitVector = (LongColumnVector) hiveVector; + final BitWriter bitWriter = (BitWriter) baseWriter; + writer = i -> bitWriter.writeBit((int) bitVector.vector[i]); + break; + case BYTE: + final LongColumnVector tinyIntVector = (LongColumnVector) hiveVector; + final TinyIntWriter tinyIntWriter = (TinyIntWriter) baseWriter; + writer = i -> tinyIntWriter.writeTinyInt((byte) tinyIntVector.vector[i]); + break; + case SHORT: + final LongColumnVector smallIntVector = (LongColumnVector) hiveVector; + final SmallIntWriter smallIntWriter = (SmallIntWriter) baseWriter; + writer = i -> smallIntWriter.writeSmallInt((short) smallIntVector.vector[i]); + break; + case INT: + final LongColumnVector intColumnVector = (LongColumnVector) hiveVector; + final IntWriter intWriter = (IntWriter) baseWriter; + writer = i -> intWriter.writeInt((int) intColumnVector.vector[i]); + break; + case LONG: + final LongColumnVector bigIntVector = (LongColumnVector) hiveVector; + final BigIntWriter bigIntWriter = (BigIntWriter) baseWriter; + writer = i -> bigIntWriter.writeBigInt(bigIntVector.vector[i]); + break; + case FLOAT: + final DoubleColumnVector floatVector = (DoubleColumnVector) hiveVector; + final Float4Writer float4Writer = (Float4Writer) baseWriter; + writer = i -> float4Writer.writeFloat4((float) floatVector.vector[i]); + break; + case DOUBLE: + final DoubleColumnVector doubleVector = (DoubleColumnVector) hiveVector; + final Float8Writer float8Writer = (Float8Writer) baseWriter; + writer = i -> float8Writer.writeFloat8(doubleVector.vector[i]); + break; + case STRING: + case VARCHAR: + case CHAR: + final BytesColumnVector stringVector = (BytesColumnVector) hiveVector; + final VarCharWriter varCharWriter = (VarCharWriter) baseWriter; + writer = i -> { + final byte[] bytes = stringVector.vector[i]; + final int start = stringVector.start[i]; + final int byteLength = stringVector.length[i]; + try (ArrowBuf arrowBuf = rootAllocator.buffer(byteLength)) { + arrowBuf.setBytes(0, bytes, start, byteLength); + varCharWriter.writeVarChar(0, byteLength, arrowBuf); + } + }; + break; + case DATE: + final LongColumnVector dateVector = (LongColumnVector) hiveVector; + final DateDayWriter dateDayWriter = (DateDayWriter) baseWriter; + writer = i -> dateDayWriter.writeDateDay((int) dateVector.vector[i]); + break; + case TIMESTAMP: + final TimestampColumnVector timestampVector = (TimestampColumnVector) hiveVector; + final TimeStampMilliWriter timeStampMilliWriter = (TimeStampMilliWriter) baseWriter; + writer = i -> { + final long time = timestampVector.getTime(i); + timeStampMilliWriter.writeTimeStampMilli(time); + }; + break; + case BINARY: + final BytesColumnVector binaryVector = (BytesColumnVector) hiveVector; + final VarBinaryWriter varBinaryWriter = (VarBinaryWriter) baseWriter; + writer = i -> { + final byte[] bytes = binaryVector.vector[i]; + final int start = binaryVector.start[i]; + final int byteLength = binaryVector.length[i]; + try (ArrowBuf arrowBuf = rootAllocator.buffer(byteLength)) { + arrowBuf.setBytes(0, bytes, start, byteLength); + varBinaryWriter.writeVarBinary(0, byteLength, arrowBuf); + } + }; + break; + case DECIMAL: + final DecimalColumnVector decimalVector = (DecimalColumnVector) hiveVector; + final DecimalWriter decimalWriter = (DecimalWriter) baseWriter; + writer = i -> decimalWriter.writeDecimal( + decimalVector.vector[i].getHiveDecimal().bigDecimalValue() + .setScale(decimalVector.scale)); + break; + case INTERVAL_YEAR_MONTH: + final LongColumnVector intervalYearMonthVector = + (LongColumnVector) hiveVector; + final IntervalYearWriter intervalYearWriter = (IntervalYearWriter) baseWriter; + writer = i -> intervalYearWriter.writeIntervalYear( + (int) intervalYearMonthVector.vector[i]); + break; + case INTERVAL_DAY_TIME: + final IntervalDayTimeColumnVector intervalDayTimeVector = + (IntervalDayTimeColumnVector) hiveVector; + final IntervalDayWriter intervalDayWriter = (IntervalDayWriter) baseWriter; + writer = i -> { + final long millis = (intervalDayTimeVector.getTotalSeconds(i) * 1_000) + + (intervalDayTimeVector.getNanos(i) / 1_000_000); + final int days = (int) (millis / MS_PER_DAY); + intervalDayWriter.writeIntervalDay(days, (int) (millis % MS_PER_DAY)); + }; + break; + case VOID: + case UNKNOWN: + case TIMESTAMPLOCALTZ: + default: + throw new IllegalArgumentException(); + } + break; + case LIST: + final ListTypeInfo listTypeInfo = (ListTypeInfo) typeInfo; + final TypeInfo elementTypeInfo = listTypeInfo.getListElementTypeInfo(); + final ListColumnVector hiveListVector = (ListColumnVector) hiveVector; + final ColumnVector hiveElementVector = hiveListVector.child; + final FieldVector arrowElementVector = arrowVector.getChildrenFromFields().get(0); + final BaseWriter.ListWriter listWriter = (BaseWriter.ListWriter) baseWriter; + final BaseWriter elementWriter = getWriter((FieldWriter) listWriter, elementTypeInfo); + + writer = i -> { + final int listOffset = (int) hiveListVector.offsets[i]; + final int listLength = (int) hiveListVector.lengths[i]; + listWriter.startList(); + write(elementWriter, arrowElementVector, hiveElementVector, elementTypeInfo, + listOffset, listLength, false); + listWriter.endList(); + }; + + incrementIndex = false; + break; + case STRUCT: + final StructTypeInfo structTypeInfo = (StructTypeInfo) typeInfo; + final List fieldTypeInfos = structTypeInfo.getAllStructFieldTypeInfos(); + final StructColumnVector hiveStructVector = (StructColumnVector) hiveVector; + final List arrowFieldVectors = arrowVector.getChildrenFromFields(); + final ColumnVector[] hiveFieldVectors = hiveStructVector.fields; + final BaseWriter.MapWriter structWriter = (BaseWriter.MapWriter) baseWriter; + final int fieldSize = fieldTypeInfos.size(); + + writer = i -> { + structWriter.start(); + for (int fieldIndex = 0; fieldIndex < fieldSize; fieldIndex++) { + final TypeInfo fieldTypeInfo = fieldTypeInfos.get(fieldIndex); + final String fieldName = structTypeInfo.getAllStructFieldNames().get(fieldIndex); + final ColumnVector hiveFieldVector = hiveFieldVectors[fieldIndex]; + final BaseWriter fieldWriter = getWriter((FieldWriter) structWriter, fieldTypeInfo, + fieldName); + final FieldVector arrowFieldVector = arrowFieldVectors.get(fieldIndex); + write(fieldWriter, arrowFieldVector, hiveFieldVector, fieldTypeInfo, i, 1, false); + } + structWriter.end(); + }; + + incrementIndex = false; + break; + case UNION: + final UnionTypeInfo unionTypeInfo = (UnionTypeInfo) typeInfo; + final List objectTypeInfos = unionTypeInfo.getAllUnionObjectTypeInfos(); + final UnionColumnVector hiveUnionVector = (UnionColumnVector) hiveVector; + final ColumnVector[] hiveObjectVectors = hiveUnionVector.fields; + final UnionWriter unionWriter = (UnionWriter) baseWriter; + + writer = i -> { + final int tag = hiveUnionVector.tags[i]; + final ColumnVector hiveObjectVector = hiveObjectVectors[tag]; + final TypeInfo objectTypeInfo = objectTypeInfos.get(tag); + write(unionWriter, arrowVector, hiveObjectVector, objectTypeInfo, i, 1, false); + }; + break; + case MAP: + final ListTypeInfo mapTypeInfo = toStructListTypeInfo((MapTypeInfo) typeInfo); + final ListColumnVector structListVector = toStructListVector((MapColumnVector) hiveVector); + + writer = i -> write(baseWriter, arrowVector, structListVector, mapTypeInfo, i, length, false); + + incrementIndex = false; + break; + default: + throw new IllegalArgumentException(); + } + + if (hiveVector.noNulls) { + if (hiveVector.isRepeating) { + for (int i = 0; i < length; i++) { + writer.accept(0); + if (incrementIndex) { + baseWriter.setPosition(baseWriter.getPosition() + 1); + } + } + } else { + if (vectorizedRowBatch.selectedInUse) { + for (int j = 0; j < length; j++) { + final int i = vectorizedRowBatch.selected[j]; + writer.accept(offset + i); + if (incrementIndex) { + baseWriter.setPosition(baseWriter.getPosition() + 1); + } + } + } else { + for (int i = 0; i < length; i++) { + writer.accept(offset + i); + if (incrementIndex) { + baseWriter.setPosition(baseWriter.getPosition() + 1); + } + } + } + } + } else { + if (hiveVector.isRepeating) { + for (int i = 0; i < length; i++) { + if (hiveVector.isNull[0]) { + BitVectorHelper.setValidityBit(arrowVector.getValidityBuffer(), i, 0); + } else { + writer.accept(0); + } + if (incrementIndex) { + baseWriter.setPosition(baseWriter.getPosition() + 1); + } + } + } else { + if (vectorizedRowBatch.selectedInUse) { + for (int j = 0; j < length; j++) { + final int i = vectorizedRowBatch.selected[j]; + if (hiveVector.isNull[offset + i]) { + BitVectorHelper.setValidityBit(arrowVector.getValidityBuffer(), offset + i, 0); + } else { + writer.accept(offset + i); + } + if (incrementIndex) { + baseWriter.setPosition(baseWriter.getPosition() + 1); + } + } + } else { + for (int i = 0; i < length; i++) { + if (hiveVector.isNull[offset + i]) { + BitVectorHelper.setValidityBit(arrowVector.getValidityBuffer(), offset + i, 0); + } else { + writer.accept(offset + i); + } + if (incrementIndex) { + baseWriter.setPosition(baseWriter.getPosition() + 1); + } + } + } + } + } + } + + public ArrowWrapperWritable serialize(Object obj, ObjectInspector objInspector) { + // if row is null, it means there are no more rows (closeOp()). + // another case can be that the buffer is full. + if (obj == null) { + return serializeBatch(); + } + List row = ((StructObjectInspector) objInspector).getStructFieldsDataAsList(obj); + vectorAssignRow.assignRow(vectorizedRowBatch, batchSize, row, fieldSize); + batchSize++; + if (batchSize == MAX_BUFFERED_ROWS) { + return serializeBatch(); + } + return null; + } + } + + private static abstract class PrimitiveReader { + final void read(FieldReader reader, ColumnVector columnVector, int offset, int length) { + for (int i = 0; i < length; i++) { + final int rowIndex = offset + i; + if (reader.isSet()) { + doRead(reader, columnVector, rowIndex); + } else { + columnVector.noNulls = false; + columnVector.isNull[rowIndex] = false; + } + reader.setPosition(reader.getPosition() + 1); + } + } + + abstract void doRead(FieldReader reader, ColumnVector columnVector, int rowIndex); + } + + private class Deserializer { + private final VectorExtractRow vectorExtractRow; + private final VectorizedRowBatch vectorizedRowBatch; + private Object[][] rows; + + public Deserializer() throws SerDeException { + vectorExtractRow = new VectorExtractRow(); + final List fieldTypeInfoList = rowTypeInfo.getAllStructFieldTypeInfos(); + final int fieldCount = fieldTypeInfoList.size(); + final TypeInfo[] typeInfos = fieldTypeInfoList.toArray(new TypeInfo[fieldCount]); + try { + vectorExtractRow.init(typeInfos); + } catch (HiveException e) { + throw new SerDeException(e); + } + + vectorizedRowBatch = new VectorizedRowBatch(fieldCount); + for (int i = 0; i < fieldCount; i++) { + final ColumnVector columnVector = createColumnVector(typeInfos[i]); + columnVector.init(); + vectorizedRowBatch.cols[i] = columnVector; + } + } + + public Object deserialize(Writable writable) { + final ArrowWrapperWritable arrowWrapperWritable = (ArrowWrapperWritable) writable; + final VectorSchemaRoot vectorSchemaRoot = arrowWrapperWritable.getVectorSchemaRoot(); + final List fieldVectors = vectorSchemaRoot.getFieldVectors(); + final int fieldCount = fieldVectors.size(); + final int rowCount = vectorSchemaRoot.getRowCount(); + vectorizedRowBatch.ensureSize(rowCount); + + if (rows == null || rows.length < rowCount ) { + rows = new Object[rowCount][]; + for (int rowIndex = 0; rowIndex < rowCount; rowIndex++) { + rows[rowIndex] = new Object[fieldCount]; + } + } + + for (int i = 0; i < fieldCount; i++) { + final FieldVector fieldVector = fieldVectors.get(i); + final FieldReader fieldReader = fieldVector.getReader(); + fieldReader.setPosition(0); + final int projectedCol = vectorizedRowBatch.projectedColumns[i]; + final ColumnVector columnVector = vectorizedRowBatch.cols[projectedCol]; + final TypeInfo typeInfo = rowTypeInfo.getAllStructFieldTypeInfos().get(i); + read(fieldReader, columnVector, typeInfo, 0, rowCount); + } + for (int rowIndex = 0; rowIndex < rowCount; rowIndex++) { + vectorExtractRow.extractRow(vectorizedRowBatch, rowIndex, rows[rowIndex]); + } + vectorizedRowBatch.reset(); + return rows; + } + + private void read(FieldReader reader, ColumnVector columnVector, TypeInfo typeInfo, int offset, + int length) { + switch (typeInfo.getCategory()) { + case PRIMITIVE: + final PrimitiveObjectInspector.PrimitiveCategory primitiveCategory = + ((PrimitiveTypeInfo) typeInfo).getPrimitiveCategory(); + final PrimitiveReader primitiveReader; + switch (primitiveCategory) { + case BOOLEAN: + primitiveReader = new PrimitiveReader() { + NullableBitHolder holder = new NullableBitHolder(); + + @Override + void doRead(FieldReader reader, ColumnVector columnVector, int rowIndex) { + reader.read(holder); + ((LongColumnVector) columnVector).vector[rowIndex] = holder.value; + } + }; + break; + case BYTE: + primitiveReader = new PrimitiveReader() { + NullableTinyIntHolder holder = new NullableTinyIntHolder(); + + @Override + void doRead(FieldReader reader, ColumnVector columnVector, int rowIndex) { + reader.read(holder); + ((LongColumnVector) columnVector).vector[rowIndex] = holder.value; + } + }; + break; + case SHORT: + primitiveReader = new PrimitiveReader() { + NullableSmallIntHolder holder = new NullableSmallIntHolder(); + + @Override + void doRead(FieldReader reader, ColumnVector columnVector, int rowIndex) { + reader.read(holder); + ((LongColumnVector) columnVector).vector[rowIndex] = holder.value; + } + }; + break; + case INT: + primitiveReader = new PrimitiveReader() { + NullableIntHolder holder = new NullableIntHolder(); + + @Override + void doRead(FieldReader reader, ColumnVector columnVector, int rowIndex) { + reader.read(holder); + ((LongColumnVector) columnVector).vector[rowIndex] = holder.value; + } + }; + break; + case LONG: + primitiveReader = new PrimitiveReader() { + NullableBigIntHolder holder = new NullableBigIntHolder(); + + @Override + void doRead(FieldReader reader, ColumnVector columnVector, int rowIndex) { + reader.read(holder); + ((LongColumnVector) columnVector).vector[rowIndex] = holder.value; + } + }; + break; + case FLOAT: + primitiveReader = new PrimitiveReader() { + NullableFloat4Holder holder = new NullableFloat4Holder(); + + @Override + void doRead(FieldReader reader, ColumnVector columnVector, int rowIndex) { + reader.read(holder); + ((DoubleColumnVector) columnVector).vector[rowIndex] = holder.value; + } + }; + break; + case DOUBLE: + primitiveReader = new PrimitiveReader() { + NullableFloat8Holder holder = new NullableFloat8Holder(); + + @Override + void doRead(FieldReader reader, ColumnVector columnVector, int rowIndex) { + reader.read(holder); + ((DoubleColumnVector) columnVector).vector[rowIndex] = holder.value; + } + }; + break; + case STRING: + case VARCHAR: + case CHAR: + primitiveReader = new PrimitiveReader() { + NullableVarCharHolder holder = new NullableVarCharHolder(); + + @Override + void doRead(FieldReader reader, ColumnVector columnVector, int rowIndex) { + reader.read(holder); + int varCharSize = holder.end - holder.start; + byte[] varCharBytes = new byte[varCharSize]; + holder.buffer.getBytes(holder.start, varCharBytes); + ((BytesColumnVector) columnVector).setVal(rowIndex, varCharBytes, 0, varCharSize); + } + }; + break; + case DATE: + primitiveReader = new PrimitiveReader() { + NullableDateDayHolder holder = new NullableDateDayHolder(); + + @Override + void doRead(FieldReader reader, ColumnVector columnVector, int rowIndex) { + reader.read(holder); + ((LongColumnVector) columnVector).vector[rowIndex] = holder.value; + } + }; + break; + case TIMESTAMP: + primitiveReader = new PrimitiveReader() { + NullableTimeStampMilliHolder timeStampMilliHolder = + new NullableTimeStampMilliHolder(); + + @Override + void doRead(FieldReader reader, ColumnVector columnVector, int rowIndex) { + reader.read(timeStampMilliHolder); + ((TimestampColumnVector) columnVector).set(rowIndex, + new Timestamp(timeStampMilliHolder.value)); + } + }; + break; + case BINARY: + primitiveReader = new PrimitiveReader() { + NullableVarBinaryHolder holder = new NullableVarBinaryHolder(); + + @Override + void doRead(FieldReader reader, ColumnVector columnVector, int rowIndex) { + reader.read(holder); + final int binarySize = holder.end - holder.start; + final byte[] binaryBytes = new byte[binarySize]; + holder.buffer.getBytes(holder.start, binaryBytes); + ((BytesColumnVector) columnVector).setVal(rowIndex, binaryBytes, 0, binarySize); + } + }; + break; + case DECIMAL: + primitiveReader = new PrimitiveReader() { + @Override + void doRead(FieldReader reader, ColumnVector columnVector, int rowIndex) { + ((DecimalColumnVector) columnVector).set(rowIndex, + HiveDecimal.create(reader.readBigDecimal())); + } + }; + break; + case INTERVAL_YEAR_MONTH: + primitiveReader = new PrimitiveReader() { + NullableIntervalYearHolder holder = new NullableIntervalYearHolder(); + + @Override + void doRead(FieldReader reader, ColumnVector columnVector, int rowIndex) { + reader.read(holder); + ((LongColumnVector) columnVector).vector[rowIndex] = holder.value; + } + }; + break; + case INTERVAL_DAY_TIME: + primitiveReader = new PrimitiveReader() { + NullableIntervalDayHolder holder = new NullableIntervalDayHolder(); + + @Override + void doRead(FieldReader reader, ColumnVector columnVector, int rowIndex) { + IntervalDayTimeColumnVector intervalDayTimeVector = + (IntervalDayTimeColumnVector) columnVector; + reader.read(holder); + HiveIntervalDayTime intervalDayTime = new HiveIntervalDayTime( + holder.days, // days + holder.milliseconds / MS_PER_HOUR, // hour + (holder.milliseconds % MS_PER_HOUR) / MS_PER_MINUTE, // minute + (holder.milliseconds % MS_PER_MINUTE) / MS_PER_SECOND, // second + (holder.milliseconds % MS_PER_SECOND) * NS_PER_MS); // nanosecond + intervalDayTimeVector.set(rowIndex, intervalDayTime); + } + }; + break; + default: + throw new IllegalArgumentException(); + } + primitiveReader.read(reader, columnVector, offset, length); + break; + case LIST: + final ListTypeInfo listTypeInfo = (ListTypeInfo) typeInfo; + final TypeInfo elementTypeInfo = listTypeInfo.getListElementTypeInfo(); + final ListColumnVector listVector = (ListColumnVector) columnVector; + final ColumnVector elementVector = listVector.child; + final FieldReader elementReader = reader.reader(); + + int listOffset = 0; + for (int rowIndex = 0; rowIndex < length; rowIndex++) { + reader.setPosition(offset + rowIndex); + final int listLength = reader.size(); + listVector.offsets[offset + rowIndex] = listOffset; + listVector.lengths[offset + rowIndex] = listLength; + read(elementReader, elementVector, elementTypeInfo, listOffset, listLength); + listOffset += listLength; + } + break; + case STRUCT: + final StructTypeInfo structTypeInfo = (StructTypeInfo) typeInfo; + final List fieldTypeInfos = structTypeInfo.getAllStructFieldTypeInfos(); + final List fieldNames = structTypeInfo.getAllStructFieldNames(); + final int fieldSize = fieldNames.size(); + final StructColumnVector structVector = (StructColumnVector) columnVector; + final ColumnVector[] fieldVectors = structVector.fields; + + for (int fieldIndex = 0; fieldIndex < fieldSize; fieldIndex++) { + final TypeInfo fieldTypeInfo = fieldTypeInfos.get(fieldIndex); + final FieldReader fieldReader = reader.reader(fieldNames.get(fieldIndex)); + final ColumnVector fieldVector = fieldVectors[fieldIndex]; + read(fieldReader, fieldVector, fieldTypeInfo, offset, length); + } + break; + case UNION: + final UnionTypeInfo unionTypeInfo = (UnionTypeInfo) typeInfo; + final List objectTypeInfos = unionTypeInfo.getAllUnionObjectTypeInfos(); + final UnionColumnVector unionVector = (UnionColumnVector) columnVector; + final ColumnVector[] objectVectors = unionVector.fields; + final Map minorTypeToTagMap = Maps.newHashMap(); + for (int tag = 0; tag < objectTypeInfos.size(); tag++) { + minorTypeToTagMap.put(toMinorType(objectTypeInfos.get(tag)), tag); + } + + final UnionReader unionReader = (UnionReader) reader; + for (int rowIndex = 0; rowIndex < length; rowIndex++) { + unionReader.setPosition(offset + rowIndex); + final Types.MinorType minorType = unionReader.getMinorType(); + final int tag = minorTypeToTagMap.get(minorType); + unionVector.tags[offset + rowIndex] = tag; + read(unionReader, objectVectors[tag], objectTypeInfos.get(tag), offset + rowIndex, 1); + } + break; + case MAP: + final MapTypeInfo mapTypeInfo = (MapTypeInfo) typeInfo; + final ListTypeInfo mapStructListTypeInfo = toStructListTypeInfo(mapTypeInfo); + final MapColumnVector hiveMapVector = (MapColumnVector) columnVector; + final ListColumnVector mapStructListVector = toStructListVector(hiveMapVector); + final StructColumnVector mapStructVector = (StructColumnVector) mapStructListVector.child; + read(reader, mapStructListVector, mapStructListTypeInfo, offset, length); + + hiveMapVector.isRepeating = mapStructListVector.isRepeating; + hiveMapVector.childCount = mapStructListVector.childCount; + hiveMapVector.noNulls = mapStructListVector.noNulls; + System.arraycopy(mapStructListVector.offsets, 0, hiveMapVector.offsets, 0, length); + System.arraycopy(mapStructListVector.lengths, 0, hiveMapVector.lengths, 0, length); + hiveMapVector.keys = mapStructVector.fields[0]; + hiveMapVector.values = mapStructVector.fields[1]; + break; + default: + throw new IllegalArgumentException(); + } + } + } + + private static Types.MinorType toMinorType(TypeInfo typeInfo) { + switch (typeInfo.getCategory()) { + case PRIMITIVE: + switch (((PrimitiveTypeInfo) typeInfo).getPrimitiveCategory()) { + case BOOLEAN: + return Types.MinorType.BIT; + case BYTE: + return Types.MinorType.TINYINT; + case SHORT: + return Types.MinorType.SMALLINT; + case INT: + return Types.MinorType.INT; + case LONG: + return Types.MinorType.BIGINT; + case FLOAT: + return Types.MinorType.FLOAT4; + case DOUBLE: + return Types.MinorType.FLOAT8; + case STRING: + case VARCHAR: + case CHAR: + return Types.MinorType.VARCHAR; + case DATE: + return Types.MinorType.DATEDAY; + case TIMESTAMP: + return Types.MinorType.TIMESTAMPMILLI; + case BINARY: + return Types.MinorType.VARBINARY; + case DECIMAL: + return Types.MinorType.DECIMAL; + case INTERVAL_YEAR_MONTH: + return Types.MinorType.INTERVALYEAR; + case INTERVAL_DAY_TIME: + return Types.MinorType.INTERVALDAY; + case VOID: + case TIMESTAMPLOCALTZ: + case UNKNOWN: + default: + throw new IllegalArgumentException(); + } + case LIST: + return Types.MinorType.LIST; + case STRUCT: + return Types.MinorType.MAP; + case UNION: + return Types.MinorType.UNION; + case MAP: + // Apache Arrow doesn't have a map vector, so it's converted to a list vector of a struct + // vector. + return Types.MinorType.LIST; + default: + throw new IllegalArgumentException(); + } + } + + private static ListTypeInfo toStructListTypeInfo(MapTypeInfo mapTypeInfo) { + final StructTypeInfo structTypeInfo = new StructTypeInfo(); + structTypeInfo.setAllStructFieldNames(Lists.newArrayList("keys", "values")); + structTypeInfo.setAllStructFieldTypeInfos(Lists.newArrayList( + mapTypeInfo.getMapKeyTypeInfo(), mapTypeInfo.getMapValueTypeInfo())); + final ListTypeInfo structListTypeInfo = new ListTypeInfo(); + structListTypeInfo.setListElementTypeInfo(structTypeInfo); + return structListTypeInfo; + } + + private static Field toField(String name, TypeInfo typeInfo) { + switch (typeInfo.getCategory()) { + case PRIMITIVE: + final PrimitiveTypeInfo primitiveTypeInfo = (PrimitiveTypeInfo) typeInfo; + switch (primitiveTypeInfo.getPrimitiveCategory()) { + case BOOLEAN: + return Field.nullable(name, Types.MinorType.BIT.getType()); + case BYTE: + return Field.nullable(name, Types.MinorType.TINYINT.getType()); + case SHORT: + return Field.nullable(name, Types.MinorType.SMALLINT.getType()); + case INT: + return Field.nullable(name, Types.MinorType.INT.getType()); + case LONG: + return Field.nullable(name, Types.MinorType.BIGINT.getType()); + case FLOAT: + return Field.nullable(name, Types.MinorType.FLOAT4.getType()); + case DOUBLE: + return Field.nullable(name, Types.MinorType.FLOAT8.getType()); + case STRING: + return Field.nullable(name, Types.MinorType.VARCHAR.getType()); + case DATE: + return Field.nullable(name, Types.MinorType.DATEDAY.getType()); + case TIMESTAMP: + return Field.nullable(name, Types.MinorType.TIMESTAMPMILLI.getType()); + case TIMESTAMPLOCALTZ: + final TimestampLocalTZTypeInfo timestampLocalTZTypeInfo = + (TimestampLocalTZTypeInfo) typeInfo; + final String timeZone = timestampLocalTZTypeInfo.getTimeZone().toString(); + return Field.nullable(name, new ArrowType.Timestamp(TimeUnit.MILLISECOND, timeZone)); + case BINARY: + return Field.nullable(name, Types.MinorType.VARBINARY.getType()); + case DECIMAL: + final DecimalTypeInfo decimalTypeInfo = (DecimalTypeInfo) typeInfo; + final int precision = decimalTypeInfo.precision(); + final int scale = decimalTypeInfo.scale(); + return Field.nullable(name, new ArrowType.Decimal(precision, scale)); + case VARCHAR: + return Field.nullable(name, Types.MinorType.VARCHAR.getType()); + case CHAR: + return Field.nullable(name, Types.MinorType.VARCHAR.getType()); + case INTERVAL_YEAR_MONTH: + return Field.nullable(name, Types.MinorType.INTERVALYEAR.getType()); + case INTERVAL_DAY_TIME: + return Field.nullable(name, Types.MinorType.INTERVALDAY.getType()); + default: + throw new IllegalArgumentException(); + } + case LIST: + final ListTypeInfo listTypeInfo = (ListTypeInfo) typeInfo; + final TypeInfo elementTypeInfo = listTypeInfo.getListElementTypeInfo(); + return new Field(name, FieldType.nullable(Types.MinorType.LIST.getType()), + Lists.newArrayList(toField(DEFAULT_ARROW_FIELD_NAME, elementTypeInfo))); + case STRUCT: + final StructTypeInfo structTypeInfo = (StructTypeInfo) typeInfo; + final List fieldTypeInfos = structTypeInfo.getAllStructFieldTypeInfos(); + final List fieldNames = structTypeInfo.getAllStructFieldNames(); + final List structFields = Lists.newArrayList(); + final int structSize = fieldNames.size(); + for (int i = 0; i < structSize; i++) { + structFields.add(toField(fieldNames.get(i), fieldTypeInfos.get(i))); + } + return new Field(name, FieldType.nullable(Types.MinorType.MAP.getType()), structFields); + case UNION: + final UnionTypeInfo unionTypeInfo = (UnionTypeInfo) typeInfo; + final List objectTypeInfos = unionTypeInfo.getAllUnionObjectTypeInfos(); + final List unionFields = Lists.newArrayList(); + final int unionSize = unionFields.size(); + for (int i = 0; i < unionSize; i++) { + unionFields.add(toField(DEFAULT_ARROW_FIELD_NAME, objectTypeInfos.get(i))); + } + return new Field(name, FieldType.nullable(Types.MinorType.UNION.getType()), unionFields); + case MAP: + final MapTypeInfo mapTypeInfo = (MapTypeInfo) typeInfo; + final TypeInfo keyTypeInfo = mapTypeInfo.getMapKeyTypeInfo(); + final TypeInfo valueTypeInfo = mapTypeInfo.getMapValueTypeInfo(); + final ListTypeInfo mapListTypeInfo = new ListTypeInfo(); + final StructTypeInfo mapStructTypeInfo = new StructTypeInfo(); + mapStructTypeInfo.setAllStructFieldNames( + Lists.newArrayList("keys", "values")); + mapStructTypeInfo.setAllStructFieldTypeInfos( + Lists.newArrayList(keyTypeInfo, valueTypeInfo)); + mapListTypeInfo.setListElementTypeInfo(mapStructTypeInfo); + return toField(DEFAULT_ARROW_FIELD_NAME, mapListTypeInfo); + default: + throw new IllegalArgumentException(); + } + } + + private static ListColumnVector toStructListVector(MapColumnVector mapVector) { + final StructColumnVector structVector; + final ListColumnVector structListVector; + structVector = new StructColumnVector(); + structVector.fields = new ColumnVector[] {mapVector.keys, mapVector.values}; + structListVector = new ListColumnVector(); + structListVector.child = structVector; + System.arraycopy(mapVector.offsets, 0, structListVector.offsets, 0, mapVector.childCount); + System.arraycopy(mapVector.lengths, 0, structListVector.lengths, 0, mapVector.childCount); + structListVector.childCount = mapVector.childCount; + structListVector.isRepeating = mapVector.isRepeating; + structListVector.noNulls = mapVector.noNulls; + return structListVector; + } + + @Override + public Class getSerializedClass() { + return ArrowWrapperWritable.class; + } + + @Override + public ArrowWrapperWritable serialize(Object obj, ObjectInspector objInspector) { + return serializer.serialize(obj, objInspector); + } + + @Override + public SerDeStats getSerDeStats() { + return null; + } + + @Override + public Object deserialize(Writable writable) { + return deserializer.deserialize(writable); + } + + @Override + public ObjectInspector getObjectInspector() { + return rowObjectInspector; + } +} diff --git ql/src/java/org/apache/hadoop/hive/ql/io/arrow/ArrowWrapperWritable.java ql/src/java/org/apache/hadoop/hive/ql/io/arrow/ArrowWrapperWritable.java new file mode 100644 index 0000000000..df7b53f42a --- /dev/null +++ ql/src/java/org/apache/hadoop/hive/ql/io/arrow/ArrowWrapperWritable.java @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.hadoop.hive.ql.io.arrow; + +import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.hadoop.io.Writable; + +import java.io.DataInput; +import java.io.DataOutput; +import java.io.IOException; + +public class ArrowWrapperWritable implements Writable { + private VectorSchemaRoot vectorSchemaRoot; + + public ArrowWrapperWritable(VectorSchemaRoot vectorSchemaRoot) { + this.vectorSchemaRoot = vectorSchemaRoot; + } + + public VectorSchemaRoot getVectorSchemaRoot() { + return vectorSchemaRoot; + } + + @Override + public void write(DataOutput dataOutput) throws IOException { + throw new UnsupportedOperationException(); + } + + @Override + public void readFields(DataInput dataInput) throws IOException { + throw new UnsupportedOperationException(); + } +} diff --git ql/src/java/org/apache/hadoop/hive/ql/io/arrow/RootAllocatorFactory.java ql/src/java/org/apache/hadoop/hive/ql/io/arrow/RootAllocatorFactory.java new file mode 100644 index 0000000000..78cc188e65 --- /dev/null +++ ql/src/java/org/apache/hadoop/hive/ql/io/arrow/RootAllocatorFactory.java @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.hadoop.hive.ql.io.arrow; + +import org.apache.arrow.memory.RootAllocator; +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.hive.conf.HiveConf; + +import static org.apache.hadoop.hive.conf.HiveConf.ConfVars.HIVE_ARROW_ROOT_ALLOCATOR_LIMIT; + +/** + * Thread-safe singleton factory for RootAllocator + */ +public enum RootAllocatorFactory { + INSTANCE; + + private RootAllocator rootAllocator; + + RootAllocatorFactory() { + } + + public synchronized RootAllocator getRootAllocator(Configuration conf) { + if (rootAllocator == null) { + final long limit = HiveConf.getLongVar(conf, HIVE_ARROW_ROOT_ALLOCATOR_LIMIT); + rootAllocator = new RootAllocator(limit); + } + return rootAllocator; + } +} diff --git ql/src/test/org/apache/hadoop/hive/ql/io/arrow/AbstractTest.java ql/src/test/org/apache/hadoop/hive/ql/io/arrow/AbstractTest.java new file mode 100644 index 0000000000..e9450e9a0d --- /dev/null +++ ql/src/test/org/apache/hadoop/hive/ql/io/arrow/AbstractTest.java @@ -0,0 +1,244 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.hadoop.hive.ql.io.arrow; + +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.hive.common.type.HiveChar; +import org.apache.hadoop.hive.common.type.HiveDecimal; +import org.apache.hadoop.hive.common.type.HiveIntervalDayTime; +import org.apache.hadoop.hive.common.type.HiveIntervalYearMonth; +import org.apache.hadoop.hive.common.type.HiveVarchar; +import org.apache.hadoop.hive.serde.serdeConstants; +import org.apache.hadoop.hive.serde2.AbstractSerDe; +import org.apache.hadoop.hive.serde2.SerDeException; +import org.apache.hadoop.hive.serde2.SerDeUtils; +import org.apache.hadoop.hive.serde2.io.ByteWritable; +import org.apache.hadoop.hive.serde2.io.DateWritable; +import org.apache.hadoop.hive.serde2.io.DoubleWritable; +import org.apache.hadoop.hive.serde2.io.HiveCharWritable; +import org.apache.hadoop.hive.serde2.io.HiveDecimalWritable; +import org.apache.hadoop.hive.serde2.io.HiveIntervalDayTimeWritable; +import org.apache.hadoop.hive.serde2.io.HiveIntervalYearMonthWritable; +import org.apache.hadoop.hive.serde2.io.HiveVarcharWritable; +import org.apache.hadoop.hive.serde2.io.ShortWritable; +import org.apache.hadoop.hive.serde2.io.TimestampWritable; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.StructField; +import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector; +import org.apache.hadoop.io.BooleanWritable; +import org.apache.hadoop.io.BytesWritable; +import org.apache.hadoop.io.FloatWritable; +import org.apache.hadoop.io.IntWritable; +import org.apache.hadoop.io.LongWritable; +import org.apache.hadoop.io.Text; +import org.apache.hadoop.io.Writable; +import org.junit.Before; + +import java.sql.Timestamp; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Properties; +import java.util.Set; + +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; + +public class AbstractTest { + protected Configuration conf; + + final static Object[][] INTEGER_ROWS = { + {byteW(0), shortW(0), intW(0), longW(0)}, + {byteW(1), shortW(1), intW(1), longW(1)}, + {byteW(-1), shortW(-1), intW(-1), longW(-1)}, + {byteW(Byte.MIN_VALUE), shortW(Short.MIN_VALUE), intW(Integer.MIN_VALUE), + longW(Long.MIN_VALUE)}, + {byteW(Byte.MAX_VALUE), shortW(Short.MAX_VALUE), intW(Integer.MAX_VALUE), + longW(Long.MAX_VALUE)}, + }; + + final static Object[][] FLOAT_ROWS = { + {floatW(0f), doubleW(0d)}, + {floatW(1f), doubleW(1d)}, + {floatW(-1f), doubleW(-1d)}, + {floatW(Float.MIN_VALUE), doubleW(Double.MIN_VALUE)}, + {floatW(-Float.MIN_VALUE), doubleW(-Double.MIN_VALUE)}, + {floatW(Float.MAX_VALUE), doubleW(Double.MAX_VALUE)}, + {floatW(-Float.MAX_VALUE), doubleW(-Double.MAX_VALUE)}, + {floatW(Float.POSITIVE_INFINITY), doubleW(Double.POSITIVE_INFINITY)}, + {floatW(Float.NEGATIVE_INFINITY), doubleW(Double.NEGATIVE_INFINITY)}, + }; + + final static Object[][] STRING_ROWS = { + {text(""), charW(""), varcharW("")}, + {text("Hello"), charW("Hello"), varcharW("Hello")}, + {text("world!"), charW("world!"), varcharW("world!")}, + }; + + private final static long NOW = System.currentTimeMillis(); + final static Object[][] DTI_ROWS = { + { + new DateWritable(DateWritable.millisToDays(NOW)), + new TimestampWritable(new Timestamp(NOW)), + new HiveIntervalYearMonthWritable(new HiveIntervalYearMonth(1, 2)), + new HiveIntervalDayTimeWritable(new HiveIntervalDayTime(1, 2, 3, 4, 5_000_000)) + }, + }; + + final static Object[][] DECIMAL_ROWS = { + {decimalW(HiveDecimal.ZERO)}, + {decimalW(HiveDecimal.ONE)}, + {decimalW(HiveDecimal.ONE.negate())}, + {decimalW(HiveDecimal.create("0.000001"))}, + {decimalW(HiveDecimal.create("100000"))}, + }; + + final static Object[][] BOOLEAN_ROWS = { + {new BooleanWritable(true)}, + {new BooleanWritable(false)}, + }; + + final static Object[][] BINARY_ROWS = { + {new BytesWritable("".getBytes())}, + {new BytesWritable("Hello".getBytes())}, + {new BytesWritable("world!".getBytes())}, + }; + + @Before + public void setUp() { + conf = new Configuration(); + } + + static ByteWritable byteW(int value) { + return new ByteWritable((byte) value); + } + + static ShortWritable shortW(int value) { + return new ShortWritable((short) value); + } + + static IntWritable intW(int value) { + return new IntWritable(value); + } + + static LongWritable longW(long value) { + return new LongWritable(value); + } + + static FloatWritable floatW(float value) { + return new FloatWritable(value); + } + + static DoubleWritable doubleW(double value) { + return new DoubleWritable(value); + } + + static Text text(String value) { + return new Text(value); + } + + static HiveCharWritable charW(String value) { + return new HiveCharWritable(new HiveChar(value, 10)); + } + + static HiveVarcharWritable varcharW(String value) { + return new HiveVarcharWritable(new HiveVarchar(value, 10)); + } + + static HiveDecimalWritable decimalW(HiveDecimal value) { + return new HiveDecimalWritable(value); + } + + void initAndSerializeAndDeserialize(Object[][] rows, StructObjectInspector rowOI, + String fieldNames, String fieldTypes) throws SerDeException { + final AbstractSerDe serDe = new ArrowColumnarBatchSerDe(); + initSerDe(serDe, fieldNames, fieldTypes); + serializeAndDeserialize(serDe, rows, rowOI); + } + + void initSerDe(AbstractSerDe serDe, String fieldNames, String fieldTypes) throws SerDeException { + Properties schema = new Properties(); + schema.setProperty(serdeConstants.LIST_COLUMNS, fieldNames); + schema.setProperty(serdeConstants.LIST_COLUMN_TYPES, fieldTypes); + SerDeUtils.initializeSerDe(serDe, conf, schema, null); + } + + void serializeAndDeserialize(AbstractSerDe serDe, Object[][] rows, StructObjectInspector rowOI) + throws SerDeException { + Writable serialized = null; + for (Object[] row : rows) { + serialized = serDe.serialize(row, rowOI); + } + // Pass null to complete a batch + if (serialized == null) { + serialized = serDe.serialize(null, rowOI); + } + final Object[][] deserializedRows = (Object[][]) serDe.deserialize(serialized); + + for (int rowIndex = 0; rowIndex < rows.length; rowIndex++) { + final Object[] row = rows[rowIndex]; + final Object[] deserializedRow = deserializedRows[rowIndex]; + assertEquals(row.length, deserializedRow.length); + + final List fields = rowOI.getAllStructFieldRefs(); + for (int fieldIndex = 0; fieldIndex < fields.size(); fieldIndex++) { + final StructField field = fields.get(fieldIndex); + final ObjectInspector fieldObjInspector = field.getFieldObjectInspector(); + switch (fieldObjInspector.getCategory()) { + case PRIMITIVE: + final PrimitiveObjectInspector primitiveObjInspector = + (PrimitiveObjectInspector) fieldObjInspector; + switch (primitiveObjInspector.getPrimitiveCategory()) { + case STRING: + case VARCHAR: + case CHAR: + assertEquals(Objects.toString(row[fieldIndex]), + Objects.toString(deserializedRow[fieldIndex])); + break; + default: + assertEquals(row[fieldIndex], deserializedRow[fieldIndex]); + break; + } + break; + case STRUCT: + final Object[] rowStruct = (Object[]) row[fieldIndex]; + final List deserializedRowStruct = (List) deserializedRow[fieldIndex]; + assertArrayEquals(rowStruct, deserializedRowStruct.toArray()); + break; + case LIST: + case UNION: + assertEquals(row[fieldIndex], deserializedRow[fieldIndex]); + break; + case MAP: + Map rowMap = (Map) row[fieldIndex]; + Map deserializedRowMap = (Map) deserializedRow[fieldIndex]; + Set rowMapKeySet = rowMap.keySet(); + Set deserializedRowMapKeySet = deserializedRowMap.keySet(); + assertTrue(rowMapKeySet.containsAll(deserializedRowMapKeySet)); + assertTrue(deserializedRowMapKeySet.containsAll(rowMapKeySet)); + for (Object key : rowMapKeySet) { + assertEquals(rowMap.get(key), deserializedRowMap.get(key)); + } + break; + } + } + } + } +} diff --git ql/src/test/org/apache/hadoop/hive/ql/io/arrow/TestArrowList.java ql/src/test/org/apache/hadoop/hive/ql/io/arrow/TestArrowList.java new file mode 100644 index 0000000000..a8dc153e38 --- /dev/null +++ ql/src/test/org/apache/hadoop/hive/ql/io/arrow/TestArrowList.java @@ -0,0 +1,129 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.hadoop.hive.ql.io.arrow; + +import org.apache.hadoop.hive.serde2.SerDeException; +import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector; +import org.junit.Test; + +import java.util.List; + +import static com.google.common.collect.Lists.newArrayList; +import static org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory.getStandardListObjectInspector; +import static org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory.getStandardStructObjectInspector; +import static org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory.*; + +public class TestArrowList extends AbstractTest { + private List[][] toList(Object[][] rows) { + List[][] array = new List[rows.length][]; + for (int rowIndex = 0; rowIndex < rows.length; rowIndex++) { + Object[] row = rows[rowIndex]; + array[rowIndex] = new List[row.length]; + for (int fieldIndex = 0; fieldIndex < row.length; fieldIndex++) { + array[rowIndex][fieldIndex] = newArrayList(row[fieldIndex]); + } + } + return array; + } + + @Test + public void testIntegerList() throws SerDeException { + String fieldNames = "tinyints1,smallints1,ints1,bigints1"; + String fieldTypes = "array,array,array,array"; + + StructObjectInspector rowOI = getStandardStructObjectInspector( + newArrayList(fieldNames.split(",")), + newArrayList( + getStandardListObjectInspector(writableByteObjectInspector), + getStandardListObjectInspector(writableShortObjectInspector), + getStandardListObjectInspector(writableIntObjectInspector), + getStandardListObjectInspector(writableLongObjectInspector))); + + initAndSerializeAndDeserialize(toList(INTEGER_ROWS), rowOI, fieldNames, fieldTypes); + } + + @Test + public void testFloatList() throws SerDeException { + String fieldNames = "floats1,doubles1"; + String fieldTypes = "array,array"; + + StructObjectInspector rowOI = getStandardStructObjectInspector( + newArrayList(fieldNames.split(",")), + newArrayList( + getStandardListObjectInspector(writableFloatObjectInspector), + getStandardListObjectInspector(writableDoubleObjectInspector))); + + initAndSerializeAndDeserialize(toList(FLOAT_ROWS), rowOI, fieldNames, fieldTypes); + } + + @Test + public void testStringList() throws SerDeException { + String fieldNames = "strings1,chars1,varchars1"; + String fieldTypes = "array,array,array"; + + StructObjectInspector rowOI = getStandardStructObjectInspector( + newArrayList(fieldNames.split(",")), + newArrayList( + getStandardListObjectInspector(writableStringObjectInspector), + getStandardListObjectInspector(writableHiveCharObjectInspector), + getStandardListObjectInspector(writableHiveVarcharObjectInspector))); + + initAndSerializeAndDeserialize(toList(STRING_ROWS), rowOI, fieldNames, fieldTypes); + } + + @Test + public void testDTIList() throws SerDeException { + String fieldNames = "dates1,timestamps1,interval_year_months1,interval_day_times1"; + String fieldTypes = "array,array,array," + + "array"; + + StructObjectInspector rowOI = getStandardStructObjectInspector( + newArrayList(fieldNames.split(",")), + newArrayList( + getStandardListObjectInspector(writableDateObjectInspector), + getStandardListObjectInspector(writableTimestampObjectInspector), + getStandardListObjectInspector(writableHiveIntervalYearMonthObjectInspector), + getStandardListObjectInspector(writableHiveIntervalDayTimeObjectInspector))); + + initAndSerializeAndDeserialize(toList(DTI_ROWS), rowOI, fieldNames, fieldTypes); + } + + @Test + public void testBooleanList() throws SerDeException { + String fieldNames = "booleans1"; + String fieldTypes = "array"; + + StructObjectInspector rowOI = getStandardStructObjectInspector( + newArrayList(fieldNames.split(",")), + newArrayList(getStandardListObjectInspector(writableBooleanObjectInspector))); + + initAndSerializeAndDeserialize(toList(BOOLEAN_ROWS), rowOI, fieldNames, fieldTypes); + } + + @Test + public void testBinaryList() throws SerDeException { + String fieldNames = "binaries1"; + String fieldTypes = "array"; + + StructObjectInspector rowOI = getStandardStructObjectInspector( + newArrayList(fieldNames.split(",")), + newArrayList(getStandardListObjectInspector(writableBinaryObjectInspector))); + + initAndSerializeAndDeserialize(toList(BINARY_ROWS), rowOI, fieldNames, fieldTypes); + } +} diff --git ql/src/test/org/apache/hadoop/hive/ql/io/arrow/TestArrowMap.java ql/src/test/org/apache/hadoop/hive/ql/io/arrow/TestArrowMap.java new file mode 100644 index 0000000000..b3d1cf4834 --- /dev/null +++ ql/src/test/org/apache/hadoop/hive/ql/io/arrow/TestArrowMap.java @@ -0,0 +1,132 @@ +package org.apache.hadoop.hive.ql.io.arrow; + +import com.google.common.collect.Maps; +import org.apache.hadoop.hive.serde2.SerDeException; +import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector; +import org.apache.hadoop.io.Text; +import org.junit.Test; + +import java.util.Map; + +import static com.google.common.collect.Lists.newArrayList; +import static org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory.getStandardMapObjectInspector; +import static org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory.getStandardStructObjectInspector; +import static org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory.*; + +public class TestArrowMap extends AbstractTest { + private Object[][] toMap(Object[][] rows) { + Map[][] array = new Map[rows.length][]; + for (int rowIndex = 0; rowIndex < rows.length; rowIndex++) { + Object[] row = rows[rowIndex]; + array[rowIndex] = new Map[row.length]; + for (int fieldIndex = 0; fieldIndex < row.length; fieldIndex++) { + Map map = Maps.newHashMap(); + map.put(new Text(String.valueOf(row[fieldIndex])), row[fieldIndex]); + array[rowIndex][fieldIndex] = map; + } + } + return array; + } + + @Test + public void testIntegerMap() throws SerDeException { + String fieldNames = "tinyint1,smallint1,int1,bigint1"; + String fieldTypes = "map,map," + + "map,map"; + + StructObjectInspector rowOI = getStandardStructObjectInspector( + newArrayList(fieldNames.split(",")), + newArrayList( + getStandardMapObjectInspector( + writableStringObjectInspector, writableByteObjectInspector), + getStandardMapObjectInspector( + writableStringObjectInspector, writableShortObjectInspector), + getStandardMapObjectInspector( + writableStringObjectInspector, writableIntObjectInspector), + getStandardMapObjectInspector( + writableStringObjectInspector, writableLongObjectInspector))); + + initAndSerializeAndDeserialize(toMap(INTEGER_ROWS), rowOI, fieldNames, fieldTypes); + } + + @Test + public void testFloatMap() throws SerDeException { + String fieldNames = "float1,double1"; + String fieldTypes = "map,map"; + + StructObjectInspector rowOI = getStandardStructObjectInspector( + newArrayList(fieldNames.split(",")), + newArrayList( + getStandardMapObjectInspector( + writableStringObjectInspector, writableFloatObjectInspector), + getStandardMapObjectInspector( + writableStringObjectInspector, writableDoubleObjectInspector))); + + initAndSerializeAndDeserialize(toMap(FLOAT_ROWS), rowOI, fieldNames, fieldTypes); + } + + @Test + public void testStringMap() throws SerDeException { + String fieldNames = "string1,char1,varchar1"; + String fieldTypes = "map,map,map"; + + StructObjectInspector rowOI = getStandardStructObjectInspector( + newArrayList(fieldNames.split(",")), + newArrayList( + getStandardMapObjectInspector( + writableStringObjectInspector, writableStringObjectInspector), + getStandardMapObjectInspector( + writableStringObjectInspector, writableHiveCharObjectInspector), + getStandardMapObjectInspector( + writableStringObjectInspector, writableHiveVarcharObjectInspector))); + + initAndSerializeAndDeserialize(toMap(STRING_ROWS), rowOI, fieldNames, fieldTypes); + } + + @Test + public void testDTIMap() throws SerDeException { + String fieldNames = "date1,timestamp1,interval_year_month1,interval_day_time1"; + String fieldTypes = "map,map,map," + + "map"; + + StructObjectInspector rowOI = getStandardStructObjectInspector( + newArrayList(fieldNames.split(",")), + newArrayList( + getStandardMapObjectInspector( + writableStringObjectInspector, writableDateObjectInspector), + getStandardMapObjectInspector( + writableStringObjectInspector, writableTimestampObjectInspector), + getStandardMapObjectInspector( + writableStringObjectInspector, writableHiveIntervalYearMonthObjectInspector), + getStandardMapObjectInspector( + writableStringObjectInspector, writableHiveIntervalDayTimeObjectInspector))); + + initAndSerializeAndDeserialize(toMap(DTI_ROWS), rowOI, fieldNames, fieldTypes); + } + + @Test + public void testBooleanMap() throws SerDeException { + String fieldNames = "boolean1"; + String fieldTypes = "map"; + + StructObjectInspector rowOI = getStandardStructObjectInspector( + newArrayList(fieldNames.split(",")), + newArrayList(getStandardMapObjectInspector( + writableStringObjectInspector, writableBooleanObjectInspector))); + + initAndSerializeAndDeserialize(toMap(BOOLEAN_ROWS), rowOI, fieldNames, fieldTypes); + } + + @Test + public void testBinaryMap() throws SerDeException { + String fieldNames = "binary1"; + String fieldTypes = "map"; + + StructObjectInspector rowOI = getStandardStructObjectInspector( + newArrayList(fieldNames.split(",")), + newArrayList(getStandardMapObjectInspector( + writableStringObjectInspector, writableBinaryObjectInspector))); + + initAndSerializeAndDeserialize(toMap(BINARY_ROWS), rowOI, fieldNames, fieldTypes); + } +} diff --git ql/src/test/org/apache/hadoop/hive/ql/io/arrow/TestArrowPrimitive.java ql/src/test/org/apache/hadoop/hive/ql/io/arrow/TestArrowPrimitive.java new file mode 100644 index 0000000000..9013cee1c4 --- /dev/null +++ ql/src/test/org/apache/hadoop/hive/ql/io/arrow/TestArrowPrimitive.java @@ -0,0 +1,215 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.hadoop.hive.ql.io.arrow; + +import org.apache.hadoop.hive.serde2.AbstractSerDe; +import org.apache.hadoop.hive.serde2.SerDeException; +import org.apache.hadoop.hive.serde2.io.DoubleWritable; +import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.WritableHiveDecimalObjectInspector; +import org.apache.hadoop.hive.serde2.typeinfo.DecimalTypeInfo; +import org.apache.hadoop.io.FloatWritable; +import org.junit.Test; + +import java.util.Random; + +import static com.google.common.collect.Lists.newArrayList; +import static org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory.*; +import static org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory.*; + +public class TestArrowPrimitive extends AbstractTest { + @Test + public void testInteger() throws SerDeException { + String fieldNames = "tinyint1,smallint1,int1,bigint1"; + String fieldTypes = "tinyint,smallint,int,bigint"; + + StructObjectInspector rowOI = getStandardStructObjectInspector( + newArrayList(fieldNames.split(",")), + newArrayList( + writableByteObjectInspector, + writableShortObjectInspector, + writableIntObjectInspector, + writableLongObjectInspector)); + + initAndSerializeAndDeserialize(INTEGER_ROWS, rowOI, fieldNames, fieldTypes); + } + + @Test + public void testBigInt10000() throws SerDeException { + String fieldNames = "bigint1"; + String fieldTypes = "bigint"; + + StructObjectInspector rowOI = getStandardStructObjectInspector( + newArrayList(fieldNames.split(",")), + newArrayList(writableLongObjectInspector)); + + final int batchSize = 1000; + final Object[][] integerRows = new Object[batchSize][]; + final AbstractSerDe serDe = new ArrowColumnarBatchSerDe(); + initSerDe(serDe, fieldNames, fieldTypes); + + for (int j = 0; j < 10; j++) { + for (int i = 0; i < batchSize; i++) { + integerRows[i] = new Object[] {longW(i + j * batchSize)}; + } + + serializeAndDeserialize(serDe, integerRows, rowOI); + } + } + + @Test + public void testBigIntRandom() { + try { + String fieldNames = "bigint1"; + String fieldTypes = "bigint"; + + StructObjectInspector rowOI = getStandardStructObjectInspector( + newArrayList(fieldNames.split(",")), + newArrayList(writableLongObjectInspector)); + + final AbstractSerDe serDe = new ArrowColumnarBatchSerDe(); + + initSerDe(serDe, fieldNames, fieldTypes); + + final Random random = new Random(); + for (int j = 0; j < 1000; j++) { + final int batchSize = random.nextInt(1000); + final Object[][] integerRows = new Object[batchSize][]; + for (int i = 0; i < batchSize; i++) { + integerRows[i] = new Object[] {longW(random.nextLong())}; + } + + serializeAndDeserialize(serDe, integerRows, rowOI); + } + } catch (Exception e) { + throw new RuntimeException(e); + } + } + + @Test + public void testFloat() throws SerDeException { + String fieldNames = "float1,double1"; + String fieldTypes = "float,double"; + + StructObjectInspector rowOI = getStandardStructObjectInspector( + newArrayList(fieldNames.split(",")), + newArrayList(writableFloatObjectInspector, writableDoubleObjectInspector)); + + initAndSerializeAndDeserialize(FLOAT_ROWS, rowOI, fieldNames, fieldTypes); + } + + @Test(expected = AssertionError.class) + public void testFloatNaN() throws SerDeException { + String fieldNames = "float1"; + String fieldTypes = "float"; + + StructObjectInspector rowOI = getStandardStructObjectInspector( + newArrayList(fieldNames.split(",")), + newArrayList(writableFloatObjectInspector)); + + Object[][] rows = {{new FloatWritable(Float.NaN)}}; + + initAndSerializeAndDeserialize(rows, rowOI, fieldNames, fieldTypes); + } + + @Test(expected = AssertionError.class) + public void testDoubleNaN() throws SerDeException { + String fieldNames = "double1"; + String fieldTypes = "double"; + + StructObjectInspector rowOI = getStandardStructObjectInspector( + newArrayList(fieldNames.split(",")), + newArrayList(writableDoubleObjectInspector)); + + Object[][] rows = {{new DoubleWritable(Double.NaN)}}; + + initAndSerializeAndDeserialize(rows, rowOI, fieldNames, fieldTypes); + } + + @Test + public void testString() throws SerDeException { + String fieldNames = "string1,char1,varchar1"; + String fieldTypes = "string,char(10),varchar(10)"; + + StructObjectInspector rowOI = getStandardStructObjectInspector( + newArrayList(fieldNames.split(",")), + newArrayList( + writableStringObjectInspector, + writableHiveCharObjectInspector, + writableHiveVarcharObjectInspector)); + + initAndSerializeAndDeserialize(STRING_ROWS, rowOI, fieldNames, fieldTypes); + } + + @Test + public void testDTI() throws SerDeException { + String fieldNames = "date1,timestamp1,interval_year_month1,interval_day_time1"; + String fieldTypes = "date,timestamp,interval_year_month," + + "interval_day_time"; + + StructObjectInspector rowOI = getStandardStructObjectInspector( + newArrayList(fieldNames.split(",")), + newArrayList( + writableDateObjectInspector, + writableTimestampObjectInspector, + writableHiveIntervalYearMonthObjectInspector, + writableHiveIntervalDayTimeObjectInspector)); + + initAndSerializeAndDeserialize(DTI_ROWS, rowOI, fieldNames, fieldTypes); + } + + @Test + public void testDecimal() throws SerDeException { + int precision = 38; + int scale = 10; + + String fieldNames = "decimal1"; + String fieldTypes = "decimal(" + precision + "," + scale + ")"; + + StructObjectInspector rowOI = getStandardStructObjectInspector( + newArrayList(fieldNames.split(",")), + newArrayList(new WritableHiveDecimalObjectInspector( + new DecimalTypeInfo(precision, scale)))); + + initAndSerializeAndDeserialize(DECIMAL_ROWS, rowOI, fieldNames, fieldTypes); + } + + @Test + public void testBoolean() throws SerDeException { + String fieldNames = "boolean1"; + String fieldTypes = "boolean"; + + StructObjectInspector rowOI = getStandardStructObjectInspector( + newArrayList(fieldNames.split(",")), + newArrayList(writableBooleanObjectInspector)); + + initAndSerializeAndDeserialize(BOOLEAN_ROWS, rowOI, fieldNames, fieldTypes); + } + + @Test + public void testBinary() throws SerDeException { + String fieldNames = "binary1"; + String fieldTypes = "binary"; + + StructObjectInspector rowOI = getStandardStructObjectInspector( + newArrayList(fieldNames.split(",")), + newArrayList(writableBinaryObjectInspector)); + + initAndSerializeAndDeserialize(BINARY_ROWS, rowOI, fieldNames, fieldTypes); + } +} \ No newline at end of file diff --git ql/src/test/org/apache/hadoop/hive/ql/io/arrow/TestArrowStruct.java ql/src/test/org/apache/hadoop/hive/ql/io/arrow/TestArrowStruct.java new file mode 100644 index 0000000000..e5a12414fc --- /dev/null +++ ql/src/test/org/apache/hadoop/hive/ql/io/arrow/TestArrowStruct.java @@ -0,0 +1,133 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.hadoop.hive.ql.io.arrow; + +import org.apache.hadoop.hive.serde2.SerDeException; +import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector; +import org.junit.Test; + +import static com.google.common.collect.Lists.newArrayList; +import static org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory.getStandardStructObjectInspector; +import static org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory.*; + +public class TestArrowStruct extends AbstractTest { + private Object[][][] toStruct(Object[][] rows) { + Object[][][] struct = new Object[rows.length][][]; + for (int rowIndex = 0; rowIndex < rows.length; rowIndex++) { + Object[] row = rows[rowIndex]; + struct[rowIndex] = new Object[][] {row}; + } + return struct; + } + + @Test + public void testIntegerStruct() throws SerDeException { + String fieldNames = "int_struct"; + String fieldTypes = "struct"; + + StructObjectInspector rowOI = getStandardStructObjectInspector( + newArrayList(fieldNames.split(",")), + newArrayList(getStandardStructObjectInspector( + newArrayList("tiny1", "smallint1", "int1", "bigint1"), + newArrayList( + writableByteObjectInspector, + writableShortObjectInspector, + writableIntObjectInspector, + writableLongObjectInspector)))); + + initAndSerializeAndDeserialize(toStruct(INTEGER_ROWS), rowOI, fieldNames, fieldTypes); + } + + @Test + public void testFloatStruct() throws SerDeException { + String fieldNames = "float_struct"; + String fieldTypes = "struct"; + + StructObjectInspector rowOI = getStandardStructObjectInspector( + newArrayList(fieldNames.split(",")), + newArrayList(getStandardStructObjectInspector( + newArrayList("float1", "double1"), + newArrayList(writableFloatObjectInspector, writableDoubleObjectInspector)))); + + initAndSerializeAndDeserialize(toStruct(FLOAT_ROWS), rowOI, fieldNames, fieldTypes); + } + + @Test + public void testStringStruct() throws SerDeException { + String fieldNames = "string_struct"; + String fieldTypes = "struct"; + + StructObjectInspector rowOI = getStandardStructObjectInspector( + newArrayList(fieldNames.split(",")), + newArrayList(getStandardStructObjectInspector( + newArrayList("string1", "char1", "varchar1"), + newArrayList( + writableStringObjectInspector, + writableHiveCharObjectInspector, + writableHiveVarcharObjectInspector)))); + + initAndSerializeAndDeserialize(toStruct(STRING_ROWS), rowOI, fieldNames, fieldTypes); + } + + @Test + public void testDTIStruct() throws SerDeException { + String fieldNames = "date_struct"; + String fieldTypes = "struct"; + + StructObjectInspector rowOI = getStandardStructObjectInspector( + newArrayList(fieldNames.split(",")), + newArrayList(getStandardStructObjectInspector( + newArrayList("date1", "timestamp1", "interval_year_month1", "interval_day_time1"), + newArrayList( + writableDateObjectInspector, + writableTimestampObjectInspector, + writableHiveIntervalYearMonthObjectInspector, + writableHiveIntervalDayTimeObjectInspector)))); + + initAndSerializeAndDeserialize(toStruct(DTI_ROWS), rowOI, fieldNames, fieldTypes); + } + + @Test + public void testBooleanStruct() throws SerDeException { + String fieldNames = "boolean_struct"; + String fieldTypes = "struct"; + + StructObjectInspector rowOI = getStandardStructObjectInspector( + newArrayList(fieldNames.split(",")), + newArrayList(getStandardStructObjectInspector( + newArrayList("boolean1"), + newArrayList(writableBooleanObjectInspector)))); + + initAndSerializeAndDeserialize(toStruct(BOOLEAN_ROWS), rowOI, fieldNames, fieldTypes); + } + + @Test + public void testBinaryStruct() throws SerDeException { + String fieldNames = "binary_struct"; + String fieldTypes = "struct"; + + StructObjectInspector rowOI = getStandardStructObjectInspector( + newArrayList(fieldNames.split(",")), + newArrayList(getStandardStructObjectInspector( + newArrayList("binary1"), + newArrayList(writableBinaryObjectInspector)))); + + initAndSerializeAndDeserialize(toStruct(BINARY_ROWS), rowOI, fieldNames, fieldTypes); + } +} diff --git ql/src/test/org/apache/hadoop/hive/ql/io/arrow/TestArrowUnion.java ql/src/test/org/apache/hadoop/hive/ql/io/arrow/TestArrowUnion.java new file mode 100644 index 0000000000..3a88394809 --- /dev/null +++ ql/src/test/org/apache/hadoop/hive/ql/io/arrow/TestArrowUnion.java @@ -0,0 +1,168 @@ +package org.apache.hadoop.hive.ql.io.arrow; + +import org.apache.hadoop.hive.common.type.HiveIntervalDayTime; +import org.apache.hadoop.hive.common.type.HiveIntervalYearMonth; +import org.apache.hadoop.hive.serde2.SerDeException; +import org.apache.hadoop.hive.serde2.io.DateWritable; +import org.apache.hadoop.hive.serde2.io.HiveIntervalDayTimeWritable; +import org.apache.hadoop.hive.serde2.io.HiveIntervalYearMonthWritable; +import org.apache.hadoop.hive.serde2.io.TimestampWritable; +import org.apache.hadoop.hive.serde2.objectinspector.StandardUnionObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector; +import org.apache.hadoop.io.BooleanWritable; +import org.apache.hadoop.io.BytesWritable; +import org.junit.Test; + +import java.sql.Timestamp; + +import static com.google.common.collect.Lists.newArrayList; +import static org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory.getStandardStructObjectInspector; +import static org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory.getStandardUnionObjectInspector; +import static org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory.*; + +public class TestArrowUnion extends AbstractTest { + private StandardUnionObjectInspector.StandardUnion union(int tag, Object object) { + return new StandardUnionObjectInspector.StandardUnion((byte) tag, object); + } + + @Test + public void testIntegerUnion() throws SerDeException { + String fieldNames = "int_union"; + String fieldTypes = "uniontype"; + + StructObjectInspector rowOI = getStandardStructObjectInspector( + newArrayList(fieldNames.split(",")), + newArrayList(getStandardUnionObjectInspector(newArrayList( + writableByteObjectInspector, + writableShortObjectInspector, + writableIntObjectInspector, + writableLongObjectInspector)))); + + StandardUnionObjectInspector.StandardUnion[][] integerUnions = { + {union(0, byteW(0))}, + {union(1, shortW(1))}, + {union(2, intW(2))}, + {union(3, longW(3))}, + }; + + initAndSerializeAndDeserialize(integerUnions, rowOI, fieldNames, fieldTypes); + } + + @Test + public void testFloatUnion() throws SerDeException { + String fieldNames = "float_union"; + String fieldTypes = "uniontype"; + + StructObjectInspector rowOI = getStandardStructObjectInspector( + newArrayList(fieldNames.split(",")), + newArrayList(getStandardUnionObjectInspector(newArrayList( + writableFloatObjectInspector, writableDoubleObjectInspector)))); + + StandardUnionObjectInspector.StandardUnion[][] floatUnions = { + {union(0, floatW(0f))}, + {union(1, doubleW(1d))}, + }; + + initAndSerializeAndDeserialize(floatUnions, rowOI, fieldNames, fieldTypes); + } + + @Test + public void testStringUnion() throws SerDeException { + String fieldNames = "string_union"; + String fieldTypes = "uniontype"; + + StructObjectInspector rowOI = getStandardStructObjectInspector( + newArrayList(fieldNames.split(",")), + newArrayList(getStandardUnionObjectInspector(newArrayList( + writableStringObjectInspector, + writableIntObjectInspector)))); + + StandardUnionObjectInspector.StandardUnion[][] stringUnions = { + {union(0, text("Hello"))}, + {union(1, intW(1))}, + }; + + initAndSerializeAndDeserialize(stringUnions, rowOI, fieldNames, fieldTypes); + } + + @Test + public void testCharUnion() throws SerDeException { + String fieldNames = "char_union"; + String fieldTypes = "uniontype"; + + StructObjectInspector rowOI = getStandardStructObjectInspector( + newArrayList(fieldNames.split(",")), + newArrayList(getStandardUnionObjectInspector(newArrayList( + writableHiveCharObjectInspector, + writableIntObjectInspector)))); + + StandardUnionObjectInspector.StandardUnion[][] charUnions = { + {union(0, charW("Hello"))}, + {union(1, intW(1))}, + }; + + initAndSerializeAndDeserialize(charUnions, rowOI, fieldNames, fieldTypes); + } + + @Test + public void testVarcharUnion() throws SerDeException { + String fieldNames = "varchar_union"; + String fieldTypes = "uniontype"; + + StructObjectInspector rowOI = getStandardStructObjectInspector( + newArrayList(fieldNames.split(",")), + newArrayList(getStandardUnionObjectInspector(newArrayList( + writableHiveVarcharObjectInspector, + writableIntObjectInspector)))); + + StandardUnionObjectInspector.StandardUnion[][] varcharUnions = { + {union(0, varcharW("Hello"))}, + {union(1, intW(1))}, + }; + + initAndSerializeAndDeserialize(varcharUnions, rowOI, fieldNames, fieldTypes); + } + + @Test + public void testDTIUnion() throws SerDeException { + String fieldNames = "date_union"; + String fieldTypes = "uniontype"; + + StructObjectInspector rowOI = getStandardStructObjectInspector( + newArrayList(fieldNames.split(",")), + newArrayList(getStandardUnionObjectInspector(newArrayList( + writableDateObjectInspector, + writableTimestampObjectInspector, + writableHiveIntervalYearMonthObjectInspector, + writableHiveIntervalDayTimeObjectInspector)))); + + long NOW = System.currentTimeMillis(); + + StandardUnionObjectInspector.StandardUnion[][] dtiUnions = { + {union(0, new DateWritable(DateWritable.millisToDays(NOW)))}, + {union(1, new TimestampWritable(new Timestamp(NOW)))}, + {union(2, new HiveIntervalYearMonthWritable(new HiveIntervalYearMonth(1, 2)))}, + {union(3, new HiveIntervalDayTimeWritable(new HiveIntervalDayTime(1, 2, 3, 4, 5_000_000)))}, + }; + + initAndSerializeAndDeserialize(dtiUnions, rowOI, fieldNames, fieldTypes); + } + + @Test + public void testBooleanBinaryUnion() throws SerDeException { + String fieldNames = "boolean_union"; + String fieldTypes = "uniontype"; + + StructObjectInspector rowOI = getStandardStructObjectInspector( + newArrayList(fieldNames.split(",")), + newArrayList(getStandardUnionObjectInspector(newArrayList( + writableBooleanObjectInspector, writableBinaryObjectInspector)))); + + StandardUnionObjectInspector.StandardUnion[][] booleanBinaryUnions = { + {union(0, new BooleanWritable(true))}, + {union(1, new BytesWritable("Hello".getBytes()))}, + }; + + initAndSerializeAndDeserialize(booleanBinaryUnions, rowOI, fieldNames, fieldTypes); + } +} diff --git serde/pom.xml serde/pom.xml index 3beeebd4df..0b8221b2e2 100644 --- serde/pom.xml +++ serde/pom.xml @@ -65,6 +65,11 @@ commons-lang ${commons-lang.version} + + org.apache.arrow + arrow-vector + ${arrow.version} + org.apache.avro avro