diff --git a/ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDAFAverage.java b/ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDAFAverage.java index 5ad5c06..f163b05 100644 --- a/ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDAFAverage.java +++ b/ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDAFAverage.java @@ -19,18 +19,22 @@ import java.util.ArrayList; import java.util.HashSet; +import java.util.List; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.apache.hadoop.hive.common.type.HiveDecimal; import org.apache.hadoop.hive.ql.exec.Description; +import org.apache.hadoop.hive.ql.exec.PTFPartition; import org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException; import org.apache.hadoop.hive.ql.metadata.HiveException; import org.apache.hadoop.hive.ql.parse.SemanticException; +import org.apache.hadoop.hive.ql.plan.ptf.PTFExpressionDef; import org.apache.hadoop.hive.ql.plan.ptf.WindowFrameDef; import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator.AbstractAggregationBuffer; import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator.AggregationBuffer; import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator.AggregationType; +import org.apache.hadoop.hive.ql.udf.ptf.BasePartitionEvaluator; import org.apache.hadoop.hive.ql.util.JavaDataModel; import org.apache.hadoop.hive.serde2.io.DoubleWritable; import org.apache.hadoop.hive.serde2.io.HiveDecimalWritable; @@ -219,6 +223,19 @@ protected DoubleWritable getNextResult( }; } + + @Override + protected BasePartitionEvaluator createPartitionEvaluator( + WindowFrameDef winFrame, + PTFPartition partition, + List parameters, + ObjectInspector outputOI) { + try { + return new BasePartitionEvaluator.AvgPartitionDoubleEvaluator(this, winFrame, partition, parameters, inputOI, outputOI); + } catch(HiveException e) { + return super.createPartitionEvaluator(winFrame, partition, parameters, outputOI); + } + } } public static class GenericUDAFAverageEvaluatorDecimal extends AbstractGenericUDAFAverageEvaluator { @@ -358,6 +375,19 @@ protected HiveDecimalWritable getNextResult( }; } + + @Override + protected BasePartitionEvaluator createPartitionEvaluator( + WindowFrameDef winFrame, + PTFPartition partition, + List parameters, + ObjectInspector outputOI) { + try { + return new BasePartitionEvaluator.AvgPartitionHiveDecimalEvaluator(this, winFrame, partition, parameters, inputOI, outputOI); + } catch(HiveException e) { + return super.createPartitionEvaluator(winFrame, partition, parameters, outputOI); + } + } } @AggregationType(estimable = true) @@ -409,6 +439,8 @@ public ObjectInspector init(Mode m, ObjectInspector[] parameters) super.init(m, parameters); // init input + partialResult = new Object[2]; + partialResult[0] = new LongWritable(0); if (mode == Mode.PARTIAL1 || mode == Mode.COMPLETE) { inputOI = (PrimitiveObjectInspector) parameters[0]; copiedOI = (PrimitiveObjectInspector)ObjectInspectorUtils.getStandardObjectInspector(inputOI, @@ -436,8 +468,6 @@ public ObjectInspector init(Mode m, ObjectInspector[] parameters) fname.add("count"); fname.add("sum"); fname.add("input"); - partialResult = new Object[2]; - partialResult[0] = new LongWritable(0); // index 1 set by child return ObjectInspectorFactory.getStandardStructObjectInspector(fname, foi); } else { @@ -445,7 +475,7 @@ public ObjectInspector init(Mode m, ObjectInspector[] parameters) } } - protected boolean isWindowingDistinct() { + public boolean isWindowingDistinct() { return isWindowing && avgDistinct; } diff --git a/ql/src/java/org/apache/hadoop/hive/ql/udf/ptf/BasePartitionEvaluator.java b/ql/src/java/org/apache/hadoop/hive/ql/udf/ptf/BasePartitionEvaluator.java index f5f9f7b..19d0864 100644 --- a/ql/src/java/org/apache/hadoop/hive/ql/udf/ptf/BasePartitionEvaluator.java +++ b/ql/src/java/org/apache/hadoop/hive/ql/udf/ptf/BasePartitionEvaluator.java @@ -19,6 +19,7 @@ import java.util.List; +import org.apache.hadoop.hive.common.type.HiveDecimal; import org.apache.hadoop.hive.ql.exec.PTFOperator; import org.apache.hadoop.hive.ql.exec.PTFPartition; import org.apache.hadoop.hive.ql.exec.PTFPartition.PTFPartitionIterator; @@ -30,9 +31,12 @@ import org.apache.hadoop.hive.ql.plan.ptf.BoundaryDef; import org.apache.hadoop.hive.ql.plan.ptf.PTFExpressionDef; import org.apache.hadoop.hive.ql.plan.ptf.WindowFrameDef; +import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFAverage.AbstractGenericUDAFAverageEvaluator; import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator; import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator.AbstractAggregationBuffer; import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator.AggregationBuffer; +import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator.Mode; +import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFSum; import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFSum.GenericUDAFSumEvaluator; import org.apache.hadoop.hive.serde2.io.DoubleWritable; import org.apache.hadoop.hive.serde2.io.HiveDecimalWritable; @@ -84,6 +88,96 @@ public int getSize() { } } + /** + * Define some type specific operation to used in the subclass + */ + private static abstract class TypeOperationBase { + public abstract ResultType add(ResultType t1, ResultType t2); + public abstract ResultType minus(ResultType t1, ResultType t2); + public abstract ResultType div(ResultType sum, long numRows); + } + + private static class TypeOperationLongWritable extends TypeOperationBase { + @Override + public LongWritable add(LongWritable t1, LongWritable t2) { + if (t1 == null && t2 == null) return null; + return new LongWritable((t1 == null ? 0 : t1.get()) + (t2 == null ? 0 : t2.get())); + } + + @Override + public LongWritable minus(LongWritable t1, LongWritable t2) { + if (t1 == null && t2 == null) return null; + return new LongWritable((t1 == null ? 0 : t1.get()) - (t2 == null ? 0 : t2.get())); + } + + @Override + public LongWritable div(LongWritable sum, long numRows) { + return null; // Not used + } + } + + private static class TypeOperationDoubleWritable extends TypeOperationBase { + @Override + public DoubleWritable add(DoubleWritable t1, DoubleWritable t2) { + if (t1 == null && t2 == null) return null; + return new DoubleWritable((t1 == null ? 0 : t1.get()) + (t2 == null ? 0 : t2.get())); + } + + public DoubleWritable minus(DoubleWritable t1, DoubleWritable t2) { + if (t1 == null && t2 == null) return null; + return new DoubleWritable((t1 == null ? 0 : t1.get()) - (t2 == null ? 0 : t2.get())); + } + + @Override + public DoubleWritable div(DoubleWritable sum, long numRows) { + if (sum == null || numRows == 0) return null; + + return new DoubleWritable(sum.get() / (double)numRows); + } + } + + private static class TypeOperationHiveDecimalWritable extends TypeOperationBase { + @Override + public HiveDecimalWritable div(HiveDecimalWritable sum, long numRows) { + if (sum == null || numRows == 0) return null; + + HiveDecimalWritable result = new HiveDecimalWritable(sum); + result.mutateDivide(HiveDecimal.create(numRows)); + return result; + } + + @Override + public HiveDecimalWritable add(HiveDecimalWritable t1, HiveDecimalWritable t2) { + if (t1 == null && t2 == null) return null; + + if (t1 == null) { + return new HiveDecimalWritable(t2); + } else { + HiveDecimalWritable result = new HiveDecimalWritable(t1); + if (t2 != null) { + result.mutateAdd(t2); + } + return result; + } + } + + @Override + public HiveDecimalWritable minus(HiveDecimalWritable t1, HiveDecimalWritable t2) { + if (t1 == null && t2 == null) return null; + + if (t2 == null) { + return new HiveDecimalWritable(t1); + } else { + HiveDecimalWritable result = new HiveDecimalWritable(t2); + result.mutateNegate(); + if (t1 != null) { + result.mutateAdd(t1); + } + return result; + } + } + } + public BasePartitionEvaluator( GenericUDAFEvaluator wrappedEvaluator, WindowFrameDef winFrame, @@ -217,7 +311,14 @@ private static int getRowBoundaryEnd(BoundaryDef b, int currRow, PTFPartition p) * */ public static abstract class SumPartitionEvaluator extends BasePartitionEvaluator { + static class WindowSumAgg extends AbstractAggregationBuffer { + Range prevRange; + ResultType prevSum; + boolean empty; + } + protected final WindowSumAgg sumAgg; + protected TypeOperationBase typeOperation; public SumPartitionEvaluator( GenericUDAFEvaluator wrappedEvaluator, @@ -229,15 +330,6 @@ public SumPartitionEvaluator( sumAgg = new WindowSumAgg(); } - static class WindowSumAgg extends AbstractAggregationBuffer { - Range prevRange; - ResultType prevSum; - boolean empty; - } - - public abstract ResultType add(ResultType t1, ResultType t2); - public abstract ResultType minus(ResultType t1, ResultType t2); - @SuppressWarnings({ "unchecked", "rawtypes" }) @Override public Object iterate(int currentRow, LeadLagInfo leadLagInfo) throws HiveException { @@ -262,7 +354,8 @@ public Object iterate(int currentRow, LeadLagInfo leadLagInfo) throws HiveExcept Range r2 = new Range(sumAgg.prevRange.end, currentRange.end, partition); ResultType sum1 = (ResultType)calcFunctionValue(r1.iterator(), leadLagInfo); ResultType sum2 = (ResultType)calcFunctionValue(r2.iterator(), leadLagInfo); - result = add(minus(sumAgg.prevSum, sum1), sum2); + result = typeOperation.add(typeOperation.minus(sumAgg.prevSum, sum1), sum2); + sumAgg.prevRange = currentRange; sumAgg.prevSum = result; } @@ -276,18 +369,7 @@ public SumPartitionDoubleEvaluator(GenericUDAFEvaluator wrappedEvaluator, WindowFrameDef winFrame, PTFPartition partition, List parameters, ObjectInspector outputOI) { super(wrappedEvaluator, winFrame, partition, parameters, outputOI); - } - - @Override - public DoubleWritable add(DoubleWritable t1, DoubleWritable t2) { - if (t1 == null && t2 == null) return null; - return new DoubleWritable((t1 == null ? 0 : t1.get()) + (t2 == null ? 0 : t2.get())); - } - - @Override - public DoubleWritable minus(DoubleWritable t1, DoubleWritable t2) { - if (t1 == null && t2 == null) return null; - return new DoubleWritable((t1 == null ? 0 : t1.get()) - (t2 == null ? 0 : t2.get())); + this.typeOperation = new TypeOperationDoubleWritable(); } } @@ -296,18 +378,7 @@ public SumPartitionLongEvaluator(GenericUDAFEvaluator wrappedEvaluator, WindowFrameDef winFrame, PTFPartition partition, List parameters, ObjectInspector outputOI) { super(wrappedEvaluator, winFrame, partition, parameters, outputOI); - } - - @Override - public LongWritable add(LongWritable t1, LongWritable t2) { - if (t1 == null && t2 == null) return null; - return new LongWritable((t1 == null ? 0 : t1.get()) + (t2 == null ? 0 : t2.get())); - } - - @Override - public LongWritable minus(LongWritable t1, LongWritable t2) { - if (t1 == null && t2 == null) return null; - return new LongWritable((t1 == null ? 0 : t1.get()) - (t2 == null ? 0 : t2.get())); + this.typeOperation = new TypeOperationLongWritable(); } } @@ -316,33 +387,119 @@ public SumPartitionHiveDecimalEvaluator(GenericUDAFEvaluator wrappedEvaluator, WindowFrameDef winFrame, PTFPartition partition, List parameters, ObjectInspector outputOI) { super(wrappedEvaluator, winFrame, partition, parameters, outputOI); + this.typeOperation = new TypeOperationHiveDecimalWritable(); } + } - @Override - public HiveDecimalWritable add(HiveDecimalWritable t1, HiveDecimalWritable t2) { - if (t1 == null && t2 == null) return null; - if (t1 == null) { - return t2; - } else { - if (t2 != null) { - t1.mutateAdd(t2); + /** + * The partition evalulator for average function + * @param + */ + public static abstract class AvgPartitionEvaluator + extends BasePartitionEvaluator { + static class WindowAvgAgg extends AbstractAggregationBuffer { + Range prevRange; + ResultType prevSum; + long prevCount; + boolean empty; + } + + protected SumPartitionEvaluator sumEvaluator; + protected TypeOperationBase typeOperation; + WindowAvgAgg avgAgg = new WindowAvgAgg(); + + public AvgPartitionEvaluator( + GenericUDAFEvaluator wrappedEvaluator, + WindowFrameDef winFrame, + PTFPartition partition, + List parameters, + ObjectInspector outputOI) { + super(wrappedEvaluator, winFrame, partition, parameters, outputOI); + } + + /** + * Calculate the partial result sum + count giving a parition range + * @return a 2-element Object array of [count long, sum ResultType] + */ + private Object[] calcPartialResult(PTFPartitionIterator pItr, LeadLagInfo leadLagInfo) + throws HiveException { + // To handle the case like SUM(LAG(f)) over(), aggregation function includes + // LAG/LEAD call + PTFOperator.connectLeadLagFunctionsToPartition(leadLagInfo, pItr); + + AggregationBuffer aggBuffer = wrappedEvaluator.getNewAggregationBuffer(); + Object[] argValues = new Object[parameters == null ? 0 : parameters.size()]; + while(pItr.hasNext()) + { + Object row = pItr.next(); + int i = 0; + if ( parameters != null ) { + for(PTFExpressionDef param : parameters) + { + argValues[i++] = param.getExprEvaluator().evaluate(row); + } } - return t1; + wrappedEvaluator.aggregate(aggBuffer, argValues); } + + // The object [count LongWritable, sum ResultType] is reused during evaluating + Object[] partial = (Object[])wrappedEvaluator.terminatePartial(aggBuffer); + return new Object[] {((LongWritable)partial[0]).get(), ObjectInspectorUtils.copyToStandardObject(partial[1], outputOI)}; } + @SuppressWarnings({ "unchecked", "rawtypes" }) @Override - public HiveDecimalWritable minus(HiveDecimalWritable t1, HiveDecimalWritable t2) { - if (t1 == null && t2 == null) return null; - if (t1 == null) { - t2.mutateNegate(); - return t2; + public Object iterate(int currentRow, LeadLagInfo leadLagInfo) throws HiveException { + // // Currently avg(distinct) not supported in PartitionEvaluator + if (((AbstractGenericUDAFAverageEvaluator)wrappedEvaluator).isWindowingDistinct()) { + return super.iterate(currentRow, leadLagInfo); + } + + Range currentRange = getRange(winFrame, currentRow, partition); + if (currentRow == 0 || // Reset for the new partition + avgAgg.prevRange == null || + currentRange.getSize() <= currentRange.getDiff(avgAgg.prevRange)) { + Object[] partial = (Object[])calcPartialResult(currentRange.iterator(), leadLagInfo); + avgAgg.prevRange = currentRange; + avgAgg.empty = false; + avgAgg.prevSum = (ResultType)partial[1]; + avgAgg.prevCount = (long)partial[0]; } else { - if (t2 != null) { - t1.mutateSubtract(t2); - } - return t1; + // Given the previous range and the current range, calculate the new sum + // from the previous sum and the difference to save the computation. + Range r1 = new Range(avgAgg.prevRange.start, currentRange.start, partition); + Range r2 = new Range(avgAgg.prevRange.end, currentRange.end, partition); + Object[] partial1 = (Object[])calcPartialResult(r1.iterator(), leadLagInfo); + Object[] partial2 = (Object[])calcPartialResult(r2.iterator(), leadLagInfo); + ResultType sum = typeOperation.add(typeOperation.minus(avgAgg.prevSum, (ResultType)partial1[1]), (ResultType)partial2[1]); + long count = avgAgg.prevCount - (long)partial1[0]+ (long)partial2[0]; + + avgAgg.prevRange = currentRange; + avgAgg.prevSum = sum; + avgAgg.prevCount = count; } + + return typeOperation.div(avgAgg.prevSum, avgAgg.prevCount); + } + } + + public static class AvgPartitionDoubleEvaluator extends AvgPartitionEvaluator { + + public AvgPartitionDoubleEvaluator(GenericUDAFEvaluator wrappedEvaluator, + WindowFrameDef winFrame, PTFPartition partition, + List parameters, ObjectInspector inputOI, ObjectInspector outputOI) throws HiveException { + super(wrappedEvaluator, winFrame, partition, parameters, outputOI); + this.typeOperation = new TypeOperationDoubleWritable(); + } + } + + public static class AvgPartitionHiveDecimalEvaluator extends AvgPartitionEvaluator { + + public AvgPartitionHiveDecimalEvaluator(GenericUDAFEvaluator wrappedEvaluator, + WindowFrameDef winFrame, PTFPartition partition, + List parameters, ObjectInspector inputOI, ObjectInspector outputOI) throws HiveException { + super(wrappedEvaluator, winFrame, partition, parameters, outputOI); + this.typeOperation = new TypeOperationHiveDecimalWritable(); } } } \ No newline at end of file