diff --git ql/src/java/org/apache/hadoop/hive/ql/io/arrow/ArrowColumnarBatchSerDe.java ql/src/java/org/apache/hadoop/hive/ql/io/arrow/ArrowColumnarBatchSerDe.java index ed82d2d01e..4232a36205 100644 --- ql/src/java/org/apache/hadoop/hive/ql/io/arrow/ArrowColumnarBatchSerDe.java +++ ql/src/java/org/apache/hadoop/hive/ql/io/arrow/ArrowColumnarBatchSerDe.java @@ -27,10 +27,6 @@ import org.apache.arrow.vector.types.pojo.Field; import org.apache.arrow.vector.types.pojo.FieldType; import org.apache.hadoop.conf.Configuration; -import org.apache.hadoop.hive.ql.exec.vector.ColumnVector; -import org.apache.hadoop.hive.ql.exec.vector.ListColumnVector; -import org.apache.hadoop.hive.ql.exec.vector.MapColumnVector; -import org.apache.hadoop.hive.ql.exec.vector.StructColumnVector; import org.apache.hadoop.hive.ql.exec.vector.VectorAssignRow; import org.apache.hadoop.hive.serde.serdeConstants; import org.apache.hadoop.hive.serde2.AbstractSerDe; @@ -232,21 +228,6 @@ static ListTypeInfo toStructListTypeInfo(MapTypeInfo mapTypeInfo) { return structListTypeInfo; } - static ListColumnVector toStructListVector(MapColumnVector mapVector) { - final StructColumnVector structVector; - final ListColumnVector structListVector; - structVector = new StructColumnVector(); - structVector.fields = new ColumnVector[] {mapVector.keys, mapVector.values}; - structListVector = new ListColumnVector(); - structListVector.child = structVector; - structListVector.childCount = mapVector.childCount; - structListVector.isRepeating = mapVector.isRepeating; - structListVector.noNulls = mapVector.noNulls; - System.arraycopy(mapVector.offsets, 0, structListVector.offsets, 0, mapVector.childCount); - System.arraycopy(mapVector.lengths, 0, structListVector.lengths, 0, mapVector.childCount); - return structListVector; - } - @Override public Class getSerializedClass() { return ArrowWrapperWritable.class; diff --git ql/src/java/org/apache/hadoop/hive/ql/io/arrow/Deserializer.java ql/src/java/org/apache/hadoop/hive/ql/io/arrow/Deserializer.java index edc4b39922..1d9892b6c1 100644 --- ql/src/java/org/apache/hadoop/hive/ql/io/arrow/Deserializer.java +++ ql/src/java/org/apache/hadoop/hive/ql/io/arrow/Deserializer.java @@ -34,6 +34,7 @@ import org.apache.arrow.vector.VarBinaryVector; import org.apache.arrow.vector.VarCharVector; import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.complex.ListVector; import org.apache.arrow.vector.holders.NullableIntervalDayHolder; import org.apache.arrow.vector.types.Types; import org.apache.hadoop.hive.common.type.HiveDecimal; @@ -64,6 +65,7 @@ import java.util.List; import static org.apache.hadoop.hive.ql.exec.vector.VectorizedBatchUtil.createColumnVector; +import static org.apache.hadoop.hive.ql.io.arrow.ArrowColumnarBatchSerDe.MICROS_PER_MILLIS; import static org.apache.hadoop.hive.ql.io.arrow.ArrowColumnarBatchSerDe.MICROS_PER_SECOND; import static org.apache.hadoop.hive.ql.io.arrow.ArrowColumnarBatchSerDe.MILLIS_PER_SECOND; import static org.apache.hadoop.hive.ql.io.arrow.ArrowColumnarBatchSerDe.NS_PER_MICROS; @@ -71,7 +73,6 @@ import static org.apache.hadoop.hive.ql.io.arrow.ArrowColumnarBatchSerDe.NS_PER_SECOND; import static org.apache.hadoop.hive.ql.io.arrow.ArrowColumnarBatchSerDe.SECOND_PER_DAY; import static org.apache.hadoop.hive.ql.io.arrow.ArrowColumnarBatchSerDe.toStructListTypeInfo; -import static org.apache.hadoop.hive.ql.io.arrow.ArrowColumnarBatchSerDe.toStructListVector; class Deserializer { private final ArrowColumnarBatchSerDe serDe; @@ -265,12 +266,8 @@ private void readPrimitive(FieldVector arrowVector, ColumnVector hiveVector) { } } break; - case TIMESTAMPMILLI: - case TIMESTAMPMILLITZ: case TIMESTAMPMICRO: case TIMESTAMPMICROTZ: - case TIMESTAMPNANO: - case TIMESTAMPNANOTZ: { for (int i = 0; i < size; i++) { if (arrowVector.isNull(i)) { @@ -278,48 +275,19 @@ private void readPrimitive(FieldVector arrowVector, ColumnVector hiveVector) { } else { hiveVector.isNull[i] = false; - // Time = second + sub-second - final long time = ((TimeStampVector) arrowVector).get(i); - long second; - int subSecondInNanos; - switch (minorType) { - case TIMESTAMPMILLI: - case TIMESTAMPMILLITZ: - { - subSecondInNanos = (int) ((time % MILLIS_PER_SECOND) * NS_PER_MILLIS); - second = time / MILLIS_PER_SECOND; - } - break; - case TIMESTAMPMICROTZ: - case TIMESTAMPMICRO: - { - subSecondInNanos = (int) ((time % MICROS_PER_SECOND) * NS_PER_MICROS); - second = time / MICROS_PER_SECOND; - } - break; - case TIMESTAMPNANOTZ: - case TIMESTAMPNANO: - { - subSecondInNanos = (int) (time % NS_PER_SECOND); - second = time / NS_PER_SECOND; - } - break; - default: - throw new IllegalArgumentException(); - } - + final long timeInMicros = ((TimeStampVector) arrowVector).get(i); final TimestampColumnVector timestampColumnVector = (TimestampColumnVector) hiveVector; - // A nanosecond value should not be negative - if (subSecondInNanos < 0) { - - // So add one second to the negative nanosecond value to make it positive - subSecondInNanos += NS_PER_SECOND; - - // Subtract one second from the second value because we added one second - second -= 1; + long timeInMillis = timeInMicros / MICROS_PER_MILLIS; + int nanos = (int) (timeInMicros % MICROS_PER_SECOND) * NS_PER_MICROS; + if (nanos < 0) { + nanos += NS_PER_SECOND; + timeInMillis--; + if (nanos >= NS_PER_MILLIS) { + timeInMillis -= MICROS_PER_MILLIS; + } } - timestampColumnVector.time[i] = second * MILLIS_PER_SECOND; - timestampColumnVector.nanos[i] = subSecondInNanos; + timestampColumnVector.time[i] = timeInMillis; + timestampColumnVector.nanos[i] = nanos; } } } @@ -409,9 +377,13 @@ private void readList(FieldVector arrowVector, ListColumnVector hiveVector, List private void readMap(FieldVector arrowVector, MapColumnVector hiveVector, MapTypeInfo typeInfo) { final int size = arrowVector.getValueCount(); + final int childSize = ((ListVector) arrowVector).getDataVector().getValueCount(); final ListTypeInfo mapStructListTypeInfo = toStructListTypeInfo(typeInfo); - final ListColumnVector mapStructListVector = toStructListVector(hiveVector); - final StructColumnVector mapStructVector = (StructColumnVector) mapStructListVector.child; + + final StructColumnVector mapStructVector = + new StructColumnVector(childSize, hiveVector.keys, hiveVector.values); + final ListColumnVector mapStructListVector = + new ListColumnVector(size, mapStructVector); read(arrowVector, mapStructListVector, mapStructListTypeInfo); diff --git ql/src/java/org/apache/hadoop/hive/ql/io/arrow/Serializer.java ql/src/java/org/apache/hadoop/hive/ql/io/arrow/Serializer.java index 08e0fb2b7f..e5602403dd 100644 --- ql/src/java/org/apache/hadoop/hive/ql/io/arrow/Serializer.java +++ ql/src/java/org/apache/hadoop/hive/ql/io/arrow/Serializer.java @@ -86,7 +86,6 @@ import static org.apache.hadoop.hive.ql.io.arrow.ArrowColumnarBatchSerDe.NS_PER_MILLIS; import static org.apache.hadoop.hive.ql.io.arrow.ArrowColumnarBatchSerDe.SECOND_PER_DAY; import static org.apache.hadoop.hive.ql.io.arrow.ArrowColumnarBatchSerDe.toStructListTypeInfo; -import static org.apache.hadoop.hive.ql.io.arrow.ArrowColumnarBatchSerDe.toStructListVector; import static org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils.ObjectInspectorCopyOption.WRITABLE; import static org.apache.hadoop.hive.serde2.typeinfo.TypeInfoUtils.getTypeInfoFromObjectInspector; @@ -181,6 +180,8 @@ public ArrowWrapperWritable emptyBatch() { public ArrowWrapperWritable serializeBatch(VectorizedRowBatch vectorizedRowBatch, boolean isNative) { rootVector.setValueCount(0); + final int size = isNative ? vectorizedRowBatch.size : batchSize; + for (int fieldIndex = 0; fieldIndex < vectorizedRowBatch.projectionSize; fieldIndex++) { final int projectedColumn = vectorizedRowBatch.projectedColumns[fieldIndex]; final ColumnVector hiveVector = vectorizedRowBatch.cols[projectedColumn]; @@ -195,20 +196,18 @@ public ArrowWrapperWritable serializeBatch(VectorizedRowBatch vectorizedRowBatch } final FieldVector arrowVector = rootVector.addOrGet(fieldName, fieldType, FieldVector.class); if(fieldExists) { - arrowVector.setValueCount(isNative ? vectorizedRowBatch.size : batchSize); + arrowVector.setValueCount(size); } else { - arrowVector.setInitialCapacity(isNative ? vectorizedRowBatch.size : batchSize); + arrowVector.setInitialCapacity(size); arrowVector.allocateNew(); } - write(arrowVector, hiveVector, fieldTypeInfo, isNative ? vectorizedRowBatch.size : batchSize, vectorizedRowBatch, isNative); + write(arrowVector, hiveVector, fieldTypeInfo, size, vectorizedRowBatch, isNative); } if(!isNative) { //Only mutate batches that are constructed by this serde vectorizedRowBatch.reset(); - rootVector.setValueCount(batchSize); - } else { - rootVector.setValueCount(vectorizedRowBatch.size); } + rootVector.setValueCount(size); batchSize = 0; VectorSchemaRoot vectorSchemaRoot = new VectorSchemaRoot(rootVector); @@ -296,10 +295,21 @@ private static void write(FieldVector arrowVector, ColumnVector hiveVector, Type } } - private static void writeMap(ListVector arrowVector, MapColumnVector hiveVector, MapTypeInfo typeInfo, - int size, VectorizedRowBatch vectorizedRowBatch, boolean isNative) { + private static void writeMap(ListVector arrowVector, MapColumnVector hiveVector, + MapTypeInfo typeInfo, int size, VectorizedRowBatch vectorizedRowBatch, boolean isNative) { + final ListTypeInfo structListTypeInfo = toStructListTypeInfo(typeInfo); - final ListColumnVector structListVector = toStructListVector(hiveVector); + final StructColumnVector structVector = + new StructColumnVector(hiveVector.childCount, hiveVector.keys, hiveVector.values); + final ListColumnVector structListVector = + new ListColumnVector(hiveVector.isNull.length, structVector); + + structListVector.childCount = hiveVector.childCount; + structListVector.isRepeating = hiveVector.isRepeating; + structListVector.noNulls = hiveVector.noNulls; + System.arraycopy(hiveVector.offsets, 0, structListVector.offsets, 0, hiveVector.isNull.length); + System.arraycopy(hiveVector.lengths, 0, structListVector.lengths, 0, hiveVector.isNull.length); + System.arraycopy(hiveVector.isNull, 0, structListVector.isNull, 0, hiveVector.isNull.length); write(arrowVector, structListVector, structListTypeInfo, size, vectorizedRowBatch, isNative); @@ -779,16 +789,10 @@ private static void writeGeneric(final FieldVector fieldVector, final ColumnVect -> { final TimeStampMicroTZVector timeStampMicroTZVector = (TimeStampMicroTZVector) arrowVector; final TimestampColumnVector timestampColumnVector = (TimestampColumnVector) hiveVector; - // Time = second + sub-second - final long secondInMillis = timestampColumnVector.getTime(j); - final long secondInMicros = (secondInMillis - secondInMillis % MILLIS_PER_SECOND) * MICROS_PER_MILLIS; - final long subSecondInMicros = timestampColumnVector.getNanos(j) / NS_PER_MICROS; - if ((secondInMillis > 0 && secondInMicros < 0) || (secondInMillis < 0 && secondInMicros > 0)) { - // If the timestamp cannot be represented in long microsecond, set it as a null value - timeStampMicroTZVector.setNull(i); - } else { - timeStampMicroTZVector.set(i, secondInMicros + subSecondInMicros); - } + final long timeInMicros = timestampColumnVector.time[j] / MICROS_PER_MILLIS + * MICROS_PER_MILLIS * MICROS_PER_MILLIS; + final long nanosInMicros = timestampColumnVector.nanos[j] / NS_PER_MICROS; + timeStampMicroTZVector.set(i, timeInMicros + nanosInMicros); }; //binary @@ -797,7 +801,12 @@ private static void writeGeneric(final FieldVector fieldVector, final ColumnVect private static final IntIntAndVectorsConsumer binaryValueSetter = (i, j, arrowVector, hiveVector) -> { BytesColumnVector bytesVector = (BytesColumnVector) hiveVector; - ((VarBinaryVector) arrowVector).setSafe(i, bytesVector.vector[j], bytesVector.start[j], bytesVector.length[j]); + VarBinaryVector varBinaryVector = (VarBinaryVector) arrowVector; + if (bytesVector.vector[j] == null) { + varBinaryVector.setNull(i); + } else { + varBinaryVector.setSafe(i, bytesVector.vector[j], bytesVector.start[j], bytesVector.length[j]); + } }; //decimal and decimal64 diff --git ql/src/test/org/apache/hadoop/hive/ql/exec/vector/VectorRandomRowSource.java ql/src/test/org/apache/hadoop/hive/ql/exec/vector/VectorRandomRowSource.java index b84273ade5..a7b96b36e9 100644 --- ql/src/test/org/apache/hadoop/hive/ql/exec/vector/VectorRandomRowSource.java +++ ql/src/test/org/apache/hadoop/hive/ql/exec/vector/VectorRandomRowSource.java @@ -40,7 +40,6 @@ import org.apache.hadoop.hive.common.type.Timestamp; import org.apache.hadoop.hive.serde2.RandomTypeUtil; import org.apache.hadoop.hive.serde2.io.HiveCharWritable; -import org.apache.hadoop.hive.serde2.io.HiveDecimalWritable; import org.apache.hadoop.hive.serde2.io.HiveVarcharWritable; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector.Category; @@ -86,7 +85,6 @@ import org.apache.hadoop.hive.serde2.typeinfo.VarcharTypeInfo; import org.apache.hive.common.util.DateUtils; import org.apache.hadoop.io.Text; -import org.apache.hadoop.io.BooleanWritable; import org.apache.hadoop.io.BytesWritable; import org.apache.hadoop.io.LongWritable; @@ -287,7 +285,7 @@ public StructObjectInspector partialRowStructObjectInspector(int partialFieldCou } public enum SupportedTypes { - ALL, PRIMITIVES, ALL_EXCEPT_MAP + ALL, PRIMITIVES, ALL_EXCEPT_MAP, ALL_EXCEPT_UNION } public void init(Random r, SupportedTypes supportedTypes, int maxComplexDepth, boolean allowNull, @@ -368,6 +366,18 @@ public void initGenerationSpecSchema(Random r, List generationSp "map" }; + private static String[] possibleHiveComplexTypeNamesWithoutMap = { + "array", + "struct", + "uniontype" + }; + + private static String[] possibleHiveComplexTypeNamesWithoutUnion = { + "array", + "struct", + "map" + }; + public static String getRandomTypeName(Random random, SupportedTypes supportedTypes, Set allowedTypeNameSet) { @@ -381,7 +391,12 @@ public static String getRandomTypeName(Random random, SupportedTypes supportedTy typeName = possibleHivePrimitiveTypeNames[random.nextInt(possibleHivePrimitiveTypeNames.length)]; break; case ALL_EXCEPT_MAP: - typeName = possibleHiveComplexTypeNames[random.nextInt(possibleHiveComplexTypeNames.length - 1)]; + typeName = possibleHiveComplexTypeNamesWithoutMap[random.nextInt( + possibleHiveComplexTypeNamesWithoutMap.length)]; + break; + case ALL_EXCEPT_UNION: + typeName = possibleHiveComplexTypeNamesWithoutUnion[random.nextInt( + possibleHiveComplexTypeNamesWithoutUnion.length)]; break; case ALL: typeName = possibleHiveComplexTypeNames[random.nextInt(possibleHiveComplexTypeNames.length)]; @@ -588,10 +603,16 @@ private void chooseSchema(SupportedTypes supportedTypes, Set allowedType if (allTypes) { switch (supportedTypes) { case ALL: - columnCount = possibleHivePrimitiveTypeNames.length + possibleHiveComplexTypeNames.length; + columnCount = possibleHivePrimitiveTypeNames.length + + possibleHiveComplexTypeNames.length; break; case ALL_EXCEPT_MAP: - columnCount = possibleHivePrimitiveTypeNames.length + possibleHiveComplexTypeNames.length - 1; + columnCount = possibleHivePrimitiveTypeNames.length + + possibleHiveComplexTypeNamesWithoutMap.length; + break; + case ALL_EXCEPT_UNION: + columnCount = possibleHivePrimitiveTypeNames.length + + possibleHiveComplexTypeNamesWithoutUnion.length; break; case PRIMITIVES: columnCount = possibleHivePrimitiveTypeNames.length; @@ -632,7 +653,10 @@ private void chooseSchema(SupportedTypes supportedTypes, Set allowedType maxTypeNum = possibleHivePrimitiveTypeNames.length; break; case ALL_EXCEPT_MAP: - maxTypeNum = possibleHivePrimitiveTypeNames.length + possibleHiveComplexTypeNames.length - 1; + maxTypeNum = possibleHivePrimitiveTypeNames.length + possibleHiveComplexTypeNamesWithoutMap.length; + break; + case ALL_EXCEPT_UNION: + maxTypeNum = possibleHivePrimitiveTypeNames.length + possibleHiveComplexTypeNamesWithoutUnion.length; break; case ALL: maxTypeNum = possibleHivePrimitiveTypeNames.length + possibleHiveComplexTypeNames.length; @@ -656,14 +680,31 @@ private void chooseSchema(SupportedTypes supportedTypes, Set allowedType if (supportedTypes == SupportedTypes.ALL_EXCEPT_MAP) { typeNum--; } + if (supportedTypes == SupportedTypes.ALL_EXCEPT_UNION) { + typeNum--; + } } } if (typeNum < possibleHivePrimitiveTypeNames.length) { typeName = possibleHivePrimitiveTypeNames[typeNum]; } else { - typeName = possibleHiveComplexTypeNames[typeNum - possibleHivePrimitiveTypeNames.length]; + switch (supportedTypes) { + case ALL: + typeName = possibleHiveComplexTypeNames[typeNum - + possibleHivePrimitiveTypeNames.length]; + break; + case ALL_EXCEPT_MAP: + typeName = possibleHiveComplexTypeNamesWithoutMap[typeNum - + possibleHivePrimitiveTypeNames.length]; + break; + case ALL_EXCEPT_UNION: + typeName = possibleHiveComplexTypeNamesWithoutUnion[typeNum - + possibleHivePrimitiveTypeNames.length]; + break; + default: + throw new IllegalArgumentException(); + } } - } String decoratedTypeName = diff --git ql/src/test/org/apache/hadoop/hive/ql/io/arrow/TestArrowColumnarBatchSerDe.java ql/src/test/org/apache/hadoop/hive/ql/io/arrow/TestArrowColumnarBatchSerDe.java index c9a5812e47..7702bdfded 100644 --- ql/src/test/org/apache/hadoop/hive/ql/io/arrow/TestArrowColumnarBatchSerDe.java +++ ql/src/test/org/apache/hadoop/hive/ql/io/arrow/TestArrowColumnarBatchSerDe.java @@ -28,6 +28,7 @@ import org.apache.hadoop.hive.common.type.HiveVarchar; import org.apache.hadoop.hive.common.type.Timestamp; import org.apache.hadoop.hive.conf.HiveConf; +import org.apache.hadoop.hive.ql.exec.vector.VectorRandomRowSource; import org.apache.hadoop.hive.serde.serdeConstants; import org.apache.hadoop.hive.serde2.AbstractSerDe; import org.apache.hadoop.hive.serde2.SerDeException; @@ -42,6 +43,7 @@ import org.apache.hadoop.hive.serde2.io.HiveVarcharWritable; import org.apache.hadoop.hive.serde2.io.ShortWritable; import org.apache.hadoop.hive.serde2.io.TimestampWritableV2; +import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.StructField; @@ -58,18 +60,17 @@ import org.junit.Before; import org.junit.Test; +import java.util.ArrayList; +import java.util.Arrays; import java.util.Collections; import java.util.List; import java.util.Map; import java.util.Objects; import java.util.Properties; import java.util.Random; -import java.util.Set; import java.util.concurrent.TimeUnit; -import static org.junit.Assert.assertArrayEquals; import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertNull; public class TestArrowColumnarBatchSerDe { private Configuration conf; @@ -204,11 +205,30 @@ private void initAndSerializeAndDeserialize(String[][] schema, Object[][] rows) serializeAndDeserialize(serDe, rows, rowOI); } + private StructObjectInspector initSerDe(AbstractSerDe serDe, TypeInfo[] typeInfos) + throws SerDeException { + List fieldNameList = new ArrayList<>(); + List fieldTypeList = new ArrayList<>(); + + for (int i = 0; i < typeInfos.length; i++) { + fieldNameList.add("col" + i); + fieldTypeList.add(typeInfos[i].getTypeName()); + } + + Properties schemaProperties = new Properties(); + schemaProperties.setProperty(serdeConstants.LIST_COLUMNS, Joiner.on(',').join(fieldNameList)); + schemaProperties.setProperty(serdeConstants.LIST_COLUMN_TYPES, + Joiner.on(',').join(fieldTypeList)); + SerDeUtils.initializeSerDe(serDe, conf, schemaProperties, null); + return (StructObjectInspector) TypeInfoUtils.getStandardWritableObjectInspectorFromTypeInfo( + TypeInfoFactory.getStructTypeInfo(fieldNameList, Arrays.asList(typeInfos))); + } + private StructObjectInspector initSerDe(AbstractSerDe serDe, String[][] schema) throws SerDeException { - List fieldNameList = newArrayList(); - List fieldTypeList = newArrayList(); - List typeInfoList = newArrayList(); + List fieldNameList = new ArrayList<>(); + List fieldTypeList = new ArrayList<>(); + List typeInfoList = new ArrayList<>(); for (String[] nameAndType : schema) { String name = nameAndType[0]; @@ -218,12 +238,10 @@ private StructObjectInspector initSerDe(AbstractSerDe serDe, String[][] schema) typeInfoList.add(TypeInfoUtils.getTypeInfoFromTypeString(type)); } - String fieldNames = Joiner.on(',').join(fieldNameList); - String fieldTypes = Joiner.on(',').join(fieldTypeList); - Properties schemaProperties = new Properties(); - schemaProperties.setProperty(serdeConstants.LIST_COLUMNS, fieldNames); - schemaProperties.setProperty(serdeConstants.LIST_COLUMN_TYPES, fieldTypes); + schemaProperties.setProperty(serdeConstants.LIST_COLUMNS, Joiner.on(',').join(fieldNameList)); + schemaProperties.setProperty(serdeConstants.LIST_COLUMN_TYPES, + Joiner.on(',').join(fieldTypeList)); SerDeUtils.initializeSerDe(serDe, conf, schemaProperties, null); return (StructObjectInspector) TypeInfoUtils.getStandardWritableObjectInspectorFromTypeInfo( TypeInfoFactory.getStructTypeInfo(fieldNameList, typeInfoList)); @@ -239,63 +257,99 @@ private void serializeAndDeserialize(ArrowColumnarBatchSerDe serDe, Object[][] r if (serialized == null) { serialized = serDe.serialize(null, rowOI); } - String s = serialized.getVectorSchemaRoot().contentToTSVString(); final Object[][] deserializedRows = (Object[][]) serDe.deserialize(serialized); for (int rowIndex = 0; rowIndex < Math.min(deserializedRows.length, rows.length); rowIndex++) { final Object[] row = rows[rowIndex]; final Object[] deserializedRow = deserializedRows[rowIndex]; - assertEquals(row.length, deserializedRow.length); - - final List fields = rowOI.getAllStructFieldRefs(); - for (int fieldIndex = 0; fieldIndex < fields.size(); fieldIndex++) { - final StructField field = fields.get(fieldIndex); - final ObjectInspector fieldObjInspector = field.getFieldObjectInspector(); - switch (fieldObjInspector.getCategory()) { - case PRIMITIVE: - final PrimitiveObjectInspector primitiveObjInspector = - (PrimitiveObjectInspector) fieldObjInspector; - switch (primitiveObjInspector.getPrimitiveCategory()) { - case STRING: - case VARCHAR: - case CHAR: - assertEquals(Objects.toString(row[fieldIndex]), - Objects.toString(deserializedRow[fieldIndex])); - break; - default: - assertEquals(row[fieldIndex], deserializedRow[fieldIndex]); - break; - } - break; - case STRUCT: - final Object[] rowStruct = (Object[]) row[fieldIndex]; - final List deserializedRowStruct = (List) deserializedRow[fieldIndex]; - if (rowStruct == null) { - assertNull(deserializedRowStruct); - } else { - assertArrayEquals(rowStruct, deserializedRowStruct.toArray()); - } - break; - case LIST: - case UNION: - assertEquals(row[fieldIndex], deserializedRow[fieldIndex]); - break; - case MAP: - final Map rowMap = (Map) row[fieldIndex]; - final Map deserializedRowMap = (Map) deserializedRow[fieldIndex]; - if (rowMap == null) { - assertNull(deserializedRowMap); - } else { - final Set rowMapKeySet = rowMap.keySet(); - final Set deserializedRowMapKeySet = deserializedRowMap.keySet(); - assertEquals(rowMapKeySet, deserializedRowMapKeySet); - for (Object key : rowMapKeySet) { - assertEquals(rowMap.get(key), deserializedRowMap.get(key)); - } - } - break; - } + compareStruct(row, Arrays.asList(deserializedRow), rowOI); + } + } + + private void compareStruct(Object struct, List deserializedStruct, + StructObjectInspector structObjectInspector) { + final List fields = structObjectInspector.getAllStructFieldRefs(); + for (int fieldIndex = 0; fieldIndex < fields.size(); fieldIndex++) { + final StructField field = fields.get(fieldIndex); + final ObjectInspector fieldObjInspector = field.getFieldObjectInspector(); + final Object value = structObjectInspector.getStructFieldData(struct, field); + final Object deserializedValue = deserializedStruct.get(fieldIndex); + compare(value, deserializedValue, fieldObjInspector); + } + } + + private void compareList(Object list, List deserializedList, + ListObjectInspector listObjInspector) { + int length = listObjInspector.getListLength(list); + ObjectInspector elementObjInspector = listObjInspector.getListElementObjectInspector(); + for (int i = 0; i < length; i++) { + Object value = listObjInspector.getListElement(list, i); + Object deserializedValue = deserializedList.get(i); + compare(value, deserializedValue, elementObjInspector); + } + } + + private void compare(Object value, Object deserializedValue, ObjectInspector objInspector) { + if (value == null && deserializedValue == null) { + return; + } + switch (objInspector.getCategory()) { + case PRIMITIVE: + comparePrimitive(value, deserializedValue, (PrimitiveObjectInspector) objInspector); + break; + case STRUCT: + compareStruct(value, (List) deserializedValue, (StructObjectInspector) objInspector); + break; + case LIST: + compareList(value, (List) deserializedValue, (ListObjectInspector) objInspector); + break; + } + } + + private void comparePrimitive(Object value, Object deserializedValue, + PrimitiveObjectInspector primitiveObjInspector) { + switch (primitiveObjInspector.getPrimitiveCategory()) { + case STRING: + case VARCHAR: + case CHAR: + assertEquals(Objects.toString(value), Objects.toString(deserializedValue)); + break; + case TIMESTAMP: { + Timestamp source = ((TimestampWritableV2) value).getTimestamp(); + Timestamp deserialized = ((TimestampWritableV2) deserializedValue).getTimestamp(); + assertEquals(source.toSqlTimestamp().getTime(), deserialized.toSqlTimestamp().getTime()); + assertEquals(source.getNanos() / ArrowColumnarBatchSerDe.NS_PER_MICROS, + deserialized.getNanos() / ArrowColumnarBatchSerDe.NS_PER_MICROS); + break; + } + case INTERVAL_DAY_TIME: { + HiveIntervalDayTime source = + ((HiveIntervalDayTimeWritable) value).getHiveIntervalDayTime(); + HiveIntervalDayTime deserialized = + ((HiveIntervalDayTimeWritable) deserializedValue).getHiveIntervalDayTime(); + assertEquals(source.getTotalSeconds(), deserialized.getTotalSeconds()); + assertEquals(source.getNanos() / ArrowColumnarBatchSerDe.NS_PER_MILLIS, + deserialized.getNanos() / ArrowColumnarBatchSerDe.NS_PER_MILLIS); + break; } + default: + assertEquals(value, deserializedValue); + break; + } + } + + @Test + public void testRandom() throws SerDeException { + Random random = new Random(3); + int numRows = HiveConf.getIntVar(conf, HiveConf.ConfVars.HIVE_ARROW_BATCH_SIZE); + for (int i = 0; i < 100; i++) { + VectorRandomRowSource source = new VectorRandomRowSource(); + source.init(random, VectorRandomRowSource.SupportedTypes.ALL_EXCEPT_UNION, 4, true, true); + Object[][] rows = source.randomRows(numRows); + + ArrowColumnarBatchSerDe serDe = new ArrowColumnarBatchSerDe(); + StructObjectInspector structObjectInspector = initSerDe(serDe, source.typeInfos()); + serializeAndDeserialize(serDe, rows, structObjectInspector); } } @@ -767,21 +821,4 @@ public void testMapBinary() throws SerDeException { initAndSerializeAndDeserialize(schema, toMap(BINARY_ROWS)); } - - public void testMapDecimal() throws SerDeException { - String[][] schema = { - {"decimal_map", "map"}, - }; - - initAndSerializeAndDeserialize(schema, toMap(DECIMAL_ROWS)); - } - - public void testListDecimal() throws SerDeException { - String[][] schema = { - {"decimal_list", "array"}, - }; - - initAndSerializeAndDeserialize(schema, toList(DECIMAL_ROWS)); - } - }