diff --git a/ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDAFAverage.java b/ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDAFAverage.java index 5ad5c06..a83fd11 100644 --- a/ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDAFAverage.java +++ b/ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDAFAverage.java @@ -19,18 +19,22 @@ import java.util.ArrayList; import java.util.HashSet; +import java.util.List; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.apache.hadoop.hive.common.type.HiveDecimal; import org.apache.hadoop.hive.ql.exec.Description; +import org.apache.hadoop.hive.ql.exec.PTFPartition; import org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException; import org.apache.hadoop.hive.ql.metadata.HiveException; import org.apache.hadoop.hive.ql.parse.SemanticException; +import org.apache.hadoop.hive.ql.plan.ptf.PTFExpressionDef; import org.apache.hadoop.hive.ql.plan.ptf.WindowFrameDef; import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator.AbstractAggregationBuffer; import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator.AggregationBuffer; import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator.AggregationType; +import org.apache.hadoop.hive.ql.udf.ptf.BasePartitionEvaluator; import org.apache.hadoop.hive.ql.util.JavaDataModel; import org.apache.hadoop.hive.serde2.io.DoubleWritable; import org.apache.hadoop.hive.serde2.io.HiveDecimalWritable; @@ -219,6 +223,19 @@ protected DoubleWritable getNextResult( }; } + + @Override + protected BasePartitionEvaluator createPartitionEvaluator( + WindowFrameDef winFrame, + PTFPartition partition, + List parameters, + ObjectInspector outputOI) { + try { + return new BasePartitionEvaluator.AvgPartitionDoubleEvaluator(this, winFrame, partition, parameters, inputOI, outputOI); + } catch(HiveException e) { + return super.createPartitionEvaluator(winFrame, partition, parameters, outputOI); + } + } } public static class GenericUDAFAverageEvaluatorDecimal extends AbstractGenericUDAFAverageEvaluator { @@ -358,6 +375,19 @@ protected HiveDecimalWritable getNextResult( }; } + + @Override + protected BasePartitionEvaluator createPartitionEvaluator( + WindowFrameDef winFrame, + PTFPartition partition, + List parameters, + ObjectInspector outputOI) { + try { + return new BasePartitionEvaluator.AvgPartitionHiveDecimalEvaluator(this, winFrame, partition, parameters, inputOI, outputOI); + } catch(HiveException e) { + return super.createPartitionEvaluator(winFrame, partition, parameters, outputOI); + } + } } @AggregationType(estimable = true) @@ -445,7 +475,7 @@ public ObjectInspector init(Mode m, ObjectInspector[] parameters) } } - protected boolean isWindowingDistinct() { + public boolean isWindowingDistinct() { return isWindowing && avgDistinct; } diff --git a/ql/src/java/org/apache/hadoop/hive/ql/udf/ptf/BasePartitionEvaluator.java b/ql/src/java/org/apache/hadoop/hive/ql/udf/ptf/BasePartitionEvaluator.java index f5f9f7b..7578d84 100644 --- a/ql/src/java/org/apache/hadoop/hive/ql/udf/ptf/BasePartitionEvaluator.java +++ b/ql/src/java/org/apache/hadoop/hive/ql/udf/ptf/BasePartitionEvaluator.java @@ -19,6 +19,7 @@ import java.util.List; +import org.apache.hadoop.hive.common.type.HiveDecimal; import org.apache.hadoop.hive.ql.exec.PTFOperator; import org.apache.hadoop.hive.ql.exec.PTFPartition; import org.apache.hadoop.hive.ql.exec.PTFPartition.PTFPartitionIterator; @@ -30,9 +31,12 @@ import org.apache.hadoop.hive.ql.plan.ptf.BoundaryDef; import org.apache.hadoop.hive.ql.plan.ptf.PTFExpressionDef; import org.apache.hadoop.hive.ql.plan.ptf.WindowFrameDef; +import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFAverage.AbstractGenericUDAFAverageEvaluator; import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator; import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator.AbstractAggregationBuffer; import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator.AggregationBuffer; +import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator.Mode; +import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFSum; import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFSum.GenericUDAFSumEvaluator; import org.apache.hadoop.hive.serde2.io.DoubleWritable; import org.apache.hadoop.hive.serde2.io.HiveDecimalWritable; @@ -84,6 +88,96 @@ public int getSize() { } } + /** + * Define some type specific operation to used in the subclass + */ + private static abstract class TypeOperationBase { + public abstract ResultType add(ResultType t1, ResultType t2); + public abstract ResultType minus(ResultType t1, ResultType t2); + public abstract ResultType div(ResultType sum, int numRows); + } + + private static class TypeOperationLongWritable extends TypeOperationBase { + @Override + public LongWritable add(LongWritable t1, LongWritable t2) { + if (t1 == null && t2 == null) return null; + return new LongWritable((t1 == null ? 0 : t1.get()) + (t2 == null ? 0 : t2.get())); + } + + @Override + public LongWritable minus(LongWritable t1, LongWritable t2) { + if (t1 == null && t2 == null) return null; + return new LongWritable((t1 == null ? 0 : t1.get()) - (t2 == null ? 0 : t2.get())); + } + + @Override + public LongWritable div(LongWritable sum, int numRows) { + return null; // Not used + } + } + + private static class TypeOperationDoubleWritable extends TypeOperationBase { + @Override + public DoubleWritable add(DoubleWritable t1, DoubleWritable t2) { + if (t1 == null && t2 == null) return null; + return new DoubleWritable((t1 == null ? 0 : t1.get()) + (t2 == null ? 0 : t2.get())); + } + + public DoubleWritable minus(DoubleWritable t1, DoubleWritable t2) { + if (t1 == null && t2 == null) return null; + return new DoubleWritable((t1 == null ? 0 : t1.get()) - (t2 == null ? 0 : t2.get())); + } + + @Override + public DoubleWritable div(DoubleWritable sum, int numRows) { + if (sum == null) return null; + + return new DoubleWritable(sum.get() / (double)numRows); + } + } + + private static class TypeOperationHiveDecimalWritable extends TypeOperationBase { + @Override + public HiveDecimalWritable div(HiveDecimalWritable sum, int numRows) { + if (sum == null) return null; + + HiveDecimalWritable result = new HiveDecimalWritable(sum); + result.mutateDivide(HiveDecimal.create(numRows)); + return result; + } + + @Override + public HiveDecimalWritable add(HiveDecimalWritable t1, HiveDecimalWritable t2) { + if (t1 == null && t2 == null) return null; + + if (t1 == null) { + return new HiveDecimalWritable(t2); + } else { + HiveDecimalWritable result = new HiveDecimalWritable(t1); + if (t2 != null) { + result.mutateAdd(t2); + } + return result; + } + } + + @Override + public HiveDecimalWritable minus(HiveDecimalWritable t1, HiveDecimalWritable t2) { + if (t1 == null && t2 == null) return null; + + if (t2 == null) { + return new HiveDecimalWritable(t1); + } else { + HiveDecimalWritable result = new HiveDecimalWritable(t2); + result.mutateNegate(); + if (t1 != null) { + result.mutateAdd(t1); + } + return result; + } + } + } + public BasePartitionEvaluator( GenericUDAFEvaluator wrappedEvaluator, WindowFrameDef winFrame, @@ -218,6 +312,7 @@ private static int getRowBoundaryEnd(BoundaryDef b, int currRow, PTFPartition p) */ public static abstract class SumPartitionEvaluator extends BasePartitionEvaluator { protected final WindowSumAgg sumAgg; + protected TypeOperationBase typeOperation; public SumPartitionEvaluator( GenericUDAFEvaluator wrappedEvaluator, @@ -235,9 +330,6 @@ public SumPartitionEvaluator( boolean empty; } - public abstract ResultType add(ResultType t1, ResultType t2); - public abstract ResultType minus(ResultType t1, ResultType t2); - @SuppressWarnings({ "unchecked", "rawtypes" }) @Override public Object iterate(int currentRow, LeadLagInfo leadLagInfo) throws HiveException { @@ -262,7 +354,8 @@ public Object iterate(int currentRow, LeadLagInfo leadLagInfo) throws HiveExcept Range r2 = new Range(sumAgg.prevRange.end, currentRange.end, partition); ResultType sum1 = (ResultType)calcFunctionValue(r1.iterator(), leadLagInfo); ResultType sum2 = (ResultType)calcFunctionValue(r2.iterator(), leadLagInfo); - result = add(minus(sumAgg.prevSum, sum1), sum2); + result = typeOperation.add(typeOperation.minus(sumAgg.prevSum, sum1), sum2); + sumAgg.prevRange = currentRange; sumAgg.prevSum = result; } @@ -276,18 +369,7 @@ public SumPartitionDoubleEvaluator(GenericUDAFEvaluator wrappedEvaluator, WindowFrameDef winFrame, PTFPartition partition, List parameters, ObjectInspector outputOI) { super(wrappedEvaluator, winFrame, partition, parameters, outputOI); - } - - @Override - public DoubleWritable add(DoubleWritable t1, DoubleWritable t2) { - if (t1 == null && t2 == null) return null; - return new DoubleWritable((t1 == null ? 0 : t1.get()) + (t2 == null ? 0 : t2.get())); - } - - @Override - public DoubleWritable minus(DoubleWritable t1, DoubleWritable t2) { - if (t1 == null && t2 == null) return null; - return new DoubleWritable((t1 == null ? 0 : t1.get()) - (t2 == null ? 0 : t2.get())); + this.typeOperation = new TypeOperationDoubleWritable(); } } @@ -296,18 +378,7 @@ public SumPartitionLongEvaluator(GenericUDAFEvaluator wrappedEvaluator, WindowFrameDef winFrame, PTFPartition partition, List parameters, ObjectInspector outputOI) { super(wrappedEvaluator, winFrame, partition, parameters, outputOI); - } - - @Override - public LongWritable add(LongWritable t1, LongWritable t2) { - if (t1 == null && t2 == null) return null; - return new LongWritable((t1 == null ? 0 : t1.get()) + (t2 == null ? 0 : t2.get())); - } - - @Override - public LongWritable minus(LongWritable t1, LongWritable t2) { - if (t1 == null && t2 == null) return null; - return new LongWritable((t1 == null ? 0 : t1.get()) - (t2 == null ? 0 : t2.get())); + this.typeOperation = new TypeOperationLongWritable(); } } @@ -316,33 +387,69 @@ public SumPartitionHiveDecimalEvaluator(GenericUDAFEvaluator wrappedEvaluator, WindowFrameDef winFrame, PTFPartition partition, List parameters, ObjectInspector outputOI) { super(wrappedEvaluator, winFrame, partition, parameters, outputOI); + this.typeOperation = new TypeOperationHiveDecimalWritable(); } + } - @Override - public HiveDecimalWritable add(HiveDecimalWritable t1, HiveDecimalWritable t2) { - if (t1 == null && t2 == null) return null; - if (t1 == null) { - return t2; - } else { - if (t2 != null) { - t1.mutateAdd(t2); - } - return t1; - } + /** + * The partition evalulator for average function + * @param + */ + public static abstract class AvgPartitionEvaluator + extends BasePartitionEvaluator { + protected SumPartitionEvaluator sumEvaluator; + protected TypeOperationBase typeOperation; + + public AvgPartitionEvaluator( + GenericUDAFEvaluator wrappedEvaluator, + WindowFrameDef winFrame, + PTFPartition partition, + List parameters, + ObjectInspector outputOI) { + super(wrappedEvaluator, winFrame, partition, parameters, outputOI); } + @SuppressWarnings({ "unchecked", "rawtypes" }) @Override - public HiveDecimalWritable minus(HiveDecimalWritable t1, HiveDecimalWritable t2) { - if (t1 == null && t2 == null) return null; - if (t1 == null) { - t2.mutateNegate(); - return t2; - } else { - if (t2 != null) { - t1.mutateSubtract(t2); - } - return t1; + public Object iterate(int currentRow, LeadLagInfo leadLagInfo) throws HiveException { + // // Currently avg(distinct) not supported in PartitionEvaluator + if (((AbstractGenericUDAFAverageEvaluator)wrappedEvaluator).isWindowingDistinct()) { + return super.iterate(currentRow, leadLagInfo); } + + // Use SumPartitionEvaluator to calculate the sum and sum / numRows to get that average + ResultType sum = (ResultType) sumEvaluator.iterate(currentRow, leadLagInfo); + Range currentRange = getRange(winFrame, currentRow, partition); + int numRows = currentRange.getSize(); + return typeOperation.div(sum, numRows); + } + } + + public static class AvgPartitionDoubleEvaluator extends AvgPartitionEvaluator { + + public AvgPartitionDoubleEvaluator(GenericUDAFEvaluator wrappedEvaluator, + WindowFrameDef winFrame, PTFPartition partition, + List parameters, ObjectInspector inputOI, ObjectInspector outputOI) throws HiveException { + super(wrappedEvaluator, winFrame, partition, parameters, outputOI); + GenericUDAFEvaluator wrappedSumEvaluator = new GenericUDAFSum.GenericUDAFSumDouble(); + ObjectInspector sumOutputOI = wrappedSumEvaluator.init(Mode.COMPLETE, new ObjectInspector[] { inputOI } ); + this.sumEvaluator = new SumPartitionDoubleEvaluator(wrappedSumEvaluator, winFrame, partition, parameters, sumOutputOI); + this.typeOperation = new TypeOperationDoubleWritable(); + } + } + + public static class AvgPartitionHiveDecimalEvaluator extends AvgPartitionEvaluator { + + public AvgPartitionHiveDecimalEvaluator(GenericUDAFEvaluator wrappedEvaluator, + WindowFrameDef winFrame, PTFPartition partition, + List parameters, ObjectInspector inputOI, ObjectInspector outputOI) throws HiveException { + super(wrappedEvaluator, winFrame, partition, parameters, outputOI); + + // Create a SumEvaluator to calculate the sum + GenericUDAFEvaluator wrappedSumEvaluator = new GenericUDAFSum.GenericUDAFSumHiveDecimal(); + ObjectInspector sumOutputOI = wrappedSumEvaluator.init(Mode.COMPLETE, new ObjectInspector[] { inputOI } ); + this.sumEvaluator = new SumPartitionHiveDecimalEvaluator(wrappedSumEvaluator, winFrame, partition, parameters, sumOutputOI); + this.typeOperation = new TypeOperationHiveDecimalWritable(); } } } \ No newline at end of file