diff --git ql/src/java/org/apache/hadoop/hive/ql/exec/VectorHashKeyWrapperBatch.java ql/src/java/org/apache/hadoop/hive/ql/exec/VectorHashKeyWrapperBatch.java index cd57151..59bede4 100644 --- ql/src/java/org/apache/hadoop/hive/ql/exec/VectorHashKeyWrapperBatch.java +++ ql/src/java/org/apache/hadoop/hive/ql/exec/VectorHashKeyWrapperBatch.java @@ -25,6 +25,7 @@ import org.apache.hadoop.hive.ql.exec.vector.LongColumnVector; import org.apache.hadoop.hive.ql.exec.vector.VectorizedRowBatch; import org.apache.hadoop.hive.ql.exec.vector.expressions.VectorExpression; +import org.apache.hadoop.hive.ql.exec.vector.expressions.VectorExpressionWriter; import org.apache.hadoop.hive.ql.metadata.HiveException; import org.apache.hadoop.hive.serde2.io.DoubleWritable; import org.apache.hadoop.io.BytesWritable; @@ -82,18 +83,6 @@ private KeyLookupHelper[] indexLookup; /** - * preallocated and reused LongWritable objects for emiting row mode key values - */ - private LongWritable[] longKeyValueOutput; - - /** - * preallocated and reused DoubleWritable objects for emiting row mode key values - */ - private DoubleWritable[] doubleKeyValueOutput; - - private BytesWritable[] stringKeyValueOutput; - - /** * Accessor for the batch-sized array of key wrappers */ public VectorHashKeyWrapper[] getVectorHashKeyWrappers() { @@ -452,15 +441,19 @@ public static VectorHashKeyWrapperBatch compileKeyWrapperBatch(VectorExpression[ for(int i=0; i < keyExpressions.length; ++i) { indexLookup[i] = new KeyLookupHelper(); String outputType = keyExpressions[i].getOutputType(); - if (outputType.equalsIgnoreCase("long") || - outputType.equalsIgnoreCase("bigint") || - outputType.equalsIgnoreCase("int")) { + if (outputType.equalsIgnoreCase("tinyint") || + outputType.equalsIgnoreCase("smallint") || + outputType.equalsIgnoreCase("int") || + outputType.equalsIgnoreCase("bigint") || + outputType.equalsIgnoreCase("timestamp") || + outputType.equalsIgnoreCase("boolean")) { longIndices[longIndicesIndex] = i; indexLookup[i].longIndex = longIndicesIndex; indexLookup[i].doubleIndex = -1; indexLookup[i].stringIndex = -1; ++longIndicesIndex; - } else if (outputType.equalsIgnoreCase("double")) { + } else if (outputType.equalsIgnoreCase("double") || + outputType.equalsIgnoreCase("float")) { doubleIndices[doubleIndicesIndex] = i; indexLookup[i].longIndex = -1; indexLookup[i].doubleIndex = doubleIndicesIndex; @@ -477,18 +470,6 @@ public static VectorHashKeyWrapperBatch compileKeyWrapperBatch(VectorExpression[ } } compiledKeyWrapperBatch.indexLookup = indexLookup; - compiledKeyWrapperBatch.longKeyValueOutput = new LongWritable[longIndicesIndex]; - for (int i=0; i < longIndicesIndex; ++i) { - compiledKeyWrapperBatch.longKeyValueOutput[i] = new LongWritable(); - } - compiledKeyWrapperBatch.doubleKeyValueOutput = new DoubleWritable[doubleIndicesIndex]; - for (int i=0; i < doubleIndicesIndex; ++i) { - compiledKeyWrapperBatch.doubleKeyValueOutput[i] = new DoubleWritable(); - } - compiledKeyWrapperBatch.stringKeyValueOutput = new BytesWritable[stringIndicesIndex]; - for (int i = 0; i < stringIndicesIndex; ++i) { - compiledKeyWrapperBatch.stringKeyValueOutput[i] = new BytesWritable(); - } compiledKeyWrapperBatch.longIndices = Arrays.copyOf(longIndices, longIndicesIndex); compiledKeyWrapperBatch.doubleIndices = Arrays.copyOf(doubleIndices, doubleIndicesIndex); compiledKeyWrapperBatch.stringIndices = Arrays.copyOf(stringIndices, stringIndicesIndex); @@ -503,23 +484,22 @@ public static VectorHashKeyWrapperBatch compileKeyWrapperBatch(VectorExpression[ /** * Get the row-mode writable object value of a key from a key wrapper + * @param keyOutputWriter */ - public Object getWritableKeyValue(VectorHashKeyWrapper kw, int i) + public Object getWritableKeyValue(VectorHashKeyWrapper kw, int i, + VectorExpressionWriter keyOutputWriter) throws HiveException { if (kw.getIsNull(i)) { return null; } KeyLookupHelper klh = indexLookup[i]; if (klh.longIndex >= 0) { - longKeyValueOutput[klh.longIndex].set(kw.getLongValue(i)); - return longKeyValueOutput[klh.longIndex]; + return keyOutputWriter.writeValue(kw.getLongValue(i)); } else if (klh.doubleIndex >= 0) { - doubleKeyValueOutput[klh.doubleIndex].set(kw.getDoubleValue(i)); - return doubleKeyValueOutput[klh.doubleIndex]; + return keyOutputWriter.writeValue(kw.getDoubleValue(i)); } else if (klh.stringIndex >= 0) { - stringKeyValueOutput[klh.stringIndex].set( + return keyOutputWriter.writeValue( kw.getBytes(i), kw.getByteStart(i), kw.getByteLength(i)); - return stringKeyValueOutput[klh.stringIndex]; } else { throw new HiveException(String.format( "Internal inconsistent KeyLookupHelper at index [%d]:%d %d %d", diff --git ql/src/java/org/apache/hadoop/hive/ql/exec/vector/TimestampUtils.java ql/src/java/org/apache/hadoop/hive/ql/exec/vector/TimestampUtils.java new file mode 100644 index 0000000..b9b7744 --- /dev/null +++ ql/src/java/org/apache/hadoop/hive/ql/exec/vector/TimestampUtils.java @@ -0,0 +1,35 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.hadoop.hive.ql.exec.vector; + +import java.sql.Timestamp; + +public final class TimestampUtils { + + public static void assignTimeInNanoSec(long timeInNanoSec, Timestamp t) { + t.setTime((timeInNanoSec)/1000000); + t.setNanos((int)((t.getNanos()) + (timeInNanoSec % 1000000))); + } + + public static long getTimeNanoSec(Timestamp t) { + long time = t.getTime(); + int nanos = t.getNanos(); + return (time * 1000000) + (nanos % 1000000); + } +} diff --git ql/src/java/org/apache/hadoop/hive/ql/exec/vector/VectorGroupByOperator.java ql/src/java/org/apache/hadoop/hive/ql/exec/vector/VectorGroupByOperator.java index 91366dd..07eccea 100644 --- ql/src/java/org/apache/hadoop/hive/ql/exec/vector/VectorGroupByOperator.java +++ ql/src/java/org/apache/hadoop/hive/ql/exec/vector/VectorGroupByOperator.java @@ -33,9 +33,12 @@ import org.apache.hadoop.hive.ql.exec.VectorHashKeyWrapper; import org.apache.hadoop.hive.ql.exec.VectorHashKeyWrapperBatch; import org.apache.hadoop.hive.ql.exec.vector.expressions.VectorExpression; +import org.apache.hadoop.hive.ql.exec.vector.expressions.VectorExpressionWriter; +import org.apache.hadoop.hive.ql.exec.vector.expressions.VectorExpressionWriterFactory; import org.apache.hadoop.hive.ql.exec.vector.expressions.aggregates.VectorAggregateExpression; import org.apache.hadoop.hive.ql.metadata.HiveException; import org.apache.hadoop.hive.ql.plan.AggregationDesc; +import org.apache.hadoop.hive.ql.plan.ExprNodeDesc; import org.apache.hadoop.hive.ql.plan.GroupByDesc; import org.apache.hadoop.hive.ql.plan.OperatorDesc; import org.apache.hadoop.hive.ql.plan.api.OperatorType; @@ -66,6 +69,8 @@ */ private transient VectorExpression[] keyExpressions; + private VectorExpressionWriter[] keyOutputWriters; + /** * The aggregation buffers to use for the current batch. */ @@ -98,13 +103,18 @@ protected void initializeOp(Configuration hconf) throws HiveException { try { vContext.setOperatorType(OperatorType.GROUPBY); - ArrayList aggrDesc = conf.getAggregators(); - keyExpressions = vContext.getVectorExpressions(conf.getKeys()); + List keysDesc = conf.getKeys(); + keyExpressions = vContext.getVectorExpressions(keysDesc); + + keyOutputWriters = new VectorExpressionWriter[keyExpressions.length]; for(int i = 0; i < keyExpressions.length; ++i) { - objectInspectors.add(vContext.createObjectInspector(keyExpressions[i])); + keyOutputWriters[i] = VectorExpressionWriterFactory. + genVectorExpressionWritable(keysDesc.get(i)); + objectInspectors.add(keyOutputWriters[i].getObjectInspector()); } - + + ArrayList aggrDesc = conf.getAggregators(); aggregators = new VectorAggregateExpression[aggrDesc.size()]; for (int i = 0; i < aggrDesc.size(); ++i) { AggregationDesc desc = aggrDesc.get(i); @@ -233,7 +243,8 @@ public void closeOp(boolean aborted) throws HiveException { int fi = 0; for (int i = 0; i < keyExpressions.length; ++i) { VectorHashKeyWrapper kw = (VectorHashKeyWrapper)pair.getKey(); - forwardCache[fi++] = keyWrappersBatch.getWritableKeyValue (kw, i); + forwardCache[fi++] = keyWrappersBatch.getWritableKeyValue ( + kw, i, keyOutputWriters[i]); } for (int i = 0; i < aggregators.length; ++i) { forwardCache[fi++] = aggregators[i].evaluateOutput(pair.getValue() diff --git ql/src/java/org/apache/hadoop/hive/ql/exec/vector/VectorReduceSinkOperator.java ql/src/java/org/apache/hadoop/hive/ql/exec/vector/VectorReduceSinkOperator.java index f61fcb6..d177e29 100644 --- ql/src/java/org/apache/hadoop/hive/ql/exec/vector/VectorReduceSinkOperator.java +++ ql/src/java/org/apache/hadoop/hive/ql/exec/vector/VectorReduceSinkOperator.java @@ -28,6 +28,8 @@ import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.hive.ql.exec.TerminalOperator; import org.apache.hadoop.hive.ql.exec.vector.expressions.VectorExpression; +import org.apache.hadoop.hive.ql.exec.vector.expressions.VectorExpressionWriter; +import org.apache.hadoop.hive.ql.exec.vector.expressions.VectorExpressionWriterFactory; import org.apache.hadoop.hive.ql.io.HiveKey; import org.apache.hadoop.hive.ql.metadata.HiveException; import org.apache.hadoop.hive.ql.plan.OperatorDesc; @@ -57,11 +59,24 @@ * the reducer side. Key columns are passed to the reducer in the "key". */ protected transient VectorExpression[] keyEval; + + /** + * The key value writers. These know how to write the necessary writable type + * based on key column metadata, from the primitive vector type. + */ + protected transient VectorExpressionWriter[] keyWriters; + /** * The evaluators for the value columns. Value columns are passed to reducer * in the "value". */ protected transient VectorExpression[] valueEval; + + /** + * The output value writers. These know how to write the necessary writable type + * based on value column metadata, from the primitive vector type. + */ + protected transient VectorExpressionWriter[] valueWriters; /** * The evaluators for the partition columns (CLUSTER BY or DISTRIBUTE BY in @@ -69,6 +84,12 @@ * goes to. Partition columns are not passed to reducer. */ protected transient VectorExpression[] partitionEval; + + /** + * The partition value writers. These know how to write the necessary writable type + * based on partition column metadata, from the primitive vector type. + */ + protected transient VectorExpressionWriter[] partitionWriters; private int numDistributionKeys; @@ -112,15 +133,22 @@ protected void initializeOp(Configuration hconf) throws HiveException { .newInstance(); keySerializer.initialize(null, keyTableDesc.getProperties()); keyIsText = keySerializer.getSerializedClass().equals(Text.class); - - keyObjectInspector = vContext.createObjectInspector(keyEval, - conf.getOutputKeyColumnNames()); - - partitionObjectInspectors = new ObjectInspector[partitionEval.length]; - for (int i = 0; i < partitionEval.length; i++) { - partitionObjectInspectors[i] = vContext.createObjectInspector(partitionEval[i]); - } - + + /* + * Compute and assign the key writers and the key object inspector + */ + VectorExpressionWriterFactory.processVectorExpressions( + conf.getKeyCols(), + conf.getOutputKeyColumnNames(), + new VectorExpressionWriterFactory.Closure() { + @Override + public void assign(VectorExpressionWriter[] writers, + ObjectInspector objectInspector) { + keyWriters = writers; + keyObjectInspector = objectInspector; + } + }); + String colNames = ""; for(String colName : conf.getOutputKeyColumnNames()) { colNames = String.format("%s %s", colNames, colName); @@ -131,18 +159,27 @@ protected void initializeOp(Configuration hconf) throws HiveException { keyObjectInspector, colNames)); - conf.getOutputKeyColumnNames(); - conf.getOutputValueColumnNames(); - - //keyObjectInspector = ObjectInspectorFactory. - + partitionWriters = VectorExpressionWriterFactory.getExpressionWriters(conf.getPartitionCols()); + TableDesc valueTableDesc = conf.getValueSerializeInfo(); valueSerializer = (Serializer) valueTableDesc.getDeserializerClass() .newInstance(); valueSerializer.initialize(null, valueTableDesc.getProperties()); - - valueObjectInspector = vContext.createObjectInspector (valueEval, - conf.getOutputValueColumnNames()); + + /* + * Compute and assign the value writers and the value object inspector + */ + VectorExpressionWriterFactory.processVectorExpressions( + conf.getValueCols(), + conf.getOutputValueColumnNames(), + new VectorExpressionWriterFactory.Closure() { + @Override + public void assign(VectorExpressionWriter[] writers, + ObjectInspector objectInspector) { + valueWriters = writers; + valueObjectInspector = objectInspector; + } + }); colNames = ""; for(String colName : conf.getOutputValueColumnNames()) { @@ -202,7 +239,7 @@ public void processOp(Object row, int tag) throws HiveException { for (int i = 0; i < valueEval.length; i++) { int batchColumn = valueEval[i].getOutputColumn(); ColumnVector vectorColumn = vrg.cols[batchColumn]; - cachedValues[i] = vectorColumn.getWritableObject(rowIndex); + cachedValues[i] = valueWriters[i].writeValue(vectorColumn, rowIndex); } // Serialize the value value = valueSerializer.serialize(cachedValues, valueObjectInspector); @@ -210,7 +247,7 @@ public void processOp(Object row, int tag) throws HiveException { for (int i = 0; i < keyEval.length; i++) { int batchColumn = keyEval[i].getOutputColumn(); ColumnVector vectorColumn = vrg.cols[batchColumn]; - distributionKeys[i] = vectorColumn.getWritableObject(rowIndex); + distributionKeys[i] = keyWriters[i].writeValue(vectorColumn, rowIndex); } // no distinct key System.arraycopy(distributionKeys, 0, cachedKeys[0], 0, numDistributionKeys); @@ -255,11 +292,13 @@ public void processOp(Object row, int tag) throws HiveException { keyHashCode = random.nextInt(); } else { for (int p = 0; p < partitionEval.length; p++) { + ColumnVector columnVector = vrg.cols[partitionEval[p].getOutputColumn()]; + Object partitionValue = partitionWriters[p].writeValue(columnVector, rowIndex); keyHashCode = keyHashCode * 31 + ObjectInspectorUtils.hashCode( - vrg.cols[partitionEval[p].getOutputColumn()].getWritableObject(rowIndex), - partitionObjectInspectors[i]); + partitionValue, + partitionWriters[p].getObjectInspector()); } } keyWritable.setHashCode(keyHashCode); diff --git ql/src/java/org/apache/hadoop/hive/ql/exec/vector/VectorizationContext.java ql/src/java/org/apache/hadoop/hive/ql/exec/vector/VectorizationContext.java index c320356..d58a8c1 100644 --- ql/src/java/org/apache/hadoop/hive/ql/exec/vector/VectorizationContext.java +++ ql/src/java/org/apache/hadoop/hive/ql/exec/vector/VectorizationContext.java @@ -1049,11 +1049,6 @@ private VectorizedRowBatch allocateRowBatch(int rowCount) throws HiveException { return ret; } - Object[][] mapObjectInspectors = { - {"double", PrimitiveObjectInspectorFactory.writableDoubleObjectInspector}, - {"long", PrimitiveObjectInspectorFactory.writableLongObjectInspector}, - }; - public Map getOutputColumnTypeMap() { Map map = new HashMap(); for (int i = 0; i < ocm.outputColCount; i++) { @@ -1073,34 +1068,6 @@ public ColumnVector allocateColumnVector(String type, int defaultSize) { } } - public ObjectInspector createObjectInspector(VectorExpression vectorExpression) - throws HiveException { - String columnType = vectorExpression.getOutputType(); - if (columnType.equalsIgnoreCase("long") || - columnType.equalsIgnoreCase("bigint") || - columnType.equalsIgnoreCase("int") || - columnType.equalsIgnoreCase("smallint")) { - return PrimitiveObjectInspectorFactory.writableLongObjectInspector; - } else if (columnType.equalsIgnoreCase("double")) { - return PrimitiveObjectInspectorFactory.writableDoubleObjectInspector; - } else if (columnType.equalsIgnoreCase("string")) { - return PrimitiveObjectInspectorFactory.writableBinaryObjectInspector; - } else { - throw new HiveException(String.format("Must implement type %s", columnType)); - } - } - - public ObjectInspector createObjectInspector( - VectorExpression[] vectorExpressions, List columnNames) - throws HiveException { - List oids = new ArrayList(); - for (VectorExpression vexpr : vectorExpressions) { - ObjectInspector oi = createObjectInspector(vexpr); - oids.add(oi); - } - return ObjectInspectorFactory.getStandardStructObjectInspector(columnNames, - oids); - } public void addToColumnMap(String columnName, int outputColumn) { if (columnMap != null) { diff --git ql/src/java/org/apache/hadoop/hive/ql/exec/vector/VectorizedBatchUtil.java ql/src/java/org/apache/hadoop/hive/ql/exec/vector/VectorizedBatchUtil.java index ffd7ef2..80bf671 100644 --- ql/src/java/org/apache/hadoop/hive/ql/exec/vector/VectorizedBatchUtil.java +++ ql/src/java/org/apache/hadoop/hive/ql/exec/vector/VectorizedBatchUtil.java @@ -173,7 +173,7 @@ public static void AddRowToBatch(Object row, StructObjectInspector oi, int rowIn LongColumnVector lcv = (LongColumnVector) batch.cols[i]; if (writableCol != null) { Timestamp t = ((TimestampWritable) writableCol).getTimestamp(); - lcv.vector[rowIndex] = (t.getTime() * 1000000) + (t.getNanos() % 1000000); + lcv.vector[rowIndex] = TimestampUtils.getTimeNanoSec(t); lcv.isNull[rowIndex] = false; } else { lcv.vector[rowIndex] = 1; diff --git ql/src/java/org/apache/hadoop/hive/ql/exec/vector/VectorizedColumnarSerDe.java ql/src/java/org/apache/hadoop/hive/ql/exec/vector/VectorizedColumnarSerDe.java index aeff313..69553d9 100644 --- ql/src/java/org/apache/hadoop/hive/ql/exec/vector/VectorizedColumnarSerDe.java +++ ql/src/java/org/apache/hadoop/hive/ql/exec/vector/VectorizedColumnarSerDe.java @@ -161,8 +161,7 @@ public Writable serializeVector(VectorizedRowBatch vrg, ObjectInspector objInspe LongColumnVector tcv = (LongColumnVector) batch.cols[k]; long timeInNanoSec = tcv.vector[rowIndex]; Timestamp t = new Timestamp(0); - t.setTime((timeInNanoSec)/1000000); - t.setNanos((int)((t.getNanos()) + (timeInNanoSec % 1000000))); + TimestampUtils.assignTimeInNanoSec(timeInNanoSec, t); TimestampWritable tw = new TimestampWritable(); tw.set(t); LazyTimestamp.writeUTF8(serializeVectorStream, tw); diff --git ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/VectorExpressionWriter.java ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/VectorExpressionWriter.java new file mode 100644 index 0000000..890cf4c --- /dev/null +++ ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/VectorExpressionWriter.java @@ -0,0 +1,36 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.hadoop.hive.ql.exec.vector.expressions; + +import org.apache.hadoop.hive.ql.exec.vector.ColumnVector; +import org.apache.hadoop.hive.ql.metadata.HiveException; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; +import org.apache.hadoop.io.Writable; + +/** + * Interface used to create Writable objects from vector expression primitives. + * + */ +public interface VectorExpressionWriter { + ObjectInspector getObjectInspector(); + Object writeValue(ColumnVector column, int row) throws HiveException; + Object writeValue(long value) throws HiveException; + Object writeValue(double value) throws HiveException; + Object writeValue(byte[] value, int start, int length) throws HiveException; +} diff --git ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/VectorExpressionWriterFactory.java ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/VectorExpressionWriterFactory.java new file mode 100644 index 0000000..be1668b --- /dev/null +++ ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/VectorExpressionWriterFactory.java @@ -0,0 +1,390 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.hadoop.hive.ql.exec.vector.expressions; + +import java.sql.Timestamp; +import java.util.ArrayList; +import java.util.List; + +import org.apache.hadoop.hive.ql.exec.vector.BytesColumnVector; +import org.apache.hadoop.hive.ql.exec.vector.ColumnVector; +import org.apache.hadoop.hive.ql.exec.vector.DoubleColumnVector; +import org.apache.hadoop.hive.ql.exec.vector.LongColumnVector; +import org.apache.hadoop.hive.ql.exec.vector.TimestampUtils; +import org.apache.hadoop.hive.ql.exec.vector.expressions.VectorExpressionWriterFactory.Closure; +import org.apache.hadoop.hive.ql.metadata.HiveException; +import org.apache.hadoop.hive.ql.plan.AggregationDesc; +import org.apache.hadoop.hive.ql.plan.ExprNodeDesc; +import org.apache.hadoop.hive.serde2.io.ShortWritable; +import org.apache.hadoop.hive.serde2.io.TimestampWritable; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory; +import org.apache.hadoop.io.ByteWritable; +import org.apache.hadoop.io.BooleanWritable; +import org.apache.hadoop.io.BytesWritable; +import org.apache.hadoop.io.DoubleWritable; +import org.apache.hadoop.io.FloatWritable; +import org.apache.hadoop.io.IntWritable; +import org.apache.hadoop.io.LongWritable; + +/** + * VectorExpressionWritableFactory helper class for generating VectorExpressionWritable objects. + */ +public final class VectorExpressionWriterFactory { + + /** + * VectorExpressionWriter base implementation, to be specialized for Long/Double/Bytes columns + */ + private static abstract class VectorExpressionWriterBase implements VectorExpressionWriter { + + protected ObjectInspector objectInspector; + + /** + * The object inspector associated with this expression. This is created from the expression + * NodeDesc (compile metadata) not from the VectorColumn info and thus preserves the type info + * lost by the vectorization process. + */ + public ObjectInspector getObjectInspector() { + return objectInspector; + } + + public VectorExpressionWriter init(ExprNodeDesc nodeDesc) throws HiveException { + this.objectInspector = nodeDesc.getWritableObjectInspector(); + return this; + } + + /** + * The base implementation must be overridden by the Long specialization + */ + @Override + public Object writeValue(long value) throws HiveException { + throw new HiveException("Internal error: should not reach here"); + } + + /** + * The base implementation must be overridden by the Double specialization + */ + @Override + public Object writeValue(double value) throws HiveException { + throw new HiveException("Internal error: should not reach here"); + } + + /** + * The base implementation must be overridden by the Bytes specialization + */ + @Override + public Object writeValue(byte[] value, int start, int length) throws HiveException { + throw new HiveException("Internal error: should not reach here"); + } + } + + /** + * Specialized writer for LongVectorColumn expressions. Will throw cast exception + * if the wrong vector column is used. + */ + private static abstract class VectorExpressionWriterLong + extends VectorExpressionWriterBase { + @Override + public Object writeValue(ColumnVector column, int row) throws HiveException { + LongColumnVector lcv = (LongColumnVector) column; + if (lcv.noNulls && !lcv.isRepeating) { + return writeValue(lcv.vector[row]); + } else if (lcv.noNulls && lcv.isRepeating) { + return writeValue(lcv.vector[0]); + } else if (!lcv.noNulls && !lcv.isRepeating && !lcv.isNull[row]) { + return writeValue(lcv.vector[row]); + } else if (!lcv.noNulls && !lcv.isRepeating && lcv.isNull[row]) { + return null; + } else if (!lcv.noNulls && lcv.isRepeating && !lcv.isNull[0]) { + return writeValue(lcv.vector[0]); + } else if (!lcv.noNulls && lcv.isRepeating && lcv.isNull[0]) { + return null; + } + throw new HiveException( + String.format( + "Incorrect null/repeating: row:%d noNulls:%b isRepeating:%b isNull[row]:%b isNull[0]:%b", + row, lcv.noNulls, lcv.isRepeating, lcv.isNull[row], lcv.isNull[0])); + } + } + + /** + * Specialized writer for DoubleColumnVector. Will throw cast exception + * if the wrong vector column is used. + */ + private static abstract class VectorExpressionWriterDouble extends VectorExpressionWriterBase { + @Override + public Object writeValue(ColumnVector column, int row) throws HiveException { + DoubleColumnVector dcv = (DoubleColumnVector) column; + if (dcv.noNulls && !dcv.isRepeating) { + return writeValue(dcv.vector[row]); + } else if (dcv.noNulls && dcv.isRepeating) { + return writeValue(dcv.vector[0]); + } else if (!dcv.noNulls && !dcv.isRepeating && !dcv.isNull[row]) { + return writeValue(dcv.vector[row]); + } else if (!dcv.noNulls && !dcv.isRepeating && dcv.isNull[row]) { + return null; + } else if (!dcv.noNulls && dcv.isRepeating && !dcv.isNull[0]) { + return writeValue(dcv.vector[0]); + } else if (!dcv.noNulls && dcv.isRepeating && dcv.isNull[0]) { + return null; + } + throw new HiveException( + String.format( + "Incorrect null/repeating: row:%d noNulls:%b isRepeating:%b isNull[row]:%b isNull[0]:%b", + row, dcv.noNulls, dcv.isRepeating, dcv.isNull[row], dcv.isNull[0])); + } + } + + /** + * Specialized writer for BytesColumnVector. Will throw cast exception + * if the wrong vector column is used. + */ + private static abstract class VectorExpressionWriterBytes extends VectorExpressionWriterBase { + @Override + public Object writeValue(ColumnVector column, int row) throws HiveException { + BytesColumnVector bcv = (BytesColumnVector) column; + if (bcv.noNulls && !bcv.isRepeating) { + return writeValue(bcv.vector[row], bcv.start[row], bcv.length[row]); + } else if (bcv.noNulls && bcv.isRepeating) { + return writeValue(bcv.vector[0], bcv.start[0], bcv.length[0]); + } else if (!bcv.noNulls && !bcv.isRepeating && !bcv.isNull[row]) { + return writeValue(bcv.vector[row], bcv.start[row], bcv.length[row]); + } else if (!bcv.noNulls && !bcv.isRepeating && bcv.isNull[row]) { + return null; + } else if (!bcv.noNulls && bcv.isRepeating && !bcv.isNull[0]) { + return writeValue(bcv.vector[0], bcv.start[0], bcv.length[0]); + } else if (!bcv.noNulls && bcv.isRepeating && bcv.isNull[0]) { + return null; + } + throw new HiveException( + String.format( + "Incorrect null/repeating: row:%d noNulls:%b isRepeating:%b isNull[row]:%b isNull[0]:%b", + row, bcv.noNulls, bcv.isRepeating, bcv.isNull[row], bcv.isNull[0])); + } + } + + /** + * Compiles the appropriate vector expression writer based on an expression info (ExprNodeDesc) + */ + public static VectorExpressionWriter genVectorExpressionWritable(ExprNodeDesc nodeDesc) + throws HiveException { + String nodeType = nodeDesc.getTypeString(); + if (nodeType.equalsIgnoreCase("tinyint")) { + return new VectorExpressionWriterLong() + { + private ByteWritable writable; + + @Override + public VectorExpressionWriter init(ExprNodeDesc nodeDesc) throws HiveException { + super.init(nodeDesc); + writable = new ByteWritable(); + return this; + } + + @Override + public Object writeValue(long value) { + writable.set((byte) value); + return writable; + } + }.init(nodeDesc); + } else if (nodeType.equalsIgnoreCase("smallint")) { + return new VectorExpressionWriterLong() + { + private ShortWritable writable; + @Override + public VectorExpressionWriter init(ExprNodeDesc nodeDesc) throws HiveException { + super.init(nodeDesc); + writable = new ShortWritable(); + return this; + } + + @Override + public Object writeValue(long value) { + writable.set((short) value); + return writable; + } + }.init(nodeDesc); + } else if (nodeType.equalsIgnoreCase("int")) { + return new VectorExpressionWriterLong() + { + private IntWritable writable; + @Override + public VectorExpressionWriter init(ExprNodeDesc nodeDesc) throws HiveException { + super.init(nodeDesc); + writable = new IntWritable(); + return this; + } + + @Override + public Object writeValue(long value) { + writable.set((int) value); + return writable; + } + }.init(nodeDesc); + } else if (nodeType.equalsIgnoreCase("bigint")) { + return new VectorExpressionWriterLong() + { + private LongWritable writable; + @Override + public VectorExpressionWriter init(ExprNodeDesc nodeDesc) throws HiveException { + super.init(nodeDesc); + writable = new LongWritable(); + return this; + } + + @Override + public Object writeValue(long value) { + writable.set(value); + return writable; + } + }.init(nodeDesc); + } else if (nodeType.equalsIgnoreCase("boolean")) { + return new VectorExpressionWriterLong() + { + private BooleanWritable writable; + @Override + public VectorExpressionWriter init(ExprNodeDesc nodeDesc) throws HiveException { + super.init(nodeDesc); + writable = new BooleanWritable(); + return this; + } + + @Override + public Object writeValue(long value) { + writable.set(value != 0 ? true : false); + return writable; + } + }.init(nodeDesc); + } else if (nodeType.equalsIgnoreCase("timestamp")) { + return new VectorExpressionWriterLong() + { + private TimestampWritable writable; + private Timestamp timestamp; + @Override + public VectorExpressionWriter init(ExprNodeDesc nodeDesc) throws HiveException { + super.init(nodeDesc); + writable = new TimestampWritable(); + timestamp = new Timestamp(0); + return this; + } + + @Override + public Object writeValue(long value) { + TimestampUtils.assignTimeInNanoSec(value, timestamp); + writable.set(timestamp); + return writable; + } + }.init(nodeDesc); + } else if (nodeType.equalsIgnoreCase("string")) { + return new VectorExpressionWriterBytes() + { + private BytesWritable writable; + public VectorExpressionWriter init(ExprNodeDesc nodeDesc) throws HiveException { + super.init(nodeDesc); + writable = new BytesWritable(); + return this; + } + + @Override + public Object writeValue(byte[] value, int start, int length) throws HiveException { + writable.set(value, start, length); + return writable; + } + }.init(nodeDesc); + } else if (nodeType.equalsIgnoreCase("float")) { + return new VectorExpressionWriterDouble() + { + private FloatWritable writable; + @Override + public VectorExpressionWriter init(ExprNodeDesc nodeDesc) throws HiveException { + super.init(nodeDesc); + writable = new FloatWritable(); + return this; + } + + @Override + public Object writeValue(double value) { + writable.set((float)value); + return writable; + } + }.init(nodeDesc); + } else if (nodeType.equalsIgnoreCase("double")) { + return new VectorExpressionWriterDouble() + { + private DoubleWritable writable; + @Override + public VectorExpressionWriter init(ExprNodeDesc nodeDesc) throws HiveException { + super.init(nodeDesc); + writable = new DoubleWritable(); + return this; + } + + @Override + public Object writeValue(double value) { + writable.set(value); + return writable; + } + }.init(nodeDesc); + } + + throw new HiveException(String.format( + "Unimplemented genVectorExpressionWritable type: %s for expression: %s", + nodeType, nodeDesc)); + } + + /** + * Helper function to create an array of writers from a list of expression descriptors. + */ + public static VectorExpressionWriter[] getExpressionWriters(List nodesDesc) + throws HiveException { + VectorExpressionWriter[] writers = new VectorExpressionWriter[nodesDesc.size()]; + for(int i=0; i nodesDesc, + List outputColumnNames, + Closure closure) + throws HiveException { + VectorExpressionWriter[] writers = getExpressionWriters(nodesDesc); + List oids = new ArrayList(writers.length); + for(int i=0; i expected) throws HiveException { + + Map mapColumnNames = new HashMap(); + mapColumnNames.put("Key", 0); + mapColumnNames.put("Value", 1); + VectorizationContext ctx = new VectorizationContext(mapColumnNames, 2); + Set keys = new HashSet(); + + AggregationDesc agg = buildAggregationDesc(ctx, aggregateName, + "Value", TypeInfoFactory.getPrimitiveTypeInfo(data.getTypes()[1])); + ArrayList aggs = new ArrayList(); + aggs.add(agg); + + ArrayList outputColumnNames = new ArrayList(); + outputColumnNames.add("_col0"); + + GroupByDesc desc = new GroupByDesc(); + desc.setOutputColumnNames(outputColumnNames); + desc.setAggregators(aggs); + + ExprNodeDesc keyExp = buildColumnDesc(ctx, "Key", + TypeInfoFactory.getPrimitiveTypeInfo(data.getTypes()[0])); + ArrayList keysDesc = new ArrayList(); + keysDesc.add(keyExp); + desc.setKeys(keysDesc); + + VectorGroupByOperator vgo = new VectorGroupByOperator(ctx, desc); + + FakeCaptureOutputOperator out = FakeCaptureOutputOperator.addCaptureOutputChild(vgo); + vgo.initialize(null, null); + out.setOutputInspector(new FakeCaptureOutputOperator.OutputInspector() { + + private int rowIndex; + private String aggregateName; + private Map expected; + private Set keys; + + @Override + public void inspectRow(Object row, int tag) throws HiveException { + assertTrue(row instanceof Object[]); + Object[] fields = (Object[]) row; + assertEquals(2, fields.length); + Object key = fields[0]; + Object keyValue = null; + if (null == key) { + keyValue = null; + } else if (key instanceof ByteWritable) { + ByteWritable bwKey = (ByteWritable)key; + keyValue = bwKey.get(); + } else if (key instanceof ShortWritable) { + ShortWritable swKey = (ShortWritable)key; + keyValue = swKey.get(); + } else if (key instanceof IntWritable) { + IntWritable iwKey = (IntWritable)key; + keyValue = iwKey.get(); + } else if (key instanceof LongWritable) { + LongWritable lwKey = (LongWritable)key; + keyValue = lwKey.get(); + } else if (key instanceof TimestampWritable) { + TimestampWritable twKey = (TimestampWritable)key; + keyValue = twKey.getTimestamp(); + } else if (key instanceof DoubleWritable) { + DoubleWritable dwKey = (DoubleWritable)key; + keyValue = dwKey.get(); + } else if (key instanceof FloatWritable) { + FloatWritable fwKey = (FloatWritable)key; + keyValue = fwKey.get(); + } else if (key instanceof BooleanWritable) { + BooleanWritable bwKey = (BooleanWritable)key; + keyValue = bwKey.get(); + } else { + Assert.fail(String.format("Not implemented key output type %s: %s", + key.getClass().getName(), key)); + } + + assertTrue(expected.containsKey(keyValue)); + Object expectedValue = expected.get(keyValue); + Object value = fields[1]; + Validator validator = getValidator(aggregateName); + validator.validate(expectedValue, new Object[] {value}); + keys.add(keyValue); + } + + private FakeCaptureOutputOperator.OutputInspector init( + String aggregateName, Map expected, Set keys) { + this.aggregateName = aggregateName; + this.expected = expected; + this.keys = keys; + return this; + } + }.init(aggregateName, expected, keys)); + + for (VectorizedRowBatch unit: data) { + vgo.process(unit, 0); + } + vgo.close(false); + + List outBatchList = out.getCapturedRows(); + assertNotNull(outBatchList); + assertEquals(expected.size(), outBatchList.size()); + assertEquals(expected.size(), keys.size()); + } public void testAggregateLongRepeats (