diff --git ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/VectorExpressionWriterFactory.java ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/VectorExpressionWriterFactory.java index b041fc5..8379385 100644 --- ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/VectorExpressionWriterFactory.java +++ ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/VectorExpressionWriterFactory.java @@ -43,6 +43,7 @@ import org.apache.hadoop.io.FloatWritable; import org.apache.hadoop.io.IntWritable; import org.apache.hadoop.io.LongWritable; +import org.apache.hadoop.io.NullWritable; import org.apache.hadoop.io.Text; /** @@ -112,11 +113,11 @@ public Object writeValue(ColumnVector column, int row) throws HiveException { } else if (!lcv.noNulls && !lcv.isRepeating && !lcv.isNull[row]) { return writeValue(lcv.vector[row]); } else if (!lcv.noNulls && !lcv.isRepeating && lcv.isNull[row]) { - return null; + return NullWritable.get(); } else if (!lcv.noNulls && lcv.isRepeating && !lcv.isNull[0]) { return writeValue(lcv.vector[0]); } else if (!lcv.noNulls && lcv.isRepeating && lcv.isNull[0]) { - return null; + return NullWritable.get(); } throw new HiveException( String.format( @@ -140,11 +141,11 @@ public Object writeValue(ColumnVector column, int row) throws HiveException { } else if (!dcv.noNulls && !dcv.isRepeating && !dcv.isNull[row]) { return writeValue(dcv.vector[row]); } else if (!dcv.noNulls && !dcv.isRepeating && dcv.isNull[row]) { - return null; + return NullWritable.get(); } else if (!dcv.noNulls && dcv.isRepeating && !dcv.isNull[0]) { return writeValue(dcv.vector[0]); } else if (!dcv.noNulls && dcv.isRepeating && dcv.isNull[0]) { - return null; + return NullWritable.get(); } throw new HiveException( String.format( @@ -168,11 +169,11 @@ public Object writeValue(ColumnVector column, int row) throws HiveException { } else if (!bcv.noNulls && !bcv.isRepeating && !bcv.isNull[row]) { return writeValue(bcv.vector[row], bcv.start[row], bcv.length[row]); } else if (!bcv.noNulls && !bcv.isRepeating && bcv.isNull[row]) { - return null; + return NullWritable.get(); } else if (!bcv.noNulls && bcv.isRepeating && !bcv.isNull[0]) { return writeValue(bcv.vector[0], bcv.start[0], bcv.length[0]); } else if (!bcv.noNulls && bcv.isRepeating && bcv.isNull[0]) { - return null; + return NullWritable.get(); } throw new HiveException( String.format( diff --git ql/src/test/org/apache/hadoop/hive/ql/exec/vector/TestVectorGroupByOperator.java ql/src/test/org/apache/hadoop/hive/ql/exec/vector/TestVectorGroupByOperator.java index 0050ebc..7003b9e 100644 --- ql/src/test/org/apache/hadoop/hive/ql/exec/vector/TestVectorGroupByOperator.java +++ ql/src/test/org/apache/hadoop/hive/ql/exec/vector/TestVectorGroupByOperator.java @@ -43,18 +43,17 @@ import org.apache.hadoop.hive.ql.plan.ExprNodeColumnDesc; import org.apache.hadoop.hive.ql.plan.ExprNodeDesc; import org.apache.hadoop.hive.ql.plan.GroupByDesc; +import org.apache.hadoop.hive.serde2.io.ByteWritable; +import org.apache.hadoop.hive.serde2.io.DoubleWritable; import org.apache.hadoop.hive.serde2.io.ShortWritable; import org.apache.hadoop.hive.serde2.io.TimestampWritable; +import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo; +import org.apache.hadoop.hive.serde2.typeinfo.TypeInfoFactory; import org.apache.hadoop.io.BooleanWritable; -import org.apache.hadoop.io.ByteWritable; -import org.apache.hadoop.io.BytesWritable; -import org.apache.hadoop.io.DoubleWritable; import org.apache.hadoop.io.FloatWritable; import org.apache.hadoop.io.IntWritable; import org.apache.hadoop.io.LongWritable; -import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo; -import org.apache.hadoop.hive.serde2.typeinfo.TypeInfoFactory; - +import org.apache.hadoop.io.Text; import org.junit.Assert; import org.junit.Test; @@ -151,12 +150,12 @@ private static GroupByDesc buildKeyGroupByDesc( ArrayList keys = new ArrayList(); keys.add(keyExp); desc.setKeys(keys); - + desc.getOutputColumnNames().add("_col1"); return desc; } - + @Test public void testDoubleValueTypeSum() throws HiveException { testKeyTypeAggregate( @@ -168,7 +167,7 @@ public void testDoubleValueTypeSum() throws HiveException { Arrays.asList(new Object[]{13.0,null,7.0, 19.0})), buildHashMap((byte)1, 20.0, null, 19.0)); } - + @Test public void testDoubleValueTypeSumOneKey() throws HiveException { testKeyTypeAggregate( @@ -179,8 +178,8 @@ public void testDoubleValueTypeSumOneKey() throws HiveException { Arrays.asList(new Object[]{ 1, 1, 1, 1}), Arrays.asList(new Object[]{13.0,null,7.0, 19.0})), buildHashMap((byte)1, 39.0)); - } - + } + @Test public void testDoubleValueTypeCount() throws HiveException { testKeyTypeAggregate( @@ -192,7 +191,7 @@ public void testDoubleValueTypeCount() throws HiveException { Arrays.asList(new Object[]{13.0,null,7.0, 19.0})), buildHashMap((byte)1, 2L, null, 1L)); } - + public void testDoubleValueTypeCountOneKey() throws HiveException { testKeyTypeAggregate( "count", @@ -202,8 +201,8 @@ public void testDoubleValueTypeCountOneKey() throws HiveException { Arrays.asList(new Object[]{ 1, 1, 1, 1}), Arrays.asList(new Object[]{13.0,null,7.0, 19.0})), buildHashMap((byte)1, 3L)); - } - + } + @Test public void testDoubleValueTypeAvg() throws HiveException { testKeyTypeAggregate( @@ -215,7 +214,7 @@ public void testDoubleValueTypeAvg() throws HiveException { Arrays.asList(new Object[]{13.0,null,7.0, 19.0})), buildHashMap((byte)1, 10.0, null, 19.0)); } - + @Test public void testDoubleValueTypeAvgOneKey() throws HiveException { testKeyTypeAggregate( @@ -226,8 +225,8 @@ public void testDoubleValueTypeAvgOneKey() throws HiveException { Arrays.asList(new Object[]{ 1, 1, 1, 1}), Arrays.asList(new Object[]{13.0,null,7.0, 19.0})), buildHashMap((byte)1, 13.0)); - } - + } + @Test public void testDoubleValueTypeMin() throws HiveException { testKeyTypeAggregate( @@ -239,7 +238,7 @@ public void testDoubleValueTypeMin() throws HiveException { Arrays.asList(new Object[]{13.0,null,7.0, 19.0})), buildHashMap((byte)1, 7.0, null, 19.0)); } - + @Test public void testDoubleValueTypeMinOneKey() throws HiveException { testKeyTypeAggregate( @@ -251,7 +250,7 @@ public void testDoubleValueTypeMinOneKey() throws HiveException { Arrays.asList(new Object[]{13.0,null,7.0, 19.0})), buildHashMap((byte)1, 7.0)); } - + @Test public void testDoubleValueTypeMax() throws HiveException { testKeyTypeAggregate( @@ -287,7 +286,7 @@ public void testDoubleValueTypeVariance() throws HiveException { Arrays.asList(new Object[]{13.0,null,7.0, 19.0})), buildHashMap((byte)1, 9.0, null, 0.0)); } - + @Test public void testDoubleValueTypeVarianceOneKey() throws HiveException { testKeyTypeAggregate( @@ -298,7 +297,7 @@ public void testDoubleValueTypeVarianceOneKey() throws HiveException { Arrays.asList(new Object[]{ 1, 1, 1, 1}), Arrays.asList(new Object[]{13.0,null,7.0, 19.0})), buildHashMap((byte)1, 24.0)); - } + } @Test public void testTinyintKeyTypeAggregate() throws HiveException { testKeyTypeAggregate( @@ -310,7 +309,7 @@ public void testTinyintKeyTypeAggregate() throws HiveException { Arrays.asList(new Object[]{13L,null,7L, 19L})), buildHashMap((byte)1, 20L, null, 19L)); } - + @Test public void testSmallintKeyTypeAggregate() throws HiveException { testKeyTypeAggregate( @@ -321,8 +320,8 @@ public void testSmallintKeyTypeAggregate() throws HiveException { Arrays.asList(new Object[]{ 1,null, 1, null}), Arrays.asList(new Object[]{13L,null,7L, 19L})), buildHashMap((short)1, 20L, null, 19L)); - } - + } + @Test public void testIntKeyTypeAggregate() throws HiveException { testKeyTypeAggregate( @@ -334,7 +333,7 @@ public void testIntKeyTypeAggregate() throws HiveException { Arrays.asList(new Object[]{13L,null,7L, 19L})), buildHashMap((int)1, 20L, null, 19L)); } - + @Test public void testBigintKeyTypeAggregate() throws HiveException { testKeyTypeAggregate( @@ -346,7 +345,7 @@ public void testBigintKeyTypeAggregate() throws HiveException { Arrays.asList(new Object[]{13L,null,7L, 19L})), buildHashMap((long)1L, 20L, null, 19L)); } - + @Test public void testBooleanKeyTypeAggregate() throws HiveException { testKeyTypeAggregate( @@ -358,7 +357,7 @@ public void testBooleanKeyTypeAggregate() throws HiveException { Arrays.asList(new Object[]{13L,null,7L, 19L})), buildHashMap(true, 20L, null, 19L)); } - + @Test public void testTimestampKeyTypeAggregate() throws HiveException { testKeyTypeAggregate( @@ -369,8 +368,8 @@ public void testTimestampKeyTypeAggregate() throws HiveException { Arrays.asList(new Object[]{new Timestamp(1),null, new Timestamp(1), null}), Arrays.asList(new Object[]{13L,null,7L, 19L})), buildHashMap(new Timestamp(1), 20L, null, 19L)); - } - + } + @Test public void testFloatKeyTypeAggregate() throws HiveException { testKeyTypeAggregate( @@ -382,7 +381,7 @@ public void testFloatKeyTypeAggregate() throws HiveException { Arrays.asList(new Object[]{13L,null,7L, 19L})), buildHashMap((float)1.0, 20L, null, 19L)); } - + @Test public void testDoubleKeyTypeAggregate() throws HiveException { testKeyTypeAggregate( @@ -393,8 +392,8 @@ public void testDoubleKeyTypeAggregate() throws HiveException { Arrays.asList(new Object[]{ 1,null, 1, null}), Arrays.asList(new Object[]{13L,null,7L, 19L})), buildHashMap((double)1.0, 20L, null, 19L)); - } - + } + @Test public void testCountStar() throws HiveException { testAggregateCountStar( @@ -1262,18 +1261,18 @@ public void testStdDevSampLongRepeat () throws HiveException { 1024, (double)0); } - + private void testKeyTypeAggregate( String aggregateName, FakeVectorRowBatchFromObjectIterables data, Map expected) throws HiveException { - + Map mapColumnNames = new HashMap(); mapColumnNames.put("Key", 0); mapColumnNames.put("Value", 1); VectorizationContext ctx = new VectorizationContext(mapColumnNames, 2); Set keys = new HashSet(); - + AggregationDesc agg = buildAggregationDesc(ctx, aggregateName, "Value", TypeInfoFactory.getPrimitiveTypeInfo(data.getTypes()[1])); ArrayList aggs = new ArrayList(); @@ -1287,7 +1286,7 @@ private void testKeyTypeAggregate( desc.setOutputColumnNames(outputColumnNames); desc.setAggregators(aggs); - ExprNodeDesc keyExp = buildColumnDesc(ctx, "Key", + ExprNodeDesc keyExp = buildColumnDesc(ctx, "Key", TypeInfoFactory.getPrimitiveTypeInfo(data.getTypes()[0])); ArrayList keysDesc = new ArrayList(); keysDesc.add(keyExp); @@ -1338,10 +1337,10 @@ public void inspectRow(Object row, int tag) throws HiveException { BooleanWritable bwKey = (BooleanWritable)key; keyValue = bwKey.get(); } else { - Assert.fail(String.format("Not implemented key output type %s: %s", + Assert.fail(String.format("Not implemented key output type %s: %s", key.getClass().getName(), key)); } - + assertTrue(expected.containsKey(keyValue)); Object expectedValue = expected.get(keyValue); Object value = fields[1]; @@ -1367,8 +1366,8 @@ public void inspectRow(Object row, int tag) throws HiveException { List outBatchList = out.getCapturedRows(); assertNotNull(outBatchList); assertEquals(expected.size(), outBatchList.size()); - assertEquals(expected.size(), keys.size()); - } + assertEquals(expected.size(), keys.size()); + } public void testAggregateLongRepeats ( @@ -1498,8 +1497,8 @@ public void validate(Object expected, Object result) { } else if (arr[0] instanceof LongWritable) { LongWritable lw = (LongWritable) arr[0]; assertEquals((Long) expected, (Long) lw.get()); - } else if (arr[0] instanceof BytesWritable) { - BytesWritable bw = (BytesWritable) arr[0]; + } else if (arr[0] instanceof Text) { + Text bw = (Text) arr[0]; String sbw = new String(bw.getBytes()); assertEquals((String) expected, sbw); } else if (arr[0] instanceof DoubleWritable) { @@ -1849,9 +1848,9 @@ public void inspectRow(Object row, int tag) throws HiveException { Object key = fields[0]; String keyValue = null; if (null != key) { - assertTrue(key instanceof BytesWritable); - BytesWritable bwKey = (BytesWritable)key; - keyValue = new String(bwKey.get()); + assertTrue(key instanceof Text); + Text bwKey = (Text)key; + keyValue = new String(bwKey.toString()); } assertTrue(expected.containsKey(keyValue)); Object expectedValue = expected.get(keyValue); diff --git ql/src/test/org/apache/hadoop/hive/ql/exec/vector/expressions/TestVectorExpressionWriters.java ql/src/test/org/apache/hadoop/hive/ql/exec/vector/expressions/TestVectorExpressionWriters.java new file mode 100644 index 0000000..cde5b1f --- /dev/null +++ ql/src/test/org/apache/hadoop/hive/ql/exec/vector/expressions/TestVectorExpressionWriters.java @@ -0,0 +1,206 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.hadoop.hive.ql.exec.vector.expressions; + + +import java.sql.Timestamp; +import java.util.Random; + +import junit.framework.Assert; + +import org.apache.hadoop.hive.ql.exec.vector.BytesColumnVector; +import org.apache.hadoop.hive.ql.exec.vector.DoubleColumnVector; +import org.apache.hadoop.hive.ql.exec.vector.LongColumnVector; +import org.apache.hadoop.hive.ql.exec.vector.TimestampUtils; +import org.apache.hadoop.hive.ql.exec.vector.util.VectorizedRowGroupGenUtil; +import org.apache.hadoop.hive.ql.metadata.HiveException; +import org.apache.hadoop.hive.ql.plan.ExprNodeColumnDesc; +import org.apache.hadoop.hive.ql.plan.ExprNodeDesc; +import org.apache.hadoop.hive.serde2.io.ByteWritable; +import org.apache.hadoop.hive.serde2.io.DoubleWritable; +import org.apache.hadoop.hive.serde2.io.ShortWritable; +import org.apache.hadoop.hive.serde2.io.TimestampWritable; +import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo; +import org.apache.hadoop.hive.serde2.typeinfo.TypeInfoFactory; +import org.apache.hadoop.io.BooleanWritable; +import org.apache.hadoop.io.FloatWritable; +import org.apache.hadoop.io.IntWritable; +import org.apache.hadoop.io.LongWritable; +import org.apache.hadoop.io.NullWritable; +import org.apache.hadoop.io.Text; +import org.apache.hadoop.io.Writable; +import org.junit.Test; + +public class TestVectorExpressionWriters { + + private final int vectorSize = 5; + + private VectorExpressionWriter getWriter(TypeInfo colTypeInfo) throws HiveException { + ExprNodeDesc columnDesc = new ExprNodeColumnDesc(); + columnDesc.setTypeInfo(colTypeInfo); + VectorExpressionWriter vew = VectorExpressionWriterFactory + .genVectorExpressionWritable(columnDesc); + return vew; + } + + private Writable getWritableValue(TypeInfo ti, double value) { + if (ti.equals(TypeInfoFactory.floatTypeInfo)) { + return new FloatWritable((float) value); + } else if (ti.equals(TypeInfoFactory.doubleTypeInfo)) { + return new DoubleWritable(value); + } + return null; + } + + private Writable getWritableValue(TypeInfo ti, byte[] value) { + if (ti.equals(TypeInfoFactory.stringTypeInfo)) { + return new Text(value); + } + return null; + } + + private Writable getWritableValue(TypeInfo ti, long value) { + if (ti.equals(TypeInfoFactory.byteTypeInfo)) { + return new ByteWritable((byte) value); + } else if (ti.equals(TypeInfoFactory.shortTypeInfo)) { + return new ShortWritable((short) value); + } else if (ti.equals(TypeInfoFactory.intTypeInfo)) { + return new IntWritable( (int) value); + } else if (ti.equals(TypeInfoFactory.longTypeInfo)) { + return new LongWritable( (long) value); + } else if (ti.equals(TypeInfoFactory.booleanTypeInfo)) { + return new BooleanWritable( value == 0 ? false : true); + } else if (ti.equals(TypeInfoFactory.timestampTypeInfo)) { + Timestamp ts = new Timestamp(value); + TimestampUtils.assignTimeInNanoSec(value, ts); + TimestampWritable tw = new TimestampWritable(ts); + return tw; + } + return null; + } + + private void testWriterDouble(TypeInfo type) throws HiveException { + DoubleColumnVector dcv = VectorizedRowGroupGenUtil.generateDoubleColumnVector(true, false, + this.vectorSize, new Random(10)); + dcv.isNull[2] = true; + VectorExpressionWriter vew = getWriter(type); + for (int i = 0; i < vectorSize; i++) { + Writable w = (Writable) vew.writeValue(dcv, i); + if (!(w instanceof NullWritable)) { + Writable expected = getWritableValue(type, dcv.vector[i]); + Assert.assertEquals(expected, w); + } else { + Assert.assertTrue(dcv.isNull[i]); + } + } + } + + private void testWriterLong(TypeInfo type) throws HiveException { + LongColumnVector lcv = VectorizedRowGroupGenUtil.generateLongColumnVector(true, false, + vectorSize, new Random(10)); + lcv.isNull[3] = true; + VectorExpressionWriter vew = getWriter(type); + for (int i = 0; i < vectorSize; i++) { + Writable w = (Writable) vew.writeValue(lcv, i); + if (!(w instanceof NullWritable)) { + Writable expected = getWritableValue(type, lcv.vector[i]); + if (expected instanceof TimestampWritable) { + TimestampWritable t1 = (TimestampWritable) expected; + TimestampWritable t2 = (TimestampWritable) w; + Assert.assertTrue(t1.getNanos() == t2.getNanos()); + Assert.assertTrue(t1.getSeconds() == t2.getSeconds()); + continue; + } + Assert.assertEquals(expected, w); + } else { + Assert.assertTrue(lcv.isNull[i]); + } + } + } + + private void testWriterBytes(TypeInfo type) throws HiveException { + Text t1 = new Text("alpha"); + Text t2 = new Text("beta"); + BytesColumnVector bcv = new BytesColumnVector(vectorSize); + bcv.noNulls = false; + bcv.initBuffer(); + bcv.setVal(0, t1.getBytes(), 0, t1.getLength()); + bcv.isNull[1] = true; + bcv.setVal(2, t2.getBytes(), 0, t2.getLength()); + bcv.isNull[3] = true; + bcv.setVal(4, t1.getBytes(), 0, t1.getLength()); + VectorExpressionWriter vew = getWriter(type); + for (int i = 0; i < vectorSize; i++) { + Writable w = (Writable) vew.writeValue(bcv, i); + if (!(w instanceof NullWritable)) { + byte [] val = new byte[bcv.length[i]]; + System.arraycopy(bcv.vector[i], bcv.start[i], val, 0, bcv.length[i]); + Writable expected = getWritableValue(type, val); + Assert.assertEquals(expected, w); + } else { + Assert.assertTrue(bcv.isNull[i]); + } + } + } + + @Test + public void testVectorExpressionWriterDouble() throws HiveException { + testWriterDouble(TypeInfoFactory.doubleTypeInfo); + } + + @Test + public void testVectorExpressionWriterFloat() throws HiveException { + testWriterDouble(TypeInfoFactory.floatTypeInfo); + } + + @Test + public void testVectorExpressionWriterLong() throws HiveException { + testWriterLong(TypeInfoFactory.longTypeInfo); + } + + @Test + public void testVectorExpressionWriterInt() throws HiveException { + testWriterLong(TypeInfoFactory.intTypeInfo); + } + + @Test + public void testVectorExpressionWriterShort() throws HiveException { + testWriterLong(TypeInfoFactory.shortTypeInfo); + } + + @Test + public void testVectorExpressionWriterBoolean() throws HiveException { + testWriterLong(TypeInfoFactory.booleanTypeInfo); + } + + @Test + public void testVectorExpressionWriterTimestamp() throws HiveException { + testWriterLong(TypeInfoFactory.timestampTypeInfo); + } + + @Test + public void testVectorExpressionWriterBye() throws HiveException { + testWriterLong(TypeInfoFactory.byteTypeInfo); + } + + @Test + public void testVectorExpressionWriterBytes() throws HiveException { + testWriterBytes(TypeInfoFactory.stringTypeInfo); + } +}