Index: ql/src/test/org/apache/hadoop/hive/ql/exec/vector/expressions/TestConstantVectorExpression.java =================================================================== --- ql/src/test/org/apache/hadoop/hive/ql/exec/vector/expressions/TestConstantVectorExpression.java (revision 1580023) +++ ql/src/test/org/apache/hadoop/hive/ql/exec/vector/expressions/TestConstantVectorExpression.java (working copy) @@ -23,7 +23,9 @@ import java.util.Arrays; +import org.apache.hadoop.hive.common.type.Decimal128; import org.apache.hadoop.hive.ql.exec.vector.BytesColumnVector; +import org.apache.hadoop.hive.ql.exec.vector.DecimalColumnVector; import org.apache.hadoop.hive.ql.exec.vector.DoubleColumnVector; import org.apache.hadoop.hive.ql.exec.vector.LongColumnVector; import org.apache.hadoop.hive.ql.exec.vector.VectorizedRowBatch; @@ -39,27 +41,57 @@ public void testConstantExpression() { ConstantVectorExpression longCve = new ConstantVectorExpression(0, 17); ConstantVectorExpression doubleCve = new ConstantVectorExpression(1, 17.34); - ConstantVectorExpression bytesCve = new ConstantVectorExpression(2, "alpha".getBytes()); + String str = "alpha"; + ConstantVectorExpression bytesCve = new ConstantVectorExpression(2, str.getBytes()); + Decimal128 decVal = new Decimal128(25.8, (short) 1); + ConstantVectorExpression decimalCve = new ConstantVectorExpression(3, decVal); int size = 20; - VectorizedRowBatch vrg = VectorizedRowGroupGenUtil.getVectorizedRowBatch(size, 3, 0); + VectorizedRowBatch vrg = VectorizedRowGroupGenUtil.getVectorizedRowBatch(size, 4, 0); LongColumnVector lcv = (LongColumnVector) vrg.cols[0]; DoubleColumnVector dcv = new DoubleColumnVector(size); BytesColumnVector bcv = new BytesColumnVector(size); + DecimalColumnVector dv = new DecimalColumnVector(5, 1); vrg.cols[1] = dcv; vrg.cols[2] = bcv; + vrg.cols[3] = dv; longCve.evaluate(vrg); doubleCve.evaluate(vrg); - bytesCve.evaluate(vrg); - + bytesCve.evaluate(vrg); + decimalCve.evaluate(vrg); assertTrue(lcv.isRepeating); assertTrue(dcv.isRepeating); assertTrue(bcv.isRepeating); assertEquals(17, lcv.vector[0]); assertTrue(17.34 == dcv.vector[0]); - assertTrue(Arrays.equals("alpha".getBytes(), bcv.vector[0])); + + byte[] alphaBytes = "alpha".getBytes(); + assertTrue(bcv.length[0] == alphaBytes.length); + assertTrue(sameFirstKBytes(alphaBytes, bcv.vector[0], alphaBytes.length)); + // Evaluation of the bytes Constant Vector Expression after the vector is + // modified. + ((BytesColumnVector) (vrg.cols[2])).vector[0] = "beta".getBytes(); + bytesCve.evaluate(vrg); + assertTrue(bcv.length[0] == alphaBytes.length); + assertTrue(sameFirstKBytes(alphaBytes, bcv.vector[0], alphaBytes.length)); + + assertTrue(25.8 == dv.vector[0].doubleValue()); + // Evaluation of the decimal Constant Vector Expression after the vector is + // modified. + ((DecimalColumnVector) (vrg.cols[3])).vector[0] = new Decimal128(39.7, (short) 1); + decimalCve.evaluate(vrg); + assertTrue(25.8 == dv.vector[0].doubleValue()); } + + private boolean sameFirstKBytes(byte[] o1, byte[] o2, int k) { + for (int i = 0; i != k; i++) { + if (o1[i] != o2[i]) { + return false; + } + } + return true; + } } Index: ql/src/test/org/apache/hadoop/hive/ql/exec/vector/TestVectorizationContext.java =================================================================== --- ql/src/test/org/apache/hadoop/hive/ql/exec/vector/TestVectorizationContext.java (revision 1580023) +++ ql/src/test/org/apache/hadoop/hive/ql/exec/vector/TestVectorizationContext.java (working copy) @@ -29,6 +29,7 @@ import junit.framework.Assert; +import org.apache.hadoop.hive.common.type.HiveDecimal; import org.apache.hadoop.hive.ql.exec.vector.expressions.ColAndCol; import org.apache.hadoop.hive.ql.exec.vector.expressions.ColOrCol; import org.apache.hadoop.hive.ql.exec.vector.expressions.DoubleColumnInList; @@ -124,8 +125,10 @@ import org.apache.hadoop.hive.ql.udf.generic.GenericUDFPower; import org.apache.hadoop.hive.ql.udf.generic.GenericUDFRound; import org.apache.hadoop.hive.ql.udf.generic.GenericUDFOPPlus; +import org.apache.hadoop.hive.ql.udf.generic.GenericUDFToDecimal; import org.apache.hadoop.hive.ql.udf.generic.GenericUDFToUnixTimeStamp; import org.apache.hadoop.hive.ql.udf.generic.GenericUDFTimestamp; +import org.apache.hadoop.hive.serde2.typeinfo.DecimalTypeInfo; import org.apache.hadoop.hive.serde2.typeinfo.TypeInfoFactory; import org.junit.Test; @@ -1185,4 +1188,25 @@ ve = vc.getVectorExpression(exprDesc); assertTrue(ve instanceof IfExprStringScalarStringColumn); } + + @Test + public void testFoldConstantsForUnaryExpression() throws HiveException { + ExprNodeConstantDesc constDesc = new ExprNodeConstantDesc(new Integer(1)); + GenericUDFToDecimal udf = new GenericUDFToDecimal(); + udf.setTypeInfo(new DecimalTypeInfo(5, 2)); + List children = new ArrayList(); + children.add(constDesc); + ExprNodeGenericFuncDesc exprDesc = new ExprNodeGenericFuncDesc(); + exprDesc.setGenericUDF(udf); + exprDesc.setChildren(children); + exprDesc.setTypeInfo(new DecimalTypeInfo(5, 2)); + Map columnMap = new HashMap(); + columnMap.put("col1", 1); + VectorizationContext vc = new VectorizationContext(columnMap, 1); + ExprNodeDesc constFoldNodeDesc = vc.foldConstantsForUnaryExpression(exprDesc); + assertTrue(constFoldNodeDesc instanceof ExprNodeConstantDesc); + assertTrue(((HiveDecimal) + (((ExprNodeConstantDesc)constFoldNodeDesc).getValue())).toString().equals("1")); + } + } Index: ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/ConstantVectorExpression.java =================================================================== --- ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/ConstantVectorExpression.java (revision 1580023) +++ ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/ConstantVectorExpression.java (working copy) @@ -93,14 +93,15 @@ BytesColumnVector cv = (BytesColumnVector) vrg.cols[outputColumn]; cv.isRepeating = true; cv.noNulls = true; - cv.setRef(0, bytesValue, 0, bytesValueLength); + cv.initBuffer(); + cv.setVal(0, bytesValue, 0, bytesValueLength); } private void evaluateDecimal(VectorizedRowBatch vrg) { DecimalColumnVector dcv = (DecimalColumnVector) vrg.cols[outputColumn]; dcv.isRepeating = true; dcv.noNulls = true; - dcv.vector[0] = decimalValue; + dcv.vector[0].update(decimalValue); } @Override Index: ql/src/java/org/apache/hadoop/hive/ql/exec/vector/VectorizationContext.java =================================================================== --- ql/src/java/org/apache/hadoop/hive/ql/exec/vector/VectorizationContext.java (revision 1580023) +++ ql/src/java/org/apache/hadoop/hive/ql/exec/vector/VectorizationContext.java (working copy) @@ -96,6 +96,7 @@ import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils; import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector; import org.apache.hadoop.hive.serde2.typeinfo.*; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils.ObjectInspectorCopyOption; /** * Context class for vectorization execution. @@ -307,6 +308,10 @@ if (ve == null) { throw new HiveException("Could not vectorize expression: "+exprDesc.getName()); } + if (LOG.isDebugEnabled()) { + LOG.debug("Input Expression = " + exprDesc.getTypeInfo() + + ", Vectorized Expression = " + ve.toString()); + } return ve; } @@ -432,6 +437,9 @@ // If castType is decimal, try not to lose precision for numeric types. castType = updatePrecision(inputTypeInfo, (DecimalTypeInfo) castType); GenericUDFToDecimal castToDecimalUDF = new GenericUDFToDecimal(); + castToDecimalUDF.setTypeInfo(new DecimalTypeInfo( + HiveDecimalUtils.getPrecisionForType((PrimitiveTypeInfo) castType), + HiveDecimalUtils.getScaleForType((PrimitiveTypeInfo) castType))); List children = new ArrayList(); children.add(child); ExprNodeDesc desc = new ExprNodeGenericFuncDesc(castType, castToDecimalUDF, children); @@ -506,6 +514,9 @@ genericUdf = new GenericUDFBridge(); ((GenericUDFBridge) genericUdf).setUdfClassName(udfClass.getClass().getName()); } + if (genericUdf instanceof SettableUDF) { + ((SettableUDF)genericUdf).setTypeInfo(castType); + } return genericUdf; } @@ -593,11 +604,11 @@ * expression. * @throws HiveException */ - private ExprNodeDesc foldConstantsForUnaryExpression(ExprNodeDesc exprDesc) throws HiveException { + ExprNodeDesc foldConstantsForUnaryExpression(ExprNodeDesc exprDesc) throws HiveException { if (!(exprDesc instanceof ExprNodeGenericFuncDesc)) { return exprDesc; } - + if (exprDesc.getChildren() == null || (exprDesc.getChildren().size() != 1) || (!( exprDesc.getChildren().get(0) instanceof ExprNodeConstantDesc))) { return exprDesc; @@ -605,15 +616,17 @@ GenericUDF gudf = ((ExprNodeGenericFuncDesc) exprDesc).getGenericUDF(); if (gudf instanceof GenericUDFOPNegative || gudf instanceof GenericUDFOPPositive - || castExpressionUdfs.contains(gudf) + || castExpressionUdfs.contains(gudf.getClass()) || ((gudf instanceof GenericUDFBridge) && castExpressionUdfs.contains(((GenericUDFBridge) gudf).getUdfClass()))) { ExprNodeEvaluator evaluator = ExprNodeEvaluatorFactory.get(exprDesc); ObjectInspector output = evaluator.initialize(null); Object constant = evaluator.evaluate(null); - Object java = ObjectInspectorUtils.copyToStandardJavaObject(constant, output); - return new ExprNodeConstantDesc(java); - } + Object java = ObjectInspectorUtils.copyToStandardJavaObject(constant, output); + return new ExprNodeConstantDesc( TypeInfoUtils.getTypeInfoFromObjectInspector( + ObjectInspectorUtils.getStandardObjectInspector(output, + ObjectInspectorCopyOption.JAVA)),java); + } return exprDesc; }