diff --git ql/src/java/org/apache/hadoop/hive/ql/exec/FunctionRegistry.java ql/src/java/org/apache/hadoop/hive/ql/exec/FunctionRegistry.java index 03cf85d..ca038c3 100644 --- ql/src/java/org/apache/hadoop/hive/ql/exec/FunctionRegistry.java +++ ql/src/java/org/apache/hadoop/hive/ql/exec/FunctionRegistry.java @@ -149,8 +149,11 @@ import org.apache.hadoop.hive.serde.serdeConstants; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector.Category; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils.PrimitiveGrouping; import org.apache.hadoop.hive.serde2.typeinfo.ListTypeInfo; import org.apache.hadoop.hive.serde2.typeinfo.MapTypeInfo; +import org.apache.hadoop.hive.serde2.typeinfo.PrimitiveTypeInfo; import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo; import org.apache.hadoop.hive.serde2.typeinfo.TypeInfoFactory; import org.apache.hadoop.hive.serde2.typeinfo.TypeInfoUtils; @@ -1025,6 +1028,50 @@ public static Method getMethodInternal(Class udfClass, List mlist, bo // No matching methods found throw new NoMatchingMethodException(udfClass, argumentsPassed, mlist); } + + if (udfMethods.size() > 1) { + // Prefer methods with a closer signature based on the primitive grouping of each argument. + // For a varchar argument, we would prefer evaluate(string) over evaluate(double). + int currentScore = 0; + int bestMatchScore = 0; + Method bestMatch = null; + for (Method m: udfMethods) { + currentScore = 0; + List argumentsAccepted = + TypeInfoUtils.getParameterTypeInfos(m, argumentsPassed.size()); + Iterator argsPassedIter = argumentsPassed.iterator(); + for (TypeInfo acceptedType : argumentsAccepted) { + // Check the affinity of the argument passed in with the accepted argument, + // based on the PrimitiveGrouping + TypeInfo passedType = argsPassedIter.next(); + if (acceptedType.getCategory() == Category.PRIMITIVE + && passedType.getCategory() == Category.PRIMITIVE) { + PrimitiveGrouping acceptedPg = PrimitiveObjectInspectorUtils.getPrimitiveGrouping( + ((PrimitiveTypeInfo) acceptedType).getPrimitiveCategory()); + PrimitiveGrouping passedPg = PrimitiveObjectInspectorUtils.getPrimitiveGrouping( + ((PrimitiveTypeInfo) passedType).getPrimitiveCategory()); + if (acceptedPg == passedPg) { + // The passed argument matches somewhat closely with an accepted argument + ++currentScore; + } + } + } + // Check if the score for this method is any better relative to others + if (currentScore > bestMatchScore) { + bestMatchScore = currentScore; + bestMatch = m; + } else if (currentScore == bestMatchScore) { + bestMatch = null; // no longer a best match if more than one. + } + } + + if (bestMatch != null) { + // Found a best match during this processing, use it. + udfMethods.clear(); + udfMethods.add(bestMatch); + } + } + if (udfMethods.size() > 1) { // if the only difference is numeric types, pick the method diff --git serde/src/java/org/apache/hadoop/hive/serde2/objectinspector/primitive/PrimitiveObjectInspectorUtils.java serde/src/java/org/apache/hadoop/hive/serde2/objectinspector/primitive/PrimitiveObjectInspectorUtils.java index 772eb43..dae43b5 100644 --- serde/src/java/org/apache/hadoop/hive/serde2/objectinspector/primitive/PrimitiveObjectInspectorUtils.java +++ serde/src/java/org/apache/hadoop/hive/serde2/objectinspector/primitive/PrimitiveObjectInspectorUtils.java @@ -952,6 +952,37 @@ public static Timestamp getTimestamp(Object o, PrimitiveObjectInspector oi) { return t == null ? null : t.primitiveJavaClass; } + /** + * Provide a general grouping for each primitive data type. + */ + public static enum PrimitiveGrouping { + NUMERIC_GROUP, STRING_GROUP, BOOLEAN_GROUP, DATE_GROUP, BINARY_GROUP, UNKNOWN_GROUP + }; + + public static PrimitiveGrouping getPrimitiveGrouping(PrimitiveCategory primitiveCategory) { + switch (primitiveCategory) { + case BYTE: + case SHORT: + case INT: + case LONG: + case FLOAT: + case DOUBLE: + case DECIMAL: + return PrimitiveGrouping.NUMERIC_GROUP; + case STRING: + return PrimitiveGrouping.STRING_GROUP; + case BOOLEAN: + return PrimitiveGrouping.BOOLEAN_GROUP; + case TIMESTAMP: + case DATE: + return PrimitiveGrouping.DATE_GROUP; + case BINARY: + return PrimitiveGrouping.BINARY_GROUP; + default: + return PrimitiveGrouping.UNKNOWN_GROUP; + } + } + private PrimitiveObjectInspectorUtils() { // prevent instantiation }