diff --git ql/src/java/org/apache/hadoop/hive/ql/exec/vector/VectorAssignRow.java ql/src/java/org/apache/hadoop/hive/ql/exec/vector/VectorAssignRow.java index 9c84937..9621483 100644 --- ql/src/java/org/apache/hadoop/hive/ql/exec/vector/VectorAssignRow.java +++ ql/src/java/org/apache/hadoop/hive/ql/exec/vector/VectorAssignRow.java @@ -18,12 +18,7 @@ package org.apache.hadoop.hive.ql.exec.vector; -import java.sql.Date; -import java.sql.Timestamp; -import java.util.List; - -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; +import com.google.common.base.Preconditions; import org.apache.hadoop.hive.common.type.HiveChar; import org.apache.hadoop.hive.common.type.HiveDecimal; import org.apache.hadoop.hive.common.type.HiveIntervalDayTime; @@ -41,18 +36,27 @@ import org.apache.hadoop.hive.serde2.io.HiveVarcharWritable; import org.apache.hadoop.hive.serde2.io.ShortWritable; import org.apache.hadoop.hive.serde2.io.TimestampWritable; +import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.MapObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector.Category; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory; import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector.PrimitiveCategory; +import org.apache.hadoop.hive.serde2.objectinspector.StandardUnionObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.StructField; import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.UnionObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory; import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils; import org.apache.hadoop.hive.serde2.typeinfo.CharTypeInfo; +import org.apache.hadoop.hive.serde2.typeinfo.ListTypeInfo; +import org.apache.hadoop.hive.serde2.typeinfo.MapTypeInfo; import org.apache.hadoop.hive.serde2.typeinfo.PrimitiveTypeInfo; +import org.apache.hadoop.hive.serde2.typeinfo.StructTypeInfo; import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo; import org.apache.hadoop.hive.serde2.typeinfo.TypeInfoUtils; +import org.apache.hadoop.hive.serde2.typeinfo.UnionTypeInfo; import org.apache.hadoop.hive.serde2.typeinfo.VarcharTypeInfo; import org.apache.hadoop.io.BooleanWritable; import org.apache.hadoop.io.BytesWritable; @@ -61,8 +65,15 @@ import org.apache.hadoop.io.LongWritable; import org.apache.hadoop.io.Text; import org.apache.hadoop.io.Writable; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; -import com.google.common.base.Preconditions; +import java.sql.Date; +import java.sql.Timestamp; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; /** @@ -89,11 +100,8 @@ // Assigning can be a subset of columns, so this is the projection -- // the batch column numbers. - Category[] targetCategories; - // The data type category of each column being assigned. - - PrimitiveCategory[] targetPrimitiveCategories; - // The data type primitive category of each column being assigned. + TypeInfo[] targetTypeInfos; + // The type information of each column being assigned. int[] maxLengths; // For the CHAR and VARCHAR data types, the maximum character length of @@ -103,11 +111,11 @@ * These members have information for data type conversion. * Not defined if there is no conversion. */ - PrimitiveObjectInspector[] convertSourcePrimitiveObjectInspectors; + Map objectInspectors; // The primitive object inspector of the source data type for any column being // converted. Otherwise, null. - Writable[] convertTargetWritables; + Map convertTargetWritables; // Conversion to the target data type requires a "helper" target writable in a // few cases. @@ -117,8 +125,7 @@ private void allocateArrays(int count) { isConvert = new boolean[count]; projectionColumnNums = new int[count]; - targetCategories = new Category[count]; - targetPrimitiveCategories = new PrimitiveCategory[count]; + targetTypeInfos = new TypeInfo[count]; maxLengths = new int[count]; } @@ -126,8 +133,8 @@ private void allocateArrays(int count) { * Allocate the source conversion related arrays (optional). */ private void allocateConvertArrays(int count) { - convertSourcePrimitiveObjectInspectors = new PrimitiveObjectInspector[count]; - convertTargetWritables = new Writable[count]; + objectInspectors = new HashMap(count); + convertTargetWritables = new HashMap(); } /* @@ -137,11 +144,11 @@ private void initTargetEntry(int logicalColumnIndex, int projectionColumnNum, Ty isConvert[logicalColumnIndex] = false; projectionColumnNums[logicalColumnIndex] = projectionColumnNum; Category category = typeInfo.getCategory(); - targetCategories[logicalColumnIndex] = category; + targetTypeInfos[logicalColumnIndex] = typeInfo; + if (category == Category.PRIMITIVE) { PrimitiveTypeInfo primitiveTypeInfo = (PrimitiveTypeInfo) typeInfo; PrimitiveCategory primitiveCategory = primitiveTypeInfo.getPrimitiveCategory(); - targetPrimitiveCategories[logicalColumnIndex] = primitiveCategory; switch (primitiveCategory) { case CHAR: maxLengths[logicalColumnIndex] = ((CharTypeInfo) primitiveTypeInfo).getLength(); @@ -162,27 +169,51 @@ private void initTargetEntry(int logicalColumnIndex, int projectionColumnNum, Ty */ private void initConvertSourceEntry(int logicalColumnIndex, TypeInfo convertSourceTypeInfo) { isConvert[logicalColumnIndex] = true; - Category convertSourceCategory = convertSourceTypeInfo.getCategory(); - if (convertSourceCategory == Category.PRIMITIVE) { - PrimitiveTypeInfo convertSourcePrimitiveTypeInfo = (PrimitiveTypeInfo) convertSourceTypeInfo; - convertSourcePrimitiveObjectInspectors[logicalColumnIndex] = - PrimitiveObjectInspectorFactory.getPrimitiveWritableObjectInspector( - convertSourcePrimitiveTypeInfo); - - // These need to be based on the target. - PrimitiveCategory targetPrimitiveCategory = targetPrimitiveCategories[logicalColumnIndex]; - switch (targetPrimitiveCategory) { - case DATE: - convertTargetWritables[logicalColumnIndex] = new DateWritable(); - break; - case STRING: - convertTargetWritables[logicalColumnIndex] = new Text(); - break; - default: - // No additional data type specific setting. - break; + objectInspectors.put(convertSourceTypeInfo, + createObjectInspector(convertSourceTypeInfo)); + } + + private ObjectInspector createObjectInspector(TypeInfo typeInfo) { + switch (typeInfo.getCategory()) { + case PRIMITIVE: + return PrimitiveObjectInspectorFactory.getPrimitiveWritableObjectInspector( + (PrimitiveTypeInfo) typeInfo); + case LIST: + return ObjectInspectorFactory.getStandardListObjectInspector( + createObjectInspector(((ListTypeInfo) typeInfo).getListElementTypeInfo())); + case MAP: + { + MapTypeInfo mapTypeInfo = (MapTypeInfo) typeInfo; + ObjectInspector keyObjectInspector = + createObjectInspector(mapTypeInfo.getMapKeyTypeInfo()); + ObjectInspector valueObjectInspector = + createObjectInspector(mapTypeInfo.getMapValueTypeInfo()); + return ObjectInspectorFactory.getStandardMapObjectInspector( + keyObjectInspector, valueObjectInspector); + } + case STRUCT: + { + StructTypeInfo structTypeInfo = (StructTypeInfo) typeInfo; + List objectInspectors = new ArrayList(); + for (TypeInfo fieldTypeInfo : structTypeInfo.getAllStructFieldTypeInfos()) { + objectInspectors.add(createObjectInspector(fieldTypeInfo)); + } + return ObjectInspectorFactory.getStandardStructObjectInspector( + structTypeInfo.getAllStructFieldNames(), objectInspectors); + } + case UNION: + { + List unionObjectTypeInfos = + ((UnionTypeInfo) typeInfo).getAllUnionObjectTypeInfos(); + List objectInspectors = + new ArrayList(unionObjectTypeInfos.size()); + for (TypeInfo unionObjectTypeInfo : unionObjectTypeInfos) { + objectInspectors.add(createObjectInspector(unionObjectTypeInfo)); + } + return ObjectInspectorFactory.getStandardUnionObjectInspector(objectInspectors); } } + return null; } /* @@ -335,7 +366,8 @@ public int initConversion(TypeInfo[] sourceTypeInfos, TypeInfo[] targetTypeInfos */ public void assignRowColumn(VectorizedRowBatch batch, int batchIndex, int logicalColumnIndex, Object object) { - Category targetCategory = targetCategories[logicalColumnIndex]; + TypeInfo logicalTypeInfo = targetTypeInfos[logicalColumnIndex]; + Category targetCategory = logicalTypeInfo.getCategory(); if (targetCategory == null) { /* * This is a column that we don't want (i.e. not included) -- we are done. @@ -343,66 +375,63 @@ public void assignRowColumn(VectorizedRowBatch batch, int batchIndex, int logica return; } final int projectionColumnNum = projectionColumnNums[logicalColumnIndex]; + ColumnVector projectionColumnVector = batch.cols[projectionColumnNum]; + assignRowColumn(batchIndex, logicalTypeInfo, projectionColumnVector, object); + } + + public void assignRowColumn(int batchIndex, TypeInfo logicalTypeInfo, + ColumnVector projectionColumnVector, Object object) { if (object == null) { - VectorizedBatchUtil.setNullColIsNullValue(batch.cols[projectionColumnNum], batchIndex); + VectorizedBatchUtil.setNullColIsNullValue(projectionColumnVector, batchIndex); return; } + Category targetCategory = logicalTypeInfo.getCategory(); switch (targetCategory) { case PRIMITIVE: { - PrimitiveCategory targetPrimitiveCategory = targetPrimitiveCategories[logicalColumnIndex]; + PrimitiveCategory targetPrimitiveCategory = ((PrimitiveTypeInfo) logicalTypeInfo).getPrimitiveCategory(); switch (targetPrimitiveCategory) { case VOID: - VectorizedBatchUtil.setNullColIsNullValue(batch.cols[projectionColumnNum], batchIndex); + VectorizedBatchUtil.setNullColIsNullValue(projectionColumnVector, batchIndex); return; case BOOLEAN: - ((LongColumnVector) batch.cols[projectionColumnNum]).vector[batchIndex] = - (((BooleanWritable) object).get() ? 1 : 0); + ((LongColumnVector) projectionColumnVector).vector[batchIndex] = (((BooleanWritable) object).get() ? 1 : 0); break; case BYTE: - ((LongColumnVector) batch.cols[projectionColumnNum]).vector[batchIndex] = - ((ByteWritable) object).get(); + ((LongColumnVector) projectionColumnVector).vector[batchIndex] = ((ByteWritable) object).get(); break; case SHORT: - ((LongColumnVector) batch.cols[projectionColumnNum]).vector[batchIndex] = - ((ShortWritable) object).get(); + ((LongColumnVector) projectionColumnVector).vector[batchIndex] = ((ShortWritable) object).get(); break; case INT: - ((LongColumnVector) batch.cols[projectionColumnNum]).vector[batchIndex] = - ((IntWritable) object).get(); + ((LongColumnVector) projectionColumnVector).vector[batchIndex] = ((IntWritable) object).get(); break; case LONG: - ((LongColumnVector) batch.cols[projectionColumnNum]).vector[batchIndex] = - ((LongWritable) object).get(); + ((LongColumnVector) projectionColumnVector).vector[batchIndex] = ((LongWritable) object).get(); break; case TIMESTAMP: - ((TimestampColumnVector) batch.cols[projectionColumnNum]).set( - batchIndex, ((TimestampWritable) object).getTimestamp()); + ((TimestampColumnVector) projectionColumnVector).set(batchIndex, ((TimestampWritable) object).getTimestamp()); break; case DATE: - ((LongColumnVector) batch.cols[projectionColumnNum]).vector[batchIndex] = - ((DateWritable) object).getDays(); + ((LongColumnVector) projectionColumnVector).vector[batchIndex] = ((DateWritable) object).getDays(); break; case FLOAT: - ((DoubleColumnVector) batch.cols[projectionColumnNum]).vector[batchIndex] = - ((FloatWritable) object).get(); + ((DoubleColumnVector) projectionColumnVector).vector[batchIndex] = ((FloatWritable) object).get(); break; case DOUBLE: - ((DoubleColumnVector) batch.cols[projectionColumnNum]).vector[batchIndex] = - ((DoubleWritable) object).get(); + ((DoubleColumnVector) projectionColumnVector).vector[batchIndex] = ((DoubleWritable) object).get(); break; case BINARY: { BytesWritable bw = (BytesWritable) object; - ((BytesColumnVector) batch.cols[projectionColumnNum]).setVal( + ((BytesColumnVector) projectionColumnVector).setVal( batchIndex, bw.getBytes(), 0, bw.getLength()); } break; case STRING: { Text tw = (Text) object; - ((BytesColumnVector) batch.cols[projectionColumnNum]).setVal( - batchIndex, tw.getBytes(), 0, tw.getLength()); + ((BytesColumnVector) projectionColumnVector).setVal(batchIndex, tw.getBytes(), 0, tw.getLength()); } break; case VARCHAR: @@ -420,8 +449,7 @@ public void assignRowColumn(VectorizedRowBatch batch, int batchIndex, int logica // TODO: HIVE-13624 Do we need maxLength checking? byte[] bytes = hiveVarchar.getValue().getBytes(); - ((BytesColumnVector) batch.cols[projectionColumnNum]).setVal( - batchIndex, bytes, 0, bytes.length); + ((BytesColumnVector) projectionColumnVector).setVal(batchIndex, bytes, 0, bytes.length); } break; case CHAR: @@ -440,33 +468,102 @@ public void assignRowColumn(VectorizedRowBatch batch, int batchIndex, int logica // We store CHAR in vector row batch with padding stripped. byte[] bytes = hiveChar.getStrippedValue().getBytes(); - ((BytesColumnVector) batch.cols[projectionColumnNum]).setVal( - batchIndex, bytes, 0, bytes.length); + ((BytesColumnVector) projectionColumnVector).setVal(batchIndex, bytes, 0, bytes.length); } break; case DECIMAL: if (object instanceof HiveDecimal) { - ((DecimalColumnVector) batch.cols[projectionColumnNum]).set( - batchIndex, (HiveDecimal) object); + ((DecimalColumnVector) projectionColumnVector).set(batchIndex, (HiveDecimal) object); } else { - ((DecimalColumnVector) batch.cols[projectionColumnNum]).set( - batchIndex, (HiveDecimalWritable) object); + ((DecimalColumnVector) projectionColumnVector).set(batchIndex, (HiveDecimalWritable) object); } break; case INTERVAL_YEAR_MONTH: - ((LongColumnVector) batch.cols[projectionColumnNum]).vector[batchIndex] = + ((LongColumnVector) projectionColumnVector).vector[batchIndex] = ((HiveIntervalYearMonthWritable) object).getHiveIntervalYearMonth().getTotalMonths(); break; case INTERVAL_DAY_TIME: - ((IntervalDayTimeColumnVector) batch.cols[projectionColumnNum]).set( + ((IntervalDayTimeColumnVector) projectionColumnVector).set( batchIndex, ((HiveIntervalDayTimeWritable) object).getHiveIntervalDayTime()); break; default: throw new RuntimeException("Primitive category " + targetPrimitiveCategory.name() + " not supported"); + } + break; + } + case LIST: + { + ListColumnVector listColumnVector = (ListColumnVector) projectionColumnVector; + ColumnVector childColumnVector = listColumnVector.child; + TypeInfo elementTypeInfo = ((ListTypeInfo) logicalTypeInfo).getListElementTypeInfo(); + + List list = (List) object; + int size = list.size(); + int offset = listColumnVector.childCount; + listColumnVector.offsets[batchIndex] = offset; + listColumnVector.lengths[batchIndex] = size; + listColumnVector.childCount += size; + childColumnVector.ensureSize(offset + size, true); + + for (int i = 0; i < size; i++) { + assignRowColumn(offset + i, elementTypeInfo, childColumnVector, list.get(i)); + } + } + break; + case MAP: + { + MapColumnVector mapColumnVector = (MapColumnVector) projectionColumnVector; + ColumnVector keyColumnVector = mapColumnVector.keys; + ColumnVector valueColumnVector = mapColumnVector.values; + MapTypeInfo mapTypeInfo = (MapTypeInfo) logicalTypeInfo; + TypeInfo keyTypeInfo = mapTypeInfo.getMapKeyTypeInfo(); + TypeInfo valueTypeInfo = mapTypeInfo.getMapValueTypeInfo(); + + Map map = (Map) object; + List entries = new ArrayList(map.entrySet()); + int size = map.size(); + int offset = mapColumnVector.childCount; + mapColumnVector.offsets[batchIndex] = offset; + mapColumnVector.lengths[batchIndex] = size; + mapColumnVector.childCount += size; + keyColumnVector.ensureSize(offset + size, true); + valueColumnVector.ensureSize(offset + size, true); + + for (int i = 0; i < size; i++) { + Map.Entry entry = entries.get(i); + assignRowColumn(offset + i, keyTypeInfo, keyColumnVector, entry.getKey()); + assignRowColumn(offset + i, valueTypeInfo, valueColumnVector, entry.getValue()); + } + } + break; + case STRUCT: + { + StructColumnVector structColumnVector = (StructColumnVector) projectionColumnVector; + ColumnVector[] fieldColumnVectors = structColumnVector.fields; + StructTypeInfo structTypeInfo = (StructTypeInfo) logicalTypeInfo; + List typeInfos = structTypeInfo.getAllStructFieldTypeInfos(); + + List list = (List) object; + int size = list.size(); + + for (int i = 0; i < size; i++) { + assignRowColumn(batchIndex, typeInfos.get(i), fieldColumnVectors[i], list.get(i)); } } break; + case UNION: + { + UnionColumnVector unionColumnVector = (UnionColumnVector) projectionColumnVector; + ColumnVector[] fieldColumnVectors = unionColumnVector.fields; + StandardUnionObjectInspector.StandardUnion union = (StandardUnionObjectInspector.StandardUnion) object; + byte tag = union.getTag(); + UnionTypeInfo unionTypeInfo = (UnionTypeInfo) logicalTypeInfo; + unionColumnVector.tags[batchIndex] = tag; + assignRowColumn(batchIndex, unionTypeInfo.getAllUnionObjectTypeInfos().get(tag), + fieldColumnVectors[tag], union.getObject()); + } + break; default: throw new RuntimeException("Category " + targetCategory.name() + " not supported"); } @@ -474,7 +571,7 @@ public void assignRowColumn(VectorizedRowBatch batch, int batchIndex, int logica /* * We always set the null flag to false when there is a value. */ - batch.cols[projectionColumnNum].isNull[batchIndex] = false; + projectionColumnVector.isNull[batchIndex] = false; } /** @@ -493,112 +590,118 @@ public void assignRowColumn(VectorizedRowBatch batch, int batchIndex, int logica public void assignConvertRowColumn(VectorizedRowBatch batch, int batchIndex, int logicalColumnIndex, Object object) { Preconditions.checkState(isConvert[logicalColumnIndex]); - Category targetCategory = targetCategories[logicalColumnIndex]; + + final int projectionColumnNum = projectionColumnNums[logicalColumnIndex]; + ColumnVector projectionColumnVector = batch.cols[projectionColumnNum]; + ObjectInspector logicalObjectInspector = objectInspectors.get(projectionColumnVector); + + TypeInfo typeInfo = targetTypeInfos[logicalColumnIndex]; + Category targetCategory = typeInfo.getCategory(); if (targetCategory == null) { /* * This is a column that we don't want (i.e. not included) -- we are done. */ return; } - final int projectionColumnNum = projectionColumnNums[logicalColumnIndex]; if (object == null) { - VectorizedBatchUtil.setNullColIsNullValue(batch.cols[projectionColumnNum], batchIndex); + VectorizedBatchUtil.setNullColIsNullValue(projectionColumnVector, batchIndex); return; } + assignConvertRowColumn(batchIndex, typeInfo, logicalObjectInspector, projectionColumnVector, object); + } + + public void assignConvertRowColumn(int batchIndex, TypeInfo typeInfo, + ObjectInspector objectInspector, ColumnVector columnVector, Object object) { + Category targetCategory = typeInfo.getCategory(); try { switch (targetCategory) { case PRIMITIVE: - PrimitiveCategory targetPrimitiveCategory = targetPrimitiveCategories[logicalColumnIndex]; + PrimitiveCategory targetPrimitiveCategory = ((PrimitiveTypeInfo) typeInfo).getPrimitiveCategory(); + PrimitiveObjectInspector primitiveLogicalOI = (PrimitiveObjectInspector) objectInspector; switch (targetPrimitiveCategory) { case VOID: - VectorizedBatchUtil.setNullColIsNullValue(batch.cols[projectionColumnNum], batchIndex); + VectorizedBatchUtil.setNullColIsNullValue(columnVector, batchIndex); return; case BOOLEAN: - ((LongColumnVector) batch.cols[projectionColumnNum]).vector[batchIndex] = - (PrimitiveObjectInspectorUtils.getBoolean( - object, convertSourcePrimitiveObjectInspectors[logicalColumnIndex]) ? 1 : 0); + ((LongColumnVector) columnVector).vector[batchIndex] = + (PrimitiveObjectInspectorUtils.getBoolean(object, primitiveLogicalOI) ? 1 : 0); break; case BYTE: - ((LongColumnVector) batch.cols[projectionColumnNum]).vector[batchIndex] = - PrimitiveObjectInspectorUtils.getByte( - object, convertSourcePrimitiveObjectInspectors[logicalColumnIndex]); + ((LongColumnVector) columnVector).vector[batchIndex] = + PrimitiveObjectInspectorUtils.getByte(object, primitiveLogicalOI); break; case SHORT: - ((LongColumnVector) batch.cols[projectionColumnNum]).vector[batchIndex] = - PrimitiveObjectInspectorUtils.getShort( - object, convertSourcePrimitiveObjectInspectors[logicalColumnIndex]); + ((LongColumnVector) columnVector).vector[batchIndex] = + PrimitiveObjectInspectorUtils.getShort(object, primitiveLogicalOI); break; case INT: - ((LongColumnVector) batch.cols[projectionColumnNum]).vector[batchIndex] = - PrimitiveObjectInspectorUtils.getInt( - object, convertSourcePrimitiveObjectInspectors[logicalColumnIndex]); + ((LongColumnVector) columnVector).vector[batchIndex] = + PrimitiveObjectInspectorUtils.getInt(object, primitiveLogicalOI); break; case LONG: - ((LongColumnVector) batch.cols[projectionColumnNum]).vector[batchIndex] = - PrimitiveObjectInspectorUtils.getLong( - object, convertSourcePrimitiveObjectInspectors[logicalColumnIndex]); + ((LongColumnVector) columnVector).vector[batchIndex] = + PrimitiveObjectInspectorUtils.getLong(object, primitiveLogicalOI); break; case TIMESTAMP: { Timestamp timestamp = - PrimitiveObjectInspectorUtils.getTimestamp( - object, convertSourcePrimitiveObjectInspectors[logicalColumnIndex]); + PrimitiveObjectInspectorUtils.getTimestamp(object, primitiveLogicalOI); if (timestamp == null) { - VectorizedBatchUtil.setNullColIsNullValue(batch.cols[projectionColumnNum], batchIndex); + VectorizedBatchUtil.setNullColIsNullValue(columnVector, batchIndex); return; } - ((TimestampColumnVector) batch.cols[projectionColumnNum]).set( - batchIndex, timestamp); + ((TimestampColumnVector) columnVector).set(batchIndex, timestamp); } break; case DATE: { - Date date = PrimitiveObjectInspectorUtils.getDate( - object, convertSourcePrimitiveObjectInspectors[logicalColumnIndex]); + Date date = PrimitiveObjectInspectorUtils.getDate(object, primitiveLogicalOI); if (date == null) { - VectorizedBatchUtil.setNullColIsNullValue(batch.cols[projectionColumnNum], batchIndex); + VectorizedBatchUtil.setNullColIsNullValue(columnVector, batchIndex); return; } - DateWritable dateWritable = (DateWritable) convertTargetWritables[logicalColumnIndex]; + if (!convertTargetWritables.containsKey(columnVector)) { + convertTargetWritables.put(columnVector, new DateWritable()); + } + DateWritable dateWritable = (DateWritable) convertTargetWritables.get(columnVector); dateWritable.set(date); - ((LongColumnVector) batch.cols[projectionColumnNum]).vector[batchIndex] = + ((LongColumnVector) columnVector).vector[batchIndex] = dateWritable.getDays(); } break; case FLOAT: - ((DoubleColumnVector) batch.cols[projectionColumnNum]).vector[batchIndex] = - PrimitiveObjectInspectorUtils.getFloat( - object, convertSourcePrimitiveObjectInspectors[logicalColumnIndex]); + ((DoubleColumnVector) columnVector).vector[batchIndex] = + PrimitiveObjectInspectorUtils.getFloat(object, primitiveLogicalOI); break; case DOUBLE: - ((DoubleColumnVector) batch.cols[projectionColumnNum]).vector[batchIndex] = - PrimitiveObjectInspectorUtils.getDouble( - object, convertSourcePrimitiveObjectInspectors[logicalColumnIndex]); + ((DoubleColumnVector) columnVector).vector[batchIndex] = + PrimitiveObjectInspectorUtils.getDouble(object, primitiveLogicalOI); break; case BINARY: { BytesWritable bytesWritable = - PrimitiveObjectInspectorUtils.getBinary( - object, convertSourcePrimitiveObjectInspectors[logicalColumnIndex]); + PrimitiveObjectInspectorUtils.getBinary(object, primitiveLogicalOI); if (bytesWritable == null) { - VectorizedBatchUtil.setNullColIsNullValue(batch.cols[projectionColumnNum], batchIndex); + VectorizedBatchUtil.setNullColIsNullValue(columnVector, batchIndex); return; } - ((BytesColumnVector) batch.cols[projectionColumnNum]).setVal( + ((BytesColumnVector) columnVector).setVal( batchIndex, bytesWritable.getBytes(), 0, bytesWritable.getLength()); } break; case STRING: { - String string = PrimitiveObjectInspectorUtils.getString( - object, convertSourcePrimitiveObjectInspectors[logicalColumnIndex]); + String string = PrimitiveObjectInspectorUtils.getString(object, primitiveLogicalOI); if (string == null) { - VectorizedBatchUtil.setNullColIsNullValue(batch.cols[projectionColumnNum], batchIndex); + VectorizedBatchUtil.setNullColIsNullValue(columnVector, batchIndex); return; } - Text text = (Text) convertTargetWritables[logicalColumnIndex]; + if (!convertTargetWritables.containsKey(columnVector)) { + convertTargetWritables.put(columnVector, new Text()); + } + Text text = (Text) convertTargetWritables.get(columnVector); text.set(string); - ((BytesColumnVector) batch.cols[projectionColumnNum]).setVal( + ((BytesColumnVector) columnVector).setVal( batchIndex, text.getBytes(), 0, text.getLength()); } break; @@ -607,18 +710,16 @@ public void assignConvertRowColumn(VectorizedRowBatch batch, int batchIndex, // UNDONE: Performance problem with conversion to String, then bytes... HiveVarchar hiveVarchar = - PrimitiveObjectInspectorUtils.getHiveVarchar( - object, convertSourcePrimitiveObjectInspectors[logicalColumnIndex]); + PrimitiveObjectInspectorUtils.getHiveVarchar(object, primitiveLogicalOI); if (hiveVarchar == null) { - VectorizedBatchUtil.setNullColIsNullValue(batch.cols[projectionColumnNum], batchIndex); + VectorizedBatchUtil.setNullColIsNullValue(columnVector, batchIndex); return; } // TODO: Do we need maxLength checking? byte[] bytes = hiveVarchar.getValue().getBytes(); - ((BytesColumnVector) batch.cols[projectionColumnNum]).setVal( - batchIndex, bytes, 0, bytes.length); + ((BytesColumnVector) columnVector).setVal(batchIndex, bytes, 0, bytes.length); } break; case CHAR: @@ -626,10 +727,9 @@ public void assignConvertRowColumn(VectorizedRowBatch batch, int batchIndex, // UNDONE: Performance problem with conversion to String, then bytes... HiveChar hiveChar = - PrimitiveObjectInspectorUtils.getHiveChar( - object, convertSourcePrimitiveObjectInspectors[logicalColumnIndex]); + PrimitiveObjectInspectorUtils.getHiveChar(object, primitiveLogicalOI); if (hiveChar == null) { - VectorizedBatchUtil.setNullColIsNullValue(batch.cols[projectionColumnNum], batchIndex); + VectorizedBatchUtil.setNullColIsNullValue(columnVector, batchIndex); return; } // We store CHAR in vector row batch with padding stripped. @@ -637,47 +737,42 @@ public void assignConvertRowColumn(VectorizedRowBatch batch, int batchIndex, // TODO: Do we need maxLength checking? byte[] bytes = hiveChar.getStrippedValue().getBytes(); - ((BytesColumnVector) batch.cols[projectionColumnNum]).setVal( + ((BytesColumnVector) columnVector).setVal( batchIndex, bytes, 0, bytes.length); } break; case DECIMAL: { HiveDecimal hiveDecimal = - PrimitiveObjectInspectorUtils.getHiveDecimal( - object, convertSourcePrimitiveObjectInspectors[logicalColumnIndex]); + PrimitiveObjectInspectorUtils.getHiveDecimal(object, primitiveLogicalOI); if (hiveDecimal == null) { - VectorizedBatchUtil.setNullColIsNullValue(batch.cols[projectionColumnNum], batchIndex); + VectorizedBatchUtil.setNullColIsNullValue(columnVector, batchIndex); return; } - ((DecimalColumnVector) batch.cols[projectionColumnNum]).set( - batchIndex, hiveDecimal); + ((DecimalColumnVector) columnVector).set(batchIndex, hiveDecimal); } break; case INTERVAL_YEAR_MONTH: { HiveIntervalYearMonth intervalYearMonth = - PrimitiveObjectInspectorUtils.getHiveIntervalYearMonth( - object, convertSourcePrimitiveObjectInspectors[logicalColumnIndex]); + PrimitiveObjectInspectorUtils.getHiveIntervalYearMonth(object, primitiveLogicalOI); if (intervalYearMonth == null) { - VectorizedBatchUtil.setNullColIsNullValue(batch.cols[projectionColumnNum], batchIndex); + VectorizedBatchUtil.setNullColIsNullValue(columnVector, batchIndex); return; } - ((LongColumnVector) batch.cols[projectionColumnNum]).vector[batchIndex] = + ((LongColumnVector) columnVector).vector[batchIndex] = intervalYearMonth.getTotalMonths(); } break; case INTERVAL_DAY_TIME: { HiveIntervalDayTime intervalDayTime = - PrimitiveObjectInspectorUtils.getHiveIntervalDayTime( - object, convertSourcePrimitiveObjectInspectors[logicalColumnIndex]); + PrimitiveObjectInspectorUtils.getHiveIntervalDayTime(object, primitiveLogicalOI); if (intervalDayTime == null) { - VectorizedBatchUtil.setNullColIsNullValue(batch.cols[projectionColumnNum], batchIndex); + VectorizedBatchUtil.setNullColIsNullValue(columnVector, batchIndex); return; } - ((IntervalDayTimeColumnVector) batch.cols[projectionColumnNum]).set( - batchIndex, intervalDayTime); + ((IntervalDayTimeColumnVector) columnVector).set(batchIndex, intervalDayTime); } break; default: @@ -685,18 +780,102 @@ public void assignConvertRowColumn(VectorizedRowBatch batch, int batchIndex, " not supported"); } break; + case LIST: + { + ListColumnVector listColumnVector = (ListColumnVector) columnVector; + ColumnVector childColumnVector = listColumnVector.child; + TypeInfo elementTypeInfo = ((ListTypeInfo) typeInfo).getListElementTypeInfo(); + ListObjectInspector listObjectInspector = (ListObjectInspector) objectInspector; + ObjectInspector elementObjectInspector = listObjectInspector.getListElementObjectInspector(); + + int size = listObjectInspector.getListLength(object); + int offset = listColumnVector.childCount; + listColumnVector.offsets[batchIndex] = offset; + listColumnVector.lengths[batchIndex] = size; + listColumnVector.childCount += size; + childColumnVector.ensureSize(offset + size, true); + + for (int i = 0; i < size; i++) { + assignConvertRowColumn( + offset + i, elementTypeInfo, elementObjectInspector, childColumnVector, + listObjectInspector.getListElement(object, i)); + } + } + break; + case MAP: + { + MapColumnVector mapColumnVector = (MapColumnVector) columnVector; + ColumnVector keyColumnVector = mapColumnVector.keys; + ColumnVector valueColumnVector = mapColumnVector.values; + MapTypeInfo mapTypeInfo = (MapTypeInfo) typeInfo; + TypeInfo keyTypeInfo = mapTypeInfo.getMapKeyTypeInfo(); + TypeInfo valueTypeInfo = mapTypeInfo.getMapValueTypeInfo(); + MapObjectInspector mapObjectInspector = (MapObjectInspector) objectInspector; + ObjectInspector keyObjectInspector = mapObjectInspector.getMapKeyObjectInspector(); + ObjectInspector valueObjectInspector = mapObjectInspector.getMapValueObjectInspector(); + + Map map = mapObjectInspector.getMap(object); + List entries = new ArrayList(map.entrySet()); + int size = map.size(); + int offset = mapColumnVector.childCount; + mapColumnVector.offsets[batchIndex] = offset; + mapColumnVector.lengths[batchIndex] = size; + mapColumnVector.childCount += size; + keyColumnVector.ensureSize(offset + size, true); + valueColumnVector.ensureSize(offset + size, true); + + for (int i = 0; i < size; i++) { + Map.Entry entry = entries.get(i); + assignConvertRowColumn( + offset + i, keyTypeInfo, keyObjectInspector, keyColumnVector, entry.getKey()); + assignConvertRowColumn( + offset + i, valueTypeInfo, valueObjectInspector, valueColumnVector, entry.getValue()); + } + } + break; + case STRUCT: + { + StructColumnVector structColumnVector = (StructColumnVector) columnVector; + ColumnVector[] fieldColumnVectors = structColumnVector.fields; + StructTypeInfo structTypeInfo = (StructTypeInfo) typeInfo; + List typeInfos = structTypeInfo.getAllStructFieldTypeInfos(); + StructObjectInspector structObjectInspector = (StructObjectInspector) objectInspector; + List fields = structObjectInspector.getAllStructFieldRefs(); + int size = fields.size(); + + for (int i = 0; i < size; i++) { + StructField field = fields.get(i); + ObjectInspector fieldObjectInspector = field.getFieldObjectInspector(); + assignConvertRowColumn(batchIndex, typeInfos.get(i), fieldObjectInspector, + fieldColumnVectors[i], structObjectInspector.getStructFieldData(object, field)); + } + } + break; + case UNION: + { + UnionColumnVector unionColumnVector = (UnionColumnVector) columnVector; + ColumnVector[] fieldColumnVectors = unionColumnVector.fields; + UnionObjectInspector unionObjectInspector = (UnionObjectInspector) objectInspector; + byte tag = unionObjectInspector.getTag(object); + ObjectInspector elementObjectInspector = + unionObjectInspector.getObjectInspectors().get(tag); + UnionTypeInfo unionTypeInfo = (UnionTypeInfo) typeInfo; + assignConvertRowColumn(batchIndex, unionTypeInfo, elementObjectInspector, + fieldColumnVectors[tag], unionObjectInspector.getField(object)); + } + break; default: throw new RuntimeException("Category " + targetCategory.name() + " not supported"); } } catch (NumberFormatException e) { // Some of the conversion methods throw this exception on numeric parsing errors. - VectorizedBatchUtil.setNullColIsNullValue(batch.cols[projectionColumnNum], batchIndex); + VectorizedBatchUtil.setNullColIsNullValue(columnVector, batchIndex); return; } // We always set the null flag to false when there is a value. - batch.cols[projectionColumnNum].isNull[batchIndex] = false; + columnVector.isNull[batchIndex] = false; } /* diff --git ql/src/java/org/apache/hadoop/hive/ql/exec/vector/VectorExtractRow.java ql/src/java/org/apache/hadoop/hive/ql/exec/vector/VectorExtractRow.java index defaf90..bf60814 100644 --- ql/src/java/org/apache/hadoop/hive/ql/exec/vector/VectorExtractRow.java +++ ql/src/java/org/apache/hadoop/hive/ql/exec/vector/VectorExtractRow.java @@ -18,8 +18,17 @@ package org.apache.hadoop.hive.ql.exec.vector; +import java.sql.Timestamp; +import java.util.ArrayList; +import java.util.HashMap; import java.util.List; +import java.util.Map; +import org.apache.hadoop.hive.serde2.objectinspector.StandardUnionObjectInspector; +import org.apache.hadoop.hive.serde2.typeinfo.ListTypeInfo; +import org.apache.hadoop.hive.serde2.typeinfo.MapTypeInfo; +import org.apache.hadoop.hive.serde2.typeinfo.StructTypeInfo; +import org.apache.hadoop.hive.serde2.typeinfo.UnionTypeInfo; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.apache.hadoop.hive.ql.exec.vector.expressions.StringExpr; @@ -73,17 +82,9 @@ // Extraction can be a subset of columns, so this is the projection -- // the batch column numbers. - Category[] categories; - // The data type category of each column being extracted. + TypeInfo[] typeInfos; - PrimitiveCategory[] primitiveCategories; - // The data type primitive category of each column being assigned. - - int[] maxLengths; - // For the CHAR and VARCHAR data types, the maximum character length of - // the columns. Otherwise, 0. - - Writable[] primitiveWritables; + Map primitiveWritables; // The extracted values will be placed in these writables. /* @@ -91,10 +92,8 @@ */ private void allocateArrays(int count) { projectionColumnNums = new int[count]; - categories = new Category[count]; - primitiveCategories = new PrimitiveCategory[count]; - maxLengths = new int[count]; - primitiveWritables = new Writable[count]; + typeInfos = new TypeInfo[count]; + primitiveWritables = new HashMap(count); } /* @@ -102,28 +101,7 @@ private void allocateArrays(int count) { */ private void initEntry(int logicalColumnIndex, int projectionColumnNum, TypeInfo typeInfo) { projectionColumnNums[logicalColumnIndex] = projectionColumnNum; - Category category = typeInfo.getCategory(); - categories[logicalColumnIndex] = category; - if (category == Category.PRIMITIVE) { - PrimitiveTypeInfo primitiveTypeInfo = (PrimitiveTypeInfo) typeInfo; - PrimitiveCategory primitiveCategory = primitiveTypeInfo.getPrimitiveCategory(); - primitiveCategories[logicalColumnIndex] = primitiveCategory; - - switch (primitiveCategory) { - case CHAR: - maxLengths[logicalColumnIndex] = ((CharTypeInfo) primitiveTypeInfo).getLength(); - break; - case VARCHAR: - maxLengths[logicalColumnIndex] = ((VarcharTypeInfo) primitiveTypeInfo).getLength(); - break; - default: - // No additional data type specific setting. - break; - } - - primitiveWritables[logicalColumnIndex] = - VectorizedBatchUtil.getPrimitiveWritable(primitiveCategory); - } + typeInfos[logicalColumnIndex] = typeInfo; } /* @@ -206,62 +184,76 @@ public Object extractRowColumn(VectorizedRowBatch batch, int batchIndex, int log // may ask for them.. return null; } + TypeInfo typeInfo = typeInfos[logicalColumnIndex]; + return extractRowColumn(colVector, batchIndex, typeInfo, projectionColumnNum, true); + } + + private Object extractRowColumn(ColumnVector colVector, int batchIndex, TypeInfo typeInfo, int projectionColumnNum, boolean reuseWritable) { int adjustedIndex = (colVector.isRepeating ? 0 : batchIndex); if (!colVector.noNulls && colVector.isNull[adjustedIndex]) { return null; } - Category category = categories[logicalColumnIndex]; + Category category = typeInfo.getCategory(); switch (category) { case PRIMITIVE: { - Writable primitiveWritable = - primitiveWritables[logicalColumnIndex]; - PrimitiveCategory primitiveCategory = primitiveCategories[logicalColumnIndex]; + PrimitiveTypeInfo primitiveTypeInfo = (PrimitiveTypeInfo) typeInfo; + PrimitiveCategory primitiveCategory = primitiveTypeInfo.getPrimitiveCategory(); + Writable primitiveWritable; + if (reuseWritable) { + primitiveWritable = primitiveWritables.get(colVector); + if (primitiveWritable == null) { + primitiveWritable = VectorizedBatchUtil.getPrimitiveWritable(primitiveCategory); + primitiveWritables.put(colVector, primitiveWritable); + } + } else { + primitiveWritable = VectorizedBatchUtil.getPrimitiveWritable(primitiveCategory); + } switch (primitiveCategory) { case VOID: return null; case BOOLEAN: ((BooleanWritable) primitiveWritable).set( - ((LongColumnVector) batch.cols[projectionColumnNum]).vector[adjustedIndex] == 0 ? + ((LongColumnVector) colVector).vector[adjustedIndex] == 0 ? false : true); return primitiveWritable; case BYTE: ((ByteWritable) primitiveWritable).set( - (byte) ((LongColumnVector) batch.cols[projectionColumnNum]).vector[adjustedIndex]); + (byte) ((LongColumnVector) colVector).vector[adjustedIndex]); return primitiveWritable; case SHORT: ((ShortWritable) primitiveWritable).set( - (short) ((LongColumnVector) batch.cols[projectionColumnNum]).vector[adjustedIndex]); + (short) ((LongColumnVector) colVector).vector[adjustedIndex]); return primitiveWritable; case INT: ((IntWritable) primitiveWritable).set( - (int) ((LongColumnVector) batch.cols[projectionColumnNum]).vector[adjustedIndex]); + (int) ((LongColumnVector) colVector).vector[adjustedIndex]); return primitiveWritable; case LONG: ((LongWritable) primitiveWritable).set( - ((LongColumnVector) batch.cols[projectionColumnNum]).vector[adjustedIndex]); + ((LongColumnVector) colVector).vector[adjustedIndex]); return primitiveWritable; case TIMESTAMP: - ((TimestampWritable) primitiveWritable).set( - ((TimestampColumnVector) batch.cols[projectionColumnNum]).asScratchTimestamp(adjustedIndex)); + Timestamp timestamp = ((TimestampWritable) primitiveWritable).getTimestamp(); + ((TimestampColumnVector) colVector).timestampUpdate(timestamp, adjustedIndex); return primitiveWritable; case DATE: ((DateWritable) primitiveWritable).set( - (int) ((LongColumnVector) batch.cols[projectionColumnNum]).vector[adjustedIndex]); + (int) ((LongColumnVector) colVector).vector[adjustedIndex]); return primitiveWritable; case FLOAT: ((FloatWritable) primitiveWritable).set( - (float) ((DoubleColumnVector) batch.cols[projectionColumnNum]).vector[adjustedIndex]); + (float) ((DoubleColumnVector) colVector).vector[adjustedIndex]); return primitiveWritable; case DOUBLE: ((DoubleWritable) primitiveWritable).set( - ((DoubleColumnVector) batch.cols[projectionColumnNum]).vector[adjustedIndex]); + ((DoubleColumnVector) colVector).vector[adjustedIndex]); return primitiveWritable; case BINARY: { BytesColumnVector bytesColVector = - ((BytesColumnVector) batch.cols[projectionColumnNum]); + ((BytesColumnVector) colVector); byte[] bytes = bytesColVector.vector[adjustedIndex]; int start = bytesColVector.start[adjustedIndex]; int length = bytesColVector.length[adjustedIndex]; @@ -277,7 +269,7 @@ public Object extractRowColumn(VectorizedRowBatch batch, int batchIndex, int log case STRING: { BytesColumnVector bytesColVector = - ((BytesColumnVector) batch.cols[projectionColumnNum]); + ((BytesColumnVector) colVector); byte[] bytes = bytesColVector.vector[adjustedIndex]; int start = bytesColVector.start[adjustedIndex]; int length = bytesColVector.length[adjustedIndex]; @@ -293,7 +285,7 @@ public Object extractRowColumn(VectorizedRowBatch batch, int batchIndex, int log case VARCHAR: { BytesColumnVector bytesColVector = - ((BytesColumnVector) batch.cols[projectionColumnNum]); + ((BytesColumnVector) colVector); byte[] bytes = bytesColVector.vector[adjustedIndex]; int start = bytesColVector.start[adjustedIndex]; int length = bytesColVector.length[adjustedIndex]; @@ -303,7 +295,7 @@ public Object extractRowColumn(VectorizedRowBatch batch, int batchIndex, int log } int adjustedLength = StringExpr.truncate(bytes, start, length, - maxLengths[logicalColumnIndex]); + ((VarcharTypeInfo) typeInfo).getLength()); HiveVarcharWritable hiveVarcharWritable = (HiveVarcharWritable) primitiveWritable; hiveVarcharWritable.set(new String(bytes, start, adjustedLength, Charsets.UTF_8), -1); @@ -312,7 +304,7 @@ public Object extractRowColumn(VectorizedRowBatch batch, int batchIndex, int log case CHAR: { BytesColumnVector bytesColVector = - ((BytesColumnVector) batch.cols[projectionColumnNum]); + ((BytesColumnVector) colVector); byte[] bytes = bytesColVector.vector[adjustedIndex]; int start = bytesColVector.start[adjustedIndex]; int length = bytesColVector.length[adjustedIndex]; @@ -321,32 +313,96 @@ public Object extractRowColumn(VectorizedRowBatch batch, int batchIndex, int log nullBytesReadError(primitiveCategory, batchIndex, projectionColumnNum); } - int adjustedLength = StringExpr.rightTrimAndTruncate(bytes, start, length, - maxLengths[logicalColumnIndex]); + int maxLength = ((CharTypeInfo) typeInfo).getLength(); + int adjustedLength = StringExpr.rightTrimAndTruncate(bytes, start, length, maxLength); HiveCharWritable hiveCharWritable = (HiveCharWritable) primitiveWritable; - hiveCharWritable.set(new String(bytes, start, adjustedLength, Charsets.UTF_8), - maxLengths[logicalColumnIndex]); + hiveCharWritable.set(new String(bytes, start, adjustedLength, Charsets.UTF_8), maxLength); return primitiveWritable; } case DECIMAL: // The HiveDecimalWritable set method will quickly copy the deserialized decimal writable fields. ((HiveDecimalWritable) primitiveWritable).set( - ((DecimalColumnVector) batch.cols[projectionColumnNum]).vector[adjustedIndex]); + ((DecimalColumnVector) colVector).vector[adjustedIndex]); return primitiveWritable; case INTERVAL_YEAR_MONTH: ((HiveIntervalYearMonthWritable) primitiveWritable).set( - (int) ((LongColumnVector) batch.cols[projectionColumnNum]).vector[adjustedIndex]); + (int) ((LongColumnVector) colVector).vector[adjustedIndex]); return primitiveWritable; case INTERVAL_DAY_TIME: ((HiveIntervalDayTimeWritable) primitiveWritable).set( - ((IntervalDayTimeColumnVector) batch.cols[projectionColumnNum]).asScratchIntervalDayTime(adjustedIndex)); + ((IntervalDayTimeColumnVector) colVector).asScratchIntervalDayTime(adjustedIndex)); return primitiveWritable; default: throw new RuntimeException("Primitive category " + primitiveCategory.name() + " not supported"); } } + case LIST: + { + ListColumnVector listColVector = (ListColumnVector) colVector; + ColumnVector childColVector = listColVector.child; + int offset = (int) listColVector.offsets[adjustedIndex]; + int length = (int) listColVector.lengths[adjustedIndex]; + ListTypeInfo listTypeInfo = (ListTypeInfo) typeInfo; + TypeInfo elementTypeInfo = listTypeInfo.getListElementTypeInfo(); + List list = new ArrayList(length); + for (int i = 0; i < length; i++) { + list.add(extractRowColumn(childColVector, offset + i, + elementTypeInfo, projectionColumnNum, false)); + } + return list; + } + case MAP: + { + MapColumnVector mapColVector = (MapColumnVector) colVector; + ColumnVector keyColVector = mapColVector.keys; + ColumnVector valueColVector = mapColVector.values; + int offset = (int) mapColVector.offsets[adjustedIndex]; + int length = (int) mapColVector.lengths[adjustedIndex]; + MapTypeInfo mapTypeInfo = (MapTypeInfo) typeInfo; + TypeInfo keyTypeInfo = mapTypeInfo.getMapKeyTypeInfo(); + TypeInfo valueTypeInfo = mapTypeInfo.getMapValueTypeInfo(); + Map map = new HashMap(); + + for (int i = 0; i < length; i++) { + Object key = extractRowColumn(keyColVector, offset + i, + keyTypeInfo, projectionColumnNum, false); + Object value = extractRowColumn(valueColVector, offset + i, + valueTypeInfo, projectionColumnNum, false); + map.put(key, value); + } + return map; + } + case STRUCT: + { + StructColumnVector structColVector = (StructColumnVector) colVector; + ColumnVector[] fieldColVectors = structColVector.fields; + StructTypeInfo structTypeInfo = (StructTypeInfo) typeInfo; + List fieldTypeInfos = structTypeInfo.getAllStructFieldTypeInfos(); + int size = fieldColVectors.length; + List list = new ArrayList(size); + for (int i = 0; i < size; i++) { + list.add(extractRowColumn(fieldColVectors[i], adjustedIndex, + fieldTypeInfos.get(i), projectionColumnNum, false)); + } + return list; + } + case UNION: + { + UnionColumnVector unionColVector = (UnionColumnVector) colVector; + ColumnVector[] fieldColVectors = unionColVector.fields; + UnionTypeInfo unionTypeInfo = (UnionTypeInfo) typeInfo; + List objectTypeInfos = unionTypeInfo.getAllUnionObjectTypeInfos(); + int tag = unionColVector.tags[adjustedIndex]; + Object object = extractRowColumn(fieldColVectors[tag], adjustedIndex, + objectTypeInfos.get(tag), projectionColumnNum, false); + StandardUnionObjectInspector.StandardUnion standardUnion = + new StandardUnionObjectInspector.StandardUnion(); + standardUnion.setObject(object); + standardUnion.setTag((byte) tag); + return standardUnion; + } default: throw new RuntimeException("Category " + category.name() + " not supported"); } diff --git ql/src/java/org/apache/hadoop/hive/ql/exec/vector/VectorSerializeRow.java ql/src/java/org/apache/hadoop/hive/ql/exec/vector/VectorSerializeRow.java index 319b4a8..257df26 100644 --- ql/src/java/org/apache/hadoop/hive/ql/exec/vector/VectorSerializeRow.java +++ ql/src/java/org/apache/hadoop/hive/ql/exec/vector/VectorSerializeRow.java @@ -20,8 +20,10 @@ import java.io.IOException; import java.sql.Timestamp; +import java.util.ArrayList; import java.util.Arrays; import java.util.List; +import java.util.ListIterator; import org.apache.hadoop.hive.common.type.HiveIntervalDayTime; import org.apache.hadoop.hive.ql.metadata.HiveException; @@ -30,9 +32,13 @@ import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector.Category; import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector.PrimitiveCategory; import org.apache.hadoop.hive.serde2.fast.SerializeWrite; +import org.apache.hadoop.hive.serde2.typeinfo.ListTypeInfo; +import org.apache.hadoop.hive.serde2.typeinfo.MapTypeInfo; import org.apache.hadoop.hive.serde2.typeinfo.PrimitiveTypeInfo; +import org.apache.hadoop.hive.serde2.typeinfo.StructTypeInfo; import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo; import org.apache.hadoop.hive.serde2.typeinfo.TypeInfoUtils; +import org.apache.hadoop.hive.serde2.typeinfo.UnionTypeInfo; /** * This class serializes columns from a row in a VectorizedRowBatch into a serialization format. @@ -50,8 +56,7 @@ private T serializeWrite; - private Category[] categories; - private PrimitiveCategory[] primitiveCategories; + private TypeInfo[] typeInfos; private int[] outputColumnNums; @@ -67,32 +72,24 @@ private VectorSerializeRow() { public void init(List typeNames, int[] columnMap) throws HiveException { final int size = typeNames.size(); - categories = new Category[size]; - primitiveCategories = new PrimitiveCategory[size]; + typeInfos = new TypeInfo[size]; outputColumnNums = Arrays.copyOf(columnMap, size); TypeInfo typeInfo; for (int i = 0; i < size; i++) { typeInfo = TypeInfoUtils.getTypeInfoFromTypeString(typeNames.get(i)); - categories[i] = typeInfo.getCategory(); - if (categories[i] == Category.PRIMITIVE) { - primitiveCategories[i] = ((PrimitiveTypeInfo) typeInfo).getPrimitiveCategory(); - } + typeInfos[i] = typeInfo; } } public void init(List typeNames) throws HiveException { final int size = typeNames.size(); - categories = new Category[size]; - primitiveCategories = new PrimitiveCategory[size]; + typeInfos = new TypeInfo[size]; outputColumnNums = new int[size]; TypeInfo typeInfo; for (int i = 0; i < size; i++) { typeInfo = TypeInfoUtils.getTypeInfoFromTypeString(typeNames.get(i)); - categories[i] = typeInfo.getCategory(); - if (categories[i] == Category.PRIMITIVE) { - primitiveCategories[i] = ((PrimitiveTypeInfo) typeInfo).getPrimitiveCategory(); - } + typeInfos[i] = typeInfo; outputColumnNums[i] = i; } } @@ -101,21 +98,12 @@ public void init(TypeInfo[] typeInfos, int[] columnMap) throws HiveException { final int size = typeInfos.length; - categories = new Category[size]; - primitiveCategories = new PrimitiveCategory[size]; outputColumnNums = Arrays.copyOf(columnMap, size); - TypeInfo typeInfo; - for (int i = 0; i < typeInfos.length; i++) { - typeInfo = typeInfos[i]; - categories[i] = typeInfo.getCategory(); - if (categories[i] == Category.PRIMITIVE) { - primitiveCategories[i] = ((PrimitiveTypeInfo) typeInfo).getPrimitiveCategory(); - } - } + this.typeInfos = typeInfos; } public int getCount() { - return categories.length; + return typeInfos.length; } public void setOutput(Output output) { @@ -134,14 +122,20 @@ public void setOutputAppend(Output output) { * been performed. batchIndex is the actual index of the row. */ public void serializeWrite(VectorizedRowBatch batch, int batchIndex) throws IOException { - hasAnyNulls = false; isAllNulls = true; ColumnVector colVector; - int adjustedBatchIndex; - final int size = categories.length; + final int size = typeInfos.length; for (int i = 0; i < size; i++) { colVector = batch.cols[outputColumnNums[i]]; + List indices = new ArrayList(); + indices.add(i); + serializeWrite(colVector, indices, batchIndex); + } + } + + private void serializeWrite(ColumnVector colVector, List indices, int batchIndex) throws IOException { + int adjustedBatchIndex; if (colVector.isRepeating) { adjustedBatchIndex = 0; } else { @@ -150,12 +144,16 @@ public void serializeWrite(VectorizedRowBatch batch, int batchIndex) throws IOEx if (!colVector.noNulls && colVector.isNull[adjustedBatchIndex]) { serializeWrite.writeNull(); hasAnyNulls = true; - continue; + return; } isAllNulls = false; - switch (categories[i]) { + TypeInfo typeInfo = getTypeInfo(indices); + Category category = typeInfo.getCategory(); + switch (category) { case PRIMITIVE: - switch (primitiveCategories[i]) { + PrimitiveTypeInfo primitiveTypeInfo = (PrimitiveTypeInfo) typeInfo; + PrimitiveCategory primitiveCategory = primitiveTypeInfo.getPrimitiveCategory(); + switch (primitiveTypeInfo.getPrimitiveCategory()) { case BOOLEAN: serializeWrite.writeBoolean(((LongColumnVector) colVector).vector[adjustedBatchIndex] != 0); break; @@ -217,13 +215,127 @@ public void serializeWrite(VectorizedRowBatch batch, int batchIndex) throws IOEx serializeWrite.writeHiveIntervalDayTime(((IntervalDayTimeColumnVector) colVector).asScratchIntervalDayTime(adjustedBatchIndex)); break; default: - throw new RuntimeException("Unexpected primitive category " + primitiveCategories[i]); + throw new RuntimeException("Unexpected primitive category " + primitiveCategory); + } + break; + case LIST: + { + indices.add(0); + + ListColumnVector listColVector = (ListColumnVector) colVector; + ColumnVector childColVector = listColVector.child; + int offset = (int) listColVector.offsets[adjustedBatchIndex]; + int length = (int) listColVector.lengths[adjustedBatchIndex]; + + serializeWrite.writeInt(length); + + for (int i = 0; i < length; i++) { + serializeWrite(childColVector, indices, offset + i); + } + + indices.remove(indices.size() - 1); + } + break; + case MAP: + { + indices.add(0); + int lastIndexPosition = indices.size() - 1; + + MapColumnVector mapColVector = (MapColumnVector) colVector; + ColumnVector keyColVector = mapColVector.keys; + ColumnVector valueColVector = mapColVector.values; + int offset = (int) mapColVector.offsets[adjustedBatchIndex]; + int length = (int) mapColVector.lengths[adjustedBatchIndex]; + + serializeWrite.writeInt(length); + + for (int i = 0; i < length; i++) { + indices.set(lastIndexPosition, 0); + serializeWrite(keyColVector, indices, offset + i); + indices.set(lastIndexPosition, 1); + serializeWrite(valueColVector, indices, offset + i); + } + + indices.remove(indices.size() - 1); + } + break; + case STRUCT: + { + indices.add(0); + int lastIndexPosition = indices.size() - 1; + + StructColumnVector structColVector = (StructColumnVector) colVector; + ColumnVector[] fieldColVectors = structColVector.fields; + + for (int i = 0; i < fieldColVectors.length; i++) { + indices.set(lastIndexPosition, i); + serializeWrite(fieldColVectors[i], indices, adjustedBatchIndex); + } + + indices.remove(indices.size() - 1); + } + break; + case UNION: + { + UnionColumnVector unionColVector = (UnionColumnVector) colVector; + ColumnVector[] fieldColVectors = unionColVector.fields; + int tag = unionColVector.tags[adjustedBatchIndex]; + indices.add(tag); + serializeWrite(fieldColVectors[tag], indices, batchIndex); + indices.remove(indices.size() - 1); } break; default: - throw new RuntimeException("Unexpected category " + categories[i]); + throw new RuntimeException("Unexpected category " + category); + } + } + + private TypeInfo getTypeInfo(List indices) { + int size = indices.size(); + if (size == 0) { + return null; + } else if (size == 1) { + return typeInfos[indices.get(0)]; + } else { + return getTypeInfo(typeInfos[indices.get(0)], indices.listIterator(1)); + } + } + + private TypeInfo getTypeInfo(TypeInfo typeInfo, ListIterator iterator) { + if (!iterator.hasNext()) { + return typeInfo; + } + int index = iterator.next(); + switch (typeInfo.getCategory()) { + case LIST: + { + ListTypeInfo listTypeInfo = (ListTypeInfo) typeInfo; + return getTypeInfo(listTypeInfo.getListElementTypeInfo(), iterator); + } + case MAP: + { + MapTypeInfo mapTypeInfo = (MapTypeInfo) typeInfo; + if (index == 0) { + return getTypeInfo(mapTypeInfo.getMapKeyTypeInfo(), iterator); + } else if (index == 1) { + return getTypeInfo(mapTypeInfo.getMapValueTypeInfo(), iterator); + } + } + break; + case STRUCT: + { + StructTypeInfo structTypeInfo = (StructTypeInfo) typeInfo; + List fieldTypeInfos = structTypeInfo.getAllStructFieldTypeInfos(); + return getTypeInfo(fieldTypeInfos.get(index), iterator); + } + case UNION: + { + UnionTypeInfo unionTypeInfo = (UnionTypeInfo) typeInfo; + List objectTypeInfos = unionTypeInfo.getAllUnionObjectTypeInfos(); + return getTypeInfo(objectTypeInfos.get(index), iterator); } } + return null; } public boolean getHasAnyNulls() { diff --git ql/src/test/org/apache/hadoop/hive/ql/exec/vector/TestVectorRowObject.java ql/src/test/org/apache/hadoop/hive/ql/exec/vector/TestVectorRowObject.java index e9ce8e8..d882b46 100644 --- ql/src/test/org/apache/hadoop/hive/ql/exec/vector/TestVectorRowObject.java +++ ql/src/test/org/apache/hadoop/hive/ql/exec/vector/TestVectorRowObject.java @@ -39,8 +39,17 @@ void examineBatch(VectorizedRowBatch batch, VectorExtractRow vectorExtractRow, vectorExtractRow.extractRow(batch, i, row); Object[] expectedRow = randomRows[firstRandomRowIndex + i]; for (int c = 0; c < rowSize; c++) { - if (!row[c].equals(expectedRow[c])) { - fail("Row " + (firstRandomRowIndex + i) + " and column " + c + " mismatch"); + Object actual = row[c]; + Object expected = expectedRow[c]; + + if (actual == null) { + if (expected != null) { + fail("Row " + (firstRandomRowIndex + i) + " and column " + c + " mismatch"); + } + } else { + if (!actual.equals(expected)) { + fail("Row " + (firstRandomRowIndex + i) + " and column " + c + " mismatch"); + } } } } @@ -51,7 +60,7 @@ void testVectorRowObject(int caseNum, boolean sort, Random r) throws HiveExcepti String[] emptyScratchTypeNames = new String[0]; VectorRandomRowSource source = new VectorRandomRowSource(); - source.init(r); + source.init(r, true, 4); VectorizedRowBatchCtx batchContext = new VectorizedRowBatchCtx(); batchContext.init(source.rowStructObjectInspector(), emptyScratchTypeNames); @@ -69,7 +78,7 @@ void testVectorRowObject(int caseNum, boolean sort, Random r) throws HiveExcepti VectorExtractRow vectorExtractRow = new VectorExtractRow(); vectorExtractRow.init(source.typeNames()); - Object[][] randomRows = source.randomRows(10000); + Object[][] randomRows = source.randomRows(1000); if (sort) { source.sort(randomRows); } diff --git ql/src/test/org/apache/hadoop/hive/ql/exec/vector/TestVectorSerDeRow.java ql/src/test/org/apache/hadoop/hive/ql/exec/vector/TestVectorSerDeRow.java index b29bb8b..e33b4bd 100644 --- ql/src/test/org/apache/hadoop/hive/ql/exec/vector/TestVectorSerDeRow.java +++ ql/src/test/org/apache/hadoop/hive/ql/exec/vector/TestVectorSerDeRow.java @@ -315,7 +315,7 @@ void testVectorSerializeRow(Random r, SerializationType serializationType) String[] emptyScratchTypeNames = new String[0]; VectorRandomRowSource source = new VectorRandomRowSource(); - source.init(r); + source.init(r, false, 4); VectorizedRowBatchCtx batchContext = new VectorizedRowBatchCtx(); batchContext.init(source.rowStructObjectInspector(), emptyScratchTypeNames); @@ -341,7 +341,7 @@ void testVectorSerializeRow(Random r, SerializationType serializationType) StructObjectInspector rowObjectInspector = source.rowStructObjectInspector(); LazySerDeParameters lazySerDeParams = getSerDeParams(rowObjectInspector); byte separator = (byte) '\t'; - deserializeRead = new LazySimpleDeserializeRead(source.primitiveTypeInfos(), /* useExternalBuffer */ false, + deserializeRead = new LazySimpleDeserializeRead(source.typeInfos(), /* useExternalBuffer */ false, separator, lazySerDeParams); serializeWrite = new LazySimpleSerializeWrite(fieldCount, separator, lazySerDeParams); @@ -353,7 +353,7 @@ void testVectorSerializeRow(Random r, SerializationType serializationType) VectorSerializeRow vectorSerializeRow = new VectorSerializeRow(serializeWrite); vectorSerializeRow.init(source.typeNames()); - Object[][] randomRows = source.randomRows(100000); + Object[][] randomRows = source.randomRows(10000); int firstRandomRowIndex = 0; for (int i = 0; i < randomRows.length; i++) { Object[] row = randomRows[i]; @@ -553,7 +553,7 @@ void testVectorDeserializeRow(Random r, SerializationType serializationType, String[] emptyScratchTypeNames = new String[0]; VectorRandomRowSource source = new VectorRandomRowSource(); - source.init(r); + source.init(r, false, 4); VectorizedRowBatchCtx batchContext = new VectorizedRowBatchCtx(); batchContext.init(source.rowStructObjectInspector(), emptyScratchTypeNames); diff --git ql/src/test/org/apache/hadoop/hive/ql/exec/vector/VectorRandomRowSource.java ql/src/test/org/apache/hadoop/hive/ql/exec/vector/VectorRandomRowSource.java index cbde615..8997b54 100644 --- ql/src/test/org/apache/hadoop/hive/ql/exec/vector/VectorRandomRowSource.java +++ ql/src/test/org/apache/hadoop/hive/ql/exec/vector/VectorRandomRowSource.java @@ -34,12 +34,23 @@ import org.apache.hadoop.hive.common.type.HiveIntervalYearMonth; import org.apache.hadoop.hive.common.type.HiveVarchar; import org.apache.hadoop.hive.common.type.RandomTypeUtil; -import org.apache.hadoop.hive.ql.metadata.HiveException; +import org.apache.hadoop.hive.serde2.io.HiveCharWritable; +import org.apache.hadoop.hive.serde2.io.HiveVarcharWritable; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector.Category; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils; import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector.PrimitiveCategory; +import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.MapObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.StandardListObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.StandardMapObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.StandardStructObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.StandardUnionObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.StandardUnionObjectInspector.StandardUnion; +import org.apache.hadoop.hive.serde2.objectinspector.StructField; import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.UnionObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory; import org.apache.hadoop.hive.serde2.objectinspector.primitive.WritableBooleanObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.primitive.WritableByteObjectInspector; @@ -58,11 +69,19 @@ import org.apache.hadoop.hive.serde2.objectinspector.primitive.WritableTimestampObjectInspector; import org.apache.hadoop.hive.serde2.typeinfo.CharTypeInfo; import org.apache.hadoop.hive.serde2.typeinfo.DecimalTypeInfo; +import org.apache.hadoop.hive.serde2.typeinfo.ListTypeInfo; +import org.apache.hadoop.hive.serde2.typeinfo.MapTypeInfo; import org.apache.hadoop.hive.serde2.typeinfo.PrimitiveTypeInfo; +import org.apache.hadoop.hive.serde2.typeinfo.StructTypeInfo; +import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo; import org.apache.hadoop.hive.serde2.typeinfo.TypeInfoUtils; +import org.apache.hadoop.hive.serde2.typeinfo.UnionTypeInfo; import org.apache.hadoop.hive.serde2.typeinfo.VarcharTypeInfo; import org.apache.hive.common.util.DateUtils; +import org.apache.hadoop.io.Text; +import org.apache.hadoop.io.BytesWritable; +import com.google.common.base.Preconditions; import com.google.common.base.Charsets; /** @@ -76,6 +95,14 @@ private List typeNames; + private Category[] categories; + + private TypeInfo[] typeInfos; + + private List objectInspectorList; + + // Primitive. + private PrimitiveCategory[] primitiveCategories; private PrimitiveTypeInfo[] primitiveTypeInfos; @@ -93,6 +120,14 @@ return typeNames; } + public Category[] categories() { + return categories; + } + + public TypeInfo[] typeInfos() { + return typeInfos; + } + public PrimitiveCategory[] primitiveCategories() { return primitiveCategories; } @@ -106,30 +141,57 @@ public StructObjectInspector rowStructObjectInspector() { } public StructObjectInspector partialRowStructObjectInspector(int partialFieldCount) { - ArrayList partialPrimitiveObjectInspectorList = + ArrayList partialObjectInspectorList = new ArrayList(partialFieldCount); List columnNames = new ArrayList(partialFieldCount); for (int i = 0; i < partialFieldCount; i++) { columnNames.add(String.format("partial%d", i)); - partialPrimitiveObjectInspectorList.add( - PrimitiveObjectInspectorFactory.getPrimitiveWritableObjectInspector( - primitiveTypeInfos[i])); + partialObjectInspectorList.add(getObjectInspector(typeInfos[i])); } return ObjectInspectorFactory.getStandardStructObjectInspector( - columnNames, primitiveObjectInspectorList); + columnNames, objectInspectorList); } public void init(Random r) { + init(r, false, 0); + } + + public void init(Random r, boolean includeComplexTypes, int maxComplexDepth) { this.r = r; - chooseSchema(); + chooseSchema(includeComplexTypes, maxComplexDepth); + } + + public void init(int keyCount, Random random, List keyPrimitiveObjectInspectorList, + PrimitiveCategory[] keyPrimitiveCategories, PrimitiveTypeInfo[] keyPrimitiveTypeInfos) { + this.r = random; + columnCount = keyCount; + typeNames = new ArrayList(keyCount); + List columnNames = new ArrayList(keyCount); + categories = new Category[keyCount]; + + for (int i = 0; i < keyCount; i++) { + columnNames.add(String.format("col%d", i)); + typeNames.add(keyPrimitiveTypeInfos[i].getTypeName()); + categories[i] = Category.PRIMITIVE; + } + + typeInfos = keyPrimitiveTypeInfos; + objectInspectorList = keyPrimitiveObjectInspectorList; + primitiveCategories = keyPrimitiveCategories; + primitiveTypeInfos = keyPrimitiveTypeInfos; + primitiveObjectInspectorList = keyPrimitiveObjectInspectorList; + + rowStructObjectInspector = + ObjectInspectorFactory.getStandardStructObjectInspector(columnNames, objectInspectorList); + alphabets = new String[keyCount]; } /* * For now, exclude CHAR until we determine why there is a difference (blank padding) * serializing with LazyBinarySerializeWrite and the regular SerDe... */ - private static String[] possibleHiveTypeNames = { + private static String[] possibleHivePrimitiveTypeNames = { "boolean", "tinyint", "smallint", @@ -149,7 +211,146 @@ public void init(Random r) { "decimal" }; - private void chooseSchema() { + private static String[] possibleHiveComplexTypeNames = { + "array", + "map", + "struct", + "uniontype" + }; + + private String getRandomTypeName(boolean includeComplexTypes) { + String typeName; + if (!includeComplexTypes || r.nextInt(10) != 0) { + typeName = possibleHivePrimitiveTypeNames[r.nextInt(possibleHivePrimitiveTypeNames.length)]; + } else { + typeName = possibleHiveComplexTypeNames[r.nextInt(possibleHiveComplexTypeNames.length)]; + } + return typeName; + } + + private String getDecoratedTypeName(String typeName, boolean includeComplexTypes, int depth, int maxDepth) { + depth++; + boolean includeChildrenComplexTypes = includeComplexTypes && depth < maxDepth; + if (typeName.equals("char")) { + int maxLength = 1 + r.nextInt(100); + typeName = String.format("char(%d)", maxLength); + } else if (typeName.equals("varchar")) { + int maxLength = 1 + r.nextInt(100); + typeName = String.format("varchar(%d)", maxLength); + } else if (typeName.equals("decimal")) { + typeName = String.format("decimal(%d,%d)", HiveDecimal.SYSTEM_DEFAULT_PRECISION, HiveDecimal.SYSTEM_DEFAULT_SCALE); + } else if (typeName.equals("array")) { + String elementTypeName = getRandomTypeName(includeChildrenComplexTypes); + elementTypeName = getDecoratedTypeName(elementTypeName, includeChildrenComplexTypes, depth, maxDepth); + typeName = String.format("array<%s>", elementTypeName); + } else if (typeName.equals("map")) { + String keyTypeName = getRandomTypeName(includeChildrenComplexTypes); + keyTypeName = getDecoratedTypeName(keyTypeName, includeChildrenComplexTypes, depth, maxDepth); + String valueTypeName = getRandomTypeName(includeChildrenComplexTypes); + valueTypeName = getDecoratedTypeName(valueTypeName, includeChildrenComplexTypes, depth, maxDepth); + typeName = String.format("map<%s,%s>", keyTypeName, valueTypeName); + } else if (typeName.equals("struct")) { + final int fieldCount = 1 + r.nextInt(10); + StringBuilder sb = new StringBuilder(); + for (int i = 0; i < fieldCount; i++) { + String fieldTypeName = getRandomTypeName(includeChildrenComplexTypes); + fieldTypeName = getDecoratedTypeName(fieldTypeName, includeChildrenComplexTypes, depth, maxDepth); + if (i > 0) { + sb.append(","); + } + sb.append("col"); + sb.append(i); + sb.append(":"); + sb.append(fieldTypeName); + } + typeName = String.format("struct<%s>", sb.toString()); + } else if (typeName.equals("struct") || + typeName.equals("uniontype")) { + final int fieldCount = 1 + r.nextInt(10); + StringBuilder sb = new StringBuilder(); + for (int i = 0; i < fieldCount; i++) { + String fieldTypeName = getRandomTypeName(includeChildrenComplexTypes); + fieldTypeName = getDecoratedTypeName(fieldTypeName, includeChildrenComplexTypes, depth, maxDepth); + if (i > 0) { + sb.append(","); + } + sb.append(fieldTypeName); + } + typeName = String.format("uniontype<%s>", sb.toString()); + } + return typeName; + } + + private ObjectInspector getObjectInspector(TypeInfo typeInfo) { + ObjectInspector objectInspector; + switch (typeInfo.getCategory()) { + case PRIMITIVE: + { + PrimitiveTypeInfo primitiveType = (PrimitiveTypeInfo) typeInfo; + objectInspector = + PrimitiveObjectInspectorFactory. + getPrimitiveWritableObjectInspector(primitiveType); + } + break; + case MAP: + { + MapTypeInfo mapType = (MapTypeInfo) typeInfo; + MapObjectInspector mapInspector = + ObjectInspectorFactory.getStandardMapObjectInspector( + getObjectInspector(mapType.getMapKeyTypeInfo()), + getObjectInspector(mapType.getMapValueTypeInfo())); + objectInspector = mapInspector; + } + break; + case LIST: + { + ListTypeInfo listType = (ListTypeInfo) typeInfo; + ListObjectInspector listInspector = + ObjectInspectorFactory.getStandardListObjectInspector( + getObjectInspector(listType.getListElementTypeInfo())); + objectInspector = listInspector; + } + break; + case STRUCT: + { + StructTypeInfo structType = (StructTypeInfo) typeInfo; + List fieldTypes = structType.getAllStructFieldTypeInfos(); + + List fieldInspectors = new ArrayList(); + for (TypeInfo fieldType : fieldTypes) { + fieldInspectors.add(getObjectInspector(fieldType)); + } + + StructObjectInspector structInspector = + ObjectInspectorFactory.getStandardStructObjectInspector( + structType.getAllStructFieldNames(), fieldInspectors); + objectInspector = structInspector; + } + break; + case UNION: + { + UnionTypeInfo unionType = (UnionTypeInfo) typeInfo; + List fieldTypes = unionType.getAllUnionObjectTypeInfos(); + + List fieldInspectors = new ArrayList(); + for (TypeInfo fieldType : fieldTypes) { + fieldInspectors.add(getObjectInspector(fieldType)); + } + + UnionObjectInspector unionInspector = + ObjectInspectorFactory.getStandardUnionObjectInspector( + fieldInspectors); + objectInspector = unionInspector; + } + break; + default: + throw new RuntimeException("Unexpected category " + typeInfo.getCategory()); + } + Preconditions.checkState(objectInspector != null); + return objectInspector; + } + + private void chooseSchema(boolean includeComplexTypes, int maxComplexDepth) { HashSet hashSet = null; boolean allTypes; boolean onlyOne = (r.nextInt(100) == 7); @@ -160,13 +361,20 @@ private void chooseSchema() { allTypes = r.nextBoolean(); if (allTypes) { // One of each type. - columnCount = possibleHiveTypeNames.length; + columnCount = possibleHivePrimitiveTypeNames.length; + if (includeComplexTypes) { + columnCount += possibleHiveComplexTypeNames.length; + } hashSet = new HashSet(); } else { columnCount = 1 + r.nextInt(20); } } typeNames = new ArrayList(columnCount); + categories = new Category[columnCount]; + typeInfos = new TypeInfo[columnCount]; + objectInspectorList = new ArrayList(columnCount); + primitiveCategories = new PrimitiveCategory[columnCount]; primitiveTypeInfos = new PrimitiveTypeInfo[columnCount]; primitiveObjectInspectorList = new ArrayList(columnCount); @@ -176,12 +384,18 @@ private void chooseSchema() { String typeName; if (onlyOne) { - typeName = possibleHiveTypeNames[r.nextInt(possibleHiveTypeNames.length)]; + typeName = getRandomTypeName(includeComplexTypes); } else { int typeNum; if (allTypes) { + int maxTypeNum = possibleHivePrimitiveTypeNames.length; + if (includeComplexTypes) { + maxTypeNum += possibleHiveComplexTypeNames.length; + } while (true) { - typeNum = r.nextInt(possibleHiveTypeNames.length); + + typeNum = r.nextInt(maxTypeNum); + Integer typeNumInteger = new Integer(typeNum); if (!hashSet.contains(typeNumInteger)) { hashSet.add(typeNumInteger); @@ -189,30 +403,94 @@ private void chooseSchema() { } } } else { - typeNum = r.nextInt(possibleHiveTypeNames.length); + if (!includeComplexTypes || r.nextInt(10) != 0) { + typeNum = r.nextInt(possibleHivePrimitiveTypeNames.length); + } else { + typeNum = possibleHivePrimitiveTypeNames.length + r.nextInt(possibleHiveComplexTypeNames.length); + } + } + if (typeNum < possibleHivePrimitiveTypeNames.length) { + typeName = possibleHivePrimitiveTypeNames[typeNum]; + } else { + typeName = possibleHiveComplexTypeNames[typeNum - possibleHivePrimitiveTypeNames.length]; } - typeName = possibleHiveTypeNames[typeNum]; + + } + + String decoratedTypeName = getDecoratedTypeName(typeName, includeComplexTypes, 0, maxComplexDepth); + + TypeInfo typeInfo; + try { + typeInfo = TypeInfoUtils.getTypeInfoFromTypeString(decoratedTypeName); + } catch (Exception e) { + throw new RuntimeException("Cannot convert type name " + decoratedTypeName + " to a type " + e); + } + + typeInfos[c] = typeInfo; + Category category = typeInfo.getCategory(); + categories[c] = category; + ObjectInspector objectInspector = getObjectInspector(typeInfo); + switch (category) { + case PRIMITIVE: + { + PrimitiveTypeInfo primitiveTypeInfo = (PrimitiveTypeInfo) typeInfo; + objectInspector = + PrimitiveObjectInspectorFactory.getPrimitiveWritableObjectInspector(primitiveTypeInfo); + primitiveTypeInfos[c] = primitiveTypeInfo; + PrimitiveCategory primitiveCategory = primitiveTypeInfo.getPrimitiveCategory(); + primitiveCategories[c] = primitiveCategory; + primitiveObjectInspectorList.add(objectInspector); } - if (typeName.equals("char")) { - int maxLength = 1 + r.nextInt(100); - typeName = String.format("char(%d)", maxLength); - } else if (typeName.equals("varchar")) { - int maxLength = 1 + r.nextInt(100); - typeName = String.format("varchar(%d)", maxLength); - } else if (typeName.equals("decimal")) { - typeName = String.format("decimal(%d,%d)", HiveDecimal.SYSTEM_DEFAULT_PRECISION, HiveDecimal.SYSTEM_DEFAULT_SCALE); + break; + case LIST: + case MAP: + case STRUCT: + case UNION: + primitiveObjectInspectorList.add(null); + break; + default: + throw new RuntimeException("Unexpected catagory " + category); + } + objectInspectorList.add(objectInspector); + + if (category == Category.PRIMITIVE) { } - PrimitiveTypeInfo primitiveTypeInfo = (PrimitiveTypeInfo) TypeInfoUtils.getTypeInfoFromTypeString(typeName); - primitiveTypeInfos[c] = primitiveTypeInfo; - PrimitiveCategory primitiveCategory = primitiveTypeInfo.getPrimitiveCategory(); - primitiveCategories[c] = primitiveCategory; - primitiveObjectInspectorList.add(PrimitiveObjectInspectorFactory.getPrimitiveWritableObjectInspector(primitiveTypeInfo)); - typeNames.add(typeName); + typeNames.add(decoratedTypeName); } - rowStructObjectInspector = ObjectInspectorFactory.getStandardStructObjectInspector(columnNames, primitiveObjectInspectorList); + rowStructObjectInspector = ObjectInspectorFactory.getStandardStructObjectInspector( + columnNames, objectInspectorList); alphabets = new String[columnCount]; } + public Object[][] randomRows(int n) { + Object[][] result = new Object[n][]; + for (int i = 0; i < n; i++) { + result[i] = randomRow(); + } + return result; + } + + public Object[] randomRow() { + Object row[] = new Object[columnCount]; + for (int c = 0; c < columnCount; c++) { + row[c] = randomWritable(c); + } + return row; + } + + public Object[] randomPrimitiveRow(int columnCount) { + return randomPrimitiveRow(columnCount, r, primitiveTypeInfos); + } + + public static Object[] randomPrimitiveRow(int columnCount, Random r, + PrimitiveTypeInfo[] primitiveTypeInfos) { + Object row[] = new Object[columnCount]; + for (int c = 0; c < columnCount; c++) { + row[c] = randomPrimitiveObject(r, primitiveTypeInfos[c]); + } + return row; + } + public void addBinarySortableAlphabets() { for (int c = 0; c < columnCount; c++) { switch (primitiveCategories[c]) { @@ -241,52 +519,6 @@ public void addEscapables(String needsEscapeStr) { this.needsEscapeStr = needsEscapeStr; } - public Object[][] randomRows(int n) { - Object[][] result = new Object[n][]; - for (int i = 0; i < n; i++) { - result[i] = randomRow(); - } - return result; - } - - public Object[] randomRow() { - Object row[] = new Object[columnCount]; - for (int c = 0; c < columnCount; c++) { - Object object = randomObject(c); - if (object == null) { - throw new Error("Unexpected null for column " + c); - } - row[c] = getWritableObject(c, object); - if (row[c] == null) { - throw new Error("Unexpected null for writable for column " + c); - } - } - return row; - } - - public Object[] randomRow(int columnCount) { - return randomRow(columnCount, r, primitiveObjectInspectorList, primitiveCategories, - primitiveTypeInfos); - } - - public static Object[] randomRow(int columnCount, Random r, - List primitiveObjectInspectorList, PrimitiveCategory[] primitiveCategories, - PrimitiveTypeInfo[] primitiveTypeInfos) { - Object row[] = new Object[columnCount]; - for (int c = 0; c < columnCount; c++) { - Object object = randomObject(c, r, primitiveCategories, primitiveTypeInfos); - if (object == null) { - throw new Error("Unexpected null for column " + c); - } - row[c] = getWritableObject(c, object, primitiveObjectInspectorList, - primitiveCategories, primitiveTypeInfos); - if (row[c] == null) { - throw new Error("Unexpected null for writable for column " + c); - } - } - return row; - } - public static void sort(Object[][] rows, ObjectInspector oi) { for (int i = 0; i < rows.length; i++) { for (int j = i + 1; j < rows.length; j++) { @@ -303,18 +535,9 @@ public void sort(Object[][] rows) { VectorRandomRowSource.sort(rows, rowStructObjectInspector); } - public Object getWritableObject(int column, Object object) { - return getWritableObject(column, object, primitiveObjectInspectorList, - primitiveCategories, primitiveTypeInfos); - } - - public static Object getWritableObject(int column, Object object, - List primitiveObjectInspectorList, PrimitiveCategory[] primitiveCategories, - PrimitiveTypeInfo[] primitiveTypeInfos) { - ObjectInspector objectInspector = primitiveObjectInspectorList.get(column); - PrimitiveCategory primitiveCategory = primitiveCategories[column]; - PrimitiveTypeInfo primitiveTypeInfo = primitiveTypeInfos[column]; - switch (primitiveCategory) { + public Object getWritablePrimitiveObject(PrimitiveTypeInfo primitiveTypeInfo, + ObjectInspector objectInspector, Object object) { + switch (primitiveTypeInfo.getPrimitiveCategory()) { case BOOLEAN: return ((WritableBooleanObjectInspector) objectInspector).create((boolean) object); case BYTE: @@ -361,99 +584,203 @@ public static Object getWritableObject(int column, Object object, return result; } default: - throw new Error("Unknown primitive category " + primitiveCategory); + throw new Error("Unknown primitive category " + primitiveTypeInfo.getPrimitiveCategory()); } } - public Object randomObject(int column) { - return randomObject(column, r, primitiveCategories, primitiveTypeInfos, alphabets, addEscapables, needsEscapeStr); - } - - public static Object randomObject(int column, Random r, PrimitiveCategory[] primitiveCategories, - PrimitiveTypeInfo[] primitiveTypeInfos) { - return randomObject(column, r, primitiveCategories, primitiveTypeInfos, null, false, ""); - } - - public static Object randomObject(int column, Random r, PrimitiveCategory[] primitiveCategories, - PrimitiveTypeInfo[] primitiveTypeInfos, String[] alphabets, boolean addEscapables, String needsEscapeStr) { - PrimitiveCategory primitiveCategory = primitiveCategories[column]; - PrimitiveTypeInfo primitiveTypeInfo = primitiveTypeInfos[column]; - try { - switch (primitiveCategory) { - case BOOLEAN: - return Boolean.valueOf(r.nextInt(1) == 1); - case BYTE: - return Byte.valueOf((byte) r.nextInt()); - case SHORT: - return Short.valueOf((short) r.nextInt()); - case INT: - return Integer.valueOf(r.nextInt()); - case LONG: - return Long.valueOf(r.nextLong()); - case DATE: - return RandomTypeUtil.getRandDate(r); - case FLOAT: - return Float.valueOf(r.nextFloat() * 10 - 5); - case DOUBLE: - return Double.valueOf(r.nextDouble() * 10 - 5); - case STRING: - case CHAR: - case VARCHAR: - { - String result; - if (alphabets != null && alphabets[column] != null) { - result = RandomTypeUtil.getRandString(r, alphabets[column], r.nextInt(10)); - } else { - result = RandomTypeUtil.getRandString(r); + public Object randomWritable(int column) { + return randomWritable(typeInfos[column], objectInspectorList.get(column)); + } + + public Object randomWritable(TypeInfo typeInfo, ObjectInspector objectInspector) { + switch (typeInfo.getCategory()) { + case PRIMITIVE: + { + Object object = randomPrimitiveObject(r, (PrimitiveTypeInfo) typeInfo); + return getWritablePrimitiveObject((PrimitiveTypeInfo) typeInfo, objectInspector, object); + } + case LIST: + { + if (r.nextInt(20) == 0) { + return null; + } + // Always generate a list with at least 1 value? + final int elementCount = 1 + r.nextInt(100); + StandardListObjectInspector listObjectInspector = + (StandardListObjectInspector) objectInspector; + ObjectInspector elementObjectInspector = + listObjectInspector.getListElementObjectInspector(); + TypeInfo elementTypeInfo = + TypeInfoUtils.getTypeInfoFromObjectInspector( + elementObjectInspector); + boolean isStringFamily = false; + PrimitiveCategory primitiveCategory = null; + if (elementTypeInfo.getCategory() == Category.PRIMITIVE) { + primitiveCategory = ((PrimitiveTypeInfo) elementTypeInfo).getPrimitiveCategory(); + if (primitiveCategory == PrimitiveCategory.STRING || + primitiveCategory == PrimitiveCategory.BINARY || + primitiveCategory == PrimitiveCategory.CHAR || + primitiveCategory == PrimitiveCategory.VARCHAR) { + isStringFamily = true; } - if (addEscapables && result.length() > 0) { - int escapeCount = 1 + r.nextInt(2); - for (int i = 0; i < escapeCount; i++) { - int index = r.nextInt(result.length()); - String begin = result.substring(0, index); - String end = result.substring(index); - Character needsEscapeChar = needsEscapeStr.charAt(r.nextInt(needsEscapeStr.length())); - result = begin + needsEscapeChar + end; - } + } + Object listObj = listObjectInspector.create(elementCount); + for (int i = 0; i < elementCount; i++) { + Object ele = randomWritable(elementTypeInfo, elementObjectInspector); + // UNDONE: For now, a 1-element list with a null element is a null list... + if (ele == null && elementCount == 1) { + return null; } - switch (primitiveCategory) { - case STRING: - return result; - case CHAR: - return new HiveChar(result, ((CharTypeInfo) primitiveTypeInfo).getLength()); - case VARCHAR: - return new HiveVarchar(result, ((VarcharTypeInfo) primitiveTypeInfo).getLength()); - default: - throw new Error("Unknown primitive category " + primitiveCategory); + if (isStringFamily && elementCount == 1) { + switch (primitiveCategory) { + case STRING: + if (((Text) ele).getLength() == 0) { + return null; + } + break; + case BINARY: + if (((BytesWritable) ele).getLength() == 0) { + return null; + } + break; + case CHAR: + if (((HiveCharWritable) ele).getHiveChar().getStrippedValue().isEmpty()) { + return null; + } + break; + case VARCHAR: + if (((HiveVarcharWritable) ele).getHiveVarchar().getValue().isEmpty()) { + return null; + } + break; + default: + throw new RuntimeException("Unexpected primitive category " + primitiveCategory); + } } + listObjectInspector.set(listObj, i, ele); } - case BINARY: - return getRandBinary(r, 1 + r.nextInt(100)); - case TIMESTAMP: - return RandomTypeUtil.getRandTimestamp(r); - case INTERVAL_YEAR_MONTH: - return getRandIntervalYearMonth(r); - case INTERVAL_DAY_TIME: - return getRandIntervalDayTime(r); - case DECIMAL: - return getRandHiveDecimal(r, (DecimalTypeInfo) primitiveTypeInfo); - default: - throw new Error("Unknown primitive category " + primitiveCategory); + return listObj; + } + case MAP: + { + if (r.nextInt(20) == 0) { + return null; + } + final int keyPairCount = r.nextInt(100); + StandardMapObjectInspector mapObjectInspector = + (StandardMapObjectInspector) objectInspector; + ObjectInspector keyObjectInspector = + mapObjectInspector.getMapKeyObjectInspector(); + TypeInfo keyTypeInfo = + TypeInfoUtils.getTypeInfoFromObjectInspector( + keyObjectInspector); + ObjectInspector valueObjectInspector = + mapObjectInspector.getMapValueObjectInspector(); + TypeInfo valueTypeInfo = + TypeInfoUtils.getTypeInfoFromObjectInspector( + valueObjectInspector); + Object mapObj = mapObjectInspector.create(); + for (int i = 0; i < keyPairCount; i++) { + Object key = randomWritable(keyTypeInfo, keyObjectInspector); + Object value = randomWritable(valueTypeInfo, valueObjectInspector); + mapObjectInspector.put(mapObj, key, value); + } + return mapObj; + } + case STRUCT: + { + if (r.nextInt(20) == 0) { + return null; + } + StandardStructObjectInspector structObjectInspector = + (StandardStructObjectInspector) objectInspector; + List fieldRefsList = structObjectInspector.getAllStructFieldRefs(); + final int fieldCount = fieldRefsList.size(); + Object structObj = structObjectInspector.create(); + for (int i = 0; i < fieldCount; i++) { + StructField fieldRef = fieldRefsList.get(i); + ObjectInspector fieldObjectInspector = + fieldRef.getFieldObjectInspector(); + TypeInfo fieldTypeInfo = + TypeInfoUtils.getTypeInfoFromObjectInspector( + fieldObjectInspector); + Object fieldObj = randomWritable(fieldTypeInfo, fieldObjectInspector); + structObjectInspector.setStructFieldData(structObj, fieldRef, fieldObj); + } + return structObj; + } + case UNION: + { + StandardUnionObjectInspector unionObjectInspector = + (StandardUnionObjectInspector) objectInspector; + List objectInspectorList = unionObjectInspector.getObjectInspectors(); + final int unionCount = objectInspectorList.size(); + final byte tag = (byte) r.nextInt(unionCount); + Object unionObj = unionObjectInspector.create(); + ObjectInspector fieldObjectInspector = + objectInspectorList.get(tag); + TypeInfo fieldTypeInfo = + TypeInfoUtils.getTypeInfoFromObjectInspector( + fieldObjectInspector); + Object fieldObj = randomWritable(fieldTypeInfo, fieldObjectInspector); + return new StandardUnion(tag, fieldObj); } - } catch (Exception e) { - throw new RuntimeException("randomObject failed on column " + column + " type " + primitiveCategory, e); + default: + throw new RuntimeException("Unexpected category " + typeInfo.getCategory()); } } - public static HiveChar getRandHiveChar(Random r, CharTypeInfo charTypeInfo, String alphabet) { - int maxLength = 1 + r.nextInt(charTypeInfo.getLength()); - String randomString = RandomTypeUtil.getRandString(r, alphabet, 100); - HiveChar hiveChar = new HiveChar(randomString, maxLength); - return hiveChar; + public Object randomPrimitiveObject(int column) { + return randomPrimitiveObject(r, primitiveTypeInfos[column]); + } + + public static Object randomPrimitiveObject(Random r, PrimitiveTypeInfo primitiveTypeInfo) { + switch (primitiveTypeInfo.getPrimitiveCategory()) { + case BOOLEAN: + return Boolean.valueOf(r.nextBoolean()); + case BYTE: + return Byte.valueOf((byte) r.nextInt()); + case SHORT: + return Short.valueOf((short) r.nextInt()); + case INT: + return Integer.valueOf(r.nextInt()); + case LONG: + return Long.valueOf(r.nextLong()); + case DATE: + return RandomTypeUtil.getRandDate(r); + case FLOAT: + return Float.valueOf(r.nextFloat() * 10 - 5); + case DOUBLE: + return Double.valueOf(r.nextDouble() * 10 - 5); + case STRING: + return RandomTypeUtil.getRandString(r); + case CHAR: + return getRandHiveChar(r, (CharTypeInfo) primitiveTypeInfo); + case VARCHAR: + return getRandHiveVarchar(r, (VarcharTypeInfo) primitiveTypeInfo); + case BINARY: + return getRandBinary(r, 1 + r.nextInt(100)); + case TIMESTAMP: + return RandomTypeUtil.getRandTimestamp(r); + case INTERVAL_YEAR_MONTH: + return getRandIntervalYearMonth(r); + case INTERVAL_DAY_TIME: + return getRandIntervalDayTime(r); + case DECIMAL: + { + HiveDecimal dec = getRandHiveDecimal(r, (DecimalTypeInfo) primitiveTypeInfo); + return dec; + } + default: + throw new Error("Unknown primitive category " + primitiveTypeInfo.getCategory()); + } } public static HiveChar getRandHiveChar(Random r, CharTypeInfo charTypeInfo) { - return getRandHiveChar(r, charTypeInfo, "abcdefghijklmnopqrstuvwxyz"); + int maxLength = 1 + r.nextInt(charTypeInfo.getLength()); + String randomString = RandomTypeUtil.getRandString(r, "abcdefghijklmnopqrstuvwxyz", 100); + HiveChar hiveChar = new HiveChar(randomString, maxLength); + return hiveChar; } public static HiveVarchar getRandHiveVarchar(Random r, VarcharTypeInfo varcharTypeInfo, String alphabet) { diff --git ql/src/test/org/apache/hadoop/hive/ql/exec/vector/mapjoin/fast/TestVectorMapJoinFastRowHashMap.java ql/src/test/org/apache/hadoop/hive/ql/exec/vector/mapjoin/fast/TestVectorMapJoinFastRowHashMap.java index ebb243e..331f48b 100644 --- ql/src/test/org/apache/hadoop/hive/ql/exec/vector/mapjoin/fast/TestVectorMapJoinFastRowHashMap.java +++ ql/src/test/org/apache/hadoop/hive/ql/exec/vector/mapjoin/fast/TestVectorMapJoinFastRowHashMap.java @@ -88,6 +88,10 @@ private void addAndVerifyRows(VectorRandomRowSource valueSource, Object[][] rows SerializeWrite valueSerializeWrite = new LazyBinarySerializeWrite(columnCount); + VectorRandomRowSource source = new VectorRandomRowSource(); + source.init(keyCount, random, keyPrimitiveObjectInspectorList, + keyPrimitiveCategories, keyPrimitiveTypeInfos); + final int count = rows.length; for (int i = 0; i < count; i++) { @@ -108,9 +112,7 @@ private void addAndVerifyRows(VectorRandomRowSource valueSource, Object[][] rows // Add a new key or add a value to an existing key? byte[] key; if (random.nextBoolean() || verifyTable.getCount() == 0) { - Object[] keyRow = - VectorRandomRowSource.randomRow(keyCount, random, keyPrimitiveObjectInspectorList, - keyPrimitiveCategories, keyPrimitiveTypeInfos); + Object[] keyRow = source.randomRow(); Output keyOutput = new Output(); keySerializeWrite.set(keyOutput); diff --git serde/src/java/org/apache/hadoop/hive/serde2/objectinspector/StandardUnionObjectInspector.java serde/src/java/org/apache/hadoop/hive/serde2/objectinspector/StandardUnionObjectInspector.java index f26c9ec..83f6a53 100644 --- serde/src/java/org/apache/hadoop/hive/serde2/objectinspector/StandardUnionObjectInspector.java +++ serde/src/java/org/apache/hadoop/hive/serde2/objectinspector/StandardUnionObjectInspector.java @@ -79,6 +79,31 @@ public byte getTag() { public String toString() { return tag + ":" + object; } + + @Override + public int hashCode() { + return object.hashCode() ^ tag; + } + + @Override + public boolean equals(Object obj) { + if (obj == this) { + return true; + } + if (!(obj instanceof StandardUnion)) { + return false; + } + StandardUnion that = (StandardUnion) obj; + if (this.tag != that.tag) { + return false; + } + + if (this.object == null) { + return that.object == null; + } else { + return this.object.equals(that.object); + } + } } /**