diff --git a/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/functions/CanAggregateDistinct.java b/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/functions/CanAggregateDistinct.java
new file mode 100644
index 0000000..c24f3c0
--- /dev/null
+++ b/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/functions/CanAggregateDistinct.java
@@ -0,0 +1,27 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.hadoop.hive.ql.optimizer.calcite.functions;
+
+/**
+ * This is the UDAF interface to support DISTINCT function.
+ *
+ */
+public interface CanAggregateDistinct {
+ boolean isDistinct();
+}
diff --git a/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/functions/HiveSqlCountAggFunction.java b/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/functions/HiveSqlCountAggFunction.java
index 58191e5..bc48707 100644
--- a/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/functions/HiveSqlCountAggFunction.java
+++ b/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/functions/HiveSqlCountAggFunction.java
@@ -30,7 +30,7 @@
import org.apache.calcite.sql.type.SqlTypeName;
import org.apache.calcite.util.ImmutableIntList;
-public class HiveSqlCountAggFunction extends SqlAggFunction {
+public class HiveSqlCountAggFunction extends SqlAggFunction implements CanAggregateDistinct {
final boolean isDistinct;
final SqlReturnTypeInference returnTypeInference;
@@ -52,6 +52,7 @@ public HiveSqlCountAggFunction(boolean isDistinct, SqlReturnTypeInference return
this.operandTypeInference = operandTypeInference;
}
+ @Override
public boolean isDistinct() {
return isDistinct;
}
diff --git a/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/functions/HiveSqlSumAggFunction.java b/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/functions/HiveSqlSumAggFunction.java
index 498cd0e..dc286a2 100644
--- a/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/functions/HiveSqlSumAggFunction.java
+++ b/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/functions/HiveSqlSumAggFunction.java
@@ -46,7 +46,7 @@
* long, float, double), and the result
* is the same type.
*/
-public class HiveSqlSumAggFunction extends SqlAggFunction {
+public class HiveSqlSumAggFunction extends SqlAggFunction implements CanAggregateDistinct{
final boolean isDistinct;
final SqlReturnTypeInference returnTypeInference;
final SqlOperandTypeInference operandTypeInference;
@@ -70,7 +70,7 @@ public HiveSqlSumAggFunction(boolean isDistinct, SqlReturnTypeInference returnTy
}
//~ Methods ----------------------------------------------------------------
-
+ @Override
public boolean isDistinct() {
return isDistinct;
}
diff --git a/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/translator/SqlFunctionConverter.java b/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/translator/SqlFunctionConverter.java
index 19aa414..7d70291 100644
--- a/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/translator/SqlFunctionConverter.java
+++ b/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/translator/SqlFunctionConverter.java
@@ -45,6 +45,7 @@
import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
import org.apache.hadoop.hive.ql.optimizer.calcite.CalciteSemanticException;
import org.apache.hadoop.hive.ql.optimizer.calcite.CalciteSemanticException.UnsupportedFeature;
+import org.apache.hadoop.hive.ql.optimizer.calcite.functions.CanAggregateDistinct;
import org.apache.hadoop.hive.ql.optimizer.calcite.functions.HiveSqlCountAggFunction;
import org.apache.hadoop.hive.ql.optimizer.calcite.functions.HiveSqlMinMaxAggFunction;
import org.apache.hadoop.hive.ql.optimizer.calcite.functions.HiveSqlSumAggFunction;
@@ -217,24 +218,18 @@ public static ASTNode buildAST(SqlOperator op, List children) {
} else if (op.kind == SqlKind.PLUS_PREFIX) {
node = (ASTNode) ParseDriver.adaptor.create(HiveParser.PLUS, "PLUS");
} else {
- // Handle 'COUNT' function for the case of COUNT(*) and COUNT(DISTINCT)
- if (op instanceof HiveSqlCountAggFunction) {
+ // Handle COUNT/SUM/AVG function for the case of COUNT(*) and COUNT(DISTINCT)
+ if (op instanceof CanAggregateDistinct) {
if (children.size() == 0) {
node = (ASTNode) ParseDriver.adaptor.create(HiveParser.TOK_FUNCTIONSTAR,
"TOK_FUNCTIONSTAR");
} else {
- HiveSqlCountAggFunction countFunction = (HiveSqlCountAggFunction)op;
- if (countFunction.isDistinct()) {
+ CanAggregateDistinct distinctFunction = (CanAggregateDistinct) op;
+ if (distinctFunction.isDistinct()) {
node = (ASTNode) ParseDriver.adaptor.create(HiveParser.TOK_FUNCTIONDI,
"TOK_FUNCTIONDI");
}
}
- } else if (op instanceof HiveSqlSumAggFunction) { // case SUM(DISTINCT)
- HiveSqlSumAggFunction sumFunction = (HiveSqlSumAggFunction) op;
- if (sumFunction.isDistinct()) {
- node = (ASTNode) ParseDriver.adaptor.create(HiveParser.TOK_FUNCTIONDI,
- "TOK_FUNCTIONDI");
- }
}
node.addChild((ASTNode) ParseDriver.adaptor.create(HiveParser.Identifier, op.getName()));
}
@@ -364,11 +359,18 @@ private static HiveToken hToken(int type, String text) {
}
// UDAF is assumed to be deterministic
- public static class CalciteUDAF extends SqlAggFunction {
- public CalciteUDAF(String opName, SqlReturnTypeInference returnTypeInference,
+ public static class CalciteUDAF extends SqlAggFunction implements CanAggregateDistinct {
+ private boolean isDistinct;
+ public CalciteUDAF(boolean isDistinct, String opName, SqlReturnTypeInference returnTypeInference,
SqlOperandTypeInference operandTypeInference, SqlOperandTypeChecker operandTypeChecker) {
super(opName, SqlKind.OTHER_FUNCTION, returnTypeInference, operandTypeInference,
operandTypeChecker, SqlFunctionCategory.USER_DEFINED_FUNCTION);
+ this.isDistinct = isDistinct;
+ }
+
+ @Override
+ public boolean isDistinct() {
+ return isDistinct;
}
}
@@ -466,6 +468,7 @@ public static SqlAggFunction getCalciteAggFn(String hiveUdfName, boolean isDisti
break;
default:
calciteAggFn = new CalciteUDAF(
+ isDistinct,
udfInfo.udfName,
udfInfo.returnTypeInference,
udfInfo.operandTypeInference,
diff --git a/ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDAFAverage.java b/ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDAFAverage.java
index cd2449f..5547753 100644
--- a/ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDAFAverage.java
+++ b/ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDAFAverage.java
@@ -35,9 +35,11 @@
import org.apache.hadoop.hive.serde2.io.HiveDecimalWritable;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils;
import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.StructField;
import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils.ObjectInspectorCopyOption;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.DoubleObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.HiveDecimalObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.LongObjectInspector;
@@ -95,6 +97,14 @@ public GenericUDAFEvaluator getEvaluator(TypeInfo[] parameters)
}
}
+ @Override
+ public GenericUDAFEvaluator getEvaluator(GenericUDAFParameterInfo paramInfo)
+ throws SemanticException {
+ AbstractGenericUDAFAverageEvaluator eval =
+ (AbstractGenericUDAFAverageEvaluator) getEvaluator(paramInfo.getParameters());
+ eval.avgDistinct = paramInfo.isDistinct();
+ return eval;
+ }
public static class GenericUDAFAverageEvaluatorDouble extends AbstractGenericUDAFAverageEvaluator {
@@ -102,6 +112,7 @@ public GenericUDAFEvaluator getEvaluator(TypeInfo[] parameters)
public void doReset(AverageAggregationBuffer aggregation) throws HiveException {
aggregation.count = 0;
aggregation.sum = new Double(0);
+ aggregation.previousValue = null;
}
@Override
@@ -319,15 +330,18 @@ protected HiveDecimalWritable getNextResult(
}
private static class AverageAggregationBuffer implements AggregationBuffer {
+ private Object previousValue;
private long count;
private TYPE sum;
};
@SuppressWarnings("unchecked")
public static abstract class AbstractGenericUDAFAverageEvaluator extends GenericUDAFEvaluator {
+ protected boolean avgDistinct;
// For PARTIAL1 and COMPLETE
protected transient PrimitiveObjectInspector inputOI;
+ protected transient ObjectInspector copiedOI;
// For PARTIAL2 and FINAL
private transient StructObjectInspector soi;
private transient StructField countField;
@@ -359,6 +373,8 @@ public ObjectInspector init(Mode m, ObjectInspector[] parameters)
// init input
if (mode == Mode.PARTIAL1 || mode == Mode.COMPLETE) {
inputOI = (PrimitiveObjectInspector) parameters[0];
+ copiedOI = ObjectInspectorUtils.getStandardObjectInspector(inputOI,
+ ObjectInspectorCopyOption.JAVA);
} else {
soi = (StructObjectInspector) parameters[0];
countField = soi.getStructFieldRef("count");
@@ -412,6 +428,14 @@ public void iterate(AggregationBuffer aggregation, Object[] parameters)
if (parameter != null) {
AverageAggregationBuffer averageAggregation = (AverageAggregationBuffer) aggregation;
try {
+ // Skip the same value if avgDistinct is true
+ if (this.avgDistinct &&
+ ObjectInspectorUtils.compare(parameter, inputOI, averageAggregation.previousValue, copiedOI) == 0) {
+ return;
+ }
+ averageAggregation.previousValue = ObjectInspectorUtils.copyToStandardObject(
+ parameter, inputOI, ObjectInspectorCopyOption.JAVA);
+
doIterate(averageAggregation, inputOI, parameter);
} catch (NumberFormatException e) {
if (!warned) {
diff --git a/ql/src/test/queries/clientpositive/windowing_distinct.q b/ql/src/test/queries/clientpositive/windowing_distinct.q
index 9f6ddfd..bb192a7 100644
--- a/ql/src/test/queries/clientpositive/windowing_distinct.q
+++ b/ql/src/test/queries/clientpositive/windowing_distinct.q
@@ -36,3 +36,11 @@ SELECT SUM(DISTINCT t) OVER (PARTITION BY index),
SUM(DISTINCT ts) OVER (PARTITION BY index),
SUM(DISTINCT dec) OVER (PARTITION BY index)
FROM windowing_distinct;
+
+SELECT AVG(DISTINCT t) OVER (PARTITION BY index),
+ AVG(DISTINCT d) OVER (PARTITION BY index),
+ AVG(DISTINCT s) OVER (PARTITION BY index),
+ AVG(DISTINCT concat('Mr.', s)) OVER (PARTITION BY index),
+ AVG(DISTINCT ts) OVER (PARTITION BY index),
+ AVG(DISTINCT dec) OVER (PARTITION BY index)
+FROM windowing_distinct;
diff --git a/ql/src/test/results/clientpositive/windowing_distinct.q.out b/ql/src/test/results/clientpositive/windowing_distinct.q.out
index 0858f0f..074a594 100644
--- a/ql/src/test/results/clientpositive/windowing_distinct.q.out
+++ b/ql/src/test/results/clientpositive/windowing_distinct.q.out
@@ -102,3 +102,29 @@ POSTHOOK: Input: default@windowing_distinct
235 77.42 0.0 0.0 2.724315837406612E9 69
235 77.42 0.0 0.0 2.724315837406612E9 69
235 77.42 0.0 0.0 2.724315837406612E9 69
+PREHOOK: query: SELECT AVG(DISTINCT t) OVER (PARTITION BY index),
+ AVG(DISTINCT d) OVER (PARTITION BY index),
+ AVG(DISTINCT s) OVER (PARTITION BY index),
+ AVG(DISTINCT concat('Mr.', s)) OVER (PARTITION BY index),
+ AVG(DISTINCT ts) OVER (PARTITION BY index),
+ AVG(DISTINCT dec) OVER (PARTITION BY index)
+FROM windowing_distinct
+PREHOOK: type: QUERY
+PREHOOK: Input: default@windowing_distinct
+#### A masked pattern was here ####
+POSTHOOK: query: SELECT AVG(DISTINCT t) OVER (PARTITION BY index),
+ AVG(DISTINCT d) OVER (PARTITION BY index),
+ AVG(DISTINCT s) OVER (PARTITION BY index),
+ AVG(DISTINCT concat('Mr.', s)) OVER (PARTITION BY index),
+ AVG(DISTINCT ts) OVER (PARTITION BY index),
+ AVG(DISTINCT dec) OVER (PARTITION BY index)
+FROM windowing_distinct
+POSTHOOK: type: QUERY
+POSTHOOK: Input: default@windowing_distinct
+#### A masked pattern was here ####
+27.0 28.315 NULL NULL 1.362157918703148E9 28.5000
+27.0 28.315 NULL NULL 1.362157918703148E9 28.5000
+27.0 28.315 NULL NULL 1.362157918703148E9 28.5000
+117.5 38.71 NULL NULL 1.362157918703306E9 34.5000
+117.5 38.71 NULL NULL 1.362157918703306E9 34.5000
+117.5 38.71 NULL NULL 1.362157918703306E9 34.5000