diff --git a/common/src/java/org/apache/hadoop/hive/conf/HiveConf.java b/common/src/java/org/apache/hadoop/hive/conf/HiveConf.java index 049e83713e..3372c7942f 100644 --- a/common/src/java/org/apache/hadoop/hive/conf/HiveConf.java +++ b/common/src/java/org/apache/hadoop/hive/conf/HiveConf.java @@ -1666,7 +1666,7 @@ private static void populateLlapDaemonVarsSet(Set llapDaemonVarsSetLocal "When CBO estimates output rows for a join involving multiple columns, the default behavior assumes" + "the columns are independent. Setting this flag to true will cause the estimator to assume" + "the columns are correlated."), - AGGR_JOIN_TRANSPOSE("hive.transpose.aggr.join", false, "push aggregates through join"), + AGGR_JOIN_TRANSPOSE("hive.transpose.aggr.join", true, "push aggregates through join"), SEMIJOIN_CONVERSION("hive.optimize.semijoin.conversion", true, "convert group by followed by inner equi join into semijoin"), HIVE_COLUMN_ALIGNMENT("hive.order.columnalignment", true, "Flag to control whether we want to try to align" + "columns in operators such as Aggregate or Join so that we try to reduce the number of shuffling stages"), diff --git a/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/HiveDefaultRelMetadataProvider.java b/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/HiveDefaultRelMetadataProvider.java index 0a2714255e..7b5b149941 100644 --- a/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/HiveDefaultRelMetadataProvider.java +++ b/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/HiveDefaultRelMetadataProvider.java @@ -36,6 +36,7 @@ import org.apache.calcite.rel.metadata.RelMetadataProvider; import org.apache.hadoop.hive.conf.HiveConf; import org.apache.hadoop.hive.ql.optimizer.calcite.cost.HiveDefaultCostModel; +import org.apache.hadoop.hive.ql.optimizer.calcite.cost.HiveCostModelWithAggregate; import org.apache.hadoop.hive.ql.optimizer.calcite.cost.HiveOnTezCostModel; import org.apache.hadoop.hive.ql.optimizer.calcite.cost.HiveRelMdCost; import org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveAggregate; @@ -91,6 +92,27 @@ HiveRelMdPredicates.SOURCE, JaninoRelMetadataProvider.DEFAULT))); + /** + * This metadata provider uses cost model with consider Aggregate cost + * besides Join cost. This is used by HiveAggregateJoinTransposeRule + */ + private static final JaninoRelMetadataProvider EXTENDED_PROVIDER = + JaninoRelMetadataProvider.of( + ChainedRelMetadataProvider.of( + ImmutableList.of( + HiveRelMdDistinctRowCount.SOURCE, + new HiveRelMdCost(HiveCostModelWithAggregate.getCostModel()).getMetadataProvider(), + HiveRelMdSelectivity.SOURCE, + HiveRelMdRowCount.SOURCE, + HiveRelMdUniqueKeys.SOURCE, + HiveRelMdColumnUniqueness.SOURCE, + HiveRelMdSize.SOURCE, + HiveRelMdMemory.SOURCE, + HiveRelMdDistribution.SOURCE, + HiveRelMdCollation.SOURCE, + HiveRelMdPredicates.SOURCE, + JaninoRelMetadataProvider.DEFAULT))); + /** * This is the list of operators that are specifically used in Hive and * should be loaded by the metadata providers. @@ -175,6 +197,10 @@ public RelMetadataProvider getMetadataProvider() { return metadataProvider; } + public RelMetadataProvider getMetadataProviderWithAggregateCost() { + return EXTENDED_PROVIDER; + } + /** * This method can be called at startup time to pre-register all the * additional Hive classes (compared to Calcite core classes) that may @@ -186,5 +212,7 @@ public static void initializeMetadataProviderClass() { HiveDefaultRelMetadataProvider.HIVE_REL_NODE_CLASSES); // This will register the classes in the default Hive implementation DEFAULT.register(HiveDefaultRelMetadataProvider.HIVE_REL_NODE_CLASSES); + + EXTENDED_PROVIDER.register(HiveDefaultRelMetadataProvider.HIVE_REL_NODE_CLASSES); } } diff --git a/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/cost/HiveCostModelWithAggregate.java b/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/cost/HiveCostModelWithAggregate.java new file mode 100644 index 0000000000..c71a1680e1 --- /dev/null +++ b/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/cost/HiveCostModelWithAggregate.java @@ -0,0 +1,126 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.hadoop.hive.ql.optimizer.calcite.cost; + +import org.apache.calcite.plan.RelOptCost; +import org.apache.calcite.rel.RelCollation; +import org.apache.calcite.rel.RelDistribution; +import org.apache.calcite.rel.RelDistributions; +import org.apache.calcite.rel.metadata.RelMetadataQuery; +import org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveAggregate; +import org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveJoin; +import org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveTableScan; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.Sets; + +/** + * Default implementation of the cost model. + * Currently used by MR and Spark execution engines. + */ +public class HiveCostModelWithAggregate extends HiveCostModel { + + private static HiveCostModelWithAggregate INSTANCE; + + synchronized public static HiveCostModelWithAggregate getCostModel() { + if (INSTANCE == null) { + INSTANCE = new HiveCostModelWithAggregate(); + } + + return INSTANCE; + } + + private HiveCostModelWithAggregate() { + super(Sets.newHashSet(DefaultJoinAlgorithm.INSTANCE)); + } + + @Override + public RelOptCost getDefaultCost() { + return HiveCost.FACTORY.makeZeroCost(); + } + + @Override + public RelOptCost getScanCost(HiveTableScan ts, RelMetadataQuery mq) { + return HiveCost.FACTORY.makeZeroCost(); + } + + @Override + public RelOptCost getAggregateCost(HiveAggregate aggregate) { + final RelMetadataQuery mq = aggregate.getCluster().getMetadataQuery(); + double rowCount = mq.getRowCount(aggregate.getInput()); + return HiveCost.FACTORY.makeCost(rowCount, 0.0, 0.0); + } + + /** + * Default join algorithm. Cost is based on cardinality. + */ + public static class DefaultJoinAlgorithm implements JoinAlgorithm { + + public static final JoinAlgorithm INSTANCE = new DefaultJoinAlgorithm(); + private static final String ALGORITHM_NAME = "none"; + + + @Override + public String toString() { + return ALGORITHM_NAME; + } + + @Override + public boolean isExecutable(HiveJoin join) { + return true; + } + + @Override + public RelOptCost getCost(HiveJoin join) { + final RelMetadataQuery mq = join.getCluster().getMetadataQuery(); + double leftRCount = mq.getRowCount(join.getLeft()); + double rightRCount = mq.getRowCount(join.getRight()); + return HiveCost.FACTORY.makeCost(leftRCount + rightRCount, 0.0, 0.0); + } + + @Override + public ImmutableList getCollation(HiveJoin join) { + return ImmutableList.of(); + } + + @Override + public RelDistribution getDistribution(HiveJoin join) { + return RelDistributions.SINGLETON; + } + + @Override + public Double getMemory(HiveJoin join) { + return null; + } + + @Override + public Double getCumulativeMemoryWithinPhaseSplit(HiveJoin join) { + return null; + } + + @Override + public Boolean isPhaseTransition(HiveJoin join) { + return false; + } + + @Override + public Integer getSplitCount(HiveJoin join) { + return 1; + } + } +} diff --git a/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/rules/HiveAggregateJoinTransposeRule.java b/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/rules/HiveAggregateJoinTransposeRule.java index ed6659c6cc..884e3fdde8 100644 --- a/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/rules/HiveAggregateJoinTransposeRule.java +++ b/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/rules/HiveAggregateJoinTransposeRule.java @@ -280,12 +280,18 @@ public void onMatch(RelOptRuleCall call) { } } - if (!aggConvertedToProjects) { + if(aggConvertedToProjects) { + // if aggregate was removed from top, we don't really need to compare costs as we can assume + // this will be better/same plan as before + RelNode r = relBuilder.build(); + call.transformTo(r); + return; + } + relBuilder.aggregate( relBuilder.groupKey(Mappings.apply(mapping, aggregate.getGroupSet()), Mappings.apply2(mapping, aggregate.getGroupSets())), newAggCalls); - } // Make a cost based decision to pick cheaper plan RelNode r = relBuilder.build(); diff --git a/ql/src/java/org/apache/hadoop/hive/ql/parse/CalcitePlanner.java b/ql/src/java/org/apache/hadoop/hive/ql/parse/CalcitePlanner.java index f5c5a105bb..d023bce7ed 100644 --- a/ql/src/java/org/apache/hadoop/hive/ql/parse/CalcitePlanner.java +++ b/ql/src/java/org/apache/hadoop/hive/ql/parse/CalcitePlanner.java @@ -1921,7 +1921,8 @@ public RelNode apply(RelOptCluster cluster, RelOptSchema relOptSchema, SchemaPlu if (conf.getBoolVar(ConfVars.AGGR_JOIN_TRANSPOSE)) { perfLogger.PerfLogBegin(this.getClass().getName(), PerfLogger.OPTIMIZER); try { - calciteOptimizedPlan = hepPlan(calciteOptimizedPlan, false, mdProvider.getMetadataProvider(), null, + calciteOptimizedPlan = hepPlan(calciteOptimizedPlan, false, + mdProvider.getMetadataProviderWithAggregateCost(), null, HepMatchOrder.BOTTOM_UP, HiveAggregateJoinTransposeRule.INSTANCE); } catch (Exception e) { boolean isMissingStats = noColsMissingStats.get() > 0; diff --git a/ql/src/test/queries/clientpositive/groupby_join_pushdown.q b/ql/src/test/queries/clientpositive/groupby_join_pushdown.q index 37ecd90bdb..1559571556 100644 --- a/ql/src/test/queries/clientpositive/groupby_join_pushdown.q +++ b/ql/src/test/queries/clientpositive/groupby_join_pushdown.q @@ -73,3 +73,21 @@ explain SELECT sum(f.cint), f.ctinyint FROM alltypesorc f JOIN alltypesorc g ON(f.ctinyint = g.ctinyint) GROUP BY f.ctinyint, g.ctinyint; + +CREATE TABLE employee(id int, name string, deptid int); +INSERT INTO employee values(1,'emp1', 100),(2,'emp2', 200),(3,'emp3', 400), + (5,'emp5', 500),(6,'emp6', 600); +ANALYZE TABLE employee compute statistics; +ANALYZE TABLE employee compute statistics for columns; + +CREATE TABLE department(deptid int, deptname string); +INSERT INTO department values(100,'dept1'),(200, 'dept2'),(300, 'dept3'); +ANALYZE TABLE department compute statistics; +ANALYZE TABLE department compute statistics for columns; + +-- Group by shouldn't be pushed down +explain cbo SELECT count(*) FROM employee JOIN department ON employee.deptid = department.deptid + GROUP BY employee.id; + +DROP TABLE employee; +DROP TABLE department; diff --git a/ql/src/test/results/clientpositive/groupby_join_pushdown.q.out b/ql/src/test/results/clientpositive/groupby_join_pushdown.q.out index dc1e808dcc..64fac4ddaa 100644 --- a/ql/src/test/results/clientpositive/groupby_join_pushdown.q.out +++ b/ql/src/test/results/clientpositive/groupby_join_pushdown.q.out @@ -1796,3 +1796,117 @@ STAGE PLANS: Processor Tree: ListSink +PREHOOK: query: CREATE TABLE employee(id int, name string, deptid int) +PREHOOK: type: CREATETABLE +PREHOOK: Output: database:default +PREHOOK: Output: default@employee +POSTHOOK: query: CREATE TABLE employee(id int, name string, deptid int) +POSTHOOK: type: CREATETABLE +POSTHOOK: Output: database:default +POSTHOOK: Output: default@employee +PREHOOK: query: INSERT INTO employee values(1,'emp1', 100),(2,'emp2', 200),(3,'emp3', 400), + (5,'emp5', 500),(6,'emp6', 600) +PREHOOK: type: QUERY +PREHOOK: Input: _dummy_database@_dummy_table +PREHOOK: Output: default@employee +POSTHOOK: query: INSERT INTO employee values(1,'emp1', 100),(2,'emp2', 200),(3,'emp3', 400), + (5,'emp5', 500),(6,'emp6', 600) +POSTHOOK: type: QUERY +POSTHOOK: Input: _dummy_database@_dummy_table +POSTHOOK: Output: default@employee +POSTHOOK: Lineage: employee.deptid SCRIPT [] +POSTHOOK: Lineage: employee.id SCRIPT [] +POSTHOOK: Lineage: employee.name SCRIPT [] +PREHOOK: query: ANALYZE TABLE employee compute statistics +PREHOOK: type: QUERY +PREHOOK: Input: default@employee +PREHOOK: Output: default@employee +POSTHOOK: query: ANALYZE TABLE employee compute statistics +POSTHOOK: type: QUERY +POSTHOOK: Input: default@employee +POSTHOOK: Output: default@employee +PREHOOK: query: ANALYZE TABLE employee compute statistics for columns +PREHOOK: type: ANALYZE_TABLE +PREHOOK: Input: default@employee +PREHOOK: Output: default@employee +#### A masked pattern was here #### +POSTHOOK: query: ANALYZE TABLE employee compute statistics for columns +POSTHOOK: type: ANALYZE_TABLE +POSTHOOK: Input: default@employee +POSTHOOK: Output: default@employee +#### A masked pattern was here #### +PREHOOK: query: CREATE TABLE department(deptid int, deptname string) +PREHOOK: type: CREATETABLE +PREHOOK: Output: database:default +PREHOOK: Output: default@department +POSTHOOK: query: CREATE TABLE department(deptid int, deptname string) +POSTHOOK: type: CREATETABLE +POSTHOOK: Output: database:default +POSTHOOK: Output: default@department +PREHOOK: query: INSERT INTO department values(100,'dept1'),(200, 'dept2'),(300, 'dept3') +PREHOOK: type: QUERY +PREHOOK: Input: _dummy_database@_dummy_table +PREHOOK: Output: default@department +POSTHOOK: query: INSERT INTO department values(100,'dept1'),(200, 'dept2'),(300, 'dept3') +POSTHOOK: type: QUERY +POSTHOOK: Input: _dummy_database@_dummy_table +POSTHOOK: Output: default@department +POSTHOOK: Lineage: department.deptid SCRIPT [] +POSTHOOK: Lineage: department.deptname SCRIPT [] +PREHOOK: query: ANALYZE TABLE department compute statistics +PREHOOK: type: QUERY +PREHOOK: Input: default@department +PREHOOK: Output: default@department +POSTHOOK: query: ANALYZE TABLE department compute statistics +POSTHOOK: type: QUERY +POSTHOOK: Input: default@department +POSTHOOK: Output: default@department +PREHOOK: query: ANALYZE TABLE department compute statistics for columns +PREHOOK: type: ANALYZE_TABLE +PREHOOK: Input: default@department +PREHOOK: Output: default@department +#### A masked pattern was here #### +POSTHOOK: query: ANALYZE TABLE department compute statistics for columns +POSTHOOK: type: ANALYZE_TABLE +POSTHOOK: Input: default@department +POSTHOOK: Output: default@department +#### A masked pattern was here #### +PREHOOK: query: explain cbo SELECT count(*) FROM employee JOIN department ON employee.deptid = department.deptid + GROUP BY employee.id +PREHOOK: type: QUERY +PREHOOK: Input: default@department +PREHOOK: Input: default@employee +#### A masked pattern was here #### +POSTHOOK: query: explain cbo SELECT count(*) FROM employee JOIN department ON employee.deptid = department.deptid + GROUP BY employee.id +POSTHOOK: type: QUERY +POSTHOOK: Input: default@department +POSTHOOK: Input: default@employee +#### A masked pattern was here #### +CBO PLAN: +HiveProject(_o__c0=[$1]) + HiveAggregate(group=[{0}], agg#0=[count()]) + HiveJoin(condition=[=($1, $2)], joinType=[inner], algorithm=[none], cost=[not available]) + HiveProject(id=[$0], deptid=[$2]) + HiveFilter(condition=[IS NOT NULL($2)]) + HiveTableScan(table=[[default, employee]], table:alias=[employee]) + HiveProject(deptid=[$0]) + HiveFilter(condition=[IS NOT NULL($0)]) + HiveTableScan(table=[[default, department]], table:alias=[department]) + +PREHOOK: query: DROP TABLE employee +PREHOOK: type: DROPTABLE +PREHOOK: Input: default@employee +PREHOOK: Output: default@employee +POSTHOOK: query: DROP TABLE employee +POSTHOOK: type: DROPTABLE +POSTHOOK: Input: default@employee +POSTHOOK: Output: default@employee +PREHOOK: query: DROP TABLE department +PREHOOK: type: DROPTABLE +PREHOOK: Input: default@department +PREHOOK: Output: default@department +POSTHOOK: query: DROP TABLE department +POSTHOOK: type: DROPTABLE +POSTHOOK: Input: default@department +POSTHOOK: Output: default@department