diff --git ql/src/java/org/apache/hadoop/hive/ql/exec/FunctionRegistry.java ql/src/java/org/apache/hadoop/hive/ql/exec/FunctionRegistry.java index 47e68fd..6dd2f06 100644 --- ql/src/java/org/apache/hadoop/hive/ql/exec/FunctionRegistry.java +++ ql/src/java/org/apache/hadoop/hive/ql/exec/FunctionRegistry.java @@ -370,6 +370,7 @@ registerGenericUDAF("histogram_numeric", new GenericUDAFHistogramNumeric()); registerGenericUDAF("percentile_approx", new GenericUDAFPercentileApprox()); registerGenericUDAF("collect_set", new GenericUDAFCollectSet()); + registerGenericUDAF("collect_map", new GenericUDAFCollectMap()); registerGenericUDAF("ngrams", new GenericUDAFnGrams()); registerGenericUDAF("context_ngrams", new GenericUDAFContextNGrams()); diff --git ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDAFCollectMap.java ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDAFCollectMap.java new file mode 100644 index 0000000..0a36585 --- /dev/null +++ ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDAFCollectMap.java @@ -0,0 +1,172 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.hadoop.hive.ql.udf.generic; + +import java.util.HashMap; +import java.util.Map; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.apache.hadoop.hive.ql.exec.Description; +import org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException; +import org.apache.hadoop.hive.ql.metadata.HiveException; +import org.apache.hadoop.hive.ql.parse.SemanticException; +import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator.AggregationBuffer; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils; +import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.StandardMapObjectInspector; +import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo; + +/** + * GenericUDAFCollectMap + */ +@Description(name = "collect_map", value = "_FUNC_(x,y) - Returns a map of entries formed by taking x as the key and y as the value. " + + + "Groups with duplicate keys will contain one entry. the value of the entry in indeterminate") +public class GenericUDAFCollectMap extends AbstractGenericUDAFResolver { + + static final Log LOG = LogFactory.getLog(GenericUDAFCollectMap.class.getName()); + + public GenericUDAFCollectMap() { + } + + @Override + public GenericUDAFEvaluator getEvaluator(TypeInfo[] parameters) + throws SemanticException { + + if (parameters.length != 2) { + throw new UDFArgumentTypeException(parameters.length - 1, + "Exactly two arguments are expected."); + } + + if (parameters[0].getCategory() != ObjectInspector.Category.PRIMITIVE) { + throw new UDFArgumentTypeException(0, + "Only primitive type arguments are accepted but " + + parameters[0].getTypeName() + " was passed as parameter 1."); + } + + return new GenericUDAFCollectMapEvaluator(); + } + + static class CollectMapAggregationBuffer implements AggregationBuffer { + Map container; + } + + public static class GenericUDAFCollectMapEvaluator extends GenericUDAFEvaluator { + + private PrimitiveObjectInspector keyOI; + private ObjectInspector valueOI; + + private StandardMapObjectInspector internalMergeOI; + + @Override + public ObjectInspector init(Mode m, ObjectInspector[] parameters) + throws HiveException { + super.init(m, parameters); + // init output object inspectors + // The output of a partial aggregation is a list + if (m == Mode.PARTIAL1) { + keyOI = (PrimitiveObjectInspector) parameters[0]; + valueOI = parameters[1]; + return ObjectInspectorFactory + .getStandardMapObjectInspector((PrimitiveObjectInspector) ObjectInspectorUtils + .getStandardObjectInspector(keyOI), + ObjectInspectorUtils + .getStandardObjectInspector(valueOI)); + } else if (m == Mode.PARTIAL2 || m == Mode.FINAL) { + internalMergeOI = (StandardMapObjectInspector) parameters[0]; + keyOI = (PrimitiveObjectInspector) internalMergeOI.getMapKeyObjectInspector(); + valueOI = internalMergeOI.getMapValueObjectInspector(); + return ObjectInspectorUtils.getStandardObjectInspector(internalMergeOI); + } + else { + keyOI = (PrimitiveObjectInspector) ObjectInspectorUtils + .getStandardObjectInspector(parameters[0]); + valueOI = ObjectInspectorUtils.getStandardObjectInspector(parameters[1]); + return ObjectInspectorFactory + .getStandardMapObjectInspector((PrimitiveObjectInspector) ObjectInspectorUtils + .getStandardObjectInspector(keyOI), + ObjectInspectorUtils + .getStandardObjectInspector(valueOI)); + } + } + + @Override + public void reset(AggregationBuffer agg) throws HiveException { + ((CollectMapAggregationBuffer) agg).container = new HashMap(); + } + + @Override + public AggregationBuffer getNewAggregationBuffer() throws HiveException { + CollectMapAggregationBuffer ret = new CollectMapAggregationBuffer(); + reset(ret); + return ret; + } + + @Override + public void iterate(AggregationBuffer agg, Object[] parameters) + throws HiveException { + assert (parameters.length == 2); + Object k = parameters[0]; + Object v = parameters[1]; + + if (k != null) { + CollectMapAggregationBuffer myagg = (CollectMapAggregationBuffer) agg; + put(k, v, myagg); + } + } + + @Override + public Object terminatePartial(AggregationBuffer agg) throws HiveException { + CollectMapAggregationBuffer myagg = (CollectMapAggregationBuffer) agg; + HashMap ret = new HashMap(myagg.container.size()); + ret.putAll(myagg.container); + return ret; + } + + @Override + public void merge(AggregationBuffer agg, Object partial) + throws HiveException { + CollectMapAggregationBuffer myagg = (CollectMapAggregationBuffer) agg; + Map partialResult = (Map) internalMergeOI.getMap(partial); + for (Map.Entry e : partialResult.entrySet()) { + put(e.getKey(), e.getValue(), myagg); + } + } + + @Override + public Object terminate(AggregationBuffer agg) throws HiveException { + CollectMapAggregationBuffer myagg = (CollectMapAggregationBuffer) agg; + HashMap ret = new HashMap(myagg.container.size()); + ret.putAll(myagg.container); + return ret; + } + + private void put(Object k, Object v, CollectMapAggregationBuffer myagg) { + Object kCopy = ObjectInspectorUtils.copyToStandardObject(k, + this.keyOI); + Object vCopy = ObjectInspectorUtils.copyToStandardObject(v, + this.valueOI); + myagg.container.put(kCopy, vCopy); + } + } + +}