diff --git a/ql/src/java/org/apache/hadoop/hive/ql/exec/vector/VectorizationContext.java b/ql/src/java/org/apache/hadoop/hive/ql/exec/vector/VectorizationContext.java index 4802489..b86f1ba 100644 --- a/ql/src/java/org/apache/hadoop/hive/ql/exec/vector/VectorizationContext.java +++ b/ql/src/java/org/apache/hadoop/hive/ql/exec/vector/VectorizationContext.java @@ -104,20 +104,7 @@ import org.apache.hadoop.hive.ql.plan.ExprNodeGenericFuncDesc; import org.apache.hadoop.hive.ql.plan.GroupByDesc; import org.apache.hadoop.hive.ql.udf.SettableUDF; -import org.apache.hadoop.hive.ql.udf.UDFConv; -import org.apache.hadoop.hive.ql.udf.UDFFromUnixTime; -import org.apache.hadoop.hive.ql.udf.UDFHex; -import org.apache.hadoop.hive.ql.udf.UDFRegExpExtract; -import org.apache.hadoop.hive.ql.udf.UDFRegExpReplace; -import org.apache.hadoop.hive.ql.udf.UDFSign; -import org.apache.hadoop.hive.ql.udf.UDFToBoolean; -import org.apache.hadoop.hive.ql.udf.UDFToByte; -import org.apache.hadoop.hive.ql.udf.UDFToDouble; -import org.apache.hadoop.hive.ql.udf.UDFToFloat; -import org.apache.hadoop.hive.ql.udf.UDFToInteger; -import org.apache.hadoop.hive.ql.udf.UDFToLong; -import org.apache.hadoop.hive.ql.udf.UDFToShort; -import org.apache.hadoop.hive.ql.udf.UDFToString; +import org.apache.hadoop.hive.ql.udf.*; import org.apache.hadoop.hive.ql.udf.generic.*; import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator.Mode; import org.apache.hadoop.hive.serde2.ByteStream.Output; @@ -359,6 +346,58 @@ public void addProjectionColumn(String columnName, int vectorBatchColIndex) { castExpressionUdfs.add(UDFToShort.class); } + // Set of GenericUDFs which require need implicit type casting of decimal parameters. + // Vectorization for mathmatical functions currently depends on decimal params automatically + // being converted to the return type (see getImplicitCastExpression()), which is not correct + // in the general case. This set restricts automatic type conversion to just these functions. + private static Set> udfsNeedingImplicitDecimalCast = new HashSet>(); + static { + udfsNeedingImplicitDecimalCast.add(GenericUDFOPPlus.class); + udfsNeedingImplicitDecimalCast.add(GenericUDFOPMinus.class); + udfsNeedingImplicitDecimalCast.add(GenericUDFOPMultiply.class); + udfsNeedingImplicitDecimalCast.add(GenericUDFOPDivide.class); + udfsNeedingImplicitDecimalCast.add(GenericUDFOPMod.class); + udfsNeedingImplicitDecimalCast.add(GenericUDFRound.class); + udfsNeedingImplicitDecimalCast.add(GenericUDFBRound.class); + udfsNeedingImplicitDecimalCast.add(GenericUDFFloor.class); + udfsNeedingImplicitDecimalCast.add(GenericUDFCbrt.class); + udfsNeedingImplicitDecimalCast.add(GenericUDFCeil.class); + udfsNeedingImplicitDecimalCast.add(GenericUDFAbs.class); + udfsNeedingImplicitDecimalCast.add(GenericUDFPosMod.class); + udfsNeedingImplicitDecimalCast.add(GenericUDFPower.class); + udfsNeedingImplicitDecimalCast.add(GenericUDFFactorial.class); + udfsNeedingImplicitDecimalCast.add(GenericUDFOPPositive.class); + udfsNeedingImplicitDecimalCast.add(GenericUDFOPNegative.class); + udfsNeedingImplicitDecimalCast.add(GenericUDFCoalesce.class); + udfsNeedingImplicitDecimalCast.add(GenericUDFElt.class); + udfsNeedingImplicitDecimalCast.add(GenericUDFGreatest.class); + udfsNeedingImplicitDecimalCast.add(GenericUDFLeast.class); + udfsNeedingImplicitDecimalCast.add(UDFSqrt.class); + udfsNeedingImplicitDecimalCast.add(UDFRand.class); + udfsNeedingImplicitDecimalCast.add(UDFLn.class); + udfsNeedingImplicitDecimalCast.add(UDFLog2.class); + udfsNeedingImplicitDecimalCast.add(UDFSin.class); + udfsNeedingImplicitDecimalCast.add(UDFAsin.class); + udfsNeedingImplicitDecimalCast.add(UDFCos.class); + udfsNeedingImplicitDecimalCast.add(UDFAcos.class); + udfsNeedingImplicitDecimalCast.add(UDFLog10.class); + udfsNeedingImplicitDecimalCast.add(UDFLog.class); + udfsNeedingImplicitDecimalCast.add(UDFExp.class); + udfsNeedingImplicitDecimalCast.add(UDFDegrees.class); + udfsNeedingImplicitDecimalCast.add(UDFRadians.class); + udfsNeedingImplicitDecimalCast.add(UDFAtan.class); + udfsNeedingImplicitDecimalCast.add(UDFTan.class); + udfsNeedingImplicitDecimalCast.add(UDFOPLongDivide.class); + } + + protected boolean needsImplicitCastForDecimal(GenericUDF udf) { + Class udfClass = udf.getClass(); + if (udf instanceof GenericUDFBridge) { + udfClass = ((GenericUDFBridge) udf).getUdfClass(); + } + return udfsNeedingImplicitDecimalCast.contains(udfClass); + } + protected int getInputColumnIndex(String name) throws HiveException { if (name == null) { throw new HiveException("Null column name"); @@ -764,24 +803,26 @@ private ExprNodeDesc getImplicitCastExpression(GenericUDF udf, ExprNodeDesc chil } if (castTypeDecimal && !inputTypeDecimal) { - - // Cast the input to decimal - // If castType is decimal, try not to lose precision for numeric types. - castType = updatePrecision(inputTypeInfo, (DecimalTypeInfo) castType); - GenericUDFToDecimal castToDecimalUDF = new GenericUDFToDecimal(); - castToDecimalUDF.setTypeInfo(castType); - List children = new ArrayList(); - children.add(child); - ExprNodeDesc desc = new ExprNodeGenericFuncDesc(castType, castToDecimalUDF, children); - return desc; + if (needsImplicitCastForDecimal(udf)) { + // Cast the input to decimal + // If castType is decimal, try not to lose precision for numeric types. + castType = updatePrecision(inputTypeInfo, (DecimalTypeInfo) castType); + GenericUDFToDecimal castToDecimalUDF = new GenericUDFToDecimal(); + castToDecimalUDF.setTypeInfo(castType); + List children = new ArrayList(); + children.add(child); + ExprNodeDesc desc = new ExprNodeGenericFuncDesc(castType, castToDecimalUDF, children); + return desc; + } } else if (!castTypeDecimal && inputTypeDecimal) { - - // Cast decimal input to returnType - GenericUDF genericUdf = getGenericUDFForCast(castType); - List children = new ArrayList(); - children.add(child); - ExprNodeDesc desc = new ExprNodeGenericFuncDesc(castType, genericUdf, children); - return desc; + if (needsImplicitCastForDecimal(udf)) { + // Cast decimal input to returnType + GenericUDF genericUdf = getGenericUDFForCast(castType); + List children = new ArrayList(); + children.add(child); + ExprNodeDesc desc = new ExprNodeGenericFuncDesc(castType, genericUdf, children); + return desc; + } } else { // Casts to exact types including long to double etc. are needed in some special cases. diff --git a/ql/src/test/org/apache/hadoop/hive/ql/exec/vector/TestVectorizationContext.java b/ql/src/test/org/apache/hadoop/hive/ql/exec/vector/TestVectorizationContext.java index bb37a04..9fcb392 100644 --- a/ql/src/test/org/apache/hadoop/hive/ql/exec/vector/TestVectorizationContext.java +++ b/ql/src/test/org/apache/hadoop/hive/ql/exec/vector/TestVectorizationContext.java @@ -31,11 +31,13 @@ import org.apache.hadoop.hive.ql.exec.vector.expressions.ColAndCol; import org.apache.hadoop.hive.ql.exec.vector.expressions.ColOrCol; import org.apache.hadoop.hive.ql.exec.vector.expressions.DoubleColumnInList; +import org.apache.hadoop.hive.ql.exec.vector.expressions.DynamicValueVectorExpression; import org.apache.hadoop.hive.ql.exec.vector.expressions.FilterExprAndExpr; import org.apache.hadoop.hive.ql.exec.vector.expressions.FilterExprOrExpr; import org.apache.hadoop.hive.ql.exec.vector.expressions.FuncLogWithBaseDoubleToDouble; import org.apache.hadoop.hive.ql.exec.vector.expressions.FuncLogWithBaseLongToDouble; import org.apache.hadoop.hive.ql.exec.vector.expressions.FuncPowerDoubleToDouble; +import org.apache.hadoop.hive.ql.exec.vector.expressions.IdentityExpression; import org.apache.hadoop.hive.ql.exec.vector.expressions.IfExprCharScalarStringGroupColumn; import org.apache.hadoop.hive.ql.exec.vector.expressions.IfExprDoubleColumnDoubleColumn; import org.apache.hadoop.hive.ql.exec.vector.expressions.IfExprLongColumnLongColumn; @@ -66,6 +68,7 @@ import org.apache.hadoop.hive.ql.exec.vector.expressions.StringLower; import org.apache.hadoop.hive.ql.exec.vector.expressions.StringUpper; import org.apache.hadoop.hive.ql.exec.vector.expressions.VectorExpression; +import org.apache.hadoop.hive.ql.exec.vector.expressions.VectorInBloomFilterColDynamicValue; import org.apache.hadoop.hive.ql.exec.vector.expressions.VectorUDFUnixTimeStampDate; import org.apache.hadoop.hive.ql.exec.vector.expressions.VectorUDFUnixTimeStampTimestamp; import org.apache.hadoop.hive.ql.exec.vector.expressions.FilterStringColumnInList; @@ -110,9 +113,11 @@ import org.apache.hadoop.hive.ql.exec.vector.expressions.gen.LongColUnaryMinus; import org.apache.hadoop.hive.ql.exec.vector.expressions.gen.LongScalarSubtractLongColumn; import org.apache.hadoop.hive.ql.metadata.HiveException; +import org.apache.hadoop.hive.ql.plan.DynamicValue; import org.apache.hadoop.hive.ql.plan.ExprNodeColumnDesc; import org.apache.hadoop.hive.ql.plan.ExprNodeConstantDesc; import org.apache.hadoop.hive.ql.plan.ExprNodeDesc; +import org.apache.hadoop.hive.ql.plan.ExprNodeDynamicValueDesc; import org.apache.hadoop.hive.ql.plan.ExprNodeGenericFuncDesc; import org.apache.hadoop.hive.ql.udf.UDFLog; import org.apache.hadoop.hive.ql.udf.UDFSin; @@ -123,6 +128,7 @@ import org.apache.hadoop.hive.ql.udf.generic.GenericUDFBridge; import org.apache.hadoop.hive.ql.udf.generic.GenericUDFIf; import org.apache.hadoop.hive.ql.udf.generic.GenericUDFIn; +import org.apache.hadoop.hive.ql.udf.generic.GenericUDFInBloomFilter; import org.apache.hadoop.hive.ql.udf.generic.GenericUDFLTrim; import org.apache.hadoop.hive.ql.udf.generic.GenericUDFLower; import org.apache.hadoop.hive.ql.udf.generic.GenericUDFOPAnd; @@ -1584,4 +1590,35 @@ public void testSIMDNotEqual() { b = 1; assertEquals(a != b ? 1 : 0, ((a - b) ^ (b - a)) >>> 63); } + + @Test + public void testInBloomFilter() throws Exception { + // Setup InBloomFilter() UDF + ExprNodeColumnDesc colExpr = new ExprNodeColumnDesc(TypeInfoFactory.getDecimalTypeInfo(10, 5), "a", "table", false); + ExprNodeDesc bfExpr = new ExprNodeDynamicValueDesc(new DynamicValue("id1", TypeInfoFactory.binaryTypeInfo)); + + ExprNodeGenericFuncDesc inBloomFilterExpr = new ExprNodeGenericFuncDesc(); + GenericUDF inBloomFilterUdf = new GenericUDFInBloomFilter(); + inBloomFilterExpr.setTypeInfo(TypeInfoFactory.booleanTypeInfo); + inBloomFilterExpr.setGenericUDF(inBloomFilterUdf); + List children1 = new ArrayList(2); + children1.add(colExpr); + children1.add(bfExpr); + inBloomFilterExpr.setChildren(children1); + + // Setup VectorizationContext + List columns = new ArrayList(); + columns.add("b"); + columns.add("a"); + VectorizationContext vc = new VectorizationContext("name", columns); + + // Create vectorized expr + VectorExpression ve = vc.getVectorExpression(inBloomFilterExpr, VectorExpressionDescriptor.Mode.FILTER); + Assert.assertEquals(VectorInBloomFilterColDynamicValue.class, ve.getClass()); + VectorInBloomFilterColDynamicValue vectorizedInBloomFilterExpr = (VectorInBloomFilterColDynamicValue) ve; + VectorExpression[] children = vectorizedInBloomFilterExpr.getChildExpressions(); + // VectorInBloomFilterColDynamicValue should have all of the necessary information to vectorize. + // Should be no need for child vector expressions, which would imply casting/conversion. + Assert.assertNull(children); + } }