diff --git a/ql/src/java/org/apache/hadoop/hive/ql/exec/vector/VectorizationContext.java b/ql/src/java/org/apache/hadoop/hive/ql/exec/vector/VectorizationContext.java index 565696d..535e4b3 100644 --- a/ql/src/java/org/apache/hadoop/hive/ql/exec/vector/VectorizationContext.java +++ b/ql/src/java/org/apache/hadoop/hive/ql/exec/vector/VectorizationContext.java @@ -717,13 +717,15 @@ ExprNodeDesc foldConstantsForUnaryExpression(ExprNodeDesc exprDesc) throws HiveE private VectorExpression getConstantVectorExpression(Object constantValue, TypeInfo typeInfo, Mode mode) throws HiveException { - String type = typeInfo.getTypeName(); + String type = typeInfo.getTypeName(); String colVectorType = getNormalizedTypeName(type); int outCol = -1; if (mode == Mode.PROJECTION) { outCol = ocm.allocateOutputColumn(colVectorType); } - if (decimalTypePattern.matcher(type).matches()) { + if (constantValue == null) { + return new ConstantVectorExpression(outCol, type, true); + } else if (decimalTypePattern.matcher(type).matches()) { VectorExpression ve = new ConstantVectorExpression(outCol, (Decimal128) constantValue); ve.setOutputType(typeInfo.getTypeName()); return ve; diff --git a/ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/ConstantVectorExpression.java b/ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/ConstantVectorExpression.java index a1a5584..9fd3853 100644 --- a/ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/ConstantVectorExpression.java +++ b/ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/ConstantVectorExpression.java @@ -41,6 +41,7 @@ private double doubleValue = 0; private byte[] bytesValue = null; private Decimal128 decimalValue = null; + private boolean isNullValue = false; private Type type; private int bytesValueLength = 0; @@ -74,34 +75,58 @@ public ConstantVectorExpression(int outputColumn, Decimal128 value) { this(outputColumn, "decimal"); setDecimalValue(value); } - + + /* + * Support for null constant object + */ + public ConstantVectorExpression(int outputColumn, String typeString, boolean isNull) { + this(outputColumn, typeString); + isNullValue = isNull; + } + private void evaluateLong(VectorizedRowBatch vrg) { LongColumnVector cv = (LongColumnVector) vrg.cols[outputColumn]; cv.isRepeating = true; - cv.noNulls = true; - cv.vector[0] = longValue; + cv.noNulls = !isNullValue; + if (!isNullValue) { + cv.vector[0] = longValue; + } else { + cv.isNull[0] = true; + } } private void evaluateDouble(VectorizedRowBatch vrg) { DoubleColumnVector cv = (DoubleColumnVector) vrg.cols[outputColumn]; cv.isRepeating = true; - cv.noNulls = true; - cv.vector[0] = doubleValue; + cv.noNulls = !isNullValue; + if (!isNullValue) { + cv.vector[0] = doubleValue; + } else { + cv.isNull[0] = true; + } } private void evaluateBytes(VectorizedRowBatch vrg) { BytesColumnVector cv = (BytesColumnVector) vrg.cols[outputColumn]; cv.isRepeating = true; - cv.noNulls = true; + cv.noNulls = !isNullValue; cv.initBuffer(); - cv.setVal(0, bytesValue, 0, bytesValueLength); + if (!isNullValue) { + cv.setVal(0, bytesValue, 0, bytesValueLength); + } else { + cv.isNull[0] = true; + } } private void evaluateDecimal(VectorizedRowBatch vrg) { DecimalColumnVector dcv = (DecimalColumnVector) vrg.cols[outputColumn]; dcv.isRepeating = true; - dcv.noNulls = true; - dcv.vector[0].update(decimalValue); + dcv.noNulls = !isNullValue; + if (!isNullValue) { + dcv.vector[0].update(decimalValue); + } else { + dcv.isNull[0] = true; + } } @Override diff --git a/ql/src/test/org/apache/hadoop/hive/ql/exec/vector/expressions/TestConstantVectorExpression.java b/ql/src/test/org/apache/hadoop/hive/ql/exec/vector/expressions/TestConstantVectorExpression.java index d2a7816..3b9e3ff 100644 --- a/ql/src/test/org/apache/hadoop/hive/ql/exec/vector/expressions/TestConstantVectorExpression.java +++ b/ql/src/test/org/apache/hadoop/hive/ql/exec/vector/expressions/TestConstantVectorExpression.java @@ -45,28 +45,36 @@ public void testConstantExpression() { ConstantVectorExpression bytesCve = new ConstantVectorExpression(2, str.getBytes()); Decimal128 decVal = new Decimal128(25.8, (short) 1); ConstantVectorExpression decimalCve = new ConstantVectorExpression(3, decVal); - + ConstantVectorExpression nullCve = new ConstantVectorExpression(4, "string", true); + int size = 20; - VectorizedRowBatch vrg = VectorizedRowGroupGenUtil.getVectorizedRowBatch(size, 4, 0); + VectorizedRowBatch vrg = VectorizedRowGroupGenUtil.getVectorizedRowBatch(size, 5, 0); LongColumnVector lcv = (LongColumnVector) vrg.cols[0]; DoubleColumnVector dcv = new DoubleColumnVector(size); BytesColumnVector bcv = new BytesColumnVector(size); DecimalColumnVector dv = new DecimalColumnVector(5, 1); + BytesColumnVector bcvn = new BytesColumnVector(size); vrg.cols[1] = dcv; vrg.cols[2] = bcv; vrg.cols[3] = dv; + vrg.cols[4] = bcvn; longCve.evaluate(vrg); doubleCve.evaluate(vrg); bytesCve.evaluate(vrg); decimalCve.evaluate(vrg); + nullCve.evaluate(vrg); assertTrue(lcv.isRepeating); assertTrue(dcv.isRepeating); assertTrue(bcv.isRepeating); assertEquals(17, lcv.vector[0]); assertTrue(17.34 == dcv.vector[0]); + assertTrue(bcvn.isRepeating); + assertTrue(bcvn.isNull[0]); + assertTrue(!bcvn.noNulls); + byte[] alphaBytes = "alpha".getBytes(); assertTrue(bcv.length[0] == alphaBytes.length); assertTrue(sameFirstKBytes(alphaBytes, bcv.vector[0], alphaBytes.length));