diff --git c/ql/src/java/org/apache/hadoop/hive/ql/exec/FunctionRegistry.java w/ql/src/java/org/apache/hadoop/hive/ql/exec/FunctionRegistry.java index d59bf1fb6e62a821d79c7d4040bdfb7b43cd3ba5..dc1017861ce978e7c12539e41040d23bbf6f6f1e 100644 --- c/ql/src/java/org/apache/hadoop/hive/ql/exec/FunctionRegistry.java +++ w/ql/src/java/org/apache/hadoop/hive/ql/exec/FunctionRegistry.java @@ -471,6 +471,7 @@ system.registerGenericUDF("map", GenericUDFMap.class); system.registerGenericUDF("struct", GenericUDFStruct.class); system.registerGenericUDF("named_struct", GenericUDFNamedStruct.class); + system.registerGenericUDF("generic_project", GenericProject.class); system.registerGenericUDF("create_union", GenericUDFUnion.class); system.registerGenericUDF("extract_union", GenericUDFExtractUnion.class); diff --git c/ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericProject.java w/ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericProject.java new file mode 100755 index 0000000000000000000000000000000000000000..717333d363037f4ca4f67c7f95896f93e726975a --- /dev/null +++ w/ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericProject.java @@ -0,0 +1,253 @@ +package org.apache.hadoop.hive.ql.udf.generic; + +import com.google.common.base.Preconditions; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import org.apache.hadoop.hive.ql.exec.Description; +import org.apache.hadoop.hive.ql.exec.UDFArgumentException; +import org.apache.hadoop.hive.ql.exec.UDFArgumentLengthException; +import org.apache.hadoop.hive.ql.metadata.HiveException; +import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.MapObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorConverters; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory; +import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.StructField; +import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.WritableConstantStringObjectInspector; +import org.apache.hadoop.hive.serde2.typeinfo.ListTypeInfo; +import org.apache.hadoop.hive.serde2.typeinfo.MapTypeInfo; +import org.apache.hadoop.hive.serde2.typeinfo.StructTypeInfo; +import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo; +import org.apache.hadoop.hive.serde2.typeinfo.TypeInfoUtils; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + + +/** + * Takes a input schema and an object and project only the fields specified by the schema + */ +@Description( + name = "GenericProject", + value = "_FUNC_(struct|array|map|primitive, schema_to_project_as_string) -" + + "Takes a input schema and an object and project only the fields specified by the schema") +public class GenericProject extends GenericUDF { + private static final Logger LOG = LoggerFactory.getLogger(GenericProject.class); + + private ObjectInspector _inputOI; + private ObjectInspector _outputOI; + + @Override + public String getDisplayString(String[] children) { + StringBuilder sb = new StringBuilder(); + sb.append("GenericProject("); + if (children != null && children.length > 0) { + sb.append(children[0]); + for (int i = 1; i < children.length; i++) { + sb.append(","); + sb.append(children[i]); + } + } + sb.append(")"); + return sb.toString(); + } + + @Override + public ObjectInspector initialize(ObjectInspector[] args) + throws UDFArgumentException { + if (args.length != 2) { + throw new UDFArgumentLengthException( + GenericProject.class.getSimpleName() + + " takes two arguments: (obj: struct|array|map|prmitive, schema_to_project_as_string)"); + } + + _inputOI = args[0]; + final TypeInfo schemaToProject = getSchemaToProject(args[1]); + // Instead of calling + // TypeInfoUtils.getStandardJavaObjectInspectorFromTypeInfo() or + // TypeInfoUtils.getStandardWritableObjectInspectorFromTypeInfo() + // we traverse the input object inspector and construct + // our output objectinspector which has the same primitive + // type (writable/java etc) corresponding to the input + // object inspector. + _outputOI = outputObjectInspector(_inputOI, schemaToProject); + return _outputOI; + } + + private static TypeInfo getSchemaToProject(ObjectInspector oi) throws UDFArgumentException { + if (!(oi instanceof WritableConstantStringObjectInspector)) { + throw new UDFArgumentException("Second argument must be a constant string. Unexpected ObjectInspector: " + oi); + } + final String typeInfoStr = ((WritableConstantStringObjectInspector) oi).getWritableConstantValue().toString(); + try { + return TypeInfoUtils.getTypeInfoFromTypeString(typeInfoStr); + } catch (Exception e) { + throw new UDFArgumentException( + "Second argument must be parse-able into Hive's TypeInfo. Received: " + typeInfoStr); + } + } + + @Override + public Object evaluate(DeferredObject[] arguments) + throws HiveException { + final Object obj = arguments[0].get(); + return project(obj, _inputOI, _outputOI); + } + + private static Object project(Object obj, ObjectInspector inputOI, ObjectInspector outputOI) { + if (obj == null) { + return null; + } + if (inputOI.equals(outputOI)) { + return obj; + } + switch (outputOI.getCategory()) { + case LIST: + Preconditions.checkArgument( + inputOI.getCategory() == ObjectInspector.Category.LIST, + "Input Object inspector: " + inputOI.getCategory() + "not of type list"); + return projectList(obj, (ListObjectInspector) inputOI, (ListObjectInspector) outputOI); + + case MAP: + Preconditions.checkArgument( + inputOI.getCategory() == ObjectInspector.Category.MAP, + "Input Object inspector: " + inputOI.getCategory() + "not of type map"); + return projectMap(obj, (MapObjectInspector) inputOI, (MapObjectInspector) outputOI); + + case STRUCT: + Preconditions.checkArgument( + inputOI.getCategory() == ObjectInspector.Category.STRUCT, + "Input Object inspector: " + inputOI.getCategory() + "not of type struct"); + return projectStruct(obj, (StructObjectInspector) inputOI, (StructObjectInspector) outputOI); + + case PRIMITIVE: + if (LOG.isDebugEnabled()) { + LOG.debug("Primitive input and output object inspectors are different. " + + "Converting from " + inputOI + " to " + outputOI); + } + return ObjectInspectorConverters.getConverter(inputOI, outputOI).convert(obj); + + default: + throw new UnsupportedOperationException(outputOI.getCategory() + " not supported"); + } + } + + private static Object projectStruct( + Object obj, + StructObjectInspector inputOI, StructObjectInspector outputOI) { + final List outputFieldRefs = outputOI.getAllStructFieldRefs(); + final List r = new ArrayList(outputFieldRefs.size()); + for (StructField ofr : outputFieldRefs) { + final StructField ifr = inputOI.getStructFieldRef(ofr.getFieldName()); + final Object data = inputOI.getStructFieldData(obj, ifr); + final Object projectedData = project(data, ifr.getFieldObjectInspector(), ofr.getFieldObjectInspector()); + r.add(projectedData); + } + return r; + } + + private static Object projectMap( + final Object obj, + final MapObjectInspector inputOI, final MapObjectInspector outputOI) { + + final ObjectInspector mapKeyInputOI = inputOI.getMapKeyObjectInspector(); + final ObjectInspector mapValInputOI = inputOI.getMapValueObjectInspector(); + + final ObjectInspector mapKeyOutputOI = outputOI.getMapKeyObjectInspector(); + final ObjectInspector mapValOutputOI = outputOI.getMapValueObjectInspector(); + + final Map m = inputOI.getMap(obj); + final Map r = new HashMap(); + for (Map.Entry entry : m.entrySet()) { + r.put( + project(entry.getKey(), mapKeyInputOI, mapKeyOutputOI), + project(entry.getValue(), mapValInputOI, mapValOutputOI)); + } + return r; + } + + private static Object projectList( + final Object obj, + final ListObjectInspector inputOI, final ListObjectInspector outputOI) { + final List l = inputOI.getList(obj); + final List r = new ArrayList(l.size()); + for (Object e : l) { + final Object o = project(e, inputOI.getListElementObjectInspector(), outputOI.getListElementObjectInspector()); + r.add(o); + } + return r; + } + + /** + * Builds an outputObjectInspector corresponding to the schema 'outputSchema' + * @param inputOI Required to traverse the input objectinspector tree to figure out + * whether to use primitive java or writable types for corresponding + * primitive types in input objectinspector + * @param outputSchema the schema for which to construct the output objectinspector for + */ + private static ObjectInspector outputObjectInspector(ObjectInspector inputOI, TypeInfo outputSchema) { + switch (outputSchema.getCategory()) { + case STRUCT: + return structOutputObjectInspector((StructObjectInspector) inputOI, (StructTypeInfo) outputSchema); + case LIST: + return listOutputObjectInspector((ListObjectInspector) inputOI, (ListTypeInfo) outputSchema); + case MAP: + return mapOutputObjectInspector((MapObjectInspector) inputOI, (MapTypeInfo) outputSchema); + case PRIMITIVE: + return primitiveOutputObjectInspector((PrimitiveObjectInspector) inputOI, outputSchema); + default: + throw new UnsupportedOperationException("Unsupported schema: " + outputSchema); + } + } + + private static ObjectInspector structOutputObjectInspector( + StructObjectInspector inputOI, StructTypeInfo outputSchema) { + final List outputFieldNames = outputSchema.getAllStructFieldNames(); + final List outputFieldTIs = outputSchema.getAllStructFieldTypeInfos(); + final List outputFieldOIs = new ArrayList(outputFieldNames.size()); + for (int i = 0; i < outputFieldNames.size(); i++) { + final String outputFieldName = outputFieldNames.get(i); + final TypeInfo outputFieldTI = outputFieldTIs.get(i); + // 'outputFieldName' should always be present in 'inputOI' + final StructField fieldRef = inputOI.getStructFieldRef(outputFieldName); + outputFieldOIs.add( + outputObjectInspector( + fieldRef.getFieldObjectInspector(), + outputFieldTI)); + } + return ObjectInspectorFactory.getStandardStructObjectInspector( + outputFieldNames, outputFieldOIs); + } + + private static ObjectInspector listOutputObjectInspector(ListObjectInspector inputOI, ListTypeInfo outputSchema) { + final ObjectInspector elementOI = inputOI.getListElementObjectInspector(); + final TypeInfo elementTI = outputSchema.getListElementTypeInfo(); + return ObjectInspectorFactory.getStandardListObjectInspector( + outputObjectInspector(elementOI, elementTI)); + } + + private static ObjectInspector mapOutputObjectInspector(MapObjectInspector inputOI, MapTypeInfo outputSchema) { + final ObjectInspector mapKeyOI = inputOI.getMapKeyObjectInspector(); + final ObjectInspector mapValOI = inputOI.getMapValueObjectInspector(); + + final TypeInfo mapKeyTI = outputSchema.getMapKeyTypeInfo(); + final TypeInfo mapValTI = outputSchema.getMapValueTypeInfo(); + + return ObjectInspectorFactory.getStandardMapObjectInspector( + outputObjectInspector(mapKeyOI, mapKeyTI), + outputObjectInspector(mapValOI, mapValTI)); + } + + private static ObjectInspector primitiveOutputObjectInspector( + PrimitiveObjectInspector inputOI, TypeInfo outputSchema) { + if (inputOI.preferWritable()) { + return TypeInfoUtils.getStandardWritableObjectInspectorFromTypeInfo(outputSchema); + } else { + return TypeInfoUtils.getStandardJavaObjectInspectorFromTypeInfo(outputSchema); + } + } +} + diff --git c/ql/src/test/org/apache/hadoop/hive/ql/udf/generic/TestGenericProject.java w/ql/src/test/org/apache/hadoop/hive/ql/udf/generic/TestGenericProject.java new file mode 100755 index 0000000000000000000000000000000000000000..75c1dc7b292069294bac29378f7884773e771793 --- /dev/null +++ w/ql/src/test/org/apache/hadoop/hive/ql/udf/generic/TestGenericProject.java @@ -0,0 +1,271 @@ +package org.apache.hadoop.hive.ql.udf.generic; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory; +import org.apache.hadoop.hive.serde2.typeinfo.TypeInfoFactory; +import org.apache.hadoop.hive.serde2.typeinfo.TypeInfoUtils; +import org.apache.hadoop.io.IntWritable; +import org.apache.hadoop.io.Text; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + + + public class TestGenericProject { + private GenericProject _udf; + + @Before + public void setup() { + _udf = new GenericProject(); + } + + @Test + public void testConversion() throws Exception { + ObjectInspector inputOI = javaObjectInspector("struct"); + ObjectInspector outputStringOI = PrimitiveObjectInspectorFactory + .getPrimitiveWritableConstantObjectInspector( + TypeInfoFactory.stringTypeInfo, new Text("struct")); + + Object input = struct(2); + Object output = runUDF(input, new ObjectInspector[]{inputOI, outputStringOI}, "struct"); + Assert.assertEquals(struct(2.0d), output); + + /// test map key, val conversion + inputOI = javaObjectInspector("struct>"); + outputStringOI = PrimitiveObjectInspectorFactory + .getPrimitiveWritableConstantObjectInspector( + TypeInfoFactory.stringTypeInfo, new Text("struct>")); + + input = struct(map(1,2)); + output = runUDF(input, new ObjectInspector[]{inputOI, outputStringOI}, "struct>"); + Assert.assertEquals(struct(map(1.0d,2.0d)), output); + } + + @Test + public void testNestedProjections() throws Exception { + /** + * actual + * + * struct< + * a: array< + * struct< + * b: struct< + * c: int + * d: string>>> + * + * projected + * + * struct< + * a: array< + * struct< + * b: struct< + * c: int>>>> + * + */ + final ObjectInspector inputOI = javaObjectInspector("struct>>>"); + final ObjectInspector outputStringOI = PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector + (TypeInfoFactory.stringTypeInfo, new Text("struct>>>")); + + Object input = struct(array(struct(struct(1, "D")), struct(struct(2, "E")))); + final Object output = runUDF( + input, new ObjectInspector[]{inputOI, outputStringOI}, "struct>>>"); + Assert.assertEquals(struct(array(struct(struct(1)), struct(struct(2)))), output); + } + + @Test + public void noProjections() throws Exception { + final String inputSchema = "struct>>>"; + final ObjectInspector inputOI = writableObjectInspector(inputSchema); + final ObjectInspector outputStringOI = PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector + (TypeInfoFactory.stringTypeInfo, new Text(inputSchema)); + + Object input = + struct( + array( + struct( + struct(new IntWritable(1), new Text("D"))), + struct( + struct(new IntWritable(2), new Text("E"))))); + + final Object output = runUDF( + input, new ObjectInspector[]{inputOI, outputStringOI}, inputSchema); + Assert.assertTrue(input == output); + } + + @Test + public void multipleProjections() throws Exception { + /** + * actual + * + * array< + * struct< + * b: struct< + * c: int + * d: struct< + * e: int + * f: struct>>>> + * + * projected + * + * array< + * struct< + * b: struct< + * d: struct< + * e: int>>>> + * + */ + final ObjectInspector inputOI = javaObjectInspector( + "array>>>>"); + final String outputSchema = "array>>>"; + final ObjectInspector outputStringOI = PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector + (TypeInfoFactory.stringTypeInfo, new Text(outputSchema)); + + Object input = array( + struct(struct(1, struct(2, struct(3)))), + struct(struct(4, struct(5, struct(6))))); + + final Object output = runUDF(input, new ObjectInspector[]{inputOI, outputStringOI}, outputSchema); + Assert.assertEquals( + array(struct(struct(struct(2))), + struct(struct(struct(5)))), + output); + } + + @Test + public void minimalObjectCopy() throws Exception { + /** + * actual: + * struct< + * a: struct< + * b: struct< + * c: array, + * d: struct< + * e: int>>> + * f: struct + * + * projected: + * struct< + * a: struct< + * b: struct< + * d: struct< ; d is not copied + * e: int>>> + * f: struct ; f is not copied + */ + final ObjectInspector inputOI = + writableObjectInspector("struct,d:struct>>,f:struct>"); + final String outputSchema = "struct>>,f:struct>"; + final ObjectInspector outputStringOI = PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector + (TypeInfoFactory.stringTypeInfo, new Text(outputSchema)); + + Object c = array(new IntWritable(1), new IntWritable(1), new IntWritable(1), new IntWritable(1)); + Object d = struct(new IntWritable(2)); + Object b = struct(c, d); + Object a = struct(b); + Object f = struct(new IntWritable(3)); + Object input = struct(a, f); + + final Object output = runUDF(input, new ObjectInspector[]{inputOI, outputStringOI}, outputSchema); + + Object bUpdated = struct(d); + Object aUpdated = struct(bUpdated); + Object expectedOutput = struct(aUpdated, f); + Assert.assertEquals(expectedOutput, output); + + final List>>> outputData = (List>>>) output; + + Assert.assertTrue(outputData.get(1) == f); + Assert.assertTrue(outputData.get(0).get(0).get(0) == d); + } + + @Test + public void testNullData() throws Exception { + final ObjectInspector inputOI = + javaObjectInspector("struct,b:struct>>>"); + final String projectedSchema = "struct,b:struct>>>"; + final ObjectInspector projectedSchemaAsStrOI = + PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( + TypeInfoFactory.stringTypeInfo, + new Text(projectedSchema)); + + Object input = null; + Object output = runUDF( + input, new ObjectInspector[]{inputOI, projectedSchemaAsStrOI}, projectedSchema); + Assert.assertEquals(null, output); + + // field 'a' is null, 'c' is null, 'd' is null + input = struct(null, struct(null, null)); + output = runUDF( + input, new ObjectInspector[]{inputOI, projectedSchemaAsStrOI}, projectedSchema); + // projected output does not contain 'c' + Assert.assertEquals(struct(null, struct(new Map[]{null})), output); + + // an element in field 'a' is null, a few map entries are null + input = struct( + array(1, null, 3), + struct(4, + map("a", struct(5.0d), + "b", null, + null, null, + "d", struct(new Double[]{null})))); + + output = runUDF( + input, new ObjectInspector[]{inputOI, projectedSchemaAsStrOI}, projectedSchema); + Assert.assertEquals( + struct( + array(1.0d, null, 3.0d), + struct( + map("a", struct(5.0d), + "b", null, + null, null, + "d", struct(new Double[]{null})))), + output + ); + } + + private Object runUDF( + Object input, ObjectInspector[] inputOIs, String outputSchema) throws Exception { + GenericUDF.DeferredObject deferredObject = new GenericUDF.DeferredJavaObject(input); + final ObjectInspector outputOI = _udf.initialize(inputOIs); + final String outputTypeInfoStr = TypeInfoUtils.getTypeInfoFromObjectInspector(outputOI).toString(); + Assert.assertEquals(outputSchema, outputTypeInfoStr); + return _udf.evaluate(new GenericUDF.DeferredObject[]{deferredObject}); + } + + private ObjectInspector javaObjectInspector(String typeInfoStr) { + return TypeInfoUtils.getStandardJavaObjectInspectorFromTypeInfo( + TypeInfoUtils.getTypeInfoFromTypeString(typeInfoStr)); + } + + private ObjectInspector writableObjectInspector(String typeInfoStr) { + return TypeInfoUtils.getStandardWritableObjectInspectorFromTypeInfo( + TypeInfoUtils.getTypeInfoFromTypeString(typeInfoStr)); + } + + private static Object struct(Object... fields) { + final List r = new ArrayList(); + for (Object f : fields) { + r.add(f); + } + return r; + } + + private Object array(Object... fields) { + final List r = new ArrayList(); + for (Object f : fields) { + r.add(f); + } + return r; + } + + private Object map(Object... fields) { + Map m = new HashMap(); + for (int i = 0; i < fields.length; i += 2) { + m.put(fields[i], fields[i + 1]); + } + return m; + } +}