diff --git ql/src/java/org/apache/hadoop/hive/ql/exec/vector/VectorHashKeyWrapper.java ql/src/java/org/apache/hadoop/hive/ql/exec/vector/VectorHashKeyWrapper.java index 8d1f0e1..f083d86 100644 --- ql/src/java/org/apache/hadoop/hive/ql/exec/vector/VectorHashKeyWrapper.java +++ ql/src/java/org/apache/hadoop/hive/ql/exec/vector/VectorHashKeyWrapper.java @@ -197,7 +197,7 @@ public void assignString(int index, byte[] bytes, int start, int length) { } public void assignNullString(int index) { - // We do not assign the value to [] because the value is never used on null + // We do not assign the value to byteValues[] because the value is never used on null isNull[longValues.length + doubleValues.length + index] = true; } @@ -210,28 +210,37 @@ public String toString() byteValues.length, Arrays.toString(byteValues)); } - public boolean getIsNull(int i) { + public boolean getIsLongNull(int i) { return isNull[i]; } + public boolean getIsDoubleNull(int i) { + return isNull[longValues.length + i]; + } + + public boolean getIsBytesNull(int i) { + return isNull[longValues.length + doubleValues.length + i]; + } + + public long getLongValue(int i) { return longValues[i]; } public double getDoubleValue(int i) { - return doubleValues[i - longValues.length]; + return doubleValues[i]; } public byte[] getBytes(int i) { - return byteValues[i - longValues.length - doubleValues.length]; + return byteValues[i]; } public int getByteStart(int i) { - return byteStarts[i - longValues.length - doubleValues.length]; + return byteStarts[i]; } public int getByteLength(int i) { - return byteLengths[i - longValues.length - doubleValues.length]; + return byteLengths[i]; } public int getVariableSize() { diff --git ql/src/java/org/apache/hadoop/hive/ql/exec/vector/VectorHashKeyWrapperBatch.java ql/src/java/org/apache/hadoop/hive/ql/exec/vector/VectorHashKeyWrapperBatch.java index 06a91d0..5c92005 100644 --- ql/src/java/org/apache/hadoop/hive/ql/exec/vector/VectorHashKeyWrapperBatch.java +++ ql/src/java/org/apache/hadoop/hive/ql/exec/vector/VectorHashKeyWrapperBatch.java @@ -469,9 +469,10 @@ public static VectorHashKeyWrapperBatch compileKeyWrapperBatch(VectorExpression[ indexLookup[i].stringIndex = -1; ++doubleIndicesIndex; } else if (outputType.equalsIgnoreCase("string")) { + stringIndices[stringIndicesIndex]= i; indexLookup[i].longIndex = -1; indexLookup[i].doubleIndex = -1; - stringIndices[stringIndicesIndex]= i; + indexLookup[i].stringIndex = stringIndicesIndex; ++stringIndicesIndex; } else { @@ -516,17 +517,20 @@ public static VectorHashKeyWrapperBatch compileKeyWrapperBatch(VectorExpression[ public Object getWritableKeyValue(VectorHashKeyWrapper kw, int i, VectorExpressionWriter keyOutputWriter) throws HiveException { - if (kw.getIsNull(i)) { - return null; - } + KeyLookupHelper klh = indexLookup[i]; if (klh.longIndex >= 0) { - return keyOutputWriter.writeValue(kw.getLongValue(i)); + return kw.getIsLongNull(klh.longIndex) ? null : + keyOutputWriter.writeValue(kw.getLongValue(klh.longIndex)); } else if (klh.doubleIndex >= 0) { - return keyOutputWriter.writeValue(kw.getDoubleValue(i)); + return kw.getIsDoubleNull(klh.doubleIndex) ? null : + keyOutputWriter.writeValue(kw.getDoubleValue(klh.doubleIndex)); } else if (klh.stringIndex >= 0) { - return keyOutputWriter.writeValue( - kw.getBytes(i), kw.getByteStart(i), kw.getByteLength(i)); + return kw.getIsBytesNull(klh.stringIndex) ? null : + keyOutputWriter.writeValue( + kw.getBytes(klh.stringIndex), + kw.getByteStart(klh.stringIndex), + kw.getByteLength(klh.stringIndex)); } else { throw new HiveException(String.format( "Internal inconsistent KeyLookupHelper at index [%d]:%d %d %d", diff --git ql/src/test/org/apache/hadoop/hive/ql/exec/vector/TestVectorGroupByOperator.java ql/src/test/org/apache/hadoop/hive/ql/exec/vector/TestVectorGroupByOperator.java index ba6e8c4..f3b15f2 100644 --- ql/src/test/org/apache/hadoop/hive/ql/exec/vector/TestVectorGroupByOperator.java +++ ql/src/test/org/apache/hadoop/hive/ql/exec/vector/TestVectorGroupByOperator.java @@ -158,6 +158,114 @@ private static GroupByDesc buildKeyGroupByDesc( } @Test + public void testMultiKeyIntStringInt() throws HiveException { + testMultiKey( + "sum", + new FakeVectorRowBatchFromObjectIterables( + 2, + new String[] {"int", "string", "int", "double"}, + Arrays.asList(new Object[]{null, 1, 1, null, 2, 2, null}), + Arrays.asList(new Object[]{ "A", "A", "A", "C", null, null, "A"}), + Arrays.asList(new Object[]{null, 2, 2, null, 2, 2, null}), + Arrays.asList(new Object[]{1.0, 2.0, 4.0, 8.0, 16.0, 32.0, 64.0})), + buildHashMap( + Arrays.asList( 1, "A", 2), 6.0, + Arrays.asList(null, "C", null), 8.0, + Arrays.asList( 2, null, 2), 48.0, + Arrays.asList(null, "A", null), 65.0)); + } + + @Test + public void testMultiKeyStringByteString() throws HiveException { + testMultiKey( + "sum", + new FakeVectorRowBatchFromObjectIterables( + 1, + new String[] {"string", "tinyint", "string", "double"}, + Arrays.asList(new Object[]{"A", "A", null}), + Arrays.asList(new Object[]{ 1, 1, 1}), + Arrays.asList(new Object[]{ "A", "A", "A"}), + Arrays.asList(new Object[]{ 1.0, 1.0, 1.0})), + buildHashMap( + Arrays.asList( "A", (byte)1, "A"), 2.0, + Arrays.asList( null, (byte)1, "A"), 1.0)); + } + + @Test + public void testMultiKeyStringIntString() throws HiveException { + testMultiKey( + "sum", + new FakeVectorRowBatchFromObjectIterables( + 2, + new String[] {"string", "int", "string", "double"}, + Arrays.asList(new Object[]{ "A", "A", "A", "C", null, null, "A"}), + Arrays.asList(new Object[]{null, 1, 1, null, 2, 2, null}), + Arrays.asList(new Object[]{ "A", "A", "A", "C", null, null, "A"}), + Arrays.asList(new Object[]{ 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0})), + buildHashMap( + Arrays.asList(null, 2, null), 2.0, + Arrays.asList( "C", null, "C"), 1.0, + Arrays.asList( "A", 1, "A"), 2.0, + Arrays.asList( "A", null, "A"), 2.0)); + } + + @Test + public void testMultiKeyIntStringString() throws HiveException { + testMultiKey( + "sum", + new FakeVectorRowBatchFromObjectIterables( + 2, + new String[] {"int", "string", "string", "double"}, + Arrays.asList(new Object[]{null, 1, 1, null, 2, 2, null}), + Arrays.asList(new Object[]{ "A", "A", "A", "C", null, null, "A"}), + Arrays.asList(new Object[]{ "A", "A", "A", "C", null, null, "A"}), + Arrays.asList(new Object[]{ 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0})), + buildHashMap( + Arrays.asList( 2, null, null), 2.0, + Arrays.asList(null, "C", "C"), 1.0, + Arrays.asList( 1, "A", "A"), 2.0, + Arrays.asList(null, "A", "A"), 2.0)); + } + + @Test + public void testMultiKeyDoubleStringInt() throws HiveException { + testMultiKey( + "sum", + new FakeVectorRowBatchFromObjectIterables( + 2, + new String[] {"double", "string", "int", "double"}, + Arrays.asList(new Object[]{null, 1.0, 1.0, null, 2.0, 2.0, null}), + Arrays.asList(new Object[]{ "A", "A", "A", "C", null, null, "A"}), + Arrays.asList(new Object[]{null, 2, 2, null, 2, 2, null}), + Arrays.asList(new Object[]{1.0, 2.0, 4.0, 8.0, 16.0, 32.0, 64.0})), + buildHashMap( + Arrays.asList( 1.0, "A", 2), 6.0, + Arrays.asList(null, "C", null), 8.0, + Arrays.asList( 2.0, null, 2), 48.0, + Arrays.asList(null, "A", null), 65.0)); + } + + @Test + public void testMultiKeyDoubleShortString() throws HiveException { + short s = 2; + testMultiKey( + "sum", + new FakeVectorRowBatchFromObjectIterables( + 2, + new String[] {"double", "smallint", "string", "double"}, + Arrays.asList(new Object[]{null, 1.0, 1.0, null, 2.0, 2.0, null}), + Arrays.asList(new Object[]{null, s, s, null, s, s, null}), + Arrays.asList(new Object[]{ "A", "A", "A", "C", null, null, "A"}), + Arrays.asList(new Object[]{1.0, 2.0, 4.0, 8.0, 16.0, 32.0, 64.0})), + buildHashMap( + Arrays.asList( 1.0, s, "A"), 6.0, + Arrays.asList(null, null, "C"), 8.0, + Arrays.asList( 2.0, s, null), 48.0, + Arrays.asList(null, null, "A"), 65.0)); + } + + + @Test public void testDoubleValueTypeSum() throws HiveException { testKeyTypeAggregate( "sum", @@ -1263,6 +1371,132 @@ public void testStdDevSampLongRepeat () throws HiveException { (double)0); } + private void testMultiKey( + String aggregateName, + FakeVectorRowBatchFromObjectIterables data, + HashMap expected) throws HiveException { + + Map mapColumnNames = new HashMap(); + ArrayList outputColumnNames = new ArrayList(); + ArrayList keysDesc = new ArrayList(); + Set keys = new HashSet(); + + // The types array tells us the number of columns in the data + final String[] columnTypes = data.getTypes(); + + // Columns 0..N-1 are keys. Column N is the aggregate value input + int i=0; + for(; i aggs = new ArrayList(1); + aggs.add( + buildAggregationDesc(ctx, aggregateName, + "value", TypeInfoFactory.getPrimitiveTypeInfo(columnTypes[i]))); + + for(i=0; i expected; + private Set keys; + + @Override + public void inspectRow(Object row, int tag) throws HiveException { + assertTrue(row instanceof Object[]); + Object[] fields = (Object[]) row; + assertEquals(columnTypes.length, fields.length); + ArrayList keyValue = new ArrayList(columnTypes.length-1); + for(int i=0; i expected, Set keys) { + this.aggregateName = aggregateName; + this.expected = expected; + this.keys = keys; + return this; + } + }.init(aggregateName, expected, keys)); + + for (VectorizedRowBatch unit: data) { + vgo.process(unit, 0); + } + vgo.close(false); + + List outBatchList = out.getCapturedRows(); + assertNotNull(outBatchList); + assertEquals(expected.size(), outBatchList.size()); + assertEquals(expected.size(), keys.size()); + } + + private void testKeyTypeAggregate( String aggregateName, FakeVectorRowBatchFromObjectIterables data, @@ -1342,11 +1576,13 @@ public void inspectRow(Object row, int tag) throws HiveException { key.getClass().getName(), key)); } + String keyValueAsString = String.format("%s", keyValue); + assertTrue(expected.containsKey(keyValue)); Object expectedValue = expected.get(keyValue); Object value = fields[1]; Validator validator = getValidator(aggregateName); - validator.validate(expectedValue, new Object[] {value}); + validator.validate(keyValueAsString, expectedValue, new Object[] {value}); keys.add(keyValue); } @@ -1482,33 +1718,33 @@ public void testAggregateCountStar ( public static interface Validator { - void validate (Object expected, Object result); + void validate (String key, Object expected, Object result); }; public static class ValueValidator implements Validator { @Override - public void validate(Object expected, Object result) { + public void validate(String key, Object expected, Object result) { assertEquals(true, result instanceof Object[]); Object[] arr = (Object[]) result; assertEquals(1, arr.length); if (expected == null) { - Assert.assertSame(NullWritable.get(), arr[0]); + Assert.assertSame(key, NullWritable.get(), arr[0]); } else if (arr[0] instanceof LongWritable) { LongWritable lw = (LongWritable) arr[0]; - assertEquals((Long) expected, (Long) lw.get()); + assertEquals(key, (Long) expected, (Long) lw.get()); } else if (arr[0] instanceof Text) { Text tx = (Text) arr[0]; String sbw = tx.toString(); - assertEquals((String) expected, sbw); + assertEquals(key, (String) expected, sbw); } else if (arr[0] instanceof DoubleWritable) { DoubleWritable dw = (DoubleWritable) arr[0]; - assertEquals ((Double) expected, (Double) dw.get()); + assertEquals (key, (Double) expected, (Double) dw.get()); } else if (arr[0] instanceof Double) { - assertEquals ((Double) expected, (Double) arr[0]); + assertEquals (key, (Double) expected, (Double) arr[0]); } else if (arr[0] instanceof Long) { - assertEquals ((Long) expected, (Long) arr[0]); + assertEquals (key, (Long) expected, (Long) arr[0]); } else { Assert.fail("Unsupported result type: " + arr[0].getClass().getName()); } @@ -1518,12 +1754,12 @@ public void validate(Object expected, Object result) { public static class AvgValidator implements Validator { @Override - public void validate(Object expected, Object result) { + public void validate(String key, Object expected, Object result) { Object[] arr = (Object[]) result; assertEquals (1, arr.length); if (expected == null) { - Assert.assertSame(NullWritable.get(), arr[0]); + Assert.assertSame(key, NullWritable.get(), arr[0]); } else { assertEquals (true, arr[0] instanceof Object[]); Object[] vals = (Object[]) arr[0]; @@ -1534,7 +1770,7 @@ public void validate(Object expected, Object result) { LongWritable lw = (LongWritable) vals[0]; DoubleWritable dw = (DoubleWritable) vals[1]; assertFalse (lw.get() == 0L); - assertEquals ((Double) expected, (Double) (dw.get() / lw.get())); + assertEquals (key, (Double) expected, (Double) (dw.get() / lw.get())); } } @@ -1542,11 +1778,11 @@ public void validate(Object expected, Object result) { public abstract static class BaseVarianceValidator implements Validator { - abstract void validateVariance ( + abstract void validateVariance (String key, double expected, long cnt, double sum, double variance); @Override - public void validate(Object expected, Object result) { + public void validate(String key, Object expected, Object result) { Object[] arr = (Object[]) result; assertEquals (1, arr.length); @@ -1564,7 +1800,7 @@ public void validate(Object expected, Object result) { DoubleWritable sum = (DoubleWritable) vals[1]; DoubleWritable var = (DoubleWritable) vals[2]; assertTrue (1 <= cnt.get()); - validateVariance ((Double) expected, cnt.get(), sum.get(), var.get()); + validateVariance (key, (Double) expected, cnt.get(), sum.get(), var.get()); } } } @@ -1572,32 +1808,32 @@ public void validate(Object expected, Object result) { public static class VarianceValidator extends BaseVarianceValidator { @Override - void validateVariance(double expected, long cnt, double sum, double variance) { - assertEquals (expected, variance /cnt, 0.0); + void validateVariance(String key, double expected, long cnt, double sum, double variance) { + assertEquals (key, expected, variance /cnt, 0.0); } } public static class VarianceSampValidator extends BaseVarianceValidator { @Override - void validateVariance(double expected, long cnt, double sum, double variance) { - assertEquals (expected, variance /(cnt-1), 0.0); + void validateVariance(String key, double expected, long cnt, double sum, double variance) { + assertEquals (key, expected, variance /(cnt-1), 0.0); } } public static class StdValidator extends BaseVarianceValidator { @Override - void validateVariance(double expected, long cnt, double sum, double variance) { - assertEquals (expected, Math.sqrt(variance / cnt), 0.0); + void validateVariance(String key, double expected, long cnt, double sum, double variance) { + assertEquals (key, expected, Math.sqrt(variance / cnt), 0.0); } } public static class StdSampValidator extends BaseVarianceValidator { @Override - void validateVariance(double expected, long cnt, double sum, double variance) { - assertEquals (expected, Math.sqrt(variance / (cnt-1)), 0.0); + void validateVariance(String key, double expected, long cnt, double sum, double variance) { + assertEquals (key, expected, Math.sqrt(variance / (cnt-1)), 0.0); } } @@ -1658,7 +1894,7 @@ public void testAggregateCountStarIterable ( Object result = outBatchList.get(0); Validator validator = getValidator("count"); - validator.validate(expected, result); + validator.validate("_total", expected, result); } public void testAggregateStringIterable ( @@ -1688,7 +1924,7 @@ public void testAggregateStringIterable ( Object result = outBatchList.get(0); Validator validator = getValidator(aggregateName); - validator.validate(expected, result); + validator.validate("_total", expected, result); } public void testAggregateDoubleIterable ( @@ -1718,7 +1954,7 @@ public void testAggregateDoubleIterable ( Object result = outBatchList.get(0); Validator validator = getValidator(aggregateName); - validator.validate(expected, result); + validator.validate("_total", expected, result); } public void testAggregateLongIterable ( @@ -1748,7 +1984,7 @@ public void testAggregateLongIterable ( Object result = outBatchList.get(0); Validator validator = getValidator(aggregateName); - validator.validate(expected, result); + validator.validate("_total", expected, result); } public void testAggregateLongKeyIterable ( @@ -1788,10 +2024,11 @@ public void inspectRow(Object row, int tag) throws HiveException { keyValue = lwKey.get(); } assertTrue(expected.containsKey(keyValue)); + String keyAsString = String.format("%s", key); Object expectedValue = expected.get(keyValue); Object value = fields[1]; Validator validator = getValidator(aggregateName); - validator.validate(expectedValue, new Object[] {value}); + validator.validate(keyAsString, expectedValue, new Object[] {value}); keys.add(keyValue); } @@ -1857,7 +2094,8 @@ public void inspectRow(Object row, int tag) throws HiveException { Object expectedValue = expected.get(keyValue); Object value = fields[1]; Validator validator = getValidator(aggregateName); - validator.validate(expectedValue, new Object[] {value}); + String keyAsString = String.format("%s", key); + validator.validate(keyAsString, expectedValue, new Object[] {value}); keys.add(keyValue); } diff --git ql/src/test/org/apache/hadoop/hive/ql/exec/vector/util/FakeVectorRowBatchFromObjectIterables.java ql/src/test/org/apache/hadoop/hive/ql/exec/vector/util/FakeVectorRowBatchFromObjectIterables.java index 35389b1..579f931 100644 --- ql/src/test/org/apache/hadoop/hive/ql/exec/vector/util/FakeVectorRowBatchFromObjectIterables.java +++ ql/src/test/org/apache/hadoop/hive/ql/exec/vector/util/FakeVectorRowBatchFromObjectIterables.java @@ -168,6 +168,8 @@ public VectorizedRowBatch produceNextBatch() { batch.cols[i].isNull[batch.size] = true; batch.cols[i].noNulls = false; } else { + // Must reset the isNull, could be set from prev batch use + batch.cols[i].isNull[batch.size] = false; columnAssign[i].assign(batch.cols[i], batch.size, value); } }