diff --git a/ql/src/java/org/apache/hadoop/hive/ql/exec/VectorHashKeyWrapper.java b/ql/src/java/org/apache/hadoop/hive/ql/exec/VectorHashKeyWrapper.java index 35712d0..d9d10a2 100644 --- a/ql/src/java/org/apache/hadoop/hive/ql/exec/VectorHashKeyWrapper.java +++ b/ql/src/java/org/apache/hadoop/hive/ql/exec/VectorHashKeyWrapper.java @@ -20,29 +20,38 @@ import java.util.Arrays; +import org.apache.hadoop.hive.ql.exec.vector.expressions.StringExpr; import org.apache.hadoop.hive.ql.metadata.HiveException; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; /** * A hash map key wrapper for vectorized processing. * It stores the key values as primitives in arrays for each supported primitive type. - * This works in conjunction with + * This works in conjunction with * {@link org.apache.hadoop.hive.ql.exec.VectorHashKeyWrapperBatch VectorHashKeyWrapperBatch} - * to hash vectorized processing units (batches). + * to hash vectorized processing units (batches). */ public class VectorHashKeyWrapper extends KeyWrapper { - + private long[] longValues; private double[] doubleValues; + + private byte[][] byteValues; + private int[] byteStarts; + private int[] byteLengths; + private boolean[] isNull; private int hashcode; - - public VectorHashKeyWrapper(int longValuesCount, int doubleValuesCount) { + + public VectorHashKeyWrapper(int longValuesCount, int doubleValuesCount, int byteValuesCount) { longValues = new long[longValuesCount]; doubleValues = new double[doubleValuesCount]; - isNull = new boolean[longValuesCount + doubleValuesCount]; + byteValues = new byte[byteValuesCount][]; + byteStarts = new int[byteValuesCount]; + byteLengths = new int[byteValuesCount]; + isNull = new boolean[longValuesCount + doubleValuesCount + byteValuesCount]; } - + private VectorHashKeyWrapper() { } @@ -56,32 +65,90 @@ void setHashKey() { hashcode = Arrays.hashCode(longValues) ^ Arrays.hashCode(doubleValues) ^ Arrays.hashCode(isNull); + + // This code, with branches and all, is not executed if there are no string keys + for (int i = 0; i < byteValues.length; ++i) { + /* + * Hashing the string is potentially expensive so is better to branch. + * Additionally not looking at values for nulls allows us not reset the values. + */ + if (!isNull[longValues.length + doubleValues.length + i]) { + byte[] bytes = byteValues[i]; + int start = byteStarts[i]; + int length = byteLengths[i]; + if (length == bytes.length && start == 0) { + hashcode ^= Arrays.hashCode(bytes); + } + else { + // Unfortunately there is no Arrays.hashCode(byte[], start, length) + for(int j = start; j < start + length; ++start) { + // use 461 as is a (sexy!) prime. + hashcode ^= 461 * bytes[j]; + } + } + } + } } - + @Override public int hashCode() { return hashcode; } - - @Override + + @Override public boolean equals(Object that) { if (that instanceof VectorHashKeyWrapper) { VectorHashKeyWrapper keyThat = (VectorHashKeyWrapper)that; return hashcode == keyThat.hashcode && Arrays.equals(longValues, keyThat.longValues) && Arrays.equals(doubleValues, keyThat.doubleValues) && - Arrays.equals(isNull, keyThat.isNull); + Arrays.equals(isNull, keyThat.isNull) && + byteValues.length == keyThat.byteValues.length && + (0 == byteValues.length || bytesEquals(keyThat)); } return false; } - + + private boolean bytesEquals(VectorHashKeyWrapper keyThat) { + //By the time we enter here the byteValues.lentgh and isNull must have already been compared + for (int i = 0; i < byteValues.length; ++i) { + // the byte comparison is potentially expensive so is better to branch on null + if (!isNull[longValues.length + doubleValues.length + i]) { + if (0 != StringExpr.compare( + byteValues[i], + byteStarts[i], + byteLengths[i], + keyThat.byteValues[i], + keyThat.byteStarts[i], + keyThat.byteLengths[i])) { + return false; + } + } + } + return true; + } + @Override protected Object clone() { VectorHashKeyWrapper clone = new VectorHashKeyWrapper(); clone.longValues = longValues.clone(); clone.doubleValues = doubleValues.clone(); clone.isNull = isNull.clone(); + + clone.byteValues = new byte[byteValues.length][]; + clone.byteStarts = new int[byteValues.length]; + clone.byteLengths = byteLengths.clone(); + for (int i = 0; i < byteValues.length; ++i) { + // avoid allocation/copy of nulls, because it potentially expensive. branch instead. + if (!isNull[i]) { + clone.byteValues[i] = Arrays.copyOfRange( + byteValues[i], + byteStarts[i], + byteStarts[i] + byteLengths[i]); + } + } clone.hashcode = hashcode; + assert clone.equals(this); return clone; } @@ -121,19 +188,32 @@ public void assignNullLong(int index) { longValues[index] = 0; // assign 0 to simplify hashcode isNull[index] = true; } - + + public void assignString(int index, byte[] bytes, int start, int length) { + byteValues[index] = bytes; + byteStarts[index] = start; + byteLengths[index] = length; + isNull[longValues.length + doubleValues.length + index] = false; + } + + public void assignNullString(int index) { + // We do not assign the value to [] because the value is never used on null + isNull[longValues.length + doubleValues.length + index] = true; + } + @Override - public String toString() + public String toString() { - return String.format("%d[%s] %d[%s]", + return String.format("%d[%s] %d[%s] %d[%s]", longValues.length, Arrays.toString(longValues), - doubleValues.length, Arrays.toString(doubleValues)); + doubleValues.length, Arrays.toString(doubleValues), + byteValues.length, Arrays.toString(byteValues)); } public boolean getIsNull(int i) { return isNull[i]; } - + public long getLongValue(int i) { return longValues[i]; } @@ -142,4 +222,18 @@ public double getDoubleValue(int i) { return doubleValues[i - longValues.length]; } + public byte[] getBytes(int i) { + return byteValues[i - longValues.length - doubleValues.length]; + } + + public int getByteStart(int i) { + return byteStarts[i - longValues.length - doubleValues.length]; + } + + public int getByteLength(int i) { + return byteLengths[i - longValues.length - doubleValues.length]; + } + + } + diff --git a/ql/src/java/org/apache/hadoop/hive/ql/exec/VectorHashKeyWrapperBatch.java b/ql/src/java/org/apache/hadoop/hive/ql/exec/VectorHashKeyWrapperBatch.java index c23614c..2312536 100644 --- a/ql/src/java/org/apache/hadoop/hive/ql/exec/VectorHashKeyWrapperBatch.java +++ b/ql/src/java/org/apache/hadoop/hive/ql/exec/VectorHashKeyWrapperBatch.java @@ -19,22 +19,25 @@ package org.apache.hadoop.hive.ql.exec; import java.util.Arrays; + +import org.apache.hadoop.hive.ql.exec.vector.BytesColumnVector; import org.apache.hadoop.hive.ql.exec.vector.DoubleColumnVector; import org.apache.hadoop.hive.ql.exec.vector.LongColumnVector; import org.apache.hadoop.hive.ql.exec.vector.VectorizedRowBatch; import org.apache.hadoop.hive.ql.exec.vector.expressions.VectorExpression; import org.apache.hadoop.hive.ql.metadata.HiveException; import org.apache.hadoop.hive.serde2.io.DoubleWritable; +import org.apache.hadoop.io.BytesWritable; import org.apache.hadoop.io.LongWritable; /** - * Class for handling vectorized hash map key wrappers. It evaluates the key columns in a + * Class for handling vectorized hash map key wrappers. It evaluates the key columns in a * row batch in a vectorized fashion. * This class stores additional information about keys needed to evaluate and output the key values. * */ public class VectorHashKeyWrapperBatch { - + /** * Helper class for looking up a key value based on key index * @@ -42,53 +45,61 @@ private static class KeyLookupHelper { public int longIndex; public int doubleIndex; + public int stringIndex; } - + /** * The key expressions that require evaluation and output the primitive values for each key. */ private VectorExpression[] keyExpressions; - + /** * indices of LONG primitive keys */ private int[] longIndices; - + /** * indices of DOUBLE primitive keys */ private int[] doubleIndices; - + + /* + * indices of stirng (byte[]) primitive keys + */ + private int[] stringIndices; + /** - * pre-allocated batch size vector of keys wrappers. + * pre-allocated batch size vector of keys wrappers. * N.B. these keys are **mutable** and should never be used in a HashMap. - * Always clone the key wrapper to obtain an immutable keywrapper suitable + * Always clone the key wrapper to obtain an immutable keywrapper suitable * to use a key in a HashMap. */ private VectorHashKeyWrapper[] vectorHashKeyWrappers; - + /** * lookup vector to map from key index to primitive type index */ private KeyLookupHelper[] indexLookup; - + /** - * preallocated and reused LongWritable objects for emiting row mode key values + * preallocated and reused LongWritable objects for emiting row mode key values */ private LongWritable[] longKeyValueOutput; - + /** * preallocated and reused DoubleWritable objects for emiting row mode key values */ private DoubleWritable[] doubleKeyValueOutput; - + + private BytesWritable[] stringKeyValueOutput; + /** - * Accessor for the batch-sized array of key wrappers + * Accessor for the batch-sized array of key wrappers */ public VectorHashKeyWrapper[] getVectorHashKeyWrappers() { return vectorHashKeyWrappers; } - + /** * Processes a batch: * - * @param vrb + * @param batch * @throws HiveException */ - public void evaluateBatch (VectorizedRowBatch vrb) throws HiveException { + public void evaluateBatch (VectorizedRowBatch batch) throws HiveException { for(int i = 0; i < keyExpressions.length; ++i) { - keyExpressions[i].evaluate(vrb); + keyExpressions[i].evaluate(batch); } for(int i = 0; i< longIndices.length; ++i) { int keyIndex = longIndices[i]; int columnIndex = keyExpressions[keyIndex].getOutputColumn(); - LongColumnVector columnVector = (LongColumnVector) vrb.cols[columnIndex]; - if (columnVector.noNulls && !columnVector.isRepeating && !vrb.selectedInUse) { - assignLongNoNullsNoRepeatingNoSelection(i, vrb.size, columnVector); - } else if (columnVector.noNulls && !columnVector.isRepeating && vrb.selectedInUse) { - assignLongNoNullsNoRepeatingSelection(i, vrb.size, columnVector, vrb.selected); + LongColumnVector columnVector = (LongColumnVector) batch.cols[columnIndex]; + if (columnVector.noNulls && !columnVector.isRepeating && !batch.selectedInUse) { + assignLongNoNullsNoRepeatingNoSelection(i, batch.size, columnVector); + } else if (columnVector.noNulls && !columnVector.isRepeating && batch.selectedInUse) { + assignLongNoNullsNoRepeatingSelection(i, batch.size, columnVector, batch.selected); } else if (columnVector.noNulls && columnVector.isRepeating) { - assignLongNoNullsRepeating(i, vrb.size, columnVector); - } else if (!columnVector.noNulls && !columnVector.isRepeating && !vrb.selectedInUse) { - assignLongNullsNoRepeatingNoSelection(i, vrb.size, columnVector); + assignLongNoNullsRepeating(i, batch.size, columnVector); + } else if (!columnVector.noNulls && !columnVector.isRepeating && !batch.selectedInUse) { + assignLongNullsNoRepeatingNoSelection(i, batch.size, columnVector); } else if (!columnVector.noNulls && columnVector.isRepeating) { - assignLongNullsRepeating(i, vrb.size, columnVector); - } else if (!columnVector.noNulls && !columnVector.isRepeating && vrb.selectedInUse) { - assignLongNullsNoRepeatingSelection (i, vrb.size, columnVector, vrb.selected); + assignLongNullsRepeating(i, batch.size, columnVector); + } else if (!columnVector.noNulls && !columnVector.isRepeating && batch.selectedInUse) { + assignLongNullsNoRepeatingSelection (i, batch.size, columnVector, batch.selected); } else { throw new HiveException (String.format("Unimplemented Long null/repeat/selected combination %b/%b/%b", - columnVector.noNulls, columnVector.isRepeating, vrb.selectedInUse)); + columnVector.noNulls, columnVector.isRepeating, batch.selectedInUse)); } } for(int i=0;i= 0) { doubleKeyValueOutput[klh.doubleIndex].set(kw.getDoubleValue(i)); return doubleKeyValueOutput[klh.doubleIndex]; + } else if (klh.stringIndex >= 0) { + stringKeyValueOutput[klh.stringIndex].set( + kw.getBytes(i), kw.getByteStart(i), kw.getByteLength(i)); + return stringKeyValueOutput[klh.stringIndex]; } else { throw new HiveException(String.format( - "Internal inconsistent KeyLookupHelper at index [%d]:%d %d", - i, klh.longIndex, klh.doubleIndex)); + "Internal inconsistent KeyLookupHelper at index [%d]:%d %d %d", + i, klh.longIndex, klh.doubleIndex, klh.stringIndex)); } - } + } } diff --git a/ql/src/java/org/apache/hadoop/hive/ql/exec/vector/VectorizationContext.java b/ql/src/java/org/apache/hadoop/hive/ql/exec/vector/VectorizationContext.java index 1ef4955..609aa61 100644 --- a/ql/src/java/org/apache/hadoop/hive/ql/exec/vector/VectorizationContext.java +++ b/ql/src/java/org/apache/hadoop/hive/ql/exec/vector/VectorizationContext.java @@ -1063,7 +1063,7 @@ public ColumnVector allocateColumnVector(String type, int defaultSize) { } } - public ObjectInspector createObjectInspector(VectorExpression vectorExpression) + public ObjectInspector createObjectInspector(VectorExpression vectorExpression) throws HiveException { String columnType = vectorExpression.getOutputType(); if (columnType.equalsIgnoreCase("long") || @@ -1071,6 +1071,8 @@ public ObjectInspector createObjectInspector(VectorExpression vectorExpression) return PrimitiveObjectInspectorFactory.writableLongObjectInspector; } else if (columnType.equalsIgnoreCase("double")) { return PrimitiveObjectInspectorFactory.writableDoubleObjectInspector; + } else if (columnType.equalsIgnoreCase("string")) { + return PrimitiveObjectInspectorFactory.writableStringObjectInspector; } else { throw new HiveException(String.format("Must implement type %s", columnType)); } diff --git a/ql/src/test/org/apache/hadoop/hive/ql/exec/vector/FakeVectorRowBatchFromObjectIterables.java b/ql/src/test/org/apache/hadoop/hive/ql/exec/vector/FakeVectorRowBatchFromObjectIterables.java new file mode 100644 index 0000000..6824ee7 --- /dev/null +++ b/ql/src/test/org/apache/hadoop/hive/ql/exec/vector/FakeVectorRowBatchFromObjectIterables.java @@ -0,0 +1,140 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.hadoop.hive.ql.exec.vector; + +import java.util.ArrayList; +import java.util.Iterator; +import java.util.List; + +import org.apache.hadoop.hive.ql.exec.vector.util.FakeVectorRowBatchBase; +import org.apache.hadoop.hive.ql.metadata.HiveException; + +/** + * Test helper class that creates vectorized execution batches from arbitrary type iterables. + */ +public class FakeVectorRowBatchFromObjectIterables extends FakeVectorRowBatchBase { + + private final String[] types; + private final List> iterators; + private final VectorizedRowBatch batch; + private boolean eof; + private final int batchSize; + + /** + * Helper interface for assigning values to primitive vector column types. + */ + private static interface ColumnVectorAssign + { + public void assign( + ColumnVector columnVector, + int row, + Object value); + } + + private final ColumnVectorAssign[] columnAssign; + + public FakeVectorRowBatchFromObjectIterables(int batchSize, String[] types, + Iterable ...iterables) throws HiveException { + this.types = types; + this.batchSize = batchSize; + iterators = new ArrayList>(types.length); + columnAssign = new ColumnVectorAssign[types.length]; + + batch = new VectorizedRowBatch(types.length, batchSize); + for(int i=0; i< types.length; ++i) { + if (types[i].equalsIgnoreCase("long")) { + batch.cols[i] = new LongColumnVector(batchSize); + columnAssign[i] = new ColumnVectorAssign() { + @Override + public void assign( + ColumnVector columnVector, + int row, + Object value) { + LongColumnVector lcv = (LongColumnVector) columnVector; + lcv.vector[row] = (Long) value; + } + }; + } else if (types[i].equalsIgnoreCase("string")) { + batch.cols[i] = new BytesColumnVector(batchSize); + columnAssign[i] = new ColumnVectorAssign() { + @Override + public void assign( + ColumnVector columnVector, + int row, + Object value) { + BytesColumnVector bcv = (BytesColumnVector) columnVector; + String s = (String) value; + byte[] bytes = s.getBytes(); + bcv.vector[row] = bytes; + bcv.start[row] = 0; + bcv.length[row] = bytes.length; + } + }; + } else if (types[i].equalsIgnoreCase("double")) { + batch.cols[i] = new DoubleColumnVector(batchSize); + columnAssign[i] = new ColumnVectorAssign() { + @Override + public void assign( + ColumnVector columnVector, + int row, + Object value) { + DoubleColumnVector dcv = (DoubleColumnVector) columnVector; + dcv.vector[row] = (Double) value; + } + }; + } else { + throw new HiveException("Unimplemented type " + types[i]); + } + iterators.add(iterables[i].iterator()); + } + } + + @Override + public VectorizedRowBatch produceNextBatch() { + batch.size = 0; + batch.selectedInUse = false; + for (int i=0; i < types.length; ++i) { + ColumnVector col = batch.cols[i]; + col.noNulls = true; + col.isRepeating = false; + } + while (!eof && batch.size < this.batchSize){ + int r = batch.size; + for (int i=0; i < types.length; ++i) { + Iterator it = iterators.get(i); + if (!it.hasNext()) { + eof = true; + break; + } + Object value = it.next(); + if (null == value) { + batch.cols[i].isNull[batch.size] = true; + batch.cols[i].noNulls = false; + } else { + columnAssign[i].assign(batch.cols[i], batch.size, value); + } + } + if (!eof) { + batch.size += 1; + } + } + return batch; + } +} + diff --git a/ql/src/test/org/apache/hadoop/hive/ql/exec/vector/TestVectorGroupByOperator.java b/ql/src/test/org/apache/hadoop/hive/ql/exec/vector/TestVectorGroupByOperator.java index b3b5cd2..f9c05cf 100644 --- a/ql/src/test/org/apache/hadoop/hive/ql/exec/vector/TestVectorGroupByOperator.java +++ b/ql/src/test/org/apache/hadoop/hive/ql/exec/vector/TestVectorGroupByOperator.java @@ -35,7 +35,7 @@ import org.apache.hadoop.hive.ql.exec.vector.util.FakeCaptureOutputOperator; import org.apache.hadoop.hive.ql.exec.vector.util.FakeVectorRowBatchFromConcat; -import org.apache.hadoop.hive.ql.exec.vector.util.FakeVectorRowBatchFromIterables; +import org.apache.hadoop.hive.ql.exec.vector.util.FakeVectorRowBatchFromLongIterables; import org.apache.hadoop.hive.ql.exec.vector.util.FakeVectorRowBatchFromRepeats; import org.apache.hadoop.hive.ql.metadata.HiveException; import org.apache.hadoop.hive.ql.plan.AggregationDesc; @@ -43,7 +43,9 @@ import org.apache.hadoop.hive.ql.plan.ExprNodeDesc; import org.apache.hadoop.hive.ql.plan.GroupByDesc; import org.apache.hadoop.hive.serde2.io.DoubleWritable; +import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo; import org.apache.hadoop.hive.serde2.typeinfo.TypeInfoFactory; +import org.apache.hadoop.io.BytesWritable; import org.apache.hadoop.io.LongWritable; import org.junit.Test; @@ -54,9 +56,11 @@ private static ExprNodeDesc buildColumnDesc( VectorizationContext ctx, - String column) { + String column, + TypeInfo typeInfo) { + return new ExprNodeColumnDesc( - TypeInfoFactory.longTypeInfo, column, "table", false); + typeInfo, column, "table", false); } private static AggregationDesc buildAggregationDesc( @@ -64,7 +68,7 @@ private static AggregationDesc buildAggregationDesc( String aggregate, String column) { - ExprNodeDesc inputColumn = buildColumnDesc(ctx, column); + ExprNodeDesc inputColumn = buildColumnDesc(ctx, column, TypeInfoFactory.longTypeInfo); ArrayList params = new ArrayList(); params.add(inputColumn); @@ -99,19 +103,40 @@ private static GroupByDesc buildKeyGroupByDesc( VectorizationContext ctx, String aggregate, String column, + TypeInfo typeInfo, String key) { GroupByDesc desc = buildGroupByDesc(ctx, aggregate, column); - - ExprNodeDesc keyExp = buildColumnDesc(ctx, key); + + ExprNodeDesc keyExp = buildColumnDesc(ctx, key, typeInfo); ArrayList keys = new ArrayList(); keys.add(keyExp); desc.setKeys(keys); - + return desc; } @Test + public void testMinLongNullStringKeys() throws HiveException { + testAggregateStringKeyAggregate( + "min", + 2, + Arrays.asList(new Object[]{"A",null,"A",null}), + Arrays.asList(new Object[]{13L, 5L, 7L,19L}), + buildHashMap("A", 7L, null, 5L)); + } + + @Test + public void testMinLongStringKeys() throws HiveException { + testAggregateStringKeyAggregate( + "min", + 2, + Arrays.asList(new Object[]{"A","B","A","B"}), + Arrays.asList(new Object[]{13L, 5L, 7L,19L}), + buildHashMap("A", 7L, "B", 5L)); + } + + @Test public void testMinLongKeyGroupByCompactBatch() throws HiveException { testAggregateLongKeyAggregate( "min", @@ -120,7 +145,7 @@ public void testMinLongKeyGroupByCompactBatch() throws HiveException { Arrays.asList(new Long[]{13L,5L,7L,19L}), buildHashMap(1L, 5L, 2L, 7L)); } - + @Test public void testMinLongKeyGroupBySingleBatch() throws HiveException { testAggregateLongKeyAggregate( @@ -130,7 +155,7 @@ public void testMinLongKeyGroupBySingleBatch() throws HiveException { Arrays.asList(new Long[]{13L,5L,7L,19L}), buildHashMap(1L, 5L, 2L, 7L)); } - + @Test public void testMinLongKeyGroupByCrossBatch() throws HiveException { testAggregateLongKeyAggregate( @@ -170,7 +195,7 @@ public void testMaxLongNullKeyGroupBySingleBatch() throws HiveException { Arrays.asList(new Long[]{13L,5L,7L,19L}), buildHashMap(null, 13L, 2L, 19L)); } - + @Test public void testCountLongNullKeyGroupBySingleBatch() throws HiveException { testAggregateLongKeyAggregate( @@ -180,7 +205,7 @@ public void testCountLongNullKeyGroupBySingleBatch() throws HiveException { Arrays.asList(new Long[]{13L,5L,7L,19L}), buildHashMap(null, 2L, 2L, 2L)); } - + @Test public void testSumLongNullKeyGroupBySingleBatch() throws HiveException { testAggregateLongKeyAggregate( @@ -190,7 +215,7 @@ public void testSumLongNullKeyGroupBySingleBatch() throws HiveException { Arrays.asList(new Long[]{13L,5L,7L,19L}), buildHashMap(null, 20L, 2L, 24L)); } - + @Test public void testAvgLongNullKeyGroupBySingleBatch() throws HiveException { testAggregateLongKeyAggregate( @@ -210,7 +235,7 @@ public void testVarLongNullKeyGroupBySingleBatch() throws HiveException { Arrays.asList(new Long[]{13L, 5L,18L,19L,12L,15L}), buildHashMap(null, 0.0, 2L, 49.0, 01L, 6.0)); } - + @Test public void testMinNullLongNullKeyGroupBy() throws HiveException { testAggregateLongKeyAggregate( @@ -230,7 +255,7 @@ public void testMinLongGroupBy() throws HiveException { 5L); } - + @Test public void testMinLongSimple() throws HiveException { testAggregateLongAggregate( @@ -408,7 +433,7 @@ public void testMinLongRepeatConcatValues () throws HiveException { new FakeVectorRowBatchFromConcat( new FakeVectorRowBatchFromRepeats( new Long[] {19L}, 10, 2), - new FakeVectorRowBatchFromIterables( + new FakeVectorRowBatchFromLongIterables( 3, Arrays.asList(new Long[]{13L, 7L, 23L, 29L}))), 7L); @@ -485,7 +510,7 @@ public void testCountLongRepeatConcatValues () throws HiveException { new FakeVectorRowBatchFromConcat( new FakeVectorRowBatchFromRepeats( new Long[] {19L}, 10, 2), - new FakeVectorRowBatchFromIterables( + new FakeVectorRowBatchFromLongIterables( 3, Arrays.asList(new Long[]{13L, 7L, 23L, 29L}))), 14L); @@ -561,7 +586,7 @@ public void testSumLongRepeatConcatValues () throws HiveException { new FakeVectorRowBatchFromConcat( new FakeVectorRowBatchFromRepeats( new Long[] {19L}, 10, 2), - new FakeVectorRowBatchFromIterables( + new FakeVectorRowBatchFromLongIterables( 3, Arrays.asList(new Long[]{13L, 7L, 23L, 29L}))), 19L*10L + 13L + 7L + 23L +29L); @@ -692,7 +717,7 @@ public void testAvgLongRepeatConcatValues () throws HiveException { new FakeVectorRowBatchFromConcat( new FakeVectorRowBatchFromRepeats( new Long[] {19L}, 10, 2), - new FakeVectorRowBatchFromIterables( + new FakeVectorRowBatchFromLongIterables( 3, Arrays.asList(new Long[]{13L, 7L, 23L, 29L}))), (double) (19L*10L + 13L + 7L + 23L +29L) / (double) 14 ); @@ -873,7 +898,7 @@ public void testAggregateLongRepeats ( new Long[] {value}, repeat, batchSize); testAggregateLongIterable (aggregateName, fdr, expected); } - + public HashMap buildHashMap(Object... pairs) { HashMap map = new HashMap(); for(int i = 0; i < pairs.length; i += 2) { @@ -882,19 +907,35 @@ public void testAggregateLongRepeats ( return map; } + public void testAggregateStringKeyAggregate ( + String aggregateName, + int batchSize, + Iterable list, + Iterable values, + HashMap expected) throws HiveException { + + @SuppressWarnings("unchecked") + FakeVectorRowBatchFromObjectIterables fdr = new FakeVectorRowBatchFromObjectIterables( + batchSize, + new String[] {"string", "long"}, + list, + values); + testAggregateStringKeyIterable (aggregateName, fdr, expected); + } + public void testAggregateLongKeyAggregate ( String aggregateName, int batchSize, - Iterable keys, + List list, Iterable values, HashMap expected) throws HiveException { @SuppressWarnings("unchecked") - FakeVectorRowBatchFromIterables fdr = new FakeVectorRowBatchFromIterables(batchSize, keys, values); + FakeVectorRowBatchFromLongIterables fdr = new FakeVectorRowBatchFromLongIterables(batchSize, list, values); testAggregateLongKeyIterable (aggregateName, fdr, expected); } - + public void testAggregateLongAggregate ( String aggregateName, int batchSize, @@ -902,7 +943,7 @@ public void testAggregateLongAggregate ( Object expected) throws HiveException { @SuppressWarnings("unchecked") - FakeVectorRowBatchFromIterables fdr = new FakeVectorRowBatchFromIterables(batchSize, values); + FakeVectorRowBatchFromLongIterables fdr = new FakeVectorRowBatchFromLongIterables(batchSize, values); testAggregateLongIterable (aggregateName, fdr, expected); } @@ -1085,19 +1126,20 @@ public void testAggregateLongKeyIterable ( VectorizationContext ctx = new VectorizationContext(mapColumnNames, 2); Set keys = new HashSet(); - GroupByDesc desc = buildKeyGroupByDesc (ctx, aggregateName, "Value", "Key"); + GroupByDesc desc = buildKeyGroupByDesc (ctx, aggregateName, "Value", + TypeInfoFactory.longTypeInfo, "Key"); VectorGroupByOperator vgo = new VectorGroupByOperator(ctx, desc); FakeCaptureOutputOperator out = FakeCaptureOutputOperator.addCaptureOutputChild(vgo); vgo.initialize(null, null); out.setOutputInspector(new FakeCaptureOutputOperator.OutputInspector() { - + private int rowIndex; private String aggregateName; private HashMap expected; private Set keys; - + @Override public void inspectRow(Object row, int tag) throws HiveException { assertTrue(row instanceof Object[]); @@ -1117,7 +1159,72 @@ public void inspectRow(Object row, int tag) throws HiveException { validator.validate(expectedValue, new Object[] {value}); keys.add(keyValue); } - + + private FakeCaptureOutputOperator.OutputInspector init( + String aggregateName, HashMap expected, Set keys) { + this.aggregateName = aggregateName; + this.expected = expected; + this.keys = keys; + return this; + } + }.init(aggregateName, expected, keys)); + + for (VectorizedRowBatch unit: data) { + vgo.process(unit, 0); + } + vgo.close(false); + + List outBatchList = out.getCapturedRows(); + assertNotNull(outBatchList); + assertEquals(expected.size(), outBatchList.size()); + assertEquals(expected.size(), keys.size()); + } + + public void testAggregateStringKeyIterable ( + String aggregateName, + Iterable data, + HashMap expected) throws HiveException { + Map mapColumnNames = new HashMap(); + mapColumnNames.put("Key", 0); + mapColumnNames.put("Value", 1); + VectorizationContext ctx = new VectorizationContext(mapColumnNames, 2); + Set keys = new HashSet(); + + GroupByDesc desc = buildKeyGroupByDesc (ctx, aggregateName, "Value", + TypeInfoFactory.stringTypeInfo, "Key"); + + VectorGroupByOperator vgo = new VectorGroupByOperator(ctx, desc); + + FakeCaptureOutputOperator out = FakeCaptureOutputOperator.addCaptureOutputChild(vgo); + vgo.initialize(null, null); + out.setOutputInspector(new FakeCaptureOutputOperator.OutputInspector() { + + private int rowIndex; + private String aggregateName; + private HashMap expected; + private Set keys; + + @SuppressWarnings("deprecation") + @Override + public void inspectRow(Object row, int tag) throws HiveException { + assertTrue(row instanceof Object[]); + Object[] fields = (Object[]) row; + assertEquals(2, fields.length); + Object key = fields[0]; + String keyValue = null; + if (null != key) { + assertTrue(key instanceof BytesWritable); + BytesWritable bwKey = (BytesWritable)key; + keyValue = new String(bwKey.get()); + } + assertTrue(expected.containsKey(keyValue)); + Object expectedValue = expected.get(keyValue); + Object value = fields[1]; + Validator validator = getValidator(aggregateName); + validator.validate(expectedValue, new Object[] {value}); + keys.add(keyValue); + } + private FakeCaptureOutputOperator.OutputInspector init( String aggregateName, HashMap expected, Set keys) { this.aggregateName = aggregateName; @@ -1131,11 +1238,13 @@ public void inspectRow(Object row, int tag) throws HiveException { vgo.process(unit, 0); } vgo.close(false); - + List outBatchList = out.getCapturedRows(); assertNotNull(outBatchList); assertEquals(expected.size(), outBatchList.size()); assertEquals(expected.size(), keys.size()); } + + } diff --git a/ql/src/test/org/apache/hadoop/hive/ql/exec/vector/util/FakeVectorRowBatchFromIterables.java b/ql/src/test/org/apache/hadoop/hive/ql/exec/vector/util/FakeVectorRowBatchFromIterables.java deleted file mode 100644 index cf3399d..0000000 --- a/ql/src/test/org/apache/hadoop/hive/ql/exec/vector/util/FakeVectorRowBatchFromIterables.java +++ /dev/null @@ -1,86 +0,0 @@ -/** - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.hadoop.hive.ql.exec.vector.util; - -import java.util.ArrayList; -import java.util.Iterator; -import java.util.List; - -import org.apache.hadoop.hive.ql.exec.vector.ColumnVector; -import org.apache.hadoop.hive.ql.exec.vector.LongColumnVector; -import org.apache.hadoop.hive.ql.exec.vector.VectorizedRowBatch; - -/** - * VectorizedRowBatch test source from individual column values (as iterables) - * Used in unit test only. - */ -public class FakeVectorRowBatchFromIterables extends FakeVectorRowBatchBase { - private VectorizedRowBatch vrg; - private final int numCols; - private final int batchSize; - private List> iterators; - private boolean eof; - - public FakeVectorRowBatchFromIterables(int batchSize, Iterable...iterables) { - numCols = iterables.length; - this.batchSize = batchSize; - iterators = new ArrayList>(); - vrg = new VectorizedRowBatch(numCols, batchSize); - for (int i =0; i < numCols; i++) { - vrg.cols[i] = new LongColumnVector(batchSize); - iterators.add(iterables[i].iterator()); - } - } - - @Override - public VectorizedRowBatch produceNextBatch() { - vrg.size = 0; - vrg.selectedInUse = false; - for (int i=0; i < numCols; ++i) { - ColumnVector col = vrg.cols[i]; - col.noNulls = true; - col.isRepeating = false; - } - while (!eof && vrg.size < this.batchSize){ - int r = vrg.size; - for (int i=0; i < numCols; ++i) { - Iterator it = iterators.get(i); - if (!it.hasNext()) { - eof = true; - break; - } - LongColumnVector col = (LongColumnVector)vrg.cols[i]; - Long value = it.next(); - if (null == value) { - col.noNulls = false; - col.isNull[vrg.size] = true; - } else { - long[] vector = col.vector; - vector[r] = value; - col.isNull[vrg.size] = false; - } - } - if (!eof) { - vrg.size += 1; - } - } - return vrg; - } -} - diff --git a/ql/src/test/org/apache/hadoop/hive/ql/exec/vector/util/FakeVectorRowBatchFromLongIterables.java b/ql/src/test/org/apache/hadoop/hive/ql/exec/vector/util/FakeVectorRowBatchFromLongIterables.java new file mode 100644 index 0000000..1ececc7 --- /dev/null +++ b/ql/src/test/org/apache/hadoop/hive/ql/exec/vector/util/FakeVectorRowBatchFromLongIterables.java @@ -0,0 +1,86 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.hadoop.hive.ql.exec.vector.util; + +import java.util.ArrayList; +import java.util.Iterator; +import java.util.List; + +import org.apache.hadoop.hive.ql.exec.vector.ColumnVector; +import org.apache.hadoop.hive.ql.exec.vector.LongColumnVector; +import org.apache.hadoop.hive.ql.exec.vector.VectorizedRowBatch; + +/** + * VectorizedRowBatch test source from individual column values (as iterables) + * Used in unit test only. + */ +public class FakeVectorRowBatchFromLongIterables extends FakeVectorRowBatchBase { + private VectorizedRowBatch batch; + private final int numCols; + private final int batchSize; + private List> iterators; + private boolean eof; + + public FakeVectorRowBatchFromLongIterables(int batchSize, Iterable...iterables) { + numCols = iterables.length; + this.batchSize = batchSize; + iterators = new ArrayList>(); + batch = new VectorizedRowBatch(numCols, batchSize); + for (int i =0; i < numCols; i++) { + batch.cols[i] = new LongColumnVector(batchSize); + iterators.add(iterables[i].iterator()); + } + } + + @Override + public VectorizedRowBatch produceNextBatch() { + batch.size = 0; + batch.selectedInUse = false; + for (int i=0; i < numCols; ++i) { + ColumnVector col = batch.cols[i]; + col.noNulls = true; + col.isRepeating = false; + } + while (!eof && batch.size < this.batchSize){ + int r = batch.size; + for (int i=0; i < numCols; ++i) { + Iterator it = iterators.get(i); + if (!it.hasNext()) { + eof = true; + break; + } + LongColumnVector col = (LongColumnVector)batch.cols[i]; + Long value = it.next(); + if (null == value) { + col.noNulls = false; + col.isNull[batch.size] = true; + } else { + long[] vector = col.vector; + vector[r] = value; + col.isNull[batch.size] = false; + } + } + if (!eof) { + batch.size += 1; + } + } + return batch; + } +} +