diff --git ql/src/java/org/apache/hadoop/hive/ql/exec/vector/VectorAssignRow.java ql/src/java/org/apache/hadoop/hive/ql/exec/vector/VectorAssignRow.java index 9c84937031..ee4bbf1fad 100644 --- ql/src/java/org/apache/hadoop/hive/ql/exec/vector/VectorAssignRow.java +++ ql/src/java/org/apache/hadoop/hive/ql/exec/vector/VectorAssignRow.java @@ -21,7 +21,13 @@ import java.sql.Date; import java.sql.Timestamp; import java.util.List; +import java.util.Map; +import org.apache.hadoop.hive.serde2.objectinspector.StandardUnionObjectInspector; +import org.apache.hadoop.hive.serde2.typeinfo.ListTypeInfo; +import org.apache.hadoop.hive.serde2.typeinfo.MapTypeInfo; +import org.apache.hadoop.hive.serde2.typeinfo.StructTypeInfo; +import org.apache.hadoop.hive.serde2.typeinfo.UnionTypeInfo; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.apache.hadoop.hive.common.type.HiveChar; @@ -89,11 +95,8 @@ // Assigning can be a subset of columns, so this is the projection -- // the batch column numbers. - Category[] targetCategories; - // The data type category of each column being assigned. - - PrimitiveCategory[] targetPrimitiveCategories; - // The data type primitive category of each column being assigned. + TypeInfo[] targetTypeInfos; + // The type info of each column being assigned. int[] maxLengths; // For the CHAR and VARCHAR data types, the maximum character length of @@ -117,8 +120,7 @@ private void allocateArrays(int count) { isConvert = new boolean[count]; projectionColumnNums = new int[count]; - targetCategories = new Category[count]; - targetPrimitiveCategories = new PrimitiveCategory[count]; + targetTypeInfos = new TypeInfo[count]; maxLengths = new int[count]; } @@ -136,12 +138,10 @@ private void allocateConvertArrays(int count) { private void initTargetEntry(int logicalColumnIndex, int projectionColumnNum, TypeInfo typeInfo) { isConvert[logicalColumnIndex] = false; projectionColumnNums[logicalColumnIndex] = projectionColumnNum; - Category category = typeInfo.getCategory(); - targetCategories[logicalColumnIndex] = category; - if (category == Category.PRIMITIVE) { + targetTypeInfos[logicalColumnIndex] = typeInfo; + if (typeInfo.getCategory() == Category.PRIMITIVE) { PrimitiveTypeInfo primitiveTypeInfo = (PrimitiveTypeInfo) typeInfo; PrimitiveCategory primitiveCategory = primitiveTypeInfo.getPrimitiveCategory(); - targetPrimitiveCategories[logicalColumnIndex] = primitiveCategory; switch (primitiveCategory) { case CHAR: maxLengths[logicalColumnIndex] = ((CharTypeInfo) primitiveTypeInfo).getLength(); @@ -170,7 +170,8 @@ private void initConvertSourceEntry(int logicalColumnIndex, TypeInfo convertSour convertSourcePrimitiveTypeInfo); // These need to be based on the target. - PrimitiveCategory targetPrimitiveCategory = targetPrimitiveCategories[logicalColumnIndex]; + PrimitiveCategory targetPrimitiveCategory = + ((PrimitiveTypeInfo) targetTypeInfos[logicalColumnIndex]).getPrimitiveCategory(); switch (targetPrimitiveCategory) { case DATE: convertTargetWritables[logicalColumnIndex] = new DateWritable(); @@ -335,73 +336,79 @@ public int initConversion(TypeInfo[] sourceTypeInfos, TypeInfo[] targetTypeInfos */ public void assignRowColumn(VectorizedRowBatch batch, int batchIndex, int logicalColumnIndex, Object object) { - Category targetCategory = targetCategories[logicalColumnIndex]; - if (targetCategory == null) { + final int projectionColumnNum = projectionColumnNums[logicalColumnIndex]; + TypeInfo targetTypeInfo = targetTypeInfos[logicalColumnIndex]; + if (targetTypeInfo == null || targetTypeInfo.getCategory() == null) { /* * This is a column that we don't want (i.e. not included) -- we are done. */ return; } - final int projectionColumnNum = projectionColumnNums[logicalColumnIndex]; + assignRowColumn(batch.cols[projectionColumnNum], batchIndex, targetTypeInfo, object); + } + + private void assignRowColumn( + ColumnVector columnVector, int batchIndex, TypeInfo targetTypeInfo, Object object) { if (object == null) { - VectorizedBatchUtil.setNullColIsNullValue(batch.cols[projectionColumnNum], batchIndex); + VectorizedBatchUtil.setNullColIsNullValue(columnVector, batchIndex); return; } - switch (targetCategory) { + switch (targetTypeInfo.getCategory()) { case PRIMITIVE: { - PrimitiveCategory targetPrimitiveCategory = targetPrimitiveCategories[logicalColumnIndex]; + PrimitiveCategory targetPrimitiveCategory = + ((PrimitiveTypeInfo) targetTypeInfo).getPrimitiveCategory(); switch (targetPrimitiveCategory) { case VOID: - VectorizedBatchUtil.setNullColIsNullValue(batch.cols[projectionColumnNum], batchIndex); + VectorizedBatchUtil.setNullColIsNullValue(columnVector, batchIndex); return; case BOOLEAN: - ((LongColumnVector) batch.cols[projectionColumnNum]).vector[batchIndex] = + ((LongColumnVector) columnVector).vector[batchIndex] = (((BooleanWritable) object).get() ? 1 : 0); break; case BYTE: - ((LongColumnVector) batch.cols[projectionColumnNum]).vector[batchIndex] = + ((LongColumnVector) columnVector).vector[batchIndex] = ((ByteWritable) object).get(); break; case SHORT: - ((LongColumnVector) batch.cols[projectionColumnNum]).vector[batchIndex] = + ((LongColumnVector) columnVector).vector[batchIndex] = ((ShortWritable) object).get(); break; case INT: - ((LongColumnVector) batch.cols[projectionColumnNum]).vector[batchIndex] = + ((LongColumnVector) columnVector).vector[batchIndex] = ((IntWritable) object).get(); break; case LONG: - ((LongColumnVector) batch.cols[projectionColumnNum]).vector[batchIndex] = + ((LongColumnVector) columnVector).vector[batchIndex] = ((LongWritable) object).get(); break; case TIMESTAMP: - ((TimestampColumnVector) batch.cols[projectionColumnNum]).set( + ((TimestampColumnVector) columnVector).set( batchIndex, ((TimestampWritable) object).getTimestamp()); break; case DATE: - ((LongColumnVector) batch.cols[projectionColumnNum]).vector[batchIndex] = + ((LongColumnVector) columnVector).vector[batchIndex] = ((DateWritable) object).getDays(); break; case FLOAT: - ((DoubleColumnVector) batch.cols[projectionColumnNum]).vector[batchIndex] = + ((DoubleColumnVector) columnVector).vector[batchIndex] = ((FloatWritable) object).get(); break; case DOUBLE: - ((DoubleColumnVector) batch.cols[projectionColumnNum]).vector[batchIndex] = + ((DoubleColumnVector) columnVector).vector[batchIndex] = ((DoubleWritable) object).get(); break; case BINARY: { BytesWritable bw = (BytesWritable) object; - ((BytesColumnVector) batch.cols[projectionColumnNum]).setVal( + ((BytesColumnVector) columnVector).setVal( batchIndex, bw.getBytes(), 0, bw.getLength()); } break; case STRING: { Text tw = (Text) object; - ((BytesColumnVector) batch.cols[projectionColumnNum]).setVal( + ((BytesColumnVector) columnVector).setVal( batchIndex, tw.getBytes(), 0, tw.getLength()); } break; @@ -420,7 +427,7 @@ public void assignRowColumn(VectorizedRowBatch batch, int batchIndex, int logica // TODO: HIVE-13624 Do we need maxLength checking? byte[] bytes = hiveVarchar.getValue().getBytes(); - ((BytesColumnVector) batch.cols[projectionColumnNum]).setVal( + ((BytesColumnVector) columnVector).setVal( batchIndex, bytes, 0, bytes.length); } break; @@ -440,25 +447,25 @@ public void assignRowColumn(VectorizedRowBatch batch, int batchIndex, int logica // We store CHAR in vector row batch with padding stripped. byte[] bytes = hiveChar.getStrippedValue().getBytes(); - ((BytesColumnVector) batch.cols[projectionColumnNum]).setVal( + ((BytesColumnVector) columnVector).setVal( batchIndex, bytes, 0, bytes.length); } break; case DECIMAL: if (object instanceof HiveDecimal) { - ((DecimalColumnVector) batch.cols[projectionColumnNum]).set( + ((DecimalColumnVector) columnVector).set( batchIndex, (HiveDecimal) object); } else { - ((DecimalColumnVector) batch.cols[projectionColumnNum]).set( + ((DecimalColumnVector) columnVector).set( batchIndex, (HiveDecimalWritable) object); } break; case INTERVAL_YEAR_MONTH: - ((LongColumnVector) batch.cols[projectionColumnNum]).vector[batchIndex] = + ((LongColumnVector) columnVector).vector[batchIndex] = ((HiveIntervalYearMonthWritable) object).getHiveIntervalYearMonth().getTotalMonths(); break; case INTERVAL_DAY_TIME: - ((IntervalDayTimeColumnVector) batch.cols[projectionColumnNum]).set( + ((IntervalDayTimeColumnVector) columnVector).set( batchIndex, ((HiveIntervalDayTimeWritable) object).getHiveIntervalDayTime()); break; default: @@ -467,14 +474,73 @@ public void assignRowColumn(VectorizedRowBatch batch, int batchIndex, int logica } } break; + case LIST: + { + ListColumnVector listColumnVector = (ListColumnVector) columnVector; + ListTypeInfo listTypeInfo = (ListTypeInfo) targetTypeInfo; + TypeInfo elementTypeInfo = listTypeInfo.getListElementTypeInfo(); + List list = (List) object; + int size = list.size(); + int childCount = listColumnVector.childCount; + listColumnVector.offsets[batchIndex] = childCount; + listColumnVector.lengths[batchIndex] = size; + listColumnVector.childCount = childCount + size; + listColumnVector.child.ensureSize(childCount + size, true); + for (int i = 0; i < size; i++) { + assignRowColumn(listColumnVector.child, childCount + i, elementTypeInfo, list.get(i)); + } + } + break; + case MAP: + { + MapColumnVector mapColumnVector = (MapColumnVector) columnVector; + MapTypeInfo mapTypeInfo = (MapTypeInfo) targetTypeInfo; + Map map = (Map) object; + int size = map.size(); + int childCount = mapColumnVector.childCount; + mapColumnVector.offsets[batchIndex] = childCount; + mapColumnVector.lengths[batchIndex] = size; + mapColumnVector.keys.ensureSize(childCount + size, true); + mapColumnVector.values.ensureSize(childCount + size, true); + for (Map.Entry entry : map.entrySet()) { + assignRowColumn(mapColumnVector.keys, childCount, mapTypeInfo.getMapKeyTypeInfo(), entry.getKey()); + assignRowColumn(mapColumnVector.values, childCount, mapTypeInfo.getMapValueTypeInfo(), entry.getValue()); + childCount++; + } + mapColumnVector.childCount = childCount; + } + break; + case STRUCT: + { + StructColumnVector structColumnVector = (StructColumnVector) columnVector; + List struct = (List) object; + StructTypeInfo structTypeInfo = (StructTypeInfo) targetTypeInfo; + List fieldStructTypeInfos = structTypeInfo.getAllStructFieldTypeInfos(); + int size = fieldStructTypeInfos.size(); + for (int i = 0; i < size; i++) { + assignRowColumn(structColumnVector.fields[i], batchIndex, fieldStructTypeInfos.get(i), struct.get(i)); + } + } + break; + case UNION: + { + StandardUnionObjectInspector.StandardUnion union = (StandardUnionObjectInspector.StandardUnion) object; + UnionColumnVector unionColumnVector = (UnionColumnVector) columnVector; + UnionTypeInfo unionTypeInfo = (UnionTypeInfo) targetTypeInfo; + List objectTypeInfos = unionTypeInfo.getAllUnionObjectTypeInfos(); + byte tag = union.getTag(); + unionColumnVector.tags[batchIndex] = tag; + assignRowColumn(unionColumnVector.fields[tag], batchIndex, objectTypeInfos.get(tag), union.getObject()); + } + break; default: - throw new RuntimeException("Category " + targetCategory.name() + " not supported"); + throw new RuntimeException("Category " + targetTypeInfo.getCategory().name() + " not supported"); } /* * We always set the null flag to false when there is a value. */ - batch.cols[projectionColumnNum].isNull[batchIndex] = false; + columnVector.isNull[batchIndex] = false; } /** @@ -493,7 +559,7 @@ public void assignRowColumn(VectorizedRowBatch batch, int batchIndex, int logica public void assignConvertRowColumn(VectorizedRowBatch batch, int batchIndex, int logicalColumnIndex, Object object) { Preconditions.checkState(isConvert[logicalColumnIndex]); - Category targetCategory = targetCategories[logicalColumnIndex]; + Category targetCategory = targetTypeInfos[logicalColumnIndex].getCategory(); if (targetCategory == null) { /* * This is a column that we don't want (i.e. not included) -- we are done. @@ -508,7 +574,8 @@ public void assignConvertRowColumn(VectorizedRowBatch batch, int batchIndex, try { switch (targetCategory) { case PRIMITIVE: - PrimitiveCategory targetPrimitiveCategory = targetPrimitiveCategories[logicalColumnIndex]; + PrimitiveCategory targetPrimitiveCategory = + ((PrimitiveTypeInfo) targetTypeInfos[logicalColumnIndex]).getPrimitiveCategory(); switch (targetPrimitiveCategory) { case VOID: VectorizedBatchUtil.setNullColIsNullValue(batch.cols[projectionColumnNum], batchIndex); diff --git ql/src/java/org/apache/hadoop/hive/ql/exec/vector/VectorDeserializeRow.java ql/src/java/org/apache/hadoop/hive/ql/exec/vector/VectorDeserializeRow.java index fc82cf79b6..01387c79b8 100644 --- ql/src/java/org/apache/hadoop/hive/ql/exec/vector/VectorDeserializeRow.java +++ ql/src/java/org/apache/hadoop/hive/ql/exec/vector/VectorDeserializeRow.java @@ -18,8 +18,8 @@ package org.apache.hadoop.hive.ql.exec.vector; -import java.io.EOFException; import java.io.IOException; +import java.util.ArrayList; import java.util.Arrays; import java.util.List; @@ -42,8 +42,12 @@ import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector.Category; import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector.PrimitiveCategory; import org.apache.hadoop.hive.serde2.typeinfo.CharTypeInfo; +import org.apache.hadoop.hive.serde2.typeinfo.ListTypeInfo; +import org.apache.hadoop.hive.serde2.typeinfo.MapTypeInfo; import org.apache.hadoop.hive.serde2.typeinfo.PrimitiveTypeInfo; +import org.apache.hadoop.hive.serde2.typeinfo.StructTypeInfo; import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo; +import org.apache.hadoop.hive.serde2.typeinfo.UnionTypeInfo; import org.apache.hadoop.hive.serde2.typeinfo.VarcharTypeInfo; import org.apache.hadoop.io.BooleanWritable; import org.apache.hadoop.io.BytesWritable; @@ -91,6 +95,82 @@ public VectorDeserializeRow(T deserializeRead) { private VectorDeserializeRow() { } + private static class Field { + + private Category category; + + private PrimitiveCategory primitiveCategory; + //The data type primitive category of the column being deserialized. + + private int maxLength; + // For the CHAR and VARCHAR data types, the maximum character length of + // the column. Otherwise, 0. + + private boolean isConvert; + + /* + * This member has information for data type conversion. + * Not defined if there is no conversion. + */ + Writable conversionWritable; + // Conversion requires source be placed in writable so we can call upon + // VectorAssignRow to convert and assign the row column. + + private ComplexTypeHelper complexTypeHelper; + // For a complex type, a helper object that describes elements, key/value pairs, + // or fields. + + public Field(PrimitiveCategory primitiveCategory, int maxLength) { + this.category = Category.PRIMITIVE; + this.primitiveCategory = primitiveCategory; + this.maxLength = maxLength; + this.isConvert = false; + this.conversionWritable = null; + this.complexTypeHelper = null; + } + + public Field(Category category, ComplexTypeHelper complexTypeHelper) { + this.category = category; + this.primitiveCategory = null; + this.maxLength = 0; + this.isConvert = false; + this.conversionWritable = null; + this.complexTypeHelper = complexTypeHelper; + } + + public Category getCategory() { + return category; + } + + public PrimitiveCategory getPrimitiveCategory() { + return primitiveCategory; + } + + public int getMaxLength() { + return maxLength; + } + + public void setIsConvert(boolean isConvert) { + this.isConvert = isConvert; + } + + public boolean getIsConvert() { + return isConvert; + } + + public void setConversionWritable(Writable conversionWritable) { + this.conversionWritable = conversionWritable; + } + + public Writable getConversionWritable() { + return conversionWritable; + } + + public ComplexTypeHelper getComplexHelper() { + return complexTypeHelper; + } + } + /* * These members have information for deserializing a row into the VectorizedRowBatch * columns. @@ -105,30 +185,11 @@ private VectorDeserializeRow() { private int[] readFieldLogicalIndices; // The logical indices for reading with readField. - private boolean[] isConvert; - // For each column, are we converting the row column? - private int[] projectionColumnNums; // Assigning can be a subset of columns, so this is the projection -- // the batch column numbers. - private Category[] sourceCategories; - // The data type category of each column being deserialized. - - private PrimitiveCategory[] sourcePrimitiveCategories; - //The data type primitive category of each column being deserialized. - - private int[] maxLengths; - // For the CHAR and VARCHAR data types, the maximum character length of - // the columns. Otherwise, 0. - - /* - * These members have information for data type conversion. - * Not defined if there is no conversion. - */ - Writable[] convertSourceWritables; - // Conversion requires source be placed in writable so we can call upon - // VectorAssignRow to convert and assign the row column. + private Field topLevelFields[]; VectorAssignRow convertVectorAssignRow; // Use its conversion ability. @@ -137,62 +198,117 @@ private VectorDeserializeRow() { * Allocate the source deserialization related arrays. */ private void allocateArrays(int count) { - isConvert = new boolean[count]; projectionColumnNums = new int[count]; Arrays.fill(projectionColumnNums, -1); - sourceCategories = new Category[count]; - sourcePrimitiveCategories = new PrimitiveCategory[count]; - maxLengths = new int[count]; + topLevelFields = new Field[count]; } - /* - * Allocate the conversion related arrays (optional). - */ - private void allocateConvertArrays(int count) { - convertSourceWritables = new Writable[count]; + private Field allocatePrimitiveField(TypeInfo sourceTypeInfo) { + PrimitiveTypeInfo sourcePrimitiveTypeInfo = (PrimitiveTypeInfo) sourceTypeInfo; + PrimitiveCategory sourcePrimitiveCategory = sourcePrimitiveTypeInfo.getPrimitiveCategory(); + int maxLength; + switch (sourcePrimitiveCategory) { + case CHAR: + maxLength = ((CharTypeInfo) sourcePrimitiveTypeInfo).getLength(); + break; + case VARCHAR: + maxLength = ((VarcharTypeInfo) sourcePrimitiveTypeInfo).getLength(); + break; + default: + // No additional data type specific setting. + maxLength = 0; + break; + } + return new Field(sourcePrimitiveCategory, maxLength); + } + + private Field allocateComplexField(TypeInfo sourceTypeInfo) { + Category category = sourceTypeInfo.getCategory(); + switch (category) { + case LIST: + { + ListTypeInfo listTypeInfo = (ListTypeInfo) sourceTypeInfo; + ListComplexTypeHelper listHelper = + new ListComplexTypeHelper( + allocateField(listTypeInfo.getListElementTypeInfo())); + return new Field(category, listHelper); + } + case MAP: + { + MapTypeInfo mapTypeInfo = (MapTypeInfo) sourceTypeInfo; + MapComplexTypeHelper mapHelper = + new MapComplexTypeHelper( + allocateField(mapTypeInfo.getMapKeyTypeInfo()), + allocateField(mapTypeInfo.getMapValueTypeInfo())); + return new Field(category, mapHelper); + } + case STRUCT: + { + StructTypeInfo structTypeInfo = (StructTypeInfo) sourceTypeInfo; + ArrayList fieldTypeInfoList = structTypeInfo.getAllStructFieldTypeInfos(); + final int count = fieldTypeInfoList.size(); + Field[] fields = new Field[count]; + for (int i = 0; i < count; i++) { + fields[i] = allocateField(fieldTypeInfoList.get(i)); + } + StructComplexTypeHelper structHelper = + new StructComplexTypeHelper(fields); + return new Field(category, structHelper); + } + case UNION: + { + UnionTypeInfo unionTypeInfo = (UnionTypeInfo) sourceTypeInfo; + List fieldTypeInfoList = unionTypeInfo.getAllUnionObjectTypeInfos(); + final int count = fieldTypeInfoList.size(); + Field[] fields = new Field[count]; + for (int i = 0; i < count; i++) { + fields[i] = allocateField(fieldTypeInfoList.get(i)); + } + UnionComplexTypeHelper unionHelper = + new UnionComplexTypeHelper(fields); + return new Field(category, unionHelper); + } + default: + throw new RuntimeException("Category " + category + " not supported"); + } + } + + private Field allocateField(TypeInfo sourceTypeInfo) { + switch (sourceTypeInfo.getCategory()) { + case PRIMITIVE: + return allocatePrimitiveField(sourceTypeInfo); + case LIST: + case MAP: + case STRUCT: + case UNION: + return allocateComplexField(sourceTypeInfo); + default: + throw new RuntimeException("Category " + sourceTypeInfo.getCategory() + " not supported"); + } } /* - * Initialize one column's source deserializtion related arrays. + * Initialize one column's source deserializtion information. */ - private void initSourceEntry(int logicalColumnIndex, int projectionColumnNum, TypeInfo sourceTypeInfo) { - isConvert[logicalColumnIndex] = false; + private void initTopLevelField(int logicalColumnIndex, int projectionColumnNum, TypeInfo sourceTypeInfo) { + projectionColumnNums[logicalColumnIndex] = projectionColumnNum; - Category sourceCategory = sourceTypeInfo.getCategory(); - sourceCategories[logicalColumnIndex] = sourceCategory; - if (sourceCategory == Category.PRIMITIVE) { - PrimitiveTypeInfo sourcePrimitiveTypeInfo = (PrimitiveTypeInfo) sourceTypeInfo; - PrimitiveCategory sourcePrimitiveCategory = sourcePrimitiveTypeInfo.getPrimitiveCategory(); - sourcePrimitiveCategories[logicalColumnIndex] = sourcePrimitiveCategory; - switch (sourcePrimitiveCategory) { - case CHAR: - maxLengths[logicalColumnIndex] = ((CharTypeInfo) sourcePrimitiveTypeInfo).getLength(); - break; - case VARCHAR: - maxLengths[logicalColumnIndex] = ((VarcharTypeInfo) sourcePrimitiveTypeInfo).getLength(); - break; - default: - // No additional data type specific setting. - break; - } - } else { - // We don't currently support complex types. - Preconditions.checkState(false); - } + + topLevelFields[logicalColumnIndex] = allocateField(sourceTypeInfo); } /* - * Initialize the conversion related arrays. Assumes initSourceEntry has already been called. + * Initialize the conversion related arrays. Assumes initTopLevelField has already been called. */ - private void initConvertTargetEntry(int logicalColumnIndex) { - isConvert[logicalColumnIndex] = true; + private void addTopLevelConversion(int logicalColumnIndex) { + Field field = topLevelFields[logicalColumnIndex]; - if (sourceCategories[logicalColumnIndex] == Category.PRIMITIVE) { - convertSourceWritables[logicalColumnIndex] = - VectorizedBatchUtil.getPrimitiveWritable(sourcePrimitiveCategories[logicalColumnIndex]); - } else { - // We don't currently support complex types. - Preconditions.checkState(false); + field.setIsConvert(true); + + if (field.getCategory() == Category.PRIMITIVE) { + + field.setConversionWritable( + VectorizedBatchUtil.getPrimitiveWritable(field.getPrimitiveCategory())); } } @@ -206,7 +322,7 @@ public void init(int[] outputColumns) throws HiveException { for (int i = 0; i < count; i++) { int outputColumn = outputColumns[i]; - initSourceEntry(i, outputColumn, sourceTypeInfos[i]); + initTopLevelField(i, outputColumn, sourceTypeInfos[i]); } } @@ -220,7 +336,7 @@ public void init(List outputColumns) throws HiveException { for (int i = 0; i < count; i++) { int outputColumn = outputColumns.get(i); - initSourceEntry(i, outputColumn, sourceTypeInfos[i]); + initTopLevelField(i, outputColumn, sourceTypeInfos[i]); } } @@ -234,7 +350,7 @@ public void init(int startColumn) throws HiveException { for (int i = 0; i < count; i++) { int outputColumn = startColumn + i; - initSourceEntry(i, outputColumn, sourceTypeInfos[i]); + initTopLevelField(i, outputColumn, sourceTypeInfos[i]); } } @@ -260,7 +376,7 @@ public void init(boolean[] columnsToIncludeTruncated) throws HiveException { } else { - initSourceEntry(i, i, sourceTypeInfos[i]); + initTopLevelField(i, i, sourceTypeInfos[i]); includedIndices[includedCount++] = i; } } @@ -298,7 +414,6 @@ public void initConversion(TypeInfo[] targetTypeInfos, final int columnCount = sourceTypeInfos.length; allocateArrays(columnCount); - allocateConvertArrays(columnCount); int includedCount = 0; int[] includedIndices = new int[columnCount]; @@ -320,20 +435,22 @@ public void initConversion(TypeInfo[] targetTypeInfos, if (VectorPartitionConversion.isImplicitVectorColumnConversion(sourceTypeInfo, targetTypeInfo)) { // Do implicit conversion from source type to target type. - initSourceEntry(i, i, sourceTypeInfo); + initTopLevelField(i, i, sourceTypeInfo); } else { // Do formal conversion... - initSourceEntry(i, i, sourceTypeInfo); - initConvertTargetEntry(i); + initTopLevelField(i, i, sourceTypeInfo); + + // UNDONE: No for List and Map; Yes for Struct and Union when field count different... + addTopLevelConversion(i); atLeastOneConvert = true; } } else { // No conversion. - initSourceEntry(i, i, sourceTypeInfo); + initTopLevelField(i, i, sourceTypeInfo); } @@ -360,6 +477,388 @@ public void init() throws HiveException { init(0); } + private void storePrimitiveRowColumn(ColumnVector colVector, + Field field, int batchIndex, + boolean canRetainByteRef) throws IOException { + switch (field.getPrimitiveCategory()) { + case VOID: + VectorizedBatchUtil.setNullColIsNullValue(colVector, batchIndex); + return; + case BOOLEAN: + ((LongColumnVector) colVector).vector[batchIndex] = (deserializeRead.currentBoolean ? 1 : 0); + break; + case BYTE: + ((LongColumnVector) colVector).vector[batchIndex] = deserializeRead.currentByte; + break; + case SHORT: + ((LongColumnVector) colVector).vector[batchIndex] = deserializeRead.currentShort; + break; + case INT: + ((LongColumnVector) colVector).vector[batchIndex] = deserializeRead.currentInt; + break; + case LONG: + ((LongColumnVector) colVector).vector[batchIndex] = deserializeRead.currentLong; + break; + case TIMESTAMP: + ((TimestampColumnVector) colVector).set( + batchIndex, deserializeRead.currentTimestampWritable.getTimestamp()); + break; + case DATE: + ((LongColumnVector) colVector).vector[batchIndex] = deserializeRead.currentDateWritable.getDays(); + break; + case FLOAT: + ((DoubleColumnVector) colVector).vector[batchIndex] = deserializeRead.currentFloat; + break; + case DOUBLE: + ((DoubleColumnVector) colVector).vector[batchIndex] = deserializeRead.currentDouble; + break; + case BINARY: + case STRING: + { + BytesColumnVector bytesColVec = ((BytesColumnVector) colVector); + if (deserializeRead.currentExternalBufferNeeded) { + bytesColVec.ensureValPreallocated(deserializeRead.currentExternalBufferNeededLen); + deserializeRead.copyToExternalBuffer( + bytesColVec.getValPreallocatedBytes(), bytesColVec.getValPreallocatedStart()); + bytesColVec.setValPreallocated( + batchIndex, + deserializeRead.currentExternalBufferNeededLen); + } else if (canRetainByteRef && inputBytes == deserializeRead.currentBytes) { + bytesColVec.setRef( + batchIndex, + deserializeRead.currentBytes, + deserializeRead.currentBytesStart, + deserializeRead.currentBytesLength); + } else { + bytesColVec.setVal( + batchIndex, + deserializeRead.currentBytes, + deserializeRead.currentBytesStart, + deserializeRead.currentBytesLength); + } + } + break; + case VARCHAR: + { + // Use the basic STRING bytes read to get access, then use our optimal truncate/trim method + // that does not use Java String objects. + BytesColumnVector bytesColVec = ((BytesColumnVector) colVector); + if (deserializeRead.currentExternalBufferNeeded) { + // Write directly into our BytesColumnVector value buffer. + bytesColVec.ensureValPreallocated(deserializeRead.currentExternalBufferNeededLen); + byte[] convertBuffer = bytesColVec.getValPreallocatedBytes(); + int convertBufferStart = bytesColVec.getValPreallocatedStart(); + deserializeRead.copyToExternalBuffer( + convertBuffer, + convertBufferStart); + bytesColVec.setValPreallocated( + batchIndex, + StringExpr.truncate( + convertBuffer, + convertBufferStart, + deserializeRead.currentExternalBufferNeededLen, + field.getMaxLength())); + } else if (canRetainByteRef && inputBytes == deserializeRead.currentBytes) { + bytesColVec.setRef( + batchIndex, + deserializeRead.currentBytes, + deserializeRead.currentBytesStart, + StringExpr.truncate( + deserializeRead.currentBytes, + deserializeRead.currentBytesStart, + deserializeRead.currentBytesLength, + field.getMaxLength())); + } else { + bytesColVec.setVal( + batchIndex, + deserializeRead.currentBytes, + deserializeRead.currentBytesStart, + StringExpr.truncate( + deserializeRead.currentBytes, + deserializeRead.currentBytesStart, + deserializeRead.currentBytesLength, + field.getMaxLength())); + } + } + break; + case CHAR: + { + // Use the basic STRING bytes read to get access, then use our optimal truncate/trim method + // that does not use Java String objects. + BytesColumnVector bytesColVec = ((BytesColumnVector) colVector); + if (deserializeRead.currentExternalBufferNeeded) { + // Write directly into our BytesColumnVector value buffer. + bytesColVec.ensureValPreallocated(deserializeRead.currentExternalBufferNeededLen); + byte[] convertBuffer = bytesColVec.getValPreallocatedBytes(); + int convertBufferStart = bytesColVec.getValPreallocatedStart(); + deserializeRead.copyToExternalBuffer( + convertBuffer, + convertBufferStart); + bytesColVec.setValPreallocated( + batchIndex, + StringExpr.rightTrimAndTruncate( + convertBuffer, + convertBufferStart, + deserializeRead.currentExternalBufferNeededLen, + field.getMaxLength())); + } else if (canRetainByteRef && inputBytes == deserializeRead.currentBytes) { + bytesColVec.setRef( + batchIndex, + deserializeRead.currentBytes, + deserializeRead.currentBytesStart, + StringExpr.rightTrimAndTruncate( + deserializeRead.currentBytes, + deserializeRead.currentBytesStart, + deserializeRead.currentBytesLength, + field.getMaxLength())); + } else { + bytesColVec.setVal( + batchIndex, + deserializeRead.currentBytes, + deserializeRead.currentBytesStart, + StringExpr.rightTrimAndTruncate( + deserializeRead.currentBytes, + deserializeRead.currentBytesStart, + deserializeRead.currentBytesLength, + field.getMaxLength())); + } + } + break; + case DECIMAL: + // The DecimalColumnVector set method will quickly copy the deserialized decimal writable fields. + ((DecimalColumnVector) colVector).set( + batchIndex, deserializeRead.currentHiveDecimalWritable); + break; + case INTERVAL_YEAR_MONTH: + ((LongColumnVector) colVector).vector[batchIndex] = + deserializeRead.currentHiveIntervalYearMonthWritable.getHiveIntervalYearMonth().getTotalMonths(); + break; + case INTERVAL_DAY_TIME: + ((IntervalDayTimeColumnVector) colVector).set( + batchIndex, deserializeRead.currentHiveIntervalDayTimeWritable.getHiveIntervalDayTime()); + break; + default: + throw new RuntimeException("Primitive category " + field.getPrimitiveCategory() + + " not supported"); + } + } + + private static class ComplexTypeHelper { + } + + private static class ListComplexTypeHelper extends ComplexTypeHelper { + + private Field elementField; + + public ListComplexTypeHelper(Field elementField) { + this.elementField = elementField; + } + + public Field getElementField() { + return elementField; + } + } + + private static class MapComplexTypeHelper extends ComplexTypeHelper { + + private Field keyField; + private Field valueField; + + public MapComplexTypeHelper(Field keyField, Field valueField) { + this.keyField = keyField; + this.valueField = valueField; + } + + public Field getKeyField() { + return keyField; + } + + public Field getValueField() { + return valueField; + } + } + + private static class FieldsComplexTypeHelper extends ComplexTypeHelper { + + private Field[] fields; + + public FieldsComplexTypeHelper(Field[] fields) { + this.fields = fields; + } + + public Field[] getFields() { + return fields; + } + } + + private static class StructComplexTypeHelper extends FieldsComplexTypeHelper { + + public StructComplexTypeHelper(Field[] fields) { + super(fields); + } + } + + private static class UnionComplexTypeHelper extends FieldsComplexTypeHelper { + + public UnionComplexTypeHelper(Field[] fields) { + super(fields); + } + } + + // UNDONE: Presumption of *append* + + private void storeComplexFieldRowColumn(ColumnVector fieldColVector, + Field field, int batchIndex, boolean canRetainByteRef) throws IOException { + + if (!deserializeRead.readComplexField()) { + fieldColVector.isNull[batchIndex] = true; + fieldColVector.noNulls = false; + return; + } + + switch (field.getCategory()) { + case PRIMITIVE: + storePrimitiveRowColumn(fieldColVector, field, batchIndex, canRetainByteRef); + break; + case LIST: + storeListRowColumn(fieldColVector, field, batchIndex, canRetainByteRef); + break; + case MAP: + storeMapRowColumn(fieldColVector, field, batchIndex, canRetainByteRef); + break; + case STRUCT: + storeStructRowColumn(fieldColVector, field, batchIndex, canRetainByteRef); + break; + case UNION: + storeUnionRowColumn(fieldColVector, field, batchIndex, canRetainByteRef); + break; + default: + throw new RuntimeException("Category " + field.getCategory() + " not supported"); + } + } + + private void storeListRowColumn(ColumnVector colVector, + Field field, int batchIndex, boolean canRetainByteRef) throws IOException { + + ListColumnVector listColVector = (ListColumnVector) colVector; + listColVector.isNull[batchIndex] = false; + int offset = listColVector.childCount; + listColVector.offsets[batchIndex] = offset; + + ColumnVector elementColVector = listColVector.child; + + ListComplexTypeHelper listHelper = (ListComplexTypeHelper) field.getComplexHelper(); + + int listLength = 0; + while (deserializeRead.isNextComplexMultiValue()) { + + // Ensure child size. + int childCapacity = listColVector.child.isNull.length; + int childCount = listColVector.childCount; + if (childCapacity < childCount / 0.75) { + listColVector.child.ensureSize(childCapacity * 2, true); + } + + storeComplexFieldRowColumn( + elementColVector, listHelper.getElementField(), offset, canRetainByteRef); + offset++; + listLength++; + } + + listColVector.childCount += listLength; + listColVector.lengths[batchIndex] = listLength; + } + + private void storeMapRowColumn(ColumnVector colVector, + Field field, int batchIndex, boolean canRetainByteRef) throws IOException { + + MapColumnVector mapColVector = (MapColumnVector) colVector; + + mapColVector.isNull[batchIndex] = false; + int offset = mapColVector.childCount; + mapColVector.offsets[batchIndex] = offset; + + ColumnVector keysColVector = mapColVector.keys; + ColumnVector valuesColVector = mapColVector.values; + + MapComplexTypeHelper mapHelper = (MapComplexTypeHelper) field.getComplexHelper(); + + int keyValueCount = 0; + while (deserializeRead.isNextComplexMultiValue()) { + + // Ensure child size. + int childCapacity = mapColVector.keys.isNull.length; + int childCount = mapColVector.childCount; + if (childCapacity < childCount / 0.75) { + mapColVector.keys.ensureSize(childCapacity * 2, true); + mapColVector.values.ensureSize(childCapacity * 2, true); + } + + // Key. + storeComplexFieldRowColumn( + keysColVector, mapHelper.getKeyField(), offset, canRetainByteRef); + + // Value. + storeComplexFieldRowColumn( + valuesColVector, mapHelper.getValueField(), offset, canRetainByteRef); + + offset++; + keyValueCount++; + } + + mapColVector.childCount += keyValueCount; + mapColVector.lengths[batchIndex] = keyValueCount; + } + + private void storeStructRowColumn(ColumnVector colVector, + Field field, int batchIndex, boolean canRetainByteRef) throws IOException { + + StructColumnVector structColVector = (StructColumnVector) colVector; + + structColVector.isNull[batchIndex] = false; + + ColumnVector[] colVectorFields = structColVector.fields; + + StructComplexTypeHelper structHelper = (StructComplexTypeHelper) field.getComplexHelper(); + + Field[] fields = structHelper.getFields(); + int i = 0; + for (ColumnVector colVectorField : colVectorFields) { + storeComplexFieldRowColumn( + colVectorField, + fields[i], + batchIndex, + canRetainByteRef); + i++; + } + deserializeRead.finishComplexVariableFieldsType(); + } + + private void storeUnionRowColumn(ColumnVector colVector, + Field field, int batchIndex, boolean canRetainByteRef) throws IOException { + + deserializeRead.readComplexField(); + + // The read field of the Union gives us its tag. + final int tag = deserializeRead.currentInt; + + UnionColumnVector unionColVector = (UnionColumnVector) colVector; + + unionColVector.isNull[batchIndex] = false; + + ColumnVector[] colVectorFields = unionColVector.fields; + unionColVector.tags[batchIndex] = tag; + + UnionComplexTypeHelper unionHelper = (UnionComplexTypeHelper) field.getComplexHelper(); + + storeComplexFieldRowColumn( + colVectorFields[tag], + unionHelper.getFields()[tag], + batchIndex, + canRetainByteRef); + deserializeRead.finishComplexVariableFieldsType(); + } + /** * Store one row column value that is the current value in deserializeRead. * @@ -374,186 +873,29 @@ public void init() throws HiveException { * @throws IOException */ private void storeRowColumn(VectorizedRowBatch batch, int batchIndex, - int logicalColumnIndex, boolean canRetainByteRef) throws IOException { + Field field, int logicalColumnIndex, boolean canRetainByteRef) throws IOException { final int projectionColumnNum = projectionColumnNums[logicalColumnIndex]; - switch (sourceCategories[logicalColumnIndex]) { + ColumnVector colVector = batch.cols[projectionColumnNum]; + + switch (field.getCategory()) { case PRIMITIVE: - { - PrimitiveCategory sourcePrimitiveCategory = sourcePrimitiveCategories[logicalColumnIndex]; - switch (sourcePrimitiveCategory) { - case VOID: - VectorizedBatchUtil.setNullColIsNullValue(batch.cols[projectionColumnNum], batchIndex); - return; - case BOOLEAN: - ((LongColumnVector) batch.cols[projectionColumnNum]).vector[batchIndex] = - (deserializeRead.currentBoolean ? 1 : 0); - break; - case BYTE: - ((LongColumnVector) batch.cols[projectionColumnNum]).vector[batchIndex] = - deserializeRead.currentByte; - break; - case SHORT: - ((LongColumnVector) batch.cols[projectionColumnNum]).vector[batchIndex] = - deserializeRead.currentShort; - break; - case INT: - ((LongColumnVector) batch.cols[projectionColumnNum]).vector[batchIndex] = - deserializeRead.currentInt; - break; - case LONG: - ((LongColumnVector) batch.cols[projectionColumnNum]).vector[batchIndex] = - deserializeRead.currentLong; - break; - case TIMESTAMP: - ((TimestampColumnVector) batch.cols[projectionColumnNum]).set( - batchIndex, deserializeRead.currentTimestampWritable.getTimestamp()); - break; - case DATE: - ((LongColumnVector) batch.cols[projectionColumnNum]).vector[batchIndex] = - deserializeRead.currentDateWritable.getDays(); - break; - case FLOAT: - ((DoubleColumnVector) batch.cols[projectionColumnNum]).vector[batchIndex] = - deserializeRead.currentFloat; - break; - case DOUBLE: - ((DoubleColumnVector) batch.cols[projectionColumnNum]).vector[batchIndex] = - deserializeRead.currentDouble; - break; - case BINARY: - case STRING: - { - BytesColumnVector bytesColVec = ((BytesColumnVector) batch.cols[projectionColumnNum]); - if (deserializeRead.currentExternalBufferNeeded) { - bytesColVec.ensureValPreallocated(deserializeRead.currentExternalBufferNeededLen); - deserializeRead.copyToExternalBuffer( - bytesColVec.getValPreallocatedBytes(), bytesColVec.getValPreallocatedStart()); - bytesColVec.setValPreallocated( - batchIndex, - deserializeRead.currentExternalBufferNeededLen); - } else if (canRetainByteRef && inputBytes == deserializeRead.currentBytes) { - bytesColVec.setRef( - batchIndex, - deserializeRead.currentBytes, - deserializeRead.currentBytesStart, - deserializeRead.currentBytesLength); - } else { - bytesColVec.setVal( - batchIndex, - deserializeRead.currentBytes, - deserializeRead.currentBytesStart, - deserializeRead.currentBytesLength); - } - } - break; - case VARCHAR: - { - // Use the basic STRING bytes read to get access, then use our optimal truncate/trim method - // that does not use Java String objects. - BytesColumnVector bytesColVec = ((BytesColumnVector) batch.cols[projectionColumnNum]); - if (deserializeRead.currentExternalBufferNeeded) { - // Write directly into our BytesColumnVector value buffer. - bytesColVec.ensureValPreallocated(deserializeRead.currentExternalBufferNeededLen); - byte[] convertBuffer = bytesColVec.getValPreallocatedBytes(); - int convertBufferStart = bytesColVec.getValPreallocatedStart(); - deserializeRead.copyToExternalBuffer( - convertBuffer, - convertBufferStart); - bytesColVec.setValPreallocated( - batchIndex, - StringExpr.truncate( - convertBuffer, - convertBufferStart, - deserializeRead.currentExternalBufferNeededLen, - maxLengths[logicalColumnIndex])); - } else if (canRetainByteRef && inputBytes == deserializeRead.currentBytes) { - bytesColVec.setRef( - batchIndex, - deserializeRead.currentBytes, - deserializeRead.currentBytesStart, - StringExpr.truncate( - deserializeRead.currentBytes, - deserializeRead.currentBytesStart, - deserializeRead.currentBytesLength, - maxLengths[logicalColumnIndex])); - } else { - bytesColVec.setVal( - batchIndex, - deserializeRead.currentBytes, - deserializeRead.currentBytesStart, - StringExpr.truncate( - deserializeRead.currentBytes, - deserializeRead.currentBytesStart, - deserializeRead.currentBytesLength, - maxLengths[logicalColumnIndex])); - } - } - break; - case CHAR: - { - // Use the basic STRING bytes read to get access, then use our optimal truncate/trim method - // that does not use Java String objects. - BytesColumnVector bytesColVec = ((BytesColumnVector) batch.cols[projectionColumnNum]); - if (deserializeRead.currentExternalBufferNeeded) { - // Write directly into our BytesColumnVector value buffer. - bytesColVec.ensureValPreallocated(deserializeRead.currentExternalBufferNeededLen); - byte[] convertBuffer = bytesColVec.getValPreallocatedBytes(); - int convertBufferStart = bytesColVec.getValPreallocatedStart(); - deserializeRead.copyToExternalBuffer( - convertBuffer, - convertBufferStart); - bytesColVec.setValPreallocated( - batchIndex, - StringExpr.rightTrimAndTruncate( - convertBuffer, - convertBufferStart, - deserializeRead.currentExternalBufferNeededLen, - maxLengths[logicalColumnIndex])); - } else if (canRetainByteRef && inputBytes == deserializeRead.currentBytes) { - bytesColVec.setRef( - batchIndex, - deserializeRead.currentBytes, - deserializeRead.currentBytesStart, - StringExpr.rightTrimAndTruncate( - deserializeRead.currentBytes, - deserializeRead.currentBytesStart, - deserializeRead.currentBytesLength, - maxLengths[logicalColumnIndex])); - } else { - bytesColVec.setVal( - batchIndex, - deserializeRead.currentBytes, - deserializeRead.currentBytesStart, - StringExpr.rightTrimAndTruncate( - deserializeRead.currentBytes, - deserializeRead.currentBytesStart, - deserializeRead.currentBytesLength, - maxLengths[logicalColumnIndex])); - } - } - break; - case DECIMAL: - // The DecimalColumnVector set method will quickly copy the deserialized decimal writable fields. - ((DecimalColumnVector) batch.cols[projectionColumnNum]).set( - batchIndex, deserializeRead.currentHiveDecimalWritable); - break; - case INTERVAL_YEAR_MONTH: - ((LongColumnVector) batch.cols[projectionColumnNum]).vector[batchIndex] = - deserializeRead.currentHiveIntervalYearMonthWritable.getHiveIntervalYearMonth().getTotalMonths(); - break; - case INTERVAL_DAY_TIME: - ((IntervalDayTimeColumnVector) batch.cols[projectionColumnNum]).set( - batchIndex, deserializeRead.currentHiveIntervalDayTimeWritable.getHiveIntervalDayTime()); - break; - default: - throw new RuntimeException("Primitive category " + sourcePrimitiveCategory.name() + - " not supported"); - } - } + storePrimitiveRowColumn(colVector, field, batchIndex, canRetainByteRef); + break; + case LIST: + storeListRowColumn(colVector, field, batchIndex, canRetainByteRef); + break; + case MAP: + storeMapRowColumn(colVector, field, batchIndex, canRetainByteRef); + break; + case STRUCT: + storeStructRowColumn(colVector, field, batchIndex, canRetainByteRef); + break; + case UNION: + storeUnionRowColumn(colVector, field, batchIndex, canRetainByteRef); break; default: - throw new RuntimeException("Category " + sourceCategories[logicalColumnIndex] + " not supported"); + throw new RuntimeException("Category " + field.getCategory() + " not supported"); } // We always set the null flag to false when there is a value. @@ -572,13 +914,13 @@ private void storeRowColumn(VectorizedRowBatch batch, int batchIndex, * @throws IOException */ private void convertRowColumn(VectorizedRowBatch batch, int batchIndex, - int logicalColumnIndex) throws IOException { - final int projectionColumnNum = projectionColumnNums[logicalColumnIndex]; - Writable convertSourceWritable = convertSourceWritables[logicalColumnIndex]; - switch (sourceCategories[logicalColumnIndex]) { + Field field, int logicalColumnIndex) throws IOException { + + Writable convertSourceWritable = field.getConversionWritable(); + switch (field.getCategory()) { case PRIMITIVE: { - switch (sourcePrimitiveCategories[logicalColumnIndex]) { + switch (field.getPrimitiveCategory()) { case VOID: convertSourceWritable = null; break; @@ -611,7 +953,9 @@ private void convertRowColumn(VectorizedRowBatch batch, int batchIndex, break; case BINARY: if (deserializeRead.currentBytes == null) { - LOG.info("null binary entry: batchIndex " + batchIndex + " projection column num " + projectionColumnNum); + LOG.info( + "null binary entry: batchIndex " + batchIndex + " projection column num " + + projectionColumnNums[logicalColumnIndex]); } ((BytesWritable) convertSourceWritable).set( @@ -622,7 +966,8 @@ private void convertRowColumn(VectorizedRowBatch batch, int batchIndex, case STRING: if (deserializeRead.currentBytes == null) { throw new RuntimeException( - "null string entry: batchIndex " + batchIndex + " projection column num " + projectionColumnNum); + "null string entry: batchIndex " + batchIndex + " projection column num " + + projectionColumnNums[logicalColumnIndex]); } // Use org.apache.hadoop.io.Text as our helper to go from byte[] to String. @@ -637,14 +982,15 @@ private void convertRowColumn(VectorizedRowBatch batch, int batchIndex, // that does not use Java String objects. if (deserializeRead.currentBytes == null) { throw new RuntimeException( - "null varchar entry: batchIndex " + batchIndex + " projection column num " + projectionColumnNum); + "null varchar entry: batchIndex " + batchIndex + " projection column num " + + projectionColumnNums[logicalColumnIndex]); } int adjustedLength = StringExpr.truncate( deserializeRead.currentBytes, deserializeRead.currentBytesStart, deserializeRead.currentBytesLength, - maxLengths[logicalColumnIndex]); + field.getMaxLength()); ((HiveVarcharWritable) convertSourceWritable).set( new String( @@ -661,14 +1007,15 @@ private void convertRowColumn(VectorizedRowBatch batch, int batchIndex, // that does not use Java String objects. if (deserializeRead.currentBytes == null) { throw new RuntimeException( - "null char entry: batchIndex " + batchIndex + " projection column num " + projectionColumnNum); + "null char entry: batchIndex " + batchIndex + " projection column num " + + projectionColumnNums[logicalColumnIndex]); } int adjustedLength = StringExpr.rightTrimAndTruncate( deserializeRead.currentBytes, deserializeRead.currentBytesStart, deserializeRead.currentBytesLength, - maxLengths[logicalColumnIndex]); + field.getMaxLength()); ((HiveCharWritable) convertSourceWritable).set( new String( @@ -691,13 +1038,26 @@ private void convertRowColumn(VectorizedRowBatch batch, int batchIndex, deserializeRead.currentHiveIntervalDayTimeWritable); break; default: - throw new RuntimeException("Primitive category " + sourcePrimitiveCategories[logicalColumnIndex] + + throw new RuntimeException("Primitive category " + field.getPrimitiveCategory() + " not supported"); } } break; + + case STRUCT: + case UNION: + // The only aspect of conversion to Struct / Union themselves is add fields as NULL on the end + // (no removal from end? which would mean skipping fields...) + + // UNDONE + break; + + case LIST: + case MAP: + // Conversion only happens below to List elements or Map key and/or values and not to the + // List or Map itself. default: - throw new RuntimeException("Category " + sourceCategories[logicalColumnIndex] + " not supported"); + throw new RuntimeException("Category " + field.getCategory() + " not supported"); } /* @@ -739,7 +1099,10 @@ public void deserialize(VectorizedRowBatch batch, int batchIndex) throws IOExcep // Pass false for canRetainByteRef since we will NOT be keeping byte references to the input // bytes with the BytesColumnVector.setRef method. - final int count = isConvert.length; + final int count = topLevelFields.length; + + Field field; + if (!useReadField) { for (int i = 0; i < count; i++) { final int projectionColumnNum = projectionColumnNums[i]; @@ -755,10 +1118,11 @@ public void deserialize(VectorizedRowBatch batch, int batchIndex) throws IOExcep continue; } // The current* members of deserializeRead have the field value. - if (isConvert[i]) { - convertRowColumn(batch, batchIndex, i); + field = topLevelFields[i]; + if (field.getIsConvert()) { + convertRowColumn(batch, batchIndex, field, i); } else { - storeRowColumn(batch, batchIndex, i, /* canRetainByteRef */ false); + storeRowColumn(batch, batchIndex, field, i, /* canRetainByteRef */ false); } } } else { @@ -773,10 +1137,11 @@ public void deserialize(VectorizedRowBatch batch, int batchIndex) throws IOExcep continue; } // The current* members of deserializeRead have the field value. - if (isConvert[logicalIndex]) { - convertRowColumn(batch, batchIndex, logicalIndex); + field = topLevelFields[logicalIndex]; + if (field.getIsConvert()) { + convertRowColumn(batch, batchIndex, field, logicalIndex); } else { - storeRowColumn(batch, batchIndex, logicalIndex, /* canRetainByteRef */ false); + storeRowColumn(batch, batchIndex, field, logicalIndex, /* canRetainByteRef */ false); } } } @@ -803,7 +1168,11 @@ public void deserialize(VectorizedRowBatch batch, int batchIndex) throws IOExcep * @throws IOException */ public void deserializeByRef(VectorizedRowBatch batch, int batchIndex) throws IOException { - final int count = isConvert.length; + + final int count = topLevelFields.length; + + Field field; + if (!useReadField) { for (int i = 0; i < count; i++) { final int projectionColumnNum = projectionColumnNums[i]; @@ -819,10 +1188,11 @@ public void deserializeByRef(VectorizedRowBatch batch, int batchIndex) throws IO continue; } // The current* members of deserializeRead have the field value. - if (isConvert[i]) { - convertRowColumn(batch, batchIndex, i); + field = topLevelFields[i]; + if (field.getIsConvert()) { + convertRowColumn(batch, batchIndex, field, i); } else { - storeRowColumn(batch, batchIndex, i, /* canRetainByteRef */ true); + storeRowColumn(batch, batchIndex, field, i, /* canRetainByteRef */ true); } } } else { @@ -837,10 +1207,11 @@ public void deserializeByRef(VectorizedRowBatch batch, int batchIndex) throws IO continue; } // The current* members of deserializeRead have the field value. - if (isConvert[logicalIndex]) { - convertRowColumn(batch, batchIndex, logicalIndex); + field = topLevelFields[logicalIndex]; + if (field.getIsConvert()) { + convertRowColumn(batch, batchIndex, field, logicalIndex); } else { - storeRowColumn(batch, batchIndex, logicalIndex, /* canRetainByteRef */ true); + storeRowColumn(batch, batchIndex, field, logicalIndex, /* canRetainByteRef */ true); } } } diff --git ql/src/java/org/apache/hadoop/hive/ql/exec/vector/VectorExtractRow.java ql/src/java/org/apache/hadoop/hive/ql/exec/vector/VectorExtractRow.java index defaf9082f..34db5f1de0 100644 --- ql/src/java/org/apache/hadoop/hive/ql/exec/vector/VectorExtractRow.java +++ ql/src/java/org/apache/hadoop/hive/ql/exec/vector/VectorExtractRow.java @@ -18,8 +18,18 @@ package org.apache.hadoop.hive.ql.exec.vector; +import java.util.ArrayList; +import java.util.HashMap; import java.util.List; - +import java.util.Map; + +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory; +import org.apache.hadoop.hive.serde2.objectinspector.StandardStructObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.StandardUnionObjectInspector; +import org.apache.hadoop.hive.serde2.typeinfo.ListTypeInfo; +import org.apache.hadoop.hive.serde2.typeinfo.MapTypeInfo; +import org.apache.hadoop.hive.serde2.typeinfo.StructTypeInfo; +import org.apache.hadoop.hive.serde2.typeinfo.UnionTypeInfo; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.apache.hadoop.hive.ql.exec.vector.expressions.StringExpr; @@ -73,28 +83,14 @@ // Extraction can be a subset of columns, so this is the projection -- // the batch column numbers. - Category[] categories; - // The data type category of each column being extracted. - - PrimitiveCategory[] primitiveCategories; - // The data type primitive category of each column being assigned. - - int[] maxLengths; - // For the CHAR and VARCHAR data types, the maximum character length of - // the columns. Otherwise, 0. - - Writable[] primitiveWritables; - // The extracted values will be placed in these writables. + TypeInfo[] typeInfos; /* * Allocate the various arrays. */ private void allocateArrays(int count) { projectionColumnNums = new int[count]; - categories = new Category[count]; - primitiveCategories = new PrimitiveCategory[count]; - maxLengths = new int[count]; - primitiveWritables = new Writable[count]; + typeInfos = new TypeInfo[count]; } /* @@ -102,28 +98,7 @@ private void allocateArrays(int count) { */ private void initEntry(int logicalColumnIndex, int projectionColumnNum, TypeInfo typeInfo) { projectionColumnNums[logicalColumnIndex] = projectionColumnNum; - Category category = typeInfo.getCategory(); - categories[logicalColumnIndex] = category; - if (category == Category.PRIMITIVE) { - PrimitiveTypeInfo primitiveTypeInfo = (PrimitiveTypeInfo) typeInfo; - PrimitiveCategory primitiveCategory = primitiveTypeInfo.getPrimitiveCategory(); - primitiveCategories[logicalColumnIndex] = primitiveCategory; - - switch (primitiveCategory) { - case CHAR: - maxLengths[logicalColumnIndex] = ((CharTypeInfo) primitiveTypeInfo).getLength(); - break; - case VARCHAR: - maxLengths[logicalColumnIndex] = ((VarcharTypeInfo) primitiveTypeInfo).getLength(); - break; - default: - // No additional data type specific setting. - break; - } - - primitiveWritables[logicalColumnIndex] = - VectorizedBatchUtil.getPrimitiveWritable(primitiveCategory); - } + typeInfos[logicalColumnIndex] = typeInfo; } /* @@ -186,6 +161,16 @@ public void init(List typeNames) throws HiveException { } } + public void init(TypeInfo[] typeInfos) throws HiveException { + + final int count = typeInfos.length; + allocateArrays(count); + + for (int i = 0; i < count; i++) { + initEntry(i, i, typeInfos[i]); + } + } + public int getCount() { return projectionColumnNums.length; } @@ -201,6 +186,12 @@ public int getCount() { public Object extractRowColumn(VectorizedRowBatch batch, int batchIndex, int logicalColumnIndex) { final int projectionColumnNum = projectionColumnNums[logicalColumnIndex]; ColumnVector colVector = batch.cols[projectionColumnNum]; + return extractRowColumn(colVector, typeInfos[logicalColumnIndex], batchIndex); + } + + public Object extractRowColumn( + ColumnVector colVector, TypeInfo typeInfo, int batchIndex) { + if (colVector == null) { // The planner will not include unneeded columns for reading but other parts of execution // may ask for them.. @@ -211,63 +202,64 @@ public Object extractRowColumn(VectorizedRowBatch batch, int batchIndex, int log return null; } - Category category = categories[logicalColumnIndex]; + Category category = typeInfo.getCategory(); switch (category) { case PRIMITIVE: { + PrimitiveTypeInfo primitiveTypeInfo = (PrimitiveTypeInfo) typeInfo; + PrimitiveCategory primitiveCategory = primitiveTypeInfo.getPrimitiveCategory(); Writable primitiveWritable = - primitiveWritables[logicalColumnIndex]; - PrimitiveCategory primitiveCategory = primitiveCategories[logicalColumnIndex]; + VectorizedBatchUtil.getPrimitiveWritable(primitiveCategory); switch (primitiveCategory) { case VOID: return null; case BOOLEAN: ((BooleanWritable) primitiveWritable).set( - ((LongColumnVector) batch.cols[projectionColumnNum]).vector[adjustedIndex] == 0 ? + ((LongColumnVector) colVector).vector[adjustedIndex] == 0 ? false : true); return primitiveWritable; case BYTE: ((ByteWritable) primitiveWritable).set( - (byte) ((LongColumnVector) batch.cols[projectionColumnNum]).vector[adjustedIndex]); + (byte) ((LongColumnVector) colVector).vector[adjustedIndex]); return primitiveWritable; case SHORT: ((ShortWritable) primitiveWritable).set( - (short) ((LongColumnVector) batch.cols[projectionColumnNum]).vector[adjustedIndex]); + (short) ((LongColumnVector) colVector).vector[adjustedIndex]); return primitiveWritable; case INT: ((IntWritable) primitiveWritable).set( - (int) ((LongColumnVector) batch.cols[projectionColumnNum]).vector[adjustedIndex]); + (int) ((LongColumnVector) colVector).vector[adjustedIndex]); return primitiveWritable; case LONG: ((LongWritable) primitiveWritable).set( - ((LongColumnVector) batch.cols[projectionColumnNum]).vector[adjustedIndex]); + ((LongColumnVector) colVector).vector[adjustedIndex]); return primitiveWritable; case TIMESTAMP: ((TimestampWritable) primitiveWritable).set( - ((TimestampColumnVector) batch.cols[projectionColumnNum]).asScratchTimestamp(adjustedIndex)); + ((TimestampColumnVector) colVector).asScratchTimestamp(adjustedIndex)); return primitiveWritable; case DATE: ((DateWritable) primitiveWritable).set( - (int) ((LongColumnVector) batch.cols[projectionColumnNum]).vector[adjustedIndex]); + (int) ((LongColumnVector) colVector).vector[adjustedIndex]); return primitiveWritable; case FLOAT: ((FloatWritable) primitiveWritable).set( - (float) ((DoubleColumnVector) batch.cols[projectionColumnNum]).vector[adjustedIndex]); + (float) ((DoubleColumnVector) colVector).vector[adjustedIndex]); return primitiveWritable; case DOUBLE: ((DoubleWritable) primitiveWritable).set( - ((DoubleColumnVector) batch.cols[projectionColumnNum]).vector[adjustedIndex]); + ((DoubleColumnVector) colVector).vector[adjustedIndex]); return primitiveWritable; case BINARY: { BytesColumnVector bytesColVector = - ((BytesColumnVector) batch.cols[projectionColumnNum]); + ((BytesColumnVector) colVector); byte[] bytes = bytesColVector.vector[adjustedIndex]; int start = bytesColVector.start[adjustedIndex]; int length = bytesColVector.length[adjustedIndex]; if (bytes == null) { - LOG.info("null binary entry: batchIndex " + batchIndex + " projection column num " + projectionColumnNum); + LOG.info("null binary entry: batchIndex " + batchIndex); } BytesWritable bytesWritable = (BytesWritable) primitiveWritable; @@ -277,13 +269,13 @@ public Object extractRowColumn(VectorizedRowBatch batch, int batchIndex, int log case STRING: { BytesColumnVector bytesColVector = - ((BytesColumnVector) batch.cols[projectionColumnNum]); + ((BytesColumnVector) colVector); byte[] bytes = bytesColVector.vector[adjustedIndex]; int start = bytesColVector.start[adjustedIndex]; int length = bytesColVector.length[adjustedIndex]; if (bytes == null) { - nullBytesReadError(primitiveCategory, batchIndex, projectionColumnNum); + nullBytesReadError(primitiveCategory, batchIndex); } // Use org.apache.hadoop.io.Text as our helper to go from byte[] to String. @@ -293,17 +285,17 @@ public Object extractRowColumn(VectorizedRowBatch batch, int batchIndex, int log case VARCHAR: { BytesColumnVector bytesColVector = - ((BytesColumnVector) batch.cols[projectionColumnNum]); + ((BytesColumnVector) colVector); byte[] bytes = bytesColVector.vector[adjustedIndex]; int start = bytesColVector.start[adjustedIndex]; int length = bytesColVector.length[adjustedIndex]; if (bytes == null) { - nullBytesReadError(primitiveCategory, batchIndex, projectionColumnNum); + nullBytesReadError(primitiveCategory, batchIndex); } int adjustedLength = StringExpr.truncate(bytes, start, length, - maxLengths[logicalColumnIndex]); + ((VarcharTypeInfo) primitiveTypeInfo).getLength()); HiveVarcharWritable hiveVarcharWritable = (HiveVarcharWritable) primitiveWritable; hiveVarcharWritable.set(new String(bytes, start, adjustedLength, Charsets.UTF_8), -1); @@ -312,41 +304,117 @@ public Object extractRowColumn(VectorizedRowBatch batch, int batchIndex, int log case CHAR: { BytesColumnVector bytesColVector = - ((BytesColumnVector) batch.cols[projectionColumnNum]); + ((BytesColumnVector) colVector); byte[] bytes = bytesColVector.vector[adjustedIndex]; int start = bytesColVector.start[adjustedIndex]; int length = bytesColVector.length[adjustedIndex]; if (bytes == null) { - nullBytesReadError(primitiveCategory, batchIndex, projectionColumnNum); + nullBytesReadError(primitiveCategory, batchIndex); } int adjustedLength = StringExpr.rightTrimAndTruncate(bytes, start, length, - maxLengths[logicalColumnIndex]); + ((CharTypeInfo) primitiveTypeInfo).getLength()); HiveCharWritable hiveCharWritable = (HiveCharWritable) primitiveWritable; hiveCharWritable.set(new String(bytes, start, adjustedLength, Charsets.UTF_8), - maxLengths[logicalColumnIndex]); + ((CharTypeInfo) primitiveTypeInfo).getLength()); return primitiveWritable; } case DECIMAL: // The HiveDecimalWritable set method will quickly copy the deserialized decimal writable fields. ((HiveDecimalWritable) primitiveWritable).set( - ((DecimalColumnVector) batch.cols[projectionColumnNum]).vector[adjustedIndex]); + ((DecimalColumnVector) colVector).vector[adjustedIndex]); return primitiveWritable; case INTERVAL_YEAR_MONTH: ((HiveIntervalYearMonthWritable) primitiveWritable).set( - (int) ((LongColumnVector) batch.cols[projectionColumnNum]).vector[adjustedIndex]); + (int) ((LongColumnVector) colVector).vector[adjustedIndex]); return primitiveWritable; case INTERVAL_DAY_TIME: ((HiveIntervalDayTimeWritable) primitiveWritable).set( - ((IntervalDayTimeColumnVector) batch.cols[projectionColumnNum]).asScratchIntervalDayTime(adjustedIndex)); + ((IntervalDayTimeColumnVector) colVector).asScratchIntervalDayTime(adjustedIndex)); return primitiveWritable; default: throw new RuntimeException("Primitive category " + primitiveCategory.name() + " not supported"); } } + case LIST: + { + ListColumnVector listColumnVector = (ListColumnVector) colVector; + ListTypeInfo listTypeInfo = (ListTypeInfo) typeInfo; + int offset = (int) listColumnVector.offsets[adjustedIndex]; + int size = (int) listColumnVector.lengths[adjustedIndex]; + + List list = new ArrayList(); + for (int i = 0; i < size; i++) { + list.add( + extractRowColumn(listColumnVector.child, + listTypeInfo.getListElementTypeInfo(), + offset + i)); + } + return list; + } + case MAP: + { + MapColumnVector mapColumnVector = (MapColumnVector) colVector; + MapTypeInfo mapTypeInfo = (MapTypeInfo) typeInfo; + int offset = (int) mapColumnVector.offsets[adjustedIndex]; + int size = (int) mapColumnVector.lengths[adjustedIndex]; + + Map map = new HashMap(); + for (int i = 0; i < size; i++) { + Object key = extractRowColumn(mapColumnVector.keys, + mapTypeInfo.getMapKeyTypeInfo(), offset + i); + Object value = extractRowColumn(mapColumnVector.values, + mapTypeInfo.getMapValueTypeInfo(), offset + i); + map.put(key, value); + } + return map; + } + case STRUCT: + { + StructColumnVector structColumnVector = (StructColumnVector) colVector; + StructTypeInfo structTypeInfo = (StructTypeInfo) typeInfo; + List fieldNames = structTypeInfo.getAllStructFieldNames(); + List fieldTypeInfos = structTypeInfo.getAllStructFieldTypeInfos(); + int size = fieldTypeInfos.size(); + List objectInspectors = new ArrayList<>(); + for (int i = 0; i < size; i++) { + TypeInfo fieldTypeInfo = fieldTypeInfos.get(i); + objectInspectors.add( + ObjectInspectorFactory.getReflectionObjectInspector(fieldTypeInfo.getClass(), + ObjectInspectorFactory.ObjectInspectorOptions.JAVA)); + } + StandardStructObjectInspector standardStructObjectInspector = + ObjectInspectorFactory.getStandardStructObjectInspector( + fieldNames, objectInspectors); + Object struct = standardStructObjectInspector.create(); + List structFields = + standardStructObjectInspector.getAllStructFieldRefs(); + for (int i = 0; i < size; i++) { + StructField structField = structFields.get(i); + TypeInfo fieldTypeInfo = fieldTypeInfos.get(i); + Object value = extractRowColumn(structColumnVector.fields[i], + fieldTypeInfo, adjustedIndex); + standardStructObjectInspector.setStructFieldData(struct, structField, value); + } + return struct; + } + case UNION: + { + UnionTypeInfo unionTypeInfo = (UnionTypeInfo) typeInfo; + List objectTypeInfos = unionTypeInfo.getAllUnionObjectTypeInfos(); + StandardUnionObjectInspector.StandardUnion standardUnion = + new StandardUnionObjectInspector.StandardUnion(); + UnionColumnVector unionColumnVector = (UnionColumnVector) colVector; + byte tag = (byte) unionColumnVector.tags[adjustedIndex]; + Object object = extractRowColumn( + unionColumnVector.fields[tag], objectTypeInfos.get(tag), adjustedIndex); + standardUnion.setTag(tag); + standardUnion.setObject(object); + return standardUnion; + } default: throw new RuntimeException("Category " + category.name() + " not supported"); } @@ -365,9 +433,8 @@ public void extractRow(VectorizedRowBatch batch, int batchIndex, Object[] object } } - private void nullBytesReadError(PrimitiveCategory primitiveCategory, int batchIndex, - int projectionColumnNum) { + private void nullBytesReadError(PrimitiveCategory primitiveCategory, int batchIndex) { throw new RuntimeException("null " + primitiveCategory.name() + - " entry: batchIndex " + batchIndex + " projection column num " + projectionColumnNum); + " entry: batchIndex " + batchIndex); } } diff --git ql/src/java/org/apache/hadoop/hive/ql/exec/vector/VectorSerializeRow.java ql/src/java/org/apache/hadoop/hive/ql/exec/vector/VectorSerializeRow.java index 319b4a8a42..197ddf922c 100644 --- ql/src/java/org/apache/hadoop/hive/ql/exec/vector/VectorSerializeRow.java +++ ql/src/java/org/apache/hadoop/hive/ql/exec/vector/VectorSerializeRow.java @@ -19,20 +19,22 @@ package org.apache.hadoop.hive.ql.exec.vector; import java.io.IOException; -import java.sql.Timestamp; import java.util.Arrays; import java.util.List; +import java.util.Map; -import org.apache.hadoop.hive.common.type.HiveIntervalDayTime; import org.apache.hadoop.hive.ql.metadata.HiveException; import org.apache.hadoop.hive.serde2.ByteStream.Output; -import org.apache.hadoop.hive.serde2.io.HiveDecimalWritable; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector.Category; import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector.PrimitiveCategory; import org.apache.hadoop.hive.serde2.fast.SerializeWrite; +import org.apache.hadoop.hive.serde2.typeinfo.ListTypeInfo; +import org.apache.hadoop.hive.serde2.typeinfo.MapTypeInfo; import org.apache.hadoop.hive.serde2.typeinfo.PrimitiveTypeInfo; +import org.apache.hadoop.hive.serde2.typeinfo.StructTypeInfo; import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo; import org.apache.hadoop.hive.serde2.typeinfo.TypeInfoUtils; +import org.apache.hadoop.hive.serde2.typeinfo.UnionTypeInfo; /** * This class serializes columns from a row in a VectorizedRowBatch into a serialization format. @@ -50,14 +52,16 @@ private T serializeWrite; - private Category[] categories; - private PrimitiveCategory[] primitiveCategories; + private TypeInfo[] typeInfos; private int[] outputColumnNums; + private VectorExtractRow vectorExtractRow; + public VectorSerializeRow(T serializeWrite) { this(); this.serializeWrite = serializeWrite; + vectorExtractRow = new VectorExtractRow(); } // Not public since we must have the serialize write object. @@ -67,55 +71,44 @@ private VectorSerializeRow() { public void init(List typeNames, int[] columnMap) throws HiveException { final int size = typeNames.size(); - categories = new Category[size]; - primitiveCategories = new PrimitiveCategory[size]; + typeInfos = new TypeInfo[size]; outputColumnNums = Arrays.copyOf(columnMap, size); TypeInfo typeInfo; for (int i = 0; i < size; i++) { typeInfo = TypeInfoUtils.getTypeInfoFromTypeString(typeNames.get(i)); - categories[i] = typeInfo.getCategory(); - if (categories[i] == Category.PRIMITIVE) { - primitiveCategories[i] = ((PrimitiveTypeInfo) typeInfo).getPrimitiveCategory(); - } + typeInfos[i] = typeInfo; } + + vectorExtractRow.init(typeInfos, outputColumnNums); } public void init(List typeNames) throws HiveException { final int size = typeNames.size(); - categories = new Category[size]; - primitiveCategories = new PrimitiveCategory[size]; + typeInfos = new TypeInfo[size]; outputColumnNums = new int[size]; TypeInfo typeInfo; for (int i = 0; i < size; i++) { typeInfo = TypeInfoUtils.getTypeInfoFromTypeString(typeNames.get(i)); - categories[i] = typeInfo.getCategory(); - if (categories[i] == Category.PRIMITIVE) { - primitiveCategories[i] = ((PrimitiveTypeInfo) typeInfo).getPrimitiveCategory(); - } + typeInfos[i] = typeInfo; outputColumnNums[i] = i; } + + vectorExtractRow.init(typeInfos); } public void init(TypeInfo[] typeInfos, int[] columnMap) throws HiveException { final int size = typeInfos.length; - categories = new Category[size]; - primitiveCategories = new PrimitiveCategory[size]; + this.typeInfos = Arrays.copyOf(typeInfos, size); outputColumnNums = Arrays.copyOf(columnMap, size); - TypeInfo typeInfo; - for (int i = 0; i < typeInfos.length; i++) { - typeInfo = typeInfos[i]; - categories[i] = typeInfo.getCategory(); - if (categories[i] == Category.PRIMITIVE) { - primitiveCategories[i] = ((PrimitiveTypeInfo) typeInfo).getPrimitiveCategory(); - } - } + + vectorExtractRow.init(this.typeInfos, outputColumnNums); } public int getCount() { - return categories.length; + return typeInfos.length; } public void setOutput(Output output) { @@ -138,91 +131,188 @@ public void serializeWrite(VectorizedRowBatch batch, int batchIndex) throws IOEx hasAnyNulls = false; isAllNulls = true; ColumnVector colVector; + for (int i = 0; i < typeInfos.length; i++) { + colVector = batch.cols[outputColumnNums[i]]; + serializeWrite(colVector, typeInfos[i], batchIndex); + } + } + + private void serializeWrite( + ColumnVector colVector, TypeInfo typeInfo, int batchIndex) throws IOException { int adjustedBatchIndex; - final int size = categories.length; + if (colVector.isRepeating) { + adjustedBatchIndex = 0; + } else { + adjustedBatchIndex = batchIndex; + } + if (!colVector.noNulls && colVector.isNull[adjustedBatchIndex]) { + serializeWrite.writeNull(); + hasAnyNulls = true; + return; + } + isAllNulls = false; + Category category = typeInfo.getCategory(); + switch (category) { + case PRIMITIVE: + serializePrimitiveWrite(colVector, (PrimitiveTypeInfo) typeInfo, adjustedBatchIndex); + break; + case LIST: + serializeListWrite((ListColumnVector) colVector, (ListTypeInfo) typeInfo, adjustedBatchIndex); + break; + case MAP: + serializeMapWrite((MapColumnVector) colVector, (MapTypeInfo) typeInfo, adjustedBatchIndex); + break; + case STRUCT: + serializeStructWrite((StructColumnVector) colVector, (StructTypeInfo) typeInfo, adjustedBatchIndex); + break; + case UNION: + serializeUnionWrite((UnionColumnVector) colVector, (UnionTypeInfo) typeInfo, adjustedBatchIndex); + break; + default: + throw new RuntimeException("Unexpected category " + category); + } + } + + private void serializeUnionWrite( + UnionColumnVector colVector, UnionTypeInfo typeInfo, int adjustedBatchIndex) throws IOException { + byte tag = (byte) colVector.tags[adjustedBatchIndex]; + ColumnVector fieldColumnVector = colVector.fields[tag]; + TypeInfo objectTypeInfo = typeInfo.getAllUnionObjectTypeInfos().get(tag); + + serializeWrite.beginUnion(tag); + serializeWrite(fieldColumnVector, objectTypeInfo, adjustedBatchIndex); + serializeWrite.finishUnion(); + } + + private void serializeStructWrite( + StructColumnVector colVector, StructTypeInfo typeInfo, int adjustedBatchIndex) throws IOException { + ColumnVector[] fieldColumnVectors = colVector.fields; + List fieldTypeInfos = typeInfo.getAllStructFieldTypeInfos(); + int size = fieldTypeInfos.size(); + + List list = (List) vectorExtractRow.extractRowColumn( + colVector, typeInfo, adjustedBatchIndex); + + serializeWrite.beginStruct(list); for (int i = 0; i < size; i++) { - colVector = batch.cols[outputColumnNums[i]]; - if (colVector.isRepeating) { - adjustedBatchIndex = 0; - } else { - adjustedBatchIndex = batchIndex; + if (i > 0) { + serializeWrite.separateStruct(); + } + serializeWrite(fieldColumnVectors[i], fieldTypeInfos.get(i), adjustedBatchIndex); + } + serializeWrite.finishStruct(); + } + + private void serializeMapWrite( + MapColumnVector colVector, MapTypeInfo typeInfo, int adjustedBatchIndex) throws IOException { + ColumnVector keyColumnVector = colVector.keys; + ColumnVector valueColumnVector = colVector.values; + TypeInfo keyTypeInfo = typeInfo.getMapKeyTypeInfo(); + TypeInfo valueTypeInfo = typeInfo.getMapValueTypeInfo(); + int offset = (int) colVector.offsets[adjustedBatchIndex]; + int size = (int) colVector.lengths[adjustedBatchIndex]; + + Map map = (Map) vectorExtractRow.extractRowColumn( + colVector, typeInfo, adjustedBatchIndex); + + serializeWrite.beginMap(map); + for (int i = 0; i < size; i++) { + if (i > 0) { + serializeWrite.separateKeyValuePair(); + } + serializeWrite(keyColumnVector, keyTypeInfo, offset + i); + serializeWrite.separateKey(); + serializeWrite(valueColumnVector, valueTypeInfo, offset + i); + } + serializeWrite.finishMap(); + } + + private void serializeListWrite( + ListColumnVector colVector, ListTypeInfo typeInfo, int adjustedBatchIndex) throws IOException { + ColumnVector childColumnVector = colVector.child; + TypeInfo elementTypeInfo = typeInfo.getListElementTypeInfo(); + int offset = (int) colVector.offsets[adjustedBatchIndex]; + int size = (int) colVector.lengths[adjustedBatchIndex]; + + List list = (List) vectorExtractRow.extractRowColumn( + colVector, typeInfo, adjustedBatchIndex); + + serializeWrite.beginList(list); + for (int i = 0; i < size; i++) { + if (i > 0) { + serializeWrite.separateList(); + } + serializeWrite(childColumnVector, elementTypeInfo, offset + i); + } + serializeWrite.finishList(); + } + + private void serializePrimitiveWrite( + ColumnVector colVector, PrimitiveTypeInfo typeInfo, int adjustedBatchIndex) throws IOException { + PrimitiveCategory primitiveCategory = typeInfo.getPrimitiveCategory(); + switch (primitiveCategory) { + case BOOLEAN: + serializeWrite.writeBoolean(((LongColumnVector) colVector).vector[adjustedBatchIndex] != 0); + break; + case BYTE: + serializeWrite.writeByte((byte) ((LongColumnVector) colVector).vector[adjustedBatchIndex]); + break; + case SHORT: + serializeWrite.writeShort((short) ((LongColumnVector) colVector).vector[adjustedBatchIndex]); + break; + case INT: + serializeWrite.writeInt((int) ((LongColumnVector) colVector).vector[adjustedBatchIndex]); + break; + case LONG: + serializeWrite.writeLong(((LongColumnVector) colVector).vector[adjustedBatchIndex]); + break; + case DATE: + serializeWrite.writeDate((int) ((LongColumnVector) colVector).vector[adjustedBatchIndex]); + break; + case TIMESTAMP: + serializeWrite.writeTimestamp(((TimestampColumnVector) colVector).asScratchTimestamp(adjustedBatchIndex)); + break; + case FLOAT: + serializeWrite.writeFloat((float) ((DoubleColumnVector) colVector).vector[adjustedBatchIndex]); + break; + case DOUBLE: + serializeWrite.writeDouble(((DoubleColumnVector) colVector).vector[adjustedBatchIndex]); + break; + case STRING: + case CHAR: + case VARCHAR: + { + // We store CHAR and VARCHAR without pads, so write with STRING. + BytesColumnVector bytesColVector = (BytesColumnVector) colVector; + serializeWrite.writeString( + bytesColVector.vector[adjustedBatchIndex], + bytesColVector.start[adjustedBatchIndex], + bytesColVector.length[adjustedBatchIndex]); } - if (!colVector.noNulls && colVector.isNull[adjustedBatchIndex]) { - serializeWrite.writeNull(); - hasAnyNulls = true; - continue; + break; + case BINARY: + { + BytesColumnVector bytesColVector = (BytesColumnVector) colVector; + serializeWrite.writeBinary( + bytesColVector.vector[adjustedBatchIndex], + bytesColVector.start[adjustedBatchIndex], + bytesColVector.length[adjustedBatchIndex]); } - isAllNulls = false; - switch (categories[i]) { - case PRIMITIVE: - switch (primitiveCategories[i]) { - case BOOLEAN: - serializeWrite.writeBoolean(((LongColumnVector) colVector).vector[adjustedBatchIndex] != 0); - break; - case BYTE: - serializeWrite.writeByte((byte) ((LongColumnVector) colVector).vector[adjustedBatchIndex]); - break; - case SHORT: - serializeWrite.writeShort((short) ((LongColumnVector) colVector).vector[adjustedBatchIndex]); - break; - case INT: - serializeWrite.writeInt((int) ((LongColumnVector) colVector).vector[adjustedBatchIndex]); - break; - case LONG: - serializeWrite.writeLong(((LongColumnVector) colVector).vector[adjustedBatchIndex]); - break; - case DATE: - serializeWrite.writeDate((int) ((LongColumnVector) colVector).vector[adjustedBatchIndex]); - break; - case TIMESTAMP: - serializeWrite.writeTimestamp(((TimestampColumnVector) colVector).asScratchTimestamp(adjustedBatchIndex)); - break; - case FLOAT: - serializeWrite.writeFloat((float) ((DoubleColumnVector) colVector).vector[adjustedBatchIndex]); - break; - case DOUBLE: - serializeWrite.writeDouble(((DoubleColumnVector) colVector).vector[adjustedBatchIndex]); - break; - case STRING: - case CHAR: - case VARCHAR: - { - // We store CHAR and VARCHAR without pads, so write with STRING. - BytesColumnVector bytesColVector = (BytesColumnVector) colVector; - serializeWrite.writeString( - bytesColVector.vector[adjustedBatchIndex], - bytesColVector.start[adjustedBatchIndex], - bytesColVector.length[adjustedBatchIndex]); - } - break; - case BINARY: - { - BytesColumnVector bytesColVector = (BytesColumnVector) colVector; - serializeWrite.writeBinary( - bytesColVector.vector[adjustedBatchIndex], - bytesColVector.start[adjustedBatchIndex], - bytesColVector.length[adjustedBatchIndex]); - } - break; - case DECIMAL: - { - DecimalColumnVector decimalColVector = (DecimalColumnVector) colVector; - serializeWrite.writeHiveDecimal(decimalColVector.vector[adjustedBatchIndex], decimalColVector.scale); - } - break; - case INTERVAL_YEAR_MONTH: - serializeWrite.writeHiveIntervalYearMonth((int) ((LongColumnVector) colVector).vector[adjustedBatchIndex]); - break; - case INTERVAL_DAY_TIME: - serializeWrite.writeHiveIntervalDayTime(((IntervalDayTimeColumnVector) colVector).asScratchIntervalDayTime(adjustedBatchIndex)); - break; - default: - throw new RuntimeException("Unexpected primitive category " + primitiveCategories[i]); - } - break; - default: - throw new RuntimeException("Unexpected category " + categories[i]); + break; + case DECIMAL: + { + DecimalColumnVector decimalColVector = (DecimalColumnVector) colVector; + serializeWrite.writeHiveDecimal(decimalColVector.vector[adjustedBatchIndex], decimalColVector.scale); } + break; + case INTERVAL_YEAR_MONTH: + serializeWrite.writeHiveIntervalYearMonth((int) ((LongColumnVector) colVector).vector[adjustedBatchIndex]); + break; + case INTERVAL_DAY_TIME: + serializeWrite.writeHiveIntervalDayTime(((IntervalDayTimeColumnVector) colVector).asScratchIntervalDayTime(adjustedBatchIndex)); + break; + default: + throw new RuntimeException("Unexpected primitive category " + primitiveCategory); } } diff --git ql/src/test/org/apache/hadoop/hive/ql/exec/vector/TestVectorRowObject.java ql/src/test/org/apache/hadoop/hive/ql/exec/vector/TestVectorRowObject.java index e9ce8e8d6e..a8748d0426 100644 --- ql/src/test/org/apache/hadoop/hive/ql/exec/vector/TestVectorRowObject.java +++ ql/src/test/org/apache/hadoop/hive/ql/exec/vector/TestVectorRowObject.java @@ -39,7 +39,13 @@ void examineBatch(VectorizedRowBatch batch, VectorExtractRow vectorExtractRow, vectorExtractRow.extractRow(batch, i, row); Object[] expectedRow = randomRows[firstRandomRowIndex + i]; for (int c = 0; c < rowSize; c++) { - if (!row[c].equals(expectedRow[c])) { + Object actualValue = row[c]; + Object expectedValue = expectedRow[c]; + if (actualValue == null || expectedValue == null) { + if (actualValue != expectedValue) { + fail("Row " + (firstRandomRowIndex + i) + " and column " + c + " mismatch"); + } + } else if (!actualValue.equals(expectedValue)) { fail("Row " + (firstRandomRowIndex + i) + " and column " + c + " mismatch"); } } @@ -51,7 +57,8 @@ void testVectorRowObject(int caseNum, boolean sort, Random r) throws HiveExcepti String[] emptyScratchTypeNames = new String[0]; VectorRandomRowSource source = new VectorRandomRowSource(); - source.init(r); + + source.init(r, VectorRandomRowSource.SupportedTypes.ALL, 4); VectorizedRowBatchCtx batchContext = new VectorizedRowBatchCtx(); batchContext.init(source.rowStructObjectInspector(), emptyScratchTypeNames); @@ -69,7 +76,7 @@ void testVectorRowObject(int caseNum, boolean sort, Random r) throws HiveExcepti VectorExtractRow vectorExtractRow = new VectorExtractRow(); vectorExtractRow.init(source.typeNames()); - Object[][] randomRows = source.randomRows(10000); + Object[][] randomRows = source.randomRows(1000); if (sort) { source.sort(randomRows); } diff --git ql/src/test/org/apache/hadoop/hive/ql/exec/vector/TestVectorSerDeRow.java ql/src/test/org/apache/hadoop/hive/ql/exec/vector/TestVectorSerDeRow.java index 822fff279a..c08965d7dd 100644 --- ql/src/test/org/apache/hadoop/hive/ql/exec/vector/TestVectorSerDeRow.java +++ ql/src/test/org/apache/hadoop/hive/ql/exec/vector/TestVectorSerDeRow.java @@ -19,31 +19,14 @@ package org.apache.hadoop.hive.ql.exec.vector; import java.io.IOException; -import java.sql.Date; -import java.sql.Timestamp; +import java.util.ArrayList; import java.util.Arrays; import java.util.Properties; import java.util.Random; import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.hive.serde.serdeConstants; -import org.apache.hadoop.hive.serde2.OpenCSVSerde; import org.apache.hadoop.hive.serde2.SerDeException; -import org.apache.hadoop.hive.serde2.io.ByteWritable; -import org.apache.hadoop.hive.serde2.io.DateWritable; -import org.apache.hadoop.hive.serde2.io.DoubleWritable; -import org.apache.hadoop.hive.serde2.io.HiveCharWritable; -import org.apache.hadoop.hive.serde2.io.HiveDecimalWritable; -import org.apache.hadoop.hive.serde2.io.HiveIntervalDayTimeWritable; -import org.apache.hadoop.hive.serde2.io.HiveIntervalYearMonthWritable; -import org.apache.hadoop.hive.serde2.io.HiveVarcharWritable; -import org.apache.hadoop.hive.serde2.io.ShortWritable; -import org.apache.hadoop.hive.serde2.io.TimestampWritable; -import org.apache.hadoop.hive.common.type.HiveChar; -import org.apache.hadoop.hive.common.type.HiveDecimal; -import org.apache.hadoop.hive.common.type.HiveIntervalDayTime; -import org.apache.hadoop.hive.common.type.HiveIntervalYearMonth; -import org.apache.hadoop.hive.common.type.HiveVarchar; import org.apache.hadoop.hive.ql.metadata.HiveException; import org.apache.hadoop.hive.serde2.ByteStream.Output; import org.apache.hadoop.hive.serde2.binarysortable.BinarySortableSerDe; @@ -52,26 +35,18 @@ import org.apache.hadoop.hive.serde2.fast.DeserializeRead; import org.apache.hadoop.hive.serde2.lazy.LazySerDeParameters; import org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe; +import org.apache.hadoop.hive.serde2.lazy.VerifyLazy; import org.apache.hadoop.hive.serde2.lazy.fast.LazySimpleDeserializeRead; import org.apache.hadoop.hive.serde2.lazy.fast.LazySimpleSerializeWrite; +import org.apache.hadoop.hive.serde2.lazy.fast.StringToDouble; import org.apache.hadoop.hive.serde2.lazybinary.fast.LazyBinaryDeserializeRead; import org.apache.hadoop.hive.serde2.lazybinary.fast.LazyBinarySerializeWrite; -import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector.PrimitiveCategory; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils; import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector; -import org.apache.hadoop.hive.serde2.typeinfo.CharTypeInfo; -import org.apache.hadoop.hive.serde2.typeinfo.DecimalTypeInfo; -import org.apache.hadoop.hive.serde2.typeinfo.PrimitiveTypeInfo; -import org.apache.hadoop.hive.serde2.typeinfo.VarcharTypeInfo; +import org.apache.hadoop.hive.serde2.objectinspector.UnionObject; +import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo; import org.apache.hadoop.hive.serde2.fast.SerializeWrite; -import org.apache.hadoop.io.BooleanWritable; -import org.apache.hadoop.io.BytesWritable; -import org.apache.hadoop.io.FloatWritable; -import org.apache.hadoop.io.IntWritable; -import org.apache.hadoop.io.LongWritable; -import org.apache.hadoop.io.Text; - -import com.google.common.base.Charsets; import junit.framework.TestCase; @@ -87,209 +62,59 @@ LAZY_SIMPLE } - void deserializeAndVerify(Output output, DeserializeRead deserializeRead, - VectorRandomRowSource source, Object[] expectedRow) - throws HiveException, IOException { - deserializeRead.set(output.getData(), 0, output.getLength()); - PrimitiveCategory[] primitiveCategories = source.primitiveCategories(); - for (int i = 0; i < primitiveCategories.length; i++) { - Object expected = expectedRow[i]; - PrimitiveCategory primitiveCategory = primitiveCategories[i]; - PrimitiveTypeInfo primitiveTypeInfo = source.primitiveTypeInfos()[i]; - if (!deserializeRead.readNextField()) { - throw new HiveException("Unexpected NULL when reading primitiveCategory " + primitiveCategory + - " expected (" + expected.getClass().getName() + ", " + expected.toString() + ") " + - " deserializeRead " + deserializeRead.getClass().getName()); + private void verifyRead( + DeserializeRead deserializeRead, TypeInfo typeInfo, Object expectedObject) throws IOException { + + if (typeInfo.getCategory() == ObjectInspector.Category.PRIMITIVE) { + VectorVerifyFast.verifyDeserializeRead(deserializeRead, typeInfo, expectedObject); + } else { + if (expectedObject instanceof ArrayList && ((ArrayList) expectedObject).size() == 0) { +// fake++; } - switch (primitiveCategory) { - case BOOLEAN: - { - Boolean value = deserializeRead.currentBoolean; - BooleanWritable expectedWritable = (BooleanWritable) expected; - if (!value.equals(expectedWritable.get())) { - TestCase.fail("Boolean field mismatch (expected " + expected + " found " + value + ")"); - } - } - break; - case BYTE: - { - Byte value = deserializeRead.currentByte; - ByteWritable expectedWritable = (ByteWritable) expected; - if (!value.equals(expectedWritable.get())) { - TestCase.fail("Byte field mismatch (expected " + (int) expected + " found " + (int) value + ")"); - } - } - break; - case SHORT: - { - Short value = deserializeRead.currentShort; - ShortWritable expectedWritable = (ShortWritable) expected; - if (!value.equals(expectedWritable.get())) { - TestCase.fail("Short field mismatch (expected " + expected + " found " + value + ")"); - } - } - break; - case INT: - { - Integer value = deserializeRead.currentInt; - IntWritable expectedWritable = (IntWritable) expected; - if (!value.equals(expectedWritable.get())) { - TestCase.fail("Int field mismatch (expected " + expected + " found " + value + ")"); - } - } - break; - case LONG: - { - Long value = deserializeRead.currentLong; - LongWritable expectedWritable = (LongWritable) expected; - if (!value.equals(expectedWritable.get())) { - TestCase.fail("Long field mismatch (expected " + expected + " found " + value + ")"); - } - } - break; - case DATE: - { - DateWritable value = deserializeRead.currentDateWritable; - DateWritable expectedWritable = (DateWritable) expected; - if (!value.equals(expectedWritable)) { - TestCase.fail("Date field mismatch (expected " + expected.toString() + " found " + value.toString() + ")"); - } + Object complexFieldObj = VectorVerifyFast.deserializeReadComplexType(deserializeRead, typeInfo); + if (expectedObject == null) { + if (complexFieldObj != null) { + TestCase.fail("Field reports not null but object is null (class " + complexFieldObj.getClass().getName() + + ", " + complexFieldObj.toString() + ")"); } - break; - case FLOAT: - { - Float value = deserializeRead.currentFloat; - FloatWritable expectedWritable = (FloatWritable) expected; - if (!value.equals(expectedWritable.get())) { - TestCase.fail("Float field mismatch (expected " + expected + " found " + value + ")"); - } - } - break; - case DOUBLE: - { - Double value = deserializeRead.currentDouble; - DoubleWritable expectedWritable = (DoubleWritable) expected; - if (!value.equals(expectedWritable.get())) { - TestCase.fail("Double field mismatch (expected " + expected + " found " + value + ")"); - } - } - break; - case STRING: - case CHAR: - case VARCHAR: - case BINARY: - { - byte[] stringBytes = - Arrays.copyOfRange( - deserializeRead.currentBytes, - deserializeRead.currentBytesStart, - deserializeRead.currentBytesStart + deserializeRead.currentBytesLength); - - Text text = new Text(stringBytes); - String string = text.toString(); - - switch (primitiveCategory) { - case STRING: - { - Text expectedWritable = (Text) expected; - if (!string.equals(expectedWritable.toString())) { - TestCase.fail("String field mismatch (expected '" + expectedWritable.toString() + "' found '" + string + "')"); - } - } - break; - case CHAR: - { - HiveChar hiveChar = new HiveChar(string, ((CharTypeInfo) primitiveTypeInfo).getLength()); - - HiveCharWritable expectedWritable = (HiveCharWritable) expected; - if (!hiveChar.equals(expectedWritable.getHiveChar())) { - TestCase.fail("Char field mismatch (expected '" + expectedWritable.getHiveChar() + "' found '" + hiveChar + "')"); - } - } - break; - case VARCHAR: - { - HiveVarchar hiveVarchar = new HiveVarchar(string, ((VarcharTypeInfo) primitiveTypeInfo).getLength()); - HiveVarcharWritable expectedWritable = (HiveVarcharWritable) expected; - if (!hiveVarchar.equals(expectedWritable.getHiveVarchar())) { - TestCase.fail("Varchar field mismatch (expected '" + expectedWritable.getHiveVarchar() + "' found '" + hiveVarchar + "')"); - } - } - break; - case BINARY: - { - BytesWritable expectedWritable = (BytesWritable) expected; - if (stringBytes.length != expectedWritable.getLength()){ - TestCase.fail("Byte Array field mismatch (expected " + expected + " found " + stringBytes + ")"); - } - byte[] expectedBytes = expectedWritable.getBytes(); - for (int b = 0; b < stringBytes.length; b++) { - if (stringBytes[b] != expectedBytes[b]) { - TestCase.fail("Byte Array field mismatch (expected " + expected + " found " + stringBytes + ")"); - } - } + } else { + if (complexFieldObj == null) { + // It's hard to distinguish a union with null from a null union. + if (expectedObject instanceof UnionObject) { + UnionObject expectedUnion = (UnionObject) expectedObject; + if (expectedUnion.getObject() == null) { + return; } - break; - default: - throw new HiveException("Unexpected primitive category " + primitiveCategory); - } - } - break; - case DECIMAL: - { - HiveDecimal value = deserializeRead.currentHiveDecimalWritable.getHiveDecimal(); - if (value == null) { - TestCase.fail("Decimal field evaluated to NULL"); } - HiveDecimalWritable expectedWritable = (HiveDecimalWritable) expected; - if (!value.equals(expectedWritable.getHiveDecimal())) { - DecimalTypeInfo decimalTypeInfo = (DecimalTypeInfo) primitiveTypeInfo; - int precision = decimalTypeInfo.getPrecision(); - int scale = decimalTypeInfo.getScale(); - TestCase.fail("Decimal field mismatch (expected " + expectedWritable.getHiveDecimal() + " found " + value.toString() + ") precision " + precision + ", scale " + scale); - } - } - break; - case TIMESTAMP: - { - Timestamp value = deserializeRead.currentTimestampWritable.getTimestamp(); - TimestampWritable expectedWritable = (TimestampWritable) expected; - if (!value.equals(expectedWritable.getTimestamp())) { - TestCase.fail("Timestamp field mismatch (expected " + expectedWritable.getTimestamp() + " found " + value.toString() + ")"); + TestCase.fail("Field reports null but object is not null (class " + expectedObject.getClass().getName() + + ", " + expectedObject.toString() + ")"); } } - break; - case INTERVAL_YEAR_MONTH: - { - HiveIntervalYearMonth value = deserializeRead.currentHiveIntervalYearMonthWritable.getHiveIntervalYearMonth(); - HiveIntervalYearMonthWritable expectedWritable = (HiveIntervalYearMonthWritable) expected; - HiveIntervalYearMonth expectedValue = expectedWritable.getHiveIntervalYearMonth(); - if (!value.equals(expectedValue)) { - TestCase.fail("HiveIntervalYearMonth field mismatch (expected " + expectedValue + " found " + value.toString() + ")"); - } - } - break; - case INTERVAL_DAY_TIME: - { - HiveIntervalDayTime value = deserializeRead.currentHiveIntervalDayTimeWritable.getHiveIntervalDayTime(); - HiveIntervalDayTimeWritable expectedWritable = (HiveIntervalDayTimeWritable) expected; - HiveIntervalDayTime expectedValue = expectedWritable.getHiveIntervalDayTime(); - if (!value.equals(expectedValue)) { - TestCase.fail("HiveIntervalDayTime field mismatch (expected " + expectedValue + " found " + value.toString() + ")"); - } + if (!VerifyLazy.lazyCompare(typeInfo, complexFieldObj, expectedObject)) { + TestCase.fail("Comparision failed typeInfo " + typeInfo.toString()); } - break; - - default: - throw new HiveException("Unexpected primitive category " + primitiveCategory); } + } + + void deserializeAndVerify( + Output output, DeserializeRead deserializeRead, + VectorRandomRowSource source, Object[] expectedRow) + throws HiveException, IOException { + + deserializeRead.set(output.getData(), 0, output.getLength()); + TypeInfo[] typeInfos = source.typeInfos(); + for (int i = 0; i < typeInfos.length; i++) { + Object expected = expectedRow[i]; + TypeInfo typeInfo = typeInfos[i]; + verifyRead(deserializeRead, typeInfo, expected); } TestCase.assertTrue(deserializeRead.isEndOfInputReached()); } - void serializeBatch(VectorizedRowBatch batch, VectorSerializeRow vectorSerializeRow, - DeserializeRead deserializeRead, VectorRandomRowSource source, Object[][] randomRows, - int firstRandomRowIndex) throws HiveException, IOException { + void serializeBatch( + VectorizedRowBatch batch, VectorSerializeRow vectorSerializeRow, + DeserializeRead deserializeRead, VectorRandomRowSource source, Object[][] randomRows, + int firstRandomRowIndex) throws HiveException, IOException { Output output = new Output(); for (int i = 0; i < batch.size; i++) { @@ -312,10 +137,20 @@ void serializeBatch(VectorizedRowBatch batch, VectorSerializeRow vectorSerialize void testVectorSerializeRow(Random r, SerializationType serializationType) throws HiveException, IOException, SerDeException { + for (int i = 0; i < 20; i++) { + innerTestVectorSerializeRow(r, serializationType); + } + } + + void innerTestVectorSerializeRow( + Random r, SerializationType serializationType) + throws HiveException, IOException, SerDeException { + String[] emptyScratchTypeNames = new String[0]; VectorRandomRowSource source = new VectorRandomRowSource(); - source.init(r); + + source.init(r, VectorRandomRowSource.SupportedTypes.ALL, 4, false); VectorizedRowBatchCtx batchContext = new VectorizedRowBatchCtx(); batchContext.init(source.rowStructObjectInspector(), emptyScratchTypeNames); @@ -329,22 +164,25 @@ void testVectorSerializeRow(Random r, SerializationType serializationType) SerializeWrite serializeWrite; switch (serializationType) { case BINARY_SORTABLE: - deserializeRead = new BinarySortableDeserializeRead(source.primitiveTypeInfos(), /* useExternalBuffer */ false); + deserializeRead = new BinarySortableDeserializeRead(source.typeInfos(), /* useExternalBuffer */ false); serializeWrite = new BinarySortableSerializeWrite(fieldCount); break; case LAZY_BINARY: - deserializeRead = new LazyBinaryDeserializeRead(source.primitiveTypeInfos(), /* useExternalBuffer */ false); + deserializeRead = new LazyBinaryDeserializeRead(source.typeInfos(), /* useExternalBuffer */ false); serializeWrite = new LazyBinarySerializeWrite(fieldCount); break; case LAZY_SIMPLE: { StructObjectInspector rowObjectInspector = source.rowStructObjectInspector(); - LazySerDeParameters lazySerDeParams = getSerDeParams(rowObjectInspector); - byte separator = (byte) '\t'; - deserializeRead = new LazySimpleDeserializeRead(source.primitiveTypeInfos(), /* useExternalBuffer */ false, - separator, lazySerDeParams); - serializeWrite = new LazySimpleSerializeWrite(fieldCount, - separator, lazySerDeParams); + // Use different separator values. + byte[] separators = new byte[] {(byte) 9, (byte) 2, (byte) 3, (byte) 4, (byte) 5, (byte) 6, (byte) 7, (byte) 8}; + LazySerDeParameters lazySerDeParams = getSerDeParams(rowObjectInspector, separators); + deserializeRead = + new LazySimpleDeserializeRead( + source.typeInfos(), + /* useExternalBuffer */ false, + lazySerDeParams); + serializeWrite = new LazySimpleSerializeWrite(fieldCount, lazySerDeParams); } break; default: @@ -353,7 +191,7 @@ void testVectorSerializeRow(Random r, SerializationType serializationType) VectorSerializeRow vectorSerializeRow = new VectorSerializeRow(serializeWrite); vectorSerializeRow.init(source.typeNames()); - Object[][] randomRows = source.randomRows(100000); + Object[][] randomRows = source.randomRows(2000); int firstRandomRowIndex = 0; for (int i = 0; i < randomRows.length; i++) { Object[] row = randomRows[i]; @@ -372,7 +210,7 @@ void testVectorSerializeRow(Random r, SerializationType serializationType) } void examineBatch(VectorizedRowBatch batch, VectorExtractRow vectorExtractRow, - PrimitiveTypeInfo[] primitiveTypeInfos, Object[][] randomRows, int firstRandomRowIndex ) { + TypeInfo[] typeInfos, Object[][] randomRows, int firstRandomRowIndex ) { int rowSize = vectorExtractRow.getCount(); Object[] row = new Object[rowSize]; @@ -385,12 +223,17 @@ void examineBatch(VectorizedRowBatch batch, VectorExtractRow vectorExtractRow, Object rowObj = row[c]; Object expectedObj = expectedRow[c]; if (rowObj == null) { + if (expectedObj == null) { + continue; + } fail("Unexpected NULL from extractRow. Expected class " + - expectedObj.getClass().getName() + " value " + expectedObj.toString() + + typeInfos[c].getCategory() + " value " + expectedObj + " batch index " + i + " firstRandomRowIndex " + firstRandomRowIndex); } if (!rowObj.equals(expectedObj)) { - fail("Row " + (firstRandomRowIndex + i) + " and column " + c + " mismatch (" + primitiveTypeInfos[c].getPrimitiveCategory() + " actual value " + rowObj + " and expected value " + expectedObj + ")"); + fail("Row " + (firstRandomRowIndex + i) + " and column " + c + " mismatch (" + + typeInfos[c].getCategory() + " actual value " + rowObj + + " and expected value " + expectedObj + ")"); } } } @@ -400,126 +243,10 @@ private Output serializeRow(Object[] row, VectorRandomRowSource source, SerializeWrite serializeWrite) throws HiveException, IOException { Output output = new Output(); serializeWrite.set(output); - PrimitiveTypeInfo[] primitiveTypeInfos = source.primitiveTypeInfos(); - for (int i = 0; i < primitiveTypeInfos.length; i++) { - Object object = row[i]; - PrimitiveCategory primitiveCategory = primitiveTypeInfos[i].getPrimitiveCategory(); - switch (primitiveCategory) { - case BOOLEAN: - { - BooleanWritable expectedWritable = (BooleanWritable) object; - boolean value = expectedWritable.get(); - serializeWrite.writeBoolean(value); - } - break; - case BYTE: - { - ByteWritable expectedWritable = (ByteWritable) object; - byte value = expectedWritable.get(); - serializeWrite.writeByte(value); - } - break; - case SHORT: - { - ShortWritable expectedWritable = (ShortWritable) object; - short value = expectedWritable.get(); - serializeWrite.writeShort(value); - } - break; - case INT: - { - IntWritable expectedWritable = (IntWritable) object; - int value = expectedWritable.get(); - serializeWrite.writeInt(value); - } - break; - case LONG: - { - LongWritable expectedWritable = (LongWritable) object; - long value = expectedWritable.get(); - serializeWrite.writeLong(value); - } - break; - case DATE: - { - DateWritable expectedWritable = (DateWritable) object; - Date value = expectedWritable.get(); - serializeWrite.writeDate(value); - } - break; - case FLOAT: - { - FloatWritable expectedWritable = (FloatWritable) object; - float value = expectedWritable.get(); - serializeWrite.writeFloat(value); - } - break; - case DOUBLE: - { - DoubleWritable expectedWritable = (DoubleWritable) object; - double value = expectedWritable.get(); - serializeWrite.writeDouble(value); - } - break; - case STRING: - { - Text text = (Text) object; - serializeWrite.writeString(text.getBytes(), 0, text.getLength()); - } - break; - case CHAR: - { - HiveCharWritable expectedWritable = (HiveCharWritable) object; - HiveChar value = expectedWritable.getHiveChar(); - serializeWrite.writeHiveChar(value); - } - break; - case VARCHAR: - { - HiveVarcharWritable expectedWritable = (HiveVarcharWritable) object; - HiveVarchar value = expectedWritable.getHiveVarchar(); - serializeWrite.writeHiveVarchar(value); - } - break; - case BINARY: - { - BytesWritable expectedWritable = (BytesWritable) object; - byte[] bytes = expectedWritable.getBytes(); - int length = expectedWritable.getLength(); - serializeWrite.writeBinary(bytes, 0, length); - } - break; - case TIMESTAMP: - { - TimestampWritable expectedWritable = (TimestampWritable) object; - Timestamp value = expectedWritable.getTimestamp(); - serializeWrite.writeTimestamp(value); - } - break; - case INTERVAL_YEAR_MONTH: - { - HiveIntervalYearMonthWritable expectedWritable = (HiveIntervalYearMonthWritable) object; - HiveIntervalYearMonth value = expectedWritable.getHiveIntervalYearMonth(); - serializeWrite.writeHiveIntervalYearMonth(value); - } - break; - case INTERVAL_DAY_TIME: - { - HiveIntervalDayTimeWritable expectedWritable = (HiveIntervalDayTimeWritable) object; - HiveIntervalDayTime value = expectedWritable.getHiveIntervalDayTime(); - serializeWrite.writeHiveIntervalDayTime(value); - } - break; - case DECIMAL: - { - HiveDecimalWritable expectedWritable = (HiveDecimalWritable) object; - HiveDecimal value = expectedWritable.getHiveDecimal(); - serializeWrite.writeHiveDecimal(value, ((DecimalTypeInfo)primitiveTypeInfos[i]).scale()); - } - break; - default: - throw new HiveException("Unexpected primitive category " + primitiveCategory); - } + TypeInfo[] typeInfos = source.typeInfos(); + + for (int i = 0; i < typeInfos.length; i++) { + VectorVerifyFast.serializeWrite(serializeWrite, typeInfos[i], row[i]); } return output; } @@ -531,29 +258,47 @@ private void addToProperties(Properties tbl, String fieldNames, String fieldType tbl.setProperty("columns", fieldNames); tbl.setProperty("columns.types", fieldTypes); - tbl.setProperty(serdeConstants.SERIALIZATION_NULL_FORMAT, "NULL"); + tbl.setProperty(serdeConstants.SERIALIZATION_NULL_FORMAT, "\\N"); } - private LazySerDeParameters getSerDeParams( StructObjectInspector rowObjectInspector) throws SerDeException { - return getSerDeParams(new Configuration(), new Properties(), rowObjectInspector); + private LazySerDeParameters getSerDeParams( + StructObjectInspector rowObjectInspector, byte[] separators) throws SerDeException { + return getSerDeParams(new Configuration(), new Properties(), rowObjectInspector, separators); } - private LazySerDeParameters getSerDeParams(Configuration conf, Properties tbl, StructObjectInspector rowObjectInspector) throws SerDeException { + private LazySerDeParameters getSerDeParams( + Configuration conf, Properties tbl, StructObjectInspector rowObjectInspector, + byte[] separators) throws SerDeException { + String fieldNames = ObjectInspectorUtils.getFieldNames(rowObjectInspector); String fieldTypes = ObjectInspectorUtils.getFieldTypes(rowObjectInspector); addToProperties(tbl, fieldNames, fieldTypes); - return new LazySerDeParameters(conf, tbl, LazySimpleSerDe.class.getName()); + LazySerDeParameters lazySerDeParams = new LazySerDeParameters(conf, tbl, LazySimpleSerDe.class.getName()); + for (int i = 0; i < separators.length; i++) { + lazySerDeParams.setSeparator(i, separators[i]); + } + return lazySerDeParams; } - void testVectorDeserializeRow(Random r, SerializationType serializationType, - boolean alternate1, boolean alternate2, - boolean useExternalBuffer) - throws HiveException, IOException, SerDeException { + void testVectorDeserializeRow( + Random r, SerializationType serializationType, + boolean alternate1, boolean alternate2, boolean useExternalBuffer) + throws HiveException, IOException, SerDeException { + + for (int i = 0; i < 20; i++) { + innerTestVectorDeserializeRow(r, serializationType, alternate1, alternate2, useExternalBuffer); + } + } + + void innerTestVectorDeserializeRow( + Random r, SerializationType serializationType, + boolean alternate1, boolean alternate2, boolean useExternalBuffer) + throws HiveException, IOException, SerDeException { String[] emptyScratchTypeNames = new String[0]; VectorRandomRowSource source = new VectorRandomRowSource(); - source.init(r); + source.init(r, VectorRandomRowSource.SupportedTypes.ALL, 4, false); VectorizedRowBatchCtx batchContext = new VectorizedRowBatchCtx(); batchContext.init(source.rowStructObjectInspector(), emptyScratchTypeNames); @@ -564,7 +309,7 @@ void testVectorDeserializeRow(Random r, SerializationType serializationType, Arrays.fill(cv.isNull, true); } - PrimitiveTypeInfo[] primitiveTypeInfos = source.primitiveTypeInfos(); + TypeInfo[] typeInfos = source.typeInfos(); int fieldCount = source.typeNames().size(); DeserializeRead deserializeRead; SerializeWrite serializeWrite; @@ -572,7 +317,7 @@ void testVectorDeserializeRow(Random r, SerializationType serializationType, case BINARY_SORTABLE: boolean useColumnSortOrderIsDesc = alternate1; if (!useColumnSortOrderIsDesc) { - deserializeRead = new BinarySortableDeserializeRead(source.primitiveTypeInfos(), useExternalBuffer); + deserializeRead = new BinarySortableDeserializeRead(source.typeInfos(), useExternalBuffer); serializeWrite = new BinarySortableSerializeWrite(fieldCount); } else { boolean[] columnSortOrderIsDesc = new boolean[fieldCount]; @@ -596,7 +341,7 @@ void testVectorDeserializeRow(Random r, SerializationType serializationType, } } serializeWrite = new BinarySortableSerializeWrite(columnSortOrderIsDesc, columnNullMarker, columnNotNullMarker); - deserializeRead = new BinarySortableDeserializeRead(source.primitiveTypeInfos(), useExternalBuffer, + deserializeRead = new BinarySortableDeserializeRead(source.typeInfos(), useExternalBuffer, columnSortOrderIsDesc, columnNullMarker, columnNotNullMarker); } @@ -606,7 +351,7 @@ void testVectorDeserializeRow(Random r, SerializationType serializationType, } break; case LAZY_BINARY: - deserializeRead = new LazyBinaryDeserializeRead(source.primitiveTypeInfos(), useExternalBuffer); + deserializeRead = new LazyBinaryDeserializeRead(source.typeInfos(), useExternalBuffer); serializeWrite = new LazyBinarySerializeWrite(fieldCount); break; case LAZY_SIMPLE: @@ -624,7 +369,8 @@ void testVectorDeserializeRow(Random r, SerializationType serializationType, tbl.setProperty(serdeConstants.ESCAPE_CHAR, escapeString); } - LazySerDeParameters lazySerDeParams = getSerDeParams(conf, tbl, rowObjectInspector); + LazySerDeParameters lazySerDeParams = + getSerDeParams(conf, tbl, rowObjectInspector, new byte[] { separator }); if (useLazySimpleEscapes) { // LazySimple seems to throw away everything but \n and \r. @@ -646,10 +392,9 @@ void testVectorDeserializeRow(Random r, SerializationType serializationType, source.addEscapables(needsEscapeStr); } } - deserializeRead = new LazySimpleDeserializeRead(source.primitiveTypeInfos(), useExternalBuffer, - separator, lazySerDeParams); - serializeWrite = new LazySimpleSerializeWrite(fieldCount, - separator, lazySerDeParams); + deserializeRead = + new LazySimpleDeserializeRead(source.typeInfos(), useExternalBuffer, lazySerDeParams); + serializeWrite = new LazySimpleSerializeWrite(fieldCount, lazySerDeParams); } break; default: @@ -667,7 +412,7 @@ void testVectorDeserializeRow(Random r, SerializationType serializationType, VectorExtractRow vectorExtractRow = new VectorExtractRow(); vectorExtractRow.init(source.typeNames()); - Object[][] randomRows = source.randomRows(100000); + Object[][] randomRows = source.randomRows(2000); int firstRandomRowIndex = 0; for (int i = 0; i < randomRows.length; i++) { Object[] row = randomRows[i]; @@ -684,13 +429,13 @@ void testVectorDeserializeRow(Random r, SerializationType serializationType, } batch.size++; if (batch.size == batch.DEFAULT_SIZE) { - examineBatch(batch, vectorExtractRow, primitiveTypeInfos, randomRows, firstRandomRowIndex); + examineBatch(batch, vectorExtractRow, typeInfos, randomRows, firstRandomRowIndex); firstRandomRowIndex = i + 1; batch.reset(); } } if (batch.size > 0) { - examineBatch(batch, vectorExtractRow, primitiveTypeInfos, randomRows, firstRandomRowIndex); + examineBatch(batch, vectorExtractRow, typeInfos, randomRows, firstRandomRowIndex); } } diff --git ql/src/test/org/apache/hadoop/hive/ql/exec/vector/VectorRandomRowSource.java ql/src/test/org/apache/hadoop/hive/ql/exec/vector/VectorRandomRowSource.java index cbde6158e9..476feefb20 100644 --- ql/src/test/org/apache/hadoop/hive/ql/exec/vector/VectorRandomRowSource.java +++ ql/src/test/org/apache/hadoop/hive/ql/exec/vector/VectorRandomRowSource.java @@ -25,21 +25,29 @@ import java.util.List; import java.util.Random; -import org.apache.hadoop.hive.serde2.io.HiveDecimalWritable; -import org.apache.commons.lang.ArrayUtils; -import org.apache.commons.lang.StringUtils; import org.apache.hadoop.hive.common.type.HiveChar; import org.apache.hadoop.hive.common.type.HiveDecimal; import org.apache.hadoop.hive.common.type.HiveIntervalDayTime; import org.apache.hadoop.hive.common.type.HiveIntervalYearMonth; import org.apache.hadoop.hive.common.type.HiveVarchar; import org.apache.hadoop.hive.common.type.RandomTypeUtil; -import org.apache.hadoop.hive.ql.metadata.HiveException; +import org.apache.hadoop.hive.serde2.io.HiveCharWritable; +import org.apache.hadoop.hive.serde2.io.HiveVarcharWritable; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector.Category; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils; import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector.PrimitiveCategory; +import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.MapObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.StandardListObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.StandardMapObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.StandardStructObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.StandardUnionObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.StandardUnionObjectInspector.StandardUnion; +import org.apache.hadoop.hive.serde2.objectinspector.StructField; import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.UnionObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory; import org.apache.hadoop.hive.serde2.objectinspector.primitive.WritableBooleanObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.primitive.WritableByteObjectInspector; @@ -58,11 +66,19 @@ import org.apache.hadoop.hive.serde2.objectinspector.primitive.WritableTimestampObjectInspector; import org.apache.hadoop.hive.serde2.typeinfo.CharTypeInfo; import org.apache.hadoop.hive.serde2.typeinfo.DecimalTypeInfo; +import org.apache.hadoop.hive.serde2.typeinfo.ListTypeInfo; +import org.apache.hadoop.hive.serde2.typeinfo.MapTypeInfo; import org.apache.hadoop.hive.serde2.typeinfo.PrimitiveTypeInfo; +import org.apache.hadoop.hive.serde2.typeinfo.StructTypeInfo; +import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo; import org.apache.hadoop.hive.serde2.typeinfo.TypeInfoUtils; +import org.apache.hadoop.hive.serde2.typeinfo.UnionTypeInfo; import org.apache.hadoop.hive.serde2.typeinfo.VarcharTypeInfo; import org.apache.hive.common.util.DateUtils; +import org.apache.hadoop.io.Text; +import org.apache.hadoop.io.BytesWritable; +import com.google.common.base.Preconditions; import com.google.common.base.Charsets; /** @@ -76,6 +92,14 @@ private List typeNames; + private Category[] categories; + + private TypeInfo[] typeInfos; + + private List objectInspectorList; + + // Primitive. + private PrimitiveCategory[] primitiveCategories; private PrimitiveTypeInfo[] primitiveTypeInfos; @@ -86,6 +110,8 @@ private String[] alphabets; + private boolean allowNull; + private boolean addEscapables; private String needsEscapeStr; @@ -93,6 +119,14 @@ return typeNames; } + public Category[] categories() { + return categories; + } + + public TypeInfo[] typeInfos() { + return typeInfos; + } + public PrimitiveCategory[] primitiveCategories() { return primitiveCategories; } @@ -106,30 +140,37 @@ public StructObjectInspector rowStructObjectInspector() { } public StructObjectInspector partialRowStructObjectInspector(int partialFieldCount) { - ArrayList partialPrimitiveObjectInspectorList = + ArrayList partialObjectInspectorList = new ArrayList(partialFieldCount); List columnNames = new ArrayList(partialFieldCount); for (int i = 0; i < partialFieldCount; i++) { columnNames.add(String.format("partial%d", i)); - partialPrimitiveObjectInspectorList.add( - PrimitiveObjectInspectorFactory.getPrimitiveWritableObjectInspector( - primitiveTypeInfos[i])); + partialObjectInspectorList.add(getObjectInspector(typeInfos[i])); } return ObjectInspectorFactory.getStandardStructObjectInspector( - columnNames, primitiveObjectInspectorList); + columnNames, objectInspectorList); + } + + public enum SupportedTypes { + ALL, PRIMITIVES, ALL_EXCEPT_MAP } - public void init(Random r) { + public void init(Random r, SupportedTypes supportedTypes, int maxComplexDepth) { + init(r, supportedTypes, maxComplexDepth, true); + } + + public void init(Random r, SupportedTypes supportedTypes, int maxComplexDepth, boolean allowNull) { this.r = r; - chooseSchema(); + this.allowNull = allowNull; + chooseSchema(supportedTypes, maxComplexDepth); } /* * For now, exclude CHAR until we determine why there is a difference (blank padding) * serializing with LazyBinarySerializeWrite and the regular SerDe... */ - private static String[] possibleHiveTypeNames = { + private static String[] possibleHivePrimitiveTypeNames = { "boolean", "tinyint", "smallint", @@ -149,7 +190,158 @@ public void init(Random r) { "decimal" }; - private void chooseSchema() { + private static String[] possibleHiveComplexTypeNames = { + "array", + "struct", + "uniontype", + "map" + }; + + private String getRandomTypeName(SupportedTypes supportedTypes) { + String typeName = null; + if (r.nextInt(10 ) != 0) { + typeName = possibleHivePrimitiveTypeNames[r.nextInt(possibleHivePrimitiveTypeNames.length)]; + } else { + switch (supportedTypes) { + case PRIMITIVES: + typeName = possibleHivePrimitiveTypeNames[r.nextInt(possibleHivePrimitiveTypeNames.length)]; + break; + case ALL_EXCEPT_MAP: + typeName = possibleHiveComplexTypeNames[r.nextInt(possibleHiveComplexTypeNames.length - 1)]; + break; + case ALL: + typeName = possibleHiveComplexTypeNames[r.nextInt(possibleHiveComplexTypeNames.length)]; + break; + } + } + return typeName; + } + + private String getDecoratedTypeName(String typeName, SupportedTypes supportedTypes, int depth, int maxDepth) { + depth++; + if (depth < maxDepth) { + supportedTypes = SupportedTypes.PRIMITIVES; + } + if (typeName.equals("char")) { + int maxLength = 1 + r.nextInt(100); + typeName = String.format("char(%d)", maxLength); + } else if (typeName.equals("varchar")) { + int maxLength = 1 + r.nextInt(100); + typeName = String.format("varchar(%d)", maxLength); + } else if (typeName.equals("decimal")) { + typeName = String.format("decimal(%d,%d)", HiveDecimal.SYSTEM_DEFAULT_PRECISION, HiveDecimal.SYSTEM_DEFAULT_SCALE); + } else if (typeName.equals("array")) { + String elementTypeName = getRandomTypeName(supportedTypes); + elementTypeName = getDecoratedTypeName(elementTypeName, supportedTypes, depth, maxDepth); + typeName = String.format("array<%s>", elementTypeName); + } else if (typeName.equals("map")) { + String keyTypeName = getRandomTypeName(SupportedTypes.PRIMITIVES); + keyTypeName = getDecoratedTypeName(keyTypeName, supportedTypes, depth, maxDepth); + String valueTypeName = getRandomTypeName(supportedTypes); + valueTypeName = getDecoratedTypeName(valueTypeName, supportedTypes, depth, maxDepth); + typeName = String.format("map<%s,%s>", keyTypeName, valueTypeName); + } else if (typeName.equals("struct")) { + final int fieldCount = 1 + r.nextInt(10); + StringBuilder sb = new StringBuilder(); + for (int i = 0; i < fieldCount; i++) { + String fieldTypeName = getRandomTypeName(supportedTypes); + fieldTypeName = getDecoratedTypeName(fieldTypeName, supportedTypes, depth, maxDepth); + if (i > 0) { + sb.append(","); + } + sb.append("col"); + sb.append(i); + sb.append(":"); + sb.append(fieldTypeName); + } + typeName = String.format("struct<%s>", sb.toString()); + } else if (typeName.equals("struct") || + typeName.equals("uniontype")) { + final int fieldCount = 1 + r.nextInt(10); + StringBuilder sb = new StringBuilder(); + for (int i = 0; i < fieldCount; i++) { + String fieldTypeName = getRandomTypeName(supportedTypes); + fieldTypeName = getDecoratedTypeName(fieldTypeName, supportedTypes, depth, maxDepth); + if (i > 0) { + sb.append(","); + } + sb.append(fieldTypeName); + } + typeName = String.format("uniontype<%s>", sb.toString()); + } + return typeName; + } + + private ObjectInspector getObjectInspector(TypeInfo typeInfo) { + ObjectInspector objectInspector; + switch (typeInfo.getCategory()) { + case PRIMITIVE: + { + PrimitiveTypeInfo primitiveType = (PrimitiveTypeInfo) typeInfo; + objectInspector = + PrimitiveObjectInspectorFactory. + getPrimitiveWritableObjectInspector(primitiveType); + } + break; + case MAP: + { + MapTypeInfo mapType = (MapTypeInfo) typeInfo; + MapObjectInspector mapInspector = + ObjectInspectorFactory.getStandardMapObjectInspector( + getObjectInspector(mapType.getMapKeyTypeInfo()), + getObjectInspector(mapType.getMapValueTypeInfo())); + objectInspector = mapInspector; + } + break; + case LIST: + { + ListTypeInfo listType = (ListTypeInfo) typeInfo; + ListObjectInspector listInspector = + ObjectInspectorFactory.getStandardListObjectInspector( + getObjectInspector(listType.getListElementTypeInfo())); + objectInspector = listInspector; + } + break; + case STRUCT: + { + StructTypeInfo structType = (StructTypeInfo) typeInfo; + List fieldTypes = structType.getAllStructFieldTypeInfos(); + + List fieldInspectors = new ArrayList(); + for (TypeInfo fieldType : fieldTypes) { + fieldInspectors.add(getObjectInspector(fieldType)); + } + + StructObjectInspector structInspector = + ObjectInspectorFactory.getStandardStructObjectInspector( + structType.getAllStructFieldNames(), fieldInspectors); + objectInspector = structInspector; + } + break; + case UNION: + { + UnionTypeInfo unionType = (UnionTypeInfo) typeInfo; + List fieldTypes = unionType.getAllUnionObjectTypeInfos(); + + List fieldInspectors = new ArrayList(); + for (TypeInfo fieldType : fieldTypes) { + fieldInspectors.add(getObjectInspector(fieldType)); + } + + UnionObjectInspector unionInspector = + ObjectInspectorFactory.getStandardUnionObjectInspector( + fieldInspectors); + objectInspector = unionInspector; + } + break; + default: + throw new RuntimeException("Unexpected category " + typeInfo.getCategory()); + } + Preconditions.checkState(objectInspector != null); + return objectInspector; + } + + private void chooseSchema(SupportedTypes supportedTypes, int maxComplexDepth) { HashSet hashSet = null; boolean allTypes; boolean onlyOne = (r.nextInt(100) == 7); @@ -159,14 +351,27 @@ private void chooseSchema() { } else { allTypes = r.nextBoolean(); if (allTypes) { - // One of each type. - columnCount = possibleHiveTypeNames.length; + switch (supportedTypes) { + case ALL: + columnCount = possibleHivePrimitiveTypeNames.length + possibleHiveComplexTypeNames.length; + break; + case ALL_EXCEPT_MAP: + columnCount = possibleHivePrimitiveTypeNames.length + possibleHiveComplexTypeNames.length - 1; + break; + case PRIMITIVES: + columnCount = possibleHivePrimitiveTypeNames.length; + break; + } hashSet = new HashSet(); } else { columnCount = 1 + r.nextInt(20); } } typeNames = new ArrayList(columnCount); + categories = new Category[columnCount]; + typeInfos = new TypeInfo[columnCount]; + objectInspectorList = new ArrayList(columnCount); + primitiveCategories = new PrimitiveCategory[columnCount]; primitiveTypeInfos = new PrimitiveTypeInfo[columnCount]; primitiveObjectInspectorList = new ArrayList(columnCount); @@ -176,12 +381,26 @@ private void chooseSchema() { String typeName; if (onlyOne) { - typeName = possibleHiveTypeNames[r.nextInt(possibleHiveTypeNames.length)]; + typeName = getRandomTypeName(supportedTypes); } else { int typeNum; if (allTypes) { + int maxTypeNum = 0; + switch (supportedTypes) { + case PRIMITIVES: + maxTypeNum = possibleHivePrimitiveTypeNames.length; + break; + case ALL_EXCEPT_MAP: + maxTypeNum = possibleHivePrimitiveTypeNames.length + possibleHiveComplexTypeNames.length - 1; + break; + case ALL: + maxTypeNum = possibleHivePrimitiveTypeNames.length + possibleHiveComplexTypeNames.length; + break; + } while (true) { - typeNum = r.nextInt(possibleHiveTypeNames.length); + + typeNum = r.nextInt(maxTypeNum); + Integer typeNumInteger = new Integer(typeNum); if (!hashSet.contains(typeNumInteger)) { hashSet.add(typeNumInteger); @@ -189,56 +408,64 @@ private void chooseSchema() { } } } else { - typeNum = r.nextInt(possibleHiveTypeNames.length); + if (supportedTypes == SupportedTypes.PRIMITIVES || r.nextInt(10) != 0) { + typeNum = r.nextInt(possibleHivePrimitiveTypeNames.length); + } else { + typeNum = possibleHivePrimitiveTypeNames.length + r.nextInt(possibleHiveComplexTypeNames.length); + if (supportedTypes == SupportedTypes.ALL_EXCEPT_MAP) { + typeNum--; + } + } + } + if (typeNum < possibleHivePrimitiveTypeNames.length) { + typeName = possibleHivePrimitiveTypeNames[typeNum]; + } else { + typeName = possibleHiveComplexTypeNames[typeNum - possibleHivePrimitiveTypeNames.length]; } - typeName = possibleHiveTypeNames[typeNum]; + } - if (typeName.equals("char")) { - int maxLength = 1 + r.nextInt(100); - typeName = String.format("char(%d)", maxLength); - } else if (typeName.equals("varchar")) { - int maxLength = 1 + r.nextInt(100); - typeName = String.format("varchar(%d)", maxLength); - } else if (typeName.equals("decimal")) { - typeName = String.format("decimal(%d,%d)", HiveDecimal.SYSTEM_DEFAULT_PRECISION, HiveDecimal.SYSTEM_DEFAULT_SCALE); + + String decoratedTypeName = getDecoratedTypeName(typeName, supportedTypes, 0, maxComplexDepth); + + TypeInfo typeInfo; + try { + typeInfo = TypeInfoUtils.getTypeInfoFromTypeString(decoratedTypeName); + } catch (Exception e) { + throw new RuntimeException("Cannot convert type name " + decoratedTypeName + " to a type " + e); } - PrimitiveTypeInfo primitiveTypeInfo = (PrimitiveTypeInfo) TypeInfoUtils.getTypeInfoFromTypeString(typeName); - primitiveTypeInfos[c] = primitiveTypeInfo; - PrimitiveCategory primitiveCategory = primitiveTypeInfo.getPrimitiveCategory(); - primitiveCategories[c] = primitiveCategory; - primitiveObjectInspectorList.add(PrimitiveObjectInspectorFactory.getPrimitiveWritableObjectInspector(primitiveTypeInfo)); - typeNames.add(typeName); - } - rowStructObjectInspector = ObjectInspectorFactory.getStandardStructObjectInspector(columnNames, primitiveObjectInspectorList); - alphabets = new String[columnCount]; - } - public void addBinarySortableAlphabets() { - for (int c = 0; c < columnCount; c++) { - switch (primitiveCategories[c]) { - case STRING: - case CHAR: - case VARCHAR: - byte[] bytes = new byte[10 + r.nextInt(10)]; - for (int i = 0; i < bytes.length; i++) { - bytes[i] = (byte) (32 + r.nextInt(96)); + typeInfos[c] = typeInfo; + Category category = typeInfo.getCategory(); + categories[c] = category; + ObjectInspector objectInspector = getObjectInspector(typeInfo); + switch (category) { + case PRIMITIVE: + { + PrimitiveTypeInfo primitiveTypeInfo = (PrimitiveTypeInfo) typeInfo; + objectInspector = PrimitiveObjectInspectorFactory.getPrimitiveWritableObjectInspector(primitiveTypeInfo); + primitiveTypeInfos[c] = primitiveTypeInfo; + PrimitiveCategory primitiveCategory = primitiveTypeInfo.getPrimitiveCategory(); + primitiveCategories[c] = primitiveCategory; + primitiveObjectInspectorList.add(objectInspector); } - int alwaysIndex = r.nextInt(bytes.length); - bytes[alwaysIndex] = 0; // Must be escaped by BinarySortable. - int alwaysIndex2 = r.nextInt(bytes.length); - bytes[alwaysIndex2] = 1; // Must be escaped by BinarySortable. - alphabets[c] = new String(bytes, Charsets.UTF_8); break; - default: - // No alphabet needed. + case LIST: + case MAP: + case STRUCT: + case UNION: + primitiveObjectInspectorList.add(null); break; + default: + throw new RuntimeException("Unexpected catagory " + category); } - } - } + objectInspectorList.add(objectInspector); - public void addEscapables(String needsEscapeStr) { - addEscapables = true; - this.needsEscapeStr = needsEscapeStr; + if (category == Category.PRIMITIVE) { + } + typeNames.add(decoratedTypeName); + } + rowStructObjectInspector = ObjectInspectorFactory.getStandardStructObjectInspector(columnNames, objectInspectorList); + alphabets = new String[columnCount]; } public Object[][] randomRows(int n) { @@ -252,39 +479,75 @@ public void addEscapables(String needsEscapeStr) { public Object[] randomRow() { Object row[] = new Object[columnCount]; for (int c = 0; c < columnCount; c++) { - Object object = randomObject(c); - if (object == null) { - throw new Error("Unexpected null for column " + c); - } - row[c] = getWritableObject(c, object); - if (row[c] == null) { - throw new Error("Unexpected null for writable for column " + c); - } + row[c] = randomWritable(c); } return row; } - public Object[] randomRow(int columnCount) { - return randomRow(columnCount, r, primitiveObjectInspectorList, primitiveCategories, - primitiveTypeInfos); + public Object[] randomRow(boolean allowNull) { + + Object row[] = new Object[columnCount]; + for (int c = 0; c < columnCount; c++) { + row[c] = randomWritable(typeInfos[c], objectInspectorList.get(c), allowNull); + } + return row; + } + + public Object[] randomPrimitiveRow(int columnCount) { + return randomPrimitiveRow(columnCount, r, primitiveTypeInfos); } - public static Object[] randomRow(int columnCount, Random r, - List primitiveObjectInspectorList, PrimitiveCategory[] primitiveCategories, + public static Object[] randomPrimitiveRow(int columnCount, Random r, PrimitiveTypeInfo[] primitiveTypeInfos) { Object row[] = new Object[columnCount]; for (int c = 0; c < columnCount; c++) { - Object object = randomObject(c, r, primitiveCategories, primitiveTypeInfos); - if (object == null) { - throw new Error("Unexpected null for column " + c); + row[c] = randomPrimitiveObject(r, primitiveTypeInfos[c]); + } + return row; + } + + public static Object[] randomWritablePrimitiveRow(int columnCount, Random r, + PrimitiveTypeInfo[] primitiveTypeInfos) { + Object row[] = new Object[columnCount]; + for (int c = 0; c < columnCount; c++) { + PrimitiveTypeInfo primitiveTypeInfo = primitiveTypeInfos[c]; + ObjectInspector objectInspector = + PrimitiveObjectInspectorFactory.getPrimitiveWritableObjectInspector(primitiveTypeInfo); + Object object = randomPrimitiveObject(r, primitiveTypeInfo); + row[c] = getWritablePrimitiveObject(primitiveTypeInfo, objectInspector, object); + } + return row; + } + + public void addBinarySortableAlphabets() { + for (int c = 0; c < columnCount; c++) { + if (primitiveCategories[c] == null) { + continue; } - row[c] = getWritableObject(c, object, primitiveObjectInspectorList, - primitiveCategories, primitiveTypeInfos); - if (row[c] == null) { - throw new Error("Unexpected null for writable for column " + c); + switch (primitiveCategories[c]) { + case STRING: + case CHAR: + case VARCHAR: + byte[] bytes = new byte[10 + r.nextInt(10)]; + for (int i = 0; i < bytes.length; i++) { + bytes[i] = (byte) (32 + r.nextInt(96)); + } + int alwaysIndex = r.nextInt(bytes.length); + bytes[alwaysIndex] = 0; // Must be escaped by BinarySortable. + int alwaysIndex2 = r.nextInt(bytes.length); + bytes[alwaysIndex2] = 1; // Must be escaped by BinarySortable. + alphabets[c] = new String(bytes, Charsets.UTF_8); + break; + default: + // No alphabet needed. + break; } } - return row; + } + + public void addEscapables(String needsEscapeStr) { + addEscapables = true; + this.needsEscapeStr = needsEscapeStr; } public static void sort(Object[][] rows, ObjectInspector oi) { @@ -303,18 +566,9 @@ public void sort(Object[][] rows) { VectorRandomRowSource.sort(rows, rowStructObjectInspector); } - public Object getWritableObject(int column, Object object) { - return getWritableObject(column, object, primitiveObjectInspectorList, - primitiveCategories, primitiveTypeInfos); - } - - public static Object getWritableObject(int column, Object object, - List primitiveObjectInspectorList, PrimitiveCategory[] primitiveCategories, - PrimitiveTypeInfo[] primitiveTypeInfos) { - ObjectInspector objectInspector = primitiveObjectInspectorList.get(column); - PrimitiveCategory primitiveCategory = primitiveCategories[column]; - PrimitiveTypeInfo primitiveTypeInfo = primitiveTypeInfos[column]; - switch (primitiveCategory) { + public static Object getWritablePrimitiveObject(PrimitiveTypeInfo primitiveTypeInfo, + ObjectInspector objectInspector, Object object) { + switch (primitiveTypeInfo.getPrimitiveCategory()) { case BOOLEAN: return ((WritableBooleanObjectInspector) objectInspector).create((boolean) object); case BYTE: @@ -334,17 +588,17 @@ public static Object getWritableObject(int column, Object object, case STRING: return ((WritableStringObjectInspector) objectInspector).create((String) object); case CHAR: - { - WritableHiveCharObjectInspector writableCharObjectInspector = - new WritableHiveCharObjectInspector( (CharTypeInfo) primitiveTypeInfo); - return writableCharObjectInspector.create((HiveChar) object); - } + { + WritableHiveCharObjectInspector writableCharObjectInspector = + new WritableHiveCharObjectInspector( (CharTypeInfo) primitiveTypeInfo); + return writableCharObjectInspector.create((HiveChar) object); + } case VARCHAR: - { - WritableHiveVarcharObjectInspector writableVarcharObjectInspector = - new WritableHiveVarcharObjectInspector( (VarcharTypeInfo) primitiveTypeInfo); - return writableVarcharObjectInspector.create((HiveVarchar) object); - } + { + WritableHiveVarcharObjectInspector writableVarcharObjectInspector = + new WritableHiveVarcharObjectInspector( (VarcharTypeInfo) primitiveTypeInfo); + return writableVarcharObjectInspector.create((HiveVarchar) object); + } case BINARY: return PrimitiveObjectInspectorFactory.writableBinaryObjectInspector.create((byte[]) object); case TIMESTAMP: @@ -354,106 +608,215 @@ public static Object getWritableObject(int column, Object object, case INTERVAL_DAY_TIME: return ((WritableHiveIntervalDayTimeObjectInspector) objectInspector).create((HiveIntervalDayTime) object); case DECIMAL: - { - WritableHiveDecimalObjectInspector writableDecimalObjectInspector = - new WritableHiveDecimalObjectInspector((DecimalTypeInfo) primitiveTypeInfo); - HiveDecimalWritable result = (HiveDecimalWritable) writableDecimalObjectInspector.create((HiveDecimal) object); - return result; - } + { + WritableHiveDecimalObjectInspector writableDecimalObjectInspector = + new WritableHiveDecimalObjectInspector((DecimalTypeInfo) primitiveTypeInfo); + return writableDecimalObjectInspector.create((HiveDecimal) object); + } default: - throw new Error("Unknown primitive category " + primitiveCategory); + throw new Error("Unknown primitive category " + primitiveTypeInfo.getPrimitiveCategory()); } } - public Object randomObject(int column) { - return randomObject(column, r, primitiveCategories, primitiveTypeInfos, alphabets, addEscapables, needsEscapeStr); + public Object randomWritable(int column) { + return randomWritable(typeInfos[column], objectInspectorList.get(column)); } - public static Object randomObject(int column, Random r, PrimitiveCategory[] primitiveCategories, - PrimitiveTypeInfo[] primitiveTypeInfos) { - return randomObject(column, r, primitiveCategories, primitiveTypeInfos, null, false, ""); - } - - public static Object randomObject(int column, Random r, PrimitiveCategory[] primitiveCategories, - PrimitiveTypeInfo[] primitiveTypeInfos, String[] alphabets, boolean addEscapables, String needsEscapeStr) { - PrimitiveCategory primitiveCategory = primitiveCategories[column]; - PrimitiveTypeInfo primitiveTypeInfo = primitiveTypeInfos[column]; - try { - switch (primitiveCategory) { - case BOOLEAN: - return Boolean.valueOf(r.nextInt(1) == 1); - case BYTE: - return Byte.valueOf((byte) r.nextInt()); - case SHORT: - return Short.valueOf((short) r.nextInt()); - case INT: - return Integer.valueOf(r.nextInt()); - case LONG: - return Long.valueOf(r.nextLong()); - case DATE: - return RandomTypeUtil.getRandDate(r); - case FLOAT: - return Float.valueOf(r.nextFloat() * 10 - 5); - case DOUBLE: - return Double.valueOf(r.nextDouble() * 10 - 5); - case STRING: - case CHAR: - case VARCHAR: - { - String result; - if (alphabets != null && alphabets[column] != null) { - result = RandomTypeUtil.getRandString(r, alphabets[column], r.nextInt(10)); - } else { - result = RandomTypeUtil.getRandString(r); + public Object randomWritable(TypeInfo typeInfo, ObjectInspector objectInspector) { + return randomWritable(typeInfo, objectInspector, allowNull); + } + + public Object randomWritable(TypeInfo typeInfo, ObjectInspector objectInspector, boolean allowNull) { + switch (typeInfo.getCategory()) { + case PRIMITIVE: + { + Object object = randomPrimitiveObject(r, (PrimitiveTypeInfo) typeInfo); + return getWritablePrimitiveObject((PrimitiveTypeInfo) typeInfo, objectInspector, object); + } + case LIST: + { + if (allowNull && r.nextInt(20) == 0) { + return null; + } + // Always generate a list with at least 1 value? + final int elementCount = 1 + r.nextInt(100); + StandardListObjectInspector listObjectInspector = + (StandardListObjectInspector) objectInspector; + ObjectInspector elementObjectInspector = + listObjectInspector.getListElementObjectInspector(); + TypeInfo elementTypeInfo = + TypeInfoUtils.getTypeInfoFromObjectInspector( + elementObjectInspector); + boolean isStringFamily = false; + PrimitiveCategory primitiveCategory = null; + if (elementTypeInfo.getCategory() == Category.PRIMITIVE) { + primitiveCategory = ((PrimitiveTypeInfo) elementTypeInfo).getPrimitiveCategory(); + if (primitiveCategory == PrimitiveCategory.STRING || + primitiveCategory == PrimitiveCategory.BINARY || + primitiveCategory == PrimitiveCategory.CHAR || + primitiveCategory == PrimitiveCategory.VARCHAR) { + isStringFamily = true; } - if (addEscapables && result.length() > 0) { - int escapeCount = 1 + r.nextInt(2); - for (int i = 0; i < escapeCount; i++) { - int index = r.nextInt(result.length()); - String begin = result.substring(0, index); - String end = result.substring(index); - Character needsEscapeChar = needsEscapeStr.charAt(r.nextInt(needsEscapeStr.length())); - result = begin + needsEscapeChar + end; - } + } + Object listObj = listObjectInspector.create(elementCount); + for (int i = 0; i < elementCount; i++) { + Object ele = randomWritable(elementTypeInfo, elementObjectInspector, allowNull); + // UNDONE: For now, a 1-element list with a null element is a null list... + if (ele == null && elementCount == 1) { + return null; } - switch (primitiveCategory) { - case STRING: - return result; - case CHAR: - return new HiveChar(result, ((CharTypeInfo) primitiveTypeInfo).getLength()); - case VARCHAR: - return new HiveVarchar(result, ((VarcharTypeInfo) primitiveTypeInfo).getLength()); - default: - throw new Error("Unknown primitive category " + primitiveCategory); + if (isStringFamily && elementCount == 1) { + switch (primitiveCategory) { + case STRING: + if (((Text) ele).getLength() == 0) { + return null; + } + break; + case BINARY: + if (((BytesWritable) ele).getLength() == 0) { + return null; + } + break; + case CHAR: + if (((HiveCharWritable) ele).getHiveChar().getStrippedValue().isEmpty()) { + return null; + } + break; + case VARCHAR: + if (((HiveVarcharWritable) ele).getHiveVarchar().getValue().isEmpty()) { + return null; + } + break; + default: + throw new RuntimeException("Unexpected primitive category " + primitiveCategory); + } } + listObjectInspector.set(listObj, i, ele); } - case BINARY: - return getRandBinary(r, 1 + r.nextInt(100)); - case TIMESTAMP: - return RandomTypeUtil.getRandTimestamp(r); - case INTERVAL_YEAR_MONTH: - return getRandIntervalYearMonth(r); - case INTERVAL_DAY_TIME: - return getRandIntervalDayTime(r); - case DECIMAL: - return getRandHiveDecimal(r, (DecimalTypeInfo) primitiveTypeInfo); - default: - throw new Error("Unknown primitive category " + primitiveCategory); + return listObj; } - } catch (Exception e) { - throw new RuntimeException("randomObject failed on column " + column + " type " + primitiveCategory, e); + case MAP: + { + if (allowNull && r.nextInt(20) == 0) { + return null; + } + final int keyPairCount = r.nextInt(100); + StandardMapObjectInspector mapObjectInspector = + (StandardMapObjectInspector) objectInspector; + ObjectInspector keyObjectInspector = + mapObjectInspector.getMapKeyObjectInspector(); + TypeInfo keyTypeInfo = + TypeInfoUtils.getTypeInfoFromObjectInspector( + keyObjectInspector); + ObjectInspector valueObjectInspector = + mapObjectInspector.getMapValueObjectInspector(); + TypeInfo valueTypeInfo = + TypeInfoUtils.getTypeInfoFromObjectInspector( + valueObjectInspector); + Object mapObj = mapObjectInspector.create(); + for (int i = 0; i < keyPairCount; i++) { + Object key = randomWritable(keyTypeInfo, keyObjectInspector); + Object value = randomWritable(valueTypeInfo, valueObjectInspector); + mapObjectInspector.put(mapObj, key, value); + } + return mapObj; + } + case STRUCT: + { + if (allowNull && r.nextInt(20) == 0) { + return null; + } + StandardStructObjectInspector structObjectInspector = + (StandardStructObjectInspector) objectInspector; + List fieldRefsList = structObjectInspector.getAllStructFieldRefs(); + final int fieldCount = fieldRefsList.size(); + Object structObj = structObjectInspector.create(); + for (int i = 0; i < fieldCount; i++) { + StructField fieldRef = fieldRefsList.get(i); + ObjectInspector fieldObjectInspector = + fieldRef.getFieldObjectInspector(); + TypeInfo fieldTypeInfo = + TypeInfoUtils.getTypeInfoFromObjectInspector( + fieldObjectInspector); + Object fieldObj = randomWritable(fieldTypeInfo, fieldObjectInspector); + structObjectInspector.setStructFieldData(structObj, fieldRef, fieldObj); + } + return structObj; + } + case UNION: + { + StandardUnionObjectInspector unionObjectInspector = + (StandardUnionObjectInspector) objectInspector; + List objectInspectorList = unionObjectInspector.getObjectInspectors(); + final int unionCount = objectInspectorList.size(); + final byte tag = (byte) r.nextInt(unionCount); + ObjectInspector fieldObjectInspector = + objectInspectorList.get(tag); + TypeInfo fieldTypeInfo = + TypeInfoUtils.getTypeInfoFromObjectInspector( + fieldObjectInspector); + Object fieldObj = randomWritable(fieldTypeInfo, fieldObjectInspector, false); + if (fieldObj == null) { + throw new RuntimeException(); + } + return new StandardUnion(tag, fieldObj); + } + default: + throw new RuntimeException("Unexpected category " + typeInfo.getCategory()); } } - public static HiveChar getRandHiveChar(Random r, CharTypeInfo charTypeInfo, String alphabet) { - int maxLength = 1 + r.nextInt(charTypeInfo.getLength()); - String randomString = RandomTypeUtil.getRandString(r, alphabet, 100); - HiveChar hiveChar = new HiveChar(randomString, maxLength); - return hiveChar; + public Object randomPrimitiveObject(int column) { + return randomPrimitiveObject(r, primitiveTypeInfos[column]); + } + + public static Object randomPrimitiveObject(Random r, PrimitiveTypeInfo primitiveTypeInfo) { + switch (primitiveTypeInfo.getPrimitiveCategory()) { + case BOOLEAN: + return Boolean.valueOf(r.nextBoolean()); + case BYTE: + return Byte.valueOf((byte) r.nextInt()); + case SHORT: + return Short.valueOf((short) r.nextInt()); + case INT: + return Integer.valueOf(r.nextInt()); + case LONG: + return Long.valueOf(r.nextLong()); + case DATE: + return RandomTypeUtil.getRandDate(r); + case FLOAT: + return Float.valueOf(r.nextFloat() * 10 - 5); + case DOUBLE: + return Double.valueOf(r.nextDouble() * 10 - 5); + case STRING: + return RandomTypeUtil.getRandString(r); + case CHAR: + return getRandHiveChar(r, (CharTypeInfo) primitiveTypeInfo); + case VARCHAR: + return getRandHiveVarchar(r, (VarcharTypeInfo) primitiveTypeInfo); + case BINARY: + return getRandBinary(r, 1 + r.nextInt(100)); + case TIMESTAMP: + return RandomTypeUtil.getRandTimestamp(r); + case INTERVAL_YEAR_MONTH: + return getRandIntervalYearMonth(r); + case INTERVAL_DAY_TIME: + return getRandIntervalDayTime(r); + case DECIMAL: + { + HiveDecimal dec = getRandHiveDecimal(r, (DecimalTypeInfo) primitiveTypeInfo); + return dec; + } + default: + throw new Error("Unknown primitive category " + primitiveTypeInfo.getCategory()); + } } public static HiveChar getRandHiveChar(Random r, CharTypeInfo charTypeInfo) { - return getRandHiveChar(r, charTypeInfo, "abcdefghijklmnopqrstuvwxyz"); + int maxLength = 1 + r.nextInt(charTypeInfo.getLength()); + String randomString = RandomTypeUtil.getRandString(r, "abcdefghijklmnopqrstuvwxyz", 100); + HiveChar hiveChar = new HiveChar(randomString, maxLength); + return hiveChar; } public static HiveVarchar getRandHiveVarchar(Random r, VarcharTypeInfo varcharTypeInfo, String alphabet) { diff --git ql/src/test/org/apache/hadoop/hive/ql/exec/vector/VectorVerifyFast.java ql/src/test/org/apache/hadoop/hive/ql/exec/vector/VectorVerifyFast.java new file mode 100644 index 0000000000..dd60c9805c --- /dev/null +++ ql/src/test/org/apache/hadoop/hive/ql/exec/vector/VectorVerifyFast.java @@ -0,0 +1,703 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.hadoop.hive.ql.exec.vector; + +import junit.framework.TestCase; +import org.apache.hadoop.hive.common.type.HiveChar; +import org.apache.hadoop.hive.common.type.HiveDecimal; +import org.apache.hadoop.hive.common.type.HiveIntervalDayTime; +import org.apache.hadoop.hive.common.type.HiveIntervalYearMonth; +import org.apache.hadoop.hive.common.type.HiveVarchar; +import org.apache.hadoop.hive.serde2.fast.DeserializeRead; +import org.apache.hadoop.hive.serde2.fast.SerializeWrite; +import org.apache.hadoop.hive.serde2.io.ByteWritable; +import org.apache.hadoop.hive.serde2.io.DateWritable; +import org.apache.hadoop.hive.serde2.io.DoubleWritable; +import org.apache.hadoop.hive.serde2.io.HiveCharWritable; +import org.apache.hadoop.hive.serde2.io.HiveDecimalWritable; +import org.apache.hadoop.hive.serde2.io.HiveIntervalDayTimeWritable; +import org.apache.hadoop.hive.serde2.io.HiveIntervalYearMonthWritable; +import org.apache.hadoop.hive.serde2.io.HiveVarcharWritable; +import org.apache.hadoop.hive.serde2.io.ShortWritable; +import org.apache.hadoop.hive.serde2.io.TimestampWritable; +import org.apache.hadoop.hive.serde2.objectinspector.StandardUnionObjectInspector; +import org.apache.hadoop.hive.serde2.typeinfo.CharTypeInfo; +import org.apache.hadoop.hive.serde2.typeinfo.DecimalTypeInfo; +import org.apache.hadoop.hive.serde2.typeinfo.ListTypeInfo; +import org.apache.hadoop.hive.serde2.typeinfo.MapTypeInfo; +import org.apache.hadoop.hive.serde2.typeinfo.PrimitiveTypeInfo; +import org.apache.hadoop.hive.serde2.typeinfo.StructTypeInfo; +import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo; +import org.apache.hadoop.hive.serde2.typeinfo.TypeInfoFactory; +import org.apache.hadoop.hive.serde2.typeinfo.UnionTypeInfo; +import org.apache.hadoop.hive.serde2.typeinfo.VarcharTypeInfo; +import org.apache.hadoop.io.BooleanWritable; +import org.apache.hadoop.io.BytesWritable; +import org.apache.hadoop.io.FloatWritable; +import org.apache.hadoop.io.IntWritable; +import org.apache.hadoop.io.LongWritable; +import org.apache.hadoop.io.Text; + +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.sql.Date; +import java.sql.Timestamp; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +public class VectorVerifyFast { + + public static void verifyDeserializeRead( + DeserializeRead deserializeRead, TypeInfo typeInfo, Object object) throws IOException { + + boolean isNull; + + isNull = !deserializeRead.readNextField(); + doVerifyDeserializeRead(deserializeRead, typeInfo, object, isNull); + } + + public static void doVerifyDeserializeRead( + DeserializeRead deserializeRead, TypeInfo typeInfo, Object object, boolean isNull) throws IOException { + if (isNull) { + if (object != null) { + TestCase.fail("Field reports null but object is not null (class " + object.getClass().getName() + ", " + object.toString() + ")"); + } + return; + } else if (object == null) { + TestCase.fail("Field report not null but object is null"); + } + switch (typeInfo.getCategory()) { + case PRIMITIVE: + { + PrimitiveTypeInfo primitiveTypeInfo = (PrimitiveTypeInfo) typeInfo; + switch (primitiveTypeInfo.getPrimitiveCategory()) { + case BOOLEAN: + { + boolean value = deserializeRead.currentBoolean; + if (!(object instanceof BooleanWritable)) { + TestCase.fail("Boolean expected writable not Boolean"); + } + boolean expected = ((BooleanWritable) object).get(); + if (value != expected) { + TestCase.fail("Boolean field mismatch (expected " + expected + " found " + value + ")"); + } + } + break; + case BYTE: + { + byte value = deserializeRead.currentByte; + if (!(object instanceof ByteWritable)) { + TestCase.fail("Byte expected writable not Byte"); + } + byte expected = ((ByteWritable) object).get(); + if (value != expected) { + TestCase.fail("Byte field mismatch (expected " + (int) expected + " found " + (int) value + ")"); + } + } + break; + case SHORT: + { + short value = deserializeRead.currentShort; + if (!(object instanceof ShortWritable)) { + TestCase.fail("Short expected writable not Short"); + } + short expected = ((ShortWritable) object).get(); + if (value != expected) { + TestCase.fail("Short field mismatch (expected " + expected + " found " + value + ")"); + } + } + break; + case INT: + { + int value = deserializeRead.currentInt; + if (!(object instanceof IntWritable)) { + TestCase.fail("Integer expected writable not Integer"); + } + int expected = ((IntWritable) object).get(); + if (value != expected) { + TestCase.fail("Int field mismatch (expected " + expected + " found " + value + ")"); + } + } + break; + case LONG: + { + long value = deserializeRead.currentLong; + if (!(object instanceof LongWritable)) { + TestCase.fail("Long expected writable not Long"); + } + Long expected = ((LongWritable) object).get(); + if (value != expected) { + TestCase.fail("Long field mismatch (expected " + expected + " found " + value + ")"); + } + } + break; + case FLOAT: + { + float value = deserializeRead.currentFloat; + if (!(object instanceof FloatWritable)) { + TestCase.fail("Float expected writable not Float"); + } + float expected = ((FloatWritable) object).get(); + if (value != expected) { + TestCase.fail("Float field mismatch (expected " + expected + " found " + value + ")"); + } + } + break; + case DOUBLE: + { + double value = deserializeRead.currentDouble; + if (!(object instanceof DoubleWritable)) { + TestCase.fail("Double expected writable not Double"); + } + double expected = ((DoubleWritable) object).get(); + if (value != expected) { + TestCase.fail("Double field mismatch (expected " + expected + " found " + value + ")"); + } + } + break; + case STRING: + { + byte[] stringBytes = Arrays.copyOfRange( + deserializeRead.currentBytes, + deserializeRead.currentBytesStart, + deserializeRead.currentBytesStart + deserializeRead.currentBytesLength); + Text text = new Text(stringBytes); + String string = text.toString(); + String expected = ((Text) object).toString(); + if (!string.equals(expected)) { + TestCase.fail("String field mismatch (expected '" + expected + "' found '" + string + "')"); + } + } + break; + case CHAR: + { + byte[] stringBytes = Arrays.copyOfRange( + deserializeRead.currentBytes, + deserializeRead.currentBytesStart, + deserializeRead.currentBytesStart + deserializeRead.currentBytesLength); + Text text = new Text(stringBytes); + String string = text.toString(); + + HiveChar hiveChar = new HiveChar(string, ((CharTypeInfo) primitiveTypeInfo).getLength()); + + HiveChar expected = ((HiveCharWritable) object).getHiveChar(); + if (!hiveChar.equals(expected)) { + TestCase.fail("Char field mismatch (expected '" + expected + "' found '" + hiveChar + "')"); + } + } + break; + case VARCHAR: + { + byte[] stringBytes = Arrays.copyOfRange( + deserializeRead.currentBytes, + deserializeRead.currentBytesStart, + deserializeRead.currentBytesStart + deserializeRead.currentBytesLength); + Text text = new Text(stringBytes); + String string = text.toString(); + + HiveVarchar hiveVarchar = new HiveVarchar(string, ((VarcharTypeInfo) primitiveTypeInfo).getLength()); + + HiveVarchar expected = ((HiveVarcharWritable) object).getHiveVarchar(); + if (!hiveVarchar.equals(expected)) { + TestCase.fail("Varchar field mismatch (expected '" + expected + "' found '" + hiveVarchar + "')"); + } + } + break; + case DECIMAL: + { + HiveDecimal value = deserializeRead.currentHiveDecimalWritable.getHiveDecimal(); + if (value == null) { + TestCase.fail("Decimal field evaluated to NULL"); + } + HiveDecimal expected = ((HiveDecimalWritable) object).getHiveDecimal(); + if (!value.equals(expected)) { + DecimalTypeInfo decimalTypeInfo = (DecimalTypeInfo) primitiveTypeInfo; + int precision = decimalTypeInfo.getPrecision(); + int scale = decimalTypeInfo.getScale(); + TestCase.fail("Decimal field mismatch (expected " + expected.toString() + " found " + value.toString() + ") precision " + precision + ", scale " + scale); + } + } + break; + case DATE: + { + Date value = deserializeRead.currentDateWritable.get(); + Date expected = ((DateWritable) object).get(); + if (!value.equals(expected)) { + TestCase.fail("Date field mismatch (expected " + expected.toString() + " found " + value.toString() + ")"); + } + } + break; + case TIMESTAMP: + { + Timestamp value = deserializeRead.currentTimestampWritable.getTimestamp(); + Timestamp expected = ((TimestampWritable) object).getTimestamp(); + if (!value.equals(expected)) { + TestCase.fail("Timestamp field mismatch (expected " + expected.toString() + " found " + value.toString() + ")"); + } + } + break; + case INTERVAL_YEAR_MONTH: + { + HiveIntervalYearMonth value = deserializeRead.currentHiveIntervalYearMonthWritable.getHiveIntervalYearMonth(); + HiveIntervalYearMonth expected = ((HiveIntervalYearMonthWritable) object).getHiveIntervalYearMonth(); + if (!value.equals(expected)) { + TestCase.fail("HiveIntervalYearMonth field mismatch (expected " + expected.toString() + " found " + value.toString() + ")"); + } + } + break; + case INTERVAL_DAY_TIME: + { + HiveIntervalDayTime value = deserializeRead.currentHiveIntervalDayTimeWritable.getHiveIntervalDayTime(); + HiveIntervalDayTime expected = ((HiveIntervalDayTimeWritable) object).getHiveIntervalDayTime(); + if (!value.equals(expected)) { + TestCase.fail("HiveIntervalDayTime field mismatch (expected " + expected.toString() + " found " + value.toString() + ")"); + } + } + break; + case BINARY: + { + byte[] byteArray = Arrays.copyOfRange( + deserializeRead.currentBytes, + deserializeRead.currentBytesStart, + deserializeRead.currentBytesStart + deserializeRead.currentBytesLength); + BytesWritable bytesWritable = (BytesWritable) object; + byte[] expected = Arrays.copyOfRange(bytesWritable.getBytes(), 0, bytesWritable.getLength()); + if (byteArray.length != expected.length){ + TestCase.fail("Byte Array field mismatch (expected " + Arrays.toString(expected) + + " found " + Arrays.toString(byteArray) + ")"); + } + for (int b = 0; b < byteArray.length; b++) { + if (byteArray[b] != expected[b]) { + TestCase.fail("Byte Array field mismatch (expected " + Arrays.toString(expected) + + " found " + Arrays.toString(byteArray) + ")"); + } + } + } + break; + default: + throw new Error("Unknown primitive category " + primitiveTypeInfo.getPrimitiveCategory()); + } + } + break; + case LIST: + case MAP: + case STRUCT: + case UNION: + throw new Error("Complex types need to be handled separately"); + default: + throw new Error("Unknown category " + typeInfo.getCategory()); + } + } + + public static void serializeWrite(SerializeWrite serializeWrite, + TypeInfo typeInfo, Object object) throws IOException { + if (object == null) { + serializeWrite.writeNull(); + return; + } + switch (typeInfo.getCategory()) { + case PRIMITIVE: + { + PrimitiveTypeInfo primitiveTypeInfo = (PrimitiveTypeInfo) typeInfo; + switch (primitiveTypeInfo.getPrimitiveCategory()) { + case BOOLEAN: + { + boolean value = ((BooleanWritable) object).get(); + serializeWrite.writeBoolean(value); + } + break; + case BYTE: + { + byte value = ((ByteWritable) object).get(); + serializeWrite.writeByte(value); + } + break; + case SHORT: + { + short value = ((ShortWritable) object).get(); + serializeWrite.writeShort(value); + } + break; + case INT: + { + int value = ((IntWritable) object).get(); + serializeWrite.writeInt(value); + } + break; + case LONG: + { + long value = ((LongWritable) object).get(); + serializeWrite.writeLong(value); + } + break; + case FLOAT: + { + float value = ((FloatWritable) object).get(); + serializeWrite.writeFloat(value); + } + break; + case DOUBLE: + { + double value = ((DoubleWritable) object).get(); + serializeWrite.writeDouble(value); + } + break; + case STRING: + { + Text value = (Text) object; + byte[] stringBytes = value.getBytes(); + int stringLength = stringBytes.length; + serializeWrite.writeString(stringBytes, 0, stringLength); + } + break; + case CHAR: + { + HiveChar value = ((HiveCharWritable) object).getHiveChar(); + serializeWrite.writeHiveChar(value); + } + break; + case VARCHAR: + { + HiveVarchar value = ((HiveVarcharWritable) object).getHiveVarchar(); + serializeWrite.writeHiveVarchar(value); + } + break; + case DECIMAL: + { + HiveDecimal value = ((HiveDecimalWritable) object).getHiveDecimal(); + DecimalTypeInfo decTypeInfo = (DecimalTypeInfo)primitiveTypeInfo; + serializeWrite.writeHiveDecimal(value, decTypeInfo.scale()); + } + break; + case DATE: + { + Date value = ((DateWritable) object).get(); + serializeWrite.writeDate(value); + } + break; + case TIMESTAMP: + { + Timestamp value = ((TimestampWritable) object).getTimestamp(); + serializeWrite.writeTimestamp(value); + } + break; + case INTERVAL_YEAR_MONTH: + { + HiveIntervalYearMonth value = ((HiveIntervalYearMonthWritable) object).getHiveIntervalYearMonth(); + serializeWrite.writeHiveIntervalYearMonth(value); + } + break; + case INTERVAL_DAY_TIME: + { + HiveIntervalDayTime value = ((HiveIntervalDayTimeWritable) object).getHiveIntervalDayTime(); + serializeWrite.writeHiveIntervalDayTime(value); + } + break; + case BINARY: + { + BytesWritable byteWritable = (BytesWritable) object; + byte[] binaryBytes = byteWritable.getBytes(); + int length = byteWritable.getLength(); + serializeWrite.writeBinary(binaryBytes, 0, length); + } + break; + default: + throw new Error("Unknown primitive category " + primitiveTypeInfo.getPrimitiveCategory().name()); + } + } + break; + case LIST: + { + ListTypeInfo listTypeInfo = (ListTypeInfo) typeInfo; + TypeInfo elementTypeInfo = listTypeInfo.getListElementTypeInfo(); + ArrayList elements = (ArrayList) object; + serializeWrite.beginList(elements); + boolean isFirst = true; + for (Object elementObject : elements) { + if (isFirst) { + isFirst = false; + } else { + serializeWrite.separateList(); + } + if (elementObject == null) { + serializeWrite.writeNull(); + } else { + serializeWrite(serializeWrite, elementTypeInfo, elementObject); + } + } + serializeWrite.finishList(); + } + break; + case MAP: + { + MapTypeInfo mapTypeInfo = (MapTypeInfo) typeInfo; + TypeInfo keyTypeInfo = mapTypeInfo.getMapKeyTypeInfo(); + TypeInfo valueTypeInfo = mapTypeInfo.getMapValueTypeInfo(); + HashMap hashMap = (HashMap) object; + serializeWrite.beginMap(hashMap); + boolean isFirst = true; + for (Map.Entry entry : hashMap.entrySet()) { + if (isFirst) { + isFirst = false; + } else { + serializeWrite.separateKeyValuePair(); + } + if (entry.getKey() == null) { + serializeWrite.writeNull(); + } else { + serializeWrite(serializeWrite, keyTypeInfo, entry.getKey()); + } + serializeWrite.separateKey(); + if (entry.getValue() == null) { + serializeWrite.writeNull(); + } else { + serializeWrite(serializeWrite, valueTypeInfo, entry.getValue()); + } + } + serializeWrite.finishMap(); + } + break; + case STRUCT: + { + StructTypeInfo structTypeInfo = (StructTypeInfo) typeInfo; + ArrayList fieldTypeInfos = structTypeInfo.getAllStructFieldTypeInfos(); + ArrayList fieldValues = (ArrayList) object; + final int size = fieldValues.size(); + serializeWrite.beginStruct(fieldValues); + boolean isFirst = true; + for (int i = 0; i < size; i++) { + if (isFirst) { + isFirst = false; + } else { + serializeWrite.separateStruct(); + } + serializeWrite(serializeWrite, fieldTypeInfos.get(i), fieldValues.get(i)); + } + serializeWrite.finishStruct(); + } + break; + case UNION: + { + UnionTypeInfo unionTypeInfo = (UnionTypeInfo) typeInfo; + List fieldTypeInfos = unionTypeInfo.getAllUnionObjectTypeInfos(); + final int size = fieldTypeInfos.size(); + StandardUnionObjectInspector.StandardUnion standardUnion = (StandardUnionObjectInspector.StandardUnion) object; + byte tag = standardUnion.getTag(); + serializeWrite.beginUnion(tag); + serializeWrite(serializeWrite, fieldTypeInfos.get(tag), standardUnion.getObject()); + serializeWrite.finishUnion(); + } + break; + default: + throw new Error("Unknown category " + typeInfo.getCategory().name()); + } + } + + public Object readComplexPrimitiveField(DeserializeRead deserializeRead, + PrimitiveTypeInfo primitiveTypeInfo) throws IOException { + boolean isNull = !deserializeRead.readComplexField(); + if (isNull) { + return null; + } else { + return doReadComplexPrimitiveField(deserializeRead, primitiveTypeInfo); + } + } + + private static Object doReadComplexPrimitiveField(DeserializeRead deserializeRead, + PrimitiveTypeInfo primitiveTypeInfo) throws IOException { + switch (primitiveTypeInfo.getPrimitiveCategory()) { + case BOOLEAN: + return new BooleanWritable(deserializeRead.currentBoolean); + case BYTE: + return new ByteWritable(deserializeRead.currentByte); + case SHORT: + return new ShortWritable(deserializeRead.currentShort); + case INT: + return new IntWritable(deserializeRead.currentInt); + case LONG: + return new LongWritable(deserializeRead.currentLong); + case FLOAT: + return new FloatWritable(deserializeRead.currentFloat); + case DOUBLE: + return new DoubleWritable(deserializeRead.currentDouble); + case STRING: + return new Text(new String( + deserializeRead.currentBytes, + deserializeRead.currentBytesStart, + deserializeRead.currentBytesLength, + StandardCharsets.UTF_8)); + case CHAR: + return new HiveCharWritable(new HiveChar( + new String( + deserializeRead.currentBytes, + deserializeRead.currentBytesStart, + deserializeRead.currentBytesLength, + StandardCharsets.UTF_8), + ((CharTypeInfo) primitiveTypeInfo).getLength())); + case VARCHAR: + if (deserializeRead.currentBytes == null) { + throw new RuntimeException(); + } + return new HiveVarcharWritable(new HiveVarchar( + new String( + deserializeRead.currentBytes, + deserializeRead.currentBytesStart, + deserializeRead.currentBytesLength, + StandardCharsets.UTF_8), + ((VarcharTypeInfo) primitiveTypeInfo).getLength())); + case DECIMAL: + return new HiveDecimalWritable(deserializeRead.currentHiveDecimalWritable); + case DATE: + return new DateWritable(deserializeRead.currentDateWritable); + case TIMESTAMP: + return new TimestampWritable(deserializeRead.currentTimestampWritable); + case INTERVAL_YEAR_MONTH: + return new HiveIntervalYearMonthWritable(deserializeRead.currentHiveIntervalYearMonthWritable); + case INTERVAL_DAY_TIME: + return new HiveIntervalDayTimeWritable(deserializeRead.currentHiveIntervalDayTimeWritable); + case BINARY: + return new BytesWritable( + Arrays.copyOfRange( + deserializeRead.currentBytes, + deserializeRead.currentBytesStart, + deserializeRead.currentBytesLength + deserializeRead.currentBytesStart)); + default: + throw new Error("Unknown primitive category " + primitiveTypeInfo.getPrimitiveCategory()); + } + } + + public static Object deserializeReadComplexType(DeserializeRead deserializeRead, + TypeInfo typeInfo) throws IOException { + + boolean isNull = !deserializeRead.readNextField(); + if (isNull) { + return null; + } + return getComplexField(deserializeRead, typeInfo); + } + + static int fake = 0; + + private static Object getComplexField(DeserializeRead deserializeRead, + TypeInfo typeInfo) throws IOException { + switch (typeInfo.getCategory()) { + case PRIMITIVE: + return doReadComplexPrimitiveField(deserializeRead, (PrimitiveTypeInfo) typeInfo); + case LIST: + { + ListTypeInfo listTypeInfo = (ListTypeInfo) typeInfo; + TypeInfo elementTypeInfo = listTypeInfo.getListElementTypeInfo(); + ArrayList list = new ArrayList(); + Object eleObj; + boolean isNull; + while (deserializeRead.isNextComplexMultiValue()) { + isNull = !deserializeRead.readComplexField(); + if (isNull) { + eleObj = null; + } else { + eleObj = getComplexField(deserializeRead, elementTypeInfo); + if (eleObj instanceof String && ((String) eleObj).equals("SMNAR")) { + fake++; + } + } + list.add(eleObj); + } + return list; + } + case MAP: + { + MapTypeInfo mapTypeInfo = (MapTypeInfo) typeInfo; + TypeInfo keyTypeInfo = mapTypeInfo.getMapKeyTypeInfo(); + TypeInfo valueTypeInfo = mapTypeInfo.getMapValueTypeInfo(); + HashMap hashMap = new HashMap(); + Object keyObj; + Object valueObj; + boolean isNull; + while (deserializeRead.isNextComplexMultiValue()) { + isNull = !deserializeRead.readComplexField(); + if (isNull) { + keyObj = null; + } else { + keyObj = getComplexField(deserializeRead, keyTypeInfo); + } + isNull = !deserializeRead.readComplexField(); + if (isNull) { + valueObj = null; + } else { + valueObj = getComplexField(deserializeRead, valueTypeInfo); + } + hashMap.put(keyObj, valueObj); + } + return hashMap; + } + case STRUCT: + { + StructTypeInfo structTypeInfo = (StructTypeInfo) typeInfo; + ArrayList fieldTypeInfos = structTypeInfo.getAllStructFieldTypeInfos(); + final int size = fieldTypeInfos.size(); + ArrayList fieldValues = new ArrayList(); + Object fieldObj; + boolean isNull; + for (int i = 0; i < size; i++) { + isNull = !deserializeRead.readComplexField(); + if (isNull) { + fieldObj = null; + } else { + fieldObj = getComplexField(deserializeRead, fieldTypeInfos.get(i)); + } + fieldValues.add(fieldObj); + } + deserializeRead.finishComplexVariableFieldsType(); + return fieldValues; + } + case UNION: + { + UnionTypeInfo unionTypeInfo = (UnionTypeInfo) typeInfo; + List unionTypeInfos = unionTypeInfo.getAllUnionObjectTypeInfos(); + final int size = unionTypeInfos.size(); + Object tagObj; + int tag; + Object unionObj; + boolean isNull = !deserializeRead.readComplexField(); + if (isNull) { + unionObj = null; + } else { + // Get the tag value. + tagObj = getComplexField(deserializeRead, TypeInfoFactory.intTypeInfo); + tag = ((IntWritable) tagObj).get(); + + isNull = !deserializeRead.readComplexField(); + if (isNull) { + unionObj = null; + } else { + // Get the union value. + unionObj = new StandardUnionObjectInspector.StandardUnion((byte) tag, getComplexField(deserializeRead, unionTypeInfos.get(tag))); + } + } + + deserializeRead.finishComplexVariableFieldsType(); + return unionObj; + } + default: + throw new Error("Unexpected category " + typeInfo.getCategory()); + } + }} diff --git ql/src/test/org/apache/hadoop/hive/ql/exec/vector/mapjoin/fast/CheckFastRowHashMap.java ql/src/test/org/apache/hadoop/hive/ql/exec/vector/mapjoin/fast/CheckFastRowHashMap.java index 72fceb94b7..77a95c367b 100644 --- ql/src/test/org/apache/hadoop/hive/ql/exec/vector/mapjoin/fast/CheckFastRowHashMap.java +++ ql/src/test/org/apache/hadoop/hive/ql/exec/vector/mapjoin/fast/CheckFastRowHashMap.java @@ -36,7 +36,10 @@ import org.apache.hadoop.hive.serde2.WriteBuffers; import org.apache.hadoop.hive.serde2.io.ByteWritable; import org.apache.hadoop.hive.serde2.io.ShortWritable; +import org.apache.hadoop.hive.serde2.lazy.VerifyLazy; import org.apache.hadoop.hive.serde2.lazybinary.fast.LazyBinaryDeserializeRead; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.UnionObject; import org.apache.hadoop.hive.serde2.typeinfo.PrimitiveTypeInfo; import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo; import org.apache.hadoop.io.BooleanWritable; @@ -76,8 +79,7 @@ public static void verifyHashMapRows(List rows, int[] actualToValueMap lazyBinaryDeserializeRead.set(bytes, offset, length); for (int index = 0; index < columnCount; index++) { - Writable writable = (Writable) row[index]; - VerifyFastRow.verifyDeserializeRead(lazyBinaryDeserializeRead, (PrimitiveTypeInfo) typeInfos[index], writable); + verifyRead(lazyBinaryDeserializeRead, typeInfos[index], row[index]); } TestCase.assertTrue(lazyBinaryDeserializeRead.isEndOfInputReached()); @@ -132,8 +134,7 @@ public static void verifyHashMapRowsMore(List rows, int[] actualToValu int index = 0; try { for (index = 0; index < columnCount; index++) { - Writable writable = (Writable) row[index]; - VerifyFastRow.verifyDeserializeRead(lazyBinaryDeserializeRead, (PrimitiveTypeInfo) typeInfos[index], writable); + verifyRead(lazyBinaryDeserializeRead, typeInfos[index], row[index]); } } catch (Exception e) { thrown = true; @@ -175,6 +176,39 @@ public static void verifyHashMapRowsMore(List rows, int[] actualToValu } } + private static void verifyRead(LazyBinaryDeserializeRead lazyBinaryDeserializeRead, + TypeInfo typeInfo, Object expectedObject) throws IOException { + if (typeInfo.getCategory() == ObjectInspector.Category.PRIMITIVE) { + VerifyFastRow.verifyDeserializeRead(lazyBinaryDeserializeRead, typeInfo, expectedObject); + } else { + if (expectedObject instanceof ArrayList && ((ArrayList) expectedObject).size() == 0) { +// fake++; + } + Object complexFieldObj = VerifyFastRow.deserializeReadComplexType(lazyBinaryDeserializeRead, typeInfo); + if (expectedObject == null) { + if (complexFieldObj != null) { + TestCase.fail("Field reports not null but object is null (class " + complexFieldObj.getClass().getName() + + ", " + complexFieldObj.toString() + ")"); + } + } else { + if (complexFieldObj == null) { + // It's hard to distinguish a union with null from a null union. + if (expectedObject instanceof UnionObject) { + UnionObject expectedUnion = (UnionObject) expectedObject; + if (expectedUnion.getObject() == null) { + return; + } + } + TestCase.fail("Field reports null but object is not null (class " + expectedObject.getClass().getName() + + ", " + expectedObject.toString() + ")"); + } + } + if (!VerifyLazy.lazyCompare(typeInfo, complexFieldObj, expectedObject)) { + TestCase.fail("Comparision failed typeInfo " + typeInfo.toString()); + } + } + } + /* * Element for Key: row and byte[] x Hash Table: HashMap */ @@ -283,7 +317,7 @@ public void add(byte[] key, Object[] keyRow, byte[] value, Object[] valueRow) { public void verify(VectorMapJoinFastHashTable map, HashTableKeyType hashTableKeyType, - PrimitiveTypeInfo[] valuePrimitiveTypeInfos, boolean doClipping, + TypeInfo[] valueTypeInfos, boolean doClipping, boolean useExactBytes, Random random) throws IOException { int mapSize = map.size(); if (mapSize != count) { @@ -368,10 +402,10 @@ public void verify(VectorMapJoinFastHashTable map, List rows = element.getValueRows(); if (!doClipping && !useExactBytes) { - verifyHashMapRows(rows, actualToValueMap, hashMapResult, valuePrimitiveTypeInfos); + verifyHashMapRows(rows, actualToValueMap, hashMapResult, valueTypeInfos); } else { int clipIndex = random.nextInt(rows.size()); - verifyHashMapRowsMore(rows, actualToValueMap, hashMapResult, valuePrimitiveTypeInfos, + verifyHashMapRowsMore(rows, actualToValueMap, hashMapResult, valueTypeInfos, clipIndex, useExactBytes); } } diff --git ql/src/test/org/apache/hadoop/hive/ql/exec/vector/mapjoin/fast/TestVectorMapJoinFastRowHashMap.java ql/src/test/org/apache/hadoop/hive/ql/exec/vector/mapjoin/fast/TestVectorMapJoinFastRowHashMap.java index ebb243e28c..82d9e29fce 100644 --- ql/src/test/org/apache/hadoop/hive/ql/exec/vector/mapjoin/fast/TestVectorMapJoinFastRowHashMap.java +++ ql/src/test/org/apache/hadoop/hive/ql/exec/vector/mapjoin/fast/TestVectorMapJoinFastRowHashMap.java @@ -34,13 +34,12 @@ import org.apache.hadoop.hive.serde2.fast.SerializeWrite; import org.apache.hadoop.hive.serde2.lazybinary.fast.LazyBinarySerializeWrite; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; -import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector.PrimitiveCategory; import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory; import org.apache.hadoop.hive.serde2.typeinfo.PrimitiveTypeInfo; +import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo; import org.apache.hadoop.hive.serde2.typeinfo.TypeInfoUtils; import org.apache.hadoop.io.BytesWritable; -import org.apache.hadoop.io.Writable; import org.junit.Test; /* @@ -83,8 +82,8 @@ private void addAndVerifyRows(VectorRandomRowSource valueSource, Object[][] rows keyColumnNullMarker, keyColumnNotNullMarker); - PrimitiveTypeInfo[] valuePrimitiveTypeInfos = valueSource.primitiveTypeInfos(); - final int columnCount = valuePrimitiveTypeInfos.length; + TypeInfo[] valueTypeInfos = valueSource.typeInfos(); + final int columnCount = valueTypeInfos.length; SerializeWrite valueSerializeWrite = new LazyBinarySerializeWrite(columnCount); @@ -97,10 +96,7 @@ private void addAndVerifyRows(VectorRandomRowSource valueSource, Object[][] rows ((LazyBinarySerializeWrite) valueSerializeWrite).set(valueOutput); for (int index = 0; index < columnCount; index++) { - - Writable writable = (Writable) valueRow[index]; - - VerifyFastRow.serializeWrite(valueSerializeWrite, valuePrimitiveTypeInfos[index], writable); + VerifyFastRow.serializeWrite(valueSerializeWrite, valueTypeInfos[index], valueRow[index]); } byte[] value = Arrays.copyOf(valueOutput.getData(), valueOutput.getLength()); @@ -109,17 +105,13 @@ private void addAndVerifyRows(VectorRandomRowSource valueSource, Object[][] rows byte[] key; if (random.nextBoolean() || verifyTable.getCount() == 0) { Object[] keyRow = - VectorRandomRowSource.randomRow(keyCount, random, keyPrimitiveObjectInspectorList, - keyPrimitiveCategories, keyPrimitiveTypeInfos); + VectorRandomRowSource.randomWritablePrimitiveRow(keyCount, random, keyPrimitiveTypeInfos); Output keyOutput = new Output(); keySerializeWrite.set(keyOutput); for (int index = 0; index < keyCount; index++) { - - Writable writable = (Writable) keyRow[index]; - - VerifyFastRow.serializeWrite(keySerializeWrite, keyPrimitiveTypeInfos[index], writable); + VerifyFastRow.serializeWrite(keySerializeWrite, keyPrimitiveTypeInfos[index], keyRow[index]); } key = Arrays.copyOf(keyOutput.getData(), keyOutput.getLength()); @@ -135,7 +127,7 @@ private void addAndVerifyRows(VectorRandomRowSource valueSource, Object[][] rows map.putRow(keyWritable, valueWritable); // verifyTable.verify(map); } - verifyTable.verify(map, hashTableKeyType, valuePrimitiveTypeInfos, + verifyTable.verify(map, hashTableKeyType, valueTypeInfos, doClipping, useExactBytes, random); } @@ -152,9 +144,10 @@ public void testBigIntRows() throws Exception { VerifyFastRowHashMap verifyTable = new VerifyFastRowHashMap(); VectorRandomRowSource valueSource = new VectorRandomRowSource(); - valueSource.init(random); + + valueSource.init(random, VectorRandomRowSource.SupportedTypes.ALL, 4, false); - int rowCount = 10000; + int rowCount = 1000; Object[][] rows = valueSource.randomRows(rowCount); addAndVerifyRows(valueSource, rows, @@ -176,9 +169,10 @@ public void testIntRows() throws Exception { VerifyFastRowHashMap verifyTable = new VerifyFastRowHashMap(); VectorRandomRowSource valueSource = new VectorRandomRowSource(); - valueSource.init(random); - int rowCount = 10000; + valueSource.init(random, VectorRandomRowSource.SupportedTypes.ALL, 4, false); + + int rowCount = 1000; Object[][] rows = valueSource.randomRows(rowCount); addAndVerifyRows(valueSource, rows, @@ -200,9 +194,10 @@ public void testStringRows() throws Exception { VerifyFastRowHashMap verifyTable = new VerifyFastRowHashMap(); VectorRandomRowSource valueSource = new VectorRandomRowSource(); - valueSource.init(random); + + valueSource.init(random, VectorRandomRowSource.SupportedTypes.ALL, 4, false); - int rowCount = 10000; + int rowCount = 1000; Object[][] rows = valueSource.randomRows(rowCount); addAndVerifyRows(valueSource, rows, @@ -224,9 +219,10 @@ public void testMultiKeyRows1() throws Exception { VerifyFastRowHashMap verifyTable = new VerifyFastRowHashMap(); VectorRandomRowSource valueSource = new VectorRandomRowSource(); - valueSource.init(random); - int rowCount = 10000; + valueSource.init(random, VectorRandomRowSource.SupportedTypes.ALL, 4, false); + + int rowCount = 1000; Object[][] rows = valueSource.randomRows(rowCount); addAndVerifyRows(valueSource, rows, @@ -248,9 +244,11 @@ public void testMultiKeyRows2() throws Exception { VerifyFastRowHashMap verifyTable = new VerifyFastRowHashMap(); VectorRandomRowSource valueSource = new VectorRandomRowSource(); - valueSource.init(random); - int rowCount = 10000; + + valueSource.init(random, VectorRandomRowSource.SupportedTypes.ALL, 4, false); + + int rowCount = 1000; Object[][] rows = valueSource.randomRows(rowCount); addAndVerifyRows(valueSource, rows, @@ -272,9 +270,10 @@ public void testMultiKeyRows3() throws Exception { VerifyFastRowHashMap verifyTable = new VerifyFastRowHashMap(); VectorRandomRowSource valueSource = new VectorRandomRowSource(); - valueSource.init(random); - int rowCount = 10000; + valueSource.init(random, VectorRandomRowSource.SupportedTypes.ALL, 4, false); + + int rowCount = 1000; Object[][] rows = valueSource.randomRows(rowCount); addAndVerifyRows(valueSource, rows, @@ -296,9 +295,10 @@ public void testBigIntRowsClipped() throws Exception { VerifyFastRowHashMap verifyTable = new VerifyFastRowHashMap(); VectorRandomRowSource valueSource = new VectorRandomRowSource(); - valueSource.init(random); - int rowCount = 10000; + valueSource.init(random, VectorRandomRowSource.SupportedTypes.ALL, 4, false); + + int rowCount = 1000; Object[][] rows = valueSource.randomRows(rowCount); addAndVerifyRows(valueSource, rows, @@ -320,9 +320,10 @@ public void testIntRowsClipped() throws Exception { VerifyFastRowHashMap verifyTable = new VerifyFastRowHashMap(); VectorRandomRowSource valueSource = new VectorRandomRowSource(); - valueSource.init(random); - int rowCount = 10000; + valueSource.init(random, VectorRandomRowSource.SupportedTypes.ALL, 4, false); + + int rowCount = 1000; Object[][] rows = valueSource.randomRows(rowCount); addAndVerifyRows(valueSource, rows, @@ -344,9 +345,10 @@ public void testStringRowsClipped() throws Exception { VerifyFastRowHashMap verifyTable = new VerifyFastRowHashMap(); VectorRandomRowSource valueSource = new VectorRandomRowSource(); - valueSource.init(random); - int rowCount = 10000; + valueSource.init(random, VectorRandomRowSource.SupportedTypes.ALL, 4, false); + + int rowCount = 1000; Object[][] rows = valueSource.randomRows(rowCount); addAndVerifyRows(valueSource, rows, @@ -368,9 +370,10 @@ public void testMultiKeyRowsClipped1() throws Exception { VerifyFastRowHashMap verifyTable = new VerifyFastRowHashMap(); VectorRandomRowSource valueSource = new VectorRandomRowSource(); - valueSource.init(random); - int rowCount = 10000; + valueSource.init(random, VectorRandomRowSource.SupportedTypes.ALL, 4, false); + + int rowCount = 1000; Object[][] rows = valueSource.randomRows(rowCount); addAndVerifyRows(valueSource, rows, @@ -392,9 +395,10 @@ public void testMultiKeyRowsClipped2() throws Exception { VerifyFastRowHashMap verifyTable = new VerifyFastRowHashMap(); VectorRandomRowSource valueSource = new VectorRandomRowSource(); - valueSource.init(random); - int rowCount = 10000; + valueSource.init(random, VectorRandomRowSource.SupportedTypes.ALL, 4, false); + + int rowCount = 1000; Object[][] rows = valueSource.randomRows(rowCount); addAndVerifyRows(valueSource, rows, @@ -416,9 +420,10 @@ public void testMultiKeyRowsClipped3() throws Exception { VerifyFastRowHashMap verifyTable = new VerifyFastRowHashMap(); VectorRandomRowSource valueSource = new VectorRandomRowSource(); - valueSource.init(random); - int rowCount = 10000; + valueSource.init(random, VectorRandomRowSource.SupportedTypes.ALL, 4, false); + + int rowCount = 1000; Object[][] rows = valueSource.randomRows(rowCount); addAndVerifyRows(valueSource, rows, @@ -441,9 +446,10 @@ public void testBigIntRowsExact() throws Exception { VerifyFastRowHashMap verifyTable = new VerifyFastRowHashMap(); VectorRandomRowSource valueSource = new VectorRandomRowSource(); - valueSource.init(random); - int rowCount = 10000; + valueSource.init(random, VectorRandomRowSource.SupportedTypes.ALL, 4, false); + + int rowCount = 1000; Object[][] rows = valueSource.randomRows(rowCount); addAndVerifyRows(valueSource, rows, @@ -465,9 +471,10 @@ public void testIntRowsExact() throws Exception { VerifyFastRowHashMap verifyTable = new VerifyFastRowHashMap(); VectorRandomRowSource valueSource = new VectorRandomRowSource(); - valueSource.init(random); - int rowCount = 10000; + valueSource.init(random, VectorRandomRowSource.SupportedTypes.ALL, 4, false); + + int rowCount = 1000; Object[][] rows = valueSource.randomRows(rowCount); addAndVerifyRows(valueSource, rows, @@ -489,9 +496,10 @@ public void testStringRowsExact() throws Exception { VerifyFastRowHashMap verifyTable = new VerifyFastRowHashMap(); VectorRandomRowSource valueSource = new VectorRandomRowSource(); - valueSource.init(random); - int rowCount = 10000; + valueSource.init(random, VectorRandomRowSource.SupportedTypes.ALL, 4, false); + + int rowCount = 1000; Object[][] rows = valueSource.randomRows(rowCount); addAndVerifyRows(valueSource, rows, @@ -513,9 +521,10 @@ public void testMultiKeyRowsExact1() throws Exception { VerifyFastRowHashMap verifyTable = new VerifyFastRowHashMap(); VectorRandomRowSource valueSource = new VectorRandomRowSource(); - valueSource.init(random); - int rowCount = 10000; + valueSource.init(random, VectorRandomRowSource.SupportedTypes.ALL, 4, false); + + int rowCount = 1000; Object[][] rows = valueSource.randomRows(rowCount); addAndVerifyRows(valueSource, rows, @@ -537,9 +546,10 @@ public void testMultiKeyRowsExact2() throws Exception { VerifyFastRowHashMap verifyTable = new VerifyFastRowHashMap(); VectorRandomRowSource valueSource = new VectorRandomRowSource(); - valueSource.init(random); - int rowCount = 10000; + valueSource.init(random, VectorRandomRowSource.SupportedTypes.ALL, 4, false); + + int rowCount = 1000; Object[][] rows = valueSource.randomRows(rowCount); addAndVerifyRows(valueSource, rows, @@ -561,9 +571,10 @@ public void testMultiKeyRowsExact3() throws Exception { VerifyFastRowHashMap verifyTable = new VerifyFastRowHashMap(); VectorRandomRowSource valueSource = new VectorRandomRowSource(); - valueSource.init(random); - int rowCount = 10000; + valueSource.init(random, VectorRandomRowSource.SupportedTypes.ALL, 4, false); + + int rowCount = 1000; Object[][] rows = valueSource.randomRows(rowCount); addAndVerifyRows(valueSource, rows, @@ -585,9 +596,10 @@ public void testBigIntRowsClippedExact() throws Exception { VerifyFastRowHashMap verifyTable = new VerifyFastRowHashMap(); VectorRandomRowSource valueSource = new VectorRandomRowSource(); - valueSource.init(random); - int rowCount = 10000; + valueSource.init(random, VectorRandomRowSource.SupportedTypes.ALL, 4, false); + + int rowCount = 1000; Object[][] rows = valueSource.randomRows(rowCount); addAndVerifyRows(valueSource, rows, @@ -609,9 +621,10 @@ public void testIntRowsClippedExact() throws Exception { VerifyFastRowHashMap verifyTable = new VerifyFastRowHashMap(); VectorRandomRowSource valueSource = new VectorRandomRowSource(); - valueSource.init(random); - int rowCount = 10000; + valueSource.init(random, VectorRandomRowSource.SupportedTypes.ALL, 4, false); + + int rowCount = 1000; Object[][] rows = valueSource.randomRows(rowCount); addAndVerifyRows(valueSource, rows, @@ -633,9 +646,10 @@ public void testStringRowsClippedExact() throws Exception { VerifyFastRowHashMap verifyTable = new VerifyFastRowHashMap(); VectorRandomRowSource valueSource = new VectorRandomRowSource(); - valueSource.init(random); - int rowCount = 10000; + valueSource.init(random, VectorRandomRowSource.SupportedTypes.ALL, 4, false); + + int rowCount = 1000; Object[][] rows = valueSource.randomRows(rowCount); addAndVerifyRows(valueSource, rows, @@ -657,9 +671,10 @@ public void testMultiKeyRowsClippedExact1() throws Exception { VerifyFastRowHashMap verifyTable = new VerifyFastRowHashMap(); VectorRandomRowSource valueSource = new VectorRandomRowSource(); - valueSource.init(random); - int rowCount = 10000; + valueSource.init(random, VectorRandomRowSource.SupportedTypes.ALL, 4, false); + + int rowCount = 1000; Object[][] rows = valueSource.randomRows(rowCount); addAndVerifyRows(valueSource, rows, @@ -681,9 +696,10 @@ public void testMultiKeyRowsClippedExact2() throws Exception { VerifyFastRowHashMap verifyTable = new VerifyFastRowHashMap(); VectorRandomRowSource valueSource = new VectorRandomRowSource(); - valueSource.init(random); - int rowCount = 10000; + valueSource.init(random, VectorRandomRowSource.SupportedTypes.ALL, 4, false); + + int rowCount = 1000; Object[][] rows = valueSource.randomRows(rowCount); addAndVerifyRows(valueSource, rows, @@ -705,9 +721,10 @@ public void testMultiKeyRowsClippedExact3() throws Exception { VerifyFastRowHashMap verifyTable = new VerifyFastRowHashMap(); VectorRandomRowSource valueSource = new VectorRandomRowSource(); - valueSource.init(random); - int rowCount = 10000; + valueSource.init(random, VectorRandomRowSource.SupportedTypes.ALL, 4, false); + + int rowCount = 1000; Object[][] rows = valueSource.randomRows(rowCount); addAndVerifyRows(valueSource, rows, diff --git ql/src/test/org/apache/hadoop/hive/ql/exec/vector/mapjoin/fast/VerifyFastRow.java ql/src/test/org/apache/hadoop/hive/ql/exec/vector/mapjoin/fast/VerifyFastRow.java index 91b3ead203..91131df9f6 100644 --- ql/src/test/org/apache/hadoop/hive/ql/exec/vector/mapjoin/fast/VerifyFastRow.java +++ ql/src/test/org/apache/hadoop/hive/ql/exec/vector/mapjoin/fast/VerifyFastRow.java @@ -18,9 +18,14 @@ package org.apache.hadoop.hive.ql.exec.vector.mapjoin.fast; import java.io.IOException; +import java.nio.charset.StandardCharsets; import java.sql.Date; import java.sql.Timestamp; +import java.util.ArrayList; import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; import junit.framework.TestCase; @@ -41,9 +46,16 @@ import org.apache.hadoop.hive.serde2.io.HiveVarcharWritable; import org.apache.hadoop.hive.serde2.io.ShortWritable; import org.apache.hadoop.hive.serde2.io.TimestampWritable; +import org.apache.hadoop.hive.serde2.objectinspector.StandardUnionObjectInspector; import org.apache.hadoop.hive.serde2.typeinfo.CharTypeInfo; import org.apache.hadoop.hive.serde2.typeinfo.DecimalTypeInfo; +import org.apache.hadoop.hive.serde2.typeinfo.ListTypeInfo; +import org.apache.hadoop.hive.serde2.typeinfo.MapTypeInfo; import org.apache.hadoop.hive.serde2.typeinfo.PrimitiveTypeInfo; +import org.apache.hadoop.hive.serde2.typeinfo.StructTypeInfo; +import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo; +import org.apache.hadoop.hive.serde2.typeinfo.TypeInfoFactory; +import org.apache.hadoop.hive.serde2.typeinfo.UnionTypeInfo; import org.apache.hadoop.hive.serde2.typeinfo.VarcharTypeInfo; import org.apache.hadoop.io.BooleanWritable; import org.apache.hadoop.io.BytesWritable; @@ -60,341 +72,638 @@ public class VerifyFastRow { public static void verifyDeserializeRead(DeserializeRead deserializeRead, - PrimitiveTypeInfo primitiveTypeInfo, Writable writable) throws IOException { + TypeInfo typeInfo, Object object) throws IOException { boolean isNull; isNull = !deserializeRead.readNextField(); + doVerifyDeserializeRead(deserializeRead, typeInfo, object, isNull); + } + + public static void doVerifyDeserializeRead(DeserializeRead deserializeRead, + TypeInfo typeInfo, Object object, boolean isNull) throws IOException { if (isNull) { - if (writable != null) { - TestCase.fail( - deserializeRead.getClass().getName() + - " field reports null but object is not null " + - "(class " + writable.getClass().getName() + ", " + writable.toString() + ")"); + if (object != null) { + TestCase.fail("Field reports null but object is not null (class " + object.getClass().getName() + ", " + object.toString() + ")"); } return; - } else if (writable == null) { + } else if (object == null) { TestCase.fail("Field report not null but object is null"); } - switch (primitiveTypeInfo.getPrimitiveCategory()) { - case BOOLEAN: - { - boolean value = deserializeRead.currentBoolean; - if (!(writable instanceof BooleanWritable)) { - TestCase.fail("Boolean expected writable not Boolean"); - } - boolean expected = ((BooleanWritable) writable).get(); - if (value != expected) { - TestCase.fail("Boolean field mismatch (expected " + expected + " found " + value + ")"); - } - } - break; - case BYTE: - { - byte value = deserializeRead.currentByte; - if (!(writable instanceof ByteWritable)) { - TestCase.fail("Byte expected writable not Byte"); - } - byte expected = ((ByteWritable) writable).get(); - if (value != expected) { - TestCase.fail("Byte field mismatch (expected " + (int) expected + " found " + (int) value + ")"); - } - } - break; - case SHORT: - { - short value = deserializeRead.currentShort; - if (!(writable instanceof ShortWritable)) { - TestCase.fail("Short expected writable not Short"); - } - short expected = ((ShortWritable) writable).get(); - if (value != expected) { - TestCase.fail("Short field mismatch (expected " + expected + " found " + value + ")"); - } - } - break; - case INT: - { - int value = deserializeRead.currentInt; - if (!(writable instanceof IntWritable)) { - TestCase.fail("Integer expected writable not Integer"); - } - int expected = ((IntWritable) writable).get(); - if (value != expected) { - TestCase.fail("Int field mismatch (expected " + expected + " found " + value + ")"); - } - } - break; - case LONG: - { - long value = deserializeRead.currentLong; - if (!(writable instanceof LongWritable)) { - TestCase.fail("Long expected writable not Long"); - } - Long expected = ((LongWritable) writable).get(); - if (value != expected) { - TestCase.fail("Long field mismatch (expected " + expected + " found " + value + ")"); - } - } - break; - case FLOAT: - { - float value = deserializeRead.currentFloat; - if (!(writable instanceof FloatWritable)) { - TestCase.fail("Float expected writable not Float"); - } - float expected = ((FloatWritable) writable).get(); - if (value != expected) { - TestCase.fail("Float field mismatch (expected " + expected + " found " + value + ")"); - } - } - break; - case DOUBLE: - { - double value = deserializeRead.currentDouble; - if (!(writable instanceof DoubleWritable)) { - TestCase.fail("Double expected writable not Double"); - } - double expected = ((DoubleWritable) writable).get(); - if (value != expected) { - TestCase.fail("Double field mismatch (expected " + expected + " found " + value + ")"); - } - } - break; - case STRING: - { - byte[] stringBytes = Arrays.copyOfRange( - deserializeRead.currentBytes, - deserializeRead.currentBytesStart, - deserializeRead.currentBytesStart + deserializeRead.currentBytesLength); - Text text = new Text(stringBytes); - String string = text.toString(); - String expected = ((Text) writable).toString(); - if (!string.equals(expected)) { - TestCase.fail("String field mismatch (expected '" + expected + "' found '" + string + "')"); - } - } - break; - case CHAR: - { - byte[] stringBytes = Arrays.copyOfRange( - deserializeRead.currentBytes, - deserializeRead.currentBytesStart, - deserializeRead.currentBytesStart + deserializeRead.currentBytesLength); - Text text = new Text(stringBytes); - String string = text.toString(); + switch (typeInfo.getCategory()) { + case PRIMITIVE: + { + PrimitiveTypeInfo primitiveTypeInfo = (PrimitiveTypeInfo) typeInfo; + switch (primitiveTypeInfo.getPrimitiveCategory()) { + case BOOLEAN: + { + boolean value = deserializeRead.currentBoolean; + if (!(object instanceof BooleanWritable)) { + TestCase.fail("Boolean expected writable not Boolean"); + } + boolean expected = ((BooleanWritable) object).get(); + if (value != expected) { + TestCase.fail("Boolean field mismatch (expected " + expected + " found " + value + ")"); + } + } + break; + case BYTE: + { + byte value = deserializeRead.currentByte; + if (!(object instanceof ByteWritable)) { + TestCase.fail("Byte expected writable not Byte"); + } + byte expected = ((ByteWritable) object).get(); + if (value != expected) { + TestCase.fail("Byte field mismatch (expected " + (int) expected + " found " + (int) value + ")"); + } + } + break; + case SHORT: + { + short value = deserializeRead.currentShort; + if (!(object instanceof ShortWritable)) { + TestCase.fail("Short expected writable not Short"); + } + short expected = ((ShortWritable) object).get(); + if (value != expected) { + TestCase.fail("Short field mismatch (expected " + expected + " found " + value + ")"); + } + } + break; + case INT: + { + int value = deserializeRead.currentInt; + if (!(object instanceof IntWritable)) { + TestCase.fail("Integer expected writable not Integer"); + } + int expected = ((IntWritable) object).get(); + if (value != expected) { + TestCase.fail("Int field mismatch (expected " + expected + " found " + value + ")"); + } + } + break; + case LONG: + { + long value = deserializeRead.currentLong; + if (!(object instanceof LongWritable)) { + TestCase.fail("Long expected writable not Long"); + } + Long expected = ((LongWritable) object).get(); + if (value != expected) { + TestCase.fail("Long field mismatch (expected " + expected + " found " + value + ")"); + } + } + break; + case FLOAT: + { + float value = deserializeRead.currentFloat; + if (!(object instanceof FloatWritable)) { + TestCase.fail("Float expected writable not Float"); + } + float expected = ((FloatWritable) object).get(); + if (value != expected) { + TestCase.fail("Float field mismatch (expected " + expected + " found " + value + ")"); + } + } + break; + case DOUBLE: + { + double value = deserializeRead.currentDouble; + if (!(object instanceof DoubleWritable)) { + TestCase.fail("Double expected writable not Double"); + } + double expected = ((DoubleWritable) object).get(); + if (value != expected) { + TestCase.fail("Double field mismatch (expected " + expected + " found " + value + ")"); + } + } + break; + case STRING: + { + byte[] stringBytes = Arrays.copyOfRange( + deserializeRead.currentBytes, + deserializeRead.currentBytesStart, + deserializeRead.currentBytesStart + deserializeRead.currentBytesLength); + Text text = new Text(stringBytes); + String string = text.toString(); + String expected = ((Text) object).toString(); + if (!string.equals(expected)) { + TestCase.fail("String field mismatch (expected '" + expected + "' found '" + string + "')"); + } + } + break; + case CHAR: + { + byte[] stringBytes = Arrays.copyOfRange( + deserializeRead.currentBytes, + deserializeRead.currentBytesStart, + deserializeRead.currentBytesStart + deserializeRead.currentBytesLength); + Text text = new Text(stringBytes); + String string = text.toString(); - HiveChar hiveChar = new HiveChar(string, ((CharTypeInfo) primitiveTypeInfo).getLength()); + HiveChar hiveChar = new HiveChar(string, ((CharTypeInfo) primitiveTypeInfo).getLength()); - HiveChar expected = ((HiveCharWritable) writable).getHiveChar(); - if (!hiveChar.equals(expected)) { - TestCase.fail("Char field mismatch (expected '" + expected + "' found '" + hiveChar + "')"); - } - } - break; - case VARCHAR: - { - byte[] stringBytes = Arrays.copyOfRange( - deserializeRead.currentBytes, - deserializeRead.currentBytesStart, - deserializeRead.currentBytesStart + deserializeRead.currentBytesLength); - Text text = new Text(stringBytes); - String string = text.toString(); + HiveChar expected = ((HiveCharWritable) object).getHiveChar(); + if (!hiveChar.equals(expected)) { + TestCase.fail("Char field mismatch (expected '" + expected + "' found '" + hiveChar + "')"); + } + } + break; + case VARCHAR: + { + byte[] stringBytes = Arrays.copyOfRange( + deserializeRead.currentBytes, + deserializeRead.currentBytesStart, + deserializeRead.currentBytesStart + deserializeRead.currentBytesLength); + Text text = new Text(stringBytes); + String string = text.toString(); - HiveVarchar hiveVarchar = new HiveVarchar(string, ((VarcharTypeInfo) primitiveTypeInfo).getLength()); + HiveVarchar hiveVarchar = new HiveVarchar(string, ((VarcharTypeInfo) primitiveTypeInfo).getLength()); - HiveVarchar expected = ((HiveVarcharWritable) writable).getHiveVarchar(); - if (!hiveVarchar.equals(expected)) { - TestCase.fail("Varchar field mismatch (expected '" + expected + "' found '" + hiveVarchar + "')"); - } - } - break; - case DECIMAL: - { - HiveDecimal value = deserializeRead.currentHiveDecimalWritable.getHiveDecimal(); - if (value == null) { - TestCase.fail("Decimal field evaluated to NULL"); - } - HiveDecimal expected = ((HiveDecimalWritable) writable).getHiveDecimal(); - if (!value.equals(expected)) { - DecimalTypeInfo decimalTypeInfo = (DecimalTypeInfo) primitiveTypeInfo; - int precision = decimalTypeInfo.getPrecision(); - int scale = decimalTypeInfo.getScale(); - TestCase.fail("Decimal field mismatch (expected " + expected.toString() + " found " + value.toString() + ") precision " + precision + ", scale " + scale); + HiveVarchar expected = ((HiveVarcharWritable) object).getHiveVarchar(); + if (!hiveVarchar.equals(expected)) { + TestCase.fail("Varchar field mismatch (expected '" + expected + "' found '" + hiveVarchar + "')"); + } + } + break; + case DECIMAL: + { + HiveDecimal value = deserializeRead.currentHiveDecimalWritable.getHiveDecimal(); + if (value == null) { + TestCase.fail("Decimal field evaluated to NULL"); + } + HiveDecimal expected = ((HiveDecimalWritable) object).getHiveDecimal(); + if (!value.equals(expected)) { + DecimalTypeInfo decimalTypeInfo = (DecimalTypeInfo) primitiveTypeInfo; + int precision = decimalTypeInfo.getPrecision(); + int scale = decimalTypeInfo.getScale(); + TestCase.fail("Decimal field mismatch (expected " + expected.toString() + " found " + value.toString() + ") precision " + precision + ", scale " + scale); + } + } + break; + case DATE: + { + Date value = deserializeRead.currentDateWritable.get(); + Date expected = ((DateWritable) object).get(); + if (!value.equals(expected)) { + TestCase.fail("Date field mismatch (expected " + expected.toString() + " found " + value.toString() + ")"); + } + } + break; + case TIMESTAMP: + { + Timestamp value = deserializeRead.currentTimestampWritable.getTimestamp(); + Timestamp expected = ((TimestampWritable) object).getTimestamp(); + if (!value.equals(expected)) { + TestCase.fail("Timestamp field mismatch (expected " + expected.toString() + " found " + value.toString() + ")"); + } + } + break; + case INTERVAL_YEAR_MONTH: + { + HiveIntervalYearMonth value = deserializeRead.currentHiveIntervalYearMonthWritable.getHiveIntervalYearMonth(); + HiveIntervalYearMonth expected = ((HiveIntervalYearMonthWritable) object).getHiveIntervalYearMonth(); + if (!value.equals(expected)) { + TestCase.fail("HiveIntervalYearMonth field mismatch (expected " + expected.toString() + " found " + value.toString() + ")"); + } + } + break; + case INTERVAL_DAY_TIME: + { + HiveIntervalDayTime value = deserializeRead.currentHiveIntervalDayTimeWritable.getHiveIntervalDayTime(); + HiveIntervalDayTime expected = ((HiveIntervalDayTimeWritable) object).getHiveIntervalDayTime(); + if (!value.equals(expected)) { + TestCase.fail("HiveIntervalDayTime field mismatch (expected " + expected.toString() + " found " + value.toString() + ")"); + } + } + break; + case BINARY: + { + byte[] byteArray = Arrays.copyOfRange( + deserializeRead.currentBytes, + deserializeRead.currentBytesStart, + deserializeRead.currentBytesStart + deserializeRead.currentBytesLength); + BytesWritable bytesWritable = (BytesWritable) object; + byte[] expected = Arrays.copyOfRange(bytesWritable.getBytes(), 0, bytesWritable.getLength()); + if (byteArray.length != expected.length){ + TestCase.fail("Byte Array field mismatch (expected " + Arrays.toString(expected) + + " found " + Arrays.toString(byteArray) + ")"); + } + for (int b = 0; b < byteArray.length; b++) { + if (byteArray[b] != expected[b]) { + TestCase.fail("Byte Array field mismatch (expected " + Arrays.toString(expected) + + " found " + Arrays.toString(byteArray) + ")"); + } + } + } + break; + default: + throw new Error("Unknown primitive category " + primitiveTypeInfo.getPrimitiveCategory()); } } break; - case DATE: - { - Date value = deserializeRead.currentDateWritable.get(); - Date expected = ((DateWritable) writable).get(); - if (!value.equals(expected)) { - TestCase.fail("Date field mismatch (expected " + expected.toString() + " found " + value.toString() + ")"); + case LIST: + case MAP: + case STRUCT: + case UNION: + throw new Error("Complex types need to be handled separately"); + default: + throw new Error("Unknown category " + typeInfo.getCategory()); + } + } + + public static void serializeWrite(SerializeWrite serializeWrite, + TypeInfo typeInfo, Object object) throws IOException { + if (object == null) { + serializeWrite.writeNull(); + return; + } + switch (typeInfo.getCategory()) { + case PRIMITIVE: + { + PrimitiveTypeInfo primitiveTypeInfo = (PrimitiveTypeInfo) typeInfo; + switch (primitiveTypeInfo.getPrimitiveCategory()) { + case BOOLEAN: + { + boolean value = ((BooleanWritable) object).get(); + serializeWrite.writeBoolean(value); + } + break; + case BYTE: + { + byte value = ((ByteWritable) object).get(); + serializeWrite.writeByte(value); + } + break; + case SHORT: + { + short value = ((ShortWritable) object).get(); + serializeWrite.writeShort(value); + } + break; + case INT: + { + int value = ((IntWritable) object).get(); + serializeWrite.writeInt(value); + } + break; + case LONG: + { + long value = ((LongWritable) object).get(); + serializeWrite.writeLong(value); + } + break; + case FLOAT: + { + float value = ((FloatWritable) object).get(); + serializeWrite.writeFloat(value); + } + break; + case DOUBLE: + { + double value = ((DoubleWritable) object).get(); + serializeWrite.writeDouble(value); + } + break; + case STRING: + { + Text value = (Text) object; + byte[] stringBytes = value.getBytes(); + int stringLength = stringBytes.length; + serializeWrite.writeString(stringBytes, 0, stringLength); + } + break; + case CHAR: + { + HiveChar value = ((HiveCharWritable) object).getHiveChar(); + serializeWrite.writeHiveChar(value); + } + break; + case VARCHAR: + { + HiveVarchar value = ((HiveVarcharWritable) object).getHiveVarchar(); + serializeWrite.writeHiveVarchar(value); + } + break; + case DECIMAL: + { + HiveDecimal value = ((HiveDecimalWritable) object).getHiveDecimal(); + DecimalTypeInfo decTypeInfo = (DecimalTypeInfo)primitiveTypeInfo; + serializeWrite.writeHiveDecimal(value, decTypeInfo.scale()); + } + break; + case DATE: + { + Date value = ((DateWritable) object).get(); + serializeWrite.writeDate(value); + } + break; + case TIMESTAMP: + { + Timestamp value = ((TimestampWritable) object).getTimestamp(); + serializeWrite.writeTimestamp(value); + } + break; + case INTERVAL_YEAR_MONTH: + { + HiveIntervalYearMonth value = ((HiveIntervalYearMonthWritable) object).getHiveIntervalYearMonth(); + serializeWrite.writeHiveIntervalYearMonth(value); + } + break; + case INTERVAL_DAY_TIME: + { + HiveIntervalDayTime value = ((HiveIntervalDayTimeWritable) object).getHiveIntervalDayTime(); + serializeWrite.writeHiveIntervalDayTime(value); + } + break; + case BINARY: + { + BytesWritable byteWritable = (BytesWritable) object; + byte[] binaryBytes = byteWritable.getBytes(); + int length = byteWritable.getLength(); + serializeWrite.writeBinary(binaryBytes, 0, length); + } + break; + default: + throw new Error("Unknown primitive category " + primitiveTypeInfo.getPrimitiveCategory().name()); } } break; - case TIMESTAMP: + case LIST: { - Timestamp value = deserializeRead.currentTimestampWritable.getTimestamp(); - Timestamp expected = ((TimestampWritable) writable).getTimestamp(); - if (!value.equals(expected)) { - TestCase.fail("Timestamp field mismatch (expected " + expected.toString() + " found " + value.toString() + ")"); + ListTypeInfo listTypeInfo = (ListTypeInfo) typeInfo; + TypeInfo elementTypeInfo = listTypeInfo.getListElementTypeInfo(); + ArrayList elements = (ArrayList) object; + serializeWrite.beginList(elements); + boolean isFirst = true; + for (Object elementObject : elements) { + if (isFirst) { + isFirst = false; + } else { + serializeWrite.separateList(); + } + if (elementObject == null) { + serializeWrite.writeNull(); + } else { + serializeWrite(serializeWrite, elementTypeInfo, elementObject); + } } - } - break; - case INTERVAL_YEAR_MONTH: - { - HiveIntervalYearMonth value = deserializeRead.currentHiveIntervalYearMonthWritable.getHiveIntervalYearMonth(); - HiveIntervalYearMonth expected = ((HiveIntervalYearMonthWritable) writable).getHiveIntervalYearMonth(); - if (!value.equals(expected)) { - TestCase.fail("HiveIntervalYearMonth field mismatch (expected " + expected.toString() + " found " + value.toString() + ")"); + serializeWrite.finishList(); + } + break; + case MAP: + { + MapTypeInfo mapTypeInfo = (MapTypeInfo) typeInfo; + TypeInfo keyTypeInfo = mapTypeInfo.getMapKeyTypeInfo(); + TypeInfo valueTypeInfo = mapTypeInfo.getMapValueTypeInfo(); + HashMap hashMap = (HashMap) object; + serializeWrite.beginMap(hashMap); + boolean isFirst = true; + for (Map.Entry entry : hashMap.entrySet()) { + if (isFirst) { + isFirst = false; + } else { + serializeWrite.separateKeyValuePair(); + } + if (entry.getKey() == null) { + serializeWrite.writeNull(); + } else { + serializeWrite(serializeWrite, keyTypeInfo, entry.getKey()); + } + serializeWrite.separateKey(); + if (entry.getValue() == null) { + serializeWrite.writeNull(); + } else { + serializeWrite(serializeWrite, valueTypeInfo, entry.getValue()); + } } - } - break; - case INTERVAL_DAY_TIME: - { - HiveIntervalDayTime value = deserializeRead.currentHiveIntervalDayTimeWritable.getHiveIntervalDayTime(); - HiveIntervalDayTime expected = ((HiveIntervalDayTimeWritable) writable).getHiveIntervalDayTime(); - if (!value.equals(expected)) { - TestCase.fail("HiveIntervalDayTime field mismatch (expected " + expected.toString() + " found " + value.toString() + ")"); + serializeWrite.finishMap(); + } + break; + case STRUCT: + { + StructTypeInfo structTypeInfo = (StructTypeInfo) typeInfo; + ArrayList fieldTypeInfos = structTypeInfo.getAllStructFieldTypeInfos(); + ArrayList fieldValues = (ArrayList) object; + final int size = fieldValues.size(); + serializeWrite.beginStruct(fieldValues); + boolean isFirst = true; + for (int i = 0; i < size; i++) { + if (isFirst) { + isFirst = false; + } else { + serializeWrite.separateStruct(); + } + serializeWrite(serializeWrite, fieldTypeInfos.get(i), fieldValues.get(i)); } + serializeWrite.finishStruct(); } break; - case BINARY: + case UNION: { - byte[] byteArray = Arrays.copyOfRange( - deserializeRead.currentBytes, - deserializeRead.currentBytesStart, - deserializeRead.currentBytesStart + deserializeRead.currentBytesLength); - BytesWritable bytesWritable = (BytesWritable) writable; - byte[] expected = Arrays.copyOfRange(bytesWritable.getBytes(), 0, bytesWritable.getLength()); - if (byteArray.length != expected.length){ - TestCase.fail("Byte Array field mismatch (expected " + Arrays.toString(expected) - + " found " + Arrays.toString(byteArray) + ")"); - } - for (int b = 0; b < byteArray.length; b++) { - if (byteArray[b] != expected[b]) { - TestCase.fail("Byte Array field mismatch (expected " + Arrays.toString(expected) - + " found " + Arrays.toString(byteArray) + ")"); - } - } + UnionTypeInfo unionTypeInfo = (UnionTypeInfo) typeInfo; + List fieldTypeInfos = unionTypeInfo.getAllUnionObjectTypeInfos(); + final int size = fieldTypeInfos.size(); + StandardUnionObjectInspector.StandardUnion standardUnion = (StandardUnionObjectInspector.StandardUnion) object; + byte tag = standardUnion.getTag(); + serializeWrite.beginUnion(tag); + serializeWrite(serializeWrite, fieldTypeInfos.get(tag), standardUnion.getObject()); + serializeWrite.finishUnion(); } break; default: - throw new Error("Unknown primitive category " + primitiveTypeInfo.getPrimitiveCategory()); + throw new Error("Unknown category " + typeInfo.getCategory().name()); } } - public static void serializeWrite(SerializeWrite serializeWrite, - PrimitiveTypeInfo primitiveTypeInfo, Writable writable) throws IOException { - if (writable == null) { - serializeWrite.writeNull(); - return; + public Object readComplexPrimitiveField(DeserializeRead deserializeRead, + PrimitiveTypeInfo primitiveTypeInfo) throws IOException { + boolean isNull = !deserializeRead.readComplexField(); + if (isNull) { + return null; + } else { + return doReadComplexPrimitiveField(deserializeRead, primitiveTypeInfo); } + } + + private static Object doReadComplexPrimitiveField(DeserializeRead deserializeRead, + PrimitiveTypeInfo primitiveTypeInfo) throws IOException { switch (primitiveTypeInfo.getPrimitiveCategory()) { - case BOOLEAN: - { - boolean value = ((BooleanWritable) writable).get(); - serializeWrite.writeBoolean(value); - } - break; + case BOOLEAN: + return new BooleanWritable(deserializeRead.currentBoolean); case BYTE: - { - byte value = ((ByteWritable) writable).get(); - serializeWrite.writeByte(value); - } - break; + return new ByteWritable(deserializeRead.currentByte); case SHORT: - { - short value = ((ShortWritable) writable).get(); - serializeWrite.writeShort(value); - } - break; + return new ShortWritable(deserializeRead.currentShort); case INT: - { - int value = ((IntWritable) writable).get(); - serializeWrite.writeInt(value); - } - break; + return new IntWritable(deserializeRead.currentInt); case LONG: - { - long value = ((LongWritable) writable).get(); - serializeWrite.writeLong(value); - } - break; + return new LongWritable(deserializeRead.currentLong); case FLOAT: - { - float value = ((FloatWritable) writable).get(); - serializeWrite.writeFloat(value); - } - break; + return new FloatWritable(deserializeRead.currentFloat); case DOUBLE: - { - double value = ((DoubleWritable) writable).get(); - serializeWrite.writeDouble(value); - } - break; + return new DoubleWritable(deserializeRead.currentDouble); case STRING: - { - Text value = (Text) writable; - byte[] stringBytes = value.getBytes(); - int stringLength = stringBytes.length; - serializeWrite.writeString(stringBytes, 0, stringLength); - } - break; + return new Text(new String( + deserializeRead.currentBytes, + deserializeRead.currentBytesStart, + deserializeRead.currentBytesLength, + StandardCharsets.UTF_8)); case CHAR: - { - HiveChar value = ((HiveCharWritable) writable).getHiveChar(); - serializeWrite.writeHiveChar(value); - } - break; + return new HiveCharWritable(new HiveChar( + new String( + deserializeRead.currentBytes, + deserializeRead.currentBytesStart, + deserializeRead.currentBytesLength, + StandardCharsets.UTF_8), + ((CharTypeInfo) primitiveTypeInfo).getLength())); case VARCHAR: - { - HiveVarchar value = ((HiveVarcharWritable) writable).getHiveVarchar(); - serializeWrite.writeHiveVarchar(value); - } - break; + if (deserializeRead.currentBytes == null) { + throw new RuntimeException(); + } + return new HiveVarcharWritable(new HiveVarchar( + new String( + deserializeRead.currentBytes, + deserializeRead.currentBytesStart, + deserializeRead.currentBytesLength, + StandardCharsets.UTF_8), + ((VarcharTypeInfo) primitiveTypeInfo).getLength())); case DECIMAL: - { - HiveDecimal value = ((HiveDecimalWritable) writable).getHiveDecimal(); - DecimalTypeInfo decTypeInfo = (DecimalTypeInfo)primitiveTypeInfo; - serializeWrite.writeHiveDecimal(value, decTypeInfo.scale()); - } - break; + return new HiveDecimalWritable(deserializeRead.currentHiveDecimalWritable); case DATE: - { - Date value = ((DateWritable) writable).get(); - serializeWrite.writeDate(value); - } - break; + return new DateWritable(deserializeRead.currentDateWritable); case TIMESTAMP: - { - Timestamp value = ((TimestampWritable) writable).getTimestamp(); - serializeWrite.writeTimestamp(value); - } - break; + return new TimestampWritable(deserializeRead.currentTimestampWritable); case INTERVAL_YEAR_MONTH: - { - HiveIntervalYearMonth value = ((HiveIntervalYearMonthWritable) writable).getHiveIntervalYearMonth(); - serializeWrite.writeHiveIntervalYearMonth(value); - } - break; + return new HiveIntervalYearMonthWritable(deserializeRead.currentHiveIntervalYearMonthWritable); case INTERVAL_DAY_TIME: - { - HiveIntervalDayTime value = ((HiveIntervalDayTimeWritable) writable).getHiveIntervalDayTime(); - serializeWrite.writeHiveIntervalDayTime(value); - } - break; + return new HiveIntervalDayTimeWritable(deserializeRead.currentHiveIntervalDayTimeWritable); case BINARY: - { - BytesWritable byteWritable = (BytesWritable) writable; - byte[] binaryBytes = byteWritable.getBytes(); - int length = byteWritable.getLength(); - serializeWrite.writeBinary(binaryBytes, 0, length); + return new BytesWritable( + Arrays.copyOfRange( + deserializeRead.currentBytes, + deserializeRead.currentBytesStart, + deserializeRead.currentBytesLength + deserializeRead.currentBytesStart)); + default: + throw new Error("Unknown primitive category " + primitiveTypeInfo.getPrimitiveCategory()); + } + } + + public static Object deserializeReadComplexType(DeserializeRead deserializeRead, + TypeInfo typeInfo) throws IOException { + + boolean isNull = !deserializeRead.readNextField(); + if (isNull) { + return null; + } + return getComplexField(deserializeRead, typeInfo); + } + + static int fake = 0; + + private static Object getComplexField(DeserializeRead deserializeRead, + TypeInfo typeInfo) throws IOException { + switch (typeInfo.getCategory()) { + case PRIMITIVE: + return doReadComplexPrimitiveField(deserializeRead, (PrimitiveTypeInfo) typeInfo); + case LIST: + { + ListTypeInfo listTypeInfo = (ListTypeInfo) typeInfo; + TypeInfo elementTypeInfo = listTypeInfo.getListElementTypeInfo(); + ArrayList list = new ArrayList(); + Object eleObj; + boolean isNull; + while (deserializeRead.isNextComplexMultiValue()) { + isNull = !deserializeRead.readComplexField(); + if (isNull) { + eleObj = null; + } else { + eleObj = getComplexField(deserializeRead, elementTypeInfo); + if (eleObj instanceof String && ((String) eleObj).equals("SMNAR")) { + fake++; + } + } + list.add(eleObj); + } + return list; + } + case MAP: + { + MapTypeInfo mapTypeInfo = (MapTypeInfo) typeInfo; + TypeInfo keyTypeInfo = mapTypeInfo.getMapKeyTypeInfo(); + TypeInfo valueTypeInfo = mapTypeInfo.getMapValueTypeInfo(); + HashMap hashMap = new HashMap(); + Object keyObj; + Object valueObj; + boolean isNull; + while (deserializeRead.isNextComplexMultiValue()) { + isNull = !deserializeRead.readComplexField(); + if (isNull) { + keyObj = null; + } else { + keyObj = getComplexField(deserializeRead, keyTypeInfo); + } + isNull = !deserializeRead.readComplexField(); + if (isNull) { + valueObj = null; + } else { + valueObj = getComplexField(deserializeRead, valueTypeInfo); + } + hashMap.put(keyObj, valueObj); + } + return hashMap; + } + case STRUCT: + { + StructTypeInfo structTypeInfo = (StructTypeInfo) typeInfo; + ArrayList fieldTypeInfos = structTypeInfo.getAllStructFieldTypeInfos(); + final int size = fieldTypeInfos.size(); + ArrayList fieldValues = new ArrayList(); + Object fieldObj; + boolean isNull; + for (int i = 0; i < size; i++) { + isNull = !deserializeRead.readComplexField(); + if (isNull) { + fieldObj = null; + } else { + fieldObj = getComplexField(deserializeRead, fieldTypeInfos.get(i)); + } + fieldValues.add(fieldObj); + } + deserializeRead.finishComplexVariableFieldsType(); + return fieldValues; + } + case UNION: + { + UnionTypeInfo unionTypeInfo = (UnionTypeInfo) typeInfo; + List unionTypeInfos = unionTypeInfo.getAllUnionObjectTypeInfos(); + final int size = unionTypeInfos.size(); + Object tagObj; + int tag; + Object unionObj; + boolean isNull = !deserializeRead.readComplexField(); + if (isNull) { + unionObj = null; + } else { + // Get the tag value. + tagObj = getComplexField(deserializeRead, TypeInfoFactory.intTypeInfo); + tag = ((IntWritable) tagObj).get(); + + isNull = !deserializeRead.readComplexField(); + if (isNull) { + unionObj = null; + } else { + // Get the union value. + unionObj = new StandardUnionObjectInspector.StandardUnion((byte) tag, getComplexField(deserializeRead, unionTypeInfos.get(tag))); + } + } + + deserializeRead.finishComplexVariableFieldsType(); + return unionObj; } - break; default: - throw new Error("Unknown primitive category " + primitiveTypeInfo.getPrimitiveCategory().name()); + throw new Error("Unexpected category " + typeInfo.getCategory()); } } } \ No newline at end of file diff --git serde/src/java/org/apache/hadoop/hive/serde2/binarysortable/fast/BinarySortableDeserializeRead.java serde/src/java/org/apache/hadoop/hive/serde2/binarysortable/fast/BinarySortableDeserializeRead.java index 19d4550aa9..0b23fca337 100644 --- serde/src/java/org/apache/hadoop/hive/serde2/binarysortable/fast/BinarySortableDeserializeRead.java +++ serde/src/java/org/apache/hadoop/hive/serde2/binarysortable/fast/BinarySortableDeserializeRead.java @@ -19,13 +19,20 @@ package org.apache.hadoop.hive.serde2.binarysortable.fast; import java.io.IOException; -import java.math.BigInteger; +import java.util.ArrayDeque; import java.util.Arrays; import java.nio.charset.StandardCharsets; - +import java.util.Deque; +import java.util.List; + +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector; +import org.apache.hadoop.hive.serde2.typeinfo.ListTypeInfo; +import org.apache.hadoop.hive.serde2.typeinfo.MapTypeInfo; +import org.apache.hadoop.hive.serde2.typeinfo.StructTypeInfo; +import org.apache.hadoop.hive.serde2.typeinfo.UnionTypeInfo; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import org.apache.hadoop.hive.common.type.FastHiveDecimal; import org.apache.hadoop.hive.serde2.binarysortable.BinarySortableSerDe; import org.apache.hadoop.hive.serde2.binarysortable.InputByteBuffer; import org.apache.hadoop.hive.serde2.fast.DeserializeRead; @@ -57,12 +64,6 @@ byte[] columnNullMarker; byte[] columnNotNullMarker; - // Which field we are on. We start with -1 so readNextField can increment once and the read - // field data methods don't increment. - private int fieldIndex; - - private int fieldCount; - private int start; private int end; private int fieldStart; @@ -78,19 +79,36 @@ private InputByteBuffer inputByteBuffer = new InputByteBuffer(); + private Field root; + private Deque stack; + + private class Field { + ObjectInspector.Category category; + PrimitiveObjectInspector.PrimitiveCategory primitiveCategory; + int index; + int count; + Field[] children; + int start; + int tag; + TypeInfo typeInfo; + } + /* * Use this constructor when only ascending sort order is used. */ - public BinarySortableDeserializeRead(PrimitiveTypeInfo[] primitiveTypeInfos, - boolean useExternalBuffer) { - this(primitiveTypeInfos, useExternalBuffer, null, null, null); + public BinarySortableDeserializeRead(TypeInfo[] typeInfos, boolean useExternalBuffer) { + this(typeInfos, useExternalBuffer, null, null, null); } public BinarySortableDeserializeRead(TypeInfo[] typeInfos, boolean useExternalBuffer, boolean[] columnSortOrderIsDesc, byte[] columnNullMarker, byte[] columnNotNullMarker) { super(typeInfos, useExternalBuffer); final int count = typeInfos.length; - fieldCount = count; + root = new Field(); + root.category = ObjectInspector.Category.STRUCT; + root.children = createFields(typeInfos); + root.count = count; + stack = new ArrayDeque<>(); if (columnSortOrderIsDesc != null) { this.columnSortOrderIsDesc = columnSortOrderIsDesc; } else { @@ -131,10 +149,23 @@ private BinarySortableDeserializeRead() { */ @Override public void set(byte[] bytes, int offset, int length) { - fieldIndex = -1; start = offset; end = offset + length; inputByteBuffer.reset(bytes, start, end); + root.index = -1; + stack.clear(); + stack.push(root); + clearIndex(root); + } + + private void clearIndex(Field field) { + field.index = -1; + if (field.children == null) { + return; + } + for (Field child : field.children) { + clearIndex(child); + } } /* @@ -150,15 +181,15 @@ public String getDetailedReadPositionString() { sb.append(" for length "); sb.append(end - start); sb.append(" to read "); - sb.append(fieldCount); + sb.append(root.count); sb.append(" fields with types "); sb.append(Arrays.toString(typeInfos)); sb.append(". "); - if (fieldIndex == -1) { + if (root.index == -1) { sb.append("Before first field?"); } else { sb.append("Read field #"); - sb.append(fieldIndex); + sb.append(root.index); sb.append(" at field start position "); sb.append(fieldStart); sb.append(" current read offset "); @@ -187,31 +218,19 @@ public String getDetailedReadPositionString() { */ @Override public boolean readNextField() throws IOException { + return readComplexField(); + } - // We start with fieldIndex as -1 so we can increment once here and then the read - // field data methods don't increment. - fieldIndex++; - - if (fieldIndex >= fieldCount) { - return false; - } - if (inputByteBuffer.isEof()) { - // Also, reading beyond our byte range produces NULL. - return false; - } - - fieldStart = inputByteBuffer.tell(); + private boolean readPrimitive(Field field) throws IOException { + final int fieldIndex = root.index; + final int fieldCount = root.count; - byte isNullByte = inputByteBuffer.read(columnSortOrderIsDesc[fieldIndex]); - - if (isNullByte == columnNullMarker[fieldIndex]) { - return false; - } + field.start = inputByteBuffer.tell(); /* * We have a field and are positioned to it. Read it. */ - switch (primitiveCategories[fieldIndex]) { + switch (field.primitiveCategory) { case BOOLEAN: currentBoolean = (inputByteBuffer.read(columnSortOrderIsDesc[fieldIndex]) == 2); return true; @@ -445,7 +464,7 @@ public boolean readNextField() throws IOException { // We have a decimal. After we enforce precision and scale, will it become a NULL? - DecimalTypeInfo decimalTypeInfo = (DecimalTypeInfo) typeInfos[fieldIndex]; + DecimalTypeInfo decimalTypeInfo = (DecimalTypeInfo) field.typeInfo; int enforcePrecision = decimalTypeInfo.getPrecision(); int enforceScale = decimalTypeInfo.getScale(); @@ -461,7 +480,7 @@ public boolean readNextField() throws IOException { } return true; default: - throw new RuntimeException("Unexpected primitive type category " + primitiveCategories[fieldIndex]); + throw new RuntimeException("Unexpected primitive type category " + field.primitiveCategory); } } @@ -472,8 +491,53 @@ public boolean readNextField() throws IOException { * Designed for skipping columns that are not included. */ public void skipNextField() throws IOException { - // Not a known use case for BinarySortable -- so don't optimize. - readNextField(); + Field current = stack.peek(); + current.index++; + + if (root.index >= root.count) { + return; + } + + if (inputByteBuffer.isEof()) { + // Also, reading beyond our byte range produces NULL. + return; + } + + if (current.category == ObjectInspector.Category.UNION && current.index == 0) { + current.tag = inputByteBuffer.read(); + currentInt = current.tag; + return; + } + + Field child = getChild(current); + + if (isNull()) { + return; + } + if (child.category == ObjectInspector.Category.PRIMITIVE) { + readPrimitive(child); + } else { + stack.push(child); + switch (child.category) { + case LIST: + case MAP: + while (isNextComplexMultiValue()) { + skipNextField(); + } + break; + case STRUCT: + for (int i = 0; i < child.count; i++) { + skipNextField(); + } + finishComplexVariableFieldsType(); + break; + case UNION: + readComplexField(); + skipNextField(); + finishComplexVariableFieldsType(); + break; + } + } } @Override @@ -482,7 +546,7 @@ public void copyToExternalBuffer(byte[] externalBuffer, int externalBufferStart) } private void copyToBuffer(byte[] buffer, int bufferStart, int bufferLength) throws IOException { - final boolean invert = columnSortOrderIsDesc[fieldIndex]; + final boolean invert = columnSortOrderIsDesc[root.index]; inputByteBuffer.seek(bytesStart); // 3. Copy the data. for (int i = 0; i < bufferLength; i++) { @@ -516,4 +580,146 @@ private void copyToBuffer(byte[] buffer, int bufferStart, int bufferLength) thro public boolean isEndOfInputReached() { return inputByteBuffer.isEof(); } + + private Field[] createFields(TypeInfo[] typeInfos) { + Field[] children = new Field[typeInfos.length]; + for (int i = 0; i < typeInfos.length; i++) { + children[i] = createField(typeInfos[i]); + } + return children; + } + + private Field createField(TypeInfo typeInfo) { + Field field = new Field(); + ObjectInspector.Category category = typeInfo.getCategory(); + field.category = category; + field.typeInfo = typeInfo; + + switch (category) { + case PRIMITIVE: + field.primitiveCategory = ((PrimitiveTypeInfo) typeInfo).getPrimitiveCategory(); + break; + case LIST: + field.children = new Field[1]; + field.children[0] = createField(((ListTypeInfo) typeInfo).getListElementTypeInfo()); + break; + case MAP: + field.children = new Field[2]; + field.children[0] = createField(((MapTypeInfo) typeInfo).getMapKeyTypeInfo()); + field.children[1] = createField(((MapTypeInfo) typeInfo).getMapValueTypeInfo()); + break; + case STRUCT: + StructTypeInfo structTypeInfo = (StructTypeInfo) typeInfo; + List fieldTypeInfos = structTypeInfo.getAllStructFieldTypeInfos(); + field.count = fieldTypeInfos.size(); + field.children = createFields(fieldTypeInfos.toArray(new TypeInfo[fieldTypeInfos.size()])); + break; + case UNION: + UnionTypeInfo unionTypeInfo = (UnionTypeInfo) typeInfo; + List objectTypeInfos = unionTypeInfo.getAllUnionObjectTypeInfos(); + field.count = 2; + field.children = createFields(objectTypeInfos.toArray(new TypeInfo[objectTypeInfos.size()])); + break; + default: + throw new RuntimeException(); + } + return field; + } + + private Field getChild(Field field) { + Field child; + switch (field.category) { + case LIST: + child = field.children[0]; + break; + case MAP: + child = field.children[field.index % 2]; + break; + case STRUCT: + child = field.children[field.index]; + break; + case UNION: + child = field.children[field.tag]; + break; + default: + throw new RuntimeException(); + } + return child; + } + + private boolean isNull() throws IOException { + return inputByteBuffer.read(columnSortOrderIsDesc[root.index]) == + columnNullMarker[root.index]; + } + + @Override + public boolean readComplexField() throws IOException { + Field current = stack.peek(); + current.index++; + + if (root.index >= root.count) { + return false; + } + + if (inputByteBuffer.isEof()) { + // Also, reading beyond our byte range produces NULL. + return false; + } + + if (current.category == ObjectInspector.Category.UNION) { + if (current.index == 0) { + current.tag = inputByteBuffer.read(columnSortOrderIsDesc[root.index]); + currentInt = current.tag; + return true; + } + } + + Field child = getChild(current); + + boolean isNull = isNull(); + + if (isNull) { + return false; + } + if (child.category == ObjectInspector.Category.PRIMITIVE) { + isNull = !readPrimitive(child); + } else { + stack.push(child); + } + return !isNull; + } + + @Override + public boolean isNextComplexMultiValue() throws IOException { + byte isNullByte = inputByteBuffer.read(columnSortOrderIsDesc[root.index]); + boolean isEnded; + + switch (isNullByte) { + case 0: + isEnded = true; + break; + + case 1: + isEnded = false; + break; + + default: + throw new RuntimeException(); + } + + if (isEnded) { + stack.pop(); + stack.peek(); + } + return !isEnded; + } + + @Override + public void finishComplexVariableFieldsType() { + stack.pop(); + if (stack.peek() == null) { + throw new RuntimeException(); + } + stack.peek(); + } } diff --git serde/src/java/org/apache/hadoop/hive/serde2/binarysortable/fast/BinarySortableSerializeWrite.java serde/src/java/org/apache/hadoop/hive/serde2/binarysortable/fast/BinarySortableSerializeWrite.java index a9ea7c0d19..5be7714e09 100644 --- serde/src/java/org/apache/hadoop/hive/serde2/binarysortable/fast/BinarySortableSerializeWrite.java +++ serde/src/java/org/apache/hadoop/hive/serde2/binarysortable/fast/BinarySortableSerializeWrite.java @@ -22,6 +22,8 @@ import java.sql.Date; import java.sql.Timestamp; import java.util.Arrays; +import java.util.List; +import java.util.Map; import org.apache.hadoop.hive.common.type.HiveChar; import org.apache.hadoop.hive.common.type.HiveDecimal; @@ -34,7 +36,6 @@ import org.apache.hadoop.hive.serde2.io.DateWritable; import org.apache.hadoop.hive.serde2.io.HiveDecimalWritable; import org.apache.hadoop.hive.serde2.io.TimestampWritable; -import org.apache.hive.common.util.DateUtils; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -57,8 +58,7 @@ // Which field we are on. We start with -1 to be consistent in style with // BinarySortableDeserializeRead. private int index; - - private int fieldCount; + private int level; private TimestampWritable tempTimestampWritable; @@ -67,7 +67,6 @@ public BinarySortableSerializeWrite(boolean[] columnSortOrderIsDesc, byte[] columnNullMarker, byte[] columnNotNullMarker) { this(); - fieldCount = columnSortOrderIsDesc.length; this.columnSortOrderIsDesc = columnSortOrderIsDesc; this.columnNullMarker = columnNullMarker; this.columnNotNullMarker = columnNotNullMarker; @@ -79,7 +78,6 @@ public BinarySortableSerializeWrite(boolean[] columnSortOrderIsDesc, */ public BinarySortableSerializeWrite(int fieldCount) { this(); - this.fieldCount = fieldCount; columnSortOrderIsDesc = new boolean[fieldCount]; Arrays.fill(columnSortOrderIsDesc, false); columnNullMarker = new byte[fieldCount]; @@ -101,6 +99,7 @@ public void set(Output output) { this.output = output; this.output.reset(); index = -1; + level = 0; } /* @@ -110,6 +109,7 @@ public void set(Output output) { public void setAppend(Output output) { this.output = output; index = -1; + level = 0; } /* @@ -119,6 +119,7 @@ public void setAppend(Output output) { public void reset() { output.reset(); index = -1; + level = 0; } /* @@ -126,23 +127,26 @@ public void reset() { */ @Override public void writeNull() throws IOException { - ++index; + if (level == 0) { + index++; + } BinarySortableSerDe.writeByte(output, columnNullMarker[index], columnSortOrderIsDesc[index]); } + private void beginElement() { + if (level == 0) { + index++; + } + BinarySortableSerDe.writeByte(output, columnNotNullMarker[index], columnSortOrderIsDesc[index]); + } + /* * BOOLEAN. */ @Override public void writeBoolean(boolean v) throws IOException { - ++index; - - final boolean invert = columnSortOrderIsDesc[index]; - - // This field is not a null. - BinarySortableSerDe.writeByte(output, columnNotNullMarker[index], invert); - - BinarySortableSerDe.writeByte(output, (byte) (v ? 2 : 1), invert); + beginElement(); + BinarySortableSerDe.writeByte(output, (byte) (v ? 2 : 1), columnSortOrderIsDesc[index]); } /* @@ -150,14 +154,8 @@ public void writeBoolean(boolean v) throws IOException { */ @Override public void writeByte(byte v) throws IOException { - ++index; - - final boolean invert = columnSortOrderIsDesc[index]; - - // This field is not a null. - BinarySortableSerDe.writeByte(output, columnNotNullMarker[index], invert); - - BinarySortableSerDe.writeByte(output, (byte) (v ^ 0x80), invert); + beginElement(); + BinarySortableSerDe.writeByte(output, (byte) (v ^ 0x80), columnSortOrderIsDesc[index]); } /* @@ -165,14 +163,8 @@ public void writeByte(byte v) throws IOException { */ @Override public void writeShort(short v) throws IOException { - ++index; - - final boolean invert = columnSortOrderIsDesc[index]; - - // This field is not a null. - BinarySortableSerDe.writeByte(output, columnNotNullMarker[index], invert); - - BinarySortableSerDe.serializeShort(output, v, invert); + beginElement(); + BinarySortableSerDe.serializeShort(output, v, columnSortOrderIsDesc[index]); } /* @@ -180,14 +172,8 @@ public void writeShort(short v) throws IOException { */ @Override public void writeInt(int v) throws IOException { - ++index; - - final boolean invert = columnSortOrderIsDesc[index]; - - // This field is not a null. - BinarySortableSerDe.writeByte(output, columnNotNullMarker[index], invert); - - BinarySortableSerDe.serializeInt(output, v, invert); + beginElement(); + BinarySortableSerDe.serializeInt(output, v, columnSortOrderIsDesc[index]); } /* @@ -195,14 +181,8 @@ public void writeInt(int v) throws IOException { */ @Override public void writeLong(long v) throws IOException { - ++index; - - final boolean invert = columnSortOrderIsDesc[index]; - - // This field is not a null. - BinarySortableSerDe.writeByte(output, columnNotNullMarker[index], invert); - - BinarySortableSerDe.serializeLong(output, v, invert); + beginElement(); + BinarySortableSerDe.serializeLong(output, v, columnSortOrderIsDesc[index]); } /* @@ -210,14 +190,8 @@ public void writeLong(long v) throws IOException { */ @Override public void writeFloat(float vf) throws IOException { - ++index; - - final boolean invert = columnSortOrderIsDesc[index]; - - // This field is not a null. - BinarySortableSerDe.writeByte(output, columnNotNullMarker[index], invert); - - BinarySortableSerDe.serializeFloat(output, vf, invert); + beginElement(); + BinarySortableSerDe.serializeFloat(output, vf, columnSortOrderIsDesc[index]); } /* @@ -225,14 +199,8 @@ public void writeFloat(float vf) throws IOException { */ @Override public void writeDouble(double vd) throws IOException { - ++index; - - final boolean invert = columnSortOrderIsDesc[index]; - - // This field is not a null. - BinarySortableSerDe.writeByte(output, columnNotNullMarker[index], invert); - - BinarySortableSerDe.serializeDouble(output, vd, invert); + beginElement(); + BinarySortableSerDe.serializeDouble(output, vd, columnSortOrderIsDesc[index]); } /* @@ -243,26 +211,14 @@ public void writeDouble(double vd) throws IOException { */ @Override public void writeString(byte[] v) throws IOException { - ++index; - - final boolean invert = columnSortOrderIsDesc[index]; - - // This field is not a null. - BinarySortableSerDe.writeByte(output, columnNotNullMarker[index], invert); - - BinarySortableSerDe.serializeBytes(output, v, 0, v.length, invert); + beginElement(); + BinarySortableSerDe.serializeBytes(output, v, 0, v.length, columnSortOrderIsDesc[index]); } @Override public void writeString(byte[] v, int start, int length) throws IOException { - ++index; - - final boolean invert = columnSortOrderIsDesc[index]; - - // This field is not a null. - BinarySortableSerDe.writeByte(output, columnNotNullMarker[index], invert); - - BinarySortableSerDe.serializeBytes(output, v, start, length, invert); + beginElement(); + BinarySortableSerDe.serializeBytes(output, v, start, length, columnSortOrderIsDesc[index]); } /* @@ -290,26 +246,14 @@ public void writeHiveVarchar(HiveVarchar hiveVarchar) throws IOException { */ @Override public void writeBinary(byte[] v) throws IOException { - ++index; - - final boolean invert = columnSortOrderIsDesc[index]; - - // This field is not a null. - BinarySortableSerDe.writeByte(output, columnNotNullMarker[index], invert); - - BinarySortableSerDe.serializeBytes(output, v, 0, v.length, invert); + beginElement(); + BinarySortableSerDe.serializeBytes(output, v, 0, v.length, columnSortOrderIsDesc[index]); } @Override public void writeBinary(byte[] v, int start, int length) { - ++index; - - final boolean invert = columnSortOrderIsDesc[index]; - - // This field is not a null. - BinarySortableSerDe.writeByte(output, columnNotNullMarker[index], invert); - - BinarySortableSerDe.serializeBytes(output, v, start, length, invert); + beginElement(); + BinarySortableSerDe.serializeBytes(output, v, start, length, columnSortOrderIsDesc[index]); } /* @@ -317,27 +261,15 @@ public void writeBinary(byte[] v, int start, int length) { */ @Override public void writeDate(Date date) throws IOException { - ++index; - - final boolean invert = columnSortOrderIsDesc[index]; - - // This field is not a null. - BinarySortableSerDe.writeByte(output, columnNotNullMarker[index], invert); - - BinarySortableSerDe.serializeInt(output, DateWritable.dateToDays(date), invert); + beginElement(); + BinarySortableSerDe.serializeInt(output, DateWritable.dateToDays(date), columnSortOrderIsDesc[index]); } // We provide a faster way to write a date without a Date object. @Override public void writeDate(int dateAsDays) throws IOException { - ++index; - - final boolean invert = columnSortOrderIsDesc[index]; - - // This field is not a null. - BinarySortableSerDe.writeByte(output, columnNotNullMarker[index], invert); - - BinarySortableSerDe.serializeInt(output, dateAsDays, invert); + beginElement(); + BinarySortableSerDe.serializeInt(output, dateAsDays, columnSortOrderIsDesc[index]); } /* @@ -345,15 +277,9 @@ public void writeDate(int dateAsDays) throws IOException { */ @Override public void writeTimestamp(Timestamp vt) throws IOException { - ++index; - - final boolean invert = columnSortOrderIsDesc[index]; - - // This field is not a null. - BinarySortableSerDe.writeByte(output, columnNotNullMarker[index], invert); - + beginElement(); tempTimestampWritable.set(vt); - BinarySortableSerDe.serializeTimestampWritable(output, tempTimestampWritable, invert); + BinarySortableSerDe.serializeTimestampWritable(output, tempTimestampWritable, columnSortOrderIsDesc[index]); } /* @@ -361,26 +287,14 @@ public void writeTimestamp(Timestamp vt) throws IOException { */ @Override public void writeHiveIntervalYearMonth(HiveIntervalYearMonth viyt) throws IOException { - ++index; - - final boolean invert = columnSortOrderIsDesc[index]; - - // This field is not a null. - BinarySortableSerDe.writeByte(output, columnNotNullMarker[index], invert); - - BinarySortableSerDe.serializeHiveIntervalYearMonth(output, viyt, invert); + beginElement(); + BinarySortableSerDe.serializeHiveIntervalYearMonth(output, viyt, columnSortOrderIsDesc[index]); } @Override public void writeHiveIntervalYearMonth(int totalMonths) throws IOException { - ++index; - - final boolean invert = columnSortOrderIsDesc[index]; - - // This field is not a null. - BinarySortableSerDe.writeByte(output, columnNotNullMarker[index], invert); - - BinarySortableSerDe.serializeInt(output, totalMonths, invert); + beginElement(); + BinarySortableSerDe.serializeInt(output, totalMonths, columnSortOrderIsDesc[index]); } /* @@ -388,14 +302,8 @@ public void writeHiveIntervalYearMonth(int totalMonths) throws IOException { */ @Override public void writeHiveIntervalDayTime(HiveIntervalDayTime vidt) throws IOException { - ++index; - - final boolean invert = columnSortOrderIsDesc[index]; - - // This field is not a null. - BinarySortableSerDe.writeByte(output, columnNotNullMarker[index], invert); - - BinarySortableSerDe.serializeHiveIntervalDayTime(output, vidt, invert); + beginElement(); + BinarySortableSerDe.serializeHiveIntervalDayTime(output, vidt, columnSortOrderIsDesc[index]); } /* @@ -406,31 +314,104 @@ public void writeHiveIntervalDayTime(HiveIntervalDayTime vidt) throws IOExceptio */ @Override public void writeHiveDecimal(HiveDecimal dec, int scale) throws IOException { - ++index; - - final boolean invert = columnSortOrderIsDesc[index]; - - // This field is not a null. - BinarySortableSerDe.writeByte(output, columnNotNullMarker[index], invert); - + beginElement(); if (decimalBytesScratch == null) { decimalBytesScratch = new byte[HiveDecimal.SCRATCH_BUFFER_LEN_TO_BYTES]; } - BinarySortableSerDe.serializeHiveDecimal(output, dec, invert, decimalBytesScratch); + BinarySortableSerDe.serializeHiveDecimal(output, dec, columnSortOrderIsDesc[index], decimalBytesScratch); } @Override public void writeHiveDecimal(HiveDecimalWritable decWritable, int scale) throws IOException { - ++index; + beginElement(); + if (decimalBytesScratch == null) { + decimalBytesScratch = new byte[HiveDecimal.SCRATCH_BUFFER_LEN_TO_BYTES]; + } + BinarySortableSerDe.serializeHiveDecimal(output, decWritable, columnSortOrderIsDesc[index], decimalBytesScratch); + } - final boolean invert = columnSortOrderIsDesc[index]; + /* + * List + */ + @Override + public void beginList(List list) { + beginElement(); + level++; + if (!list.isEmpty()) { + BinarySortableSerDe.writeByte(output, (byte) 1, columnSortOrderIsDesc[index]); + } + } - // This field is not a null. - BinarySortableSerDe.writeByte(output, columnNotNullMarker[index], invert); + @Override + public void separateList() { + BinarySortableSerDe.writeByte(output, (byte) 1, columnSortOrderIsDesc[index]); + } - if (decimalBytesScratch == null) { - decimalBytesScratch = new byte[HiveDecimal.SCRATCH_BUFFER_LEN_TO_BYTES]; + @Override + public void finishList() { + level--; + // and \0 to terminate + BinarySortableSerDe.writeByte(output, (byte) 0, columnSortOrderIsDesc[index]); + } + + /* + * Map + */ + @Override + public void beginMap(Map map) { + beginElement(); + level++; + if (!map.isEmpty()) { + BinarySortableSerDe.writeByte(output, (byte) 1, columnSortOrderIsDesc[index]); } - BinarySortableSerDe.serializeHiveDecimal(output, decWritable, invert, decimalBytesScratch); + } + + @Override + public void separateKey() { + } + + @Override + public void separateKeyValuePair() { + BinarySortableSerDe.writeByte(output, (byte) 1, columnSortOrderIsDesc[index]); + } + + @Override + public void finishMap() { + level--; + // and \0 to terminate + BinarySortableSerDe.writeByte(output, (byte) 0, columnSortOrderIsDesc[index]); + } + + /* + * Struct + */ + @Override + public void beginStruct(List fieldValues) { + beginElement(); + level++; + } + + @Override + public void separateStruct() { + } + + @Override + public void finishStruct() { + level--; + } + + /* + * Union + */ + @Override + public void beginUnion(int tag) throws IOException { + beginElement(); + BinarySortableSerDe.writeByte(output, (byte) tag, columnSortOrderIsDesc[index]); + level++; + } + + @Override + public void finishUnion() { + level--; } } diff --git serde/src/java/org/apache/hadoop/hive/serde2/fast/DeserializeRead.java serde/src/java/org/apache/hadoop/hive/serde2/fast/DeserializeRead.java index ac931d6d64..cb6ed182f9 100644 --- serde/src/java/org/apache/hadoop/hive/serde2/fast/DeserializeRead.java +++ serde/src/java/org/apache/hadoop/hive/serde2/fast/DeserializeRead.java @@ -19,6 +19,7 @@ package org.apache.hadoop.hive.serde2.fast; import java.io.IOException; + import org.apache.hadoop.hive.serde2.io.DateWritable; import org.apache.hadoop.hive.serde2.io.HiveDecimalWritable; import org.apache.hadoop.hive.serde2.io.HiveIntervalDayTimeWritable; @@ -26,8 +27,12 @@ import org.apache.hadoop.hive.serde2.io.TimestampWritable; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector.Category; import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector.PrimitiveCategory; +import org.apache.hadoop.hive.serde2.typeinfo.ListTypeInfo; +import org.apache.hadoop.hive.serde2.typeinfo.MapTypeInfo; import org.apache.hadoop.hive.serde2.typeinfo.PrimitiveTypeInfo; +import org.apache.hadoop.hive.serde2.typeinfo.StructTypeInfo; import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo; +import org.apache.hadoop.hive.serde2.typeinfo.UnionTypeInfo; /* * Directly deserialize with the caller reading field-by-field a serialization format. @@ -52,6 +57,68 @@ protected Category[] categories; protected PrimitiveCategory[] primitiveCategories; + /* + * This class is used to read one field at a time. Simple fields like long, double, int are read + * into to primitive current* members; the non-simple field types like Date, Timestamp, etc, are + * read into a current object that this method will allocate. + * + * This method handles complex type fields by recursively calling this method. + */ + private void allocateCurrentWritable(TypeInfo typeInfo) { + switch (typeInfo.getCategory()) { + case PRIMITIVE: + switch (((PrimitiveTypeInfo) typeInfo).getPrimitiveCategory()) { + case DATE: + if (currentDateWritable == null) { + currentDateWritable = new DateWritable(); + } + break; + case TIMESTAMP: + if (currentTimestampWritable == null) { + currentTimestampWritable = new TimestampWritable(); + } + break; + case INTERVAL_YEAR_MONTH: + if (currentHiveIntervalYearMonthWritable == null) { + currentHiveIntervalYearMonthWritable = new HiveIntervalYearMonthWritable(); + } + break; + case INTERVAL_DAY_TIME: + if (currentHiveIntervalDayTimeWritable == null) { + currentHiveIntervalDayTimeWritable = new HiveIntervalDayTimeWritable(); + } + break; + case DECIMAL: + if (currentHiveDecimalWritable == null) { + currentHiveDecimalWritable = new HiveDecimalWritable(); + } + break; + default: + // No writable needed for this data type. + } + break; + case LIST: + allocateCurrentWritable(((ListTypeInfo) typeInfo).getListElementTypeInfo()); + break; + case MAP: + allocateCurrentWritable(((MapTypeInfo) typeInfo).getMapKeyTypeInfo()); + allocateCurrentWritable(((MapTypeInfo) typeInfo).getMapValueTypeInfo()); + break; + case STRUCT: + for (TypeInfo fieldTypeInfo : ((StructTypeInfo) typeInfo).getAllStructFieldTypeInfos()) { + allocateCurrentWritable(fieldTypeInfo); + } + break; + case UNION: + for (TypeInfo fieldTypeInfo : ((UnionTypeInfo) typeInfo).getAllUnionObjectTypeInfos()) { + allocateCurrentWritable(fieldTypeInfo); + } + break; + default: + throw new RuntimeException("Unexpected category " + typeInfo.getCategory()); + } + } + /** * Constructor. * @@ -85,37 +152,8 @@ public DeserializeRead(TypeInfo[] typeInfos, boolean useExternalBuffer) { PrimitiveTypeInfo primitiveTypeInfo = (PrimitiveTypeInfo) typeInfo; PrimitiveCategory primitiveCategory = primitiveTypeInfo.getPrimitiveCategory(); primitiveCategories[i] = primitiveCategory; - - switch (primitiveCategory) { - case DATE: - if (currentDateWritable == null) { - currentDateWritable = new DateWritable(); - } - break; - case TIMESTAMP: - if (currentTimestampWritable == null) { - currentTimestampWritable = new TimestampWritable(); - } - break; - case INTERVAL_YEAR_MONTH: - if (currentHiveIntervalYearMonthWritable == null) { - currentHiveIntervalYearMonthWritable = new HiveIntervalYearMonthWritable(); - } - break; - case INTERVAL_DAY_TIME: - if (currentHiveIntervalDayTimeWritable == null) { - currentHiveIntervalDayTimeWritable = new HiveIntervalDayTimeWritable(); - } - break; - case DECIMAL: - if (currentHiveDecimalWritable == null) { - currentHiveDecimalWritable = new HiveDecimalWritable(); - } - break; - default: - // No writable needed for this data type. - } } + allocateCurrentWritable(typeInfo); this.useExternalBuffer = useExternalBuffer; } @@ -178,6 +216,22 @@ public boolean readField(int fieldIndex) throws IOException { } /* + * Tests whether there is another List element or another Map key/value pair. + */ + public abstract boolean isNextComplexMultiValue() throws IOException; + + /* + * Read a field that is under a complex type. It may be a primitive type or deeper complex type. + */ + public abstract boolean readComplexField() throws IOException; + + /* + * Used by Struct and Union complex type readers to indicate the (final) field has been fully + * read and the current complex type is finished. + */ + public abstract void finishComplexVariableFieldsType(); + + /* * Call this method may be called after all the all fields have been read to check * for unread fields. * diff --git serde/src/java/org/apache/hadoop/hive/serde2/fast/SerializeWrite.java serde/src/java/org/apache/hadoop/hive/serde2/fast/SerializeWrite.java index 17d2385563..89bcf4fd2f 100644 --- serde/src/java/org/apache/hadoop/hive/serde2/fast/SerializeWrite.java +++ serde/src/java/org/apache/hadoop/hive/serde2/fast/SerializeWrite.java @@ -21,6 +21,8 @@ import java.io.IOException; import java.sql.Date; import java.sql.Timestamp; +import java.util.List; +import java.util.Map; import org.apache.hadoop.hive.common.type.HiveChar; import org.apache.hadoop.hive.common.type.HiveDecimal; @@ -154,4 +156,32 @@ */ void writeHiveDecimal(HiveDecimal dec, int scale) throws IOException; void writeHiveDecimal(HiveDecimalWritable decWritable, int scale) throws IOException; + + /* + * LIST. + */ + void beginList(List list); + void separateList(); + void finishList(); + + /* + * MAP. + */ + void beginMap(Map map); + void separateKey(); + void separateKeyValuePair(); + void finishMap(); + + /* + * STRUCT. + */ + void beginStruct(List fieldValues); + void separateStruct(); + void finishStruct(); + + /* + * UNION. + */ + void beginUnion(int tag) throws IOException; + void finishUnion(); } diff --git serde/src/java/org/apache/hadoop/hive/serde2/io/TimestampWritable.java serde/src/java/org/apache/hadoop/hive/serde2/io/TimestampWritable.java index bbccc7fe90..64bcb11c62 100644 --- serde/src/java/org/apache/hadoop/hive/serde2/io/TimestampWritable.java +++ serde/src/java/org/apache/hadoop/hive/serde2/io/TimestampWritable.java @@ -133,7 +133,8 @@ public void set(Timestamp t) { timestamp.setNanos(0); return; } - this.timestamp = t; + timestamp.setTime(t.getTime()); + timestamp.setNanos(t.getNanos()); bytesEmpty = true; timestampEmpty = false; } diff --git serde/src/java/org/apache/hadoop/hive/serde2/lazy/VerifyLazy.java serde/src/java/org/apache/hadoop/hive/serde2/lazy/VerifyLazy.java new file mode 100644 index 0000000000..324f5b85ee --- /dev/null +++ serde/src/java/org/apache/hadoop/hive/serde2/lazy/VerifyLazy.java @@ -0,0 +1,444 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.hadoop.hive.serde2.lazy; + +import java.sql.Date; +import java.sql.Timestamp; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.Map; + +import org.apache.hadoop.hive.common.type.HiveChar; +import org.apache.hadoop.hive.common.type.HiveDecimal; +import org.apache.hadoop.hive.common.type.HiveIntervalDayTime; +import org.apache.hadoop.hive.common.type.HiveIntervalYearMonth; +import org.apache.hadoop.hive.common.type.HiveVarchar; +import org.apache.hadoop.hive.serde2.io.ByteWritable; +import org.apache.hadoop.hive.serde2.io.DateWritable; +import org.apache.hadoop.hive.serde2.io.DoubleWritable; +import org.apache.hadoop.hive.serde2.io.HiveCharWritable; +import org.apache.hadoop.hive.serde2.io.HiveDecimalWritable; +import org.apache.hadoop.hive.serde2.io.HiveIntervalDayTimeWritable; +import org.apache.hadoop.hive.serde2.io.HiveIntervalYearMonthWritable; +import org.apache.hadoop.hive.serde2.io.HiveVarcharWritable; +import org.apache.hadoop.hive.serde2.io.ShortWritable; +import org.apache.hadoop.hive.serde2.io.TimestampWritable; +import org.apache.hadoop.hive.serde2.lazybinary.LazyBinaryArray; +import org.apache.hadoop.hive.serde2.lazybinary.LazyBinaryMap; +import org.apache.hadoop.hive.serde2.lazybinary.LazyBinaryStruct; +import org.apache.hadoop.hive.serde2.lazybinary.LazyBinaryUnion; +import org.apache.hadoop.hive.serde2.objectinspector.StandardUnionObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.UnionObject; +import org.apache.hadoop.hive.serde2.typeinfo.DecimalTypeInfo; +import org.apache.hadoop.hive.serde2.typeinfo.ListTypeInfo; +import org.apache.hadoop.hive.serde2.typeinfo.MapTypeInfo; +import org.apache.hadoop.hive.serde2.typeinfo.PrimitiveTypeInfo; +import org.apache.hadoop.hive.serde2.typeinfo.StructTypeInfo; +import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo; +import org.apache.hadoop.hive.serde2.typeinfo.UnionTypeInfo; +import org.apache.hadoop.io.BooleanWritable; +import org.apache.hadoop.io.BytesWritable; +import org.apache.hadoop.io.FloatWritable; +import org.apache.hadoop.io.IntWritable; +import org.apache.hadoop.io.LongWritable; +import org.apache.hadoop.io.Text; +import org.apache.hadoop.io.Writable; + +/** + * TestBinarySortableSerDe. + * + */ +public class VerifyLazy { + + public static boolean lazyCompareList(ListTypeInfo listTypeInfo, List list, List expectedList) { + TypeInfo elementTypeInfo = listTypeInfo.getListElementTypeInfo(); + final int size = list.size(); + for (int i = 0; i < size; i++) { + Object lazyEleObj = list.get(i); + Object expectedEleObj = expectedList.get(i); + if (!lazyCompare(elementTypeInfo, lazyEleObj, expectedEleObj)) { + throw new RuntimeException("List element deserialized value does not match elementTypeInfo " + elementTypeInfo.toString()); + } + } + return true; + } + + public static boolean lazyCompareMap(MapTypeInfo mapTypeInfo, Map map, Map expectedMap) { + TypeInfo keyTypeInfo = mapTypeInfo.getMapKeyTypeInfo(); + TypeInfo valueTypeInfo = mapTypeInfo.getMapValueTypeInfo(); + if (map.size() != expectedMap.size()) { + throw new RuntimeException("Map key/value deserialized map.size() " + map.size() + " map " + map.toString() + + " expectedMap.size() " + expectedMap.size() + " expectedMap " + expectedMap.toString() + + " does not match keyTypeInfo " + keyTypeInfo.toString() + " valueTypeInfo " + valueTypeInfo.toString()); + } + return true; + } + + public static boolean lazyCompareStruct(StructTypeInfo structTypeInfo, List fields, List expectedFields) { + ArrayList fieldTypeInfos = structTypeInfo.getAllStructFieldTypeInfos(); + final int size = fieldTypeInfos.size(); + for (int i = 0; i < size; i++) { + Object lazyEleObj = fields.get(i); + Object expectedEleObj = expectedFields.get(i); + if (!lazyCompare(fieldTypeInfos.get(i), lazyEleObj, expectedEleObj)) { + throw new RuntimeException("SerDe deserialized value does not match"); + } + } + return true; + } + + public static boolean lazyCompareUnion(UnionTypeInfo unionTypeInfo, LazyBinaryUnion union, UnionObject expectedUnion) { + byte tag = union.getTag(); + byte expectedTag = expectedUnion.getTag(); + if (tag != expectedTag) { + throw new RuntimeException("Union tag does not match union.getTag() " + tag + " expectedUnion.getTag() " + expectedTag); + } + return lazyCompare(unionTypeInfo.getAllUnionObjectTypeInfos().get(tag), + union.getField(), expectedUnion.getObject()); + } + + public static boolean lazyCompareUnion(UnionTypeInfo unionTypeInfo, LazyUnion union, UnionObject expectedUnion) { + byte tag = union.getTag(); + byte expectedTag = expectedUnion.getTag(); + if (tag != expectedTag) { + throw new RuntimeException("Union tag does not match union.getTag() " + tag + " expectedUnion.getTag() " + expectedTag); + } + return lazyCompare(unionTypeInfo.getAllUnionObjectTypeInfos().get(tag), + union.getField(), expectedUnion.getObject()); + } + + public static boolean lazyCompareUnion(UnionTypeInfo unionTypeInfo, UnionObject union, UnionObject expectedUnion) { + byte tag = union.getTag(); + byte expectedTag = expectedUnion.getTag(); + if (tag != expectedTag) { + throw new RuntimeException("Union tag does not match union.getTag() " + tag + + " expectedUnion.getTag() " + expectedTag); + } + return lazyCompare(unionTypeInfo.getAllUnionObjectTypeInfos().get(tag), + union.getObject(), expectedUnion.getObject()); + } + + public static boolean lazyCompare(TypeInfo typeInfo, Object lazyObject, Object expectedObject) { + if (expectedObject == null) { + if (lazyObject != null) { + throw new RuntimeException("Expected object is null but object is not null " + lazyObject.toString() + + " typeInfo " + typeInfo.toString()); + } + return true; + } else if (lazyObject == null) { + throw new RuntimeException("Expected object is not null \"" + expectedObject.toString() + + "\" typeInfo " + typeInfo.toString() + " but object is null"); + } + if (lazyObject instanceof Writable) { + if (!lazyObject.equals(expectedObject)) { + throw new RuntimeException("Expected object " + expectedObject.toString() + + " and actual object " + lazyObject.toString() + " is not equal typeInfo " + typeInfo.toString()); + } + return true; + } + if (lazyObject instanceof LazyPrimitive) { + Object primitiveObject = ((LazyPrimitive) lazyObject).getObject(); + PrimitiveTypeInfo primitiveTypeInfo = (PrimitiveTypeInfo) typeInfo; + switch (primitiveTypeInfo.getPrimitiveCategory()) { + case BOOLEAN: + { + if (!(primitiveObject instanceof LazyBoolean)) { + throw new RuntimeException("Expected LazyBoolean"); + } + boolean value = ((LazyBoolean) primitiveObject).getWritableObject().get(); + boolean expected = ((BooleanWritable) expectedObject).get(); + if (value != expected) { + throw new RuntimeException("Boolean field mismatch (expected " + expected + " found " + value + ")"); + } + } + break; + case BYTE: + { + if (!(primitiveObject instanceof LazyByte)) { + throw new RuntimeException("Expected LazyByte"); + } + byte value = ((LazyByte) primitiveObject).getWritableObject().get(); + byte expected = ((ByteWritable) expectedObject).get(); + if (value != expected) { + throw new RuntimeException("Byte field mismatch (expected " + (int) expected + " found " + (int) value + ")"); + } + } + break; + case SHORT: + { + if (!(primitiveObject instanceof LazyShort)) { + throw new RuntimeException("Expected LazyShort"); + } + short value = ((LazyShort) primitiveObject).getWritableObject().get(); + short expected = ((ShortWritable) expectedObject).get(); + if (value != expected) { + throw new RuntimeException("Short field mismatch (expected " + expected + " found " + value + ")"); + } + } + break; + case INT: + { + if (!(primitiveObject instanceof LazyInteger)) { + throw new RuntimeException("Expected LazyInteger"); + } + int value = ((LazyInteger) primitiveObject).getWritableObject().get(); + int expected = ((IntWritable) expectedObject).get(); + if (value != expected) { + throw new RuntimeException("Int field mismatch (expected " + expected + " found " + value + ")"); + } + } + break; + case LONG: + { + if (!(primitiveObject instanceof LazyLong)) { + throw new RuntimeException("Expected LazyLong"); + } + long value = ((LazyLong) primitiveObject).getWritableObject().get(); + long expected = ((LongWritable) expectedObject).get(); + if (value != expected) { + throw new RuntimeException("Long field mismatch (expected " + expected + " found " + value + ")"); + } + } + break; + case FLOAT: + { + if (!(primitiveObject instanceof LazyFloat)) { + throw new RuntimeException("Expected LazyFloat"); + } + float value = ((LazyFloat) primitiveObject).getWritableObject().get(); + float expected = ((FloatWritable) expectedObject).get(); + if (value != expected) { + throw new RuntimeException("Float field mismatch (expected " + expected + " found " + value + ")"); + } + } + break; + case DOUBLE: + { + if (!(primitiveObject instanceof LazyDouble)) { + throw new RuntimeException("Expected LazyDouble"); + } + double value = ((LazyDouble) primitiveObject).getWritableObject().get(); + double expected = ((DoubleWritable) expectedObject).get(); + if (value != expected) { + throw new RuntimeException("Double field mismatch (expected " + expected + " found " + value + ")"); + } + } + break; + case STRING: + { + if (!(primitiveObject instanceof LazyString)) { + throw new RuntimeException("Text expected writable not Text"); + } + Text value = ((LazyString) primitiveObject).getWritableObject(); + Text expected = ((Text) expectedObject); + if (!value.equals(expected)) { + throw new RuntimeException("String field mismatch (expected '" + expected + "' found '" + value + "')"); + } + } + break; + case CHAR: + { + if (!(primitiveObject instanceof LazyHiveChar)) { + throw new RuntimeException("Expected LazyHiveChar"); + } + HiveChar value = ((LazyHiveChar) primitiveObject).getWritableObject().getHiveChar(); + HiveChar expected = ((HiveCharWritable) expectedObject).getHiveChar(); + + if (!value.equals(expected)) { + throw new RuntimeException("HiveChar field mismatch (expected '" + expected + "' found '" + value + "')"); + } + } + break; + case VARCHAR: + { + if (!(primitiveObject instanceof LazyHiveVarchar)) { + throw new RuntimeException("Expected LazyHiveVarchar"); + } + HiveVarchar value = ((LazyHiveVarchar) primitiveObject).getWritableObject().getHiveVarchar(); + HiveVarchar expected = ((HiveVarcharWritable) expectedObject).getHiveVarchar(); + + if (!value.equals(expected)) { + throw new RuntimeException("HiveVarchar field mismatch (expected '" + expected + "' found '" + value + "')"); + } + } + break; + case DECIMAL: + { + if (!(primitiveObject instanceof LazyHiveDecimal)) { + throw new RuntimeException("Expected LazyDecimal"); + } + HiveDecimal value = ((LazyHiveDecimal) primitiveObject).getWritableObject().getHiveDecimal(); + HiveDecimal expected = ((HiveDecimalWritable) expectedObject).getHiveDecimal(); + + if (!value.equals(expected)) { + DecimalTypeInfo decimalTypeInfo = (DecimalTypeInfo) primitiveTypeInfo; + int precision = decimalTypeInfo.getPrecision(); + int scale = decimalTypeInfo.getScale(); + throw new RuntimeException("Decimal field mismatch (expected " + expected.toString() + + " found " + value.toString() + ") precision " + precision + ", scale " + scale); + } + } + break; + case DATE: + { + if (!(primitiveObject instanceof LazyDate)) { + throw new RuntimeException("Expected LazyDate"); + } + Date value = ((LazyDate) primitiveObject).getWritableObject().get(); + Date expected = ((DateWritable) expectedObject).get(); + if (!value.equals(expected)) { + throw new RuntimeException("Date field mismatch (expected " + expected + " found " + value + ")"); + } + } + break; + case TIMESTAMP: + { + if (!(primitiveObject instanceof LazyTimestamp)) { + throw new RuntimeException("TimestampWritable expected writable not TimestampWritable"); + } + Timestamp value = ((LazyTimestamp) primitiveObject).getWritableObject().getTimestamp(); + Timestamp expected = ((TimestampWritable) expectedObject).getTimestamp(); + if (!value.equals(expected)) { + throw new RuntimeException("Timestamp field mismatch (expected " + expected + " found " + value + ")"); + } + } + break; + case INTERVAL_YEAR_MONTH: + { + if (!(primitiveObject instanceof LazyHiveIntervalYearMonth)) { + throw new RuntimeException("Expected LazyHiveIntervalYearMonth"); + } + HiveIntervalYearMonth value = ((LazyHiveIntervalYearMonth) primitiveObject).getWritableObject().getHiveIntervalYearMonth(); + HiveIntervalYearMonth expected = ((HiveIntervalYearMonthWritable) expectedObject).getHiveIntervalYearMonth(); + if (!value.equals(expected)) { + throw new RuntimeException("HiveIntervalYearMonth field mismatch (expected " + expected + " found " + value + ")"); + } + } + break; + case INTERVAL_DAY_TIME: + { + if (!(primitiveObject instanceof LazyHiveIntervalDayTime)) { + throw new RuntimeException("Expected writable LazyHiveIntervalDayTime"); + } + HiveIntervalDayTime value = ((LazyHiveIntervalDayTime) primitiveObject).getWritableObject().getHiveIntervalDayTime(); + HiveIntervalDayTime expected = ((HiveIntervalDayTimeWritable) expectedObject).getHiveIntervalDayTime(); + if (!value.equals(expected)) { + throw new RuntimeException("HiveIntervalDayTime field mismatch (expected " + expected + " found " + value + ")"); + } + } + break; + case BINARY: + { + if (!(primitiveObject instanceof LazyBinary)) { + throw new RuntimeException("Expected LazyBinary"); + } + BytesWritable bytesWritable = ((LazyBinary) primitiveObject).getWritableObject(); + byte[] value = Arrays.copyOfRange(bytesWritable.getBytes(), 0, bytesWritable.getLength()); + BytesWritable bytesWritableExpected = (BytesWritable) expectedObject; + byte[] expected = Arrays.copyOfRange(bytesWritableExpected.getBytes(), 0, bytesWritableExpected.getLength()); + if (value.length != expected.length){ + throw new RuntimeException("Byte Array field mismatch (expected " + Arrays.toString(expected) + + " found " + Arrays.toString(value) + ")"); + } + for (int b = 0; b < value.length; b++) { + if (value[b] != expected[b]) { + throw new RuntimeException("Byte Array field mismatch (expected " + Arrays.toString(expected) + + " found " + Arrays.toString(value) + ")"); + } + } + } + break; + default: + throw new Error("Unknown primitive category " + primitiveTypeInfo.getPrimitiveCategory()); + } + } else if (lazyObject instanceof LazyArray) { + LazyArray lazyArray = (LazyArray) lazyObject; + List list = lazyArray.getList(); + List expectedList = (List) expectedObject; + ListTypeInfo listTypeInfo = (ListTypeInfo) typeInfo; + if (list.size() != expectedList.size()) { + throw new RuntimeException("SerDe deserialized list length does not match (list " + + list.toString() + " list.size() " + list.size() + " expectedList " + expectedList.toString() + + " expectedList.size() " + expectedList.size() + ")" + + " elementTypeInfo " + listTypeInfo.getListElementTypeInfo().toString()); + } + return lazyCompareList((ListTypeInfo) typeInfo, list, expectedList); + } else if (typeInfo instanceof ListTypeInfo) { + List list; + if (lazyObject instanceof LazyBinaryArray) { + list = ((LazyBinaryArray) lazyObject).getList(); + } else { + list = (List) lazyObject; + } + List expectedList = (List) expectedObject; + if (list.size() != expectedList.size()) { + throw new RuntimeException("SerDe deserialized list length does not match (list " + + list.toString() + " list.size() " + list.size() + " expectedList " + expectedList.toString() + + " expectedList.size() " + expectedList.size() + ")"); + } + return lazyCompareList((ListTypeInfo) typeInfo, list, expectedList); + } else if (lazyObject instanceof LazyMap) { + LazyMap lazyMap = (LazyMap) lazyObject; + Map map = lazyMap.getMap(); + Map expectedMap = (Map) expectedObject; + return lazyCompareMap((MapTypeInfo) typeInfo, map, expectedMap); + } else if (typeInfo instanceof MapTypeInfo) { + Map map; + Map expectedMap = (Map) expectedObject; + if (lazyObject instanceof LazyBinaryMap) { + map = ((LazyBinaryMap) lazyObject).getMap(); + } else { + map = (Map) lazyObject; + } + return lazyCompareMap((MapTypeInfo) typeInfo, map, expectedMap); + } else if (lazyObject instanceof LazyStruct) { + LazyStruct lazyStruct = (LazyStruct) lazyObject; + List fields = lazyStruct.getFieldsAsList(); + List expectedFields = (List) expectedObject; + StructTypeInfo structTypeInfo = (StructTypeInfo) typeInfo; + return lazyCompareStruct(structTypeInfo, fields, expectedFields); + } else if (typeInfo instanceof StructTypeInfo) { + ArrayList fields; + if (lazyObject instanceof LazyBinaryStruct) { + fields = ((LazyBinaryStruct) lazyObject).getFieldsAsList(); + } else { + fields = (ArrayList) lazyObject; + } + List expectedFields = (List) expectedObject; + StructTypeInfo structTypeInfo = (StructTypeInfo) typeInfo; + return lazyCompareStruct(structTypeInfo, fields, expectedFields); + } else if (lazyObject instanceof LazyUnion) { + LazyUnion union = (LazyUnion) lazyObject; + StandardUnionObjectInspector.StandardUnion expectedUnion = (StandardUnionObjectInspector.StandardUnion) expectedObject; + UnionTypeInfo unionTypeInfo = (UnionTypeInfo) typeInfo; + return lazyCompareUnion(unionTypeInfo, union, expectedUnion); + } else if (typeInfo instanceof UnionTypeInfo) { + StandardUnionObjectInspector.StandardUnion expectedUnion = (StandardUnionObjectInspector.StandardUnion) expectedObject; + UnionTypeInfo unionTypeInfo = (UnionTypeInfo) typeInfo; + if (lazyObject instanceof LazyBinaryUnion) { + return lazyCompareUnion(unionTypeInfo, (LazyBinaryUnion) lazyObject, expectedUnion); + } else { + return lazyCompareUnion(unionTypeInfo, (UnionObject) lazyObject, expectedUnion); + } + } else { + System.err.println("Not implemented " + typeInfo.getClass().getName()); + } + return true; + } +} \ No newline at end of file diff --git serde/src/java/org/apache/hadoop/hive/serde2/lazy/fast/LazySimpleDeserializeRead.java serde/src/java/org/apache/hadoop/hive/serde2/lazy/fast/LazySimpleDeserializeRead.java index 606b246b5d..21e4c56caa 100644 --- serde/src/java/org/apache/hadoop/hive/serde2/lazy/fast/LazySimpleDeserializeRead.java +++ serde/src/java/org/apache/hadoop/hive/serde2/lazy/fast/LazySimpleDeserializeRead.java @@ -19,11 +19,11 @@ package org.apache.hadoop.hive.serde2.lazy.fast; import java.io.IOException; -import java.io.UnsupportedEncodingException; import java.nio.charset.CharacterCodingException; import java.nio.charset.StandardCharsets; import java.sql.Date; import java.util.Arrays; +import java.util.List; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -37,12 +37,21 @@ import org.apache.hadoop.hive.serde2.lazy.LazySerDeParameters; import org.apache.hadoop.hive.serde2.lazy.LazyShort; import org.apache.hadoop.hive.serde2.lazy.LazyUtils; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector.Category; import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector.PrimitiveCategory; import org.apache.hadoop.hive.serde2.typeinfo.DecimalTypeInfo; +import org.apache.hadoop.hive.serde2.typeinfo.ListTypeInfo; +import org.apache.hadoop.hive.serde2.typeinfo.MapTypeInfo; +import org.apache.hadoop.hive.serde2.typeinfo.PrimitiveTypeInfo; +import org.apache.hadoop.hive.serde2.typeinfo.StructTypeInfo; import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo; +import org.apache.hadoop.hive.serde2.typeinfo.TypeInfoFactory; +import org.apache.hadoop.hive.serde2.typeinfo.UnionTypeInfo; import org.apache.hadoop.io.Text; import org.apache.hive.common.util.TimestampParser; +import com.google.common.base.Preconditions; + /* * Directly deserialize with the caller reading field-by-field the LazySimple (text) * serialization format. @@ -61,9 +70,123 @@ public final class LazySimpleDeserializeRead extends DeserializeRead { public static final Logger LOG = LoggerFactory.getLogger(LazySimpleDeserializeRead.class.getName()); - private int[] startPosition; + /* + * Information on a field. Made a class to allow readField to be agnostic to whether a top level + * or field within a complex type is being read + */ + private static class Field { + + // Optimize for most common case -- primitive. + public final boolean isPrimitive; + public final PrimitiveCategory primitiveCategory; + + public final Category complexCategory; + + public final TypeInfo typeInfo; + + public ComplexTypeHelper complexTypeHelper; + + public Field(TypeInfo typeInfo) { + Category category = typeInfo.getCategory(); + if (category == Category.PRIMITIVE) { + isPrimitive = true; + primitiveCategory = ((PrimitiveTypeInfo) typeInfo).getPrimitiveCategory(); + complexCategory = null; + } else { + isPrimitive = false; + primitiveCategory = null; + complexCategory = category; + } + + this.typeInfo = typeInfo; + + complexTypeHelper = null; + } + } + + /* + * Used to keep position/length for complex type fields. + * NOTE: The top level uses startPositions instead. + */ + private static class ComplexTypeHelper { + + public final Field complexField; + + public int complexFieldStart; + public int complexFieldLength; + public int complexFieldEnd; + + public int fieldPosition; + + public ComplexTypeHelper(Field complexField) { + this.complexField = complexField; + } + + public void setCurrentFieldInfo(int complexFieldStart, int complexFieldLength) { + this.complexFieldStart = complexFieldStart; + this.complexFieldLength = complexFieldLength; + complexFieldEnd = complexFieldStart + complexFieldLength; + fieldPosition = complexFieldStart; + } + } + + private static class ListComplexTypeHelper extends ComplexTypeHelper { + + public Field elementField; + + public ListComplexTypeHelper(Field complexField, Field elementField) { + super(complexField); + this.elementField = elementField; + } + } + + private static class MapComplexTypeHelper extends ComplexTypeHelper { + + public Field keyField; + public Field valueField; + + public boolean fieldHaveParsedKey; + + public MapComplexTypeHelper(Field complexField, Field keyField, Field valueField) { + super(complexField); + this.keyField = keyField; + this.valueField = valueField; + fieldHaveParsedKey = false; + } + } + + private static class StructComplexTypeHelper extends ComplexTypeHelper { + + public Field[] fields; + + public int nextFieldIndex; + + public StructComplexTypeHelper(Field complexField, Field[] fields) { + super(complexField); + this.fields = fields; + nextFieldIndex = 0; + } + } + + private static class UnionComplexTypeHelper extends ComplexTypeHelper { + + public Field tagField; + public Field[] fields; + + public boolean fieldHaveParsedTag; + public int fieldTag; - private final byte separator; + public UnionComplexTypeHelper(Field complexField, Field[] fields) { + super(complexField); + this.tagField = new Field(TypeInfoFactory.intTypeInfo); + this.fields = fields; + fieldHaveParsedTag = false; + } + } + + private int[] startPositions; + + private final byte[] separators; private final boolean isEscaped; private final byte escapeChar; private final int[] escapeCounts; @@ -71,19 +194,25 @@ private final boolean isExtendedBooleanLiteral; private final int fieldCount; + private final Field[] fields; + private final int maxLevelDepth; private byte[] bytes; private int start; private int end; - private boolean parsed; + private boolean topLevelParsed; // Used by readNextField/skipNextField and not by readField. private int nextFieldIndex; // For getDetailedReadPositionString. - private int currentFieldIndex; + private int currentLevel; + private int currentTopLevelFieldIndex; private int currentFieldStart; private int currentFieldLength; + private int currentEscapeCount; + + private ComplexTypeHelper[] currentComplexTypeHelpers; // For string/char/varchar buffering when there are escapes. private int internalBufferLen; @@ -93,21 +222,112 @@ private boolean isEndOfInputReached; + private int addComplexFields(List fieldTypeInfoList, Field[] fields, int depth) { + Field field; + final int count = fieldTypeInfoList.size(); + for (int i = 0; i < count; i++) { + field = new Field(fieldTypeInfoList.get(i)); + if (!field.isPrimitive) { + depth = Math.max(depth, addComplexTypeHelper(field, depth)); + } + fields[i] = field; + } + return depth; + } + + private int addComplexTypeHelper(Field complexField, int depth) { + + // Assume one separator (depth) needed. + depth++; + + switch (complexField.complexCategory) { + case LIST: + { + ListTypeInfo listTypeInfo = (ListTypeInfo) complexField.typeInfo; + Field elementField = new Field(listTypeInfo.getListElementTypeInfo()); + if (!elementField.isPrimitive) { + depth = addComplexTypeHelper(elementField, depth); + } + ListComplexTypeHelper listHelper = + new ListComplexTypeHelper(complexField, elementField); + complexField.complexTypeHelper = listHelper; + } + break; + case MAP: + { + // Map needs two separators (key and key/value pair). + depth++; + + MapTypeInfo mapTypeInfo = (MapTypeInfo) complexField.typeInfo; + Field keyField = new Field(mapTypeInfo.getMapKeyTypeInfo()); + if (!keyField.isPrimitive) { + depth = Math.max(depth, addComplexTypeHelper(keyField, depth)); + } + Field valueField = new Field(mapTypeInfo.getMapValueTypeInfo()); + if (!valueField.isPrimitive) { + depth = Math.max(depth, addComplexTypeHelper(valueField, depth)); + } + MapComplexTypeHelper mapHelper = + new MapComplexTypeHelper(complexField, keyField, valueField); + complexField.complexTypeHelper = mapHelper; + } + break; + case STRUCT: + { + StructTypeInfo structTypeInfo = (StructTypeInfo) complexField.typeInfo; + List fieldTypeInfoList = structTypeInfo.getAllStructFieldTypeInfos(); + Field[] fields = new Field[fieldTypeInfoList.size()]; + depth = addComplexFields(fieldTypeInfoList, fields, depth); + StructComplexTypeHelper structHelper = + new StructComplexTypeHelper(complexField, fields); + complexField.complexTypeHelper = structHelper; + } + break; + case UNION: + { + UnionTypeInfo unionTypeInfo = (UnionTypeInfo) complexField.typeInfo; + List fieldTypeInfoList = unionTypeInfo.getAllUnionObjectTypeInfos(); + Field[] fields = new Field[fieldTypeInfoList.size()]; + depth = addComplexFields(fieldTypeInfoList, fields, depth); + UnionComplexTypeHelper structHelper = + new UnionComplexTypeHelper(complexField, fields); + complexField.complexTypeHelper = structHelper; + } + break; + default: + throw new Error("Unexpected complex category " + complexField.complexCategory); + } + return depth; + } + public LazySimpleDeserializeRead(TypeInfo[] typeInfos, boolean useExternalBuffer, - byte separator, LazySerDeParameters lazyParams) { + LazySerDeParameters lazyParams) { super(typeInfos, useExternalBuffer); - fieldCount = typeInfos.length; + final int count = typeInfos.length; + fieldCount = count; + int depth = 0; + fields = new Field[count]; + Field field; + for (int i = 0; i < count; i++) { + field = new Field(typeInfos[i]); + if (!field.isPrimitive) { + depth = Math.max(depth, addComplexTypeHelper(field, 0)); + } + fields[i] = field; + } + maxLevelDepth = depth; + currentComplexTypeHelpers = new ComplexTypeHelper[depth]; // Field length is difference between positions hence one extra. - startPosition = new int[fieldCount + 1]; + startPositions = new int[count + 1]; - this.separator = separator; + this.separators = lazyParams.getSeparators(); isEscaped = lazyParams.isEscaped(); if (isEscaped) { escapeChar = lazyParams.getEscapeChar(); - escapeCounts = new int[fieldCount]; + escapeCounts = new int[count]; } else { escapeChar = (byte) 0; escapeCounts = null; @@ -123,11 +343,6 @@ public LazySimpleDeserializeRead(TypeInfo[] typeInfos, boolean useExternalBuffer internalBufferLen = -1; } - public LazySimpleDeserializeRead(TypeInfo[] typeInfos, boolean useExternalBuffer, - LazySerDeParameters lazyParams) { - this(typeInfos, useExternalBuffer, lazyParams.getSeparators()[0], lazyParams); - } - /* * Set the range of bytes to be deserialized. */ @@ -136,7 +351,8 @@ public void set(byte[] bytes, int offset, int length) { this.bytes = bytes; start = offset; end = offset + length; - parsed = false; + topLevelParsed = false; + currentLevel = 0; nextFieldIndex = -1; } @@ -157,14 +373,15 @@ public String getDetailedReadPositionString() { sb.append(" fields with types "); sb.append(Arrays.toString(typeInfos)); sb.append(". "); - if (!parsed) { + if (!topLevelParsed) { sb.append("Error during field separator parsing"); } else { sb.append("Read field #"); - sb.append(currentFieldIndex); + sb.append(currentTopLevelFieldIndex); sb.append(" at field start position "); - sb.append(startPosition[currentFieldIndex]); - int currentFieldLength = startPosition[currentFieldIndex + 1] - startPosition[currentFieldIndex] - 1; + sb.append(startPositions[currentTopLevelFieldIndex]); + int currentFieldLength = startPositions[currentTopLevelFieldIndex + 1] - + startPositions[currentTopLevelFieldIndex] - 1; sb.append(" for field length "); sb.append(currentFieldLength); } @@ -178,15 +395,15 @@ public String getDetailedReadPositionString() { * This is an adapted version of the parse method in the LazyStruct class. * They should parse things the same way. */ - private void parse() { + private void topLevelParse() { int fieldId = 0; int fieldByteBegin = start; int fieldByteEnd = start; - final byte separator = this.separator; + final byte separator = this.separators[0]; final int fieldCount = this.fieldCount; - final int[] startPosition = this.startPosition; + final int[] startPositions = this.startPositions; final byte[] bytes = this.bytes; final int end = this.end; @@ -196,7 +413,7 @@ private void parse() { if (!isEscaped) { while (fieldByteEnd < end) { if (bytes[fieldByteEnd] == separator) { - startPosition[fieldId++] = fieldByteBegin; + startPositions[fieldId++] = fieldByteBegin; if (fieldId == fieldCount) { break; } @@ -207,7 +424,7 @@ private void parse() { } // End serves as final separator. if (fieldByteEnd == end && fieldId < fieldCount) { - startPosition[fieldId++] = fieldByteBegin; + startPositions[fieldId++] = fieldByteBegin; } } else { final byte escapeChar = this.escapeChar; @@ -219,7 +436,7 @@ private void parse() { if (bytes[fieldByteEnd] == separator) { escapeCounts[fieldId] = escapeCount; escapeCount = 0; - startPosition[fieldId++] = fieldByteBegin; + startPositions[fieldId++] = fieldByteBegin; if (fieldId == fieldCount) { break; } @@ -237,7 +454,7 @@ private void parse() { if (bytes[fieldByteEnd] == separator) { escapeCounts[fieldId] = escapeCount; escapeCount = 0; - startPosition[fieldId++] = fieldByteBegin; + startPositions[fieldId++] = fieldByteBegin; if (fieldId <= fieldCount) { fieldByteBegin = ++fieldByteEnd; } @@ -248,23 +465,66 @@ private void parse() { // End serves as final separator. if (fieldByteEnd == end && fieldId < fieldCount) { escapeCounts[fieldId] = escapeCount; - startPosition[fieldId++] = fieldByteBegin; + startPositions[fieldId++] = fieldByteBegin; } } if (fieldId == fieldCount || fieldByteEnd == end) { // All fields have been parsed, or bytes have been parsed. - // We need to set the startPosition of fields.length to ensure we + // We need to set the startPositions of fields.length to ensure we // can use the same formula to calculate the length of each field. // For missing fields, their starting positions will all be the same, // which will make their lengths to be -1 and uncheckedGetField will // return these fields as NULLs. - Arrays.fill(startPosition, fieldId, startPosition.length, fieldByteEnd + 1); + Arrays.fill(startPositions, fieldId, startPositions.length, fieldByteEnd + 1); } isEndOfInputReached = (fieldByteEnd == end); } + private int parseComplexField(int start, int end, int level) { + + final byte separator = separators[level]; + int fieldByteEnd = start; + + byte[] bytes = this.bytes; + + currentEscapeCount = 0; + if (!isEscaped) { + while (fieldByteEnd < end) { + if (bytes[fieldByteEnd] == separator) { + return fieldByteEnd; + } + fieldByteEnd++; + } + } else { + final byte escapeChar = this.escapeChar; + final int endLessOne = end - 1; + int escapeCount = 0; + // Process the bytes that can be escaped (the last one can't be). + while (fieldByteEnd < endLessOne) { + if (bytes[fieldByteEnd] == separator) { + currentEscapeCount = escapeCount; + return fieldByteEnd; + } else if (bytes[fieldByteEnd] == escapeChar) { + // Ignore the char after escape_char + fieldByteEnd += 2; + escapeCount++; + } else { + fieldByteEnd++; + } + } + // Process the last byte. + if (fieldByteEnd == endLessOne) { + if (bytes[fieldByteEnd] != separator) { + fieldByteEnd++; + } + } + currentEscapeCount = escapeCount; + } + return fieldByteEnd; + } + /* * Reads the the next field. * @@ -291,9 +551,9 @@ public boolean readNextField() throws IOException { * Designed for skipping columns that are not included. */ public void skipNextField() throws IOException { - if (!parsed) { - parse(); - parsed = true; + if (!topLevelParsed) { + topLevelParse(); + topLevelParsed = true; } if (nextFieldIndex + 1 >= fieldCount) { // No more. @@ -341,17 +601,26 @@ private boolean checkNull(byte[] bytes, int start, int len) { */ public boolean readField(int fieldIndex) throws IOException { - if (!parsed) { - parse(); - parsed = true; + Preconditions.checkState(currentLevel == 0); + + if (!topLevelParsed) { + topLevelParse(); + topLevelParsed = true; } - currentFieldIndex = fieldIndex; + // Top level. + currentTopLevelFieldIndex = fieldIndex; + + currentFieldStart = startPositions[fieldIndex]; + currentFieldLength = startPositions[fieldIndex + 1] - startPositions[fieldIndex] - 1; + currentEscapeCount = (isEscaped ? escapeCounts[fieldIndex] : 0); + + return doReadField(fields[fieldIndex]); + } - final int fieldStart = startPosition[fieldIndex]; - currentFieldStart = fieldStart; - final int fieldLength = startPosition[fieldIndex + 1] - startPosition[fieldIndex] - 1; - currentFieldLength = fieldLength; + private boolean doReadField(Field field) { + final int fieldStart = currentFieldStart; + final int fieldLength = currentFieldLength; if (fieldLength < 0) { return false; } @@ -369,222 +638,252 @@ public boolean readField(int fieldIndex) throws IOException { /* * We have a field and are positioned to it. Read it. */ - switch (primitiveCategories[fieldIndex]) { - case BOOLEAN: - { - int i = fieldStart; - if (fieldLength == 4) { - if ((bytes[i] == 'T' || bytes[i] == 't') && - (bytes[i + 1] == 'R' || bytes[i + 1] == 'r') && - (bytes[i + 2] == 'U' || bytes[i + 2] == 'u') && - (bytes[i + 3] == 'E' || bytes[i + 3] == 'e')) { - currentBoolean = true; - } else { - // No boolean value match for 4 char field. - return false; - } - } else if (fieldLength == 5) { - if ((bytes[i] == 'F' || bytes[i] == 'f') && - (bytes[i + 1] == 'A' || bytes[i + 1] == 'a') && - (bytes[i + 2] == 'L' || bytes[i + 2] == 'l') && - (bytes[i + 3] == 'S' || bytes[i + 3] == 's') && - (bytes[i + 4] == 'E' || bytes[i + 4] == 'e')) { - currentBoolean = false; - } else { - // No boolean value match for 5 char field. - return false; - } - } else if (isExtendedBooleanLiteral && fieldLength == 1) { - byte b = bytes[fieldStart]; - if (b == '1' || b == 't' || b == 'T') { - currentBoolean = true; - } else if (b == '0' || b == 'f' || b == 'F') { - currentBoolean = false; + if (field.isPrimitive) { + switch (field.primitiveCategory) { + case BOOLEAN: + { + int i = fieldStart; + if (fieldLength == 4) { + if ((bytes[i] == 'T' || bytes[i] == 't') && + (bytes[i + 1] == 'R' || bytes[i + 1] == 'r') && + (bytes[i + 2] == 'U' || bytes[i + 2] == 'u') && + (bytes[i + 3] == 'E' || bytes[i + 3] == 'e')) { + currentBoolean = true; + } else { + // No boolean value match for 4 char field. + return false; + } + } else if (fieldLength == 5) { + if ((bytes[i] == 'F' || bytes[i] == 'f') && + (bytes[i + 1] == 'A' || bytes[i + 1] == 'a') && + (bytes[i + 2] == 'L' || bytes[i + 2] == 'l') && + (bytes[i + 3] == 'S' || bytes[i + 3] == 's') && + (bytes[i + 4] == 'E' || bytes[i + 4] == 'e')) { + currentBoolean = false; + } else { + // No boolean value match for 5 char field. + return false; + } + } else if (isExtendedBooleanLiteral && fieldLength == 1) { + byte b = bytes[fieldStart]; + if (b == '1' || b == 't' || b == 'T') { + currentBoolean = true; + } else if (b == '0' || b == 'f' || b == 'F') { + currentBoolean = false; + } else { + // No boolean value match for extended 1 char field. + return false; + } } else { - // No boolean value match for extended 1 char field. + // No boolean value match for other lengths. return false; } - } else { - // No boolean value match for other lengths. + } + return true; + case BYTE: + if (!LazyUtils.isNumberMaybe(bytes, fieldStart, fieldLength)) { return false; } - } - return true; - case BYTE: - if (!LazyUtils.isNumberMaybe(bytes, fieldStart, fieldLength)) { - return false; - } - currentByte = LazyByte.parseByte(bytes, fieldStart, fieldLength, 10); - return true; - case SHORT: - if (!LazyUtils.isNumberMaybe(bytes, fieldStart, fieldLength)) { - return false; - } - currentShort = LazyShort.parseShort(bytes, fieldStart, fieldLength, 10); - return true; - case INT: - if (!LazyUtils.isNumberMaybe(bytes, fieldStart, fieldLength)) { - return false; - } - currentInt = LazyInteger.parseInt(bytes, fieldStart, fieldLength, 10); - return true; - case LONG: - if (!LazyUtils.isNumberMaybe(bytes, fieldStart, fieldLength)) { - return false; - } - currentLong = LazyLong.parseLong(bytes, fieldStart, fieldLength, 10); - return true; - case FLOAT: - if (!LazyUtils.isNumberMaybe(bytes, fieldStart, fieldLength)) { - return false; - } - currentFloat = - Float.parseFloat( - new String(bytes, fieldStart, fieldLength, StandardCharsets.UTF_8)); - return true; - case DOUBLE: - if (!LazyUtils.isNumberMaybe(bytes, fieldStart, fieldLength)) { - return false; - } - currentDouble = StringToDouble.strtod(bytes, fieldStart, fieldLength); - return true; - case STRING: - case CHAR: - case VARCHAR: - { - if (isEscaped) { - if (escapeCounts[fieldIndex] == 0) { - // No escaping. + currentByte = LazyByte.parseByte(bytes, fieldStart, fieldLength, 10); + return true; + case SHORT: + if (!LazyUtils.isNumberMaybe(bytes, fieldStart, fieldLength)) { + return false; + } + currentShort = LazyShort.parseShort(bytes, fieldStart, fieldLength, 10); + return true; + case INT: + if (!LazyUtils.isNumberMaybe(bytes, fieldStart, fieldLength)) { + return false; + } + currentInt = LazyInteger.parseInt(bytes, fieldStart, fieldLength, 10); + return true; + case LONG: + if (!LazyUtils.isNumberMaybe(bytes, fieldStart, fieldLength)) { + return false; + } + currentLong = LazyLong.parseLong(bytes, fieldStart, fieldLength, 10); + return true; + case FLOAT: + if (!LazyUtils.isNumberMaybe(bytes, fieldStart, fieldLength)) { + return false; + } + currentFloat = + Float.parseFloat( + new String(bytes, fieldStart, fieldLength, StandardCharsets.UTF_8)); + return true; + case DOUBLE: + if (!LazyUtils.isNumberMaybe(bytes, fieldStart, fieldLength)) { + return false; + } + currentDouble = StringToDouble.strtod(bytes, fieldStart, fieldLength); + return true; + case STRING: + case CHAR: + case VARCHAR: + { + if (isEscaped) { + if (currentEscapeCount == 0) { + // No escaping. + currentExternalBufferNeeded = false; + currentBytes = bytes; + currentBytesStart = fieldStart; + currentBytesLength = fieldLength; + } else { + final int unescapedLength = fieldLength - currentEscapeCount; + if (useExternalBuffer) { + currentExternalBufferNeeded = true; + currentExternalBufferNeededLen = unescapedLength; + } else { + // The copyToBuffer will reposition and re-read the input buffer. + currentExternalBufferNeeded = false; + if (internalBufferLen < unescapedLength) { + internalBufferLen = unescapedLength; + internalBuffer = new byte[internalBufferLen]; + } + copyToBuffer(internalBuffer, 0, unescapedLength); + currentBytes = internalBuffer; + currentBytesStart = 0; + currentBytesLength = unescapedLength; + } + } + } else { + // If the data is not escaped, reference the data directly. currentExternalBufferNeeded = false; currentBytes = bytes; currentBytesStart = fieldStart; currentBytesLength = fieldLength; - } else { - final int unescapedLength = fieldLength - escapeCounts[fieldIndex]; - if (useExternalBuffer) { - currentExternalBufferNeeded = true; - currentExternalBufferNeededLen = unescapedLength; - } else { - // The copyToBuffer will reposition and re-read the input buffer. - currentExternalBufferNeeded = false; - if (internalBufferLen < unescapedLength) { - internalBufferLen = unescapedLength; - internalBuffer = new byte[internalBufferLen]; - } - copyToBuffer(internalBuffer, 0, unescapedLength); - currentBytes = internalBuffer; - currentBytesStart = 0; - currentBytesLength = unescapedLength; - } } - } else { - // If the data is not escaped, reference the data directly. - currentExternalBufferNeeded = false; - currentBytes = bytes; - currentBytesStart = fieldStart; - currentBytesLength = fieldLength; } - } - return true; - case BINARY: - { - byte[] recv = new byte[fieldLength]; - System.arraycopy(bytes, fieldStart, recv, 0, fieldLength); - byte[] decoded = LazyBinary.decodeIfNeeded(recv); - // use the original bytes in case decoding should fail - decoded = decoded.length > 0 ? decoded : recv; - currentBytes = decoded; - currentBytesStart = 0; - currentBytesLength = decoded.length; - } - return true; - case DATE: - if (!LazyUtils.isDateMaybe(bytes, fieldStart, fieldLength)) { - return false; - } - currentDateWritable.set( - Date.valueOf( - new String(bytes, fieldStart, fieldLength, StandardCharsets.UTF_8))); - return true; - case TIMESTAMP: - { + return true; + case BINARY: + { + byte[] recv = new byte[fieldLength]; + System.arraycopy(bytes, fieldStart, recv, 0, fieldLength); + byte[] decoded = LazyBinary.decodeIfNeeded(recv); + // use the original bytes in case decoding should fail + decoded = decoded.length > 0 ? decoded : recv; + currentBytes = decoded; + currentBytesStart = 0; + currentBytesLength = decoded.length; + } + return true; + case DATE: if (!LazyUtils.isDateMaybe(bytes, fieldStart, fieldLength)) { return false; } - String s = new String(bytes, fieldStart, fieldLength, StandardCharsets.US_ASCII); - if (s.compareTo("NULL") == 0) { - logExceptionMessage(bytes, fieldStart, fieldLength, "TIMESTAMP"); + currentDateWritable.set( + Date.valueOf( + new String(bytes, fieldStart, fieldLength, StandardCharsets.UTF_8))); + return true; + case TIMESTAMP: + { + if (!LazyUtils.isDateMaybe(bytes, fieldStart, fieldLength)) { + return false; + } + String s = new String(bytes, fieldStart, fieldLength, StandardCharsets.US_ASCII); + if (s.compareTo("NULL") == 0) { + logExceptionMessage(bytes, fieldStart, fieldLength, "TIMESTAMP"); + return false; + } + try { + currentTimestampWritable.set(timestampParser.parseTimestamp(s)); + } catch (IllegalArgumentException e) { + logExceptionMessage(bytes, fieldStart, fieldLength, "TIMESTAMP"); + return false; + } + } + return true; + case INTERVAL_YEAR_MONTH: + if (fieldLength == 0) { return false; } try { - currentTimestampWritable.set(timestampParser.parseTimestamp(s)); - } catch (IllegalArgumentException e) { - logExceptionMessage(bytes, fieldStart, fieldLength, "TIMESTAMP"); + String s = new String(bytes, fieldStart, fieldLength, StandardCharsets.UTF_8); + currentHiveIntervalYearMonthWritable.set(HiveIntervalYearMonth.valueOf(s)); + } catch (Exception e) { + logExceptionMessage(bytes, fieldStart, fieldLength, "INTERVAL_YEAR_MONTH"); return false; } - } - return true; - case INTERVAL_YEAR_MONTH: - if (fieldLength == 0) { - return false; - } - try { - String s = new String(bytes, fieldStart, fieldLength, StandardCharsets.UTF_8); - currentHiveIntervalYearMonthWritable.set(HiveIntervalYearMonth.valueOf(s)); - } catch (Exception e) { - logExceptionMessage(bytes, fieldStart, fieldLength, "INTERVAL_YEAR_MONTH"); - return false; - } - return true; - case INTERVAL_DAY_TIME: - if (fieldLength == 0) { - return false; - } - try { - String s = new String(bytes, fieldStart, fieldLength, StandardCharsets.UTF_8); - currentHiveIntervalDayTimeWritable.set(HiveIntervalDayTime.valueOf(s)); - } catch (Exception e) { - logExceptionMessage(bytes, fieldStart, fieldLength, "INTERVAL_DAY_TIME"); - return false; - } - return true; - case DECIMAL: - { - if (!LazyUtils.isNumberMaybe(bytes, fieldStart, fieldLength)) { + return true; + case INTERVAL_DAY_TIME: + if (fieldLength == 0) { + return false; + } + try { + String s = new String(bytes, fieldStart, fieldLength, StandardCharsets.UTF_8); + currentHiveIntervalDayTimeWritable.set(HiveIntervalDayTime.valueOf(s)); + } catch (Exception e) { + logExceptionMessage(bytes, fieldStart, fieldLength, "INTERVAL_DAY_TIME"); return false; } - // Trim blanks because OldHiveDecimal did... - currentHiveDecimalWritable.setFromBytes(bytes, fieldStart, fieldLength, /* trimBlanks */ true); - boolean decimalIsNull = !currentHiveDecimalWritable.isSet(); - if (!decimalIsNull) { - DecimalTypeInfo decimalTypeInfo = (DecimalTypeInfo) typeInfos[fieldIndex]; + return true; + case DECIMAL: + { + if (!LazyUtils.isNumberMaybe(bytes, fieldStart, fieldLength)) { + return false; + } + // Trim blanks because OldHiveDecimal did... + currentHiveDecimalWritable.setFromBytes(bytes, fieldStart, fieldLength, /* trimBlanks */ true); + boolean decimalIsNull = !currentHiveDecimalWritable.isSet(); + if (!decimalIsNull) { + DecimalTypeInfo decimalTypeInfo = (DecimalTypeInfo) field.typeInfo; - int precision = decimalTypeInfo.getPrecision(); - int scale = decimalTypeInfo.getScale(); + int precision = decimalTypeInfo.getPrecision(); + int scale = decimalTypeInfo.getScale(); - decimalIsNull = !currentHiveDecimalWritable.mutateEnforcePrecisionScale(precision, scale); + decimalIsNull = !currentHiveDecimalWritable.mutateEnforcePrecisionScale(precision, scale); + } + if (decimalIsNull) { + if (LOG.isDebugEnabled()) { + LOG.debug("Data not in the HiveDecimal data type range so converted to null. Given data is :" + + new String(bytes, fieldStart, fieldLength, StandardCharsets.UTF_8)); + } + return false; + } } - if (decimalIsNull) { - if (LOG.isDebugEnabled()) { - LOG.debug("Data not in the HiveDecimal data type range so converted to null. Given data is :" - + new String(bytes, fieldStart, fieldLength, StandardCharsets.UTF_8)); + return true; + + default: + throw new Error("Unexpected primitive category " + field.primitiveCategory); + } + } else { + switch (field.complexCategory) { + case LIST: + case MAP: + case STRUCT: + case UNION: + { + if (currentLevel > 0) { + + // Check for Map which occupies 2 levels (key separator and key/value pair separator). + if (currentComplexTypeHelpers[currentLevel - 1] == null) { + Preconditions.checkState(currentLevel > 1); + Preconditions.checkState( + currentComplexTypeHelpers[currentLevel - 2] instanceof MapComplexTypeHelper); + currentLevel++; + } } - return false; + ComplexTypeHelper complexTypeHelper = field.complexTypeHelper; + currentComplexTypeHelpers[currentLevel++] = complexTypeHelper; + if (field.complexCategory == Category.MAP) { + currentComplexTypeHelpers[currentLevel] = null; + } + + // Set up context for readNextComplexField. + complexTypeHelper.setCurrentFieldInfo(currentFieldStart, currentFieldLength); } + return true; + default: + throw new Error("Unexpected complex category " + field.complexCategory); } - return true; - - default: - throw new Error("Unexpected primitive category " + primitiveCategories[fieldIndex].name()); } } catch (NumberFormatException nfe) { - // U+FFFD will throw this as well - logExceptionMessage(bytes, fieldStart, fieldLength, primitiveCategories[fieldIndex]); + logExceptionMessage(bytes, fieldStart, fieldLength, field.complexCategory, field.primitiveCategory); return false; } catch (IllegalArgumentException iae) { - // E.g. can be thrown by Date.valueOf - logExceptionMessage(bytes, fieldStart, fieldLength, primitiveCategories[fieldIndex]); - return false; + logExceptionMessage(bytes, fieldStart, fieldLength, field.complexCategory, field.primitiveCategory); + return false; } } @@ -616,6 +915,248 @@ private void copyToBuffer(byte[] buffer, int bufferStart, int bufferLength) { } } + @Override + public boolean isNextComplexMultiValue() { + Preconditions.checkState(currentLevel > 0); + + ComplexTypeHelper complexTypeHelper = currentComplexTypeHelpers[currentLevel - 1]; + Field complexField = complexTypeHelper.complexField; + final int fieldPosition = complexTypeHelper.fieldPosition; + final int complexFieldEnd = complexTypeHelper.complexFieldEnd; + switch (complexField.complexCategory) { + case LIST: + { + // Allow for empty string, etc. + final boolean isNext = (fieldPosition <= complexFieldEnd); + if (!isNext) { + popComplexType(); + } + return isNext; + } + case MAP: + { + final boolean isNext = (fieldPosition < complexFieldEnd); + if (!isNext) { + popComplexType(); + } + return isNext; + } + case STRUCT: + case UNION: + throw new Error("Complex category " + complexField.complexCategory + " not multi-value"); + default: + throw new Error("Unexpected complex category " + complexField.complexCategory); + } + } + + private void popComplexType() { + Preconditions.checkState(currentLevel > 0); + currentLevel--; + if (currentLevel > 0) { + + // Check for Map which occupies 2 levels (key separator and key/value pair separator). + if (currentComplexTypeHelpers[currentLevel - 1] == null) { + Preconditions.checkState(currentLevel > 1); + Preconditions.checkState( + currentComplexTypeHelpers[currentLevel - 2] instanceof MapComplexTypeHelper); + currentLevel--; + } + } + } + + /* + * NOTE: There is an expectation that all fields will be read-thru. + */ + @Override + public boolean readComplexField() throws IOException { + + Preconditions.checkState(currentLevel > 0); + + ComplexTypeHelper complexTypeHelper = currentComplexTypeHelpers[currentLevel - 1]; + Field complexField = complexTypeHelper.complexField; + switch (complexField.complexCategory) { + case LIST: + { + ListComplexTypeHelper listHelper = (ListComplexTypeHelper) complexTypeHelper; + final int fieldPosition = listHelper.fieldPosition; + final int complexFieldEnd = listHelper.complexFieldEnd; + Preconditions.checkState(fieldPosition <= complexFieldEnd); + + final int fieldEnd = parseComplexField(fieldPosition, complexFieldEnd, currentLevel); + listHelper.fieldPosition = fieldEnd + 1; // Move past separator. + + currentFieldStart = fieldPosition; + currentFieldLength = fieldEnd - fieldPosition; + + return doReadField(listHelper.elementField); + } + case MAP: + { + MapComplexTypeHelper mapHelper = (MapComplexTypeHelper) complexTypeHelper; + final int fieldPosition = mapHelper.fieldPosition; + final int complexFieldEnd = mapHelper.complexFieldEnd; + Preconditions.checkState(fieldPosition <= complexFieldEnd); + + currentFieldStart = fieldPosition; + + boolean isParentMap = isParentMap(); + if (isParentMap) { + currentLevel++; + } + int fieldEnd; + if (!mapHelper.fieldHaveParsedKey) { + + // Parse until key separator (currentLevel + 1). + fieldEnd = parseComplexField(fieldPosition, complexFieldEnd, currentLevel + 1); + + mapHelper.fieldPosition = fieldEnd + 1; // Move past key separator. + + currentFieldLength = fieldEnd - fieldPosition; + + mapHelper.fieldHaveParsedKey = true; + boolean result = doReadField(mapHelper.keyField); + if (isParentMap) { + currentLevel--; + } + return result; + } else { + + // Parse until pair separator (currentLevel). + fieldEnd = parseComplexField(fieldPosition, complexFieldEnd, currentLevel); + + mapHelper.fieldPosition = fieldEnd + 1; // Move past pair separator. + + currentFieldLength = fieldEnd - fieldPosition; + + mapHelper.fieldHaveParsedKey = false; + boolean result = doReadField(mapHelper.valueField); + if (isParentMap) { + currentLevel--; + } + return result; + } + } + case STRUCT: + { + StructComplexTypeHelper structHelper = (StructComplexTypeHelper) complexTypeHelper; + final int fieldPosition = structHelper.fieldPosition; + final int complexFieldEnd = structHelper.complexFieldEnd; + Preconditions.checkState(fieldPosition <= complexFieldEnd); + + currentFieldStart = fieldPosition; + + final int nextFieldIndex = structHelper.nextFieldIndex; + Field[] fields = structHelper.fields; + int fieldEnd; + if (nextFieldIndex != fields.length - 1) { + + // Parse until field separator (currentLevel). + fieldEnd = parseComplexField(fieldPosition, complexFieldEnd, currentLevel); + + structHelper.fieldPosition = fieldEnd + 1; // Move past key separator. + + currentFieldLength = fieldEnd - fieldPosition; + + return doReadField(fields[structHelper.nextFieldIndex++]); + } else { + + if (!isEscaped) { + + // No parsing necessary -- the end is the parent's end. + structHelper.fieldPosition = complexFieldEnd + 1; // Move past parent field separator. + currentEscapeCount = 0; + } else { + // We must parse to get the escape count. + fieldEnd = parseComplexField(fieldPosition, complexFieldEnd, currentLevel - 1); + } + + currentFieldLength = complexFieldEnd - fieldPosition; + + structHelper.nextFieldIndex = 0; + return doReadField(fields[fields.length - 1]); + } + } + case UNION: + { + UnionComplexTypeHelper unionHelper = (UnionComplexTypeHelper) complexTypeHelper; + final int fieldPosition = unionHelper.fieldPosition; + final int complexFieldEnd = unionHelper.complexFieldEnd; + Preconditions.checkState(fieldPosition <= complexFieldEnd); + + currentFieldStart = fieldPosition; + + int fieldEnd; + if (!unionHelper.fieldHaveParsedTag) { + boolean isParentMap = isParentMap(); + if (isParentMap) { + currentLevel++; + } + + // Parse until union separator (currentLevel). + fieldEnd = parseComplexField(fieldPosition, complexFieldEnd, currentLevel); + + unionHelper.fieldPosition = fieldEnd + 1; // Move past union separator. + + currentFieldLength = fieldEnd - fieldPosition; + + unionHelper.fieldHaveParsedTag = true; + boolean successful = doReadField(unionHelper.tagField); + if (!successful) { + throw new IOException("Null union tag"); + } + unionHelper.fieldTag = currentInt; + + if (isParentMap) { + currentLevel--; + } + return true; + } else { + + if (!isEscaped) { + + // No parsing necessary -- the end is the parent's end. + unionHelper.fieldPosition = complexFieldEnd + 1; // Move past parent field separator. + currentEscapeCount = 0; + } else { + // We must parse to get the escape count. + fieldEnd = parseComplexField(fieldPosition, complexFieldEnd, currentLevel - 1); + } + + currentFieldLength = complexFieldEnd - fieldPosition; + + unionHelper.fieldHaveParsedTag = false; + return doReadField(unionHelper.fields[unionHelper.fieldTag]); + } + } + default: + throw new Error("Unexpected complex category " + complexField.complexCategory); + } + } + + private boolean isParentMap() { + return currentLevel >= 2 && + currentComplexTypeHelpers[currentLevel - 2] instanceof MapComplexTypeHelper; + } + + @Override + public void finishComplexVariableFieldsType() { + Preconditions.checkState(currentLevel > 0); + + ComplexTypeHelper complexTypeHelper = currentComplexTypeHelpers[currentLevel - 1]; + Field complexField = complexTypeHelper.complexField; + switch (complexField.complexCategory) { + case LIST: + case MAP: + throw new Error("Complex category " + complexField.complexCategory + " is not variable fields type"); + case STRUCT: + case UNION: + popComplexType(); + break; + default: + throw new Error("Unexpected category " + complexField.complexCategory); + } + } + /* * Call this method may be called after all the all fields have been read to check * for unread fields. @@ -632,21 +1173,34 @@ public boolean isEndOfInputReached() { } public void logExceptionMessage(byte[] bytes, int bytesStart, int bytesLength, - PrimitiveCategory dataCategory) { + Category dataComplexCategory, PrimitiveCategory dataPrimitiveCategory) { final String dataType; - switch (dataCategory) { - case BYTE: - dataType = "TINYINT"; - break; - case LONG: - dataType = "BIGINT"; - break; - case SHORT: - dataType = "SMALLINT"; - break; - default: - dataType = dataCategory.toString(); - break; + if (dataComplexCategory == null) { + switch (dataPrimitiveCategory) { + case BYTE: + dataType = "TINYINT"; + break; + case LONG: + dataType = "BIGINT"; + break; + case SHORT: + dataType = "SMALLINT"; + break; + default: + dataType = dataPrimitiveCategory.toString(); + break; + } + } else { + switch (dataComplexCategory) { + case LIST: + case MAP: + case STRUCT: + case UNION: + dataType = dataComplexCategory.toString(); + break; + default: + throw new Error("Unexpected complex category " + dataComplexCategory); + } } logExceptionMessage(bytes, bytesStart, bytesLength, dataType); } diff --git serde/src/java/org/apache/hadoop/hive/serde2/lazy/fast/LazySimpleSerializeWrite.java serde/src/java/org/apache/hadoop/hive/serde2/lazy/fast/LazySimpleSerializeWrite.java index 1401ac3b94..ef77daf221 100644 --- serde/src/java/org/apache/hadoop/hive/serde2/lazy/fast/LazySimpleSerializeWrite.java +++ serde/src/java/org/apache/hadoop/hive/serde2/lazy/fast/LazySimpleSerializeWrite.java @@ -22,6 +22,10 @@ import java.nio.ByteBuffer; import java.sql.Date; import java.sql.Timestamp; +import java.util.ArrayDeque; +import java.util.Deque; +import java.util.List; +import java.util.Map; import org.apache.commons.codec.binary.Base64; import org.slf4j.Logger; @@ -48,7 +52,6 @@ import org.apache.hadoop.hive.serde2.lazy.LazyUtils; import org.apache.hadoop.hive.serde2.fast.SerializeWrite; import org.apache.hadoop.io.Text; -import org.apache.hive.common.util.DateUtils; /* * Directly serialize, field-by-field, the LazyBinary format. @@ -60,7 +63,7 @@ private LazySerDeParameters lazyParams; - private byte separator; + private byte[] separators; private boolean[] needsEscape; private boolean isEscaped; private byte escapeChar; @@ -70,6 +73,8 @@ private int fieldCount; private int index; + private int currentLevel; + private Deque indexStack = new ArrayDeque(); // For thread safety, we allocate private writable objects for our use only. private DateWritable dateWritable; @@ -80,14 +85,14 @@ private byte[] decimalScratchBuffer; public LazySimpleSerializeWrite(int fieldCount, - byte separator, LazySerDeParameters lazyParams) { + LazySerDeParameters lazyParams) { this(); this.fieldCount = fieldCount; - - this.separator = separator; + this.lazyParams = lazyParams; + separators = lazyParams.getSeparators(); isEscaped = lazyParams.isEscaped(); escapeChar = lazyParams.getEscapeChar(); needsEscape = lazyParams.getNeedsEscape(); @@ -106,6 +111,7 @@ public void set(Output output) { this.output = output; output.reset(); index = 0; + currentLevel = 0; } /* @@ -115,6 +121,7 @@ public void set(Output output) { public void setAppend(Output output) { this.output = output; index = 0; + currentLevel = 0; } /* @@ -124,35 +131,19 @@ public void setAppend(Output output) { public void reset() { output.reset(); index = 0; + currentLevel = 0; } /* - * General Pattern: - * - * if (index > 0) { - * output.write(separator); - * } - * - * WHEN NOT NULL: Write value. - * OTHERWISE NULL: Write nullSequenceBytes. - * - * Increment index - * - */ - - /* * Write a NULL field. */ @Override public void writeNull() throws IOException { - - if (index > 0) { - output.write(separator); - } + beginPrimitive(); output.write(nullSequenceBytes); - index++; + finishPrimitive(); } /* @@ -160,18 +151,13 @@ public void writeNull() throws IOException { */ @Override public void writeBoolean(boolean v) throws IOException { - - if (index > 0) { - output.write(separator); - } - + beginPrimitive(); if (v) { output.write(LazyUtils.trueBytes, 0, LazyUtils.trueBytes.length); } else { output.write(LazyUtils.falseBytes, 0, LazyUtils.falseBytes.length); } - - index++; + finishPrimitive(); } /* @@ -179,14 +165,9 @@ public void writeBoolean(boolean v) throws IOException { */ @Override public void writeByte(byte v) throws IOException { - - if (index > 0) { - output.write(separator); - } - + beginPrimitive(); LazyInteger.writeUTF8(output, v); - - index++; + finishPrimitive(); } /* @@ -194,14 +175,9 @@ public void writeByte(byte v) throws IOException { */ @Override public void writeShort(short v) throws IOException { - - if (index > 0) { - output.write(separator); - } - + beginPrimitive(); LazyInteger.writeUTF8(output, v); - - index++; + finishPrimitive(); } /* @@ -209,14 +185,9 @@ public void writeShort(short v) throws IOException { */ @Override public void writeInt(int v) throws IOException { - - if (index > 0) { - output.write(separator); - } - + beginPrimitive(); LazyInteger.writeUTF8(output, v); - - index++; + finishPrimitive(); } /* @@ -224,14 +195,9 @@ public void writeInt(int v) throws IOException { */ @Override public void writeLong(long v) throws IOException { - - if (index > 0) { - output.write(separator); - } - + beginPrimitive(); LazyLong.writeUTF8(output, v); - - index++; + finishPrimitive(); } /* @@ -239,15 +205,10 @@ public void writeLong(long v) throws IOException { */ @Override public void writeFloat(float vf) throws IOException { - - if (index > 0) { - output.write(separator); - } - + beginPrimitive(); ByteBuffer b = Text.encode(String.valueOf(vf)); output.write(b.array(), 0, b.limit()); - - index++; + finishPrimitive(); } /* @@ -255,15 +216,10 @@ public void writeFloat(float vf) throws IOException { */ @Override public void writeDouble(double v) throws IOException { - - if (index > 0) { - output.write(separator); - } - + beginPrimitive(); ByteBuffer b = Text.encode(String.valueOf(v)); output.write(b.array(), 0, b.limit()); - - index++; + finishPrimitive(); } /* @@ -274,28 +230,20 @@ public void writeDouble(double v) throws IOException { */ @Override public void writeString(byte[] v) throws IOException { - - if (index > 0) { - output.write(separator); + beginPrimitive(); + if (v.equals(nullSequenceBytes)) { } - LazyUtils.writeEscaped(output, v, 0, v.length, isEscaped, escapeChar, needsEscape); - - index++; + finishPrimitive(); } @Override public void writeString(byte[] v, int start, int length) throws IOException { - - if (index > 0) { - output.write(separator); - } - + beginPrimitive(); LazyUtils.writeEscaped(output, v, start, length, isEscaped, escapeChar, needsEscape); - - index++; + finishPrimitive(); } /* @@ -303,16 +251,11 @@ public void writeString(byte[] v, int start, int length) throws IOException { */ @Override public void writeHiveChar(HiveChar hiveChar) throws IOException { - - if (index > 0) { - output.write(separator); - } - + beginPrimitive(); ByteBuffer b = Text.encode(hiveChar.getPaddedValue()); LazyUtils.writeEscaped(output, b.array(), 0, b.limit(), isEscaped, escapeChar, needsEscape); - - index++; + finishPrimitive(); } /* @@ -320,16 +263,11 @@ public void writeHiveChar(HiveChar hiveChar) throws IOException { */ @Override public void writeHiveVarchar(HiveVarchar hiveVarchar) throws IOException { - - if (index > 0) { - output.write(separator); - } - + beginPrimitive(); ByteBuffer b = Text.encode(hiveVarchar.getValue()); LazyUtils.writeEscaped(output, b.array(), 0, b.limit(), isEscaped, escapeChar, needsEscape); - - index++; + finishPrimitive(); } /* @@ -337,32 +275,22 @@ public void writeHiveVarchar(HiveVarchar hiveVarchar) throws IOException { */ @Override public void writeBinary(byte[] v) throws IOException { - - if (index > 0) { - output.write(separator); - } - + beginPrimitive(); byte[] toEncode = new byte[v.length]; System.arraycopy(v, 0, toEncode, 0, v.length); byte[] toWrite = Base64.encodeBase64(toEncode); output.write(toWrite, 0, toWrite.length); - - index++; + finishPrimitive(); } @Override public void writeBinary(byte[] v, int start, int length) throws IOException { - - if (index > 0) { - output.write(separator); - } - + beginPrimitive(); byte[] toEncode = new byte[length]; System.arraycopy(v, start, toEncode, 0, length); byte[] toWrite = Base64.encodeBase64(toEncode); output.write(toWrite, 0, toWrite.length); - - index++; + finishPrimitive(); } /* @@ -370,35 +298,25 @@ public void writeBinary(byte[] v, int start, int length) throws IOException { */ @Override public void writeDate(Date date) throws IOException { - - if (index > 0) { - output.write(separator); - } - + beginPrimitive(); if (dateWritable == null) { dateWritable = new DateWritable(); } dateWritable.set(date); LazyDate.writeUTF8(output, dateWritable); - - index++; + finishPrimitive(); } // We provide a faster way to write a date without a Date object. @Override public void writeDate(int dateAsDays) throws IOException { - - if (index > 0) { - output.write(separator); - } - + beginPrimitive(); if (dateWritable == null) { dateWritable = new DateWritable(); } dateWritable.set(dateAsDays); LazyDate.writeUTF8(output, dateWritable); - - index++; + finishPrimitive(); } /* @@ -406,18 +324,13 @@ public void writeDate(int dateAsDays) throws IOException { */ @Override public void writeTimestamp(Timestamp v) throws IOException { - - if (index > 0) { - output.write(separator); - } - + beginPrimitive(); if (timestampWritable == null) { timestampWritable = new TimestampWritable(); } timestampWritable.set(v); LazyTimestamp.writeUTF8(output, timestampWritable); - - index++; + finishPrimitive(); } /* @@ -425,35 +338,25 @@ public void writeTimestamp(Timestamp v) throws IOException { */ @Override public void writeHiveIntervalYearMonth(HiveIntervalYearMonth viyt) throws IOException { - - if (index > 0) { - output.write(separator); - } - + beginPrimitive(); if (hiveIntervalYearMonthWritable == null) { hiveIntervalYearMonthWritable = new HiveIntervalYearMonthWritable(); } hiveIntervalYearMonthWritable.set(viyt); LazyHiveIntervalYearMonth.writeUTF8(output, hiveIntervalYearMonthWritable); - - index++; + finishPrimitive(); } @Override public void writeHiveIntervalYearMonth(int totalMonths) throws IOException { - - if (index > 0) { - output.write(separator); - } - + beginPrimitive(); if (hiveIntervalYearMonthWritable == null) { hiveIntervalYearMonthWritable = new HiveIntervalYearMonthWritable(); } hiveIntervalYearMonthWritable.set(totalMonths); LazyHiveIntervalYearMonth.writeUTF8(output, hiveIntervalYearMonthWritable); - - index++; + finishPrimitive(); } /* @@ -461,18 +364,13 @@ public void writeHiveIntervalYearMonth(int totalMonths) throws IOException { */ @Override public void writeHiveIntervalDayTime(HiveIntervalDayTime vidt) throws IOException { - - if (index > 0) { - output.write(separator); - } - + beginPrimitive(); if (hiveIntervalDayTimeWritable == null) { hiveIntervalDayTimeWritable = new HiveIntervalDayTimeWritable(); } hiveIntervalDayTimeWritable.set(vidt); LazyHiveIntervalDayTime.writeUTF8(output, hiveIntervalDayTimeWritable); - - index++; + finishPrimitive(); } /* @@ -483,29 +381,119 @@ public void writeHiveIntervalDayTime(HiveIntervalDayTime vidt) throws IOExceptio */ @Override public void writeHiveDecimal(HiveDecimal dec, int scale) throws IOException { - if (index > 0) { - output.write(separator); - } - + beginPrimitive(); if (decimalScratchBuffer == null) { decimalScratchBuffer = new byte[HiveDecimal.SCRATCH_BUFFER_LEN_TO_BYTES]; } LazyHiveDecimal.writeUTF8(output, dec, scale, decimalScratchBuffer); - - index++; + finishPrimitive(); } @Override public void writeHiveDecimal(HiveDecimalWritable decWritable, int scale) throws IOException { - if (index > 0) { - output.write(separator); - } - + beginPrimitive(); if (decimalScratchBuffer == null) { decimalScratchBuffer = new byte[HiveDecimal.SCRATCH_BUFFER_LEN_TO_BYTES]; } LazyHiveDecimal.writeUTF8(output, decWritable, scale, decimalScratchBuffer); + finishPrimitive(); + } + + private void beginComplex() { + if (index > 0) { + output.write(separators[currentLevel]); + } + indexStack.push(index); + + // Always use index 0 so the write methods don't write a separator. + index = 0; + + // Set "global" separator member to next level. + currentLevel++; + } + + private void finishComplex() { + currentLevel--; + index = indexStack.pop(); + index++; + } + + @Override + public void beginList(List list) { + beginComplex(); + } + + @Override + public void separateList() { + } + + @Override + public void finishList() { + finishComplex(); + } + + @Override + public void beginMap(Map map) { + beginComplex(); + + // MAP requires 2 levels: key separator and key-pair separator. + currentLevel++; + } + + @Override + public void separateKey() { + index = 0; + output.write(separators[currentLevel]); + } + + @Override + public void separateKeyValuePair() { + index = 0; + output.write(separators[currentLevel - 1]); + } + + @Override + public void finishMap() { + // Remove MAP extra level. + currentLevel--; + + finishComplex(); + } + + @Override + public void beginStruct(List fieldValues) { + beginComplex(); + } + + @Override + public void separateStruct() { + } + + @Override + public void finishStruct() { + finishComplex(); + } + + @Override + public void beginUnion(int tag) throws IOException { + beginComplex(); + writeInt(tag); + output.write(separators[currentLevel]); + index = 0; + } + + @Override + public void finishUnion() { + finishComplex(); + } + + private void beginPrimitive() { + if (index > 0) { + output.write(separators[currentLevel]); + } + } + private void finishPrimitive() { index++; } } diff --git serde/src/java/org/apache/hadoop/hive/serde2/lazybinary/fast/LazyBinaryDeserializeRead.java serde/src/java/org/apache/hadoop/hive/serde2/lazybinary/fast/LazyBinaryDeserializeRead.java index e94ae999fe..1dbdd73dc3 100644 --- serde/src/java/org/apache/hadoop/hive/serde2/lazybinary/fast/LazyBinaryDeserializeRead.java +++ serde/src/java/org/apache/hadoop/hive/serde2/lazybinary/fast/LazyBinaryDeserializeRead.java @@ -20,19 +20,30 @@ import java.io.EOFException; import java.io.IOException; +import java.util.ArrayDeque; import java.util.Arrays; +import java.util.Deque; +import java.util.List; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; import org.apache.hadoop.hive.serde2.fast.DeserializeRead; import org.apache.hadoop.hive.serde2.io.TimestampWritable; -import org.apache.hadoop.hive.serde2.lazybinary.LazyBinarySerDe; import org.apache.hadoop.hive.serde2.lazybinary.LazyBinaryUtils; import org.apache.hadoop.hive.serde2.lazybinary.LazyBinaryUtils.VInt; import org.apache.hadoop.hive.serde2.lazybinary.LazyBinaryUtils.VLong; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector.Category; +import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector.PrimitiveCategory; import org.apache.hadoop.hive.serde2.typeinfo.DecimalTypeInfo; +import org.apache.hadoop.hive.serde2.typeinfo.ListTypeInfo; +import org.apache.hadoop.hive.serde2.typeinfo.MapTypeInfo; +import org.apache.hadoop.hive.serde2.typeinfo.PrimitiveTypeInfo; +import org.apache.hadoop.hive.serde2.typeinfo.StructTypeInfo; import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo; +import org.apache.hadoop.hive.serde2.typeinfo.UnionTypeInfo; import org.apache.hadoop.io.WritableUtils; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; /* * Directly deserialize with the caller reading field-by-field the LazyBinary serialization format. @@ -55,26 +66,82 @@ private int start; private int offset; private int end; - private int fieldCount; - private int fieldStart; - private int fieldIndex; - private byte nullByte; + + private boolean skipLengthPrefix = false; // Object to receive results of reading a decoded variable length int or long. private VInt tempVInt; private VLong tempVLong; + private Deque stack = new ArrayDeque<>(); + private Field root; + + private class Field { + Category category; + PrimitiveCategory primitiveCategory; + TypeInfo typeInfo; + int index; + int count; + Field[] children; + int start; + int end; + int nullByteStart; + byte nullByte; + byte tag; + } + public LazyBinaryDeserializeRead(TypeInfo[] typeInfos, boolean useExternalBuffer) { super(typeInfos, useExternalBuffer); - fieldCount = typeInfos.length; tempVInt = new VInt(); tempVLong = new VLong(); currentExternalBufferNeeded = false; + + root = new Field(); + root.category = Category.STRUCT; + root.children = createFields(typeInfos); + root.count = typeInfos.length; } - // Not public since we must have the field count so every 8 fields NULL bytes can be navigated. - private LazyBinaryDeserializeRead() { - super(); + private Field[] createFields(TypeInfo[] typeInfos) { + Field[] children = new Field[typeInfos.length]; + for (int i = 0; i < typeInfos.length; i++) { + children[i] = createField(typeInfos[i]); + } + return children; + } + + private Field createField(TypeInfo typeInfo) { + Field field = new Field(); + Category category = typeInfo.getCategory(); + field.category = category; + field.typeInfo = typeInfo; + switch (category) { + case PRIMITIVE: + field.primitiveCategory = ((PrimitiveTypeInfo) typeInfo).getPrimitiveCategory(); + break; + case LIST: + field.children = new Field[1]; + field.children[0] = createField(((ListTypeInfo) typeInfo).getListElementTypeInfo()); + break; + case MAP: + field.children = new Field[2]; + field.children[0] = createField(((MapTypeInfo) typeInfo).getMapKeyTypeInfo()); + field.children[1] = createField(((MapTypeInfo) typeInfo).getMapValueTypeInfo()); + break; + case STRUCT: + StructTypeInfo structTypeInfo = (StructTypeInfo) typeInfo; + List fieldTypeInfos = structTypeInfo.getAllStructFieldTypeInfos(); + field.children = createFields(fieldTypeInfos.toArray(new TypeInfo[fieldTypeInfos.size()])); + break; + case UNION: + UnionTypeInfo unionTypeInfo = (UnionTypeInfo) typeInfo; + List objectTypeInfos = unionTypeInfo.getAllUnionObjectTypeInfos(); + field.children = createFields(objectTypeInfos.toArray(new TypeInfo[objectTypeInfos.size()])); + break; + default: + throw new RuntimeException(); + } + return field; } /* @@ -86,7 +153,20 @@ public void set(byte[] bytes, int offset, int length) { this.offset = offset; start = offset; end = offset + length; - fieldIndex = 0; + + stack.clear(); + stack.push(root); + clearIndex(root); + } + + private void clearIndex(Field field) { + field.index = 0; + if (field.children == null) { + return; + } + for (Field child : field.children) { + clearIndex(child); + } } /* @@ -102,13 +182,13 @@ public String getDetailedReadPositionString() { sb.append(" for length "); sb.append(end - start); sb.append(" to read "); - sb.append(fieldCount); + sb.append(root.children.length); sb.append(" fields with types "); sb.append(Arrays.toString(typeInfos)); sb.append(". Read field #"); - sb.append(fieldIndex); + sb.append(root.index); sb.append(" at field start position "); - sb.append(fieldStart); + sb.append(root.start); sb.append(" current read offset "); sb.append(offset); @@ -127,263 +207,196 @@ public String getDetailedReadPositionString() { */ @Override public boolean readNextField() throws IOException { - if (fieldIndex >= fieldCount) { - return false; - } - - fieldStart = offset; + return readComplexField(); + } - if (fieldIndex == 0) { - // The rest of the range check for fields after the first is below after checking - // the NULL byte. - if (offset >= end) { + private boolean readPrimitive(Field field) throws IOException { + PrimitiveCategory primitiveCategory = field.primitiveCategory; + TypeInfo typeInfo = field.typeInfo; + switch (primitiveCategory) { + case BOOLEAN: + // No check needed for single byte read. + currentBoolean = (bytes[offset++] != 0); + break; + case BYTE: + // No check needed for single byte read. + currentByte = bytes[offset++]; + break; + case SHORT: + // Last item -- ok to be at end. + if (offset + 2 > end) { throw new EOFException(); } - nullByte = bytes[offset++]; - } - - // NOTE: The bit is set to 1 if a field is NOT NULL. boolean isNull; - if ((nullByte & (1 << (fieldIndex % 8))) == 0) { - - // Logically move past this field. - fieldIndex++; - - // Every 8 fields we read a new NULL byte. - if (fieldIndex < fieldCount) { - if ((fieldIndex % 8) == 0) { - // Get next null byte. - if (offset >= end) { - throw new EOFException(); - } - nullByte = bytes[offset++]; - } + currentShort = LazyBinaryUtils.byteArrayToShort(bytes, offset); + offset += 2; + break; + case INT: + // Parse the first byte of a vint/vlong to determine the number of bytes. + if (offset + WritableUtils.decodeVIntSize(bytes[offset]) > end) { + throw new EOFException(); } - return false; - } else { - - // Make sure there is at least one byte that can be read for a value. - if (offset >= end) { + LazyBinaryUtils.readVInt(bytes, offset, tempVInt); + offset += tempVInt.length; + currentInt = tempVInt.value; + break; + case LONG: + // Parse the first byte of a vint/vlong to determine the number of bytes. + if (offset + WritableUtils.decodeVIntSize(bytes[offset]) > end) { throw new EOFException(); } - - /* - * We have a field and are positioned to it. Read it. - */ - switch (primitiveCategories[fieldIndex]) { - case BOOLEAN: - // No check needed for single byte read. - currentBoolean = (bytes[offset++] != 0); - break; - case BYTE: - // No check needed for single byte read. - currentByte = bytes[offset++]; - break; - case SHORT: - // Last item -- ok to be at end. - if (offset + 2 > end) { - throw new EOFException(); - } - currentShort = LazyBinaryUtils.byteArrayToShort(bytes, offset); - offset += 2; - break; - case INT: + LazyBinaryUtils.readVLong(bytes, offset, tempVLong); + offset += tempVLong.length; + currentLong = tempVLong.value; + break; + case FLOAT: + // Last item -- ok to be at end. + if (offset + 4 > end) { + throw new EOFException(); + } + currentFloat = Float.intBitsToFloat(LazyBinaryUtils.byteArrayToInt(bytes, offset)); + offset += 4; + break; + case DOUBLE: + // Last item -- ok to be at end. + if (offset + 8 > end) { + throw new EOFException(); + } + currentDouble = Double.longBitsToDouble(LazyBinaryUtils.byteArrayToLong(bytes, offset)); + offset += 8; + break; + + case BINARY: + case STRING: + case CHAR: + case VARCHAR: + { + // using vint instead of 4 bytes // Parse the first byte of a vint/vlong to determine the number of bytes. if (offset + WritableUtils.decodeVIntSize(bytes[offset]) > end) { throw new EOFException(); } LazyBinaryUtils.readVInt(bytes, offset, tempVInt); offset += tempVInt.length; - currentInt = tempVInt.value; - break; - case LONG: - // Parse the first byte of a vint/vlong to determine the number of bytes. - if (offset + WritableUtils.decodeVIntSize(bytes[offset]) > end) { - throw new EOFException(); - } - LazyBinaryUtils.readVLong(bytes, offset, tempVLong); - offset += tempVLong.length; - currentLong = tempVLong.value; - break; - case FLOAT: + + int saveStart = offset; + int length = tempVInt.value; + offset += length; // Last item -- ok to be at end. - if (offset + 4 > end) { + if (offset > end) { throw new EOFException(); } - currentFloat = Float.intBitsToFloat(LazyBinaryUtils.byteArrayToInt(bytes, offset)); - offset += 4; - break; - case DOUBLE: + + currentBytes = bytes; + currentBytesStart = saveStart; + currentBytesLength = length; + } + break; + case DATE: + // Parse the first byte of a vint/vlong to determine the number of bytes. + if (offset + WritableUtils.decodeVIntSize(bytes[offset]) > end) { + throw new EOFException(); + } + LazyBinaryUtils.readVInt(bytes, offset, tempVInt); + offset += tempVInt.length; + + currentDateWritable.set(tempVInt.value); + break; + case TIMESTAMP: + { + int length = TimestampWritable.getTotalLength(bytes, offset); + int saveStart = offset; + offset += length; // Last item -- ok to be at end. - if (offset + 8 > end) { + if (offset > end) { throw new EOFException(); } - currentDouble = Double.longBitsToDouble(LazyBinaryUtils.byteArrayToLong(bytes, offset)); - offset += 8; - break; - - case BINARY: - case STRING: - case CHAR: - case VARCHAR: - { - // using vint instead of 4 bytes - // Parse the first byte of a vint/vlong to determine the number of bytes. - if (offset + WritableUtils.decodeVIntSize(bytes[offset]) > end) { - throw new EOFException(); - } - LazyBinaryUtils.readVInt(bytes, offset, tempVInt); - offset += tempVInt.length; - - int saveStart = offset; - int length = tempVInt.value; - offset += length; - // Last item -- ok to be at end. - if (offset > end) { - throw new EOFException(); - } - - currentBytes = bytes; - currentBytesStart = saveStart; - currentBytesLength = length; - } - break; - case DATE: + + currentTimestampWritable.set(bytes, saveStart); + } + break; + case INTERVAL_YEAR_MONTH: + // Parse the first byte of a vint/vlong to determine the number of bytes. + if (offset + WritableUtils.decodeVIntSize(bytes[offset]) > end) { + throw new EOFException(); + } + LazyBinaryUtils.readVInt(bytes, offset, tempVInt); + offset += tempVInt.length; + + currentHiveIntervalYearMonthWritable.set(tempVInt.value); + break; + case INTERVAL_DAY_TIME: + // The first bounds check requires at least one more byte beyond for 2nd int (hence >=). + // Parse the first byte of a vint/vlong to determine the number of bytes. + if (offset + WritableUtils.decodeVIntSize(bytes[offset]) >= end) { + throw new EOFException(); + } + LazyBinaryUtils.readVLong(bytes, offset, tempVLong); + offset += tempVLong.length; + + // Parse the first byte of a vint/vlong to determine the number of bytes. + if (offset + WritableUtils.decodeVIntSize(bytes[offset]) > end) { + throw new EOFException(); + } + LazyBinaryUtils.readVInt(bytes, offset, tempVInt); + offset += tempVInt.length; + + currentHiveIntervalDayTimeWritable.set(tempVLong.value, tempVInt.value); + break; + case DECIMAL: + { + // Since enforcing precision and scale can cause a HiveDecimal to become NULL, + // we must read it, enforce it here, and either return NULL or buffer the result. + + // These calls are to see how much data there is. The setFromBytes call below will do the same + // readVInt reads but actually unpack the decimal. + + // The first bounds check requires at least one more byte beyond for 2nd int (hence >=). // Parse the first byte of a vint/vlong to determine the number of bytes. - if (offset + WritableUtils.decodeVIntSize(bytes[offset]) > end) { + if (offset + WritableUtils.decodeVIntSize(bytes[offset]) >= end) { throw new EOFException(); } LazyBinaryUtils.readVInt(bytes, offset, tempVInt); offset += tempVInt.length; + int readScale = tempVInt.value; - currentDateWritable.set(tempVInt.value); - break; - case TIMESTAMP: - { - int length = TimestampWritable.getTotalLength(bytes, offset); - int saveStart = offset; - offset += length; - // Last item -- ok to be at end. - if (offset > end) { - throw new EOFException(); - } - - currentTimestampWritable.set(bytes, saveStart); - } - break; - case INTERVAL_YEAR_MONTH: // Parse the first byte of a vint/vlong to determine the number of bytes. if (offset + WritableUtils.decodeVIntSize(bytes[offset]) > end) { throw new EOFException(); } LazyBinaryUtils.readVInt(bytes, offset, tempVInt); offset += tempVInt.length; - - currentHiveIntervalYearMonthWritable.set(tempVInt.value); - break; - case INTERVAL_DAY_TIME: - // The first bounds check requires at least one more byte beyond for 2nd int (hence >=). - // Parse the first byte of a vint/vlong to determine the number of bytes. - if (offset + WritableUtils.decodeVIntSize(bytes[offset]) >= end) { + int saveStart = offset; + offset += tempVInt.value; + // Last item -- ok to be at end. + if (offset > end) { throw new EOFException(); } - LazyBinaryUtils.readVLong(bytes, offset, tempVLong); - offset += tempVLong.length; + int length = offset - saveStart; - // Parse the first byte of a vint/vlong to determine the number of bytes. - if (offset + WritableUtils.decodeVIntSize(bytes[offset]) > end) { - throw new EOFException(); - } - LazyBinaryUtils.readVInt(bytes, offset, tempVInt); - offset += tempVInt.length; + // scale = 2, length = 6, value = -6065716379.11 + // \002\006\255\114\197\131\083\105 + // \255\114\197\131\083\105 - currentHiveIntervalDayTimeWritable.set(tempVLong.value, tempVInt.value); - break; - case DECIMAL: - { - // Since enforcing precision and scale can cause a HiveDecimal to become NULL, - // we must read it, enforce it here, and either return NULL or buffer the result. - - // These calls are to see how much data there is. The setFromBytes call below will do the same - // readVInt reads but actually unpack the decimal. - - // The first bounds check requires at least one more byte beyond for 2nd int (hence >=). - // Parse the first byte of a vint/vlong to determine the number of bytes. - if (offset + WritableUtils.decodeVIntSize(bytes[offset]) >= end) { - throw new EOFException(); - } - LazyBinaryUtils.readVInt(bytes, offset, tempVInt); - offset += tempVInt.length; - int readScale = tempVInt.value; - - // Parse the first byte of a vint/vlong to determine the number of bytes. - if (offset + WritableUtils.decodeVIntSize(bytes[offset]) > end) { - throw new EOFException(); - } - LazyBinaryUtils.readVInt(bytes, offset, tempVInt); - offset += tempVInt.length; - int saveStart = offset; - offset += tempVInt.value; - // Last item -- ok to be at end. - if (offset > end) { - throw new EOFException(); - } - int length = offset - saveStart; - - // scale = 2, length = 6, value = -6065716379.11 - // \002\006\255\114\197\131\083\105 - // \255\114\197\131\083\105 - - currentHiveDecimalWritable.setFromBigIntegerBytesAndScale( - bytes, saveStart, length, readScale); - boolean decimalIsNull = !currentHiveDecimalWritable.isSet(); - if (!decimalIsNull) { - - DecimalTypeInfo decimalTypeInfo = (DecimalTypeInfo) typeInfos[fieldIndex]; - - int precision = decimalTypeInfo.getPrecision(); - int scale = decimalTypeInfo.getScale(); - - decimalIsNull = !currentHiveDecimalWritable.mutateEnforcePrecisionScale(precision, scale); - } - if (decimalIsNull) { - - // Logically move past this field. - fieldIndex++; - - // Every 8 fields we read a new NULL byte. - if (fieldIndex < fieldCount) { - if ((fieldIndex % 8) == 0) { - // Get next null byte. - if (offset >= end) { - throw new EOFException(); - } - nullByte = bytes[offset++]; - } - } - return false; - } - } - break; + currentHiveDecimalWritable.setFromBigIntegerBytesAndScale( + bytes, saveStart, length, readScale); + boolean decimalIsNull = !currentHiveDecimalWritable.isSet(); + if (!decimalIsNull) { - default: - throw new Error("Unexpected primitive category " + primitiveCategories[fieldIndex].name()); - } - } + DecimalTypeInfo decimalTypeInfo = (DecimalTypeInfo) typeInfo; - // Logically move past this field. - fieldIndex++; + int precision = decimalTypeInfo.getPrecision(); + int scale = decimalTypeInfo.getScale(); - // Every 8 fields we read a new NULL byte. - if (fieldIndex < fieldCount) { - if ((fieldIndex % 8) == 0) { - // Get next null byte. - if (offset >= end) { - throw new EOFException(); + decimalIsNull = !currentHiveDecimalWritable.mutateEnforcePrecisionScale(precision, scale); + } + if (decimalIsNull) { + return false; } - nullByte = bytes[offset++]; } + break; + default: + throw new Error("Unexpected primitive category " + primitiveCategory.name()); } - return true; } @@ -394,8 +407,37 @@ public boolean readNextField() throws IOException { * Designed for skipping columns that are not included. */ public void skipNextField() throws IOException { - // Not a known use case for LazyBinary -- so don't optimize. - readNextField(); + Field current = stack.peek(); + boolean isNull = isNull(current); + + if (isNull) { + current.index++; + return; + } + + if (readUnionTag(current)) { + current.index++; + return; + } + + Field child = getChild(current); + + if (child.category == Category.PRIMITIVE) { + readPrimitive(child); + current.index++; + } else { + parseHeader(child); + stack.push(child); + + for (int i = 0; i < child.count; i++) { + skipNextField(); + } + finishComplexVariableFieldsType(); + } + + if (offset > end) { + throw new EOFException(); + } } /* @@ -412,4 +454,149 @@ public void skipNextField() throws IOException { public boolean isEndOfInputReached() { return (offset == end); } + + private boolean isNull(Field field) { + byte b = (byte) (1 << (field.index % 8)); + switch (field.category) { + case PRIMITIVE: + return false; + case LIST: + case MAP: + byte nullByte = bytes[field.nullByteStart + (field.index / 8)]; + return (nullByte & b) == 0; + case STRUCT: + if (field.index % 8 == 0) { + field.nullByte = bytes[offset++]; + } + return (field.nullByte & b) == 0; + case UNION: + return false; + default: + throw new RuntimeException(); + } + } + + private void parseHeader(Field field) { + // Init + field.index = 0; + field.start = offset; + + // Read length + if (!skipLengthPrefix) { + int length = LazyBinaryUtils.byteArrayToInt(bytes, offset); + offset += 4; + field.end = offset + length; + } + + switch (field.category) { + case LIST: + case MAP: + // Read count + LazyBinaryUtils.readVInt(bytes, offset, tempVInt); + if (field.category == Category.LIST) { + field.count = tempVInt.value; + } else { + field.count = tempVInt.value * 2; + } + offset += tempVInt.length; + + // Null byte start + field.nullByteStart = offset; + offset += ((field.count) + 7) / 8; + break; + case STRUCT: + field.count = ((StructTypeInfo) field.typeInfo).getAllStructFieldTypeInfos().size(); + break; + case UNION: + field.count = 2; + break; + } + } + + private Field getChild(Field field) { + Field child; + switch (field.category) { + case LIST: + child = field.children[0]; + break; + case MAP: + child = field.children[field.index % 2]; + break; + case STRUCT: + child = field.children[field.index]; + break; + case UNION: + child = field.children[field.tag]; + break; + default: + throw new RuntimeException(); + } + return child; + } + + private boolean readUnionTag(Field field) { + if (field.category == Category.UNION && field.index == 0) { + field.tag = bytes[offset++]; + currentInt = field.tag; + return true; + } else { + return false; + } + } + + // Push or next + @Override + public boolean readComplexField() throws IOException { + Field current = stack.peek(); + boolean isNull = isNull(current); + + if (isNull) { + current.index++; + return false; + } + + if (readUnionTag(current)) { + current.index++; + return true; + } + + Field child = getChild(current); + + if (child.category == Category.PRIMITIVE) { + isNull = !readPrimitive(child); + current.index++; + } else { + parseHeader(child); + stack.push(child); + } + + if (offset > end) { + throw new EOFException(); + } + return !isNull; + } + + // Pop (list, map) + @Override + public boolean isNextComplexMultiValue() { + Field current = stack.peek(); + boolean isNext; + + isNext = current.index < current.count; + if (!isNext) { + stack.pop(); + stack.peek().index++; + } + return isNext; + } + + // Pop (struct, union) + @Override + public void finishComplexVariableFieldsType() { + stack.pop(); + if (stack.peek() == null) { + throw new RuntimeException(); + } + stack.peek().index++; + } } diff --git serde/src/java/org/apache/hadoop/hive/serde2/lazybinary/fast/LazyBinarySerializeWrite.java serde/src/java/org/apache/hadoop/hive/serde2/lazybinary/fast/LazyBinarySerializeWrite.java index 085d71cc9b..8cc8647510 100644 --- serde/src/java/org/apache/hadoop/hive/serde2/lazybinary/fast/LazyBinarySerializeWrite.java +++ serde/src/java/org/apache/hadoop/hive/serde2/lazybinary/fast/LazyBinarySerializeWrite.java @@ -21,7 +21,13 @@ import java.io.IOException; import java.sql.Date; import java.sql.Timestamp; +import java.util.ArrayDeque; +import java.util.Deque; +import java.util.List; +import java.util.Map; +import org.apache.hadoop.hive.serde2.ByteStream; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.apache.hadoop.hive.common.type.HiveChar; @@ -38,7 +44,6 @@ import org.apache.hadoop.hive.serde2.lazybinary.LazyBinarySerDe; import org.apache.hadoop.hive.serde2.lazybinary.LazyBinaryUtils; import org.apache.hadoop.hive.serde2.fast.SerializeWrite; -import org.apache.hive.common.util.DateUtils; /* * Directly serialize, field-by-field, the LazyBinary format. @@ -50,10 +55,8 @@ private Output output; - private int fieldCount; - private int fieldIndex; - private byte nullByte; - private long nullOffset; + private int rootFieldCount; + private boolean skipLengthPrefix = false; // For thread safety, we allocate private writable objects for our use only. private TimestampWritable timestampWritable; @@ -64,10 +67,58 @@ private long[] scratchLongs; private byte[] scratchBuffer; + private Deque complexTypeHelperStack = new ArrayDeque<>(); + private LazyBinarySerDe.BooleanRef warnedOnceNullMapKey; + + private static class ComplexTypeHelper { + ObjectInspector.Category type; + byte nullByte; + int fieldCount; + int fieldIndex; + int byteSizeStart; + int typeStart; + long nullOffset; + + ComplexTypeHelper(ObjectInspector.Category type) { + this.type = type; + } + } + + private static class ListComplexTypeHelper extends ComplexTypeHelper { + int listStart; + + ListComplexTypeHelper() { + super(ObjectInspector.Category.LIST); + } + } + + private static class MapComplexTypeHelper extends ComplexTypeHelper { + int mapStart; + + MapComplexTypeHelper() { + super(ObjectInspector.Category.MAP); + } + } + + private static class StructComplexTypeHelper extends ComplexTypeHelper { + StructComplexTypeHelper() { + super(ObjectInspector.Category.STRUCT); + } + } + + private static class UnionComplexTypeHelper extends ComplexTypeHelper { + UnionComplexTypeHelper() { + super(ObjectInspector.Category.UNION); + } + } + public LazyBinarySerializeWrite(int fieldCount) { this(); vLongBytes = new byte[LazyBinaryUtils.VLONG_BYTES_LEN]; - this.fieldCount = fieldCount; + ComplexTypeHelper rootComplexTypeHelper = new ComplexTypeHelper(ObjectInspector.Category.STRUCT); + complexTypeHelperStack.push(rootComplexTypeHelper); + this.rootFieldCount = fieldCount; + rootComplexTypeHelper.fieldCount = fieldCount; } // Not public since we must have the field count and other information. @@ -81,9 +132,7 @@ private LazyBinarySerializeWrite() { public void set(Output output) { this.output = output; output.reset(); - fieldIndex = 0; - nullByte = 0; - nullOffset = 0; + resetWithoutOutput(); } /* @@ -92,9 +141,7 @@ public void set(Output output) { @Override public void setAppend(Output output) { this.output = output; - fieldIndex = 0; - nullByte = 0; - nullOffset = output.getLength(); + resetWithoutOutput(); } /* @@ -103,57 +150,45 @@ public void setAppend(Output output) { @Override public void reset() { output.reset(); - fieldIndex = 0; - nullByte = 0; - nullOffset = 0; + resetWithoutOutput(); } - /* - * General Pattern: - * - * // Every 8 fields we write a NULL byte. - * IF ((fieldIndex % 8) == 0), then - * IF (fieldIndex > 0), then - * Write back previous NullByte - * NullByte = 0 - * Remember write position - * Allocate room for next NULL byte. - * - * WHEN NOT NULL: Set bit in NULL byte; Write value. - * OTHERWISE NULL: We do not set a bit in the nullByte when we are writing a null. - * - * Increment fieldIndex - * - * IF (fieldIndex == fieldCount), then - * Write back final NullByte - * - */ + private void resetWithoutOutput() { + complexTypeHelperStack.clear(); + ComplexTypeHelper rootComplexTypeHelper = new ComplexTypeHelper(ObjectInspector.Category.STRUCT); + rootComplexTypeHelper.fieldCount = rootFieldCount; + complexTypeHelperStack.push(rootComplexTypeHelper); + warnedOnceNullMapKey = null; + } /* * Write a NULL field. */ @Override public void writeNull() throws IOException { - - // Every 8 fields we write a NULL byte. - if ((fieldIndex % 8) == 0) { - if (fieldIndex > 0) { - // Write back previous 8 field's NULL byte. - output.writeByte(nullOffset, nullByte); - nullByte = 0; - nullOffset = output.getLength(); + ComplexTypeHelper currentComplexTypeHelper = complexTypeHelperStack.peek(); + + if (currentComplexTypeHelper.type == ObjectInspector.Category.STRUCT) { + // Every 8 fields we write a NULL byte. + if ((currentComplexTypeHelper.fieldIndex % 8) == 0) { + if (currentComplexTypeHelper.fieldIndex > 0) { + // Write back previous 8 field's NULL byte. + output.writeByte(currentComplexTypeHelper.nullOffset, currentComplexTypeHelper.nullByte); + currentComplexTypeHelper.nullByte = 0; + currentComplexTypeHelper.nullOffset = output.getLength(); + } + // Allocate next NULL byte. + output.reserve(1); } - // Allocate next NULL byte. - output.reserve(1); - } - // We DO NOT set a bit in the NULL byte when we are writing a NULL. + // We DO NOT set a bit in the NULL byte when we are writing a NULL. - fieldIndex++; + currentComplexTypeHelper.fieldIndex++; - if (fieldIndex == fieldCount) { - // Write back the final NULL byte before the last fields. - output.writeByte(nullOffset, nullByte); + if (currentComplexTypeHelper.fieldIndex == currentComplexTypeHelper.fieldCount) { + // Write back the final NULL byte before the last fields. + output.writeByte(currentComplexTypeHelper.nullOffset, currentComplexTypeHelper.nullByte); + } } } @@ -162,30 +197,9 @@ public void writeNull() throws IOException { */ @Override public void writeBoolean(boolean v) throws IOException { - - // Every 8 fields we write a NULL byte. - if ((fieldIndex % 8) == 0) { - if (fieldIndex > 0) { - // Write back previous 8 field's NULL byte. - output.writeByte(nullOffset, nullByte); - nullByte = 0; - nullOffset = output.getLength(); - } - // Allocate next NULL byte. - output.reserve(1); - } - - // Set bit in NULL byte when a field is NOT NULL. - nullByte |= 1 << (fieldIndex % 8); - + beginElement(); output.write((byte) (v ? 1 : 0)); - - fieldIndex++; - - if (fieldIndex == fieldCount) { - // Write back the final NULL byte before the last fields. - output.writeByte(nullOffset, nullByte); - } + finishElement(); } /* @@ -193,30 +207,9 @@ public void writeBoolean(boolean v) throws IOException { */ @Override public void writeByte(byte v) throws IOException { - - // Every 8 fields we write a NULL byte. - if ((fieldIndex % 8) == 0) { - if (fieldIndex > 0) { - // Write back previous 8 field's NULL byte. - output.writeByte(nullOffset, nullByte); - nullByte = 0; - nullOffset = output.getLength(); - } - // Allocate next NULL byte. - output.reserve(1); - } - - // Set bit in NULL byte when a field is NOT NULL. - nullByte |= 1 << (fieldIndex % 8); - + beginElement(); output.write(v); - - fieldIndex++; - - if (fieldIndex == fieldCount) { - // Write back the final NULL byte before the last fields. - output.writeByte(nullOffset, nullByte); - } + finishElement(); } /* @@ -224,31 +217,10 @@ public void writeByte(byte v) throws IOException { */ @Override public void writeShort(short v) throws IOException { - - // Every 8 fields we write a NULL byte. - if ((fieldIndex % 8) == 0) { - if (fieldIndex > 0) { - // Write back previous 8 field's NULL byte. - output.writeByte(nullOffset, nullByte); - nullByte = 0; - nullOffset = output.getLength(); - } - // Allocate next NULL byte. - output.reserve(1); - } - - // Set bit in NULL byte when a field is NOT NULL. - nullByte |= 1 << (fieldIndex % 8); - + beginElement(); output.write((byte) (v >> 8)); output.write((byte) (v)); - - fieldIndex++; - - if (fieldIndex == fieldCount) { - // Write back the final NULL byte before the last fields. - output.writeByte(nullOffset, nullByte); - } + finishElement(); } /* @@ -256,30 +228,9 @@ public void writeShort(short v) throws IOException { */ @Override public void writeInt(int v) throws IOException { - - // Every 8 fields we write a NULL byte. - if ((fieldIndex % 8) == 0) { - if (fieldIndex > 0) { - // Write back previous 8 field's NULL byte. - output.writeByte(nullOffset, nullByte); - nullByte = 0; - nullOffset = output.getLength(); - } - // Allocate next NULL byte. - output.reserve(1); - } - - // Set bit in NULL byte when a field is NOT NULL. - nullByte |= 1 << (fieldIndex % 8); - + beginElement(); writeVInt(v); - - fieldIndex++; - - if (fieldIndex == fieldCount) { - // Write back the final NULL byte before the last fields. - output.writeByte(nullOffset, nullByte); - } + finishElement(); } /* @@ -287,30 +238,9 @@ public void writeInt(int v) throws IOException { */ @Override public void writeLong(long v) throws IOException { - - // Every 8 fields we write a NULL byte. - if ((fieldIndex % 8) == 0) { - if (fieldIndex > 0) { - // Write back previous 8 field's NULL byte. - output.writeByte(nullOffset, nullByte); - nullByte = 0; - nullOffset = output.getLength(); - } - // Allocate next NULL byte. - output.reserve(1); - } - - // Set bit in NULL byte when a field is NOT NULL. - nullByte |= 1 << (fieldIndex % 8); - + beginElement(); writeVLong(v); - - fieldIndex++; - - if (fieldIndex == fieldCount) { - // Write back the final NULL byte before the last fields. - output.writeByte(nullOffset, nullByte); - } + finishElement(); } /* @@ -318,34 +248,13 @@ public void writeLong(long v) throws IOException { */ @Override public void writeFloat(float vf) throws IOException { - - // Every 8 fields we write a NULL byte. - if ((fieldIndex % 8) == 0) { - if (fieldIndex > 0) { - // Write back previous 8 field's NULL byte. - output.writeByte(nullOffset, nullByte); - nullByte = 0; - nullOffset = output.getLength(); - } - // Allocate next NULL byte. - output.reserve(1); - } - - // Set bit in NULL byte when a field is NOT NULL. - nullByte |= 1 << (fieldIndex % 8); - + beginElement(); int v = Float.floatToIntBits(vf); output.write((byte) (v >> 24)); output.write((byte) (v >> 16)); output.write((byte) (v >> 8)); output.write((byte) (v)); - - fieldIndex++; - - if (fieldIndex == fieldCount) { - // Write back the final NULL byte before the last fields. - output.writeByte(nullOffset, nullByte); - } + finishElement(); } /* @@ -353,97 +262,32 @@ public void writeFloat(float vf) throws IOException { */ @Override public void writeDouble(double v) throws IOException { - - // Every 8 fields we write a NULL byte. - if ((fieldIndex % 8) == 0) { - if (fieldIndex > 0) { - // Write back previous 8 field's NULL byte. - output.writeByte(nullOffset, nullByte); - nullByte = 0; - nullOffset = output.getLength(); - } - // Allocate next NULL byte. - output.reserve(1); - } - - // Set bit in NULL byte when a field is NOT NULL. - nullByte |= 1 << (fieldIndex % 8); - + beginElement(); LazyBinaryUtils.writeDouble(output, v); - - fieldIndex++; - - if (fieldIndex == fieldCount) { - // Write back the final NULL byte before the last fields. - output.writeByte(nullOffset, nullByte); - } + finishElement(); } /* * STRING. - * + * * Can be used to write CHAR and VARCHAR when the caller takes responsibility for * truncation/padding issues. */ @Override public void writeString(byte[] v) throws IOException { - - // Every 8 fields we write a NULL byte. - if ((fieldIndex % 8) == 0) { - if (fieldIndex > 0) { - // Write back previous 8 field's NULL byte. - output.writeByte(nullOffset, nullByte); - nullByte = 0; - nullOffset = output.getLength(); - } - // Allocate next NULL byte. - output.reserve(1); - } - - // Set bit in NULL byte when a field is NOT NULL. - nullByte |= 1 << (fieldIndex % 8); - + beginElement(); int length = v.length; writeVInt(length); - output.write(v, 0, length); - - fieldIndex++; - - if (fieldIndex == fieldCount) { - // Write back the final NULL byte before the last fields. - output.writeByte(nullOffset, nullByte); - } + finishElement(); } @Override public void writeString(byte[] v, int start, int length) throws IOException { - - // Every 8 fields we write a NULL byte. - if ((fieldIndex % 8) == 0) { - if (fieldIndex > 0) { - // Write back previous 8 field's NULL byte. - output.writeByte(nullOffset, nullByte); - nullByte = 0; - nullOffset = output.getLength(); - } - // Allocate next NULL byte. - output.reserve(1); - } - - // Set bit in NULL byte when a field is NOT NULL. - nullByte |= 1 << (fieldIndex % 8); - + beginElement(); writeVInt(length); - output.write(v, start, length); - - fieldIndex++; - - if (fieldIndex == fieldCount) { - // Write back the final NULL byte before the last fields. - output.writeByte(nullOffset, nullByte); - } + finishElement(); } /* @@ -484,59 +328,17 @@ public void writeBinary(byte[] v, int start, int length) throws IOException { */ @Override public void writeDate(Date date) throws IOException { - - // Every 8 fields we write a NULL byte. - if ((fieldIndex % 8) == 0) { - if (fieldIndex > 0) { - // Write back previous 8 field's NULL byte. - output.writeByte(nullOffset, nullByte); - nullByte = 0; - nullOffset = output.getLength(); - } - // Allocate next NULL byte. - output.reserve(1); - } - - // Set bit in NULL byte when a field is NOT NULL. - nullByte |= 1 << (fieldIndex % 8); - + beginElement(); writeVInt(DateWritable.dateToDays(date)); - - fieldIndex++; - - if (fieldIndex == fieldCount) { - // Write back the final NULL byte before the last fields. - output.writeByte(nullOffset, nullByte); - } + finishElement(); } // We provide a faster way to write a date without a Date object. @Override public void writeDate(int dateAsDays) throws IOException { - - // Every 8 fields we write a NULL byte. - if ((fieldIndex % 8) == 0) { - if (fieldIndex > 0) { - // Write back previous 8 field's NULL byte. - output.writeByte(nullOffset, nullByte); - nullByte = 0; - nullOffset = output.getLength(); - } - // Allocate next NULL byte. - output.reserve(1); - } - - // Set bit in NULL byte when a field is NOT NULL. - nullByte |= 1 << (fieldIndex % 8); - + beginElement(); writeVInt(dateAsDays); - - fieldIndex++; - - if (fieldIndex == fieldCount) { - // Write back the final NULL byte before the last fields. - output.writeByte(nullOffset, nullByte); - } + finishElement(); } /* @@ -544,34 +346,13 @@ public void writeDate(int dateAsDays) throws IOException { */ @Override public void writeTimestamp(Timestamp v) throws IOException { - - // Every 8 fields we write a NULL byte. - if ((fieldIndex % 8) == 0) { - if (fieldIndex > 0) { - // Write back previous 8 field's NULL byte. - output.writeByte(nullOffset, nullByte); - nullByte = 0; - nullOffset = output.getLength(); - } - // Allocate next NULL byte. - output.reserve(1); - } - - // Set bit in NULL byte when a field is NOT NULL. - nullByte |= 1 << (fieldIndex % 8); - + beginElement(); if (timestampWritable == null) { timestampWritable = new TimestampWritable(); } timestampWritable.set(v); timestampWritable.writeToByteStream(output); - - fieldIndex++; - - if (fieldIndex == fieldCount) { - // Write back the final NULL byte before the last fields. - output.writeByte(nullOffset, nullByte); - } + finishElement(); } /* @@ -579,66 +360,24 @@ public void writeTimestamp(Timestamp v) throws IOException { */ @Override public void writeHiveIntervalYearMonth(HiveIntervalYearMonth viyt) throws IOException { - - // Every 8 fields we write a NULL byte. - if ((fieldIndex % 8) == 0) { - if (fieldIndex > 0) { - // Write back previous 8 field's NULL byte. - output.writeByte(nullOffset, nullByte); - nullByte = 0; - nullOffset = output.getLength(); - } - // Allocate next NULL byte. - output.reserve(1); - } - - // Set bit in NULL byte when a field is NOT NULL. - nullByte |= 1 << (fieldIndex % 8); - + beginElement(); if (hiveIntervalYearMonthWritable == null) { hiveIntervalYearMonthWritable = new HiveIntervalYearMonthWritable(); } hiveIntervalYearMonthWritable.set(viyt); hiveIntervalYearMonthWritable.writeToByteStream(output); - - fieldIndex++; - - if (fieldIndex == fieldCount) { - // Write back the final NULL byte before the last fields. - output.writeByte(nullOffset, nullByte); - } + finishElement(); } @Override public void writeHiveIntervalYearMonth(int totalMonths) throws IOException { - - // Every 8 fields we write a NULL byte. - if ((fieldIndex % 8) == 0) { - if (fieldIndex > 0) { - // Write back previous 8 field's NULL byte. - output.writeByte(nullOffset, nullByte); - nullByte = 0; - nullOffset = output.getLength(); - } - // Allocate next NULL byte. - output.reserve(1); - } - - // Set bit in NULL byte when a field is NOT NULL. - nullByte |= 1 << (fieldIndex % 8); - + beginElement(); if (hiveIntervalYearMonthWritable == null) { hiveIntervalYearMonthWritable = new HiveIntervalYearMonthWritable(); } hiveIntervalYearMonthWritable.set(totalMonths); hiveIntervalYearMonthWritable.writeToByteStream(output); - - fieldIndex++; - - if (fieldIndex == fieldCount) { - // Write back the final NULL byte before the last fields. - output.writeByte(nullOffset, nullByte); - } + finishElement(); } /* @@ -646,34 +385,13 @@ public void writeHiveIntervalYearMonth(int totalMonths) throws IOException { */ @Override public void writeHiveIntervalDayTime(HiveIntervalDayTime vidt) throws IOException { - - // Every 8 fields we write a NULL byte. - if ((fieldIndex % 8) == 0) { - if (fieldIndex > 0) { - // Write back previous 8 field's NULL byte. - output.writeByte(nullOffset, nullByte); - nullByte = 0; - nullOffset = output.getLength(); - } - // Allocate next NULL byte. - output.reserve(1); - } - - // Set bit in NULL byte when a field is NOT NULL. - nullByte |= 1 << (fieldIndex % 8); - + beginElement(); if (hiveIntervalDayTimeWritable == null) { hiveIntervalDayTimeWritable = new HiveIntervalDayTimeWritable(); } hiveIntervalDayTimeWritable.set(vidt); hiveIntervalDayTimeWritable.writeToByteStream(output); - - fieldIndex++; - - if (fieldIndex == fieldCount) { - // Write back the final NULL byte before the last fields. - output.writeByte(nullOffset, nullByte); - } + finishElement(); } /* @@ -684,22 +402,7 @@ public void writeHiveIntervalDayTime(HiveIntervalDayTime vidt) throws IOExceptio */ @Override public void writeHiveDecimal(HiveDecimal dec, int scale) throws IOException { - - // Every 8 fields we write a NULL byte. - if ((fieldIndex % 8) == 0) { - if (fieldIndex > 0) { - // Write back previous 8 field's NULL byte. - output.writeByte(nullOffset, nullByte); - nullByte = 0; - nullOffset = output.getLength(); - } - // Allocate next NULL byte. - output.reserve(1); - } - - // Set bit in NULL byte when a field is NOT NULL. - nullByte |= 1 << (fieldIndex % 8); - + beginElement(); if (scratchLongs == null) { scratchLongs = new long[HiveDecimal.SCRATCH_LONGS_LEN]; scratchBuffer = new byte[HiveDecimal.SCRATCH_BUFFER_LEN_BIG_INTEGER_BYTES]; @@ -709,33 +412,12 @@ public void writeHiveDecimal(HiveDecimal dec, int scale) throws IOException { dec, scratchLongs, scratchBuffer); - - fieldIndex++; - - if (fieldIndex == fieldCount) { - // Write back the final NULL byte before the last fields. - output.writeByte(nullOffset, nullByte); - } + finishElement(); } @Override public void writeHiveDecimal(HiveDecimalWritable decWritable, int scale) throws IOException { - - // Every 8 fields we write a NULL byte. - if ((fieldIndex % 8) == 0) { - if (fieldIndex > 0) { - // Write back previous 8 field's NULL byte. - output.writeByte(nullOffset, nullByte); - nullByte = 0; - nullOffset = output.getLength(); - } - // Allocate next NULL byte. - output.reserve(1); - } - - // Set bit in NULL byte when a field is NOT NULL. - nullByte |= 1 << (fieldIndex % 8); - + beginElement(); if (scratchLongs == null) { scratchLongs = new long[HiveDecimal.SCRATCH_LONGS_LEN]; scratchBuffer = new byte[HiveDecimal.SCRATCH_BUFFER_LEN_BIG_INTEGER_BYTES]; @@ -745,13 +427,7 @@ public void writeHiveDecimal(HiveDecimalWritable decWritable, int scale) throws decWritable, scratchLongs, scratchBuffer); - - fieldIndex++; - - if (fieldIndex == fieldCount) { - // Write back the final NULL byte before the last fields. - output.writeByte(nullOffset, nullByte); - } + finishElement(); } /* @@ -767,4 +443,239 @@ private void writeVLong(long v) { final int len = LazyBinaryUtils.writeVLongToByteArray(vLongBytes, v); output.write(vLongBytes, 0, len); } + + @Override + public void beginList(List list) { + ListComplexTypeHelper listHandler = new ListComplexTypeHelper(); + beginComplex(listHandler); + + int size = list.size(); + listHandler.fieldCount = size; + + if (!skipLengthPrefix) { + // 1/ reserve spaces for the byte size of the list + // which is a integer and takes four bytes + listHandler.byteSizeStart = output.getLength(); + output.reserve(4); + listHandler.listStart = output.getLength(); + } + // 2/ write the size of the list as a VInt + LazyBinaryUtils.writeVInt(output, size); + + // 3/ write the null bytes + byte nullByte = 0; + for (int eid = 0; eid < size; eid++) { + // set the bit to 1 if an element is not null + if (null != list.get(eid)) { + nullByte |= 1 << (eid % 8); + } + // store the byte every eight elements or + // if this is the last element + if (7 == eid % 8 || eid == size - 1) { + output.write(nullByte); + nullByte = 0; + } + } + } + + @Override + public void separateList() { + } + + @Override + public void finishList() { + ListComplexTypeHelper listHandler = (ListComplexTypeHelper) complexTypeHelperStack.peek(); + if (!skipLengthPrefix) { + // 5/ update the list byte size + int listEnd = output.getLength(); + int listSize = listEnd - listHandler.listStart; + writeSizeAtOffset(output, listHandler.byteSizeStart, listSize); + } + + finishComplex(); + } + + @Override + public void beginMap(Map map) { + MapComplexTypeHelper mapHelper = new MapComplexTypeHelper(); + beginComplex(mapHelper); + + if (!skipLengthPrefix) { + // 1/ reserve spaces for the byte size of the map + // which is a integer and takes four bytes + mapHelper.byteSizeStart = output.getLength(); + output.reserve(4); + mapHelper.mapStart = output.getLength(); + } + + // 2/ write the size of the map which is a VInt + int size = map.size(); + mapHelper.fieldIndex = size; + LazyBinaryUtils.writeVInt(output, size); + + // 3/ write the null bytes + int b = 0; + byte nullByte = 0; + for (Map.Entry entry : map.entrySet()) { + // set the bit to 1 if a key is not null + if (null != entry.getKey()) { + nullByte |= 1 << (b % 8); + } else if (warnedOnceNullMapKey != null) { + if (!warnedOnceNullMapKey.value) { + LOG.warn("Null map key encountered! Ignoring similar problems."); + } + warnedOnceNullMapKey.value = true; + } + b++; + // set the bit to 1 if a value is not null + if (null != entry.getValue()) { + nullByte |= 1 << (b % 8); + } + b++; + // write the byte to stream every 4 key-value pairs + // or if this is the last key-value pair + if (0 == b % 8 || b == size * 2) { + output.write(nullByte); + nullByte = 0; + } + } + } + + @Override + public void separateKey() { + } + + @Override + public void separateKeyValuePair() { + } + + @Override + public void finishMap() { + MapComplexTypeHelper mapHelper = (MapComplexTypeHelper) complexTypeHelperStack.peek(); + if (!skipLengthPrefix) { + // 5/ update the byte size of the map + int mapEnd = output.getLength(); + int mapSize = mapEnd - mapHelper.mapStart; + writeSizeAtOffset(output, mapHelper.byteSizeStart, mapSize); + } + + finishComplex(); + } + + @Override + public void beginStruct(List fieldValues) { + StructComplexTypeHelper structHelper = new StructComplexTypeHelper(); + beginComplex(structHelper); + + structHelper.fieldCount = fieldValues.size(); + + if (!skipLengthPrefix) { + // 1/ reserve spaces for the byte size of the struct + // which is a integer and takes four bytes + structHelper.byteSizeStart = output.getLength(); + output.reserve(4); + structHelper.typeStart = output.getLength(); + } + structHelper.nullOffset = output.getLength(); + } + + @Override + public void separateStruct() { + } + + @Override + public void finishStruct() { + StructComplexTypeHelper structHelper = (StructComplexTypeHelper) complexTypeHelperStack.peek(); + + if (!skipLengthPrefix) { + // 3/ update the byte size of the struct + int typeEnd = output.getLength(); + int typeSize = typeEnd - structHelper.typeStart; + writeSizeAtOffset(output, structHelper.byteSizeStart, typeSize); + } + + finishComplex(); + } + + @Override + public void beginUnion(int tag) throws IOException { + UnionComplexTypeHelper unionHelper = new UnionComplexTypeHelper(); + beginComplex(unionHelper); + + unionHelper.fieldCount = 1; + + if (!skipLengthPrefix) { + // 1/ reserve spaces for the byte size of the struct + // which is a integer and takes four bytes + unionHelper.byteSizeStart = output.getLength(); + output.reserve(4); + unionHelper.typeStart = output.getLength(); + } + + // 2/ serialize the union + output.write(tag); + } + + @Override + public void finishUnion() { + UnionComplexTypeHelper unionHelper = (UnionComplexTypeHelper) complexTypeHelperStack.peek(); + + if (!skipLengthPrefix) { + // 3/ update the byte size of the struct + int typeEnd = output.getLength(); + int typeSize = typeEnd - unionHelper.typeStart; + writeSizeAtOffset(output, unionHelper.byteSizeStart, typeSize); + } + + finishComplex(); + } + + private void beginElement() { + ComplexTypeHelper currentComplexTypeHelper = complexTypeHelperStack.peek(); + + if (currentComplexTypeHelper.type == ObjectInspector.Category.STRUCT) { + // Every 8 fields we write a NULL byte. + if ((currentComplexTypeHelper.fieldIndex % 8) == 0) { + if (currentComplexTypeHelper.fieldIndex > 0) { + // Write back previous 8 field's NULL byte. + output.writeByte(currentComplexTypeHelper.nullOffset, currentComplexTypeHelper.nullByte); + currentComplexTypeHelper.nullByte = 0; + currentComplexTypeHelper.nullOffset = output.getLength(); + } + // Allocate next NULL byte. + output.reserve(1); + } + + // Set bit in NULL byte when a field is NOT NULL. + currentComplexTypeHelper.nullByte |= 1 << (currentComplexTypeHelper.fieldIndex % 8); + } + } + + private void finishElement() { + ComplexTypeHelper currentComplexTypeHelper = complexTypeHelperStack.peek(); + + if (currentComplexTypeHelper.type == ObjectInspector.Category.STRUCT) { + currentComplexTypeHelper.fieldIndex++; + + if (currentComplexTypeHelper.fieldIndex == currentComplexTypeHelper.fieldCount) { + // Write back the final NULL byte before the last fields. + output.writeByte(currentComplexTypeHelper.nullOffset, currentComplexTypeHelper.nullByte); + } + } + } + + private void beginComplex(ComplexTypeHelper complexTypeHelper) { + beginElement(); + complexTypeHelperStack.push(complexTypeHelper); + } + + private void finishComplex() { + complexTypeHelperStack.pop(); + finishElement(); + } + + private static void writeSizeAtOffset( + ByteStream.RandomAccessOutput byteStream, int byteSizeStart, int size) { + byteStream.writeInt(byteSizeStart, size); + } } diff --git serde/src/java/org/apache/hadoop/hive/serde2/objectinspector/StandardUnionObjectInspector.java serde/src/java/org/apache/hadoop/hive/serde2/objectinspector/StandardUnionObjectInspector.java index f26c9ec69b..7b2868233f 100644 --- serde/src/java/org/apache/hadoop/hive/serde2/objectinspector/StandardUnionObjectInspector.java +++ serde/src/java/org/apache/hadoop/hive/serde2/objectinspector/StandardUnionObjectInspector.java @@ -79,6 +79,31 @@ public byte getTag() { public String toString() { return tag + ":" + object; } + + @Override + public int hashCode() { + if (object == null) { + return tag; + } else { + return object.hashCode() ^ tag; + } + } + + @Override + public boolean equals(Object obj) { + if (this == obj) { + return true; + } + if (!(obj instanceof StandardUnion)) { + return false; + } + StandardUnion that = (StandardUnion) obj; + if (this.object == null || that.object == null) { + return this.tag == that.tag && this.object == that.object; + } else { + return this.tag == that.tag && this.object.equals(that.object); + } + } } /** diff --git serde/src/test/org/apache/hadoop/hive/serde2/SerdeRandomRowSource.java serde/src/test/org/apache/hadoop/hive/serde2/SerdeRandomRowSource.java index 301ee8b344..a630e7a04e 100644 --- serde/src/test/org/apache/hadoop/hive/serde2/SerdeRandomRowSource.java +++ serde/src/test/org/apache/hadoop/hive/serde2/SerdeRandomRowSource.java @@ -25,19 +25,29 @@ import java.util.List; import java.util.Random; -import org.apache.commons.lang.ArrayUtils; -import org.apache.commons.lang.StringUtils; import org.apache.hadoop.hive.common.type.HiveChar; import org.apache.hadoop.hive.common.type.HiveDecimal; import org.apache.hadoop.hive.common.type.HiveIntervalDayTime; import org.apache.hadoop.hive.common.type.HiveIntervalYearMonth; import org.apache.hadoop.hive.common.type.HiveVarchar; import org.apache.hadoop.hive.common.type.RandomTypeUtil; +import org.apache.hadoop.hive.serde2.io.HiveCharWritable; +import org.apache.hadoop.hive.serde2.io.HiveVarcharWritable; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector.Category; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils; import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector.PrimitiveCategory; +import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.MapObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.StandardListObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.StandardMapObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.StandardStructObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.StandardUnionObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.StandardUnionObjectInspector.StandardUnion; +import org.apache.hadoop.hive.serde2.objectinspector.StructField; import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.UnionObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory; import org.apache.hadoop.hive.serde2.objectinspector.primitive.WritableBooleanObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.primitive.WritableByteObjectInspector; @@ -56,10 +66,20 @@ import org.apache.hadoop.hive.serde2.objectinspector.primitive.WritableTimestampObjectInspector; import org.apache.hadoop.hive.serde2.typeinfo.CharTypeInfo; import org.apache.hadoop.hive.serde2.typeinfo.DecimalTypeInfo; +import org.apache.hadoop.hive.serde2.typeinfo.ListTypeInfo; +import org.apache.hadoop.hive.serde2.typeinfo.MapTypeInfo; import org.apache.hadoop.hive.serde2.typeinfo.PrimitiveTypeInfo; +import org.apache.hadoop.hive.serde2.typeinfo.StructTypeInfo; +import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo; import org.apache.hadoop.hive.serde2.typeinfo.TypeInfoUtils; +import org.apache.hadoop.hive.serde2.typeinfo.UnionTypeInfo; import org.apache.hadoop.hive.serde2.typeinfo.VarcharTypeInfo; import org.apache.hive.common.util.DateUtils; +import org.apache.hadoop.io.Text; +import org.apache.hadoop.io.BytesWritable; + +import com.google.common.base.Preconditions; +import com.google.common.base.Charsets; /** * Generate object inspector and random row object[]. @@ -72,6 +92,14 @@ private List typeNames; + private Category[] categories; + + private TypeInfo[] typeInfos; + + private List objectInspectorList; + + // Primitive. + private PrimitiveCategory[] primitiveCategories; private PrimitiveTypeInfo[] primitiveTypeInfos; @@ -80,10 +108,25 @@ private StructObjectInspector rowStructObjectInspector; + private String[] alphabets; + + private boolean allowNull; + + private boolean addEscapables; + private String needsEscapeStr; + public List typeNames() { return typeNames; } + public Category[] categories() { + return categories; + } + + public TypeInfo[] typeInfos() { + return typeInfos; + } + public PrimitiveCategory[] primitiveCategories() { return primitiveCategories; } @@ -97,30 +140,37 @@ public StructObjectInspector rowStructObjectInspector() { } public StructObjectInspector partialRowStructObjectInspector(int partialFieldCount) { - ArrayList partialPrimitiveObjectInspectorList = + ArrayList partialObjectInspectorList = new ArrayList(partialFieldCount); List columnNames = new ArrayList(partialFieldCount); for (int i = 0; i < partialFieldCount; i++) { columnNames.add(String.format("partial%d", i)); - partialPrimitiveObjectInspectorList.add( - PrimitiveObjectInspectorFactory.getPrimitiveWritableObjectInspector( - primitiveTypeInfos[i])); + partialObjectInspectorList.add(getObjectInspector(typeInfos[i])); } return ObjectInspectorFactory.getStandardStructObjectInspector( - columnNames, primitiveObjectInspectorList); + columnNames, objectInspectorList); + } + + public enum SupportedTypes { + ALL, PRIMITIVE, ALL_EXCEPT_MAP + } + + public void init(Random r, SupportedTypes supportedTypes, int maxComplexDepth) { + init(r, supportedTypes, maxComplexDepth, true); } - public void init(Random r) { + public void init(Random r, SupportedTypes supportedTypes, int maxComplexDepth, boolean allowNull) { this.r = r; - chooseSchema(); + this.allowNull = allowNull; + chooseSchema(supportedTypes, maxComplexDepth); } /* * For now, exclude CHAR until we determine why there is a difference (blank padding) * serializing with LazyBinarySerializeWrite and the regular SerDe... */ - private static String[] possibleHiveTypeNames = { + private static String[] possibleHivePrimitiveTypeNames = { "boolean", "tinyint", "smallint", @@ -140,7 +190,158 @@ public void init(Random r) { "decimal" }; - private void chooseSchema() { + private static String[] possibleHiveComplexTypeNames = { + "array", + "struct", + "uniontype", + "map" + }; + + private String getRandomTypeName(SupportedTypes supportedTypes) { + String typeName = null; + if (r.nextInt(10 ) != 0) { + typeName = possibleHivePrimitiveTypeNames[r.nextInt(possibleHivePrimitiveTypeNames.length)]; + } else { + switch (supportedTypes) { + case PRIMITIVE: + typeName = possibleHivePrimitiveTypeNames[r.nextInt(possibleHivePrimitiveTypeNames.length)]; + break; + case ALL_EXCEPT_MAP: + typeName = possibleHiveComplexTypeNames[r.nextInt(possibleHiveComplexTypeNames.length - 1)]; + break; + case ALL: + typeName = possibleHiveComplexTypeNames[r.nextInt(possibleHiveComplexTypeNames.length)]; + break; + } + } + return typeName; + } + + private String getDecoratedTypeName(String typeName, SupportedTypes supportedTypes, int depth, int maxDepth) { + depth++; + if (depth < maxDepth) { + supportedTypes = SupportedTypes.PRIMITIVE; + } + if (typeName.equals("char")) { + int maxLength = 1 + r.nextInt(100); + typeName = String.format("char(%d)", maxLength); + } else if (typeName.equals("varchar")) { + int maxLength = 1 + r.nextInt(100); + typeName = String.format("varchar(%d)", maxLength); + } else if (typeName.equals("decimal")) { + typeName = String.format("decimal(%d,%d)", HiveDecimal.SYSTEM_DEFAULT_PRECISION, HiveDecimal.SYSTEM_DEFAULT_SCALE); + } else if (typeName.equals("array")) { + String elementTypeName = getRandomTypeName(supportedTypes); + elementTypeName = getDecoratedTypeName(elementTypeName, supportedTypes, depth, maxDepth); + typeName = String.format("array<%s>", elementTypeName); + } else if (typeName.equals("map")) { + String keyTypeName = getRandomTypeName(SupportedTypes.PRIMITIVE); + keyTypeName = getDecoratedTypeName(keyTypeName, supportedTypes, depth, maxDepth); + String valueTypeName = getRandomTypeName(supportedTypes); + valueTypeName = getDecoratedTypeName(valueTypeName, supportedTypes, depth, maxDepth); + typeName = String.format("map<%s,%s>", keyTypeName, valueTypeName); + } else if (typeName.equals("struct")) { + final int fieldCount = 1 + r.nextInt(10); + StringBuilder sb = new StringBuilder(); + for (int i = 0; i < fieldCount; i++) { + String fieldTypeName = getRandomTypeName(supportedTypes); + fieldTypeName = getDecoratedTypeName(fieldTypeName, supportedTypes, depth, maxDepth); + if (i > 0) { + sb.append(","); + } + sb.append("col"); + sb.append(i); + sb.append(":"); + sb.append(fieldTypeName); + } + typeName = String.format("struct<%s>", sb.toString()); + } else if (typeName.equals("struct") || + typeName.equals("uniontype")) { + final int fieldCount = 1 + r.nextInt(10); + StringBuilder sb = new StringBuilder(); + for (int i = 0; i < fieldCount; i++) { + String fieldTypeName = getRandomTypeName(supportedTypes); + fieldTypeName = getDecoratedTypeName(fieldTypeName, supportedTypes, depth, maxDepth); + if (i > 0) { + sb.append(","); + } + sb.append(fieldTypeName); + } + typeName = String.format("uniontype<%s>", sb.toString()); + } + return typeName; + } + + private ObjectInspector getObjectInspector(TypeInfo typeInfo) { + ObjectInspector objectInspector; + switch (typeInfo.getCategory()) { + case PRIMITIVE: + { + PrimitiveTypeInfo primitiveType = (PrimitiveTypeInfo) typeInfo; + objectInspector = + PrimitiveObjectInspectorFactory. + getPrimitiveWritableObjectInspector(primitiveType); + } + break; + case MAP: + { + MapTypeInfo mapType = (MapTypeInfo) typeInfo; + MapObjectInspector mapInspector = + ObjectInspectorFactory.getStandardMapObjectInspector( + getObjectInspector(mapType.getMapKeyTypeInfo()), + getObjectInspector(mapType.getMapValueTypeInfo())); + objectInspector = mapInspector; + } + break; + case LIST: + { + ListTypeInfo listType = (ListTypeInfo) typeInfo; + ListObjectInspector listInspector = + ObjectInspectorFactory.getStandardListObjectInspector( + getObjectInspector(listType.getListElementTypeInfo())); + objectInspector = listInspector; + } + break; + case STRUCT: + { + StructTypeInfo structType = (StructTypeInfo) typeInfo; + List fieldTypes = structType.getAllStructFieldTypeInfos(); + + List fieldInspectors = new ArrayList(); + for (TypeInfo fieldType : fieldTypes) { + fieldInspectors.add(getObjectInspector(fieldType)); + } + + StructObjectInspector structInspector = + ObjectInspectorFactory.getStandardStructObjectInspector( + structType.getAllStructFieldNames(), fieldInspectors); + objectInspector = structInspector; + } + break; + case UNION: + { + UnionTypeInfo unionType = (UnionTypeInfo) typeInfo; + List fieldTypes = unionType.getAllUnionObjectTypeInfos(); + + List fieldInspectors = new ArrayList(); + for (TypeInfo fieldType : fieldTypes) { + fieldInspectors.add(getObjectInspector(fieldType)); + } + + UnionObjectInspector unionInspector = + ObjectInspectorFactory.getStandardUnionObjectInspector( + fieldInspectors); + objectInspector = unionInspector; + } + break; + default: + throw new RuntimeException("Unexpected category " + typeInfo.getCategory()); + } + Preconditions.checkState(objectInspector != null); + return objectInspector; + } + + private void chooseSchema(SupportedTypes supportedTypes, int maxComplexDepth) { HashSet hashSet = null; boolean allTypes; boolean onlyOne = (r.nextInt(100) == 7); @@ -150,14 +351,27 @@ private void chooseSchema() { } else { allTypes = r.nextBoolean(); if (allTypes) { - // One of each type. - columnCount = possibleHiveTypeNames.length; + switch (supportedTypes) { + case ALL: + columnCount = possibleHivePrimitiveTypeNames.length + possibleHiveComplexTypeNames.length; + break; + case ALL_EXCEPT_MAP: + columnCount = possibleHivePrimitiveTypeNames.length + possibleHiveComplexTypeNames.length - 1; + break; + case PRIMITIVE: + columnCount = possibleHivePrimitiveTypeNames.length; + break; + } hashSet = new HashSet(); } else { columnCount = 1 + r.nextInt(20); } } typeNames = new ArrayList(columnCount); + categories = new Category[columnCount]; + typeInfos = new TypeInfo[columnCount]; + objectInspectorList = new ArrayList(columnCount); + primitiveCategories = new PrimitiveCategory[columnCount]; primitiveTypeInfos = new PrimitiveTypeInfo[columnCount]; primitiveObjectInspectorList = new ArrayList(columnCount); @@ -167,12 +381,26 @@ private void chooseSchema() { String typeName; if (onlyOne) { - typeName = possibleHiveTypeNames[r.nextInt(possibleHiveTypeNames.length)]; + typeName = getRandomTypeName(supportedTypes); } else { int typeNum; if (allTypes) { + int maxTypeNum = 0; + switch (supportedTypes) { + case PRIMITIVE: + maxTypeNum = possibleHivePrimitiveTypeNames.length; + break; + case ALL_EXCEPT_MAP: + maxTypeNum = possibleHivePrimitiveTypeNames.length + possibleHiveComplexTypeNames.length - 1; + break; + case ALL: + maxTypeNum = possibleHivePrimitiveTypeNames.length + possibleHiveComplexTypeNames.length; + break; + } while (true) { - typeNum = r.nextInt(possibleHiveTypeNames.length); + + typeNum = r.nextInt(maxTypeNum); + Integer typeNumInteger = new Integer(typeNum); if (!hashSet.contains(typeNumInteger)) { hashSet.add(typeNumInteger); @@ -180,27 +408,64 @@ private void chooseSchema() { } } } else { - typeNum = r.nextInt(possibleHiveTypeNames.length); + if (supportedTypes == SupportedTypes.PRIMITIVE || r.nextInt(10) != 0) { + typeNum = r.nextInt(possibleHivePrimitiveTypeNames.length); + } else { + typeNum = possibleHivePrimitiveTypeNames.length + r.nextInt(possibleHiveComplexTypeNames.length); + if (supportedTypes == SupportedTypes.ALL_EXCEPT_MAP) { + typeNum--; + } + } + } + if (typeNum < possibleHivePrimitiveTypeNames.length) { + typeName = possibleHivePrimitiveTypeNames[typeNum]; + } else { + typeName = possibleHiveComplexTypeNames[typeNum - possibleHivePrimitiveTypeNames.length]; + } + + } + + String decoratedTypeName = getDecoratedTypeName(typeName, supportedTypes, 0, maxComplexDepth); + + TypeInfo typeInfo; + try { + typeInfo = TypeInfoUtils.getTypeInfoFromTypeString(decoratedTypeName); + } catch (Exception e) { + throw new RuntimeException("Cannot convert type name " + decoratedTypeName + " to a type " + e); + } + + typeInfos[c] = typeInfo; + Category category = typeInfo.getCategory(); + categories[c] = category; + ObjectInspector objectInspector = getObjectInspector(typeInfo); + switch (category) { + case PRIMITIVE: + { + PrimitiveTypeInfo primitiveTypeInfo = (PrimitiveTypeInfo) typeInfo; + objectInspector = PrimitiveObjectInspectorFactory.getPrimitiveWritableObjectInspector(primitiveTypeInfo); + primitiveTypeInfos[c] = primitiveTypeInfo; + PrimitiveCategory primitiveCategory = primitiveTypeInfo.getPrimitiveCategory(); + primitiveCategories[c] = primitiveCategory; + primitiveObjectInspectorList.add(objectInspector); } - typeName = possibleHiveTypeNames[typeNum]; + break; + case LIST: + case MAP: + case STRUCT: + case UNION: + primitiveObjectInspectorList.add(null); + break; + default: + throw new RuntimeException("Unexpected catagory " + category); } - if (typeName.equals("char")) { - int maxLength = 1 + r.nextInt(100); - typeName = String.format("char(%d)", maxLength); - } else if (typeName.equals("varchar")) { - int maxLength = 1 + r.nextInt(100); - typeName = String.format("varchar(%d)", maxLength); - } else if (typeName.equals("decimal")) { - typeName = String.format("decimal(%d,%d)", HiveDecimal.SYSTEM_DEFAULT_PRECISION, HiveDecimal.SYSTEM_DEFAULT_SCALE); + objectInspectorList.add(objectInspector); + + if (category == Category.PRIMITIVE) { } - PrimitiveTypeInfo primitiveTypeInfo = (PrimitiveTypeInfo) TypeInfoUtils.getTypeInfoFromTypeString(typeName); - primitiveTypeInfos[c] = primitiveTypeInfo; - PrimitiveCategory primitiveCategory = primitiveTypeInfo.getPrimitiveCategory(); - primitiveCategories[c] = primitiveCategory; - primitiveObjectInspectorList.add(PrimitiveObjectInspectorFactory.getPrimitiveWritableObjectInspector(primitiveTypeInfo)); - typeNames.add(typeName); + typeNames.add(decoratedTypeName); } - rowStructObjectInspector = ObjectInspectorFactory.getStandardStructObjectInspector(columnNames, primitiveObjectInspectorList); + rowStructObjectInspector = ObjectInspectorFactory.getStandardStructObjectInspector(columnNames, objectInspectorList); + alphabets = new String[columnCount]; } public Object[][] randomRows(int n) { @@ -214,18 +479,65 @@ private void chooseSchema() { public Object[] randomRow() { Object row[] = new Object[columnCount]; for (int c = 0; c < columnCount; c++) { - Object object = randomObject(c); - if (object == null) { - throw new Error("Unexpected null for column " + c); - } - row[c] = getWritableObject(c, object); - if (row[c] == null) { - throw new Error("Unexpected null for writable for column " + c); - } + row[c] = randomWritable(c); + } + return row; + } + + public Object[] randomPrimitiveRow(int columnCount) { + return randomPrimitiveRow(columnCount, r, primitiveTypeInfos); + } + + public static Object[] randomPrimitiveRow(int columnCount, Random r, + PrimitiveTypeInfo[] primitiveTypeInfos) { + Object row[] = new Object[columnCount]; + for (int c = 0; c < columnCount; c++) { + row[c] = randomPrimitiveObject(r, primitiveTypeInfos[c]); + } + return row; + } + + public static Object[] randomWritablePrimitiveRow(int columnCount, Random r, + PrimitiveTypeInfo[] primitiveTypeInfos) { + Object row[] = new Object[columnCount]; + for (int c = 0; c < columnCount; c++) { + PrimitiveTypeInfo primitiveTypeInfo = primitiveTypeInfos[c]; + ObjectInspector objectInspector = + PrimitiveObjectInspectorFactory.getPrimitiveWritableObjectInspector(primitiveTypeInfo); + Object object = randomPrimitiveObject(r, primitiveTypeInfo); + row[c] = getWritablePrimitiveObject(primitiveTypeInfo, objectInspector, object); } return row; } + public void addBinarySortableAlphabets() { + for (int c = 0; c < columnCount; c++) { + switch (primitiveCategories[c]) { + case STRING: + case CHAR: + case VARCHAR: + byte[] bytes = new byte[10 + r.nextInt(10)]; + for (int i = 0; i < bytes.length; i++) { + bytes[i] = (byte) (32 + r.nextInt(96)); + } + int alwaysIndex = r.nextInt(bytes.length); + bytes[alwaysIndex] = 0; // Must be escaped by BinarySortable. + int alwaysIndex2 = r.nextInt(bytes.length); + bytes[alwaysIndex2] = 1; // Must be escaped by BinarySortable. + alphabets[c] = new String(bytes, Charsets.UTF_8); + break; + default: + // No alphabet needed. + break; + } + } + } + + public void addEscapables(String needsEscapeStr) { + addEscapables = true; + this.needsEscapeStr = needsEscapeStr; + } + public static void sort(Object[][] rows, ObjectInspector oi) { for (int i = 0; i < rows.length; i++) { for (int j = i + 1; j < rows.length; j++) { @@ -242,11 +554,9 @@ public void sort(Object[][] rows) { SerdeRandomRowSource.sort(rows, rowStructObjectInspector); } - public Object getWritableObject(int column, Object object) { - ObjectInspector objectInspector = primitiveObjectInspectorList.get(column); - PrimitiveCategory primitiveCategory = primitiveCategories[column]; - PrimitiveTypeInfo primitiveTypeInfo = primitiveTypeInfos[column]; - switch (primitiveCategory) { + public static Object getWritablePrimitiveObject(PrimitiveTypeInfo primitiveTypeInfo, + ObjectInspector objectInspector, Object object) { + switch (primitiveTypeInfo.getPrimitiveCategory()) { case BOOLEAN: return ((WritableBooleanObjectInspector) objectInspector).create((boolean) object); case BYTE: @@ -292,16 +602,166 @@ public Object getWritableObject(int column, Object object) { return writableDecimalObjectInspector.create((HiveDecimal) object); } default: - throw new Error("Unknown primitive category " + primitiveCategory); + throw new Error("Unknown primitive category " + primitiveTypeInfo.getPrimitiveCategory()); + } + } + + public Object randomWritable(int column) { + return randomWritable(typeInfos[column], objectInspectorList.get(column)); + } + + public Object randomWritable(TypeInfo typeInfo, ObjectInspector objectInspector) { + return randomWritable(typeInfo, objectInspector, allowNull); + } + + public Object randomWritable(TypeInfo typeInfo, ObjectInspector objectInspector, boolean allowNull) { + switch (typeInfo.getCategory()) { + case PRIMITIVE: + { + Object object = randomPrimitiveObject(r, (PrimitiveTypeInfo) typeInfo); + return getWritablePrimitiveObject((PrimitiveTypeInfo) typeInfo, objectInspector, object); + } + case LIST: + { + if (allowNull && r.nextInt(20) == 0) { + return null; + } + // Always generate a list with at least 1 value? + final int elementCount = 1 + r.nextInt(100); + StandardListObjectInspector listObjectInspector = + (StandardListObjectInspector) objectInspector; + ObjectInspector elementObjectInspector = + listObjectInspector.getListElementObjectInspector(); + TypeInfo elementTypeInfo = + TypeInfoUtils.getTypeInfoFromObjectInspector( + elementObjectInspector); + boolean isStringFamily = false; + PrimitiveCategory primitiveCategory = null; + if (elementTypeInfo.getCategory() == Category.PRIMITIVE) { + primitiveCategory = ((PrimitiveTypeInfo) elementTypeInfo).getPrimitiveCategory(); + if (primitiveCategory == PrimitiveCategory.STRING || + primitiveCategory == PrimitiveCategory.BINARY || + primitiveCategory == PrimitiveCategory.CHAR || + primitiveCategory == PrimitiveCategory.VARCHAR) { + isStringFamily = true; + } + } + Object listObj = listObjectInspector.create(elementCount); + for (int i = 0; i < elementCount; i++) { + Object ele = randomWritable(elementTypeInfo, elementObjectInspector, allowNull); + // UNDONE: For now, a 1-element list with a null element is a null list... + if (ele == null && elementCount == 1) { + return null; + } + if (isStringFamily && elementCount == 1) { + switch (primitiveCategory) { + case STRING: + if (((Text) ele).getLength() == 0) { + return null; + } + break; + case BINARY: + if (((BytesWritable) ele).getLength() == 0) { + return null; + } + break; + case CHAR: + if (((HiveCharWritable) ele).getHiveChar().getStrippedValue().isEmpty()) { + return null; + } + break; + case VARCHAR: + if (((HiveVarcharWritable) ele).getHiveVarchar().getValue().isEmpty()) { + return null; + } + break; + default: + throw new RuntimeException("Unexpected primitive category " + primitiveCategory); + } + } + listObjectInspector.set(listObj, i, ele); + } + return listObj; + } + case MAP: + { + if (allowNull && r.nextInt(20) == 0) { + return null; + } + final int keyPairCount = r.nextInt(100); + StandardMapObjectInspector mapObjectInspector = + (StandardMapObjectInspector) objectInspector; + ObjectInspector keyObjectInspector = + mapObjectInspector.getMapKeyObjectInspector(); + TypeInfo keyTypeInfo = + TypeInfoUtils.getTypeInfoFromObjectInspector( + keyObjectInspector); + ObjectInspector valueObjectInspector = + mapObjectInspector.getMapValueObjectInspector(); + TypeInfo valueTypeInfo = + TypeInfoUtils.getTypeInfoFromObjectInspector( + valueObjectInspector); + Object mapObj = mapObjectInspector.create(); + for (int i = 0; i < keyPairCount; i++) { + Object key = randomWritable(keyTypeInfo, keyObjectInspector); + Object value = randomWritable(valueTypeInfo, valueObjectInspector); + mapObjectInspector.put(mapObj, key, value); + } + return mapObj; + } + case STRUCT: + { + if (allowNull && r.nextInt(20) == 0) { + return null; + } + StandardStructObjectInspector structObjectInspector = + (StandardStructObjectInspector) objectInspector; + List fieldRefsList = structObjectInspector.getAllStructFieldRefs(); + final int fieldCount = fieldRefsList.size(); + Object structObj = structObjectInspector.create(); + for (int i = 0; i < fieldCount; i++) { + StructField fieldRef = fieldRefsList.get(i); + ObjectInspector fieldObjectInspector = + fieldRef.getFieldObjectInspector(); + TypeInfo fieldTypeInfo = + TypeInfoUtils.getTypeInfoFromObjectInspector( + fieldObjectInspector); + Object fieldObj = randomWritable(fieldTypeInfo, fieldObjectInspector); + structObjectInspector.setStructFieldData(structObj, fieldRef, fieldObj); + } + return structObj; + } + case UNION: + { + StandardUnionObjectInspector unionObjectInspector = + (StandardUnionObjectInspector) objectInspector; + List objectInspectorList = unionObjectInspector.getObjectInspectors(); + final int unionCount = objectInspectorList.size(); + final byte tag = (byte) r.nextInt(unionCount); + ObjectInspector fieldObjectInspector = + objectInspectorList.get(tag); + TypeInfo fieldTypeInfo = + TypeInfoUtils.getTypeInfoFromObjectInspector( + fieldObjectInspector); + Object fieldObj = randomWritable(fieldTypeInfo, fieldObjectInspector, false); + if (fieldObj == null) { + throw new RuntimeException(); + } + return new StandardUnion(tag, fieldObj); + } + default: + throw new RuntimeException("Unexpected category " + typeInfo.getCategory()); } } - public Object randomObject(int column) { - PrimitiveCategory primitiveCategory = primitiveCategories[column]; - PrimitiveTypeInfo primitiveTypeInfo = primitiveTypeInfos[column]; - switch (primitiveCategory) { + public Object randomPrimitiveObject(int column) { + return randomPrimitiveObject(r, primitiveTypeInfos[column]); + } + + public static Object randomPrimitiveObject(Random r, PrimitiveTypeInfo primitiveTypeInfo) { + switch (primitiveTypeInfo.getPrimitiveCategory()) { case BOOLEAN: - return Boolean.valueOf(r.nextInt(1) == 1); + return Boolean.valueOf(r.nextBoolean()); case BYTE: return Byte.valueOf((byte) r.nextInt()); case SHORT: @@ -336,7 +796,7 @@ public Object randomObject(int column) { return dec; } default: - throw new Error("Unknown primitive category " + primitiveCategory); + throw new Error("Unknown primitive category " + primitiveTypeInfo.getCategory()); } } @@ -347,13 +807,17 @@ public static HiveChar getRandHiveChar(Random r, CharTypeInfo charTypeInfo) { return hiveChar; } - public static HiveVarchar getRandHiveVarchar(Random r, VarcharTypeInfo varcharTypeInfo) { + public static HiveVarchar getRandHiveVarchar(Random r, VarcharTypeInfo varcharTypeInfo, String alphabet) { int maxLength = 1 + r.nextInt(varcharTypeInfo.getLength()); - String randomString = RandomTypeUtil.getRandString(r, "abcdefghijklmnopqrstuvwxyz", 100); + String randomString = RandomTypeUtil.getRandString(r, alphabet, 100); HiveVarchar hiveVarchar = new HiveVarchar(randomString, maxLength); return hiveVarchar; } + public static HiveVarchar getRandHiveVarchar(Random r, VarcharTypeInfo varcharTypeInfo) { + return getRandHiveVarchar(r, varcharTypeInfo, "abcdefghijklmnopqrstuvwxyz"); + } + public static byte[] getRandBinary(Random r, int len){ byte[] bytes = new byte[len]; for (int j = 0; j < len; j++){ diff --git serde/src/test/org/apache/hadoop/hive/serde2/VerifyFast.java serde/src/test/org/apache/hadoop/hive/serde2/VerifyFast.java index 19b04bb66e..6a69290340 100644 --- serde/src/test/org/apache/hadoop/hive/serde2/VerifyFast.java +++ serde/src/test/org/apache/hadoop/hive/serde2/VerifyFast.java @@ -18,9 +18,14 @@ package org.apache.hadoop.hive.serde2; import java.io.IOException; +import java.nio.charset.StandardCharsets; import java.sql.Date; import java.sql.Timestamp; +import java.util.ArrayList; import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map.Entry; import junit.framework.TestCase; @@ -30,7 +35,7 @@ import org.apache.hadoop.hive.common.type.HiveIntervalYearMonth; import org.apache.hadoop.hive.common.type.HiveVarchar; import org.apache.hadoop.hive.serde2.fast.DeserializeRead; -import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector.PrimitiveCategory; +import org.apache.hadoop.hive.serde2.objectinspector.StandardUnionObjectInspector.StandardUnion; import org.apache.hadoop.hive.serde2.fast.SerializeWrite; import org.apache.hadoop.hive.serde2.io.ByteWritable; import org.apache.hadoop.hive.serde2.io.DateWritable; @@ -44,7 +49,13 @@ import org.apache.hadoop.hive.serde2.io.TimestampWritable; import org.apache.hadoop.hive.serde2.typeinfo.CharTypeInfo; import org.apache.hadoop.hive.serde2.typeinfo.DecimalTypeInfo; +import org.apache.hadoop.hive.serde2.typeinfo.ListTypeInfo; +import org.apache.hadoop.hive.serde2.typeinfo.MapTypeInfo; import org.apache.hadoop.hive.serde2.typeinfo.PrimitiveTypeInfo; +import org.apache.hadoop.hive.serde2.typeinfo.StructTypeInfo; +import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo; +import org.apache.hadoop.hive.serde2.typeinfo.TypeInfoFactory; +import org.apache.hadoop.hive.serde2.typeinfo.UnionTypeInfo; import org.apache.hadoop.hive.serde2.typeinfo.VarcharTypeInfo; import org.apache.hadoop.io.BooleanWritable; import org.apache.hadoop.io.BytesWritable; @@ -52,7 +63,6 @@ import org.apache.hadoop.io.IntWritable; import org.apache.hadoop.io.LongWritable; import org.apache.hadoop.io.Text; -import org.apache.hadoop.io.Writable; /** * TestBinarySortableSerDe. @@ -61,338 +71,638 @@ public class VerifyFast { public static void verifyDeserializeRead(DeserializeRead deserializeRead, - PrimitiveTypeInfo primitiveTypeInfo, Writable writable) throws IOException { + TypeInfo typeInfo, Object object) throws IOException { boolean isNull; isNull = !deserializeRead.readNextField(); + doVerifyDeserializeRead(deserializeRead, typeInfo, object, isNull); + } + + public static void doVerifyDeserializeRead(DeserializeRead deserializeRead, + TypeInfo typeInfo, Object object, boolean isNull) throws IOException { if (isNull) { - if (writable != null) { - TestCase.fail("Field reports null but object is not null (class " + writable.getClass().getName() + ", " + writable.toString() + ")"); + if (object != null) { + TestCase.fail("Field reports null but object is not null (class " + object.getClass().getName() + ", " + object.toString() + ")"); } return; - } else if (writable == null) { + } else if (object == null) { TestCase.fail("Field report not null but object is null"); } - switch (primitiveTypeInfo.getPrimitiveCategory()) { - case BOOLEAN: - { - boolean value = deserializeRead.currentBoolean; - if (!(writable instanceof BooleanWritable)) { - TestCase.fail("Boolean expected writable not Boolean"); - } - boolean expected = ((BooleanWritable) writable).get(); - if (value != expected) { - TestCase.fail("Boolean field mismatch (expected " + expected + " found " + value + ")"); - } - } - break; - case BYTE: - { - byte value = deserializeRead.currentByte; - if (!(writable instanceof ByteWritable)) { - TestCase.fail("Byte expected writable not Byte"); - } - byte expected = ((ByteWritable) writable).get(); - if (value != expected) { - TestCase.fail("Byte field mismatch (expected " + (int) expected + " found " + (int) value + ")"); - } - } - break; - case SHORT: - { - short value = deserializeRead.currentShort; - if (!(writable instanceof ShortWritable)) { - TestCase.fail("Short expected writable not Short"); - } - short expected = ((ShortWritable) writable).get(); - if (value != expected) { - TestCase.fail("Short field mismatch (expected " + expected + " found " + value + ")"); - } - } - break; - case INT: - { - int value = deserializeRead.currentInt; - if (!(writable instanceof IntWritable)) { - TestCase.fail("Integer expected writable not Integer"); - } - int expected = ((IntWritable) writable).get(); - if (value != expected) { - TestCase.fail("Int field mismatch (expected " + expected + " found " + value + ")"); - } - } - break; - case LONG: - { - long value = deserializeRead.currentLong; - if (!(writable instanceof LongWritable)) { - TestCase.fail("Long expected writable not Long"); - } - Long expected = ((LongWritable) writable).get(); - if (value != expected) { - TestCase.fail("Long field mismatch (expected " + expected + " found " + value + ")"); - } - } - break; - case FLOAT: - { - float value = deserializeRead.currentFloat; - if (!(writable instanceof FloatWritable)) { - TestCase.fail("Float expected writable not Float"); - } - float expected = ((FloatWritable) writable).get(); - if (value != expected) { - TestCase.fail("Float field mismatch (expected " + expected + " found " + value + ")"); - } - } - break; - case DOUBLE: - { - double value = deserializeRead.currentDouble; - if (!(writable instanceof DoubleWritable)) { - TestCase.fail("Double expected writable not Double"); - } - double expected = ((DoubleWritable) writable).get(); - if (value != expected) { - TestCase.fail("Double field mismatch (expected " + expected + " found " + value + ")"); - } - } - break; - case STRING: - { - byte[] stringBytes = Arrays.copyOfRange( - deserializeRead.currentBytes, - deserializeRead.currentBytesStart, - deserializeRead.currentBytesStart + deserializeRead.currentBytesLength); - Text text = new Text(stringBytes); - String string = text.toString(); - String expected = ((Text) writable).toString(); - if (!string.equals(expected)) { - TestCase.fail("String field mismatch (expected '" + expected + "' found '" + string + "')"); - } - } - break; - case CHAR: - { - byte[] stringBytes = Arrays.copyOfRange( - deserializeRead.currentBytes, - deserializeRead.currentBytesStart, - deserializeRead.currentBytesStart + deserializeRead.currentBytesLength); - Text text = new Text(stringBytes); - String string = text.toString(); - - HiveChar hiveChar = new HiveChar(string, ((CharTypeInfo) primitiveTypeInfo).getLength()); - - HiveChar expected = ((HiveCharWritable) writable).getHiveChar(); - if (!hiveChar.equals(expected)) { - TestCase.fail("Char field mismatch (expected '" + expected + "' found '" + hiveChar + "')"); + switch (typeInfo.getCategory()) { + case PRIMITIVE: + { + PrimitiveTypeInfo primitiveTypeInfo = (PrimitiveTypeInfo) typeInfo; + switch (primitiveTypeInfo.getPrimitiveCategory()) { + case BOOLEAN: + { + boolean value = deserializeRead.currentBoolean; + if (!(object instanceof BooleanWritable)) { + TestCase.fail("Boolean expected writable not Boolean"); + } + boolean expected = ((BooleanWritable) object).get(); + if (value != expected) { + TestCase.fail("Boolean field mismatch (expected " + expected + " found " + value + ")"); + } + } + break; + case BYTE: + { + byte value = deserializeRead.currentByte; + if (!(object instanceof ByteWritable)) { + TestCase.fail("Byte expected writable not Byte"); + } + byte expected = ((ByteWritable) object).get(); + if (value != expected) { + TestCase.fail("Byte field mismatch (expected " + (int) expected + " found " + (int) value + ")"); + } + } + break; + case SHORT: + { + short value = deserializeRead.currentShort; + if (!(object instanceof ShortWritable)) { + TestCase.fail("Short expected writable not Short"); + } + short expected = ((ShortWritable) object).get(); + if (value != expected) { + TestCase.fail("Short field mismatch (expected " + expected + " found " + value + ")"); + } + } + break; + case INT: + { + int value = deserializeRead.currentInt; + if (!(object instanceof IntWritable)) { + TestCase.fail("Integer expected writable not Integer"); + } + int expected = ((IntWritable) object).get(); + if (value != expected) { + TestCase.fail("Int field mismatch (expected " + expected + " found " + value + ")"); + } + } + break; + case LONG: + { + long value = deserializeRead.currentLong; + if (!(object instanceof LongWritable)) { + TestCase.fail("Long expected writable not Long"); + } + Long expected = ((LongWritable) object).get(); + if (value != expected) { + TestCase.fail("Long field mismatch (expected " + expected + " found " + value + ")"); + } + } + break; + case FLOAT: + { + float value = deserializeRead.currentFloat; + if (!(object instanceof FloatWritable)) { + TestCase.fail("Float expected writable not Float"); + } + float expected = ((FloatWritable) object).get(); + if (value != expected) { + TestCase.fail("Float field mismatch (expected " + expected + " found " + value + ")"); + } + } + break; + case DOUBLE: + { + double value = deserializeRead.currentDouble; + if (!(object instanceof DoubleWritable)) { + TestCase.fail("Double expected writable not Double"); + } + double expected = ((DoubleWritable) object).get(); + if (value != expected) { + TestCase.fail("Double field mismatch (expected " + expected + " found " + value + ")"); + } + } + break; + case STRING: + { + byte[] stringBytes = Arrays.copyOfRange( + deserializeRead.currentBytes, + deserializeRead.currentBytesStart, + deserializeRead.currentBytesStart + deserializeRead.currentBytesLength); + Text text = new Text(stringBytes); + String string = text.toString(); + String expected = ((Text) object).toString(); + if (!string.equals(expected)) { + TestCase.fail("String field mismatch (expected '" + expected + "' found '" + string + "')"); + } + } + break; + case CHAR: + { + byte[] stringBytes = Arrays.copyOfRange( + deserializeRead.currentBytes, + deserializeRead.currentBytesStart, + deserializeRead.currentBytesStart + deserializeRead.currentBytesLength); + Text text = new Text(stringBytes); + String string = text.toString(); + + HiveChar hiveChar = new HiveChar(string, ((CharTypeInfo) primitiveTypeInfo).getLength()); + + HiveChar expected = ((HiveCharWritable) object).getHiveChar(); + if (!hiveChar.equals(expected)) { + TestCase.fail("Char field mismatch (expected '" + expected + "' found '" + hiveChar + "')"); + } + } + break; + case VARCHAR: + { + byte[] stringBytes = Arrays.copyOfRange( + deserializeRead.currentBytes, + deserializeRead.currentBytesStart, + deserializeRead.currentBytesStart + deserializeRead.currentBytesLength); + Text text = new Text(stringBytes); + String string = text.toString(); + + HiveVarchar hiveVarchar = new HiveVarchar(string, ((VarcharTypeInfo) primitiveTypeInfo).getLength()); + + HiveVarchar expected = ((HiveVarcharWritable) object).getHiveVarchar(); + if (!hiveVarchar.equals(expected)) { + TestCase.fail("Varchar field mismatch (expected '" + expected + "' found '" + hiveVarchar + "')"); + } + } + break; + case DECIMAL: + { + HiveDecimal value = deserializeRead.currentHiveDecimalWritable.getHiveDecimal(); + if (value == null) { + TestCase.fail("Decimal field evaluated to NULL"); + } + HiveDecimal expected = ((HiveDecimalWritable) object).getHiveDecimal(); + if (!value.equals(expected)) { + DecimalTypeInfo decimalTypeInfo = (DecimalTypeInfo) primitiveTypeInfo; + int precision = decimalTypeInfo.getPrecision(); + int scale = decimalTypeInfo.getScale(); + TestCase.fail("Decimal field mismatch (expected " + expected.toString() + " found " + value.toString() + ") precision " + precision + ", scale " + scale); + } + } + break; + case DATE: + { + Date value = deserializeRead.currentDateWritable.get(); + Date expected = ((DateWritable) object).get(); + if (!value.equals(expected)) { + TestCase.fail("Date field mismatch (expected " + expected.toString() + " found " + value.toString() + ")"); + } + } + break; + case TIMESTAMP: + { + Timestamp value = deserializeRead.currentTimestampWritable.getTimestamp(); + Timestamp expected = ((TimestampWritable) object).getTimestamp(); + if (!value.equals(expected)) { + TestCase.fail("Timestamp field mismatch (expected " + expected.toString() + " found " + value.toString() + ")"); + } + } + break; + case INTERVAL_YEAR_MONTH: + { + HiveIntervalYearMonth value = deserializeRead.currentHiveIntervalYearMonthWritable.getHiveIntervalYearMonth(); + HiveIntervalYearMonth expected = ((HiveIntervalYearMonthWritable) object).getHiveIntervalYearMonth(); + if (!value.equals(expected)) { + TestCase.fail("HiveIntervalYearMonth field mismatch (expected " + expected.toString() + " found " + value.toString() + ")"); + } + } + break; + case INTERVAL_DAY_TIME: + { + HiveIntervalDayTime value = deserializeRead.currentHiveIntervalDayTimeWritable.getHiveIntervalDayTime(); + HiveIntervalDayTime expected = ((HiveIntervalDayTimeWritable) object).getHiveIntervalDayTime(); + if (!value.equals(expected)) { + TestCase.fail("HiveIntervalDayTime field mismatch (expected " + expected.toString() + " found " + value.toString() + ")"); + } + } + break; + case BINARY: + { + byte[] byteArray = Arrays.copyOfRange( + deserializeRead.currentBytes, + deserializeRead.currentBytesStart, + deserializeRead.currentBytesStart + deserializeRead.currentBytesLength); + BytesWritable bytesWritable = (BytesWritable) object; + byte[] expected = Arrays.copyOfRange(bytesWritable.getBytes(), 0, bytesWritable.getLength()); + if (byteArray.length != expected.length){ + TestCase.fail("Byte Array field mismatch (expected " + Arrays.toString(expected) + + " found " + Arrays.toString(byteArray) + ")"); + } + for (int b = 0; b < byteArray.length; b++) { + if (byteArray[b] != expected[b]) { + TestCase.fail("Byte Array field mismatch (expected " + Arrays.toString(expected) + + " found " + Arrays.toString(byteArray) + ")"); + } + } + } + break; + default: + throw new Error("Unknown primitive category " + primitiveTypeInfo.getPrimitiveCategory()); } } break; - case VARCHAR: - { - byte[] stringBytes = Arrays.copyOfRange( - deserializeRead.currentBytes, - deserializeRead.currentBytesStart, - deserializeRead.currentBytesStart + deserializeRead.currentBytesLength); - Text text = new Text(stringBytes); - String string = text.toString(); - - HiveVarchar hiveVarchar = new HiveVarchar(string, ((VarcharTypeInfo) primitiveTypeInfo).getLength()); + case LIST: + case MAP: + case STRUCT: + case UNION: + throw new Error("Complex types need to be handled separately"); + default: + throw new Error("Unknown category " + typeInfo.getCategory()); + } + } - HiveVarchar expected = ((HiveVarcharWritable) writable).getHiveVarchar(); - if (!hiveVarchar.equals(expected)) { - TestCase.fail("Varchar field mismatch (expected '" + expected + "' found '" + hiveVarchar + "')"); - } - } - break; - case DECIMAL: - { - HiveDecimal value = deserializeRead.currentHiveDecimalWritable.getHiveDecimal(); - if (value == null) { - TestCase.fail("Decimal field evaluated to NULL"); - } - HiveDecimal expected = ((HiveDecimalWritable) writable).getHiveDecimal(); - if (!value.equals(expected)) { - DecimalTypeInfo decimalTypeInfo = (DecimalTypeInfo) primitiveTypeInfo; - int precision = decimalTypeInfo.getPrecision(); - int scale = decimalTypeInfo.getScale(); - TestCase.fail("Decimal field mismatch (expected " + expected.toString() + " found " + value.toString() + ") precision " + precision + ", scale " + scale); - } - } - break; - case DATE: - { - Date value = deserializeRead.currentDateWritable.get(); - Date expected = ((DateWritable) writable).get(); - if (!value.equals(expected)) { - TestCase.fail("Date field mismatch (expected " + expected.toString() + " found " + value.toString() + ")"); + public static void serializeWrite(SerializeWrite serializeWrite, + TypeInfo typeInfo, Object object) throws IOException { + if (object == null) { + serializeWrite.writeNull(); + return; + } + switch (typeInfo.getCategory()) { + case PRIMITIVE: + { + PrimitiveTypeInfo primitiveTypeInfo = (PrimitiveTypeInfo) typeInfo; + switch (primitiveTypeInfo.getPrimitiveCategory()) { + case BOOLEAN: + { + boolean value = ((BooleanWritable) object).get(); + serializeWrite.writeBoolean(value); + } + break; + case BYTE: + { + byte value = ((ByteWritable) object).get(); + serializeWrite.writeByte(value); + } + break; + case SHORT: + { + short value = ((ShortWritable) object).get(); + serializeWrite.writeShort(value); + } + break; + case INT: + { + int value = ((IntWritable) object).get(); + serializeWrite.writeInt(value); + } + break; + case LONG: + { + long value = ((LongWritable) object).get(); + serializeWrite.writeLong(value); + } + break; + case FLOAT: + { + float value = ((FloatWritable) object).get(); + serializeWrite.writeFloat(value); + } + break; + case DOUBLE: + { + double value = ((DoubleWritable) object).get(); + serializeWrite.writeDouble(value); + } + break; + case STRING: + { + Text value = (Text) object; + byte[] stringBytes = value.getBytes(); + int stringLength = stringBytes.length; + serializeWrite.writeString(stringBytes, 0, stringLength); + } + break; + case CHAR: + { + HiveChar value = ((HiveCharWritable) object).getHiveChar(); + serializeWrite.writeHiveChar(value); + } + break; + case VARCHAR: + { + HiveVarchar value = ((HiveVarcharWritable) object).getHiveVarchar(); + serializeWrite.writeHiveVarchar(value); + } + break; + case DECIMAL: + { + HiveDecimal value = ((HiveDecimalWritable) object).getHiveDecimal(); + DecimalTypeInfo decTypeInfo = (DecimalTypeInfo)primitiveTypeInfo; + serializeWrite.writeHiveDecimal(value, decTypeInfo.scale()); + } + break; + case DATE: + { + Date value = ((DateWritable) object).get(); + serializeWrite.writeDate(value); + } + break; + case TIMESTAMP: + { + Timestamp value = ((TimestampWritable) object).getTimestamp(); + serializeWrite.writeTimestamp(value); + } + break; + case INTERVAL_YEAR_MONTH: + { + HiveIntervalYearMonth value = ((HiveIntervalYearMonthWritable) object).getHiveIntervalYearMonth(); + serializeWrite.writeHiveIntervalYearMonth(value); + } + break; + case INTERVAL_DAY_TIME: + { + HiveIntervalDayTime value = ((HiveIntervalDayTimeWritable) object).getHiveIntervalDayTime(); + serializeWrite.writeHiveIntervalDayTime(value); + } + break; + case BINARY: + { + BytesWritable byteWritable = (BytesWritable) object; + byte[] binaryBytes = byteWritable.getBytes(); + int length = byteWritable.getLength(); + serializeWrite.writeBinary(binaryBytes, 0, length); + } + break; + default: + throw new Error("Unknown primitive category " + primitiveTypeInfo.getPrimitiveCategory().name()); } } break; - case TIMESTAMP: + case LIST: { - Timestamp value = deserializeRead.currentTimestampWritable.getTimestamp(); - Timestamp expected = ((TimestampWritable) writable).getTimestamp(); - if (!value.equals(expected)) { - TestCase.fail("Timestamp field mismatch (expected " + expected.toString() + " found " + value.toString() + ")"); + ListTypeInfo listTypeInfo = (ListTypeInfo) typeInfo; + TypeInfo elementTypeInfo = listTypeInfo.getListElementTypeInfo(); + ArrayList elements = (ArrayList) object; + serializeWrite.beginList(elements); + boolean isFirst = true; + for (Object elementObject : elements) { + if (isFirst) { + isFirst = false; + } else { + serializeWrite.separateList(); + } + if (elementObject == null) { + serializeWrite.writeNull(); + } else { + serializeWrite(serializeWrite, elementTypeInfo, elementObject); + } } - } - break; - case INTERVAL_YEAR_MONTH: - { - HiveIntervalYearMonth value = deserializeRead.currentHiveIntervalYearMonthWritable.getHiveIntervalYearMonth(); - HiveIntervalYearMonth expected = ((HiveIntervalYearMonthWritable) writable).getHiveIntervalYearMonth(); - if (!value.equals(expected)) { - TestCase.fail("HiveIntervalYearMonth field mismatch (expected " + expected.toString() + " found " + value.toString() + ")"); + serializeWrite.finishList(); + } + break; + case MAP: + { + MapTypeInfo mapTypeInfo = (MapTypeInfo) typeInfo; + TypeInfo keyTypeInfo = mapTypeInfo.getMapKeyTypeInfo(); + TypeInfo valueTypeInfo = mapTypeInfo.getMapValueTypeInfo(); + HashMap hashMap = (HashMap) object; + serializeWrite.beginMap(hashMap); + boolean isFirst = true; + for (Entry entry : hashMap.entrySet()) { + if (isFirst) { + isFirst = false; + } else { + serializeWrite.separateKeyValuePair(); + } + if (entry.getKey() == null) { + serializeWrite.writeNull(); + } else { + serializeWrite(serializeWrite, keyTypeInfo, entry.getKey()); + } + serializeWrite.separateKey(); + if (entry.getValue() == null) { + serializeWrite.writeNull(); + } else { + serializeWrite(serializeWrite, valueTypeInfo, entry.getValue()); + } } - } - break; - case INTERVAL_DAY_TIME: - { - HiveIntervalDayTime value = deserializeRead.currentHiveIntervalDayTimeWritable.getHiveIntervalDayTime(); - HiveIntervalDayTime expected = ((HiveIntervalDayTimeWritable) writable).getHiveIntervalDayTime(); - if (!value.equals(expected)) { - TestCase.fail("HiveIntervalDayTime field mismatch (expected " + expected.toString() + " found " + value.toString() + ")"); + serializeWrite.finishMap(); + } + break; + case STRUCT: + { + StructTypeInfo structTypeInfo = (StructTypeInfo) typeInfo; + ArrayList fieldTypeInfos = structTypeInfo.getAllStructFieldTypeInfos(); + ArrayList fieldValues = (ArrayList) object; + final int size = fieldValues.size(); + serializeWrite.beginStruct(fieldValues); + boolean isFirst = true; + for (int i = 0; i < size; i++) { + if (isFirst) { + isFirst = false; + } else { + serializeWrite.separateStruct(); + } + serializeWrite(serializeWrite, fieldTypeInfos.get(i), fieldValues.get(i)); } + serializeWrite.finishStruct(); } break; - case BINARY: + case UNION: { - byte[] byteArray = Arrays.copyOfRange( - deserializeRead.currentBytes, - deserializeRead.currentBytesStart, - deserializeRead.currentBytesStart + deserializeRead.currentBytesLength); - BytesWritable bytesWritable = (BytesWritable) writable; - byte[] expected = Arrays.copyOfRange(bytesWritable.getBytes(), 0, bytesWritable.getLength()); - if (byteArray.length != expected.length){ - TestCase.fail("Byte Array field mismatch (expected " + Arrays.toString(expected) - + " found " + Arrays.toString(byteArray) + ")"); - } - for (int b = 0; b < byteArray.length; b++) { - if (byteArray[b] != expected[b]) { - TestCase.fail("Byte Array field mismatch (expected " + Arrays.toString(expected) - + " found " + Arrays.toString(byteArray) + ")"); - } - } + UnionTypeInfo unionTypeInfo = (UnionTypeInfo) typeInfo; + List fieldTypeInfos = unionTypeInfo.getAllUnionObjectTypeInfos(); + final int size = fieldTypeInfos.size(); + StandardUnion standardUnion = (StandardUnion) object; + byte tag = standardUnion.getTag(); + serializeWrite.beginUnion(tag); + serializeWrite(serializeWrite, fieldTypeInfos.get(tag), standardUnion.getObject()); + serializeWrite.finishUnion(); } break; default: - throw new Error("Unknown primitive category " + primitiveTypeInfo.getPrimitiveCategory()); + throw new Error("Unknown category " + typeInfo.getCategory().name()); } } - public static void serializeWrite(SerializeWrite serializeWrite, - PrimitiveTypeInfo primitiveTypeInfo, Writable writable) throws IOException { - if (writable == null) { - serializeWrite.writeNull(); - return; + public Object readComplexPrimitiveField(DeserializeRead deserializeRead, + PrimitiveTypeInfo primitiveTypeInfo) throws IOException { + boolean isNull = !deserializeRead.readComplexField(); + if (isNull) { + return null; + } else { + return doReadComplexPrimitiveField(deserializeRead, primitiveTypeInfo); } + } + + private static Object doReadComplexPrimitiveField(DeserializeRead deserializeRead, + PrimitiveTypeInfo primitiveTypeInfo) throws IOException { switch (primitiveTypeInfo.getPrimitiveCategory()) { - case BOOLEAN: - { - boolean value = ((BooleanWritable) writable).get(); - serializeWrite.writeBoolean(value); - } - break; + case BOOLEAN: + return new BooleanWritable(deserializeRead.currentBoolean); case BYTE: - { - byte value = ((ByteWritable) writable).get(); - serializeWrite.writeByte(value); - } - break; + return new ByteWritable(deserializeRead.currentByte); case SHORT: - { - short value = ((ShortWritable) writable).get(); - serializeWrite.writeShort(value); - } - break; + return new ShortWritable(deserializeRead.currentShort); case INT: - { - int value = ((IntWritable) writable).get(); - serializeWrite.writeInt(value); - } - break; + return new IntWritable(deserializeRead.currentInt); case LONG: - { - long value = ((LongWritable) writable).get(); - serializeWrite.writeLong(value); - } - break; + return new LongWritable(deserializeRead.currentLong); case FLOAT: - { - float value = ((FloatWritable) writable).get(); - serializeWrite.writeFloat(value); - } - break; + return new FloatWritable(deserializeRead.currentFloat); case DOUBLE: - { - double value = ((DoubleWritable) writable).get(); - serializeWrite.writeDouble(value); - } - break; + return new DoubleWritable(deserializeRead.currentDouble); case STRING: - { - Text value = (Text) writable; - byte[] stringBytes = value.getBytes(); - int stringLength = stringBytes.length; - serializeWrite.writeString(stringBytes, 0, stringLength); - } - break; + return new Text(new String( + deserializeRead.currentBytes, + deserializeRead.currentBytesStart, + deserializeRead.currentBytesLength, + StandardCharsets.UTF_8)); case CHAR: - { - HiveChar value = ((HiveCharWritable) writable).getHiveChar(); - serializeWrite.writeHiveChar(value); - } - break; + return new HiveCharWritable(new HiveChar( + new String( + deserializeRead.currentBytes, + deserializeRead.currentBytesStart, + deserializeRead.currentBytesLength, + StandardCharsets.UTF_8), + ((CharTypeInfo) primitiveTypeInfo).getLength())); case VARCHAR: - { - HiveVarchar value = ((HiveVarcharWritable) writable).getHiveVarchar(); - serializeWrite.writeHiveVarchar(value); + if (deserializeRead.currentBytes == null) { + throw new RuntimeException(); } - break; + return new HiveVarcharWritable(new HiveVarchar( + new String( + deserializeRead.currentBytes, + deserializeRead.currentBytesStart, + deserializeRead.currentBytesLength, + StandardCharsets.UTF_8), + ((VarcharTypeInfo) primitiveTypeInfo).getLength())); case DECIMAL: - { - HiveDecimal value = ((HiveDecimalWritable) writable).getHiveDecimal(); - DecimalTypeInfo decTypeInfo = (DecimalTypeInfo)primitiveTypeInfo; - serializeWrite.writeHiveDecimal(value, decTypeInfo.scale()); - } - break; + return new HiveDecimalWritable(deserializeRead.currentHiveDecimalWritable); case DATE: - { - Date value = ((DateWritable) writable).get(); - serializeWrite.writeDate(value); - } - break; + return new DateWritable(deserializeRead.currentDateWritable); case TIMESTAMP: - { - Timestamp value = ((TimestampWritable) writable).getTimestamp(); - serializeWrite.writeTimestamp(value); - } - break; + return new TimestampWritable(deserializeRead.currentTimestampWritable); case INTERVAL_YEAR_MONTH: - { - HiveIntervalYearMonth value = ((HiveIntervalYearMonthWritable) writable).getHiveIntervalYearMonth(); - serializeWrite.writeHiveIntervalYearMonth(value); - } - break; + return new HiveIntervalYearMonthWritable(deserializeRead.currentHiveIntervalYearMonthWritable); case INTERVAL_DAY_TIME: - { - HiveIntervalDayTime value = ((HiveIntervalDayTimeWritable) writable).getHiveIntervalDayTime(); - serializeWrite.writeHiveIntervalDayTime(value); - } - break; + return new HiveIntervalDayTimeWritable(deserializeRead.currentHiveIntervalDayTimeWritable); case BINARY: - { - BytesWritable byteWritable = (BytesWritable) writable; - byte[] binaryBytes = byteWritable.getBytes(); - int length = byteWritable.getLength(); - serializeWrite.writeBinary(binaryBytes, 0, length); + return new BytesWritable( + Arrays.copyOfRange( + deserializeRead.currentBytes, + deserializeRead.currentBytesStart, + deserializeRead.currentBytesLength + deserializeRead.currentBytesStart)); + default: + throw new Error("Unknown primitive category " + primitiveTypeInfo.getPrimitiveCategory()); + } + } + + public static Object deserializeReadComplexType(DeserializeRead deserializeRead, + TypeInfo typeInfo) throws IOException { + + boolean isNull = !deserializeRead.readNextField(); + if (isNull) { + return null; + } + return getComplexField(deserializeRead, typeInfo); + } + + static int fake = 0; + + private static Object getComplexField(DeserializeRead deserializeRead, + TypeInfo typeInfo) throws IOException { + switch (typeInfo.getCategory()) { + case PRIMITIVE: + return doReadComplexPrimitiveField(deserializeRead, (PrimitiveTypeInfo) typeInfo); + case LIST: + { + ListTypeInfo listTypeInfo = (ListTypeInfo) typeInfo; + TypeInfo elementTypeInfo = listTypeInfo.getListElementTypeInfo(); + ArrayList list = new ArrayList(); + Object eleObj; + boolean isNull; + while (deserializeRead.isNextComplexMultiValue()) { + isNull = !deserializeRead.readComplexField(); + if (isNull) { + eleObj = null; + } else { + eleObj = getComplexField(deserializeRead, elementTypeInfo); + if (eleObj instanceof String && ((String) eleObj).equals("SMNAR")) { + fake++; + } + } + list.add(eleObj); + } + return list; + } + case MAP: + { + MapTypeInfo mapTypeInfo = (MapTypeInfo) typeInfo; + TypeInfo keyTypeInfo = mapTypeInfo.getMapKeyTypeInfo(); + TypeInfo valueTypeInfo = mapTypeInfo.getMapValueTypeInfo(); + HashMap hashMap = new HashMap(); + Object keyObj; + Object valueObj; + boolean isNull; + while (deserializeRead.isNextComplexMultiValue()) { + isNull = !deserializeRead.readComplexField(); + if (isNull) { + keyObj = null; + } else { + keyObj = getComplexField(deserializeRead, keyTypeInfo); + } + isNull = !deserializeRead.readComplexField(); + if (isNull) { + valueObj = null; + } else { + valueObj = getComplexField(deserializeRead, valueTypeInfo); + } + hashMap.put(keyObj, valueObj); + } + return hashMap; + } + case STRUCT: + { + StructTypeInfo structTypeInfo = (StructTypeInfo) typeInfo; + ArrayList fieldTypeInfos = structTypeInfo.getAllStructFieldTypeInfos(); + final int size = fieldTypeInfos.size(); + ArrayList fieldValues = new ArrayList(); + Object fieldObj; + boolean isNull; + for (int i = 0; i < size; i++) { + isNull = !deserializeRead.readComplexField(); + if (isNull) { + fieldObj = null; + } else { + fieldObj = getComplexField(deserializeRead, fieldTypeInfos.get(i)); + } + fieldValues.add(fieldObj); + } + deserializeRead.finishComplexVariableFieldsType(); + return fieldValues; + } + case UNION: + { + UnionTypeInfo unionTypeInfo = (UnionTypeInfo) typeInfo; + List unionTypeInfos = unionTypeInfo.getAllUnionObjectTypeInfos(); + final int size = unionTypeInfos.size(); + Object tagObj; + int tag; + Object unionObj; + boolean isNull = !deserializeRead.readComplexField(); + if (isNull) { + unionObj = null; + } else { + // Get the tag value. + tagObj = getComplexField(deserializeRead, TypeInfoFactory.intTypeInfo); + tag = ((IntWritable) tagObj).get(); + + isNull = !deserializeRead.readComplexField(); + if (isNull) { + unionObj = null; + } else { + // Get the union value. + unionObj = new StandardUnion((byte) tag, getComplexField(deserializeRead, unionTypeInfos.get(tag))); + } + } + + deserializeRead.finishComplexVariableFieldsType(); + return unionObj; } - break; default: - throw new Error("Unknown primitive category " + primitiveTypeInfo.getPrimitiveCategory().name()); + throw new Error("Unexpected category " + typeInfo.getCategory()); } } } \ No newline at end of file diff --git serde/src/test/org/apache/hadoop/hive/serde2/binarysortable/MyTestClass.java serde/src/test/org/apache/hadoop/hive/serde2/binarysortable/MyTestClass.java index df5e8dbf6b..77982a642d 100644 --- serde/src/test/org/apache/hadoop/hive/serde2/binarysortable/MyTestClass.java +++ serde/src/test/org/apache/hadoop/hive/serde2/binarysortable/MyTestClass.java @@ -230,6 +230,9 @@ public static void nonRandomRowFill(Object[][] rows, PrimitiveCategory[] primiti for (int i = 0; i < minCount; i++) { Object[] row = rows[i]; for (int c = 0; c < primitiveCategories.length; c++) { + if (primitiveCategories[c] == null) { + continue; + } Object object = row[c]; // Current value. switch (primitiveCategories[c]) { case BOOLEAN: diff --git serde/src/test/org/apache/hadoop/hive/serde2/binarysortable/TestBinarySortableFast.java serde/src/test/org/apache/hadoop/hive/serde2/binarysortable/TestBinarySortableFast.java index b3694626e2..311e369f54 100644 --- serde/src/test/org/apache/hadoop/hive/serde2/binarysortable/TestBinarySortableFast.java +++ serde/src/test/org/apache/hadoop/hive/serde2/binarysortable/TestBinarySortableFast.java @@ -17,6 +17,7 @@ */ package org.apache.hadoop.hive.serde2.binarysortable; +import java.io.IOException; import java.util.ArrayList; import java.io.EOFException; import java.util.Arrays; @@ -30,13 +31,16 @@ import org.apache.hadoop.hive.serde2.VerifyFast; import org.apache.hadoop.hive.serde2.binarysortable.fast.BinarySortableDeserializeRead; import org.apache.hadoop.hive.serde2.binarysortable.fast.BinarySortableSerializeWrite; +import org.apache.hadoop.hive.serde2.lazy.VerifyLazy; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils; import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector; -import org.apache.hadoop.hive.serde2.typeinfo.PrimitiveTypeInfo; +import org.apache.hadoop.hive.serde2.objectinspector.UnionObject; +import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo; import org.apache.hadoop.io.BytesWritable; -import org.apache.hadoop.io.Writable; import junit.framework.TestCase; +import org.junit.Assert; public class TestBinarySortableFast extends TestCase { @@ -48,11 +52,11 @@ private void testBinarySortableFast( boolean[] columnSortOrderIsDesc, byte[] columnNullMarker, byte[] columnNotNullMarker, AbstractSerDe serde, StructObjectInspector rowOI, AbstractSerDe serde_fewer, StructObjectInspector writeRowOI, - boolean ascending, PrimitiveTypeInfo[] primitiveTypeInfos, + boolean ascending, TypeInfo[] typeInfos, boolean useIncludeColumns, boolean doWriteFewerColumns, Random r) throws Throwable { int rowCount = rows.length; - int columnCount = primitiveTypeInfos.length; + int columnCount = typeInfos.length; boolean[] columnsToInclude = null; if (useIncludeColumns) { @@ -83,10 +87,7 @@ private void testBinarySortableFast( int[] perFieldWriteLengths = new int[columnCount]; for (int index = 0; index < writeColumnCount; index++) { - - Writable writable = (Writable) row[index]; - - VerifyFast.serializeWrite(binarySortableSerializeWrite, primitiveTypeInfos[index], writable); + VerifyFast.serializeWrite(binarySortableSerializeWrite, typeInfos[index], row[index]); perFieldWriteLengths[index] = output.getLength(); } perFieldWriteLengthsArray[i] = perFieldWriteLengths; @@ -95,7 +96,8 @@ private void testBinarySortableFast( bytesWritable.set(output.getData(), 0, output.getLength()); serializeWriteBytes[i] = bytesWritable; if (i > 0) { - int compareResult = serializeWriteBytes[i - 1].compareTo(serializeWriteBytes[i]); + BytesWritable previousBytesWritable = serializeWriteBytes[i - 1]; + int compareResult = previousBytesWritable.compareTo(bytesWritable); if ((compareResult < 0 && !ascending) || (compareResult > 0 && ascending)) { System.out.println("Test failed in " @@ -117,7 +119,7 @@ private void testBinarySortableFast( Object[] row = rows[i]; BinarySortableDeserializeRead binarySortableDeserializeRead = new BinarySortableDeserializeRead( - primitiveTypeInfos, + typeInfos, /* useExternalBuffer */ false, columnSortOrderIsDesc, columnNullMarker, @@ -132,10 +134,9 @@ private void testBinarySortableFast( binarySortableDeserializeRead.skipNextField(); } else if (index >= writeColumnCount) { // Should come back a null. - VerifyFast.verifyDeserializeRead(binarySortableDeserializeRead, primitiveTypeInfos[index], null); + VerifyFast.verifyDeserializeRead(binarySortableDeserializeRead, typeInfos[index], null); } else { - Writable writable = (Writable) row[index]; - VerifyFast.verifyDeserializeRead(binarySortableDeserializeRead, primitiveTypeInfos[index], writable); + verifyRead(binarySortableDeserializeRead, typeInfos[index], row[index]); } } if (writeColumnCount == columnCount) { @@ -147,7 +148,7 @@ private void testBinarySortableFast( */ BinarySortableDeserializeRead binarySortableDeserializeRead2 = new BinarySortableDeserializeRead( - primitiveTypeInfos, + typeInfos, /* useExternalBuffer */ false, columnSortOrderIsDesc, columnNullMarker, @@ -157,22 +158,24 @@ private void testBinarySortableFast( bytesWritable.getBytes(), 0, bytesWritable.getLength() - 1); // One fewer byte. for (int index = 0; index < writeColumnCount; index++) { - Writable writable = (Writable) row[index]; if (index == writeColumnCount - 1) { boolean threw = false; try { - VerifyFast.verifyDeserializeRead(binarySortableDeserializeRead2, primitiveTypeInfos[index], writable); + verifyRead(binarySortableDeserializeRead2, typeInfos[index], row[index]); } catch (EOFException e) { // debugDetailedReadPositionString = binarySortableDeserializeRead2.getDetailedReadPositionString(); // debugStackTrace = e.getStackTrace(); threw = true; } - TestCase.assertTrue(threw); + + if (!threw && row[index] != null) { + Assert.fail(); + } } else { if (useIncludeColumns && !columnsToInclude[index]) { binarySortableDeserializeRead2.skipNextField(); } else { - VerifyFast.verifyDeserializeRead(binarySortableDeserializeRead2, primitiveTypeInfos[index], writable); + verifyRead(binarySortableDeserializeRead2, typeInfos[index], row[index]); } } } @@ -270,7 +273,7 @@ private void testBinarySortableFast( "\nSerDe: " + serDeFields.toString() + "\nperFieldWriteLengths " + Arrays.toString(perFieldWriteLengthsArray[i]) + - "\nprimitiveTypeInfos " + Arrays.toString(primitiveTypeInfos) + + "\nprimitiveTypeInfos " + Arrays.toString(typeInfos) + "\nrow " + Arrays.toString(row)); } } @@ -282,7 +285,7 @@ private void testBinarySortableFast( Object[] row = rows[i]; BinarySortableDeserializeRead binarySortableDeserializeRead = new BinarySortableDeserializeRead( - primitiveTypeInfos, + typeInfos, /* useExternalBuffer */ false, columnSortOrderIsDesc, columnNullMarker, @@ -297,10 +300,9 @@ private void testBinarySortableFast( binarySortableDeserializeRead.skipNextField(); } else if (index >= writeColumnCount) { // Should come back a null. - VerifyFast.verifyDeserializeRead(binarySortableDeserializeRead, primitiveTypeInfos[index], null); + verifyRead(binarySortableDeserializeRead, typeInfos[index], null); } else { - Writable writable = (Writable) row[index]; - VerifyFast.verifyDeserializeRead(binarySortableDeserializeRead, primitiveTypeInfos[index], writable); + verifyRead(binarySortableDeserializeRead, typeInfos[index], row[index]); } } if (writeColumnCount == columnCount) { @@ -309,11 +311,47 @@ private void testBinarySortableFast( } } - private void testBinarySortableFastCase(int caseNum, boolean doNonRandomFill, Random r) + private void verifyRead(BinarySortableDeserializeRead binarySortableDeserializeRead, + TypeInfo typeInfo, Object expectedObject) throws IOException { + if (typeInfo.getCategory() == ObjectInspector.Category.PRIMITIVE) { + VerifyFast.verifyDeserializeRead(binarySortableDeserializeRead, typeInfo, expectedObject); + } else { + if (expectedObject instanceof ArrayList && ((ArrayList) expectedObject).size() == 0) { +// fake++; + } + Object complexFieldObj = VerifyFast.deserializeReadComplexType(binarySortableDeserializeRead, typeInfo); + if (expectedObject == null) { + if (complexFieldObj != null) { + TestCase.fail("Field reports not null but object is null (class " + complexFieldObj.getClass().getName() + + ", " + complexFieldObj.toString() + ")"); + } + } else { + if (complexFieldObj == null) { + // It's hard to distinguish a union with null from a null union. + if (expectedObject instanceof UnionObject) { + UnionObject expectedUnion = (UnionObject) expectedObject; + if (expectedUnion.getObject() == null) { + return; + } + } + TestCase.fail("Field reports null but object is not null (class " + expectedObject.getClass().getName() + + ", " + expectedObject.toString() + ")"); + } + } + if (!VerifyLazy.lazyCompare(typeInfo, complexFieldObj, expectedObject)) { + TestCase.fail("Comparision failed typeInfo " + typeInfo.toString()); + } + } + } + + private void testBinarySortableFastCase( + int caseNum, boolean doNonRandomFill, Random r, SerdeRandomRowSource.SupportedTypes supportedTypes, int depth) throws Throwable { SerdeRandomRowSource source = new SerdeRandomRowSource(); - source.init(r); + + // UNDONE: Until Fast BinarySortable supports complex types -- disable. + source.init(r, supportedTypes, depth); int rowCount = 1000; Object[][] rows = source.randomRows(rowCount); @@ -327,8 +365,8 @@ private void testBinarySortableFastCase(int caseNum, boolean doNonRandomFill, Ra StructObjectInspector rowStructObjectInspector = source.rowStructObjectInspector(); - PrimitiveTypeInfo[] primitiveTypeInfos = source.primitiveTypeInfos(); - int columnCount = primitiveTypeInfos.length; + TypeInfo[] typeInfos = source.typeInfos(); + int columnCount = typeInfos.length; int writeColumnCount = columnCount; StructObjectInspector writeRowStructObjectInspector = rowStructObjectInspector; @@ -385,14 +423,14 @@ private void testBinarySortableFastCase(int caseNum, boolean doNonRandomFill, Ra columnSortOrderIsDesc, columnNullMarker, columnNotNullMarker, serde_ascending, rowStructObjectInspector, serde_ascending_fewer, writeRowStructObjectInspector, - /* ascending */ true, primitiveTypeInfos, + /* ascending */ true, typeInfos, /* useIncludeColumns */ false, /* doWriteFewerColumns */ false, r); testBinarySortableFast(source, rows, columnSortOrderIsDesc, columnNullMarker, columnNotNullMarker, serde_ascending, rowStructObjectInspector, serde_ascending_fewer, writeRowStructObjectInspector, - /* ascending */ true, primitiveTypeInfos, + /* ascending */ true, typeInfos, /* useIncludeColumns */ true, /* doWriteFewerColumns */ false, r); if (doWriteFewerColumns) { @@ -400,14 +438,14 @@ private void testBinarySortableFastCase(int caseNum, boolean doNonRandomFill, Ra columnSortOrderIsDesc, columnNullMarker, columnNotNullMarker, serde_ascending, rowStructObjectInspector, serde_ascending_fewer, writeRowStructObjectInspector, - /* ascending */ true, primitiveTypeInfos, + /* ascending */ true, typeInfos, /* useIncludeColumns */ false, /* doWriteFewerColumns */ true, r); testBinarySortableFast(source, rows, columnSortOrderIsDesc, columnNullMarker, columnNotNullMarker, serde_ascending, rowStructObjectInspector, serde_ascending_fewer, writeRowStructObjectInspector, - /* ascending */ true, primitiveTypeInfos, + /* ascending */ true, typeInfos, /* useIncludeColumns */ true, /* doWriteFewerColumns */ true, r); } @@ -420,14 +458,14 @@ private void testBinarySortableFastCase(int caseNum, boolean doNonRandomFill, Ra columnSortOrderIsDesc, columnNullMarker, columnNotNullMarker, serde_descending, rowStructObjectInspector, serde_ascending_fewer, writeRowStructObjectInspector, - /* ascending */ false, primitiveTypeInfos, + /* ascending */ false, typeInfos, /* useIncludeColumns */ false, /* doWriteFewerColumns */ false, r); testBinarySortableFast(source, rows, columnSortOrderIsDesc, columnNullMarker, columnNotNullMarker, serde_descending, rowStructObjectInspector, serde_ascending_fewer, writeRowStructObjectInspector, - /* ascending */ false, primitiveTypeInfos, + /* ascending */ false, typeInfos, /* useIncludeColumns */ true, /* doWriteFewerColumns */ false, r); if (doWriteFewerColumns) { @@ -435,27 +473,27 @@ private void testBinarySortableFastCase(int caseNum, boolean doNonRandomFill, Ra columnSortOrderIsDesc, columnNullMarker, columnNotNullMarker, serde_descending, rowStructObjectInspector, serde_descending_fewer, writeRowStructObjectInspector, - /* ascending */ false, primitiveTypeInfos, + /* ascending */ false, typeInfos, /* useIncludeColumns */ false, /* doWriteFewerColumns */ true, r); testBinarySortableFast(source, rows, columnSortOrderIsDesc, columnNullMarker, columnNotNullMarker, serde_descending, rowStructObjectInspector, serde_descending_fewer, writeRowStructObjectInspector, - /* ascending */ false, primitiveTypeInfos, + /* ascending */ false, typeInfos, /* useIncludeColumns */ true, /* doWriteFewerColumns */ true, r); } } - public void testBinarySortableFast() throws Throwable { + public void testBinarySortableFast(SerdeRandomRowSource.SupportedTypes supportedTypes, int depth) throws Throwable { try { Random r = new Random(35790); int caseNum = 0; for (int i = 0; i < 10; i++) { - testBinarySortableFastCase(caseNum, (i % 2 == 0), r); + testBinarySortableFastCase(caseNum, (i % 2 == 0), r, supportedTypes, depth); caseNum++; } @@ -465,6 +503,18 @@ public void testBinarySortableFast() throws Throwable { } } + public void testBinarySortableFastPrimitive() throws Throwable { + testBinarySortableFast(SerdeRandomRowSource.SupportedTypes.PRIMITIVE, 0); + } + + public void testBinarySortableFastComplexDepthOne() throws Throwable { + testBinarySortableFast(SerdeRandomRowSource.SupportedTypes.ALL_EXCEPT_MAP, 1); + } + + public void testBinarySortableFastComplexDepthFour() throws Throwable { + testBinarySortableFast(SerdeRandomRowSource.SupportedTypes.ALL_EXCEPT_MAP, 4); + } + private static String displayBytes(byte[] bytes, int start, int length) { StringBuilder sb = new StringBuilder(); for (int i = start; i < start + length; i++) { diff --git serde/src/test/org/apache/hadoop/hive/serde2/lazy/TestLazySimpleFast.java serde/src/test/org/apache/hadoop/hive/serde2/lazy/TestLazySimpleFast.java index c857b42f98..04c9805842 100644 --- serde/src/test/org/apache/hadoop/hive/serde2/lazy/TestLazySimpleFast.java +++ serde/src/test/org/apache/hadoop/hive/serde2/lazy/TestLazySimpleFast.java @@ -17,6 +17,8 @@ */ package org.apache.hadoop.hive.serde2.lazy; +import java.io.IOException; +import java.util.ArrayList; import java.util.Arrays; import java.util.Properties; import java.util.Random; @@ -33,10 +35,11 @@ import org.apache.hadoop.hive.serde2.lazy.fast.LazySimpleSerializeWrite; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils; import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector; -import org.apache.hadoop.hive.serde2.typeinfo.PrimitiveTypeInfo; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector.Category; +import org.apache.hadoop.hive.serde2.objectinspector.UnionObject; +import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo; import org.apache.hadoop.io.BytesWritable; import org.apache.hadoop.io.Text; -import org.apache.hadoop.io.Writable; import junit.framework.TestCase; @@ -46,12 +49,12 @@ private void testLazySimpleFast( SerdeRandomRowSource source, Object[][] rows, LazySimpleSerDe serde, StructObjectInspector rowOI, LazySimpleSerDe serde_fewer, StructObjectInspector writeRowOI, - byte separator, LazySerDeParameters serdeParams, LazySerDeParameters serdeParams_fewer, - PrimitiveTypeInfo[] primitiveTypeInfos, + LazySerDeParameters serdeParams, LazySerDeParameters serdeParams_fewer, + TypeInfo[] typeInfos, boolean useIncludeColumns, boolean doWriteFewerColumns, Random r) throws Throwable { int rowCount = rows.length; - int columnCount = primitiveTypeInfos.length; + int columnCount = typeInfos.length; boolean[] columnsToInclude = null; if (useIncludeColumns) { @@ -62,10 +65,10 @@ private void testLazySimpleFast( } int writeColumnCount = columnCount; - PrimitiveTypeInfo[] writePrimitiveTypeInfos = primitiveTypeInfos; + TypeInfo[] writeTypeInfos = typeInfos; if (doWriteFewerColumns) { writeColumnCount = writeRowOI.getAllStructFieldRefs().size(); - writePrimitiveTypeInfos = Arrays.copyOf(primitiveTypeInfos, writeColumnCount); + writeTypeInfos = Arrays.copyOf(typeInfos, writeColumnCount); } // Try to serialize @@ -75,16 +78,12 @@ private void testLazySimpleFast( Output output = new Output(); LazySimpleSerializeWrite lazySimpleSerializeWrite = - new LazySimpleSerializeWrite(columnCount, - separator, serdeParams); + new LazySimpleSerializeWrite(columnCount, serdeParams); lazySimpleSerializeWrite.set(output); for (int index = 0; index < columnCount; index++) { - - Writable writable = (Writable) row[index]; - - VerifyFast.serializeWrite(lazySimpleSerializeWrite, primitiveTypeInfos[index], writable); + VerifyFast.serializeWrite(lazySimpleSerializeWrite, typeInfos[index], row[index]); } BytesWritable bytesWritable = new BytesWritable(); @@ -97,29 +96,24 @@ private void testLazySimpleFast( Object[] row = rows[i]; LazySimpleDeserializeRead lazySimpleDeserializeRead = new LazySimpleDeserializeRead( - writePrimitiveTypeInfos, + writeTypeInfos, /* useExternalBuffer */ false, - separator, serdeParams); + serdeParams); BytesWritable bytesWritable = serializeWriteBytes[i]; byte[] bytes = bytesWritable.getBytes(); int length = bytesWritable.getLength(); lazySimpleDeserializeRead.set(bytes, 0, length); - char[] chars = new char[length]; - for (int c = 0; c < chars.length; c++) { - chars[c] = (char) (bytes[c] & 0xFF); - } - for (int index = 0; index < columnCount; index++) { if (useIncludeColumns && !columnsToInclude[index]) { lazySimpleDeserializeRead.skipNextField(); } else if (index >= writeColumnCount) { // Should come back a null. - VerifyFast.verifyDeserializeRead(lazySimpleDeserializeRead, primitiveTypeInfos[index], null); + verifyReadNull(lazySimpleDeserializeRead, typeInfos[index]); } else { - Writable writable = (Writable) row[index]; - VerifyFast.verifyDeserializeRead(lazySimpleDeserializeRead, primitiveTypeInfos[index], writable); + Object expectedObject = row[index]; + verifyRead(lazySimpleDeserializeRead, typeInfos[index], expectedObject); } } if (writeColumnCount == columnCount) { @@ -128,28 +122,22 @@ private void testLazySimpleFast( } // Try to deserialize using SerDe class our Writable row objects created by SerializeWrite. - for (int i = 0; i < rowCount; i++) { - BytesWritable bytesWritable = serializeWriteBytes[i]; + for (int rowIndex = 0; rowIndex < rowCount; rowIndex++) { + BytesWritable bytesWritable = serializeWriteBytes[rowIndex]; LazyStruct lazySimpleStruct = (LazyStruct) serde.deserialize(bytesWritable); - Object[] row = rows[i]; + Object[] row = rows[rowIndex]; for (int index = 0; index < columnCount; index++) { - PrimitiveTypeInfo primitiveTypeInfo = primitiveTypeInfos[index]; - Writable writable = (Writable) row[index]; - LazyPrimitive lazyPrimitive = (LazyPrimitive) lazySimpleStruct.getField(index); - Object object; - if (lazyPrimitive != null) { - object = lazyPrimitive.getWritableObject(); - } else { - object = null; - } - if (writable == null || object == null) { - if (writable != null || object != null) { + TypeInfo typeInfo = typeInfos[index]; + Object expectedObject = row[index]; + Object object = lazySimpleStruct.getField(index); + if (expectedObject == null || object == null) { + if (expectedObject != null || object != null) { fail("SerDe deserialized NULL column mismatch"); } } else { - if (!object.equals(writable)) { + if (!VerifyLazy.lazyCompare(typeInfo, object, expectedObject)) { fail("SerDe deserialized value does not match"); } } @@ -185,9 +173,9 @@ private void testLazySimpleFast( LazySimpleDeserializeRead lazySimpleDeserializeRead = new LazySimpleDeserializeRead( - writePrimitiveTypeInfos, + writeTypeInfos, /* useExternalBuffer */ false, - separator, serdeParams); + serdeParams); byte[] bytes = serdeBytes[i]; lazySimpleDeserializeRead.set(bytes, 0, bytes.length); @@ -197,10 +185,10 @@ private void testLazySimpleFast( lazySimpleDeserializeRead.skipNextField(); } else if (index >= writeColumnCount) { // Should come back a null. - VerifyFast.verifyDeserializeRead(lazySimpleDeserializeRead, primitiveTypeInfos[index], null); + verifyReadNull(lazySimpleDeserializeRead, typeInfos[index]); } else { - Writable writable = (Writable) row[index]; - VerifyFast.verifyDeserializeRead(lazySimpleDeserializeRead, primitiveTypeInfos[index], writable); + Object expectedObject = row[index]; + verifyRead(lazySimpleDeserializeRead, typeInfos[index], expectedObject); } } if (writeColumnCount == columnCount) { @@ -209,6 +197,51 @@ private void testLazySimpleFast( } } + private void verifyReadNull(LazySimpleDeserializeRead lazySimpleDeserializeRead, + TypeInfo typeInfo) throws IOException { + if (typeInfo.getCategory() == Category.PRIMITIVE) { + VerifyFast.verifyDeserializeRead(lazySimpleDeserializeRead, typeInfo, null); + } else { + Object complexFieldObj = VerifyFast.deserializeReadComplexType(lazySimpleDeserializeRead, typeInfo); + if (complexFieldObj != null) { + TestCase.fail("Field report not null but object is null"); + } + } + } + + static int fake = 0; + + private void verifyRead(LazySimpleDeserializeRead lazySimpleDeserializeRead, + TypeInfo typeInfo, Object expectedObject) throws IOException { + if (typeInfo.getCategory() == Category.PRIMITIVE) { + VerifyFast.verifyDeserializeRead(lazySimpleDeserializeRead, typeInfo, expectedObject); + } else { + if (expectedObject instanceof ArrayList && ((ArrayList) expectedObject).size() == 0) { + fake++; + } + Object complexFieldObj = VerifyFast.deserializeReadComplexType(lazySimpleDeserializeRead, typeInfo); + if (expectedObject == null) { + if (complexFieldObj != null) { + TestCase.fail("Field reports not null but object is null (class " + complexFieldObj.getClass().getName() + ", " + complexFieldObj.toString() + ")"); + } + } else { + if (complexFieldObj == null) { + // It's hard to distinguish a union with null from a null union. + if (expectedObject instanceof UnionObject) { + UnionObject expectedUnion = (UnionObject) expectedObject; + if (expectedUnion.getObject() == null) { + return; + } + } + TestCase.fail("Field reports null but object is not null (class " + expectedObject.getClass().getName() + ", " + expectedObject.toString() + ")"); + } + } + if (!VerifyLazy.lazyCompare(typeInfo, complexFieldObj, expectedObject)) { + TestCase.fail("Comparision failed typeInfo " + typeInfo.toString()); + } + } + } + private byte[] copyBytes(Text serialized) { byte[] result = new byte[serialized.getLength()]; System.arraycopy(serialized.getBytes(), 0, result, 0, serialized.getLength()); @@ -238,19 +271,25 @@ private LazySimpleSerDe getSerDe(String fieldNames, String fieldTypes) throws Se return serDe; } - private LazySerDeParameters getSerDeParams(String fieldNames, String fieldTypes) throws SerDeException { + private LazySerDeParameters getSerDeParams(String fieldNames, String fieldTypes, + byte[] separators) throws SerDeException { Configuration conf = new Configuration(); Properties tbl = createProperties(fieldNames, fieldTypes); - return new LazySerDeParameters(conf, tbl, LazySimpleSerDe.class.getName()); + LazySerDeParameters lazySerDeParams = new LazySerDeParameters(conf, tbl, LazySimpleSerDe.class.getName()); + for (int i = 0; i < separators.length; i++) { + lazySerDeParams.setSeparator(i, separators[i]); + } + return lazySerDeParams; } - public void testLazySimpleFastCase(int caseNum, boolean doNonRandomFill, Random r) + public void testLazySimpleFastCase( + int caseNum, boolean doNonRandomFill, Random r, SerdeRandomRowSource.SupportedTypes supportedTypes, int depth) throws Throwable { SerdeRandomRowSource source = new SerdeRandomRowSource(); - source.init(r); + source.init(r, supportedTypes, depth); - int rowCount = 1000; + int rowCount = 100; Object[][] rows = source.randomRows(rowCount); if (doNonRandomFill) { @@ -259,8 +298,8 @@ public void testLazySimpleFastCase(int caseNum, boolean doNonRandomFill, Random StructObjectInspector rowStructObjectInspector = source.rowStructObjectInspector(); - PrimitiveTypeInfo[] primitiveTypeInfos = source.primitiveTypeInfos(); - int columnCount = primitiveTypeInfos.length; + TypeInfo[] typeInfos = source.typeInfos(); + int columnCount = typeInfos.length; int writeColumnCount = columnCount; StructObjectInspector writeRowStructObjectInspector = rowStructObjectInspector; @@ -277,8 +316,11 @@ public void testLazySimpleFastCase(int caseNum, boolean doNonRandomFill, Random String fieldNames = ObjectInspectorUtils.getFieldNames(rowStructObjectInspector); String fieldTypes = ObjectInspectorUtils.getFieldTypes(rowStructObjectInspector); + // Use different separator values. + byte[] separators = new byte[] {(byte) 9, (byte) 2, (byte) 3, (byte) 4, (byte) 5, (byte) 6, (byte) 7, (byte) 8}; + LazySimpleSerDe serde = getSerDe(fieldNames, fieldTypes); - LazySerDeParameters serdeParams = getSerDeParams(fieldNames, fieldTypes); + LazySerDeParameters serdeParams = getSerDeParams(fieldNames, fieldTypes, separators); LazySimpleSerDe serde_fewer = null; LazySerDeParameters serdeParams_fewer = null; @@ -287,22 +329,22 @@ public void testLazySimpleFastCase(int caseNum, boolean doNonRandomFill, Random String partialFieldTypes = ObjectInspectorUtils.getFieldTypes(writeRowStructObjectInspector); serde_fewer = getSerDe(fieldNames, fieldTypes); - serdeParams_fewer = getSerDeParams(partialFieldNames, partialFieldTypes); + serdeParams_fewer = getSerDeParams(partialFieldNames, partialFieldTypes, separators); } - byte separator = (byte) '\t'; + testLazySimpleFast( source, rows, serde, rowStructObjectInspector, serde_fewer, writeRowStructObjectInspector, - separator, serdeParams, serdeParams_fewer, primitiveTypeInfos, + serdeParams, serdeParams_fewer, typeInfos, /* useIncludeColumns */ false, /* doWriteFewerColumns */ false, r); testLazySimpleFast( source, rows, serde, rowStructObjectInspector, serde_fewer, writeRowStructObjectInspector, - separator, serdeParams, serdeParams_fewer, primitiveTypeInfos, + serdeParams, serdeParams_fewer, typeInfos, /* useIncludeColumns */ true, /* doWriteFewerColumns */ false, r); if (doWriteFewerColumns) { @@ -310,26 +352,26 @@ public void testLazySimpleFastCase(int caseNum, boolean doNonRandomFill, Random source, rows, serde, rowStructObjectInspector, serde_fewer, writeRowStructObjectInspector, - separator, serdeParams, serdeParams_fewer, primitiveTypeInfos, + serdeParams, serdeParams_fewer, typeInfos, /* useIncludeColumns */ false, /* doWriteFewerColumns */ true, r); testLazySimpleFast( source, rows, serde, rowStructObjectInspector, serde_fewer, writeRowStructObjectInspector, - separator, serdeParams, serdeParams_fewer, primitiveTypeInfos, + serdeParams, serdeParams_fewer, typeInfos, /* useIncludeColumns */ true, /* doWriteFewerColumns */ true, r); } } - public void testLazySimpleFast() throws Throwable { + public void testLazySimpleFast(SerdeRandomRowSource.SupportedTypes supportedTypes, int depth) throws Throwable { try { - Random r = new Random(35790); + Random r = new Random(8322); int caseNum = 0; - for (int i = 0; i < 10; i++) { - testLazySimpleFastCase(caseNum, (i % 2 == 0), r); + for (int i = 0; i < 20; i++) { + testLazySimpleFastCase(caseNum, (i % 2 == 0), r, supportedTypes, depth); caseNum++; } @@ -338,4 +380,16 @@ public void testLazySimpleFast() throws Throwable { throw e; } } + + public void testLazyBinarySimplePrimitive() throws Throwable { + testLazySimpleFast(SerdeRandomRowSource.SupportedTypes.PRIMITIVE, 0); + } + + public void testLazyBinarySimpleComplexDepthOne() throws Throwable { + testLazySimpleFast(SerdeRandomRowSource.SupportedTypes.ALL, 1); + } + + public void testLazyBinarySimpleComplexDepthFour() throws Throwable { + testLazySimpleFast(SerdeRandomRowSource.SupportedTypes.ALL, 4); + } } \ No newline at end of file diff --git serde/src/test/org/apache/hadoop/hive/serde2/lazybinary/TestLazyBinaryFast.java serde/src/test/org/apache/hadoop/hive/serde2/lazybinary/TestLazyBinaryFast.java index e62a80a1d6..631650d235 100644 --- serde/src/test/org/apache/hadoop/hive/serde2/lazybinary/TestLazyBinaryFast.java +++ serde/src/test/org/apache/hadoop/hive/serde2/lazybinary/TestLazyBinaryFast.java @@ -17,6 +17,8 @@ */ package org.apache.hadoop.hive.serde2.lazybinary; +import java.io.IOException; +import java.util.ArrayList; import java.util.Arrays; import java.util.Random; @@ -27,11 +29,14 @@ import org.apache.hadoop.hive.serde2.SerdeRandomRowSource; import org.apache.hadoop.hive.serde2.VerifyFast; import org.apache.hadoop.hive.serde2.binarysortable.MyTestClass; +import org.apache.hadoop.hive.serde2.lazy.VerifyLazy; import org.apache.hadoop.hive.serde2.lazybinary.fast.LazyBinaryDeserializeRead; import org.apache.hadoop.hive.serde2.lazybinary.fast.LazyBinarySerializeWrite; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils; import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector; -import org.apache.hadoop.hive.serde2.typeinfo.PrimitiveTypeInfo; +import org.apache.hadoop.hive.serde2.objectinspector.UnionObject; +import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo; import org.apache.hadoop.io.BytesWritable; import org.apache.hadoop.io.Writable; @@ -41,11 +46,11 @@ private void testLazyBinaryFast( SerdeRandomRowSource source, Object[][] rows, AbstractSerDe serde, StructObjectInspector rowOI, AbstractSerDe serde_fewer, StructObjectInspector writeRowOI, - PrimitiveTypeInfo[] primitiveTypeInfos, + TypeInfo[] typeInfos, boolean useIncludeColumns, boolean doWriteFewerColumns, Random r) throws Throwable { int rowCount = rows.length; - int columnCount = primitiveTypeInfos.length; + int columnCount = typeInfos.length; boolean[] columnsToInclude = null; if (useIncludeColumns) { @@ -56,10 +61,10 @@ private void testLazyBinaryFast( } int writeColumnCount = columnCount; - PrimitiveTypeInfo[] writePrimitiveTypeInfos = primitiveTypeInfos; + TypeInfo[] writeTypeInfos = typeInfos; if (doWriteFewerColumns) { writeColumnCount = writeRowOI.getAllStructFieldRefs().size(); - writePrimitiveTypeInfos = Arrays.copyOf(primitiveTypeInfos, writeColumnCount); + writeTypeInfos = Arrays.copyOf(typeInfos, writeColumnCount); } LazyBinarySerializeWrite lazyBinarySerializeWrite = @@ -73,10 +78,7 @@ private void testLazyBinaryFast( lazyBinarySerializeWrite.set(output); for (int index = 0; index < writeColumnCount; index++) { - - Writable writable = (Writable) row[index]; - - VerifyFast.serializeWrite(lazyBinarySerializeWrite, primitiveTypeInfos[index], writable); + VerifyFast.serializeWrite(lazyBinarySerializeWrite, typeInfos[index], row[index]); } BytesWritable bytesWritable = new BytesWritable(); @@ -92,7 +94,7 @@ private void testLazyBinaryFast( // column. LazyBinaryDeserializeRead lazyBinaryDeserializeRead = new LazyBinaryDeserializeRead( - writePrimitiveTypeInfos, + writeTypeInfos, /* useExternalBuffer */ false); BytesWritable bytesWritable = serializeWriteBytes[i]; @@ -103,10 +105,9 @@ private void testLazyBinaryFast( lazyBinaryDeserializeRead.skipNextField(); } else if (index >= writeColumnCount) { // Should come back a null. - VerifyFast.verifyDeserializeRead(lazyBinaryDeserializeRead, primitiveTypeInfos[index], null); + VerifyFast.verifyDeserializeRead(lazyBinaryDeserializeRead, typeInfos[index], null); } else { - Writable writable = (Writable) row[index]; - VerifyFast.verifyDeserializeRead(lazyBinaryDeserializeRead, primitiveTypeInfos[index], writable); + verifyRead(lazyBinaryDeserializeRead, typeInfos[index], row[index]); } } if (writeColumnCount == columnCount) { @@ -127,15 +128,14 @@ private void testLazyBinaryFast( Object[] row = rows[i]; for (int index = 0; index < writeColumnCount; index++) { - PrimitiveTypeInfo primitiveTypeInfo = primitiveTypeInfos[index]; - Writable writable = (Writable) row[index]; + TypeInfo typeInfo = typeInfos[index]; Object object = lazyBinaryStruct.getField(index); - if (writable == null || object == null) { - if (writable != null || object != null) { + if (row[index] == null || object == null) { + if (row[index] != null || object != null) { fail("SerDe deserialized NULL column mismatch"); } } else { - if (!object.equals(writable)) { + if (!VerifyLazy.lazyCompare(typeInfo, object, row[index])) { fail("SerDe deserialized value does not match"); } } @@ -172,10 +172,10 @@ private void testLazyBinaryFast( if (bytes1.length != bytes2.length) { fail("SerializeWrite length " + bytes2.length + " and " + "SerDe serialization length " + bytes1.length + - " do not match (" + Arrays.toString(primitiveTypeInfos) + ")"); + " do not match (" + Arrays.toString(typeInfos) + ")"); } if (!Arrays.equals(bytes1, bytes2)) { - fail("SerializeWrite and SerDe serialization does not match (" + Arrays.toString(primitiveTypeInfos) + ")"); + fail("SerializeWrite and SerDe serialization does not match (" + Arrays.toString(typeInfos) + ")"); } serdeBytes[i] = bytesWritable; } @@ -187,7 +187,7 @@ private void testLazyBinaryFast( // When doWriteFewerColumns, try to read more fields than exist in buffer. LazyBinaryDeserializeRead lazyBinaryDeserializeRead = new LazyBinaryDeserializeRead( - primitiveTypeInfos, + typeInfos, /* useExternalBuffer */ false); BytesWritable bytesWritable = serdeBytes[i]; @@ -198,10 +198,9 @@ private void testLazyBinaryFast( lazyBinaryDeserializeRead.skipNextField(); } else if (index >= writeColumnCount) { // Should come back a null. - VerifyFast.verifyDeserializeRead(lazyBinaryDeserializeRead, primitiveTypeInfos[index], null); + VerifyFast.verifyDeserializeRead(lazyBinaryDeserializeRead, typeInfos[index], null); } else { - Writable writable = (Writable) row[index]; - VerifyFast.verifyDeserializeRead(lazyBinaryDeserializeRead, primitiveTypeInfos[index], writable); + verifyRead(lazyBinaryDeserializeRead, typeInfos[index], row[index]); } } if (writeColumnCount == columnCount) { @@ -210,12 +209,48 @@ private void testLazyBinaryFast( } } - public void testLazyBinaryFastCase(int caseNum, boolean doNonRandomFill, Random r) throws Throwable { + private void verifyRead(LazyBinaryDeserializeRead lazyBinaryDeserializeRead, + TypeInfo typeInfo, Object expectedObject) throws IOException { + if (typeInfo.getCategory() == ObjectInspector.Category.PRIMITIVE) { + VerifyFast.verifyDeserializeRead(lazyBinaryDeserializeRead, typeInfo, expectedObject); + } else { + if (expectedObject instanceof ArrayList && ((ArrayList) expectedObject).size() == 0) { +// fake++; + } + Object complexFieldObj = VerifyFast.deserializeReadComplexType(lazyBinaryDeserializeRead, typeInfo); + if (expectedObject == null) { + if (complexFieldObj != null) { + TestCase.fail("Field reports not null but object is null (class " + complexFieldObj.getClass().getName() + + ", " + complexFieldObj.toString() + ")"); + } + } else { + if (complexFieldObj == null) { + // It's hard to distinguish a union with null from a null union. + if (expectedObject instanceof UnionObject) { + UnionObject expectedUnion = (UnionObject) expectedObject; + if (expectedUnion.getObject() == null) { + return; + } + } + TestCase.fail("Field reports null but object is not null (class " + expectedObject.getClass().getName() + + ", " + expectedObject.toString() + ")"); + } + } + if (!VerifyLazy.lazyCompare(typeInfo, complexFieldObj, expectedObject)) { + TestCase.fail("Comparision failed typeInfo " + typeInfo.toString()); + } + } + } + + public void testLazyBinaryFastCase( + int caseNum, boolean doNonRandomFill, Random r, SerdeRandomRowSource.SupportedTypes supportedTypes, int depth) + throws Throwable { SerdeRandomRowSource source = new SerdeRandomRowSource(); - source.init(r); - int rowCount = 1000; + source.init(r, supportedTypes, depth); + + int rowCount = 100; Object[][] rows = source.randomRows(rowCount); if (doNonRandomFill) { @@ -224,8 +259,8 @@ public void testLazyBinaryFastCase(int caseNum, boolean doNonRandomFill, Random StructObjectInspector rowStructObjectInspector = source.rowStructObjectInspector(); - PrimitiveTypeInfo[] primitiveTypeInfos = source.primitiveTypeInfos(); - int columnCount = primitiveTypeInfos.length; + TypeInfo[] typeInfos = source.typeInfos(); + int columnCount = typeInfos.length; int writeColumnCount = columnCount; StructObjectInspector writeRowStructObjectInspector = rowStructObjectInspector; @@ -256,14 +291,14 @@ public void testLazyBinaryFastCase(int caseNum, boolean doNonRandomFill, Random source, rows, serde, rowStructObjectInspector, serde_fewer, writeRowStructObjectInspector, - primitiveTypeInfos, + typeInfos, /* useIncludeColumns */ false, /* doWriteFewerColumns */ false, r); testLazyBinaryFast( source, rows, serde, rowStructObjectInspector, serde_fewer, writeRowStructObjectInspector, - primitiveTypeInfos, + typeInfos, /* useIncludeColumns */ true, /* doWriteFewerColumns */ false, r); /* @@ -286,14 +321,13 @@ public void testLazyBinaryFastCase(int caseNum, boolean doNonRandomFill, Random // } } - public void testLazyBinaryFast() throws Throwable { - + private void testLazyBinaryFast(SerdeRandomRowSource.SupportedTypes supportedTypes, int depth) throws Throwable { try { - Random r = new Random(35790); + Random r = new Random(9983); int caseNum = 0; for (int i = 0; i < 10; i++) { - testLazyBinaryFastCase(caseNum, (i % 2 == 0), r); + testLazyBinaryFastCase(caseNum, (i % 2 == 0), r, supportedTypes, depth); caseNum++; } @@ -302,4 +336,16 @@ public void testLazyBinaryFast() throws Throwable { throw e; } } + + public void testLazyBinaryFastPrimitive() throws Throwable { + testLazyBinaryFast(SerdeRandomRowSource.SupportedTypes.PRIMITIVE, 0); + } + + public void testLazyBinaryFastComplexDepthOne() throws Throwable { + testLazyBinaryFast(SerdeRandomRowSource.SupportedTypes.ALL, 1); + } + + public void testLazyBinaryFastComplexDepthFour() throws Throwable { + testLazyBinaryFast(SerdeRandomRowSource.SupportedTypes.ALL, 4); + } } \ No newline at end of file