diff --git ql/src/java/org/apache/hadoop/hive/ql/exec/FunctionRegistry.java ql/src/java/org/apache/hadoop/hive/ql/exec/FunctionRegistry.java index 03cf85d..880a8ac 100644 --- ql/src/java/org/apache/hadoop/hive/ql/exec/FunctionRegistry.java +++ ql/src/java/org/apache/hadoop/hive/ql/exec/FunctionRegistry.java @@ -149,8 +149,11 @@ import org.apache.hadoop.hive.serde.serdeConstants; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector.Category; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils.PrimitiveGrouping; import org.apache.hadoop.hive.serde2.typeinfo.ListTypeInfo; import org.apache.hadoop.hive.serde2.typeinfo.MapTypeInfo; +import org.apache.hadoop.hive.serde2.typeinfo.PrimitiveTypeInfo; import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo; import org.apache.hadoop.hive.serde2.typeinfo.TypeInfoFactory; import org.apache.hadoop.hive.serde2.typeinfo.TypeInfoUtils; @@ -956,6 +959,59 @@ public static int matchCost(TypeInfo argumentPassed, } /** + * Given a set of candidate methods and list of argument types, try to + * select the best candidate based on how close the passed argument types are + * to the candidate argument types. + * For a varchar argument, we would prefer evaluate(string) over evaluate(double). + * @param udfMethods list of candidate methods + * @param argumentsPassed list of argument types to match to the candidate methods + */ + static void filterMethodsByTypeAffinity(List udfMethods, List argumentsPassed) { + if (udfMethods.size() > 1) { + // Prefer methods with a closer signature based on the primitive grouping of each argument. + // Score each method based on its similarity to the passed argument types. + int currentScore = 0; + int bestMatchScore = 0; + Method bestMatch = null; + for (Method m: udfMethods) { + currentScore = 0; + List argumentsAccepted = + TypeInfoUtils.getParameterTypeInfos(m, argumentsPassed.size()); + Iterator argsPassedIter = argumentsPassed.iterator(); + for (TypeInfo acceptedType : argumentsAccepted) { + // Check the affinity of the argument passed in with the accepted argument, + // based on the PrimitiveGrouping + TypeInfo passedType = argsPassedIter.next(); + if (acceptedType.getCategory() == Category.PRIMITIVE + && passedType.getCategory() == Category.PRIMITIVE) { + PrimitiveGrouping acceptedPg = PrimitiveObjectInspectorUtils.getPrimitiveGrouping( + ((PrimitiveTypeInfo) acceptedType).getPrimitiveCategory()); + PrimitiveGrouping passedPg = PrimitiveObjectInspectorUtils.getPrimitiveGrouping( + ((PrimitiveTypeInfo) passedType).getPrimitiveCategory()); + if (acceptedPg == passedPg) { + // The passed argument matches somewhat closely with an accepted argument + ++currentScore; + } + } + } + // Check if the score for this method is any better relative to others + if (currentScore > bestMatchScore) { + bestMatchScore = currentScore; + bestMatch = m; + } else if (currentScore == bestMatchScore) { + bestMatch = null; // no longer a best match if more than one. + } + } + + if (bestMatch != null) { + // Found a best match during this processing, use it. + udfMethods.clear(); + udfMethods.add(bestMatch); + } + } + } + + /** * Gets the closest matching method corresponding to the argument list from a * list of methods. * @@ -1025,6 +1081,13 @@ public static Method getMethodInternal(Class udfClass, List mlist, bo // No matching methods found throw new NoMatchingMethodException(udfClass, argumentsPassed, mlist); } + + if (udfMethods.size() > 1) { + // First try selecting methods based on the type affinity of the arguments passed + // to the candidate method arguments. + filterMethodsByTypeAffinity(udfMethods, argumentsPassed); + } + if (udfMethods.size() > 1) { // if the only difference is numeric types, pick the method diff --git ql/src/test/org/apache/hadoop/hive/ql/exec/TestFunctionRegistry.java ql/src/test/org/apache/hadoop/hive/ql/exec/TestFunctionRegistry.java index 3875b5d..27a9832 100644 --- ql/src/test/org/apache/hadoop/hive/ql/exec/TestFunctionRegistry.java +++ ql/src/test/org/apache/hadoop/hive/ql/exec/TestFunctionRegistry.java @@ -18,7 +18,9 @@ package org.apache.hadoop.hive.ql.exec; +import java.lang.reflect.Type; import java.lang.reflect.Method; +import java.util.ArrayList; import java.util.LinkedList; import java.util.List; @@ -33,6 +35,7 @@ import org.apache.hadoop.hive.serde2.typeinfo.TypeInfoFactory; import org.apache.hadoop.io.BytesWritable; import org.apache.hadoop.io.IntWritable; +import org.apache.hadoop.io.Text; public class TestFunctionRegistry extends TestCase { @@ -45,6 +48,11 @@ public void one(IntWritable x, IntWritable y) {} public void mismatch(DateWritable x, HiveDecimalWritable y) {} public void mismatch(TimestampWritable x, HiveDecimalWritable y) {} public void mismatch(BytesWritable x, DoubleWritable y) {} + public void typeaffinity1(DateWritable x) {} + public void typeaffinity1(DoubleWritable x) {}; + public void typeaffinity1(Text x) {} + public void typeaffinity2(IntWritable x) {} + public void typeaffinity2(DoubleWritable x) {} } @Override @@ -64,6 +72,52 @@ public void testImplicitConversion() { implicit(TypeInfoFactory.timestampTypeInfo, TypeInfoFactory.decimalTypeInfo, false); } + private static List getMethods(Class udfClass, String methodName) { + List mlist = new ArrayList(); + + for (Method m : udfClass.getMethods()) { + if (m.getName().equals(methodName)) { + mlist.add(m); + } + } + return mlist; + } + + private void typeAffinity(String methodName, TypeInfo inputType, + int expectedNumFoundMethods, Class expectedFoundType) { + List mlist = getMethods(TestUDF.class, methodName); + assertEquals(true, 1 < mlist.size()); + List inputTypes = new ArrayList(); + inputTypes.add(inputType); + + // narrow down the possible choices based on type affinity + FunctionRegistry.filterMethodsByTypeAffinity(mlist, inputTypes); + assertEquals(expectedNumFoundMethods, mlist.size()); + if (expectedNumFoundMethods == 1) { + assertEquals(expectedFoundType, mlist.get(0).getParameterTypes()[0]); + } + } + + public void testTypeAffinity() { + // Prefer numeric type arguments over other method signatures + typeAffinity("typeaffinity1", TypeInfoFactory.shortTypeInfo, 1, DoubleWritable.class); + typeAffinity("typeaffinity1", TypeInfoFactory.intTypeInfo, 1, DoubleWritable.class); + typeAffinity("typeaffinity1", TypeInfoFactory.floatTypeInfo, 1, DoubleWritable.class); + + // Prefer date type arguments over other method signatures + typeAffinity("typeaffinity1", TypeInfoFactory.dateTypeInfo, 1, DateWritable.class); + typeAffinity("typeaffinity1", TypeInfoFactory.timestampTypeInfo, 1, DateWritable.class); + + // String type affinity + typeAffinity("typeaffinity1", TypeInfoFactory.stringTypeInfo, 1, Text.class); + + // Type affinity does not help when multiple methods have the same type affinity. + typeAffinity("typeaffinity2", TypeInfoFactory.shortTypeInfo, 2, null); + + // Type affinity does not help when type affinity does not match input args + typeAffinity("typeaffinity2", TypeInfoFactory.dateTypeInfo, 2, null); + } + private void verify(Class udf, String name, TypeInfo ta, TypeInfo tb, Class a, Class b, boolean throwException) { List args = new LinkedList(); diff --git serde/src/java/org/apache/hadoop/hive/serde2/objectinspector/primitive/PrimitiveObjectInspectorUtils.java serde/src/java/org/apache/hadoop/hive/serde2/objectinspector/primitive/PrimitiveObjectInspectorUtils.java index 772eb43..7b8f947 100644 --- serde/src/java/org/apache/hadoop/hive/serde2/objectinspector/primitive/PrimitiveObjectInspectorUtils.java +++ serde/src/java/org/apache/hadoop/hive/serde2/objectinspector/primitive/PrimitiveObjectInspectorUtils.java @@ -952,6 +952,44 @@ public static Timestamp getTimestamp(Object o, PrimitiveObjectInspector oi) { return t == null ? null : t.primitiveJavaClass; } + /** + * Provide a general grouping for each primitive data type. + */ + public static enum PrimitiveGrouping { + NUMERIC_GROUP, STRING_GROUP, BOOLEAN_GROUP, DATE_GROUP, BINARY_GROUP, UNKNOWN_GROUP + }; + + /** + * Based on the PrimitiveCategory of a type, return the PrimitiveGrouping + * that the PrimitiveCategory belongs to (numeric, string, date, etc). + * @param primitiveCategory Primitive category of the type + * @return PrimitveGrouping corresponding to the PrimitiveCategory, + * or UNKNOWN_GROUP if the type does not match to a grouping. + */ + public static PrimitiveGrouping getPrimitiveGrouping(PrimitiveCategory primitiveCategory) { + switch (primitiveCategory) { + case BYTE: + case SHORT: + case INT: + case LONG: + case FLOAT: + case DOUBLE: + case DECIMAL: + return PrimitiveGrouping.NUMERIC_GROUP; + case STRING: + return PrimitiveGrouping.STRING_GROUP; + case BOOLEAN: + return PrimitiveGrouping.BOOLEAN_GROUP; + case TIMESTAMP: + case DATE: + return PrimitiveGrouping.DATE_GROUP; + case BINARY: + return PrimitiveGrouping.BINARY_GROUP; + default: + return PrimitiveGrouping.UNKNOWN_GROUP; + } + } + private PrimitiveObjectInspectorUtils() { // prevent instantiation } diff --git serde/src/test/org/apache/hadoop/hive/serde2/objectinspector/primitive/TestPrimitiveObjectInspectorUtils.java serde/src/test/org/apache/hadoop/hive/serde2/objectinspector/primitive/TestPrimitiveObjectInspectorUtils.java new file mode 100644 index 0000000..bbf8d6e --- /dev/null +++ serde/src/test/org/apache/hadoop/hive/serde2/objectinspector/primitive/TestPrimitiveObjectInspectorUtils.java @@ -0,0 +1,45 @@ +package org.apache.hadoop.hive.serde2.objectinspector.primitive; + +import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector.PrimitiveCategory; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils.PrimitiveGrouping; + +import junit.framework.TestCase; + +public class TestPrimitiveObjectInspectorUtils extends TestCase { + + public void testGetPrimitiveGrouping() { + assertEquals(PrimitiveGrouping.NUMERIC_GROUP, + PrimitiveObjectInspectorUtils.getPrimitiveGrouping(PrimitiveCategory.BYTE)); + assertEquals(PrimitiveGrouping.NUMERIC_GROUP, + PrimitiveObjectInspectorUtils.getPrimitiveGrouping(PrimitiveCategory.SHORT)); + assertEquals(PrimitiveGrouping.NUMERIC_GROUP, + PrimitiveObjectInspectorUtils.getPrimitiveGrouping(PrimitiveCategory.INT)); + assertEquals(PrimitiveGrouping.NUMERIC_GROUP, + PrimitiveObjectInspectorUtils.getPrimitiveGrouping(PrimitiveCategory.LONG)); + assertEquals(PrimitiveGrouping.NUMERIC_GROUP, + PrimitiveObjectInspectorUtils.getPrimitiveGrouping(PrimitiveCategory.FLOAT)); + assertEquals(PrimitiveGrouping.NUMERIC_GROUP, + PrimitiveObjectInspectorUtils.getPrimitiveGrouping(PrimitiveCategory.DOUBLE)); + assertEquals(PrimitiveGrouping.NUMERIC_GROUP, + PrimitiveObjectInspectorUtils.getPrimitiveGrouping(PrimitiveCategory.DECIMAL)); + + assertEquals(PrimitiveGrouping.STRING_GROUP, + PrimitiveObjectInspectorUtils.getPrimitiveGrouping(PrimitiveCategory.STRING)); + + assertEquals(PrimitiveGrouping.DATE_GROUP, + PrimitiveObjectInspectorUtils.getPrimitiveGrouping(PrimitiveCategory.DATE)); + assertEquals(PrimitiveGrouping.DATE_GROUP, + PrimitiveObjectInspectorUtils.getPrimitiveGrouping(PrimitiveCategory.TIMESTAMP)); + + assertEquals(PrimitiveGrouping.BOOLEAN_GROUP, + PrimitiveObjectInspectorUtils.getPrimitiveGrouping(PrimitiveCategory.BOOLEAN)); + + assertEquals(PrimitiveGrouping.BINARY_GROUP, + PrimitiveObjectInspectorUtils.getPrimitiveGrouping(PrimitiveCategory.BINARY)); + + assertEquals(PrimitiveGrouping.UNKNOWN_GROUP, + PrimitiveObjectInspectorUtils.getPrimitiveGrouping(PrimitiveCategory.UNKNOWN)); + assertEquals(PrimitiveGrouping.UNKNOWN_GROUP, + PrimitiveObjectInspectorUtils.getPrimitiveGrouping(PrimitiveCategory.VOID)); + } +}