diff --git ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/rules/HiveAggregateReduceRule.java ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/rules/HiveAggregateReduceRule.java new file mode 100644 index 0000000..823d958 --- /dev/null +++ ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/rules/HiveAggregateReduceRule.java @@ -0,0 +1,120 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to you under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.hadoop.hive.ql.optimizer.calcite.rules; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import org.apache.calcite.plan.RelOptRule; +import org.apache.calcite.plan.RelOptRuleCall; +import org.apache.calcite.rel.core.Aggregate; +import org.apache.calcite.rel.core.AggregateCall; +import org.apache.calcite.rex.RexBuilder; +import org.apache.calcite.rex.RexNode; +import org.apache.calcite.sql.SqlKind; +import org.apache.calcite.tools.RelBuilder; +import org.apache.hadoop.hive.ql.optimizer.calcite.HiveRelFactories; +import org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveAggregate; + +import com.google.common.collect.Lists; + +/** + * Planner rule that reduces aggregate functions in + * {@link org.apache.calcite.rel.core.Aggregate}s to simpler forms. + * + *

Rewrites: + *

+ */ +public class HiveAggregateReduceRule extends RelOptRule { + + /** The singleton. */ + public static final HiveAggregateReduceRule INSTANCE = + new HiveAggregateReduceRule(); + + /** Private constructor. */ + private HiveAggregateReduceRule() { + super(operand(HiveAggregate.class, any()), + HiveRelFactories.HIVE_BUILDER, null); + } + + @Override + public void onMatch(RelOptRuleCall call) { + final RelBuilder relBuilder = call.builder(); + final Aggregate aggRel = (Aggregate) call.rel(0); + final RexBuilder rexBuilder = aggRel.getCluster().getRexBuilder(); + + // We try to rewrite COUNT(x) into COUNT(*) if x is not nullable. + // We remove duplicate aggregate calls as well. + boolean rewrite = false; + final Map mapping = new HashMap<>(); + final List indexes = new ArrayList<>(); + final List aggCalls = aggRel.getAggCallList(); + final List newAggCalls = new ArrayList<>(aggCalls.size()); + int nextIdx = aggRel.getGroupCount() + aggRel.getIndicatorCount(); + for (int i = 0; i < aggCalls.size(); i++) { + AggregateCall aggCall = aggCalls.get(i); + if (aggCall.getAggregation().getKind() == SqlKind.COUNT && !aggCall.isDistinct()) { + final List args = aggCall.getArgList(); + final List nullableArgs = new ArrayList<>(args.size()); + for (int arg : args) { + if (aggRel.getInput().getRowType().getFieldList().get(arg).getType().isNullable()) { + nullableArgs.add(arg); + } + } + if (nullableArgs.size() != args.size()) { + aggCall = aggCall.copy(nullableArgs, aggCall.filterArg); + rewrite = true; + } + } + Integer idx = mapping.get(aggCall); + if (idx == null) { + newAggCalls.add(aggCall); + idx = nextIdx++; + mapping.put(aggCall, idx); + } else { + rewrite = true; + } + indexes.add(idx); + } + + if (rewrite) { + // We trigger the transform + final List projList = Lists.newArrayList(); + for (int i = 0; i < aggRel.getGroupCount() + aggRel.getIndicatorCount(); ++i) { + projList.add( + rexBuilder.makeInputRef( + aggRel.getRowType().getFieldList().get(i).getType(), i)); + } + for (int idx : indexes) { + projList.add( + rexBuilder.makeInputRef( + aggRel.getRowType().getFieldList().get(idx).getType(), idx)); + } + + final Aggregate newAggregate = aggRel.copy(aggRel.getTraitSet(), aggRel.getInput(), + aggRel.indicator, aggRel.getGroupSet(), aggRel.getGroupSets(), + newAggCalls); + call.transformTo(relBuilder.push(newAggregate).project(projList).build()); + } + } + +} diff --git ql/src/java/org/apache/hadoop/hive/ql/parse/CalcitePlanner.java ql/src/java/org/apache/hadoop/hive/ql/parse/CalcitePlanner.java index 931e074..efbedef 100644 --- ql/src/java/org/apache/hadoop/hive/ql/parse/CalcitePlanner.java +++ ql/src/java/org/apache/hadoop/hive/ql/parse/CalcitePlanner.java @@ -171,6 +171,7 @@ import org.apache.hadoop.hive.ql.optimizer.calcite.rules.HiveAggregateJoinTransposeRule; import org.apache.hadoop.hive.ql.optimizer.calcite.rules.HiveAggregateProjectMergeRule; import org.apache.hadoop.hive.ql.optimizer.calcite.rules.HiveAggregatePullUpConstantsRule; +import org.apache.hadoop.hive.ql.optimizer.calcite.rules.HiveAggregateReduceRule; import org.apache.hadoop.hive.ql.optimizer.calcite.rules.HiveDruidProjectFilterTransposeRule; import org.apache.hadoop.hive.ql.optimizer.calcite.rules.HiveExceptRewriteRule; import org.apache.hadoop.hive.ql.optimizer.calcite.rules.HiveExpandDistinctAggregatesRule; @@ -1682,6 +1683,7 @@ private RelNode applyPreJoinOrderingTransforms(RelNode basePlan, RelMetadataProv rules.add(HiveReduceExpressionsRule.PROJECT_INSTANCE); rules.add(HiveReduceExpressionsRule.FILTER_INSTANCE); rules.add(HiveReduceExpressionsRule.JOIN_INSTANCE); + rules.add(HiveAggregateReduceRule.INSTANCE); if (conf.getBoolVar(HiveConf.ConfVars.HIVEPOINTLOOKUPOPTIMIZER)) { rules.add(new HivePointLookupOptimizerRule.FilterCondition(minNumORClauses)); rules.add(new HivePointLookupOptimizerRule.JoinCondition(minNumORClauses)); diff --git ql/src/test/results/clientpositive/llap/except_distinct.q.out ql/src/test/results/clientpositive/llap/except_distinct.q.out index e4c2941..92c628d 100644 --- ql/src/test/results/clientpositive/llap/except_distinct.q.out +++ ql/src/test/results/clientpositive/llap/except_distinct.q.out @@ -218,11 +218,11 @@ STAGE PLANS: Statistics: Num rows: 500 Data size: 89000 Basic stats: COMPLETE Column stats: COMPLETE Select Operator expressions: key (type: string), value (type: string) - outputColumnNames: _col0, _col1 + outputColumnNames: key, value Statistics: Num rows: 500 Data size: 89000 Basic stats: COMPLETE Column stats: COMPLETE Group By Operator - aggregations: count(2) - keys: _col0 (type: string), _col1 (type: string) + aggregations: count() + keys: key (type: string), value (type: string) mode: hash outputColumnNames: _col0, _col1, _col2 Statistics: Num rows: 250 Data size: 46500 Basic stats: COMPLETE Column stats: COMPLETE @@ -241,11 +241,11 @@ STAGE PLANS: Statistics: Num rows: 500 Data size: 89000 Basic stats: COMPLETE Column stats: COMPLETE Select Operator expressions: key (type: string), value (type: string) - outputColumnNames: _col0, _col1 + outputColumnNames: key, value Statistics: Num rows: 500 Data size: 89000 Basic stats: COMPLETE Column stats: COMPLETE Group By Operator - aggregations: count(1) - keys: _col0 (type: string), _col1 (type: string) + aggregations: count() + keys: key (type: string), value (type: string) mode: hash outputColumnNames: _col0, _col1, _col2 Statistics: Num rows: 250 Data size: 46500 Basic stats: COMPLETE Column stats: COMPLETE @@ -384,11 +384,11 @@ STAGE PLANS: Statistics: Num rows: 500 Data size: 89000 Basic stats: COMPLETE Column stats: COMPLETE Select Operator expressions: key (type: string), value (type: string) - outputColumnNames: _col0, _col1 + outputColumnNames: key, value Statistics: Num rows: 500 Data size: 89000 Basic stats: COMPLETE Column stats: COMPLETE Group By Operator - aggregations: count(2) - keys: _col0 (type: string), _col1 (type: string) + aggregations: count() + keys: key (type: string), value (type: string) mode: hash outputColumnNames: _col0, _col1, _col2 Statistics: Num rows: 250 Data size: 46500 Basic stats: COMPLETE Column stats: COMPLETE @@ -400,11 +400,11 @@ STAGE PLANS: value expressions: _col2 (type: bigint) Select Operator expressions: key (type: string), value (type: string) - outputColumnNames: _col0, _col1 + outputColumnNames: key, value Statistics: Num rows: 500 Data size: 89000 Basic stats: COMPLETE Column stats: COMPLETE Group By Operator - aggregations: count(1) - keys: _col0 (type: string), _col1 (type: string) + aggregations: count() + keys: key (type: string), value (type: string) mode: hash outputColumnNames: _col0, _col1, _col2 Statistics: Num rows: 250 Data size: 46500 Basic stats: COMPLETE Column stats: COMPLETE @@ -416,11 +416,11 @@ STAGE PLANS: value expressions: _col2 (type: bigint) Select Operator expressions: key (type: string), value (type: string) - outputColumnNames: _col0, _col1 + outputColumnNames: key, value Statistics: Num rows: 500 Data size: 89000 Basic stats: COMPLETE Column stats: COMPLETE Group By Operator - aggregations: count(1) - keys: _col0 (type: string), _col1 (type: string) + aggregations: count() + keys: key (type: string), value (type: string) mode: hash outputColumnNames: _col0, _col1, _col2 Statistics: Num rows: 250 Data size: 46500 Basic stats: COMPLETE Column stats: COMPLETE @@ -439,11 +439,11 @@ STAGE PLANS: Statistics: Num rows: 500 Data size: 89000 Basic stats: COMPLETE Column stats: COMPLETE Select Operator expressions: key (type: string), value (type: string) - outputColumnNames: _col0, _col1 + outputColumnNames: key, value Statistics: Num rows: 500 Data size: 89000 Basic stats: COMPLETE Column stats: COMPLETE Group By Operator - aggregations: count(1) - keys: _col0 (type: string), _col1 (type: string) + aggregations: count() + keys: key (type: string), value (type: string) mode: hash outputColumnNames: _col0, _col1, _col2 Statistics: Num rows: 250 Data size: 46500 Basic stats: COMPLETE Column stats: COMPLETE @@ -559,7 +559,7 @@ STAGE PLANS: outputColumnNames: _col0, _col1 Statistics: Num rows: 41 Data size: 7954 Basic stats: COMPLETE Column stats: COMPLETE Group By Operator - aggregations: count(2) + aggregations: count() keys: _col0 (type: string), _col1 (type: string) mode: complete outputColumnNames: _col0, _col1, _col2 @@ -601,7 +601,7 @@ STAGE PLANS: outputColumnNames: _col0, _col1 Statistics: Num rows: 24 Data size: 4656 Basic stats: COMPLETE Column stats: COMPLETE Group By Operator - aggregations: count(2) + aggregations: count() keys: _col0 (type: string), _col1 (type: string) mode: complete outputColumnNames: _col0, _col1, _col2 @@ -768,7 +768,7 @@ STAGE PLANS: outputColumnNames: _col0 Statistics: Num rows: 2 Data size: 6 Basic stats: COMPLETE Column stats: NONE Group By Operator - aggregations: count(2) + aggregations: count() keys: _col0 (type: int) mode: complete outputColumnNames: _col0, _col1 @@ -825,7 +825,7 @@ STAGE PLANS: outputColumnNames: _col0 Statistics: Num rows: 2 Data size: 6 Basic stats: COMPLETE Column stats: NONE Group By Operator - aggregations: count(1) + aggregations: count() keys: _col0 (type: int) mode: complete outputColumnNames: _col0, _col1