Index: ql/src/java/org/apache/hadoop/hive/ql/exec/FunctionRegistry.java =================================================================== --- ql/src/java/org/apache/hadoop/hive/ql/exec/FunctionRegistry.java +++ ql/src/java/org/apache/hadoop/hive/ql/exec/FunctionRegistry.java @@ -161,6 +161,7 @@ import org.apache.hadoop.hive.ql.udf.generic.GenericUDFIf; import org.apache.hadoop.hive.ql.udf.generic.GenericUDFIn; import org.apache.hadoop.hive.ql.udf.generic.GenericUDFIndex; +import org.apache.hadoop.hive.ql.udf.generic.GenericUDFInFile; import org.apache.hadoop.hive.ql.udf.generic.GenericUDFInstr; import org.apache.hadoop.hive.ql.udf.generic.GenericUDFLocate; import org.apache.hadoop.hive.ql.udf.generic.GenericUDFMap; @@ -426,6 +427,7 @@ registerGenericUDF("hash", GenericUDFHash.class); registerGenericUDF("coalesce", GenericUDFCoalesce.class); registerGenericUDF("index", GenericUDFIndex.class); + registerGenericUDF("in_file", GenericUDFInFile.class); registerGenericUDF("instr", GenericUDFInstr.class); registerGenericUDF("locate", GenericUDFLocate.class); registerGenericUDF("elt", GenericUDFElt.class); @@ -710,7 +712,7 @@ GenericUDAFEvaluator udafEvaluator = null; ObjectInspector args[] = new ObjectInspector[argumentOIs.size()]; - // Can't use toArray here because Java is dumb when it comes to + // Can't use toArray here because Java is dumb when it comes to // generics + arrays. for (int ii = 0; ii < argumentOIs.size(); ++ii) { args[ii] = argumentOIs.get(ii); Index: ql/src/java/org/apache/hadoop/hive/ql/plan/ExprNodeGenericFuncDesc.java =================================================================== --- ql/src/java/org/apache/hadoop/hive/ql/plan/ExprNodeGenericFuncDesc.java +++ ql/src/java/org/apache/hadoop/hive/ql/plan/ExprNodeGenericFuncDesc.java @@ -212,6 +212,25 @@ } ObjectInspector oi = genericUDF.initializeAndFoldConstants(childrenOIs); + + String[] requiredJars = genericUDF.getRequiredJars(); + String[] requiredFiles = genericUDF.getRequiredFiles(); + SessionState ss = SessionState.get(); + + if (requiredJars != null) { + SessionState.ResourceType t = SessionState.find_resource_type("JAR"); + for (String jarPath : requiredJars) { + ss.add_resource(t, jarPath); + } + } + + if (requiredFiles != null) { + SessionState.ResourceType t = SessionState.find_resource_type("FILE"); + for (String filePath : requiredFiles) { + ss.add_resource(t, filePath); + } + } + return new ExprNodeGenericFuncDesc(oi, genericUDF, children); } Index: ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDF.java =================================================================== --- ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDF.java +++ ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDF.java @@ -75,7 +75,7 @@ /** * Initialize this GenericUDF. This will be called once and only once per * GenericUDF instance. - * + * * @param arguments * The ObjectInspector for the arguments * @throws UDFArgumentException @@ -97,6 +97,13 @@ ObjectInspector oi = initialize(arguments); + // If the UDF depends on any external resources, we can't fold because the + // resources may not be available at compile time. + if (getRequiredFiles() != null || + getRequiredJars() != null) { + return oi; + } + boolean allConstant = true; for (int ii = 0; ii < arguments.length; ++ii) { if (!ObjectInspectorUtils.isConstantObjectInspector(arguments[ii])) { @@ -127,6 +134,19 @@ } /** + * The following two functions can be overridden to automatically include + * additional resources required by this UDF. The return types should be + * arrays of paths. + */ + public String[] getRequiredJars() { + return null; + } + + public String[] getRequiredFiles() { + return null; + } + + /** * Evaluate the GenericUDF with the arguments. * * @param arguments Index: ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDFInFile.java =================================================================== --- /dev/null +++ ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDFInFile.java @@ -0,0 +1,142 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.hadoop.hive.ql.udf.generic; + +import java.io.BufferedReader; +import java.io.File; +import java.io.FileInputStream; +import java.io.FileNotFoundException; +import java.io.InputStream; +import java.io.InputStreamReader; + +import java.util.HashSet; + +import org.apache.hadoop.hive.ql.exec.Description; +import org.apache.hadoop.hive.ql.exec.UDFArgumentException; +import org.apache.hadoop.hive.ql.exec.UDFArgumentLengthException; +import org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException; +import org.apache.hadoop.hive.ql.metadata.HiveException; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils; + +/** + * IN_FILE(str, filename) returns true if 'str' appears in the file specified + * by 'filename'. A string is considered to be in the file if it that string + * appears as a line in the file. + * + * If either argument is NULL then NULL is returned. + */ +@Description(name = "in_file", + value = "_FUNC_(str, filename) - Returns true if str appears in the file") +public class GenericUDFInFile extends GenericUDF { + + HashSet set; + ObjectInspector strObjectInspector; + ObjectInspector fileObjectInspector; + + @Override + public ObjectInspector initialize(ObjectInspector[] arguments) + throws UDFArgumentException { + if (arguments.length != 2) { + throw new UDFArgumentLengthException( + "IN_FILE() accepts exactly 2 arguments."); + } + + for (int i = 0; i < arguments.length; i++) { + if (!String.class.equals( + PrimitiveObjectInspectorUtils. + getJavaPrimitiveClassFromObjectInspector(arguments[i]))) { + throw new UDFArgumentTypeException(i, "The " + + GenericUDFUtils.getOrdinal(i + 1) + + " argument of function IN_FILE must be a string but " + + arguments[i].toString() + " was given."); + } + } + + strObjectInspector = arguments[0]; + fileObjectInspector = arguments[1]; + + if (!ObjectInspectorUtils.isConstantObjectInspector(fileObjectInspector)) { + throw new UDFArgumentTypeException(1, + "The second argument of IN_FILE() must be a constant string but " + + fileObjectInspector.toString() + " was given."); + } + + return PrimitiveObjectInspectorFactory.javaBooleanObjectInspector; + } + + @Override + public String[] getRequiredFiles() { + return new String[] { + ObjectInspectorUtils.getWritableConstantValue(fileObjectInspector) + .toString() + }; + } + + @Override + public Object evaluate(DeferredObject[] arguments) throws HiveException { + if (arguments[0].get() == null || arguments[1].get() == null) { + return null; + } + + String str = (String)ObjectInspectorUtils.copyToStandardJavaObject( + arguments[0].get(), strObjectInspector); + + if (set == null) { + String fileName = (String)ObjectInspectorUtils.copyToStandardJavaObject( + arguments[1].get(), fileObjectInspector); + try { + load(new FileInputStream((new File(fileName)).getName())); + } catch (FileNotFoundException e) { + throw new HiveException(e); + } + } + + return Boolean.valueOf(set.contains(str)); + } + + /** + * Load the file from an InputStream. + * @param is The InputStream contains the file data. + * @throws HiveException + */ + public void load(InputStream is) throws HiveException { + BufferedReader reader = + new BufferedReader(new InputStreamReader(is)); + + set = new HashSet(); + + try { + String line; + while((line = reader.readLine()) != null) { + set.add(line); + } + } catch (Exception e) { + throw new HiveException(e); + } + } + + @Override + public String getDisplayString(String[] children) { + assert (children.length == 2); + return "in_file(" + children[0] + ", " + children[1] + ")"; + } +} Index: ql/src/test/queries/clientpositive/udf_in_file.q =================================================================== --- /dev/null +++ ql/src/test/queries/clientpositive/udf_in_file.q @@ -0,0 +1,12 @@ +DESCRIBE FUNCTION in_file; + +EXPLAIN +SELECT in_file("303", "../data/files/test2.dat"), + in_file("304", "../data/files/test2.dat"), + in_file(CAST(NULL AS STRING), "../data/files/test2.dat") +FROM src LIMIT 1; + +SELECT in_file("303", "../data/files/test2.dat"), + in_file("304", "../data/files/test2.dat"), + in_file(CAST(NULL AS STRING), "../data/files/test2.dat") +FROM src LIMIT 1; Index: ql/src/test/results/clientpositive/show_functions.q.out =================================================================== --- ql/src/test/results/clientpositive/show_functions.q.out +++ ql/src/test/results/clientpositive/show_functions.q.out @@ -76,6 +76,7 @@ hour if in +in_file index instr int Index: ql/src/test/results/clientpositive/udf_in_file.q.out =================================================================== --- /dev/null +++ ql/src/test/results/clientpositive/udf_in_file.q.out @@ -0,0 +1,68 @@ +PREHOOK: query: DESCRIBE FUNCTION in_file +PREHOOK: type: DESCFUNCTION +POSTHOOK: query: DESCRIBE FUNCTION in_file +POSTHOOK: type: DESCFUNCTION +in_file(str, filename) - Returns true if str appears in the file +PREHOOK: query: EXPLAIN +SELECT in_file("303", "../data/files/test2.dat"), + in_file("304", "../data/files/test2.dat"), + in_file(CAST(NULL AS STRING), "../data/files/test2.dat") +FROM src LIMIT 1 +PREHOOK: type: QUERY +POSTHOOK: query: EXPLAIN +SELECT in_file("303", "../data/files/test2.dat"), + in_file("304", "../data/files/test2.dat"), + in_file(CAST(NULL AS STRING), "../data/files/test2.dat") +FROM src LIMIT 1 +POSTHOOK: type: QUERY +ABSTRACT SYNTAX TREE: + (TOK_QUERY (TOK_FROM (TOK_TABREF (TOK_TABNAME src))) (TOK_INSERT (TOK_DESTINATION (TOK_DIR TOK_TMP_FILE)) (TOK_SELECT (TOK_SELEXPR (TOK_FUNCTION in_file "303" "../data/files/test2.dat")) (TOK_SELEXPR (TOK_FUNCTION in_file "304" "../data/files/test2.dat")) (TOK_SELEXPR (TOK_FUNCTION in_file (TOK_FUNCTION TOK_STRING TOK_NULL) "../data/files/test2.dat"))) (TOK_LIMIT 1))) + +STAGE DEPENDENCIES: + Stage-1 is a root stage + Stage-0 is a root stage + +STAGE PLANS: + Stage: Stage-1 + Map Reduce + Alias -> Map Operator Tree: + src + TableScan + alias: src + Select Operator + expressions: + expr: in_file('303', '../data/files/test2.dat') + type: boolean + expr: in_file('304', '../data/files/test2.dat') + type: boolean + expr: in_file(UDFToString(null), '../data/files/test2.dat') + type: boolean + outputColumnNames: _col0, _col1, _col2 + Limit + File Output Operator + compressed: false + GlobalTableId: 0 + table: + input format: org.apache.hadoop.mapred.TextInputFormat + output format: org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat + + Stage: Stage-0 + Fetch Operator + limit: 1 + + +PREHOOK: query: SELECT in_file("303", "../data/files/test2.dat"), + in_file("304", "../data/files/test2.dat"), + in_file(CAST(NULL AS STRING), "../data/files/test2.dat") +FROM src LIMIT 1 +PREHOOK: type: QUERY +PREHOOK: Input: default@src +PREHOOK: Output: file:/var/folders/71/h_j6fpg10r33hvx1lcxlgttcw61_4s/T/jonchang/hive_2011-11-20_01-14-47_787_473222432824773238/-mr-10000 +POSTHOOK: query: SELECT in_file("303", "../data/files/test2.dat"), + in_file("304", "../data/files/test2.dat"), + in_file(CAST(NULL AS STRING), "../data/files/test2.dat") +FROM src LIMIT 1 +POSTHOOK: type: QUERY +POSTHOOK: Input: default@src +POSTHOOK: Output: file:/var/folders/71/h_j6fpg10r33hvx1lcxlgttcw61_4s/T/jonchang/hive_2011-11-20_01-14-47_787_473222432824773238/-mr-10000 +true false NULL Index: serde/src/java/org/apache/hadoop/hive/serde2/objectinspector/ObjectInspectorUtils.java =================================================================== --- serde/src/java/org/apache/hadoop/hive/serde2/objectinspector/ObjectInspectorUtils.java +++ serde/src/java/org/apache/hadoop/hive/serde2/objectinspector/ObjectInspectorUtils.java @@ -208,6 +208,9 @@ return copyToStandardObject(o, oi, ObjectInspectorCopyOption.DEFAULT); } + public static Object copyToStandardJavaObject(Object o, ObjectInspector oi) { + return copyToStandardObject(o, oi, ObjectInspectorCopyOption.JAVA); + } public static Object copyToStandardObject(Object o, ObjectInspector oi, ObjectInspectorCopyOption objectInspectorOption) { @@ -490,7 +493,7 @@ ObjectInspector valueOI = mapOI.getMapValueObjectInspector(); Map map = mapOI.getMap(o); for (Map.Entry entry : map.entrySet()) { - r += hashCode(entry.getKey(), keyOI) ^ + r += hashCode(entry.getKey(), keyOI) ^ hashCode(entry.getValue(), valueOI); } return r; @@ -564,7 +567,7 @@ ObjectInspector oi2) { return compare(o1, oi1, o2, oi2, new FullMapEqualComparer()); } - + /** * Compare two objects with their respective ObjectInspectors. */ @@ -667,7 +670,7 @@ for (int i = 0; i < minimum; i++) { int r = compare(soi1.getStructFieldData(o1, fields1.get(i)), fields1 .get(i).getFieldObjectInspector(), soi2.getStructFieldData(o2, - fields2.get(i)), fields2.get(i).getFieldObjectInspector(), + fields2.get(i)), fields2.get(i).getFieldObjectInspector(), mapEqualComparer); if (r != 0) { return r; @@ -921,6 +924,10 @@ } } + public static Object getWritableConstantValue(ObjectInspector oi) { + return ((ConstantObjectInspector)oi).getWritableConstantValue(); + } + public static boolean supportsConstantObjectInspector(ObjectInspector oi) { switch (oi.getCategory()) { case PRIMITIVE: Index: serde/src/java/org/apache/hadoop/hive/serde2/objectinspector/primitive/PrimitiveObjectInspectorUtils.java =================================================================== --- serde/src/java/org/apache/hadoop/hive/serde2/objectinspector/primitive/PrimitiveObjectInspectorUtils.java +++ serde/src/java/org/apache/hadoop/hive/serde2/objectinspector/primitive/PrimitiveObjectInspectorUtils.java @@ -33,6 +33,8 @@ import org.apache.hadoop.hive.serde2.lazy.ByteArrayRef; import org.apache.hadoop.hive.serde2.lazy.LazyInteger; import org.apache.hadoop.hive.serde2.lazy.LazyLong; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector.Category; import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector.PrimitiveCategory; import org.apache.hadoop.io.BooleanWritable; @@ -796,6 +798,16 @@ return result; } + public static Class getJavaPrimitiveClassFromObjectInspector(ObjectInspector oi) { + if (oi.getCategory() != Category.PRIMITIVE) { + return null; + } + PrimitiveObjectInspector poi = (PrimitiveObjectInspector)oi; + PrimitiveTypeEntry t = + getTypeEntryFromPrimitiveCategory(poi.getPrimitiveCategory()); + return t == null ? null : t.primitiveJavaClass; + } + private PrimitiveObjectInspectorUtils() { // prevent instantiation }