diff --git a/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/functions/HiveSqlSumAggFunction.java b/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/functions/HiveSqlSumAggFunction.java index 056eaeb..498cd0e 100644 --- a/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/functions/HiveSqlSumAggFunction.java +++ b/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/functions/HiveSqlSumAggFunction.java @@ -71,6 +71,9 @@ public HiveSqlSumAggFunction(boolean isDistinct, SqlReturnTypeInference returnTy //~ Methods ---------------------------------------------------------------- + public boolean isDistinct() { + return isDistinct; + } @Override public T unwrap(Class clazz) { diff --git a/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/translator/SqlFunctionConverter.java b/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/translator/SqlFunctionConverter.java index 75c38fa..19aa414 100644 --- a/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/translator/SqlFunctionConverter.java +++ b/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/translator/SqlFunctionConverter.java @@ -229,6 +229,12 @@ public static ASTNode buildAST(SqlOperator op, List children) { "TOK_FUNCTIONDI"); } } + } else if (op instanceof HiveSqlSumAggFunction) { // case SUM(DISTINCT) + HiveSqlSumAggFunction sumFunction = (HiveSqlSumAggFunction) op; + if (sumFunction.isDistinct()) { + node = (ASTNode) ParseDriver.adaptor.create(HiveParser.TOK_FUNCTIONDI, + "TOK_FUNCTIONDI"); + } } node.addChild((ASTNode) ParseDriver.adaptor.create(HiveParser.Identifier, op.getName())); } diff --git a/ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDAFCount.java b/ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDAFCount.java index f526c43..2825045 100644 --- a/ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDAFCount.java +++ b/ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDAFCount.java @@ -17,13 +17,11 @@ */ package org.apache.hadoop.hive.ql.udf.generic; -import org.apache.commons.lang.ArrayUtils; import org.apache.hadoop.hive.ql.exec.Description; import org.apache.hadoop.hive.ql.exec.UDFArgumentException; import org.apache.hadoop.hive.ql.metadata.HiveException; import org.apache.hadoop.hive.ql.parse.SemanticException; import org.apache.hadoop.hive.ql.util.JavaDataModel; -import org.apache.hadoop.hive.serde2.lazy.LazyString; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils.ObjectInspectorCopyOption; @@ -31,7 +29,6 @@ import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory; import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo; import org.apache.hadoop.io.LongWritable; -import org.apache.hadoop.io.Text; /** * This class implements the COUNT aggregation function as in SQL. diff --git a/ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDAFSum.java b/ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDAFSum.java index 0968008..7b1d6e5 100644 --- a/ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDAFSum.java +++ b/ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDAFSum.java @@ -24,14 +24,14 @@ import org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException; import org.apache.hadoop.hive.ql.metadata.HiveException; import org.apache.hadoop.hive.ql.parse.SemanticException; -import org.apache.hadoop.hive.ql.parse.WindowingSpec.BoundarySpec; -import org.apache.hadoop.hive.ql.plan.ptf.BoundaryDef; import org.apache.hadoop.hive.ql.plan.ptf.WindowFrameDef; import org.apache.hadoop.hive.ql.util.JavaDataModel; import org.apache.hadoop.hive.serde2.io.DoubleWritable; import org.apache.hadoop.hive.serde2.io.HiveDecimalWritable; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils; import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils.ObjectInspectorCopyOption; import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory; import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils; import org.apache.hadoop.hive.serde2.typeinfo.DecimalTypeInfo; @@ -87,6 +87,17 @@ public GenericUDAFEvaluator getEvaluator(TypeInfo[] parameters) } } + @Override + public GenericUDAFEvaluator getEvaluator(GenericUDAFParameterInfo info) + throws SemanticException { + TypeInfo[] parameters = info.getParameters(); + + GenericUDAFSumEvaluator eval = (GenericUDAFSumEvaluator) getEvaluator(parameters); + eval.setSumDistinct(info.isDistinct()); + + return eval; + } + public static PrimitiveObjectInspector.PrimitiveCategory getReturnType(TypeInfo type) { if (type.getCategory() != ObjectInspector.Category.PRIMITIVE) { return null; @@ -111,12 +122,54 @@ public GenericUDAFEvaluator getEvaluator(TypeInfo[] parameters) } /** + * The base type for sum operator evaluator + * + */ + public static abstract class GenericUDAFSumEvaluator extends GenericUDAFEvaluator { + static abstract class SumAgg extends AbstractAggregationBuffer { + boolean empty; + T sum; + Object previousValue = null; + } + + protected PrimitiveObjectInspector inputOI; + protected ObjectInspector outputOI; + protected ResultType result; + protected boolean sumDistinct; + + public boolean sumDistinct() { + return sumDistinct; + } + + public void setSumDistinct(boolean sumDistinct) { + this.sumDistinct = sumDistinct; + } + + /** + * Check if the input object is the same as the previous one for the case of + * SUM(DISTINCT). + * @param input the input object + * @return True if sumDistinct is false or the input is different from the previous object + */ + protected boolean checkDistinct(SumAgg agg, Object input) { + if (this.sumDistinct && + ObjectInspectorUtils.compare(input, inputOI, agg.previousValue, outputOI) == 0) { + return false; + } + + agg.previousValue = ObjectInspectorUtils.copyToStandardObject( + input, inputOI, ObjectInspectorCopyOption.JAVA); + return true; + } + + + } + + /** * GenericUDAFSumHiveDecimal. * */ - public static class GenericUDAFSumHiveDecimal extends GenericUDAFEvaluator { - private PrimitiveObjectInspector inputOI; - private HiveDecimalWritable result; + public static class GenericUDAFSumHiveDecimal extends GenericUDAFSumEvaluator { @Override public ObjectInspector init(Mode m, ObjectInspector[] parameters) throws HiveException { @@ -124,6 +177,8 @@ public ObjectInspector init(Mode m, ObjectInspector[] parameters) throws HiveExc super.init(m, parameters); result = new HiveDecimalWritable(HiveDecimal.ZERO); inputOI = (PrimitiveObjectInspector) parameters[0]; + outputOI = ObjectInspectorUtils.getStandardObjectInspector(inputOI, + ObjectInspectorCopyOption.JAVA); // The output precision is 10 greater than the input which should cover at least // 10b rows. The scale is the same as the input. DecimalTypeInfo outputTypeInfo = null; @@ -138,9 +193,7 @@ public ObjectInspector init(Mode m, ObjectInspector[] parameters) throws HiveExc /** class for storing decimal sum value. */ @AggregationType(estimable = false) // hard to know exactly for decimals - static class SumHiveDecimalAgg extends AbstractAggregationBuffer { - boolean empty; - HiveDecimal sum; + static class SumHiveDecimalAgg extends SumAgg { } @Override @@ -152,7 +205,7 @@ public AggregationBuffer getNewAggregationBuffer() throws HiveException { @Override public void reset(AggregationBuffer agg) throws HiveException { - SumHiveDecimalAgg bdAgg = (SumHiveDecimalAgg) agg; + SumAgg bdAgg = (SumAgg) agg; bdAgg.empty = true; bdAgg.sum = HiveDecimal.ZERO; } @@ -163,7 +216,9 @@ public void reset(AggregationBuffer agg) throws HiveException { public void iterate(AggregationBuffer agg, Object[] parameters) throws HiveException { assert (parameters.length == 1); try { - merge(agg, parameters[0]); + if (checkDistinct((SumAgg) agg, parameters[0])) { + merge(agg, parameters[0]); + } } catch (NumberFormatException e) { if (!warned) { warned = true; @@ -239,24 +294,21 @@ protected HiveDecimal getCurrentIntermediateResult( * GenericUDAFSumDouble. * */ - public static class GenericUDAFSumDouble extends GenericUDAFEvaluator { - private PrimitiveObjectInspector inputOI; - private DoubleWritable result; - + public static class GenericUDAFSumDouble extends GenericUDAFSumEvaluator { @Override public ObjectInspector init(Mode m, ObjectInspector[] parameters) throws HiveException { assert (parameters.length == 1); super.init(m, parameters); result = new DoubleWritable(0); inputOI = (PrimitiveObjectInspector) parameters[0]; + outputOI = ObjectInspectorUtils.getStandardObjectInspector(inputOI, + ObjectInspectorCopyOption.JAVA); return PrimitiveObjectInspectorFactory.writableDoubleObjectInspector; } /** class for storing double sum value. */ @AggregationType(estimable = true) - static class SumDoubleAgg extends AbstractAggregationBuffer { - boolean empty; - double sum; + static class SumDoubleAgg extends SumAgg { @Override public int estimate() { return JavaDataModel.PRIMITIVES1 + JavaDataModel.PRIMITIVES2; } } @@ -272,7 +324,7 @@ public AggregationBuffer getNewAggregationBuffer() throws HiveException { public void reset(AggregationBuffer agg) throws HiveException { SumDoubleAgg myagg = (SumDoubleAgg) agg; myagg.empty = true; - myagg.sum = 0; + myagg.sum = 0.0; } boolean warned = false; @@ -281,7 +333,9 @@ public void reset(AggregationBuffer agg) throws HiveException { public void iterate(AggregationBuffer agg, Object[] parameters) throws HiveException { assert (parameters.length == 1); try { - merge(agg, parameters[0]); + if (checkDistinct((SumAgg) agg, parameters[0])) { + merge(agg, parameters[0]); + } } catch (NumberFormatException e) { if (!warned) { warned = true; @@ -354,24 +408,21 @@ protected Double getCurrentIntermediateResult( * GenericUDAFSumLong. * */ - public static class GenericUDAFSumLong extends GenericUDAFEvaluator { - private PrimitiveObjectInspector inputOI; - protected LongWritable result; - + public static class GenericUDAFSumLong extends GenericUDAFSumEvaluator { @Override public ObjectInspector init(Mode m, ObjectInspector[] parameters) throws HiveException { assert (parameters.length == 1); super.init(m, parameters); result = new LongWritable(0); inputOI = (PrimitiveObjectInspector) parameters[0]; + outputOI = ObjectInspectorUtils.getStandardObjectInspector(inputOI, + ObjectInspectorCopyOption.JAVA); return PrimitiveObjectInspectorFactory.writableLongObjectInspector; } /** class for storing double sum value. */ @AggregationType(estimable = true) - static class SumLongAgg extends AbstractAggregationBuffer { - boolean empty; - long sum; + static class SumLongAgg extends SumAgg { @Override public int estimate() { return JavaDataModel.PRIMITIVES1 + JavaDataModel.PRIMITIVES2; } } @@ -387,7 +438,7 @@ public AggregationBuffer getNewAggregationBuffer() throws HiveException { public void reset(AggregationBuffer agg) throws HiveException { SumLongAgg myagg = (SumLongAgg) agg; myagg.empty = true; - myagg.sum = 0; + myagg.sum = 0L; } private boolean warned = false; @@ -396,7 +447,9 @@ public void reset(AggregationBuffer agg) throws HiveException { public void iterate(AggregationBuffer agg, Object[] parameters) throws HiveException { assert (parameters.length == 1); try { - merge(agg, parameters[0]); + if (checkDistinct((SumAgg) agg, parameters[0])) { + merge(agg, parameters[0]); + } } catch (NumberFormatException e) { if (!warned) { warned = true; @@ -460,5 +513,4 @@ protected Long getCurrentIntermediateResult( }; } } - } diff --git a/ql/src/test/queries/clientpositive/windowing_distinct.q b/ql/src/test/queries/clientpositive/windowing_distinct.q index 94f4044..9f6ddfd 100644 --- a/ql/src/test/queries/clientpositive/windowing_distinct.q +++ b/ql/src/test/queries/clientpositive/windowing_distinct.q @@ -28,3 +28,11 @@ SELECT COUNT(DISTINCT t) OVER (PARTITION BY index), COUNT(DISTINCT dec) OVER (PARTITION BY index), COUNT(DISTINCT bin) OVER (PARTITION BY index) FROM windowing_distinct; + +SELECT SUM(DISTINCT t) OVER (PARTITION BY index), + SUM(DISTINCT d) OVER (PARTITION BY index), + SUM(DISTINCT s) OVER (PARTITION BY index), + SUM(DISTINCT concat('Mr.', s)) OVER (PARTITION BY index), + SUM(DISTINCT ts) OVER (PARTITION BY index), + SUM(DISTINCT dec) OVER (PARTITION BY index) +FROM windowing_distinct; diff --git a/ql/src/test/results/clientpositive/windowing_distinct.q.out b/ql/src/test/results/clientpositive/windowing_distinct.q.out index 50f8ff8..7129dd3 100644 --- a/ql/src/test/results/clientpositive/windowing_distinct.q.out +++ b/ql/src/test/results/clientpositive/windowing_distinct.q.out @@ -76,3 +76,29 @@ POSTHOOK: Input: default@windowing_distinct 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 +PREHOOK: query: SELECT SUM(DISTINCT t) OVER (PARTITION BY index), + SUM(DISTINCT d) OVER (PARTITION BY index), + SUM(DISTINCT s) OVER (PARTITION BY index), + SUM(DISTINCT concat('Mr.', s)) OVER (PARTITION BY index), + SUM(DISTINCT ts) OVER (PARTITION BY index), + SUM(DISTINCT dec) OVER (PARTITION BY index) +FROM windowing_distinct +PREHOOK: type: QUERY +PREHOOK: Input: default@windowing_distinct +#### A masked pattern was here #### +POSTHOOK: query: SELECT SUM(DISTINCT t) OVER (PARTITION BY index), + SUM(DISTINCT d) OVER (PARTITION BY index), + SUM(DISTINCT s) OVER (PARTITION BY index), + SUM(DISTINCT concat('Mr.', s)) OVER (PARTITION BY index), + SUM(DISTINCT ts) OVER (PARTITION BY index), + SUM(DISTINCT dec) OVER (PARTITION BY index) +FROM windowing_distinct +POSTHOOK: type: QUERY +POSTHOOK: Input: default@windowing_distinct +#### A masked pattern was here #### +73 94.4 0.0 0.0 4.086473756109513E9 87 +73 94.4 0.0 0.0 4.086473756109513E9 87 +73 94.4 0.0 0.0 4.086473756109513E9 87 +359 119.89 0.0 0.0 4.086473756109914E9 114 +359 119.89 0.0 0.0 4.086473756109914E9 114 +359 119.89 0.0 0.0 4.086473756109914E9 114