diff --git ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/rules/HiveAggregateReduceRule.java ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/rules/HiveAggregateReduceRule.java
new file mode 100644
index 0000000..823d958
--- /dev/null
+++ ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/rules/HiveAggregateReduceRule.java
@@ -0,0 +1,120 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to you under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.hadoop.hive.ql.optimizer.calcite.rules;
+
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+import org.apache.calcite.plan.RelOptRule;
+import org.apache.calcite.plan.RelOptRuleCall;
+import org.apache.calcite.rel.core.Aggregate;
+import org.apache.calcite.rel.core.AggregateCall;
+import org.apache.calcite.rex.RexBuilder;
+import org.apache.calcite.rex.RexNode;
+import org.apache.calcite.sql.SqlKind;
+import org.apache.calcite.tools.RelBuilder;
+import org.apache.hadoop.hive.ql.optimizer.calcite.HiveRelFactories;
+import org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveAggregate;
+
+import com.google.common.collect.Lists;
+
+/**
+ * Planner rule that reduces aggregate functions in
+ * {@link org.apache.calcite.rel.core.Aggregate}s to simpler forms.
+ *
+ *
Rewrites:
+ *
+ *
+ * - COUNT(x) → COUNT(*) if x is not nullable
+ *
+ */
+public class HiveAggregateReduceRule extends RelOptRule {
+
+ /** The singleton. */
+ public static final HiveAggregateReduceRule INSTANCE =
+ new HiveAggregateReduceRule();
+
+ /** Private constructor. */
+ private HiveAggregateReduceRule() {
+ super(operand(HiveAggregate.class, any()),
+ HiveRelFactories.HIVE_BUILDER, null);
+ }
+
+ @Override
+ public void onMatch(RelOptRuleCall call) {
+ final RelBuilder relBuilder = call.builder();
+ final Aggregate aggRel = (Aggregate) call.rel(0);
+ final RexBuilder rexBuilder = aggRel.getCluster().getRexBuilder();
+
+ // We try to rewrite COUNT(x) into COUNT(*) if x is not nullable.
+ // We remove duplicate aggregate calls as well.
+ boolean rewrite = false;
+ final Map mapping = new HashMap<>();
+ final List indexes = new ArrayList<>();
+ final List aggCalls = aggRel.getAggCallList();
+ final List newAggCalls = new ArrayList<>(aggCalls.size());
+ int nextIdx = aggRel.getGroupCount() + aggRel.getIndicatorCount();
+ for (int i = 0; i < aggCalls.size(); i++) {
+ AggregateCall aggCall = aggCalls.get(i);
+ if (aggCall.getAggregation().getKind() == SqlKind.COUNT && !aggCall.isDistinct()) {
+ final List args = aggCall.getArgList();
+ final List nullableArgs = new ArrayList<>(args.size());
+ for (int arg : args) {
+ if (aggRel.getInput().getRowType().getFieldList().get(arg).getType().isNullable()) {
+ nullableArgs.add(arg);
+ }
+ }
+ if (nullableArgs.size() != args.size()) {
+ aggCall = aggCall.copy(nullableArgs, aggCall.filterArg);
+ rewrite = true;
+ }
+ }
+ Integer idx = mapping.get(aggCall);
+ if (idx == null) {
+ newAggCalls.add(aggCall);
+ idx = nextIdx++;
+ mapping.put(aggCall, idx);
+ } else {
+ rewrite = true;
+ }
+ indexes.add(idx);
+ }
+
+ if (rewrite) {
+ // We trigger the transform
+ final List projList = Lists.newArrayList();
+ for (int i = 0; i < aggRel.getGroupCount() + aggRel.getIndicatorCount(); ++i) {
+ projList.add(
+ rexBuilder.makeInputRef(
+ aggRel.getRowType().getFieldList().get(i).getType(), i));
+ }
+ for (int idx : indexes) {
+ projList.add(
+ rexBuilder.makeInputRef(
+ aggRel.getRowType().getFieldList().get(idx).getType(), idx));
+ }
+
+ final Aggregate newAggregate = aggRel.copy(aggRel.getTraitSet(), aggRel.getInput(),
+ aggRel.indicator, aggRel.getGroupSet(), aggRel.getGroupSets(),
+ newAggCalls);
+ call.transformTo(relBuilder.push(newAggregate).project(projList).build());
+ }
+ }
+
+}
diff --git ql/src/java/org/apache/hadoop/hive/ql/parse/CalcitePlanner.java ql/src/java/org/apache/hadoop/hive/ql/parse/CalcitePlanner.java
index 931e074..efbedef 100644
--- ql/src/java/org/apache/hadoop/hive/ql/parse/CalcitePlanner.java
+++ ql/src/java/org/apache/hadoop/hive/ql/parse/CalcitePlanner.java
@@ -171,6 +171,7 @@
import org.apache.hadoop.hive.ql.optimizer.calcite.rules.HiveAggregateJoinTransposeRule;
import org.apache.hadoop.hive.ql.optimizer.calcite.rules.HiveAggregateProjectMergeRule;
import org.apache.hadoop.hive.ql.optimizer.calcite.rules.HiveAggregatePullUpConstantsRule;
+import org.apache.hadoop.hive.ql.optimizer.calcite.rules.HiveAggregateReduceRule;
import org.apache.hadoop.hive.ql.optimizer.calcite.rules.HiveDruidProjectFilterTransposeRule;
import org.apache.hadoop.hive.ql.optimizer.calcite.rules.HiveExceptRewriteRule;
import org.apache.hadoop.hive.ql.optimizer.calcite.rules.HiveExpandDistinctAggregatesRule;
@@ -1682,6 +1683,7 @@ private RelNode applyPreJoinOrderingTransforms(RelNode basePlan, RelMetadataProv
rules.add(HiveReduceExpressionsRule.PROJECT_INSTANCE);
rules.add(HiveReduceExpressionsRule.FILTER_INSTANCE);
rules.add(HiveReduceExpressionsRule.JOIN_INSTANCE);
+ rules.add(HiveAggregateReduceRule.INSTANCE);
if (conf.getBoolVar(HiveConf.ConfVars.HIVEPOINTLOOKUPOPTIMIZER)) {
rules.add(new HivePointLookupOptimizerRule.FilterCondition(minNumORClauses));
rules.add(new HivePointLookupOptimizerRule.JoinCondition(minNumORClauses));
diff --git ql/src/test/results/clientpositive/llap/except_distinct.q.out ql/src/test/results/clientpositive/llap/except_distinct.q.out
index e4c2941..92c628d 100644
--- ql/src/test/results/clientpositive/llap/except_distinct.q.out
+++ ql/src/test/results/clientpositive/llap/except_distinct.q.out
@@ -218,11 +218,11 @@ STAGE PLANS:
Statistics: Num rows: 500 Data size: 89000 Basic stats: COMPLETE Column stats: COMPLETE
Select Operator
expressions: key (type: string), value (type: string)
- outputColumnNames: _col0, _col1
+ outputColumnNames: key, value
Statistics: Num rows: 500 Data size: 89000 Basic stats: COMPLETE Column stats: COMPLETE
Group By Operator
- aggregations: count(2)
- keys: _col0 (type: string), _col1 (type: string)
+ aggregations: count()
+ keys: key (type: string), value (type: string)
mode: hash
outputColumnNames: _col0, _col1, _col2
Statistics: Num rows: 250 Data size: 46500 Basic stats: COMPLETE Column stats: COMPLETE
@@ -241,11 +241,11 @@ STAGE PLANS:
Statistics: Num rows: 500 Data size: 89000 Basic stats: COMPLETE Column stats: COMPLETE
Select Operator
expressions: key (type: string), value (type: string)
- outputColumnNames: _col0, _col1
+ outputColumnNames: key, value
Statistics: Num rows: 500 Data size: 89000 Basic stats: COMPLETE Column stats: COMPLETE
Group By Operator
- aggregations: count(1)
- keys: _col0 (type: string), _col1 (type: string)
+ aggregations: count()
+ keys: key (type: string), value (type: string)
mode: hash
outputColumnNames: _col0, _col1, _col2
Statistics: Num rows: 250 Data size: 46500 Basic stats: COMPLETE Column stats: COMPLETE
@@ -384,11 +384,11 @@ STAGE PLANS:
Statistics: Num rows: 500 Data size: 89000 Basic stats: COMPLETE Column stats: COMPLETE
Select Operator
expressions: key (type: string), value (type: string)
- outputColumnNames: _col0, _col1
+ outputColumnNames: key, value
Statistics: Num rows: 500 Data size: 89000 Basic stats: COMPLETE Column stats: COMPLETE
Group By Operator
- aggregations: count(2)
- keys: _col0 (type: string), _col1 (type: string)
+ aggregations: count()
+ keys: key (type: string), value (type: string)
mode: hash
outputColumnNames: _col0, _col1, _col2
Statistics: Num rows: 250 Data size: 46500 Basic stats: COMPLETE Column stats: COMPLETE
@@ -400,11 +400,11 @@ STAGE PLANS:
value expressions: _col2 (type: bigint)
Select Operator
expressions: key (type: string), value (type: string)
- outputColumnNames: _col0, _col1
+ outputColumnNames: key, value
Statistics: Num rows: 500 Data size: 89000 Basic stats: COMPLETE Column stats: COMPLETE
Group By Operator
- aggregations: count(1)
- keys: _col0 (type: string), _col1 (type: string)
+ aggregations: count()
+ keys: key (type: string), value (type: string)
mode: hash
outputColumnNames: _col0, _col1, _col2
Statistics: Num rows: 250 Data size: 46500 Basic stats: COMPLETE Column stats: COMPLETE
@@ -416,11 +416,11 @@ STAGE PLANS:
value expressions: _col2 (type: bigint)
Select Operator
expressions: key (type: string), value (type: string)
- outputColumnNames: _col0, _col1
+ outputColumnNames: key, value
Statistics: Num rows: 500 Data size: 89000 Basic stats: COMPLETE Column stats: COMPLETE
Group By Operator
- aggregations: count(1)
- keys: _col0 (type: string), _col1 (type: string)
+ aggregations: count()
+ keys: key (type: string), value (type: string)
mode: hash
outputColumnNames: _col0, _col1, _col2
Statistics: Num rows: 250 Data size: 46500 Basic stats: COMPLETE Column stats: COMPLETE
@@ -439,11 +439,11 @@ STAGE PLANS:
Statistics: Num rows: 500 Data size: 89000 Basic stats: COMPLETE Column stats: COMPLETE
Select Operator
expressions: key (type: string), value (type: string)
- outputColumnNames: _col0, _col1
+ outputColumnNames: key, value
Statistics: Num rows: 500 Data size: 89000 Basic stats: COMPLETE Column stats: COMPLETE
Group By Operator
- aggregations: count(1)
- keys: _col0 (type: string), _col1 (type: string)
+ aggregations: count()
+ keys: key (type: string), value (type: string)
mode: hash
outputColumnNames: _col0, _col1, _col2
Statistics: Num rows: 250 Data size: 46500 Basic stats: COMPLETE Column stats: COMPLETE
@@ -559,7 +559,7 @@ STAGE PLANS:
outputColumnNames: _col0, _col1
Statistics: Num rows: 41 Data size: 7954 Basic stats: COMPLETE Column stats: COMPLETE
Group By Operator
- aggregations: count(2)
+ aggregations: count()
keys: _col0 (type: string), _col1 (type: string)
mode: complete
outputColumnNames: _col0, _col1, _col2
@@ -601,7 +601,7 @@ STAGE PLANS:
outputColumnNames: _col0, _col1
Statistics: Num rows: 24 Data size: 4656 Basic stats: COMPLETE Column stats: COMPLETE
Group By Operator
- aggregations: count(2)
+ aggregations: count()
keys: _col0 (type: string), _col1 (type: string)
mode: complete
outputColumnNames: _col0, _col1, _col2
@@ -768,7 +768,7 @@ STAGE PLANS:
outputColumnNames: _col0
Statistics: Num rows: 2 Data size: 6 Basic stats: COMPLETE Column stats: NONE
Group By Operator
- aggregations: count(2)
+ aggregations: count()
keys: _col0 (type: int)
mode: complete
outputColumnNames: _col0, _col1
@@ -825,7 +825,7 @@ STAGE PLANS:
outputColumnNames: _col0
Statistics: Num rows: 2 Data size: 6 Basic stats: COMPLETE Column stats: NONE
Group By Operator
- aggregations: count(1)
+ aggregations: count()
keys: _col0 (type: int)
mode: complete
outputColumnNames: _col0, _col1