diff --git ql/src/java/org/apache/hadoop/hive/ql/exec/vector/VectorizationContext.java ql/src/java/org/apache/hadoop/hive/ql/exec/vector/VectorizationContext.java index eb8c4c5..ce5c37e 100644 --- ql/src/java/org/apache/hadoop/hive/ql/exec/vector/VectorizationContext.java +++ ql/src/java/org/apache/hadoop/hive/ql/exec/vector/VectorizationContext.java @@ -19,7 +19,7 @@ package org.apache.hadoop.hive.ql.exec.vector; import java.lang.reflect.Constructor; -import java.lang.reflect.InvocationTargetException; +import java.sql.Date; import java.sql.Timestamp; import java.util.ArrayList; import java.util.Arrays; @@ -39,11 +39,33 @@ import org.apache.hadoop.hive.ql.exec.FunctionInfo; import org.apache.hadoop.hive.ql.exec.FunctionRegistry; import org.apache.hadoop.hive.ql.exec.UDF; -import org.apache.hadoop.hive.ql.exec.vector.TimestampUtils; -import org.apache.hadoop.hive.ql.exec.vector.VectorExpressionDescriptor.ArgumentType; import org.apache.hadoop.hive.ql.exec.vector.VectorExpressionDescriptor.InputExpressionType; import org.apache.hadoop.hive.ql.exec.vector.VectorExpressionDescriptor.Mode; -import org.apache.hadoop.hive.ql.exec.vector.expressions.*; +import org.apache.hadoop.hive.ql.exec.vector.expressions.CastBooleanToStringViaLongToString; +import org.apache.hadoop.hive.ql.exec.vector.expressions.CastDecimalToDecimal; +import org.apache.hadoop.hive.ql.exec.vector.expressions.CastDecimalToDouble; +import org.apache.hadoop.hive.ql.exec.vector.expressions.CastDecimalToString; +import org.apache.hadoop.hive.ql.exec.vector.expressions.CastDoubleToDecimal; +import org.apache.hadoop.hive.ql.exec.vector.expressions.CastLongToDecimal; +import org.apache.hadoop.hive.ql.exec.vector.expressions.CastLongToString; +import org.apache.hadoop.hive.ql.exec.vector.expressions.CastStringToDecimal; +import org.apache.hadoop.hive.ql.exec.vector.expressions.CastTimestampToDecimal; +import org.apache.hadoop.hive.ql.exec.vector.expressions.ConstantVectorExpression; +import org.apache.hadoop.hive.ql.exec.vector.expressions.DoubleColumnInList; +import org.apache.hadoop.hive.ql.exec.vector.expressions.FilterConstantBooleanVectorExpression; +import org.apache.hadoop.hive.ql.exec.vector.expressions.FilterDoubleColumnInList; +import org.apache.hadoop.hive.ql.exec.vector.expressions.FilterLongColumnInList; +import org.apache.hadoop.hive.ql.exec.vector.expressions.FilterStringColumnInList; +import org.apache.hadoop.hive.ql.exec.vector.expressions.IDoubleInExpr; +import org.apache.hadoop.hive.ql.exec.vector.expressions.ILongInExpr; +import org.apache.hadoop.hive.ql.exec.vector.expressions.IStringInExpr; +import org.apache.hadoop.hive.ql.exec.vector.expressions.IdentityExpression; +import org.apache.hadoop.hive.ql.exec.vector.expressions.LongColumnInList; +import org.apache.hadoop.hive.ql.exec.vector.expressions.SelectColumnIsTrue; +import org.apache.hadoop.hive.ql.exec.vector.expressions.StringColumnInList; +import org.apache.hadoop.hive.ql.exec.vector.expressions.StringLength; +import org.apache.hadoop.hive.ql.exec.vector.expressions.VectorCoalesce; +import org.apache.hadoop.hive.ql.exec.vector.expressions.VectorExpression; import org.apache.hadoop.hive.ql.exec.vector.expressions.aggregates.VectorAggregateExpression; import org.apache.hadoop.hive.ql.exec.vector.expressions.aggregates.VectorUDAFAvgDecimal; import org.apache.hadoop.hive.ql.exec.vector.expressions.aggregates.VectorUDAFCount; @@ -90,12 +112,46 @@ import org.apache.hadoop.hive.ql.plan.ExprNodeConstantDesc; import org.apache.hadoop.hive.ql.plan.ExprNodeDesc; import org.apache.hadoop.hive.ql.plan.ExprNodeGenericFuncDesc; -import org.apache.hadoop.hive.ql.udf.*; -import org.apache.hadoop.hive.ql.udf.generic.*; +import org.apache.hadoop.hive.ql.udf.SettableUDF; +import org.apache.hadoop.hive.ql.udf.UDFConv; +import org.apache.hadoop.hive.ql.udf.UDFHex; +import org.apache.hadoop.hive.ql.udf.UDFSign; +import org.apache.hadoop.hive.ql.udf.UDFToBoolean; +import org.apache.hadoop.hive.ql.udf.UDFToByte; +import org.apache.hadoop.hive.ql.udf.UDFToDouble; +import org.apache.hadoop.hive.ql.udf.UDFToFloat; +import org.apache.hadoop.hive.ql.udf.UDFToInteger; +import org.apache.hadoop.hive.ql.udf.UDFToLong; +import org.apache.hadoop.hive.ql.udf.UDFToShort; +import org.apache.hadoop.hive.ql.udf.UDFToString; +import org.apache.hadoop.hive.ql.udf.generic.GenericUDF; +import org.apache.hadoop.hive.ql.udf.generic.GenericUDFBaseCompare; +import org.apache.hadoop.hive.ql.udf.generic.GenericUDFBetween; +import org.apache.hadoop.hive.ql.udf.generic.GenericUDFBridge; +import org.apache.hadoop.hive.ql.udf.generic.GenericUDFCase; +import org.apache.hadoop.hive.ql.udf.generic.GenericUDFCoalesce; +import org.apache.hadoop.hive.ql.udf.generic.GenericUDFIn; +import org.apache.hadoop.hive.ql.udf.generic.GenericUDFOPAnd; +import org.apache.hadoop.hive.ql.udf.generic.GenericUDFOPNegative; +import org.apache.hadoop.hive.ql.udf.generic.GenericUDFOPOr; +import org.apache.hadoop.hive.ql.udf.generic.GenericUDFOPPositive; +import org.apache.hadoop.hive.ql.udf.generic.GenericUDFRound; +import org.apache.hadoop.hive.ql.udf.generic.GenericUDFTimestamp; +import org.apache.hadoop.hive.ql.udf.generic.GenericUDFToBinary; +import org.apache.hadoop.hive.ql.udf.generic.GenericUDFToChar; +import org.apache.hadoop.hive.ql.udf.generic.GenericUDFToDate; +import org.apache.hadoop.hive.ql.udf.generic.GenericUDFToDecimal; +import org.apache.hadoop.hive.ql.udf.generic.GenericUDFToUnixTimeStamp; +import org.apache.hadoop.hive.ql.udf.generic.GenericUDFToUtcTimestamp; +import org.apache.hadoop.hive.ql.udf.generic.GenericUDFToVarchar; +import org.apache.hadoop.hive.ql.udf.generic.GenericUDFWhen; +import org.apache.hadoop.hive.serde2.io.DateWritable; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils; -import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector; -import org.apache.hadoop.hive.serde2.typeinfo.*; +import org.apache.hadoop.hive.serde2.typeinfo.DecimalTypeInfo; +import org.apache.hadoop.hive.serde2.typeinfo.HiveDecimalUtils; +import org.apache.hadoop.hive.serde2.typeinfo.PrimitiveTypeInfo; +import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo; /** * Context class for vectorization execution. @@ -307,6 +363,10 @@ public VectorExpression getVectorExpression(ExprNodeDesc exprDesc, Mode mode) th if (ve == null) { throw new HiveException("Could not vectorize expression: "+exprDesc.getName()); } + if (LOG.isDebugEnabled()) { + LOG.debug("Input Expression = " + exprDesc.getTypeInfo() + + ", Vectorized Expression = " + ve.toString()); + } return ve; } @@ -398,7 +458,7 @@ private TypeInfo updatePrecision(TypeInfo inputTypeInfo, DecimalTypeInfo returnT return returnType; } PrimitiveTypeInfo ptinfo = (PrimitiveTypeInfo) inputTypeInfo; - int precision = HiveDecimalUtils.getPrecisionForType(ptinfo); + int precision = getPrecisionForType(ptinfo); int scale = HiveDecimalUtils.getScaleForType(ptinfo); return new DecimalTypeInfo(precision, scale); } @@ -432,6 +492,9 @@ private ExprNodeDesc getImplicitCastExpression(GenericUDF udf, ExprNodeDesc chil // If castType is decimal, try not to lose precision for numeric types. castType = updatePrecision(inputTypeInfo, (DecimalTypeInfo) castType); GenericUDFToDecimal castToDecimalUDF = new GenericUDFToDecimal(); + castToDecimalUDF.setTypeInfo(new DecimalTypeInfo( + getPrecisionForType((PrimitiveTypeInfo) castType), + HiveDecimalUtils.getScaleForType((PrimitiveTypeInfo) castType))); List children = new ArrayList(); children.add(child); ExprNodeDesc desc = new ExprNodeGenericFuncDesc(castType, castToDecimalUDF, children); @@ -457,6 +520,13 @@ private ExprNodeDesc getImplicitCastExpression(GenericUDF udf, ExprNodeDesc chil } return null; } + + private int getPrecisionForType(PrimitiveTypeInfo typeInfo) { + if (isFloatFamily(typeInfo.getTypeName())) { + return HiveDecimal.MAX_PRECISION; + } + return HiveDecimalUtils.getPrecisionForType(typeInfo); + } private GenericUDF getGenericUDFForCast(TypeInfo castType) throws HiveException { UDF udfClass = null; @@ -506,6 +576,9 @@ private GenericUDF getGenericUDFForCast(TypeInfo castType) throws HiveException genericUdf = new GenericUDFBridge(); ((GenericUDFBridge) genericUdf).setUdfClassName(udfClass.getClass().getName()); } + if (genericUdf instanceof SettableUDF) { + ((SettableUDF)genericUdf).setTypeInfo(castType); + } return genericUdf; } @@ -593,27 +666,30 @@ public static boolean isCustomUDF(ExprNodeGenericFuncDesc expr) { * expression. * @throws HiveException */ - private ExprNodeDesc foldConstantsForUnaryExpression(ExprNodeDesc exprDesc) throws HiveException { + ExprNodeDesc foldConstantsForUnaryExpression(ExprNodeDesc exprDesc) throws HiveException { if (!(exprDesc instanceof ExprNodeGenericFuncDesc)) { return exprDesc; } - + if (exprDesc.getChildren() == null || (exprDesc.getChildren().size() != 1) || (!( exprDesc.getChildren().get(0) instanceof ExprNodeConstantDesc))) { return exprDesc; } + ExprNodeConstantDesc encd = (ExprNodeConstantDesc) exprDesc.getChildren().get(0); + ObjectInspector childoi = encd.getWritableObjectInspector(); GenericUDF gudf = ((ExprNodeGenericFuncDesc) exprDesc).getGenericUDF(); + if (gudf instanceof GenericUDFOPNegative || gudf instanceof GenericUDFOPPositive - || castExpressionUdfs.contains(gudf) + || castExpressionUdfs.contains(gudf.getClass()) || ((gudf instanceof GenericUDFBridge) && castExpressionUdfs.contains(((GenericUDFBridge) gudf).getUdfClass()))) { ExprNodeEvaluator evaluator = ExprNodeEvaluatorFactory.get(exprDesc); - ObjectInspector output = evaluator.initialize(null); + ObjectInspector output = evaluator.initialize(childoi); Object constant = evaluator.evaluate(null); - Object java = ObjectInspectorUtils.copyToStandardJavaObject(constant, output); - return new ExprNodeConstantDesc(java); - } + Object java = ObjectInspectorUtils.copyToStandardJavaObject(constant, output); + return new ExprNodeConstantDesc(exprDesc.getTypeInfo(), java); + } return exprDesc; } @@ -754,7 +830,7 @@ private VectorExpression createVectorExpression(Class vectorClass, } arguments[i] = colIndex; } else if (child instanceof ExprNodeConstantDesc) { - Object scalarValue = getScalarValue((ExprNodeConstantDesc) child); + Object scalarValue = getVectorTypeScalarValue((ExprNodeConstantDesc) child); arguments[i] = scalarValue; } else { throw new HiveException("Cannot handle expression type: " + child.getClass().getSimpleName()); @@ -941,8 +1017,8 @@ private VectorExpression getInExpression(List childExpr, Mode mode } expr = createVectorExpression(cl, childExpr.subList(0, 1), Mode.PROJECTION, colTypeInfo); ((IDoubleInExpr) expr).setInListValues(inValsD); - } - + } + // Return the desired VectorExpression if found. Otherwise, return null to cause // execution to fall back to row mode. return expr; @@ -1071,7 +1147,7 @@ private VectorExpression getCastToDoubleExpression(Class udf, List childExpr } else { cl = FilterLongColumnBetween.class; } - } + } return createVectorExpression(cl, childrenAfterNot, Mode.PROJECTION, null); } @@ -1252,6 +1328,14 @@ public static boolean isDatetimeFamily(String resultType) { return resultType.equalsIgnoreCase("timestamp") || resultType.equalsIgnoreCase("date"); } + public static boolean isTimestampFamily(String resultType) { + return resultType.equalsIgnoreCase("timestamp"); + } + + public static boolean isDateFamily(String resultType) { + return resultType.equalsIgnoreCase("date"); + } + // return true if this is any kind of float public static boolean isFloatFamily(String resultType) { return resultType.equalsIgnoreCase("double") @@ -1322,6 +1406,24 @@ private double getNumericScalarAsDouble(ExprNodeDesc constDesc) } throw new HiveException("Unexpected type when converting to double"); } + + private Object getVectorTypeScalarValue(ExprNodeConstantDesc constDesc) throws HiveException { + String t = constDesc.getTypeInfo().getTypeName(); + if (isIntFamily(t)) { + return getIntFamilyScalarAsLong(constDesc); + } else if (isFloatFamily(t)){ + return getNumericScalarAsDouble(constDesc); + } else if (isDecimalFamily(t)) { + return getScalarValue(constDesc); + } else if (isTimestampFamily(t)) { + return TimestampUtils.getTimeNanoSec((Timestamp) getScalarValue(constDesc)); + } else if (isDateFamily(t)) { + return DateWritable.dateToDays((Date) getScalarValue(constDesc)); + } else if (isStringFamily(t)) { + return getScalarValue(constDesc); + } + throw new HiveException("Unexpected type for ExprNodeConstantDesc : "+t); + } // Get a timestamp as a long in number of nanos, from a string constant or cast private long getTimestampScalar(ExprNodeDesc expr) throws HiveException { diff --git ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/ConstantVectorExpression.java ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/ConstantVectorExpression.java index d4c00ab..a1a5584 100644 --- ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/ConstantVectorExpression.java +++ ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/ConstantVectorExpression.java @@ -93,14 +93,15 @@ private void evaluateBytes(VectorizedRowBatch vrg) { BytesColumnVector cv = (BytesColumnVector) vrg.cols[outputColumn]; cv.isRepeating = true; cv.noNulls = true; - cv.setRef(0, bytesValue, 0, bytesValueLength); + cv.initBuffer(); + cv.setVal(0, bytesValue, 0, bytesValueLength); } private void evaluateDecimal(VectorizedRowBatch vrg) { DecimalColumnVector dcv = (DecimalColumnVector) vrg.cols[outputColumn]; dcv.isRepeating = true; dcv.noNulls = true; - dcv.vector[0] = decimalValue; + dcv.vector[0].update(decimalValue); } @Override diff --git ql/src/test/org/apache/hadoop/hive/ql/exec/vector/TestVectorizationContext.java ql/src/test/org/apache/hadoop/hive/ql/exec/vector/TestVectorizationContext.java index 7fa9730..5ebab70 100644 --- ql/src/test/org/apache/hadoop/hive/ql/exec/vector/TestVectorizationContext.java +++ ql/src/test/org/apache/hadoop/hive/ql/exec/vector/TestVectorizationContext.java @@ -29,6 +29,7 @@ import junit.framework.Assert; +import org.apache.hadoop.hive.common.type.HiveDecimal; import org.apache.hadoop.hive.ql.exec.vector.expressions.ColAndCol; import org.apache.hadoop.hive.ql.exec.vector.expressions.ColOrCol; import org.apache.hadoop.hive.ql.exec.vector.expressions.DoubleColumnInList; @@ -124,8 +125,10 @@ import org.apache.hadoop.hive.ql.udf.generic.GenericUDFPower; import org.apache.hadoop.hive.ql.udf.generic.GenericUDFRound; import org.apache.hadoop.hive.ql.udf.generic.GenericUDFOPPlus; +import org.apache.hadoop.hive.ql.udf.generic.GenericUDFToDecimal; import org.apache.hadoop.hive.ql.udf.generic.GenericUDFToUnixTimeStamp; import org.apache.hadoop.hive.ql.udf.generic.GenericUDFTimestamp; +import org.apache.hadoop.hive.serde2.typeinfo.DecimalTypeInfo; import org.apache.hadoop.hive.serde2.typeinfo.TypeInfoFactory; import org.junit.Test; @@ -1130,6 +1133,7 @@ public void testIfConditionalExprs() throws HiveException { // timestamp scalar/column children1.set(2, col3Expr); + ve = vc.getVectorExpression(exprDesc); assertTrue(IfExprLongColumnLongColumn.class == ve.getClass() || IfExprLongScalarLongColumn.class == ve.getClass()); @@ -1185,4 +1189,25 @@ public void testIfConditionalExprs() throws HiveException { ve = vc.getVectorExpression(exprDesc); assertTrue(ve instanceof IfExprStringScalarStringColumn); } + + @Test + public void testFoldConstantsForUnaryExpression() throws HiveException { + ExprNodeConstantDesc constDesc = new ExprNodeConstantDesc(new Integer(1)); + GenericUDFToDecimal udf = new GenericUDFToDecimal(); + udf.setTypeInfo(new DecimalTypeInfo(5, 2)); + List children = new ArrayList(); + children.add(constDesc); + ExprNodeGenericFuncDesc exprDesc = new ExprNodeGenericFuncDesc(); + exprDesc.setGenericUDF(udf); + exprDesc.setChildren(children); + exprDesc.setTypeInfo(new DecimalTypeInfo(5, 2)); + Map columnMap = new HashMap(); + columnMap.put("col1", 1); + VectorizationContext vc = new VectorizationContext(columnMap, 1); + ExprNodeDesc constFoldNodeDesc = vc.foldConstantsForUnaryExpression(exprDesc); + assertTrue(constFoldNodeDesc instanceof ExprNodeConstantDesc); + assertTrue(((HiveDecimal) + (((ExprNodeConstantDesc)constFoldNodeDesc).getValue())).toString().equals("1")); + } + } diff --git ql/src/test/org/apache/hadoop/hive/ql/exec/vector/expressions/TestConstantVectorExpression.java ql/src/test/org/apache/hadoop/hive/ql/exec/vector/expressions/TestConstantVectorExpression.java index 4321545..d2a7816 100644 --- ql/src/test/org/apache/hadoop/hive/ql/exec/vector/expressions/TestConstantVectorExpression.java +++ ql/src/test/org/apache/hadoop/hive/ql/exec/vector/expressions/TestConstantVectorExpression.java @@ -23,7 +23,9 @@ import java.util.Arrays; +import org.apache.hadoop.hive.common.type.Decimal128; import org.apache.hadoop.hive.ql.exec.vector.BytesColumnVector; +import org.apache.hadoop.hive.ql.exec.vector.DecimalColumnVector; import org.apache.hadoop.hive.ql.exec.vector.DoubleColumnVector; import org.apache.hadoop.hive.ql.exec.vector.LongColumnVector; import org.apache.hadoop.hive.ql.exec.vector.VectorizedRowBatch; @@ -39,27 +41,57 @@ public void testConstantExpression() { ConstantVectorExpression longCve = new ConstantVectorExpression(0, 17); ConstantVectorExpression doubleCve = new ConstantVectorExpression(1, 17.34); - ConstantVectorExpression bytesCve = new ConstantVectorExpression(2, "alpha".getBytes()); + String str = "alpha"; + ConstantVectorExpression bytesCve = new ConstantVectorExpression(2, str.getBytes()); + Decimal128 decVal = new Decimal128(25.8, (short) 1); + ConstantVectorExpression decimalCve = new ConstantVectorExpression(3, decVal); int size = 20; - VectorizedRowBatch vrg = VectorizedRowGroupGenUtil.getVectorizedRowBatch(size, 3, 0); + VectorizedRowBatch vrg = VectorizedRowGroupGenUtil.getVectorizedRowBatch(size, 4, 0); LongColumnVector lcv = (LongColumnVector) vrg.cols[0]; DoubleColumnVector dcv = new DoubleColumnVector(size); BytesColumnVector bcv = new BytesColumnVector(size); + DecimalColumnVector dv = new DecimalColumnVector(5, 1); vrg.cols[1] = dcv; vrg.cols[2] = bcv; + vrg.cols[3] = dv; longCve.evaluate(vrg); doubleCve.evaluate(vrg); - bytesCve.evaluate(vrg); - + bytesCve.evaluate(vrg); + decimalCve.evaluate(vrg); assertTrue(lcv.isRepeating); assertTrue(dcv.isRepeating); assertTrue(bcv.isRepeating); assertEquals(17, lcv.vector[0]); assertTrue(17.34 == dcv.vector[0]); - assertTrue(Arrays.equals("alpha".getBytes(), bcv.vector[0])); + + byte[] alphaBytes = "alpha".getBytes(); + assertTrue(bcv.length[0] == alphaBytes.length); + assertTrue(sameFirstKBytes(alphaBytes, bcv.vector[0], alphaBytes.length)); + // Evaluation of the bytes Constant Vector Expression after the vector is + // modified. + ((BytesColumnVector) (vrg.cols[2])).vector[0] = "beta".getBytes(); + bytesCve.evaluate(vrg); + assertTrue(bcv.length[0] == alphaBytes.length); + assertTrue(sameFirstKBytes(alphaBytes, bcv.vector[0], alphaBytes.length)); + + assertTrue(25.8 == dv.vector[0].doubleValue()); + // Evaluation of the decimal Constant Vector Expression after the vector is + // modified. + ((DecimalColumnVector) (vrg.cols[3])).vector[0] = new Decimal128(39.7, (short) 1); + decimalCve.evaluate(vrg); + assertTrue(25.8 == dv.vector[0].doubleValue()); + } + + private boolean sameFirstKBytes(byte[] o1, byte[] o2, int k) { + for (int i = 0; i != k; i++) { + if (o1[i] != o2[i]) { + return false; + } + } + return true; } }