diff --git ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/VectorExpressionWriterFactory.java ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/VectorExpressionWriterFactory.java index d49136f..eeb76d7 100644 --- ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/VectorExpressionWriterFactory.java +++ ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/VectorExpressionWriterFactory.java @@ -478,12 +478,15 @@ public Object writeValue(Decimal128 value) throws HiveException { @Override public Object setValue(Object field, Decimal128 value) { + if (null == field) { + field = initValue(null); + } return ((SettableHiveDecimalObjectInspector) this.objectInspector).set(field, HiveDecimal.create(value.toBigDecimal())); } @Override - public Object initValue(Object ignored) throws HiveException { + public Object initValue(Object ignored) { return ((SettableHiveDecimalObjectInspector) this.objectInspector).create( HiveDecimal.create(BigDecimal.ZERO)); } diff --git ql/src/test/org/apache/hadoop/hive/ql/exec/vector/expressions/TestVectorExpressionWriters.java ql/src/test/org/apache/hadoop/hive/ql/exec/vector/expressions/TestVectorExpressionWriters.java index 3ee7a21..98a6527 100644 --- ql/src/test/org/apache/hadoop/hive/ql/exec/vector/expressions/TestVectorExpressionWriters.java +++ ql/src/test/org/apache/hadoop/hive/ql/exec/vector/expressions/TestVectorExpressionWriters.java @@ -25,8 +25,11 @@ import junit.framework.Assert; +import org.apache.hadoop.hive.common.type.Decimal128; +import org.apache.hadoop.hive.common.type.HiveDecimal; import org.apache.hadoop.hive.common.type.HiveVarchar; import org.apache.hadoop.hive.ql.exec.vector.BytesColumnVector; +import org.apache.hadoop.hive.ql.exec.vector.DecimalColumnVector; import org.apache.hadoop.hive.ql.exec.vector.DoubleColumnVector; import org.apache.hadoop.hive.ql.exec.vector.LongColumnVector; import org.apache.hadoop.hive.ql.exec.vector.TimestampUtils; @@ -36,6 +39,7 @@ import org.apache.hadoop.hive.ql.plan.ExprNodeDesc; import org.apache.hadoop.hive.serde2.io.ByteWritable; import org.apache.hadoop.hive.serde2.io.DoubleWritable; +import org.apache.hadoop.hive.serde2.io.HiveDecimalWritable; import org.apache.hadoop.hive.serde2.io.HiveVarcharWritable; import org.apache.hadoop.hive.serde2.io.ShortWritable; import org.apache.hadoop.hive.serde2.io.TimestampWritable; @@ -44,6 +48,7 @@ import org.apache.hadoop.hive.serde2.objectinspector.StandardStructObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory; +import org.apache.hadoop.hive.serde2.typeinfo.DecimalTypeInfo; import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo; import org.apache.hadoop.hive.serde2.typeinfo.TypeInfoFactory; import org.apache.hadoop.io.BooleanWritable; @@ -79,6 +84,10 @@ private Writable getWritableValue(TypeInfo ti, double value) { return null; } + private Writable getWritableValue(TypeInfo ti, Decimal128 value) { + return new HiveDecimalWritable(HiveDecimal.create(value.toBigDecimal())); + } + private Writable getWritableValue(TypeInfo ti, byte[] value) { if (ti.equals(TypeInfoFactory.stringTypeInfo)) { return new Text(value); @@ -135,7 +144,7 @@ private void testSetterDouble(TypeInfo type) throws HiveException { VectorExpressionWriter vew = getWriter(type); for (int i = 0; i < vectorSize; i++) { - values[i] = vew.initValue(null); + values[i] = null; // setValue() should be able to handle null input values[i] = vew.setValue(values[i], dcv, i); if (values[i] != null) { Writable expected = getWritableValue(type, dcv.vector[i]); @@ -146,6 +155,41 @@ private void testSetterDouble(TypeInfo type) throws HiveException { } } + private void testWriterDecimal(DecimalTypeInfo type) throws HiveException { + DecimalColumnVector dcv = VectorizedRowGroupGenUtil.generateDecimalColumnVector(type, true, false, + this.vectorSize, new Random(10)); + dcv.isNull[2] = true; + VectorExpressionWriter vew = getWriter(type); + for (int i = 0; i < vectorSize; i++) { + Writable w = (Writable) vew.writeValue(dcv, i); + if (w != null) { + Writable expected = getWritableValue(type, dcv.vector[i]); + Assert.assertEquals(expected, w); + } else { + Assert.assertTrue(dcv.isNull[i]); + } + } + } + + private void testSetterDecimal(DecimalTypeInfo type) throws HiveException { + DecimalColumnVector dcv = VectorizedRowGroupGenUtil.generateDecimalColumnVector(type, true, false, + this.vectorSize, new Random(10)); + dcv.isNull[2] = true; + Object[] values = new Object[this.vectorSize]; + + VectorExpressionWriter vew = getWriter(type); + for (int i = 0; i < vectorSize; i++) { + values[i] = null; // setValue() should be able to handle null input + values[i] = vew.setValue(values[i], dcv, i); + if (values[i] != null) { + Writable expected = getWritableValue(type, dcv.vector[i]); + Assert.assertEquals(expected, values[i]); + } else { + Assert.assertTrue(dcv.isNull[i]); + } + } + } + private void testWriterLong(TypeInfo type) throws HiveException { LongColumnVector lcv = VectorizedRowGroupGenUtil.generateLongColumnVector(true, false, vectorSize, new Random(10)); @@ -178,7 +222,7 @@ private void testSetterLong(TypeInfo type) throws HiveException { VectorExpressionWriter vew = getWriter(type); for (int i = 0; i < vectorSize; i++) { - values[i] = vew.initValue(null); + values[i] = null; // setValue() should be able to handle null input values[i] = vew.setValue(values[i], lcv, i); if (values[i] != null) { Writable expected = getWritableValue(type, lcv.vector[i]); @@ -290,7 +334,7 @@ private void testSetterText(TypeInfo type) throws HiveException { Object[] values = new Object[this.vectorSize]; VectorExpressionWriter vew = getWriter(type); for (int i = 0; i < vectorSize; i++) { - values[i] = vew.initValue(null); + values[i] = null; // setValue() should be able to handle null input Writable w = (Writable) vew.setValue(values[i], bcv, i); if (w != null) { byte [] val = new byte[bcv.length[i]]; @@ -327,7 +371,19 @@ public void testVectorExpressionSetterFloat() throws HiveException { public void testVectorExpressionWriterLong() throws HiveException { testWriterLong(TypeInfoFactory.longTypeInfo); } - + + @Test + public void testVectorExpressionWriterDecimal() throws HiveException { + DecimalTypeInfo typeInfo = TypeInfoFactory.getDecimalTypeInfo(38, 18); + testWriterDecimal(typeInfo); + } + + @Test + public void testVectorExpressionSetterDecimal() throws HiveException { + DecimalTypeInfo typeInfo = TypeInfoFactory.getDecimalTypeInfo(38, 18); + testSetterDecimal(typeInfo); + } + @Test public void testVectorExpressionSetterLong() throws HiveException { testSetterLong(TypeInfoFactory.longTypeInfo); diff --git ql/src/test/org/apache/hadoop/hive/ql/exec/vector/util/VectorizedRowGroupGenUtil.java ql/src/test/org/apache/hadoop/hive/ql/exec/vector/util/VectorizedRowGroupGenUtil.java index 238d40f..b36b097 100644 --- ql/src/test/org/apache/hadoop/hive/ql/exec/vector/util/VectorizedRowGroupGenUtil.java +++ ql/src/test/org/apache/hadoop/hive/ql/exec/vector/util/VectorizedRowGroupGenUtil.java @@ -20,9 +20,12 @@ import java.util.Random; +import org.apache.hadoop.hive.common.type.Decimal128; +import org.apache.hadoop.hive.ql.exec.vector.DecimalColumnVector; import org.apache.hadoop.hive.ql.exec.vector.DoubleColumnVector; import org.apache.hadoop.hive.ql.exec.vector.LongColumnVector; import org.apache.hadoop.hive.ql.exec.vector.VectorizedRowBatch; +import org.apache.hadoop.hive.serde2.typeinfo.DecimalTypeInfo; public class VectorizedRowGroupGenUtil { @@ -103,6 +106,42 @@ public static DoubleColumnVector generateDoubleColumnVector(boolean nulls, return dcv; } + public static DecimalColumnVector generateDecimalColumnVector(DecimalTypeInfo typeInfo, boolean nulls, + boolean repeating, int size, Random rand) { + DecimalColumnVector dcv = + new DecimalColumnVector(size, typeInfo.precision(), typeInfo.scale()); + + dcv.noNulls = !nulls; + dcv.isRepeating = repeating; + + Decimal128 repeatingValue = new Decimal128(); + do{ + repeatingValue.update(rand.nextDouble(), (short)typeInfo.scale()); + }while(repeatingValue.doubleValue() == 0); + + int nullFrequency = generateNullFrequency(rand); + + for(int i = 0; i < size; i++) { + if(nulls && (repeating || i % nullFrequency == 0)) { + dcv.isNull[i] = true; + dcv.vector[i] = null;//Decimal128.ONE; + + }else { + dcv.isNull[i] = false; + if (repeating) { + dcv.vector[i].update(repeatingValue); + } else { + dcv.vector[i].update(rand.nextDouble(), (short) typeInfo.scale()); + } + + if(dcv.vector[i].doubleValue() == 0) { + i--; + } + } + } + return dcv; + } + private static int generateNullFrequency(Random rand) { return 60 + rand.nextInt(20); }