Index: ql/src/test/results/clientpositive/udaf_argmax.q.out =================================================================== --- ql/src/test/results/clientpositive/udaf_argmax.q.out (revision 0) +++ ql/src/test/results/clientpositive/udaf_argmax.q.out (revision 0) @@ -0,0 +1,75 @@ +PREHOOK: query: DROP TABLE covar_tab +PREHOOK: type: DROPTABLE +POSTHOOK: query: DROP TABLE covar_tab +POSTHOOK: type: DROPTABLE +PREHOOK: query: CREATE TABLE covar_tab (a INT, b INT, c INT) +ROW FORMAT DELIMITED FIELDS TERMINATED BY '\t' +STORED AS TEXTFILE +PREHOOK: type: CREATETABLE +POSTHOOK: query: CREATE TABLE covar_tab (a INT, b INT, c INT) +ROW FORMAT DELIMITED FIELDS TERMINATED BY '\t' +STORED AS TEXTFILE +POSTHOOK: type: CREATETABLE +POSTHOOK: Output: default@covar_tab +PREHOOK: query: LOAD DATA LOCAL INPATH '../data/files/covar_tab.txt' OVERWRITE +INTO TABLE covar_tab +PREHOOK: type: LOAD +POSTHOOK: query: LOAD DATA LOCAL INPATH '../data/files/covar_tab.txt' OVERWRITE +INTO TABLE covar_tab +POSTHOOK: type: LOAD +POSTHOOK: Output: default@covar_tab +PREHOOK: query: DESCRIBE FUNCTION argmax +PREHOOK: type: DESCFUNCTION +POSTHOOK: query: DESCRIBE FUNCTION argmax +POSTHOOK: type: DESCFUNCTION +argmax(x,y) - Returns the value y that maximizes x. +PREHOOK: query: DESCRIBE FUNCTION EXTENDED argmax +PREHOOK: type: DESCFUNCTION +POSTHOOK: query: DESCRIBE FUNCTION EXTENDED argmax +POSTHOOK: type: DESCFUNCTION +argmax(x,y) - Returns the value y that maximizes x. +The function takes as arguments any pair (x,y) where x can be of any type that +supports comparison, and returns the value y that maximizes x. Any pair with +x=NULL will be ignored. If the function is applied to an empty set, NULL will +be returnd. If more than one value of y maximize x, the function returns one +of them arbitrarily. +PREHOOK: query: SELECT argmax(b, c) FROM covar_tab WHERE a < 1 +PREHOOK: type: QUERY +PREHOOK: Input: default@covar_tab +PREHOOK: Output: file:/tmp/nhuyn/hive_2010-08-27_21-11-01_507_6691310769026128174/-mr-10000 +POSTHOOK: query: SELECT argmax(b, c) FROM covar_tab WHERE a < 1 +POSTHOOK: type: QUERY +POSTHOOK: Input: default@covar_tab +POSTHOOK: Output: file:/tmp/nhuyn/hive_2010-08-27_21-11-01_507_6691310769026128174/-mr-10000 +NULL +PREHOOK: query: SELECT argmax(b, c) FROM covar_tab WHERE a < 2 +PREHOOK: type: QUERY +PREHOOK: Input: default@covar_tab +PREHOOK: Output: file:/tmp/nhuyn/hive_2010-08-27_21-11-05_185_6909630609225918573/-mr-10000 +POSTHOOK: query: SELECT argmax(b, c) FROM covar_tab WHERE a < 2 +POSTHOOK: type: QUERY +POSTHOOK: Input: default@covar_tab +POSTHOOK: Output: file:/tmp/nhuyn/hive_2010-08-27_21-11-05_185_6909630609225918573/-mr-10000 +NULL +PREHOOK: query: SELECT * FROM covar_tab +PREHOOK: type: QUERY +PREHOOK: Input: default@covar_tab +PREHOOK: Output: file:/tmp/nhuyn/hive_2010-08-27_21-11-08_447_6687040914262005732/-mr-10000 +POSTHOOK: query: SELECT * FROM covar_tab +POSTHOOK: type: QUERY +POSTHOOK: Input: default@covar_tab +POSTHOOK: Output: file:/tmp/nhuyn/hive_2010-08-27_21-11-08_447_6687040914262005732/-mr-10000 +1 NULL 15 +2 3 NULL +3 7 12 +4 4 14 +5 8 17 +6 2 11 +PREHOOK: query: DROP TABLE covar_tab +PREHOOK: type: DROPTABLE +PREHOOK: Input: default@covar_tab +PREHOOK: Output: default@covar_tab +POSTHOOK: query: DROP TABLE covar_tab +POSTHOOK: type: DROPTABLE +POSTHOOK: Input: default@covar_tab +POSTHOOK: Output: default@covar_tab Index: ql/src/test/queries/clientpositive/udaf_argmax.q =================================================================== --- ql/src/test/queries/clientpositive/udaf_argmax.q (revision 0) +++ ql/src/test/queries/clientpositive/udaf_argmax.q (revision 0) @@ -0,0 +1,17 @@ +DROP TABLE covar_tab; +CREATE TABLE covar_tab (a INT, b INT, c INT) +ROW FORMAT DELIMITED FIELDS TERMINATED BY '\t' +STORED AS TEXTFILE; +LOAD DATA LOCAL INPATH '../data/files/covar_tab.txt' OVERWRITE +INTO TABLE covar_tab; + +DESCRIBE FUNCTION argmax; +DESCRIBE FUNCTION EXTENDED argmax; +SELECT argmax(b, c) FROM covar_tab WHERE a < 1; +SELECT argmax(b, c) FROM covar_tab WHERE a < 2; +SELECT argmax(b, c) FROM covar_tab; +SELECT a, argmax(b, c) FROM covar_tab GROUP BY a ORDER BY a; +SELECT argmax(b, 'anything') FROM covar_tab; +SELECT argmax('anycomparable', 'anything') FROM covar_tab; + +DROP TABLE covar_tab; Index: ql/src/java/org/apache/hadoop/hive/ql/exec/FunctionRegistry.java =================================================================== --- ql/src/java/org/apache/hadoop/hive/ql/exec/FunctionRegistry.java (revision 990399) +++ ql/src/java/org/apache/hadoop/hive/ql/exec/FunctionRegistry.java (working copy) @@ -130,6 +130,7 @@ import org.apache.hadoop.hive.ql.udf.UDFUpper; import org.apache.hadoop.hive.ql.udf.UDFWeekOfYear; import org.apache.hadoop.hive.ql.udf.UDFYear; +import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFArgMax; import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFAverage; import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFBridge; import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFCollectSet; @@ -376,6 +377,7 @@ registerGenericUDAF("histogram_numeric", new GenericUDAFHistogramNumeric()); registerGenericUDAF("percentile_approx", new GenericUDAFPercentileApprox()); registerGenericUDAF("collect_set", new GenericUDAFCollectSet()); + registerGenericUDAF("argmax", new GenericUDAFArgMax()); registerGenericUDAF("ngrams", new GenericUDAFnGrams()); registerGenericUDAF("context_ngrams", new GenericUDAFContextNGrams()); Index: ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDAFArgMax.java =================================================================== --- ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDAFArgMax.java (revision 0) +++ ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDAFArgMax.java (revision 0) @@ -0,0 +1,185 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.hadoop.hive.ql.udf.generic; + +import java.util.ArrayList; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.apache.hadoop.hive.ql.exec.Description; +import org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException; +import org.apache.hadoop.hive.ql.metadata.HiveException; +import org.apache.hadoop.hive.ql.parse.SemanticException; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory; +import org.apache.hadoop.hive.serde2.objectinspector.StructField; +import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector; +import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo; +import org.apache.hadoop.hive.serde2.typeinfo.TypeInfoUtils; + +/** + * Argmax(x, y) computes the value y that maximizes x. If there are more than one such + * y value, returns one arbitrarily. Any pair (NULL,y) will be ignored. If Argmax is + * applied to an empty set, then NULL will be returned. + * + */ +@Description(name = "argmax", + value = "_FUNC_(x,y) - Returns the value y that maximizes x.", + extended = "The function takes as arguments any pair (x,y) where x can be of any type that\n" + + "supports comparison, and returns the value y that maximizes x. Any pair with\n" + + "x=NULL will be ignored. If the function is applied to an empty set, NULL will\n" + + "be returnd. If more than one value of y maximize x, the function returns one\n" + + "of them arbitrarily.") +public class GenericUDAFArgMax extends AbstractGenericUDAFResolver { + + static final Log LOG = LogFactory.getLog(GenericUDAFArgMax.class.getName()); + + @Override + public GenericUDAFEvaluator getEvaluator(TypeInfo[] parameters) throws SemanticException { + if (parameters.length != 2) { + throw new UDFArgumentTypeException(parameters.length - 1, + "Exactly two arguments are expected."); + } + ObjectInspector oi = TypeInfoUtils.getStandardJavaObjectInspectorFromTypeInfo(parameters[0]); + if (!ObjectInspectorUtils.compareSupported(oi)) { + throw new UDFArgumentTypeException(0, "Argument type does not support comparison."); + } + return new GenericUDAFArgMaxEvaluator(); + } // getEvaluator + + /** + * + */ + public static class GenericUDAFArgMaxEvaluator extends GenericUDAFEvaluator { + + // For PARTIAL1 and COMPLETE + private ObjectInspector xInputOI; + private ObjectInspector yInputOI; + + // For PARTIAL2 and FINAL + private StructObjectInspector soi; + private StructField xField; + private StructField yField; + + // For PARTIAL1 and PARTIAL2 +// private Object[] partialResult; + ArrayList partialResult; + + @Override + public ObjectInspector init(Mode m, ObjectInspector[] parameters) + throws HiveException { + super.init(m, parameters); + + // init input object inspectors + if (mode == Mode.PARTIAL1 || mode == Mode.COMPLETE) { + assert (parameters.length == 2); + xInputOI = parameters[0]; + yInputOI = parameters[1]; + } else { + assert (parameters.length == 1); + soi = (StructObjectInspector) parameters[0]; + xField = soi.getStructFieldRef("max"); + yField = soi.getStructFieldRef("argmax"); + } + + // init output object inspectors + if (mode == Mode.PARTIAL1 || mode == Mode.PARTIAL2) { + // The output of a partial aggregation is a struct containing + // a comparable max and an object argmax. + ArrayList foi = new ArrayList(); + foi.add(xInputOI); + foi.add(yInputOI); + ArrayList fname = new ArrayList(); + fname.add("max"); + fname.add("argmax"); +// partialResult = new Object[2]; + partialResult = new ArrayList(2); + return ObjectInspectorFactory.getStandardStructObjectInspector(fname, foi); + } else { + return yInputOI; + } + } // init + + static class MaxAgg implements AggregationBuffer { + Object max; // current max x value + Object argmax; // current y value that maximizes x + }; + + @Override + public AggregationBuffer getNewAggregationBuffer() throws HiveException { + MaxAgg myagg = new MaxAgg(); + reset(myagg); + return myagg; + } + + @Override + public void reset(AggregationBuffer agg) throws HiveException { + MaxAgg myagg = (MaxAgg) agg; + myagg.max = null; + myagg.argmax = null; + } + + @Override + public void iterate(AggregationBuffer agg, Object[] parameters) throws HiveException { + assert (parameters.length == 2); + internalMerge(agg, parameters[0], parameters[1]); + } // iterate + + @Override + public Object terminatePartial(AggregationBuffer agg) throws HiveException { + MaxAgg myagg = (MaxAgg) agg; + partialResult.set(0, myagg.max); + partialResult.set(1, myagg.argmax); + return partialResult; + } // terminatePartial + + @Override + public void merge(AggregationBuffer agg, Object partial) throws HiveException { + if (partial != null) { + Object partialMax = soi.getStructFieldData(partial, xField); + Object partialArgmax = soi.getStructFieldData(partial, yField); + internalMerge(agg, partialMax, partialArgmax); + } + } // merge + + private void internalMerge(AggregationBuffer agg, Object px, Object py) + throws HiveException { + MaxAgg myagg = (MaxAgg) agg; + int r = ObjectInspectorUtils.compare(myagg.max, xInputOI, px, xInputOI); + if (px != null && (myagg.max == null || r < 0)) { + // Replace A with B + myagg.max = ObjectInspectorUtils.copyToStandardObject(px, xInputOI, + ObjectInspectorUtils.ObjectInspectorCopyOption.DEFAULT); + myagg.argmax = ObjectInspectorUtils.copyToStandardObject(py, yInputOI, + ObjectInspectorUtils.ObjectInspectorCopyOption.DEFAULT); + } + } // internal_merge + + @Override + public Object terminate(AggregationBuffer agg) throws HiveException { + MaxAgg myagg = (MaxAgg) agg; + if (myagg.max == null) { + return null; + } else { + return myagg.argmax; + } + } // terminate + + } // GenericUDAFArgMaxEvaluator +} // GenericUDAFArgMax