diff --git ql/src/java/org/apache/hadoop/hive/ql/exec/FunctionRegistry.java ql/src/java/org/apache/hadoop/hive/ql/exec/FunctionRegistry.java index ca038c3..eb325a1 100644 --- ql/src/java/org/apache/hadoop/hive/ql/exec/FunctionRegistry.java +++ ql/src/java/org/apache/hadoop/hive/ql/exec/FunctionRegistry.java @@ -23,7 +23,7 @@ import java.net.URL; import java.util.ArrayList; import java.util.Collections; -import java.util.HashMap; +import java.util.EnumMap; import java.util.HashSet; import java.util.Iterator; import java.util.LinkedHashMap; @@ -149,6 +149,7 @@ import org.apache.hadoop.hive.serde.serdeConstants; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector.Category; +import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector.PrimitiveCategory; import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils; import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils.PrimitiveGrouping; import org.apache.hadoop.hive.serde2.typeinfo.ListTypeInfo; @@ -600,24 +601,53 @@ public static FunctionInfo getFunctionInfo(String functionName) { return synonyms; } - static Map numericTypes = new HashMap(); - static List numericTypeList = new ArrayList(); + // The ordering of types here is used to determine which numeric types + // are common/convertible to one another. Probably better to rely on the + // ordering explicitly defined here than to assume that the enum values + // that were arbitrarily assigned in PrimitiveCategory work for our purposes. + static EnumMap numericTypes = + new EnumMap(PrimitiveCategory.class); + static List numericTypeList = new ArrayList(); - static void registerNumericType(String typeName, int level) { - TypeInfo t = TypeInfoFactory.getPrimitiveTypeInfo(typeName); - numericTypeList.add(t); - numericTypes.put(t, level); + static void registerNumericType(PrimitiveCategory primitiveCategory, int level) { + numericTypeList.add(primitiveCategory); + numericTypes.put(primitiveCategory, level); } static { - registerNumericType(serdeConstants.TINYINT_TYPE_NAME, 1); - registerNumericType(serdeConstants.SMALLINT_TYPE_NAME, 2); - registerNumericType(serdeConstants.INT_TYPE_NAME, 3); - registerNumericType(serdeConstants.BIGINT_TYPE_NAME, 4); - registerNumericType(serdeConstants.FLOAT_TYPE_NAME, 5); - registerNumericType(serdeConstants.DOUBLE_TYPE_NAME, 6); - registerNumericType(serdeConstants.DECIMAL_TYPE_NAME, 7); - registerNumericType(serdeConstants.STRING_TYPE_NAME, 8); + registerNumericType(PrimitiveCategory.BYTE, 1); + registerNumericType(PrimitiveCategory.SHORT, 2); + registerNumericType(PrimitiveCategory.INT, 3); + registerNumericType(PrimitiveCategory.LONG, 4); + registerNumericType(PrimitiveCategory.FLOAT, 5); + registerNumericType(PrimitiveCategory.DOUBLE, 6); + registerNumericType(PrimitiveCategory.DECIMAL, 7); + registerNumericType(PrimitiveCategory.STRING, 8); + } + + /** + * Given 2 TypeInfo types and the PrimitiveCategory selected as the common class between the two, + * return a TypeInfo corresponding to the common PrimitiveCategory, and with type qualifiers + * (if applicable) that match the 2 TypeInfo types. + * Examples: + * varchar(10), varchar(20), primitive category varchar => varchar(20) + * date, string, primitive category string => string + * @param a TypeInfo of the first type + * @param b TypeInfo of the second type + * @param typeCategory PrimitiveCategory of the designated common type between a and b + * @return TypeInfo represented by the primitive category, with any applicable type qualifiers. + */ + public static TypeInfo getTypeInfoForPrimitiveCategory( + PrimitiveTypeInfo a, PrimitiveTypeInfo b, PrimitiveCategory typeCategory) { + // For types with parameters (like varchar), we need to determine the type parameters + // that should be added to this type, based on the original 2 TypeInfos. + switch (typeCategory) { + + default: + // Type doesn't require any qualifiers. + return TypeInfoFactory.getPrimitiveTypeInfo( + PrimitiveObjectInspectorUtils.getTypeEntryFromPrimitiveCategory(typeCategory).typeName); + } } /** @@ -627,18 +657,38 @@ public static TypeInfo getCommonClassForUnionAll(TypeInfo a, TypeInfo b) { if (a.equals(b)) { return a; } + if (a.getCategory() != Category.PRIMITIVE || b.getCategory() != Category.PRIMITIVE) { + return null; + } + PrimitiveCategory pcA = ((PrimitiveTypeInfo)a).getPrimitiveCategory(); + PrimitiveCategory pcB = ((PrimitiveTypeInfo)b).getPrimitiveCategory(); + + if (pcA == pcB) { + // Same primitive category but different qualifiers. + return getTypeInfoForPrimitiveCategory((PrimitiveTypeInfo)a, (PrimitiveTypeInfo)b, pcA); + } + + PrimitiveGrouping pgA = PrimitiveObjectInspectorUtils.getPrimitiveGrouping(pcA); + PrimitiveGrouping pgB = PrimitiveObjectInspectorUtils.getPrimitiveGrouping(pcB); + // handle string types properly + if (pgA == PrimitiveGrouping.STRING_GROUP && pgB == PrimitiveGrouping.STRING_GROUP) { + return getTypeInfoForPrimitiveCategory( + (PrimitiveTypeInfo)a, (PrimitiveTypeInfo)b,PrimitiveCategory.STRING); + } + if (FunctionRegistry.implicitConvertable(a, b)) { - return b; + return getTypeInfoForPrimitiveCategory((PrimitiveTypeInfo)a, (PrimitiveTypeInfo)b, pcB); } if (FunctionRegistry.implicitConvertable(b, a)) { - return a; + return getTypeInfoForPrimitiveCategory((PrimitiveTypeInfo)a, (PrimitiveTypeInfo)b, pcA); } - for (TypeInfo t : numericTypeList) { - if (FunctionRegistry.implicitConvertable(a, t) - && FunctionRegistry.implicitConvertable(b, t)) { - return t; + for (PrimitiveCategory t : numericTypeList) { + if (FunctionRegistry.implicitConvertable(pcA, t) + && FunctionRegistry.implicitConvertable(pcB, t)) { + return getTypeInfoForPrimitiveCategory((PrimitiveTypeInfo)a, (PrimitiveTypeInfo)b, t); } } + return null; } @@ -656,12 +706,34 @@ public static TypeInfo getCommonClassForComparison(TypeInfo a, TypeInfo b) { if (a.equals(b)) { return a; } - for (TypeInfo t : numericTypeList) { - if (FunctionRegistry.implicitConvertable(a, t) - && FunctionRegistry.implicitConvertable(b, t)) { - return t; + if (a.getCategory() != Category.PRIMITIVE || b.getCategory() != Category.PRIMITIVE) { + return null; + } + PrimitiveCategory pcA = ((PrimitiveTypeInfo)a).getPrimitiveCategory(); + PrimitiveCategory pcB = ((PrimitiveTypeInfo)b).getPrimitiveCategory(); + + if (pcA == pcB) { + // Same primitive category but different qualifiers. + // Rely on getTypeInfoForPrimitiveCategory() to sort out the type params. + return getTypeInfoForPrimitiveCategory((PrimitiveTypeInfo)a, (PrimitiveTypeInfo)b, pcA); + } + + PrimitiveGrouping pgA = PrimitiveObjectInspectorUtils.getPrimitiveGrouping(pcA); + PrimitiveGrouping pgB = PrimitiveObjectInspectorUtils.getPrimitiveGrouping(pcB); + // handle string types properly + if (pgA == PrimitiveGrouping.STRING_GROUP && pgB == PrimitiveGrouping.STRING_GROUP) { + // Compare as strings. Char comparison semantics may be different if/when implemented. + return getTypeInfoForPrimitiveCategory( + (PrimitiveTypeInfo)a, (PrimitiveTypeInfo)b,PrimitiveCategory.STRING); + } + + for (PrimitiveCategory t : numericTypeList) { + if (FunctionRegistry.implicitConvertable(pcA, t) + && FunctionRegistry.implicitConvertable(pcB, t)) { + return getTypeInfoForPrimitiveCategory((PrimitiveTypeInfo)a, (PrimitiveTypeInfo)b, t); } } + return null; } @@ -677,45 +749,60 @@ public static TypeInfo getCommonClass(TypeInfo a, TypeInfo b) { if (a.equals(b)) { return a; } - Integer ai = numericTypes.get(a); - Integer bi = numericTypes.get(b); + if (a.getCategory() != Category.PRIMITIVE || b.getCategory() != Category.PRIMITIVE) { + return null; + } + PrimitiveCategory pcA = ((PrimitiveTypeInfo)a).getPrimitiveCategory(); + PrimitiveCategory pcB = ((PrimitiveTypeInfo)b).getPrimitiveCategory(); + + PrimitiveGrouping pgA = PrimitiveObjectInspectorUtils.getPrimitiveGrouping(pcA); + PrimitiveGrouping pgB = PrimitiveObjectInspectorUtils.getPrimitiveGrouping(pcB); + // handle string types properly + if (pgA == PrimitiveGrouping.STRING_GROUP && pgB == PrimitiveGrouping.STRING_GROUP) { + return getTypeInfoForPrimitiveCategory( + (PrimitiveTypeInfo)a, (PrimitiveTypeInfo)b,PrimitiveCategory.STRING); + } + + Integer ai = numericTypes.get(pcA); + Integer bi = numericTypes.get(pcB); if (ai == null || bi == null) { // If either is not a numeric type, return null. return null; } - return (ai > bi) ? a : b; + PrimitiveCategory pcCommon = (ai > bi) ? pcA : pcB; + return getTypeInfoForPrimitiveCategory((PrimitiveTypeInfo)a, (PrimitiveTypeInfo)b, pcCommon); } - /** - * Returns whether it is possible to implicitly convert an object of Class - * from to Class to. - */ - public static boolean implicitConvertable(TypeInfo from, TypeInfo to) { - if (from.equals(to)) { + public static boolean implicitConvertable(PrimitiveCategory from, PrimitiveCategory to) { + if (from == to) { return true; } + + PrimitiveGrouping fromPg = PrimitiveObjectInspectorUtils.getPrimitiveGrouping(from); + PrimitiveGrouping toPg = PrimitiveObjectInspectorUtils.getPrimitiveGrouping(to); + // Allow implicit String to Double conversion - if (from.equals(TypeInfoFactory.stringTypeInfo) - && to.equals(TypeInfoFactory.doubleTypeInfo)) { + if (fromPg == PrimitiveGrouping.STRING_GROUP && to == PrimitiveCategory.DOUBLE) { return true; } // Allow implicit String to Decimal conversion - if (from.equals(TypeInfoFactory.stringTypeInfo) - && to.equals(TypeInfoFactory.decimalTypeInfo)) { + if (fromPg == PrimitiveGrouping.STRING_GROUP && to == PrimitiveCategory.DECIMAL) { return true; } // Void can be converted to any type - if (from.equals(TypeInfoFactory.voidTypeInfo)) { + if (from == PrimitiveCategory.VOID) { return true; } // Allow implicit String to Date conversion - if (from.equals(TypeInfoFactory.dateTypeInfo) - && to.equals(TypeInfoFactory.stringTypeInfo)) { + if (fromPg == PrimitiveGrouping.DATE_GROUP && toPg == PrimitiveGrouping.STRING_GROUP) { return true; } - - if (from.equals(TypeInfoFactory.timestampTypeInfo) - && to.equals(TypeInfoFactory.stringTypeInfo)) { + // Allow implicit Numeric to String conversion + if (fromPg == PrimitiveGrouping.NUMERIC_GROUP && toPg == PrimitiveGrouping.STRING_GROUP) { + return true; + } + // Allow implicit String to varchar conversion, and vice versa + if (fromPg == PrimitiveGrouping.STRING_GROUP && toPg == PrimitiveGrouping.STRING_GROUP) { return true; } @@ -733,6 +820,27 @@ public static boolean implicitConvertable(TypeInfo from, TypeInfo to) { } /** + * Returns whether it is possible to implicitly convert an object of Class + * from to Class to. + */ + public static boolean implicitConvertable(TypeInfo from, TypeInfo to) { + if (from.equals(to)) { + return true; + } + + // Reimplemented to use PrimitiveCategory rather than TypeInfo, because + // 2 TypeInfos from the same qualified type (varchar, decimal) should still be + // seen as equivalent. + if (from.getCategory() == Category.PRIMITIVE && to.getCategory() == Category.PRIMITIVE) { + return implicitConvertable( + ((PrimitiveTypeInfo)from).getPrimitiveCategory(), + ((PrimitiveTypeInfo)to).getPrimitiveCategory()); + } + return false; + } + + + /** * Get the GenericUDAF evaluator for the name and argumentClasses. * * @param name @@ -1097,9 +1205,15 @@ public static Method getMethodInternal(Class udfClass, List mlist, bo for (TypeInfo accepted: argumentsAccepted) { TypeInfo reference = referenceIterator.next(); - if (numericTypes.containsKey(accepted)) { + boolean acceptedIsPrimitive = false; + PrimitiveCategory acceptedPrimCat = PrimitiveCategory.UNKNOWN; + if (accepted.getCategory() == Category.PRIMITIVE) { + acceptedIsPrimitive = true; + acceptedPrimCat = ((PrimitiveTypeInfo) accepted).getPrimitiveCategory(); + } + if (acceptedIsPrimitive && numericTypes.containsKey(acceptedPrimCat)) { // We're looking for the udf with the smallest maximum numeric type. - int typeValue = numericTypes.get(accepted); + int typeValue = numericTypes.get(acceptedPrimCat); maxNumericType = typeValue > maxNumericType ? typeValue : maxNumericType; } else if (!accepted.equals(reference)) { // There are non-numeric arguments that don't match from one UDF to diff --git ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDFBaseCompare.java ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDFBaseCompare.java index 2cb65d8..a05b277 100644 --- ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDFBaseCompare.java +++ ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDFBaseCompare.java @@ -135,29 +135,11 @@ public ObjectInspector initialize(ObjectInspector[] arguments) throws UDFArgumen if (oiTypeInfo0 != oiTypeInfo1) { compareType = CompareType.NEED_CONVERT; - - if ((oiTypeInfo0.equals(TypeInfoFactory.stringTypeInfo) - && oiTypeInfo1.equals(TypeInfoFactory.dateTypeInfo)) - || (oiTypeInfo0.equals(TypeInfoFactory.dateTypeInfo) - && oiTypeInfo1.equals(TypeInfoFactory.stringTypeInfo))) { - // Date should be comparable with string - compareOI = TypeInfoUtils.getStandardWritableObjectInspectorFromTypeInfo( - TypeInfoFactory.stringTypeInfo); - - } else if (oiTypeInfo0.equals(TypeInfoFactory.stringTypeInfo) - || oiTypeInfo1.equals(TypeInfoFactory.stringTypeInfo)) { - // If either argument is a string, we convert to a double because a number - // in string form should always be convertible into a double - compareOI = TypeInfoUtils.getStandardWritableObjectInspectorFromTypeInfo( - TypeInfoFactory.doubleTypeInfo); - } else { - TypeInfo compareType = FunctionRegistry.getCommonClass(oiTypeInfo0, oiTypeInfo1); - - // For now, we always convert to double if we can't find a common type - compareOI = TypeInfoUtils.getStandardWritableObjectInspectorFromTypeInfo( - (compareType == null) ? - TypeInfoFactory.doubleTypeInfo : compareType); - } + TypeInfo compareType = FunctionRegistry.getCommonClassForComparison(oiTypeInfo0, oiTypeInfo1); + // For now, we always convert to double if we can't find a common type + compareOI = TypeInfoUtils.getStandardWritableObjectInspectorFromTypeInfo( + (compareType == null) ? + TypeInfoFactory.doubleTypeInfo : compareType); converter0 = ObjectInspectorConverters.getConverter(arguments[0], compareOI); converter1 = ObjectInspectorConverters.getConverter(arguments[1], compareOI); diff --git ql/src/test/org/apache/hadoop/hive/ql/exec/TestFunctionRegistry.java ql/src/test/org/apache/hadoop/hive/ql/exec/TestFunctionRegistry.java index 3875b5d..2868718 100644 --- ql/src/test/org/apache/hadoop/hive/ql/exec/TestFunctionRegistry.java +++ ql/src/test/org/apache/hadoop/hive/ql/exec/TestFunctionRegistry.java @@ -29,6 +29,8 @@ import org.apache.hadoop.hive.serde2.io.DoubleWritable; import org.apache.hadoop.hive.serde2.io.HiveDecimalWritable; import org.apache.hadoop.hive.serde2.io.TimestampWritable; +import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector.PrimitiveCategory; +import org.apache.hadoop.hive.serde2.typeinfo.PrimitiveTypeInfo; import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo; import org.apache.hadoop.hive.serde2.typeinfo.TypeInfoFactory; import org.apache.hadoop.io.BytesWritable; @@ -144,6 +146,33 @@ public void testCommonClassComparison() { TypeInfoFactory.doubleTypeInfo); } + private void unionAll(TypeInfo a, TypeInfo b, TypeInfo result) { + assertEquals(result, FunctionRegistry.getCommonClassForUnionAll(a,b)); + } + + public void testCommonClassUnionAll() { + unionAll(TypeInfoFactory.intTypeInfo, TypeInfoFactory.decimalTypeInfo, + TypeInfoFactory.decimalTypeInfo); + unionAll(TypeInfoFactory.stringTypeInfo, TypeInfoFactory.decimalTypeInfo, + TypeInfoFactory.decimalTypeInfo); + unionAll(TypeInfoFactory.doubleTypeInfo, TypeInfoFactory.decimalTypeInfo, + TypeInfoFactory.decimalTypeInfo); + unionAll(TypeInfoFactory.doubleTypeInfo, TypeInfoFactory.stringTypeInfo, + TypeInfoFactory.stringTypeInfo); + } + + public void testGetTypeInfoForPrimitiveCategory() { + // non-qualified types should simply return the TypeInfo associated with that type + assertEquals(TypeInfoFactory.stringTypeInfo, FunctionRegistry.getTypeInfoForPrimitiveCategory( + (PrimitiveTypeInfo) TypeInfoFactory.stringTypeInfo, + (PrimitiveTypeInfo) TypeInfoFactory.stringTypeInfo, + PrimitiveCategory.STRING)); + assertEquals(TypeInfoFactory.doubleTypeInfo, FunctionRegistry.getTypeInfoForPrimitiveCategory( + (PrimitiveTypeInfo) TypeInfoFactory.doubleTypeInfo, + (PrimitiveTypeInfo) TypeInfoFactory.stringTypeInfo, + PrimitiveCategory.DOUBLE)); + } + @Override protected void tearDown() { }