diff --git a/ql/src/java/org/apache/hadoop/hive/ql/parse/WindowingSpec.java b/ql/src/java/org/apache/hadoop/hive/ql/parse/WindowingSpec.java index 5ce7200..ef5186a 100644 --- a/ql/src/java/org/apache/hadoop/hive/ql/parse/WindowingSpec.java +++ b/ql/src/java/org/apache/hadoop/hive/ql/parse/WindowingSpec.java @@ -124,9 +124,6 @@ public void validateAndMakeEffective() throws SemanticException { WindowFunctionSpec wFn = (WindowFunctionSpec) expr; WindowSpec wdwSpec = wFn.getWindowSpec(); - // 0. Precheck supported syntax - precheckSyntax(wFn, wdwSpec); - // 1. For Wdw Specs that refer to Window Defns, inherit missing components if ( wdwSpec != null ) { ArrayList sources = new ArrayList(); @@ -153,14 +150,6 @@ public void validateAndMakeEffective() throws SemanticException { } } - private void precheckSyntax(WindowFunctionSpec wFn, WindowSpec wdwSpec) throws SemanticException { - if (wdwSpec != null ) { - if (wFn.isDistinct && (wdwSpec.windowFrame != null || wdwSpec.getOrder() != null) ) { - throw new SemanticException("Function with DISTINCT cannot work with partition ORDER BY or windowing clause."); - } - } - } - private void fillInWindowSpec(String sourceId, WindowSpec dest, ArrayList visited) throws SemanticException { @@ -509,9 +498,6 @@ protected void ensureOrderSpec(WindowFunctionSpec wFn) throws SemanticException if ( getOrder() == null ) { OrderSpec order = new OrderSpec(); order.prefixBy(getPartition()); - if (wFn.isDistinct) { - order.addExpressions(wFn.getArgs()); - } setOrder(order); } } diff --git a/ql/src/java/org/apache/hadoop/hive/ql/plan/ptf/WindowFunctionDef.java b/ql/src/java/org/apache/hadoop/hive/ql/plan/ptf/WindowFunctionDef.java index ed6c671..84ac614 100644 --- a/ql/src/java/org/apache/hadoop/hive/ql/plan/ptf/WindowFunctionDef.java +++ b/ql/src/java/org/apache/hadoop/hive/ql/plan/ptf/WindowFunctionDef.java @@ -124,4 +124,4 @@ public void setPivotResult(boolean pivotResult) { this.pivotResult = pivotResult; } -} \ No newline at end of file +} diff --git a/ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDAFAverage.java b/ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDAFAverage.java index 3c1ce26..47fe0f8 100644 --- a/ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDAFAverage.java +++ b/ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDAFAverage.java @@ -18,6 +18,7 @@ package org.apache.hadoop.hive.ql.udf.generic; import java.util.ArrayList; +import java.util.HashSet; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -38,6 +39,7 @@ import org.apache.hadoop.hive.serde2.objectinspector.StructField; import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils.ObjectInspectorCopyOption; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils.ObjectInspectorObject; import org.apache.hadoop.hive.serde2.objectinspector.primitive.DoubleObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.primitive.HiveDecimalObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.primitive.LongObjectInspector; @@ -115,7 +117,7 @@ public GenericUDAFEvaluator getEvaluator(GenericUDAFParameterInfo paramInfo) public void doReset(AverageAggregationBuffer aggregation) throws HiveException { aggregation.count = 0; aggregation.sum = new Double(0); - aggregation.previousValue = null; + aggregation.uniqueObjects = new HashSet(); } @Override @@ -145,6 +147,12 @@ protected void doMerge(AverageAggregationBuffer aggregation, Long partia } @Override + protected void doMergeAdd(Double sum, + ObjectInspectorObject obj) { + sum += PrimitiveObjectInspectorUtils.getDouble(obj.getValues()[0], copiedOI); + } + + @Override protected void doTerminatePartial(AverageAggregationBuffer aggregation) { if(partialResult[1] == null) { partialResult[1] = new DoubleWritable(0); @@ -172,6 +180,10 @@ public AggregationBuffer getNewAggregationBuffer() throws HiveException { @Override public GenericUDAFEvaluator getWindowingEvaluator(WindowFrameDef wFrameDef) { + // Don't use streaming for distinct cases + if (avgDistinct) { + return null; + } return new GenericUDAFStreamingEvaluator.SumAvgEnhancer(this, wFrameDef) { @@ -212,6 +224,7 @@ protected DoubleWritable getNextResult( public void doReset(AverageAggregationBuffer aggregation) throws HiveException { aggregation.count = 0; aggregation.sum = HiveDecimal.ZERO; + aggregation.uniqueObjects = new HashSet(); } @Override @@ -263,6 +276,14 @@ protected void doMerge(AverageAggregationBuffer aggregation, Long p } } + + @Override + protected void doMergeAdd( + HiveDecimal sum, + ObjectInspectorObject obj) { + sum.add(PrimitiveObjectInspectorUtils.getHiveDecimal(obj.getValues()[0], copiedOI)); + } + @Override protected void doTerminatePartial(AverageAggregationBuffer aggregation) { if(partialResult[1] == null && aggregation.sum != null) { @@ -296,6 +317,10 @@ public AggregationBuffer getNewAggregationBuffer() throws HiveException { @Override public GenericUDAFEvaluator getWindowingEvaluator(WindowFrameDef wFrameDef) { + // Don't use streaming for distinct cases + if (avgDistinct) { + return null; + } return new GenericUDAFStreamingEvaluator.SumAvgEnhancer( this, wFrameDef) { @@ -333,7 +358,7 @@ protected HiveDecimalWritable getNextResult( } private static class AverageAggregationBuffer implements AggregationBuffer { - private Object previousValue; + private HashSet uniqueObjects; // Unique rows. private long count; private TYPE sum; }; @@ -341,10 +366,9 @@ protected HiveDecimalWritable getNextResult( @SuppressWarnings("unchecked") public static abstract class AbstractGenericUDAFAverageEvaluator extends GenericUDAFEvaluator { protected boolean avgDistinct; - // For PARTIAL1 and COMPLETE protected transient PrimitiveObjectInspector inputOI; - protected transient ObjectInspector copiedOI; + protected transient PrimitiveObjectInspector copiedOI; // For PARTIAL2 and FINAL private transient StructObjectInspector soi; private transient StructField countField; @@ -363,6 +387,7 @@ protected abstract void doIterate(AverageAggregationBuffer aggregation, PrimitiveObjectInspector inputOI, Object parameter); protected abstract void doMerge(AverageAggregationBuffer aggregation, Long partialCount, ObjectInspector sumFieldOI, Object partialSum); + protected abstract void doMergeAdd(TYPE sum, ObjectInspectorObject obj); protected abstract void doTerminatePartial(AverageAggregationBuffer aggregation); protected abstract Object doTerminate(AverageAggregationBuffer aggregation); protected abstract void doReset(AverageAggregationBuffer aggregation) throws HiveException; @@ -376,7 +401,7 @@ public ObjectInspector init(Mode m, ObjectInspector[] parameters) // init input if (mode == Mode.PARTIAL1 || mode == Mode.COMPLETE) { inputOI = (PrimitiveObjectInspector) parameters[0]; - copiedOI = ObjectInspectorUtils.getStandardObjectInspector(inputOI, + copiedOI = (PrimitiveObjectInspector)ObjectInspectorUtils.getStandardObjectInspector(inputOI, ObjectInspectorCopyOption.JAVA); } else { soi = (StructObjectInspector) parameters[0]; @@ -432,12 +457,15 @@ public void iterate(AggregationBuffer aggregation, Object[] parameters) AverageAggregationBuffer averageAggregation = (AverageAggregationBuffer) aggregation; try { // Skip the same value if avgDistinct is true - if (this.avgDistinct && - ObjectInspectorUtils.compare(parameter, inputOI, averageAggregation.previousValue, copiedOI) == 0) { - return; + if (this.avgDistinct) { + ObjectInspectorObject obj = new ObjectInspectorObject( + ObjectInspectorUtils.copyToStandardObject(parameter, inputOI, ObjectInspectorCopyOption.JAVA), + copiedOI); + if (averageAggregation.uniqueObjects.contains(obj)) { + return; + } + averageAggregation.uniqueObjects.add(obj); } - averageAggregation.previousValue = ObjectInspectorUtils.copyToStandardObject( - parameter, inputOI, ObjectInspectorCopyOption.JAVA); doIterate(averageAggregation, inputOI, parameter); } catch (NumberFormatException e) { @@ -451,6 +479,10 @@ public void iterate(AggregationBuffer aggregation, Object[] parameters) @Override public Object terminatePartial(AggregationBuffer aggregation) throws HiveException { + if (avgDistinct) { + return aggregation; + } + doTerminatePartial((AverageAggregationBuffer) aggregation); return partialResult; } @@ -459,9 +491,21 @@ public Object terminatePartial(AggregationBuffer aggregation) throws HiveExcepti public void merge(AggregationBuffer aggregation, Object partial) throws HiveException { if (partial != null) { - doMerge((AverageAggregationBuffer)aggregation, - countFieldOI.get(soi.getStructFieldData(partial, countField)), - sumFieldOI, soi.getStructFieldData(partial, sumField)); + if (avgDistinct) { + AverageAggregationBuffer mergeAgg = (AverageAggregationBuffer)aggregation; + AverageAggregationBuffer partialAgg = (AverageAggregationBuffer)partial; + for(ObjectInspectorObject obj :partialAgg.uniqueObjects) { + if (!mergeAgg.uniqueObjects.contains(obj)) { + mergeAgg.uniqueObjects.add(obj); + ++mergeAgg.count; + doMergeAdd(mergeAgg.sum, obj); + } + } + } else { + doMerge((AverageAggregationBuffer)aggregation, + countFieldOI.get(soi.getStructFieldData(partial, countField)), + sumFieldOI, soi.getStructFieldData(partial, sumField)); + } } } diff --git a/ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDAFCount.java b/ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDAFCount.java index 2825045..82ac129 100644 --- a/ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDAFCount.java +++ b/ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDAFCount.java @@ -17,6 +17,8 @@ */ package org.apache.hadoop.hive.ql.udf.generic; +import java.util.HashSet; + import org.apache.hadoop.hive.ql.exec.Description; import org.apache.hadoop.hive.ql.exec.UDFArgumentException; import org.apache.hadoop.hive.ql.metadata.HiveException; @@ -25,6 +27,7 @@ import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils.ObjectInspectorCopyOption; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils.ObjectInspectorObject; import org.apache.hadoop.hive.serde2.objectinspector.primitive.LongObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory; import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo; @@ -99,6 +102,7 @@ public ObjectInspector init(Mode m, ObjectInspector[] parameters) ObjectInspectorCopyOption.JAVA); } result = new LongWritable(0); + return PrimitiveObjectInspectorFactory.writableLongObjectInspector; } @@ -113,7 +117,7 @@ private void setCountDistinct(boolean countDistinct) { /** class for storing count value. */ @AggregationType(estimable = true) static class CountAgg extends AbstractAggregationBuffer { - Object[] prevColumns = null; // Column values from previous row. Used to compare with current row for the case of COUNT(DISTINCT). + HashSet uniqueObjects; // Unique rows long value; @Override public int estimate() { return JavaDataModel.PRIMITIVES2; } @@ -128,8 +132,8 @@ public AggregationBuffer getNewAggregationBuffer() throws HiveException { @Override public void reset(AggregationBuffer agg) throws HiveException { - ((CountAgg) agg).prevColumns = null; ((CountAgg) agg).value = 0; + ((CountAgg) agg).uniqueObjects = new HashSet(); } @Override @@ -153,17 +157,14 @@ public void iterate(AggregationBuffer agg, Object[] parameters) // Skip the counting if the values are the same for COUNT(DISTINCT) case if (countThisRow && countDistinct) { - Object[] prevColumns = ((CountAgg) agg).prevColumns; - if (prevColumns == null) { - ((CountAgg) agg).prevColumns = new Object[parameters.length]; - } else if (ObjectInspectorUtils.compare(parameters, inputOI, prevColumns, outputOI) == 0) { - countThisRow = false; - } - - // We need to keep a copy of values from previous row. - if (countThisRow) { - ((CountAgg) agg).prevColumns = ObjectInspectorUtils.copyToStandardObject( - parameters, inputOI, ObjectInspectorCopyOption.JAVA); + HashSet uniqueObjs = ((CountAgg) agg).uniqueObjects; + ObjectInspectorObject obj = new ObjectInspectorObject( + ObjectInspectorUtils.copyToStandardObject(parameters, inputOI, ObjectInspectorCopyOption.JAVA), + outputOI); + if (!uniqueObjs.contains(obj)) { + uniqueObjs.add(obj); + } else { + countThisRow = false; } } @@ -177,8 +178,17 @@ public void iterate(AggregationBuffer agg, Object[] parameters) public void merge(AggregationBuffer agg, Object partial) throws HiveException { if (partial != null) { - long p = partialCountAggOI.get(partial); - ((CountAgg) agg).value += p; + CountAgg countAgg = (CountAgg) agg; + + // For count distinct, need to merge 2 sets and recount + if (countDistinct) { + CountAgg partialResult = (CountAgg)partial; + countAgg.uniqueObjects.addAll(partialResult.uniqueObjects); + countAgg.value = countAgg.uniqueObjects.size(); + } else { + long p = partialCountAggOI.get(partial); + countAgg.value += p; + } } } @@ -190,7 +200,11 @@ public Object terminate(AggregationBuffer agg) throws HiveException { @Override public Object terminatePartial(AggregationBuffer agg) throws HiveException { - return terminate(agg); + if (countDistinct) { + return agg; + } else { + return terminate(agg); + } } } } diff --git a/ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDAFSum.java b/ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDAFSum.java index 7b1d6e5..66e6c3c 100644 --- a/ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDAFSum.java +++ b/ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDAFSum.java @@ -17,6 +17,8 @@ */ package org.apache.hadoop.hive.ql.udf.generic; +import java.util.HashSet; + import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.apache.hadoop.hive.common.type.HiveDecimal; @@ -32,6 +34,7 @@ import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils; import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils.ObjectInspectorCopyOption; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils.ObjectInspectorObject; import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory; import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils; import org.apache.hadoop.hive.serde2.typeinfo.DecimalTypeInfo; @@ -39,6 +42,7 @@ import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo; import org.apache.hadoop.hive.serde2.typeinfo.TypeInfoFactory; import org.apache.hadoop.io.LongWritable; +import org.apache.hadoop.io.Writable; import org.apache.hadoop.util.StringUtils; /** @@ -125,15 +129,15 @@ public GenericUDAFEvaluator getEvaluator(GenericUDAFParameterInfo info) * The base type for sum operator evaluator * */ - public static abstract class GenericUDAFSumEvaluator extends GenericUDAFEvaluator { + public static abstract class GenericUDAFSumEvaluator extends GenericUDAFEvaluator { static abstract class SumAgg extends AbstractAggregationBuffer { boolean empty; T sum; - Object previousValue = null; + HashSet uniqueObjects; // Unique rows. } protected PrimitiveObjectInspector inputOI; - protected ObjectInspector outputOI; + protected PrimitiveObjectInspector outputOI; protected ResultType result; protected boolean sumDistinct; @@ -145,6 +149,15 @@ public void setSumDistinct(boolean sumDistinct) { this.sumDistinct = sumDistinct; } + @Override + public Object terminatePartial(AggregationBuffer agg) throws HiveException { + if (sumDistinct) { + return agg; + } else { + return terminate(agg); + } + } + /** * Check if the input object is the same as the previous one for the case of * SUM(DISTINCT). @@ -152,17 +165,21 @@ public void setSumDistinct(boolean sumDistinct) { * @return True if sumDistinct is false or the input is different from the previous object */ protected boolean checkDistinct(SumAgg agg, Object input) { - if (this.sumDistinct && - ObjectInspectorUtils.compare(input, inputOI, agg.previousValue, outputOI) == 0) { + if (this.sumDistinct) { + HashSet uniqueObjs = agg.uniqueObjects; + ObjectInspectorObject obj = new ObjectInspectorObject( + ObjectInspectorUtils.copyToStandardObject(input, inputOI, ObjectInspectorCopyOption.JAVA), + outputOI); + if (!uniqueObjs.contains(obj)) { + uniqueObjs.add(obj); + return true; + } + return false; } - agg.previousValue = ObjectInspectorUtils.copyToStandardObject( - input, inputOI, ObjectInspectorCopyOption.JAVA); return true; } - - } /** @@ -177,7 +194,7 @@ public ObjectInspector init(Mode m, ObjectInspector[] parameters) throws HiveExc super.init(m, parameters); result = new HiveDecimalWritable(HiveDecimal.ZERO); inputOI = (PrimitiveObjectInspector) parameters[0]; - outputOI = ObjectInspectorUtils.getStandardObjectInspector(inputOI, + outputOI = (PrimitiveObjectInspector) ObjectInspectorUtils.getStandardObjectInspector(inputOI, ObjectInspectorCopyOption.JAVA); // The output precision is 10 greater than the input which should cover at least // 10b rows. The scale is the same as the input. @@ -208,6 +225,7 @@ public void reset(AggregationBuffer agg) throws HiveException { SumAgg bdAgg = (SumAgg) agg; bdAgg.empty = true; bdAgg.sum = HiveDecimal.ZERO; + bdAgg.uniqueObjects = new HashSet(); } boolean warned = false; @@ -216,8 +234,9 @@ public void reset(AggregationBuffer agg) throws HiveException { public void iterate(AggregationBuffer agg, Object[] parameters) throws HiveException { assert (parameters.length == 1); try { - if (checkDistinct((SumAgg) agg, parameters[0])) { - merge(agg, parameters[0]); + if (checkDistinct((SumHiveDecimalAgg) agg, parameters[0])) { + ((SumHiveDecimalAgg)agg).empty = false; + ((SumHiveDecimalAgg)agg).sum = ((SumHiveDecimalAgg)agg).sum.add(PrimitiveObjectInspectorUtils.getHiveDecimal(parameters[0], inputOI)); } } catch (NumberFormatException e) { if (!warned) { @@ -232,20 +251,21 @@ public void iterate(AggregationBuffer agg, Object[] parameters) throws HiveExcep } @Override - public Object terminatePartial(AggregationBuffer agg) throws HiveException { - return terminate(agg); - } - - @Override public void merge(AggregationBuffer agg, Object partial) throws HiveException { if (partial != null) { SumHiveDecimalAgg myagg = (SumHiveDecimalAgg) agg; - if (myagg.sum == null) { - return; - } myagg.empty = false; - myagg.sum = myagg.sum.add(PrimitiveObjectInspectorUtils.getHiveDecimal(partial, inputOI)); + if (sumDistinct) { + SumHiveDecimalAgg partialAgg = (SumHiveDecimalAgg)partial; + for (ObjectInspectorObject obj : partialAgg.uniqueObjects) { + if (checkDistinct(myagg, obj)) { + myagg.sum = myagg.sum.add(PrimitiveObjectInspectorUtils.getHiveDecimal(obj.getValues()[0], outputOI)); + } + } + } else { + myagg.sum = myagg.sum.add(PrimitiveObjectInspectorUtils.getHiveDecimal(partial, outputOI)); + } } } @@ -261,6 +281,11 @@ public Object terminate(AggregationBuffer agg) throws HiveException { @Override public GenericUDAFEvaluator getWindowingEvaluator(WindowFrameDef wFrameDef) { + // Don't use streaming for distinct cases + if (sumDistinct) { + return null; + } + return new GenericUDAFStreamingEvaluator.SumAvgEnhancer( this, wFrameDef) { @@ -301,7 +326,7 @@ public ObjectInspector init(Mode m, ObjectInspector[] parameters) throws HiveExc super.init(m, parameters); result = new DoubleWritable(0); inputOI = (PrimitiveObjectInspector) parameters[0]; - outputOI = ObjectInspectorUtils.getStandardObjectInspector(inputOI, + outputOI = (PrimitiveObjectInspector)ObjectInspectorUtils.getStandardObjectInspector(inputOI, ObjectInspectorCopyOption.JAVA); return PrimitiveObjectInspectorFactory.writableDoubleObjectInspector; } @@ -325,6 +350,7 @@ public void reset(AggregationBuffer agg) throws HiveException { SumDoubleAgg myagg = (SumDoubleAgg) agg; myagg.empty = true; myagg.sum = 0.0; + myagg.uniqueObjects = new HashSet(); } boolean warned = false; @@ -333,8 +359,9 @@ public void reset(AggregationBuffer agg) throws HiveException { public void iterate(AggregationBuffer agg, Object[] parameters) throws HiveException { assert (parameters.length == 1); try { - if (checkDistinct((SumAgg) agg, parameters[0])) { - merge(agg, parameters[0]); + if (checkDistinct((SumDoubleAgg) agg, parameters[0])) { + ((SumDoubleAgg)agg).empty = false; + ((SumDoubleAgg)agg).sum += PrimitiveObjectInspectorUtils.getDouble(parameters[0], inputOI); } } catch (NumberFormatException e) { if (!warned) { @@ -349,16 +376,20 @@ public void iterate(AggregationBuffer agg, Object[] parameters) throws HiveExcep } @Override - public Object terminatePartial(AggregationBuffer agg) throws HiveException { - return terminate(agg); - } - - @Override public void merge(AggregationBuffer agg, Object partial) throws HiveException { if (partial != null) { SumDoubleAgg myagg = (SumDoubleAgg) agg; myagg.empty = false; - myagg.sum += PrimitiveObjectInspectorUtils.getDouble(partial, inputOI); + if (sumDistinct) { + SumDoubleAgg partialAgg = (SumDoubleAgg)partial; + for (ObjectInspectorObject obj : partialAgg.uniqueObjects) { + if (checkDistinct(myagg, obj)) { + myagg.sum += PrimitiveObjectInspectorUtils.getDouble(obj.getValues()[0], outputOI); + } + } + } else { + myagg.sum += PrimitiveObjectInspectorUtils.getDouble(partial, outputOI); + } } } @@ -374,6 +405,11 @@ public Object terminate(AggregationBuffer agg) throws HiveException { @Override public GenericUDAFEvaluator getWindowingEvaluator(WindowFrameDef wFrameDef) { + // Don't use streaming for distinct cases + if (sumDistinct) { + return null; + } + return new GenericUDAFStreamingEvaluator.SumAvgEnhancer(this, wFrameDef) { @@ -415,7 +451,7 @@ public ObjectInspector init(Mode m, ObjectInspector[] parameters) throws HiveExc super.init(m, parameters); result = new LongWritable(0); inputOI = (PrimitiveObjectInspector) parameters[0]; - outputOI = ObjectInspectorUtils.getStandardObjectInspector(inputOI, + outputOI = (PrimitiveObjectInspector)ObjectInspectorUtils.getStandardObjectInspector(inputOI, ObjectInspectorCopyOption.JAVA); return PrimitiveObjectInspectorFactory.writableLongObjectInspector; } @@ -439,6 +475,7 @@ public void reset(AggregationBuffer agg) throws HiveException { SumLongAgg myagg = (SumLongAgg) agg; myagg.empty = true; myagg.sum = 0L; + myagg.uniqueObjects = new HashSet(); } private boolean warned = false; @@ -447,8 +484,9 @@ public void reset(AggregationBuffer agg) throws HiveException { public void iterate(AggregationBuffer agg, Object[] parameters) throws HiveException { assert (parameters.length == 1); try { - if (checkDistinct((SumAgg) agg, parameters[0])) { - merge(agg, parameters[0]); + if (checkDistinct((SumLongAgg) agg, parameters[0])) { + ((SumLongAgg)agg).empty = false; + ((SumLongAgg)agg).sum += PrimitiveObjectInspectorUtils.getLong(parameters[0], inputOI); } } catch (NumberFormatException e) { if (!warned) { @@ -460,16 +498,20 @@ public void iterate(AggregationBuffer agg, Object[] parameters) throws HiveExcep } @Override - public Object terminatePartial(AggregationBuffer agg) throws HiveException { - return terminate(agg); - } - - @Override public void merge(AggregationBuffer agg, Object partial) throws HiveException { if (partial != null) { SumLongAgg myagg = (SumLongAgg) agg; - myagg.sum += PrimitiveObjectInspectorUtils.getLong(partial, inputOI); myagg.empty = false; + if (sumDistinct) { + SumLongAgg partialAgg = (SumLongAgg)partial; + for (ObjectInspectorObject obj : partialAgg.uniqueObjects) { + if (checkDistinct(myagg, obj)) { + myagg.sum += PrimitiveObjectInspectorUtils.getLong(obj.getValues()[0], outputOI); + } + } + } else { + myagg.sum += PrimitiveObjectInspectorUtils.getLong(partial, outputOI); + } } } @@ -485,6 +527,11 @@ public Object terminate(AggregationBuffer agg) throws HiveException { @Override public GenericUDAFEvaluator getWindowingEvaluator(WindowFrameDef wFrameDef) { + // Don't use streaming for distinct cases + if (sumDistinct) { + return null; + } + return new GenericUDAFStreamingEvaluator.SumAvgEnhancer(this, wFrameDef) { @@ -509,7 +556,6 @@ protected Long getCurrentIntermediateResult( SumLongAgg myagg = (SumLongAgg) ss.wrappedBuf; return myagg.empty ? null : new Long(myagg.sum); } - }; } } diff --git a/ql/src/java/org/apache/hadoop/hive/ql/udf/ptf/WindowingTableFunction.java b/ql/src/java/org/apache/hadoop/hive/ql/udf/ptf/WindowingTableFunction.java index 858b47a..b89c14e 100644 --- a/ql/src/java/org/apache/hadoop/hive/ql/udf/ptf/WindowingTableFunction.java +++ b/ql/src/java/org/apache/hadoop/hive/ql/udf/ptf/WindowingTableFunction.java @@ -54,6 +54,7 @@ import org.apache.hadoop.hive.ql.plan.ptf.WindowTableFunctionDef; import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator; import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator.AggregationBuffer; +import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFStreamingEvaluator.SumAvgEnhancer; import org.apache.hadoop.hive.ql.udf.generic.ISupportStreamingModeForWindowing; import org.apache.hadoop.hive.serde2.SerDe; import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector; @@ -392,8 +393,6 @@ public void startPartition() throws HiveException { } streamingState.rollingPart.append(row); - row = streamingState.rollingPart - .getAt(streamingState.rollingPart.size() - 1); WindowTableFunctionDef tabDef = (WindowTableFunctionDef) getTableDef(); @@ -408,7 +407,8 @@ public void startPartition() throws HiveException { } } - if (fnEval instanceof ISupportStreamingModeForWindowing) { + if (fnEval != null && + fnEval instanceof ISupportStreamingModeForWindowing) { fnEval.aggregate(streamingState.aggBuffers[i], streamingState.funcArgs[i]); Object out = ((ISupportStreamingModeForWindowing) fnEval) .getNextResult(streamingState.aggBuffers[i]); @@ -472,7 +472,8 @@ public void startPartition() throws HiveException { GenericUDAFEvaluator fnEval = wFn.getWFnEval(); int numRowsRemaining = wFn.getWindowFrame().getEnd().getRelativeOffset(); - if (fnEval instanceof ISupportStreamingModeForWindowing) { + if (fnEval != null && + fnEval instanceof ISupportStreamingModeForWindowing) { fnEval.terminate(streamingState.aggBuffers[i]); WindowingFunctionInfoHelper wFnInfo = getWindowingFunctionInfoHelper(wFn.getName()); diff --git a/ql/src/test/queries/clientpositive/windowing_distinct.q b/ql/src/test/queries/clientpositive/windowing_distinct.q index bb192a7..6b49978 100644 --- a/ql/src/test/queries/clientpositive/windowing_distinct.q +++ b/ql/src/test/queries/clientpositive/windowing_distinct.q @@ -44,3 +44,21 @@ SELECT AVG(DISTINCT t) OVER (PARTITION BY index), AVG(DISTINCT ts) OVER (PARTITION BY index), AVG(DISTINCT dec) OVER (PARTITION BY index) FROM windowing_distinct; + +-- count +select index, f, count(distinct f) over (partition by index order by f rows between 2 preceding and 1 preceding), + count(distinct f) over (partition by index order by f rows between unbounded preceding and 1 preceding), + count(distinct f) over (partition by index order by f rows between 1 following and 2 following), + count(distinct f) over (partition by index order by f rows between unbounded preceding and 1 following) from windowing_distinct; + +-- sum +select index, f, sum(distinct f) over (partition by index order by f rows between 2 preceding and 1 preceding), + sum(distinct f) over (partition by index order by f rows between unbounded preceding and 1 preceding), + sum(distinct f) over (partition by index order by f rows between 1 following and 2 following), + sum(distinct f) over (partition by index order by f rows between unbounded preceding and 1 following) from windowing_distinct; + +-- avg +select index, f, avg(distinct f) over (partition by index order by f rows between 2 preceding and 1 preceding), + avg(distinct f) over (partition by index order by f rows between unbounded preceding and 1 preceding), + avg(distinct f) over (partition by index order by f rows between 1 following and 2 following), + avg(distinct f) over (partition by index order by f rows between unbounded preceding and 1 following) from windowing_distinct; diff --git a/ql/src/test/results/clientpositive/windowing_distinct.q.out b/ql/src/test/results/clientpositive/windowing_distinct.q.out index 074a594..86d1cdd 100644 --- a/ql/src/test/results/clientpositive/windowing_distinct.q.out +++ b/ql/src/test/results/clientpositive/windowing_distinct.q.out @@ -128,3 +128,69 @@ POSTHOOK: Input: default@windowing_distinct 117.5 38.71 NULL NULL 1.362157918703306E9 34.5000 117.5 38.71 NULL NULL 1.362157918703306E9 34.5000 117.5 38.71 NULL NULL 1.362157918703306E9 34.5000 +PREHOOK: query: -- count +select index, f, count(distinct f) over (partition by index order by f rows between 2 preceding and 1 preceding), + count(distinct f) over (partition by index order by f rows between unbounded preceding and 1 preceding), + count(distinct f) over (partition by index order by f rows between 1 following and 2 following), + count(distinct f) over (partition by index order by f rows between unbounded preceding and 1 following) from windowing_distinct +PREHOOK: type: QUERY +PREHOOK: Input: default@windowing_distinct +#### A masked pattern was here #### +POSTHOOK: query: -- count +select index, f, count(distinct f) over (partition by index order by f rows between 2 preceding and 1 preceding), + count(distinct f) over (partition by index order by f rows between unbounded preceding and 1 preceding), + count(distinct f) over (partition by index order by f rows between 1 following and 2 following), + count(distinct f) over (partition by index order by f rows between unbounded preceding and 1 following) from windowing_distinct +POSTHOOK: type: QUERY +POSTHOOK: Input: default@windowing_distinct +#### A masked pattern was here #### +1 26.43 0 0 2 1 +1 26.43 1 1 1 2 +1 96.91 1 1 0 2 +2 13.01 0 0 1 2 +2 74.72 1 1 1 2 +2 74.72 2 2 0 2 +PREHOOK: query: -- sum +select index, f, sum(distinct f) over (partition by index order by f rows between 2 preceding and 1 preceding), + sum(distinct f) over (partition by index order by f rows between unbounded preceding and 1 preceding), + sum(distinct f) over (partition by index order by f rows between 1 following and 2 following), + sum(distinct f) over (partition by index order by f rows between unbounded preceding and 1 following) from windowing_distinct +PREHOOK: type: QUERY +PREHOOK: Input: default@windowing_distinct +#### A masked pattern was here #### +POSTHOOK: query: -- sum +select index, f, sum(distinct f) over (partition by index order by f rows between 2 preceding and 1 preceding), + sum(distinct f) over (partition by index order by f rows between unbounded preceding and 1 preceding), + sum(distinct f) over (partition by index order by f rows between 1 following and 2 following), + sum(distinct f) over (partition by index order by f rows between unbounded preceding and 1 following) from windowing_distinct +POSTHOOK: type: QUERY +POSTHOOK: Input: default@windowing_distinct +#### A masked pattern was here #### +1 26.43 NULL NULL 123.34000396728516 26.43000030517578 +1 26.43 26.43000030517578 26.43000030517578 96.91000366210938 123.34000396728516 +1 96.91 26.43000030517578 26.43000030517578 NULL 123.34000396728516 +2 13.01 NULL NULL 74.72000122070312 87.73000144958496 +2 74.72 13.010000228881836 13.010000228881836 74.72000122070312 87.73000144958496 +2 74.72 87.73000144958496 87.73000144958496 NULL 87.73000144958496 +PREHOOK: query: -- avg +select index, f, avg(distinct f) over (partition by index order by f rows between 2 preceding and 1 preceding), + avg(distinct f) over (partition by index order by f rows between unbounded preceding and 1 preceding), + avg(distinct f) over (partition by index order by f rows between 1 following and 2 following), + avg(distinct f) over (partition by index order by f rows between unbounded preceding and 1 following) from windowing_distinct +PREHOOK: type: QUERY +PREHOOK: Input: default@windowing_distinct +#### A masked pattern was here #### +POSTHOOK: query: -- avg +select index, f, avg(distinct f) over (partition by index order by f rows between 2 preceding and 1 preceding), + avg(distinct f) over (partition by index order by f rows between unbounded preceding and 1 preceding), + avg(distinct f) over (partition by index order by f rows between 1 following and 2 following), + avg(distinct f) over (partition by index order by f rows between unbounded preceding and 1 following) from windowing_distinct +POSTHOOK: type: QUERY +POSTHOOK: Input: default@windowing_distinct +#### A masked pattern was here #### +1 26.43 NULL NULL 61.67000198364258 26.43000030517578 +1 26.43 26.43000030517578 26.43000030517578 96.91000366210938 61.67000198364258 +1 96.91 26.43000030517578 26.43000030517578 NULL 61.67000198364258 +2 13.01 NULL NULL 74.72000122070312 43.86500072479248 +2 74.72 13.010000228881836 13.010000228881836 74.72000122070312 43.86500072479248 +2 74.72 43.86500072479248 43.86500072479248 NULL 43.86500072479248 diff --git a/serde/src/java/org/apache/hadoop/hive/serde2/objectinspector/ObjectInspectorUtils.java b/serde/src/java/org/apache/hadoop/hive/serde2/objectinspector/ObjectInspectorUtils.java index c58e8ed..1ac72c6 100644 --- a/serde/src/java/org/apache/hadoop/hive/serde2/objectinspector/ObjectInspectorUtils.java +++ b/serde/src/java/org/apache/hadoop/hive/serde2/objectinspector/ObjectInspectorUtils.java @@ -117,6 +117,44 @@ } /** + * This class can be used to wrap Hive objects and put in HashMap or HashSet. + * The objects will be compared using ObjectInspectors. + * + */ + public static class ObjectInspectorObject { + private final Object[] objects; + private final ObjectInspector[] oi; + + public ObjectInspectorObject(Object object, ObjectInspector oi) { + this.objects = new Object[] { object }; + this.oi = new ObjectInspector[] { oi }; + } + + public ObjectInspectorObject(Object[] objects, ObjectInspector[] oi) { + this.objects = objects; + this.oi = oi; + } + + public Object[] getValues() { + return objects; + } + + @Override + public boolean equals(Object obj) { + if (this == obj) return true; + if (obj == null || obj.getClass() != this.getClass()) { return false; } + + ObjectInspectorObject comparedObject = (ObjectInspectorObject)obj; + return ObjectInspectorUtils.compare(objects, oi, comparedObject.objects, comparedObject.oi) == 0; + } + + @Override + public int hashCode() { + return ObjectInspectorUtils.getBucketHashCode(objects, oi); + } + } + + /** * Calculates the hash code for array of Objects that contains writables. This is used * to work around the buggy Hadoop DoubleWritable hashCode implementation. This should * only be used for process-local hash codes; don't replace stored hash codes like bucketing.