diff --git ql/src/java/org/apache/hadoop/hive/ql/exec/GroupByOperator.java ql/src/java/org/apache/hadoop/hive/ql/exec/GroupByOperator.java index 46f0ecd..129689f 100644 --- ql/src/java/org/apache/hadoop/hive/ql/exec/GroupByOperator.java +++ ql/src/java/org/apache/hadoop/hive/ql/exec/GroupByOperator.java @@ -352,6 +352,21 @@ protected void initializeOp(Configuration hconf) throws HiveException { } } + // grouping id should be pruned, which is the last of key columns + // see ColumnPrunerGroupByProc + outputKeyLength = conf.pruneGroupingSetId() ? keyFields.length - 1 : keyFields.length; + + // init objectInspectors + ObjectInspector[] objectInspectors = + new ObjectInspector[outputKeyLength + aggregationEvaluators.length]; + for (int i = 0; i < outputKeyLength; i++) { + objectInspectors[i] = currentKeyObjectInspectors[i]; + } + for (int i = 0; i < aggregationEvaluators.length; i++) { + objectInspectors[outputKeyLength + i] = aggregationEvaluators[i].init(conf.getAggregators() + .get(i).getMode(), aggregationParameterObjectInspectors[i]); + } + aggregationsParametersLastInvoke = new Object[conf.getAggregators().size()][]; if ((conf.getMode() != GroupByDesc.Mode.HASH || conf.getBucketGroup()) && (!groupingSetsPresent)) { @@ -374,21 +389,6 @@ protected void initializeOp(Configuration hconf) throws HiveException { List fieldNames = new ArrayList(conf.getOutputColumnNames()); - // grouping id should be pruned, which is the last of key columns - // see ColumnPrunerGroupByProc - outputKeyLength = conf.pruneGroupingSetId() ? keyFields.length - 1 : keyFields.length; - - // init objectInspectors - ObjectInspector[] objectInspectors = - new ObjectInspector[outputKeyLength + aggregationEvaluators.length]; - for (int i = 0; i < outputKeyLength; i++) { - objectInspectors[i] = currentKeyObjectInspectors[i]; - } - for (int i = 0; i < aggregationEvaluators.length; i++) { - objectInspectors[outputKeyLength + i] = aggregationEvaluators[i].init(conf.getAggregators() - .get(i).getMode(), aggregationParameterObjectInspectors[i]); - } - outputObjInspector = ObjectInspectorFactory .getStandardStructObjectInspector(fieldNames, Arrays.asList(objectInspectors)); diff --git ql/src/test/org/apache/hadoop/hive/ql/udf/generic/GenericUDAFSum2.java ql/src/test/org/apache/hadoop/hive/ql/udf/generic/GenericUDAFSum2.java new file mode 100644 index 0000000..882eb74 --- /dev/null +++ ql/src/test/org/apache/hadoop/hive/ql/udf/generic/GenericUDAFSum2.java @@ -0,0 +1,261 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.hadoop.hive.ql.udf.generic; + +import org.apache.hadoop.hive.ql.exec.Description; +import org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException; +import org.apache.hadoop.hive.ql.metadata.HiveException; +import org.apache.hadoop.hive.ql.parse.SemanticException; +import org.apache.hadoop.hive.serde2.io.DoubleWritable; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils; +import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorConverter; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils; +import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo; +import org.apache.hadoop.io.LongWritable; +import org.apache.hadoop.io.Writable; +import org.apache.hadoop.util.StringUtils; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.HashSet; + +/** + * refactored version of GenericUDAFSum for illustrating the issue of HIVE-15513 + */ +@Description(name = "sum", value = "_FUNC_(x) - Returns the sum of a set of numbers") +public class GenericUDAFSum2 extends AbstractGenericUDAFResolver { + + static final Logger LOG = LoggerFactory.getLogger(GenericUDAFSum2.class.getName()); + + @Override + public GenericUDAFEvaluator getEvaluator(TypeInfo[] parameters) + throws SemanticException { + if (parameters.length != 1) { + throw new UDFArgumentTypeException(parameters.length - 1, + "Exactly one argument is expected."); + } + + if (parameters[0].getCategory() != ObjectInspector.Category.PRIMITIVE) { + throw new UDFArgumentTypeException(0, + "Only primitive type arguments are accepted but " + + parameters[0].getTypeName() + " is passed."); + } + return new GenericUDAFSumEvaluator(); + } + + @Override + public GenericUDAFEvaluator getEvaluator(GenericUDAFParameterInfo info) + throws SemanticException { + TypeInfo[] parameters = info.getParameters(); + + GenericUDAFSumEvaluator eval = (GenericUDAFSumEvaluator) getEvaluator(parameters); + eval.setWindowing(info.isWindowing()); + eval.setSumDistinct(info.isDistinct()); + + return eval; + } + + public static class GenericUDAFSumEvaluator extends GenericUDAFEvaluator { + + private boolean warned; + + static abstract class SumAgg extends AbstractAggregationBuffer { + + T sum; + HashSet uniqueObjects = new HashSet<>(); // Unique rows. + + protected SumAgg(){ + initSum(); + } + + protected abstract void initSum(); + + public void add(T v){ + doAdd(v); + uniqueObjects.add(v); + } + + protected abstract void doAdd(T v); + + public void reset(){ + initSum(); + uniqueObjects.clear(); + } + + public abstract Writable getSumAsWritable(); + } + + protected PrimitiveObjectInspector inputOI; + protected PrimitiveObjectInspector outputOI; + protected boolean isWindowing; + protected boolean sumDistinct; + + @Override + public ObjectInspector init(Mode m, ObjectInspector[] parameters) throws HiveException { + assert (parameters.length == 1); + super.init(m, parameters); + inputOI = (PrimitiveObjectInspector) parameters[0]; + switch (inputOI.getPrimitiveCategory()) { + case BYTE: + case SHORT: + case INT: + case LONG: + outputOI = PrimitiveObjectInspectorFactory.writableLongObjectInspector; + break; + case TIMESTAMP: + case FLOAT: + case DOUBLE: + outputOI = PrimitiveObjectInspectorFactory.writableDoubleObjectInspector; + break; + case STRING: + case VARCHAR: + case CHAR: + case DECIMAL: + // TODO + case BOOLEAN: + case DATE: + default: + throw new UDFArgumentTypeException(0, "Only numeric type arguments are accepted but " + inputOI.getTypeName() + " is passed."); + } + return outputOI; + } + + public void setWindowing(boolean isWindowing) { + this.isWindowing = isWindowing; + } + + public void setSumDistinct(boolean sumDistinct) { + this.sumDistinct = sumDistinct; + } + + protected boolean isWindowingDistinct() { + return isWindowing && sumDistinct; + } + + @Override + public AggregationBuffer getNewAggregationBuffer() throws HiveException { + switch (outputOI.getPrimitiveCategory()) { + case BYTE: + case SHORT: + case INT: + case LONG: + return new SumAgg() { + @Override + protected void doAdd(Number v) { + sum = sum.longValue() + v.longValue(); + } + + @Override + public Writable getSumAsWritable() { + return new LongWritable(sum.longValue()); + } + + @Override + protected void initSum() { + sum = 0l; + } + }; + case TIMESTAMP: + case FLOAT: + case DOUBLE: + return new SumAgg() { + @Override + protected void doAdd(Number v) { + sum = sum.doubleValue() + v.doubleValue(); + } + + @Override + public Writable getSumAsWritable() { + return new DoubleWritable(sum.doubleValue()); + } + + @Override + protected void initSum() { + sum = 0.0; + } + }; + case STRING: + case VARCHAR: + case CHAR: + case DECIMAL: + // TODO + case BOOLEAN: + case DATE: + } + throw new UDFArgumentTypeException(0, "Only numeric type arguments are accepted but " + inputOI.getTypeName() + " is passed."); + } + + @Override + public void reset(AggregationBuffer agg) throws HiveException { + ((SumAgg) agg).reset(); + } + + @Override + public void iterate(AggregationBuffer agg, Object[] parameters) throws HiveException { + assert (parameters.length == 1); + try { + if (!isEligibleValue((SumAgg) agg, parameters[0])) return; + ((SumAgg)agg).add(inputOI.getPrimitiveJavaObject(parameters[0])); + + } catch (NumberFormatException e) { + if (!warned) { + warned = true; + LOG.warn(getClass().getSimpleName() + " " + StringUtils.stringifyException(e)); + LOG.warn(getClass().getSimpleName() + " ignoring similar exceptions."); + } + } + } + + @Override + public Object terminatePartial(AggregationBuffer agg) throws HiveException { + return terminate(agg); + } + + @Override + public void merge(AggregationBuffer agg, Object partial) throws HiveException { + if (partial == null) return; + SumAgg myagg = (SumAgg) agg; + if (isWindowingDistinct())throw new HiveException("Distinct windowing UDAF doesn't support merge and terminatePartial"); + myagg.add(inputOI.getPrimitiveJavaObject(partial)); + } + + @Override + public Object terminate(AggregationBuffer agg) throws HiveException { + SumAgg myagg = (SumAgg) agg; + if (myagg.sum == null || myagg.uniqueObjects.isEmpty()) return null; + return myagg.getSumAsWritable(); + } + + /** + * Check if the input object is eligible to contribute to the sum. If it's null + * or the same value as the previous one for the case of SUM(DISTINCT). Then + * skip it. + * @param input the input object + * @return True if sumDistinct is false or the non-null input is different from the previous object + */ + protected boolean isEligibleValue(SumAgg agg, Object input) { + if (input == null) return false; + if (!isWindowingDistinct()) return true; + return !agg.uniqueObjects.contains(input); + } + } + +} diff --git ql/src/test/queries/clientpositive/udaf_aggbuf_based_on_oi.q ql/src/test/queries/clientpositive/udaf_aggbuf_based_on_oi.q new file mode 100644 index 0000000..faa849a --- /dev/null +++ ql/src/test/queries/clientpositive/udaf_aggbuf_based_on_oi.q @@ -0,0 +1,3 @@ +create temporary function sum2 as 'org.apache.hadoop.hive.ql.udf.generic.GenericUDAFSum2'; + +select sum2(CAST(key AS INT)) from src; diff --git ql/src/test/results/clientpositive/udaf_aggbuf_based_on_oi.q.out ql/src/test/results/clientpositive/udaf_aggbuf_based_on_oi.q.out new file mode 100644 index 0000000..60c2515 --- /dev/null +++ ql/src/test/results/clientpositive/udaf_aggbuf_based_on_oi.q.out @@ -0,0 +1,15 @@ +PREHOOK: query: create temporary function sum2 as 'org.apache.hadoop.hive.ql.udf.generic.GenericUDAFSum2' +PREHOOK: type: CREATEFUNCTION +PREHOOK: Output: sum2 +POSTHOOK: query: create temporary function sum2 as 'org.apache.hadoop.hive.ql.udf.generic.GenericUDAFSum2' +POSTHOOK: type: CREATEFUNCTION +POSTHOOK: Output: sum2 +PREHOOK: query: select sum2(CAST(key AS INT)) from src +PREHOOK: type: QUERY +PREHOOK: Input: default@src +#### A masked pattern was here #### +POSTHOOK: query: select sum2(CAST(key AS INT)) from src +POSTHOOK: type: QUERY +POSTHOOK: Input: default@src +#### A masked pattern was here #### +130091