diff --git ql/src/java/org/apache/hadoop/hive/ql/exec/ComparisonOpMethodResolver.java ql/src/java/org/apache/hadoop/hive/ql/exec/ComparisonOpMethodResolver.java index 41a9cb3..fcd55d1 100644 --- ql/src/java/org/apache/hadoop/hive/ql/exec/ComparisonOpMethodResolver.java +++ ql/src/java/org/apache/hadoop/hive/ql/exec/ComparisonOpMethodResolver.java @@ -20,12 +20,10 @@ import java.lang.reflect.Method; import java.util.ArrayList; -import java.util.Arrays; import java.util.List; import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo; import org.apache.hadoop.hive.serde2.typeinfo.TypeInfoFactory; -import org.apache.hadoop.hive.serde2.typeinfo.TypeInfoUtils; /** * The class implements the method resolution for operators like (> < <= >= = @@ -42,7 +40,7 @@ private final Class udfClass; /** - * Constuctor. + * Constructor. */ public ComparisonOpMethodResolver(Class udfClass) { this.udfClass = udfClass; @@ -78,46 +76,7 @@ public Method getEvalMethod(List argTypeInfos) throws UDFArgumentExcep pTypeInfos.add(TypeInfoFactory.doubleTypeInfo); } - Method udfMethod = null; - - List evaluateMethods = new ArrayList(); - - for (Method m : Arrays.asList(udfClass.getMethods())) { - if (m.getName().equals("evaluate")) { - - evaluateMethods.add(m); - List acceptedTypeInfos = TypeInfoUtils.getParameterTypeInfos( - m, pTypeInfos.size()); - if (acceptedTypeInfos == null) { - // null means the method does not accept number of arguments passed. - continue; - } - - boolean match = (acceptedTypeInfos.size() == pTypeInfos.size()); - - for (int i = 0; i < pTypeInfos.size() && match; i++) { - TypeInfo accepted = acceptedTypeInfos.get(i); - if (accepted != pTypeInfos.get(i)) { - match = false; - } - } - - if (match) { - if (udfMethod != null) { - throw new AmbiguousMethodException(udfClass, argTypeInfos, - Arrays.asList(new Method[]{udfMethod, m})); - } else { - udfMethod = m; - } - } - } - } - - if (udfMethod == null) { - throw new NoMatchingMethodException(udfClass, argTypeInfos, evaluateMethods); - } - - return udfMethod; + return FunctionRegistry.matchMethod(argTypeInfos, pTypeInfos, udfClass, "evaluate"); } } diff --git ql/src/java/org/apache/hadoop/hive/ql/exec/FunctionRegistry.java ql/src/java/org/apache/hadoop/hive/ql/exec/FunctionRegistry.java index 0c6a3d4..7f2044c 100644 --- ql/src/java/org/apache/hadoop/hive/ql/exec/FunctionRegistry.java +++ ql/src/java/org/apache/hadoop/hive/ql/exec/FunctionRegistry.java @@ -22,6 +22,7 @@ import java.lang.reflect.Method; import java.net.URL; import java.util.ArrayList; +import java.util.Arrays; import java.util.Collections; import java.util.EnumMap; import java.util.HashSet; @@ -972,6 +973,16 @@ public static PrimitiveCategory getCommonCategory(TypeInfo a, TypeInfo b) { PrimitiveCategory pcA = ((PrimitiveTypeInfo)a).getPrimitiveCategory(); PrimitiveCategory pcB = ((PrimitiveTypeInfo)b).getPrimitiveCategory(); + if (pcA == PrimitiveCategory.UNKNOWN && pcB == PrimitiveCategory.UNKNOWN) { + return null; + } + if (pcA == PrimitiveCategory.UNKNOWN) { + return pcB; + } + if (pcB == PrimitiveCategory.UNKNOWN) { + return pcA; + } + PrimitiveGrouping pgA = PrimitiveObjectInspectorUtils.getPrimitiveGrouping(pcA); PrimitiveGrouping pgB = PrimitiveObjectInspectorUtils.getPrimitiveGrouping(pcB); // handle string types properly @@ -1264,6 +1275,9 @@ public static Object invoke(Method m, Object thisObject, Object... arguments) */ public static int matchCost(TypeInfo argumentPassed, TypeInfo argumentAccepted, boolean exact) { + if (argumentAccepted == TypeInfoFactory.unknownTypeInfo) { + return 0; + } if (argumentAccepted.equals(argumentPassed) || TypeInfoUtils.doPrimitiveCategoriesMatch(argumentPassed, argumentAccepted)) { // matches @@ -1334,6 +1348,9 @@ static void filterMethodsByTypeAffinity(List udfMethods, List TypeInfoUtils.getParameterTypeInfos(m, argumentsPassed.size()); Iterator argsPassedIter = argumentsPassed.iterator(); for (TypeInfo acceptedType : argumentsAccepted) { + if (acceptedType == TypeInfoFactory.unknownTypeInfo) { + continue; + } // Check the affinity of the argument passed in with the accepted argument, // based on the PrimitiveGrouping TypeInfo passedType = argsPassedIter.next(); @@ -1389,46 +1406,36 @@ public static Method getMethodInternal(Class udfClass, List mlist, bo for (Method m : mlist) { List argumentsAccepted = TypeInfoUtils.getParameterTypeInfos(m, argumentsPassed.size()); - if (argumentsAccepted == null) { + if (argumentsAccepted == null || argumentsAccepted.size() != argumentsPassed.size()) { // null means the method does not accept number of arguments passed. continue; } - boolean match = (argumentsAccepted.size() == argumentsPassed.size()); - int conversionCost = 0; - - for (int i = 0; i < argumentsPassed.size() && match; i++) { - int cost = matchCost(argumentsPassed.get(i), argumentsAccepted.get(i), - exact); - if (cost == -1) { - match = false; - } else { - conversionCost += cost; - } - } + int conversionCost = calculateCost(argumentsPassed, argumentsAccepted, exact); if (LOG.isDebugEnabled()) { - LOG.debug("Method " + (match ? "did" : "didn't") + " match: passed = " + LOG.debug("Method " + (conversionCost >= 0 ? "did" : "didn't") + " match: passed = " + argumentsPassed + " accepted = " + argumentsAccepted + " method = " + m); } - if (match) { - // Always choose the function with least implicit conversions. - if (conversionCost < leastConversionCost) { - udfMethods.clear(); - udfMethods.add(m); - leastConversionCost = conversionCost; - // Found an exact match - if (leastConversionCost == 0) { - break; - } - } else if (conversionCost == leastConversionCost) { - // Ambiguous call: two methods with the same number of implicit - // conversions - udfMethods.add(m); - // Don't break! We might find a better match later. - } else { - // do nothing if implicitConversions > leastImplicitConversions + if (conversionCost < 0) { + continue; + } + // Always choose the function with least implicit conversions. + if (conversionCost < leastConversionCost) { + udfMethods.clear(); + udfMethods.add(m); + leastConversionCost = conversionCost; + // Found an exact match + if (leastConversionCost == 0) { + break; } + } else if (conversionCost == leastConversionCost) { + // Ambiguous call: two methods with the same number of implicit + // conversions + udfMethods.add(m); + // Don't break! We might find a better match later. + } else { + // do nothing if implicitConversions > leastImplicitConversions } } @@ -1466,6 +1473,9 @@ public static Method getMethodInternal(Class udfClass, List mlist, bo Iterator referenceIterator = referenceArguments.iterator(); for (TypeInfo accepted: argumentsAccepted) { + if (accepted == TypeInfoFactory.unknownTypeInfo) { + continue; + } TypeInfo reference = referenceIterator.next(); boolean acceptedIsPrimitive = false; @@ -1505,6 +1515,71 @@ public static Method getMethodInternal(Class udfClass, List mlist, bo return udfMethods.get(0); } + private static int calculateCost( + List argumentsPassed, List argumentsAccepted, boolean exact) { + int conversionCost = 0; + for (int i = 0; i < argumentsPassed.size(); i++) { + int cost = matchCost(argumentsPassed.get(i), argumentsAccepted.get(i), exact); + if (cost == -1) { + return -1; + } + conversionCost += cost; + } + return conversionCost; + } + + private static Method[] findMethod(Class clazz, String methodName) { + List methods = new ArrayList(); + for (Method method : clazz.getMethods()) { + if (method.getName().equals(methodName)) { + methods.add(method); + } + } + return methods.toArray(new Method[methods.size()]); + } + + public static Method matchMethod(List argTypeInfos, List pTypeInfos, + Class clazz, String methodName) throws AmbiguousMethodException, NoMatchingMethodException { + return matchMethod(argTypeInfos, pTypeInfos, clazz, findMethod(clazz, methodName)); + } + + public static Method matchMethod(List argTypeInfos, List pTypeInfos, + Class clazz, Method... methods) throws AmbiguousMethodException, NoMatchingMethodException { + Method found = null; + + for (Method m : methods) { + List argumentTypeInfos = TypeInfoUtils.getParameterTypeInfos( + m, pTypeInfos.size()); + if (argumentTypeInfos == null) { + // null means the method does not accept number of arguments passed. + continue; + } + + boolean match = (argumentTypeInfos.size() == pTypeInfos.size()); + + for (int i = 0; i < pTypeInfos.size() && match; i++) { + TypeInfo accepted = argumentTypeInfos.get(i); + if (accepted != TypeInfoFactory.unknownTypeInfo && !accepted.accept(pTypeInfos.get(i))) { + match = false; + } + } + + if (match) { + if (found != null) { + throw new AmbiguousMethodException(clazz, argTypeInfos, + Arrays.asList(found, m)); + } + found = m; + } + } + + if (found == null) { + throw new NoMatchingMethodException(clazz, argTypeInfos, Arrays.asList(methods)); + } + + return found; + } + /** * A shortcut to get the "index" GenericUDF. This is used for getting elements * out of array and getting values out of map. diff --git ql/src/java/org/apache/hadoop/hive/ql/exec/NumericOpMethodResolver.java ql/src/java/org/apache/hadoop/hive/ql/exec/NumericOpMethodResolver.java index b056554..904ba2f 100644 --- ql/src/java/org/apache/hadoop/hive/ql/exec/NumericOpMethodResolver.java +++ ql/src/java/org/apache/hadoop/hive/ql/exec/NumericOpMethodResolver.java @@ -20,12 +20,10 @@ import java.lang.reflect.Method; import java.util.ArrayList; -import java.util.Arrays; import java.util.List; import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo; import org.apache.hadoop.hive.serde2.typeinfo.TypeInfoFactory; -import org.apache.hadoop.hive.serde2.typeinfo.TypeInfoUtils; /** * The class implements the method resolution for operators like (+, -, *, %). @@ -46,7 +44,7 @@ Class udfClass; /** - * Constuctor. + * Constructor. */ public NumericOpMethodResolver(Class udfClass) { this.udfClass = udfClass; @@ -107,42 +105,6 @@ public Method getEvalMethod(List argTypeInfos) throws UDFArgumentExcep pTypeInfos.add(commonType); pTypeInfos.add(commonType); - Method udfMethod = null; - - for (Method m : Arrays.asList(udfClass.getMethods())) { - if (m.getName().equals("evaluate")) { - - List argumentTypeInfos = TypeInfoUtils.getParameterTypeInfos( - m, pTypeInfos.size()); - if (argumentTypeInfos == null) { - // null means the method does not accept number of arguments passed. - continue; - } - - boolean match = (argumentTypeInfos.size() == pTypeInfos.size()); - - for (int i = 0; i < pTypeInfos.size() && match; i++) { - TypeInfo accepted = argumentTypeInfos.get(i); - if (!accepted.accept(pTypeInfos.get(i))) { - match = false; - } - } - - if (match) { - if (udfMethod != null) { - throw new AmbiguousMethodException(udfClass, argTypeInfos, - Arrays.asList(new Method[]{udfMethod, m})); - } else { - udfMethod = m; - } - } - } - } - - if (udfMethod == null) { - throw new NoMatchingMethodException(udfClass, argTypeInfos, null); - } - - return udfMethod; + return FunctionRegistry.matchMethod(argTypeInfos, pTypeInfos, udfClass, "evaluate"); } } diff --git ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDFBridge.java ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDFBridge.java index 959007a..15a1fec 100644 --- ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDFBridge.java +++ ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDFBridge.java @@ -20,6 +20,8 @@ import java.io.Serializable; import java.lang.reflect.Method; +import java.lang.reflect.Type; +import java.lang.reflect.TypeVariable; import java.util.ArrayList; import org.apache.hadoop.hive.common.JavaUtils; @@ -161,8 +163,22 @@ public ObjectInspector initialize(ObjectInspector[] arguments) throws UDFArgumen realArguments = new Object[arguments.length]; // Get the return ObjectInspector. + Type returnType = udfMethod.getGenericReturnType(); + if (returnType instanceof TypeVariable) { + // should be referenced from parameters at least once + TypeVariable tv = (TypeVariable)returnType; + Type[] parameterTypes = udfMethod.getGenericParameterTypes(); + for (int i = 0; i < parameterTypes.length; i++) { + ObjectInspector oi = TypeInfoUtils.extractOI(tv, parameterTypes[i], arguments[i]); + if (oi != null) { + return ObjectInspectorFactory + .getReflectionObjectInspector(returnType, oi, ObjectInspectorOptions.JAVA); + } + } + throw new UDFArgumentException("Return type " + returnType + " cannot be resolved"); + } ObjectInspector returnOI = ObjectInspectorFactory - .getReflectionObjectInspector(udfMethod.getGenericReturnType(), + .getReflectionObjectInspector(returnType, ObjectInspectorOptions.JAVA); return returnOI; diff --git ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDFUtils.java ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDFUtils.java index 1f70c55..c513960 100644 --- ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDFUtils.java +++ ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDFUtils.java @@ -20,9 +20,7 @@ import java.lang.reflect.Array; import java.lang.reflect.Method; -import java.lang.reflect.ParameterizedType; import java.lang.reflect.Type; -import java.nio.ByteBuffer; import java.util.HashMap; import org.apache.hadoop.hive.ql.exec.FunctionRegistry; @@ -214,16 +212,6 @@ public Object convertIfNecessary(Object o, ObjectInspector oi) { Object[] convertedParameters; Object[] convertedParametersInArray; - private static Class getClassFromType(Type t) { - if (t instanceof Class) { - return (Class) t; - } else if (t instanceof ParameterizedType) { - ParameterizedType pt = (ParameterizedType) t; - return (Class) pt.getRawType(); - } - return null; - } - /** * Create a PrimitiveConversionHelper for Method m. The ObjectInspector's * input parameters are specified in parameters. @@ -259,35 +247,20 @@ public ConversionHelper(Method m, ObjectInspector[] parameterOIs) for (int i = 0; i < methodParameterTypes.length - 1; i++) { // This method takes Object, so it accepts whatever types that are // passed in. - if (methodParameterTypes[i] == Object.class) { - methodParameterOIs[i] = ObjectInspectorUtils - .getStandardObjectInspector(parameterOIs[i], - ObjectInspectorCopyOption.JAVA); - } else { methodParameterOIs[i] = ObjectInspectorFactory - .getReflectionObjectInspector(methodParameterTypes[i], + .getReflectionObjectInspector(methodParameterTypes[i], parameterOIs[i], ObjectInspectorOptions.JAVA); - } } // Deal with the last entry - if (lastParaElementType == Object.class) { - // This method takes Object[], so it accepts whatever types that are - // passed in. - for (int i = methodParameterTypes.length - 1; i < parameterOIs.length; i++) { - methodParameterOIs[i] = ObjectInspectorUtils - .getStandardObjectInspector(parameterOIs[i], - ObjectInspectorCopyOption.JAVA); - } - } else { - // This method takes something like String[], so it only accepts - // something like String - ObjectInspector oi = ObjectInspectorFactory - .getReflectionObjectInspector(lastParaElementType, - ObjectInspectorOptions.JAVA); - for (int i = methodParameterTypes.length - 1; i < parameterOIs.length; i++) { - methodParameterOIs[i] = oi; - } + // This method takes something like String[], so it only accepts + // something like String + ObjectInspector oi = ObjectInspectorFactory + .getReflectionObjectInspector(lastParaElementType, + parameterOIs[methodParameterTypes.length - 1], + ObjectInspectorOptions.JAVA); + for (int i = methodParameterTypes.length - 1; i < parameterOIs.length; i++) { + methodParameterOIs[i] = oi; } } else { @@ -304,15 +277,9 @@ public ConversionHelper(Method m, ObjectInspector[] parameterOIs) for (int i = 0; i < methodParameterTypes.length; i++) { // This method takes Object, so it accepts whatever types that are // passed in. - if (methodParameterTypes[i] == Object.class) { - methodParameterOIs[i] = ObjectInspectorUtils - .getStandardObjectInspector(parameterOIs[i], - ObjectInspectorCopyOption.JAVA); - } else { methodParameterOIs[i] = ObjectInspectorFactory - .getReflectionObjectInspector(methodParameterTypes[i], + .getReflectionObjectInspector(methodParameterTypes[i], parameterOIs[i], ObjectInspectorOptions.JAVA); - } } } @@ -331,7 +298,7 @@ public ConversionHelper(Method m, ObjectInspector[] parameterOIs) if (isVariableLengthArgument) { convertedParameters = new Object[methodParameterTypes.length]; convertedParametersInArray = (Object[]) Array.newInstance( - getClassFromType(lastParaElementType), parameterOIs.length + TypeInfoUtils.getClassFromType(lastParaElementType), parameterOIs.length - methodParameterTypes.length + 1); convertedParameters[convertedParameters.length - 1] = convertedParametersInArray; } else { diff --git ql/src/test/org/apache/hadoop/hive/ql/io/orc/TestRecordReaderImpl.java ql/src/test/org/apache/hadoop/hive/ql/io/orc/TestRecordReaderImpl.java index 22e4724..c94ab27 100644 --- ql/src/test/org/apache/hadoop/hive/ql/io/orc/TestRecordReaderImpl.java +++ ql/src/test/org/apache/hadoop/hive/ql/io/orc/TestRecordReaderImpl.java @@ -50,7 +50,7 @@ public class TestRecordReaderImpl { // can add .verboseLogging() to cause Mockito to log invocations - private final MockSettings settings = Mockito.withSettings().verboseLogging(); + private final MockSettings settings = Mockito.withSettings(); static class BufferInStream extends InputStream implements PositionedReadable, Seekable { diff --git ql/src/test/org/apache/hadoop/hive/ql/udf/UDFObjectToString.java ql/src/test/org/apache/hadoop/hive/ql/udf/UDFObjectToString.java new file mode 100644 index 0000000..3d1e2d9 --- /dev/null +++ ql/src/test/org/apache/hadoop/hive/ql/udf/UDFObjectToString.java @@ -0,0 +1,28 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.hadoop.hive.ql.udf; + +import org.apache.hadoop.hive.ql.exec.UDF; + +public class UDFObjectToString extends UDF { + + public String evaluate(Object arg) { + return arg == null ? null : String.valueOf(arg); + } +} diff --git ql/src/test/org/apache/hadoop/hive/ql/udf/UDFTypeVariable.java ql/src/test/org/apache/hadoop/hive/ql/udf/UDFTypeVariable.java new file mode 100644 index 0000000..809fa52 --- /dev/null +++ ql/src/test/org/apache/hadoop/hive/ql/udf/UDFTypeVariable.java @@ -0,0 +1,34 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.hadoop.hive.ql.udf; + +import org.apache.hadoop.hive.ql.exec.UDF; + +import java.util.Map; + +public class UDFTypeVariable extends UDF { + + public T evaluate(Map arg) { + return arg.values().iterator().next(); + } + + public T evaluate(T arg) { + return arg; + } +} diff --git ql/src/test/queries/clientpositive/udf_generics.q ql/src/test/queries/clientpositive/udf_generics.q new file mode 100644 index 0000000..600517e --- /dev/null +++ ql/src/test/queries/clientpositive/udf_generics.q @@ -0,0 +1,16 @@ +set hive.fetch.task.conversion=more; + +create temporary function os as 'org.apache.hadoop.hive.ql.udf.UDFObjectToString'; +create temporary function tv as 'org.apache.hadoop.hive.ql.udf.UDFTypeVariable'; + +explain +select os(key), os(cast(key as smallint)), os(cast(key as int)), os(cast(key as float)), os(cast(key as double)), os(cast(key as bigint)) from src tablesample (1 rows); +select os(key), os(cast(key as smallint)), os(cast(key as int)), os(cast(key as float)), os(cast(key as double)), os(cast(key as bigint)) from src tablesample (1 rows); + +explain +select tv(key), tv(cast(key as smallint)), tv(cast(key as int)), tv(cast(key as float)), tv(cast(key as double)), tv(cast(key as bigint)) from src tablesample (1 rows); +select tv(key), tv(cast(key as smallint)), tv(cast(key as int)), tv(cast(key as float)), tv(cast(key as double)), tv(cast(key as bigint)) from src tablesample (1 rows); + +explain +select tv(map(value, key)), tv(map(value, cast(key as smallint))), tv(map(value, cast(key as int))), tv(map(value, cast(key as float))), tv(map(value, cast(key as double))),tv(map(value, cast(key as bigint))) from src tablesample (1 rows); +select tv(map(value, key)), tv(map(value, cast(key as smallint))), tv(map(value, cast(key as int))), tv(map(value, cast(key as float))), tv(map(value, cast(key as double))),tv(map(value, cast(key as bigint))) from src tablesample (1 rows); diff --git ql/src/test/results/clientpositive/udf_generics.q.out ql/src/test/results/clientpositive/udf_generics.q.out new file mode 100644 index 0000000..ca46ce9 --- /dev/null +++ ql/src/test/results/clientpositive/udf_generics.q.out @@ -0,0 +1,111 @@ +PREHOOK: query: create temporary function os as 'org.apache.hadoop.hive.ql.udf.UDFObjectToString' +PREHOOK: type: CREATEFUNCTION +PREHOOK: Output: os +POSTHOOK: query: create temporary function os as 'org.apache.hadoop.hive.ql.udf.UDFObjectToString' +POSTHOOK: type: CREATEFUNCTION +POSTHOOK: Output: os +PREHOOK: query: create temporary function tv as 'org.apache.hadoop.hive.ql.udf.UDFTypeVariable' +PREHOOK: type: CREATEFUNCTION +PREHOOK: Output: tv +POSTHOOK: query: create temporary function tv as 'org.apache.hadoop.hive.ql.udf.UDFTypeVariable' +POSTHOOK: type: CREATEFUNCTION +POSTHOOK: Output: tv +PREHOOK: query: explain +select os(key), os(cast(key as smallint)), os(cast(key as int)), os(cast(key as float)), os(cast(key as double)), os(cast(key as bigint)) from src tablesample (1 rows) +PREHOOK: type: QUERY +POSTHOOK: query: explain +select os(key), os(cast(key as smallint)), os(cast(key as int)), os(cast(key as float)), os(cast(key as double)), os(cast(key as bigint)) from src tablesample (1 rows) +POSTHOOK: type: QUERY +STAGE DEPENDENCIES: + Stage-0 is a root stage + +STAGE PLANS: + Stage: Stage-0 + Fetch Operator + limit: -1 + Processor Tree: + TableScan + alias: src + Row Limit Per Split: 1 + Statistics: Num rows: 58 Data size: 5812 Basic stats: COMPLETE Column stats: NONE + Select Operator + expressions: os(key) (type: string), os(UDFToShort(key)) (type: string), os(UDFToInteger(key)) (type: string), os(UDFToFloat(key)) (type: string), os(UDFToDouble(key)) (type: string), os(UDFToLong(key)) (type: string) + outputColumnNames: _col0, _col1, _col2, _col3, _col4, _col5 + Statistics: Num rows: 58 Data size: 5812 Basic stats: COMPLETE Column stats: NONE + ListSink + +PREHOOK: query: select os(key), os(cast(key as smallint)), os(cast(key as int)), os(cast(key as float)), os(cast(key as double)), os(cast(key as bigint)) from src tablesample (1 rows) +PREHOOK: type: QUERY +PREHOOK: Input: default@src +#### A masked pattern was here #### +POSTHOOK: query: select os(key), os(cast(key as smallint)), os(cast(key as int)), os(cast(key as float)), os(cast(key as double)), os(cast(key as bigint)) from src tablesample (1 rows) +POSTHOOK: type: QUERY +POSTHOOK: Input: default@src +#### A masked pattern was here #### +238 238 238 238.0 238.0 238 +PREHOOK: query: explain +select tv(key), tv(cast(key as smallint)), tv(cast(key as int)), tv(cast(key as float)), tv(cast(key as double)), tv(cast(key as bigint)) from src tablesample (1 rows) +PREHOOK: type: QUERY +POSTHOOK: query: explain +select tv(key), tv(cast(key as smallint)), tv(cast(key as int)), tv(cast(key as float)), tv(cast(key as double)), tv(cast(key as bigint)) from src tablesample (1 rows) +POSTHOOK: type: QUERY +STAGE DEPENDENCIES: + Stage-0 is a root stage + +STAGE PLANS: + Stage: Stage-0 + Fetch Operator + limit: -1 + Processor Tree: + TableScan + alias: src + Row Limit Per Split: 1 + Statistics: Num rows: 58 Data size: 5812 Basic stats: COMPLETE Column stats: NONE + Select Operator + expressions: tv(key) (type: string), tv(UDFToShort(key)) (type: smallint), tv(UDFToInteger(key)) (type: int), tv(UDFToFloat(key)) (type: float), tv(UDFToDouble(key)) (type: double), tv(UDFToLong(key)) (type: bigint) + outputColumnNames: _col0, _col1, _col2, _col3, _col4, _col5 + Statistics: Num rows: 58 Data size: 5812 Basic stats: COMPLETE Column stats: NONE + ListSink + +PREHOOK: query: select tv(key), tv(cast(key as smallint)), tv(cast(key as int)), tv(cast(key as float)), tv(cast(key as double)), tv(cast(key as bigint)) from src tablesample (1 rows) +PREHOOK: type: QUERY +PREHOOK: Input: default@src +#### A masked pattern was here #### +POSTHOOK: query: select tv(key), tv(cast(key as smallint)), tv(cast(key as int)), tv(cast(key as float)), tv(cast(key as double)), tv(cast(key as bigint)) from src tablesample (1 rows) +POSTHOOK: type: QUERY +POSTHOOK: Input: default@src +#### A masked pattern was here #### +238 238 238 238.0 238.0 238 +PREHOOK: query: explain +select tv(map(value, key)), tv(map(value, cast(key as smallint))), tv(map(value, cast(key as int))), tv(map(value, cast(key as float))), tv(map(value, cast(key as double))),tv(map(value, cast(key as bigint))) from src tablesample (1 rows) +PREHOOK: type: QUERY +POSTHOOK: query: explain +select tv(map(value, key)), tv(map(value, cast(key as smallint))), tv(map(value, cast(key as int))), tv(map(value, cast(key as float))), tv(map(value, cast(key as double))),tv(map(value, cast(key as bigint))) from src tablesample (1 rows) +POSTHOOK: type: QUERY +STAGE DEPENDENCIES: + Stage-0 is a root stage + +STAGE PLANS: + Stage: Stage-0 + Fetch Operator + limit: -1 + Processor Tree: + TableScan + alias: src + Row Limit Per Split: 1 + Statistics: Num rows: 29 Data size: 5812 Basic stats: COMPLETE Column stats: NONE + Select Operator + expressions: tv(map(value:key)) (type: string), tv(map(value:UDFToShort(key))) (type: smallint), tv(map(value:UDFToInteger(key))) (type: int), tv(map(value:UDFToFloat(key))) (type: float), tv(map(value:UDFToDouble(key))) (type: double), tv(map(value:UDFToLong(key))) (type: bigint) + outputColumnNames: _col0, _col1, _col2, _col3, _col4, _col5 + Statistics: Num rows: 29 Data size: 5812 Basic stats: COMPLETE Column stats: NONE + ListSink + +PREHOOK: query: select tv(map(value, key)), tv(map(value, cast(key as smallint))), tv(map(value, cast(key as int))), tv(map(value, cast(key as float))), tv(map(value, cast(key as double))),tv(map(value, cast(key as bigint))) from src tablesample (1 rows) +PREHOOK: type: QUERY +PREHOOK: Input: default@src +#### A masked pattern was here #### +POSTHOOK: query: select tv(map(value, key)), tv(map(value, cast(key as smallint))), tv(map(value, cast(key as int))), tv(map(value, cast(key as float))), tv(map(value, cast(key as double))),tv(map(value, cast(key as bigint))) from src tablesample (1 rows) +POSTHOOK: type: QUERY +POSTHOOK: Input: default@src +#### A masked pattern was here #### +238 238 238 238.0 238.0 238 diff --git serde/src/java/org/apache/hadoop/hive/serde2/objectinspector/ObjectInspectorFactory.java serde/src/java/org/apache/hadoop/hive/serde2/objectinspector/ObjectInspectorFactory.java index 9a226b3..e9249ed 100644 --- serde/src/java/org/apache/hadoop/hive/serde2/objectinspector/ObjectInspectorFactory.java +++ serde/src/java/org/apache/hadoop/hive/serde2/objectinspector/ObjectInspectorFactory.java @@ -22,14 +22,18 @@ import java.lang.reflect.GenericArrayType; import java.lang.reflect.ParameterizedType; import java.lang.reflect.Type; +import java.lang.reflect.TypeVariable; import java.util.ArrayList; import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory; import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils; +import org.apache.hadoop.hive.serde2.typeinfo.TypeInfoUtils; /** * ObjectInspectorFactory is the primary way to create new ObjectInspector @@ -64,9 +68,17 @@ public static ObjectInspector getReflectionObjectInspector(Type t, ObjectInspectorOptions options) { + return getReflectionObjectInspector(t, null, options); + } + + public static ObjectInspector getReflectionObjectInspector(Type t, ObjectInspector toi, + ObjectInspectorOptions options) { + if (TypeInfoUtils.containsTypeVariable(t)) { + return getReflectionObjectInspectorNoCache(t, toi, options); + } ObjectInspector oi = objectInspectorCache.get(t); if (oi == null) { - oi = getReflectionObjectInspectorNoCache(t, options); + oi = getReflectionObjectInspectorNoCache(t, toi, options); objectInspectorCache.put(t, oi); } verifyObjectInspector(options, oi, ObjectInspectorOptions.JAVA, new Class[]{ThriftStructObjectInspector.class, @@ -100,31 +112,51 @@ private static void verifyObjectInspector(ObjectInspectorOptions option, ObjectI } } - private static ObjectInspector getReflectionObjectInspectorNoCache(Type t, + private static ObjectInspector getReflectionObjectInspectorNoCache(Type t, ObjectInspector oi, ObjectInspectorOptions options) { + if (t == Object.class) { + return ObjectInspectorUtils.getStandardObjectInspector( + oi, ObjectInspectorUtils.ObjectInspectorCopyOption.JAVA); + } if (t instanceof GenericArrayType) { GenericArrayType at = (GenericArrayType) t; + if (oi instanceof ListObjectInspector) { + oi = ((ListObjectInspector)oi).getListElementObjectInspector(); + } return getStandardListObjectInspector(getReflectionObjectInspector(at - .getGenericComponentType(), options)); + .getGenericComponentType(), oi, options)); } if (t instanceof ParameterizedType) { ParameterizedType pt = (ParameterizedType) t; // List? - if (List.class.isAssignableFrom((Class) pt.getRawType())) { + if (List.class.isAssignableFrom(TypeInfoUtils.getClassFromType(pt.getRawType()))) { + if (oi instanceof ListObjectInspector) { + oi = ((ListObjectInspector)oi).getListElementObjectInspector(); + } return getStandardListObjectInspector(getReflectionObjectInspector(pt - .getActualTypeArguments()[0], options)); + .getActualTypeArguments()[0], oi, options)); } // Map? if (Map.class.isAssignableFrom((Class) pt.getRawType())) { - return getStandardMapObjectInspector(getReflectionObjectInspector(pt - .getActualTypeArguments()[0], options), - getReflectionObjectInspector(pt.getActualTypeArguments()[1], - options)); + ObjectInspector koi = null; + ObjectInspector voi = null; + if (oi instanceof MapObjectInspector) { + koi = ((MapObjectInspector)oi).getMapKeyObjectInspector(); + voi = ((MapObjectInspector)oi).getMapValueObjectInspector(); + } + koi = getReflectionObjectInspector(pt.getActualTypeArguments()[0], koi, options); + voi = getReflectionObjectInspector(pt.getActualTypeArguments()[1], voi, options); + return getStandardMapObjectInspector(koi, voi); } // Otherwise convert t to RawType so we will fall into the following if // block. t = pt.getRawType(); + } else if (t instanceof TypeVariable) { + TypeVariable tv = (TypeVariable) t; + if (tv.getBounds().length == 1) { + return getReflectionObjectInspectorNoCache(tv.getBounds()[0], oi, options); + } } // Must be a class. @@ -166,16 +198,16 @@ private static ObjectInspector getReflectionObjectInspectorNoCache(Type t, assert (!Map.class.isAssignableFrom(c)); // Create StructObjectInspector - ReflectionStructObjectInspector oi; + ReflectionStructObjectInspector soi; switch (options) { case JAVA: - oi = new ReflectionStructObjectInspector(); + soi = new ReflectionStructObjectInspector(); break; case THRIFT: - oi = new ThriftStructObjectInspector(); + soi = new ThriftStructObjectInspector(); break; case PROTOCOL_BUFFERS: - oi = new ProtocolBuffersStructObjectInspector(); + soi = new ProtocolBuffersStructObjectInspector(); break; default: throw new RuntimeException(ObjectInspectorFactory.class.getName() @@ -183,18 +215,34 @@ private static ObjectInspector getReflectionObjectInspectorNoCache(Type t, } // put it into the cache BEFORE it is initialized to make sure we can catch // recursive types. - objectInspectorCache.put(t, oi); + if (!TypeInfoUtils.containsTypeVariable(t)) { + objectInspectorCache.put(t, soi); + } Field[] fields = ObjectInspectorUtils.getDeclaredNonStaticFields(c); ArrayList structFieldObjectInspectors = new ArrayList( fields.length); for (int i = 0; i < fields.length; i++) { - if (!oi.shouldIgnoreField(fields[i].getName())) { + Map mapping = toFieldMap(oi); + if (!soi.shouldIgnoreField(fields[i].getName())) { structFieldObjectInspectors.add(getReflectionObjectInspector(fields[i] - .getGenericType(), options)); + .getGenericType(), mapping.get(fields[i].getName()), options)); } } - oi.init(c, structFieldObjectInspectors); - return oi; + soi.init(c, structFieldObjectInspectors); + return soi; + } + + private static Map toFieldMap(ObjectInspector oi) { + if (!(oi instanceof StructObjectInspector)) { + return Collections.emptyMap(); + } + StructObjectInspector soi = (StructObjectInspector) oi; + List fields = soi.getAllStructFieldRefs(); + Map mapping = new HashMap(fields.size()); + for (StructField field : fields) { + mapping.put(field.getFieldName(), field.getFieldObjectInspector()); + } + return mapping; } static ConcurrentHashMap cachedStandardListObjectInspector = diff --git serde/src/java/org/apache/hadoop/hive/serde2/typeinfo/TypeInfoUtils.java serde/src/java/org/apache/hadoop/hive/serde2/typeinfo/TypeInfoUtils.java index 8dffe63..6cdcea5 100644 --- serde/src/java/org/apache/hadoop/hive/serde2/typeinfo/TypeInfoUtils.java +++ serde/src/java/org/apache/hadoop/hive/serde2/typeinfo/TypeInfoUtils.java @@ -22,8 +22,9 @@ import java.lang.reflect.Method; import java.lang.reflect.ParameterizedType; import java.lang.reflect.Type; +import java.lang.reflect.TypeVariable; +import java.lang.reflect.WildcardType; import java.util.ArrayList; -import java.util.HashMap; import java.util.LinkedList; import java.util.List; import java.util.Map; @@ -72,17 +73,18 @@ private static TypeInfo getExtendedTypeInfoFromJavaType(Type t, Method m) { return TypeInfoFactory.unknownTypeInfo; } - if (t instanceof ParameterizedType) { + if (t instanceof GenericArrayType) { + GenericArrayType ga = (GenericArrayType)t; + return TypeInfoFactory.getListTypeInfo(getExtendedTypeInfoFromJavaType( + ga.getGenericComponentType(), m)); + } else if (t instanceof ParameterizedType) { ParameterizedType pt = (ParameterizedType) t; - // List? - if (List.class == (Class) pt.getRawType() - || ArrayList.class == (Class) pt.getRawType()) { + Class rawType = TypeInfoUtils.getClassFromType(pt.getRawType()); + if (List.class.isAssignableFrom(rawType)) { return TypeInfoFactory.getListTypeInfo(getExtendedTypeInfoFromJavaType( pt.getActualTypeArguments()[0], m)); } - // Map? - if (Map.class == (Class) pt.getRawType() - || HashMap.class == (Class) pt.getRawType()) { + if (Map.class.isAssignableFrom(rawType)) { return TypeInfoFactory.getMapTypeInfo(getExtendedTypeInfoFromJavaType( pt.getActualTypeArguments()[0], m), getExtendedTypeInfoFromJavaType(pt.getActualTypeArguments()[1], m)); @@ -90,6 +92,11 @@ private static TypeInfo getExtendedTypeInfoFromJavaType(Type t, Method m) { // Otherwise convert t to RawType so we will fall into the following if // block. t = pt.getRawType(); + } else if (t instanceof TypeVariable) { + TypeVariable tv = (TypeVariable) t; + if (tv.getBounds().length == 1) { + return getExtendedTypeInfoFromJavaType(tv.getBounds()[0], m); + } } // Must be a class. @@ -135,6 +142,86 @@ private static TypeInfo getExtendedTypeInfoFromJavaType(Type t, Method m) { return TypeInfoFactory.getStructTypeInfo(fieldNames, fieldTypeInfos); } + public static boolean containsTypeVariable(Type t) { + if (t instanceof ParameterizedType) { + ParameterizedType pt = (ParameterizedType) t; + for (Type child : pt.getActualTypeArguments()) { + if (containsTypeVariable(child)) { + return true; + } + } + return false; + } + if (t instanceof GenericArrayType) { + GenericArrayType ga = (GenericArrayType)t; + return containsTypeVariable(ga.getGenericComponentType()); + } + if (t instanceof WildcardType) { + WildcardType wt = (WildcardType) t; + for (Type child : wt.getUpperBounds()) { + if (containsTypeVariable(child)) { + return true; + } + } + for (Type child : wt.getLowerBounds()) { + if (containsTypeVariable(child)) { + return true; + } + } + return false; + } + return t instanceof TypeVariable; + } + + public static Class getClassFromType(Type t) { + if (t instanceof Class) { + return (Class) t; + } + if (t instanceof ParameterizedType) { + ParameterizedType pt = (ParameterizedType) t; + return getClassFromType(pt.getRawType()); + } + if (t instanceof TypeVariable) { + TypeVariable tv = (TypeVariable) t; + if (tv.getBounds().length == 1) { + return getClassFromType(tv.getBounds()[0]); + } + } + return null; + } + + public static ObjectInspector extractOI(TypeVariable tv, Type t, ObjectInspector oi) { + if (t instanceof TypeVariable) { + return tv == t ? oi : null; + } + if (t instanceof ParameterizedType) { + ParameterizedType pt = (ParameterizedType) t; + Class rawType = getClassFromType(pt.getRawType()); + if (rawType == null) { + return null; + } + if (List.class.isAssignableFrom(rawType)) { + ObjectInspector eoi = ((ListObjectInspector) oi).getListElementObjectInspector(); + return extractOI(tv, pt.getActualTypeArguments()[0], eoi); + } + if (Map.class.isAssignableFrom(rawType)) { + ObjectInspector koi = ((MapObjectInspector) oi).getMapKeyObjectInspector(); + ObjectInspector voi = ((MapObjectInspector) oi).getMapValueObjectInspector(); + ObjectInspector found = extractOI(tv, pt.getActualTypeArguments()[0], koi); + if (found == null) { + found = extractOI(tv, pt.getActualTypeArguments()[1], voi); + } + return found; + } + } + if (t instanceof GenericArrayType && oi instanceof ListObjectInspector) { + GenericArrayType ga = (GenericArrayType)t; + ObjectInspector eoi = ((ListObjectInspector) oi).getListElementObjectInspector(); + return extractOI(tv, ga.getGenericComponentType(), eoi); + } + return null; + } + /** * Returns the array element type, if the Type is an array (Object[]), or * GenericArrayType (Map[]). Otherwise return null.