diff --git a/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/rules/HiveSemiJoinRule.java b/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/rules/HiveSemiJoinRule.java index 7799090d43..37ddcca469 100644 --- a/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/rules/HiveSemiJoinRule.java +++ b/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/rules/HiveSemiJoinRule.java @@ -29,8 +29,10 @@ import org.apache.calcite.rel.core.JoinInfo; import org.apache.calcite.rel.core.JoinRelType; import org.apache.calcite.rel.core.Project; +import org.apache.calcite.rel.type.RelDataTypeField; import org.apache.calcite.rex.RexBuilder; import org.apache.calcite.rex.RexNode; +import org.apache.calcite.tools.RelBuilder; import org.apache.calcite.tools.RelBuilderFactory; import org.apache.calcite.util.ImmutableBitSet; import org.apache.calcite.util.ImmutableIntList; @@ -42,6 +44,7 @@ import com.google.common.collect.Lists; import java.util.ArrayList; +import java.util.Arrays; import java.util.List; /** @@ -59,6 +62,9 @@ public static final HiveProjectToSemiJoinRule INSTANCE_PROJECT = new HiveProjectToSemiJoinRule(HiveRelFactories.HIVE_BUILDER); + public static final HiveProjectToSemiJoinRuleSwapInputs INSTANCE_PROJECT_SWAPPED = + new HiveProjectToSemiJoinRuleSwapInputs (HiveRelFactories.HIVE_BUILDER); + public static final HiveAggregateToSemiJoinRule INSTANCE_AGGREGATE = new HiveAggregateToSemiJoinRule(HiveRelFactories.HIVE_BUILDER); @@ -153,6 +159,92 @@ public HiveProjectToSemiJoinRule(RelBuilderFactory relBuilder) { } } + /** SemiJoinRule that matches a Project on top of a Join with an Aggregate + * as its right child. */ + public static class HiveProjectToSemiJoinRuleSwapInputs extends HiveSemiJoinRule { + + /** Creates a HiveProjectToSemiJoinRule. */ + public HiveProjectToSemiJoinRuleSwapInputs(RelBuilderFactory relBuilder) { + super( + operand(Project.class, + some(operand(Join.class, + some( + operand(Aggregate.class, any()), + operand(RelNode.class, any()))))), + relBuilder); + } + + private Project swapInputs(Join join, Project topProject, RelBuilder builder) { + RexBuilder rexBuilder = join.getCluster().getRexBuilder(); + + int rightInputSize = join.getRight().getRowType().getFieldCount(); + int leftInputSize = join.getLeft().getRowType().getFieldCount(); + List joinFields = join.getRowType().getFieldList(); + + //swap the join inputs + //adjust join condition + int[] condAdjustments = new int[joinFields.size()]; + for(int i=0; i newProjects = new ArrayList<>(); + + List swappedJoinFeilds = swappedJoin.getRowType().getFieldList(); + for(RexNode project:topProject.getProjects()) { + RexNode newProject = project.accept(new RelOptUtil.RexInputConverter(rexBuilder,swappedJoinFeilds, + swappedJoinFeilds, condAdjustments)); + newProjects.add(newProject); + } + return (Project)builder.push(swappedJoin).project(newProjects).build(); + } + + @Override public void onMatch(RelOptRuleCall call) { + final Project project = call.rel(0); + final Join join = call.rel(1); + final RelNode right = call.rel(3); + final Aggregate aggregate = call.rel(2); + + // make sure the following conditions are met + // Join is INNER + // project above is referring to inputs only from non-aggregate side + if(join.getJoinType() != JoinRelType.INNER) { + return; + } + + // TODO: Ideally this condition should be in match + //FIXME: other condition is if join key is not same as group by key + final ImmutableBitSet topRefs = + RelOptUtil.InputFinder.bits(project.getChildExps(), null); + + final ImmutableBitSet leftBits = + ImmutableBitSet.range(0, join.getLeft().getRowType().getFieldCount()); + + if (topRefs.intersects(leftBits)) { + return; + } + // it is safe to swap inputs + final Project swappedProject = swapInputs(join, project, call.builder()); + final RelNode swappedJoin = swappedProject.getInput(); + assert(swappedJoin instanceof Join); + + final ImmutableBitSet swappedTopRefs = + RelOptUtil.InputFinder.bits(swappedProject.getChildExps(), null); + + perform(call, swappedTopRefs, swappedProject, (Join)swappedJoin, right, aggregate); + } + } + /** SemiJoinRule that matches a Aggregate on top of a Join with an Aggregate * as its right child. */ public static class HiveAggregateToSemiJoinRule extends HiveSemiJoinRule { diff --git a/ql/src/java/org/apache/hadoop/hive/ql/parse/CalcitePlanner.java b/ql/src/java/org/apache/hadoop/hive/ql/parse/CalcitePlanner.java index 82e975a50d..f71b6b19c5 100644 --- a/ql/src/java/org/apache/hadoop/hive/ql/parse/CalcitePlanner.java +++ b/ql/src/java/org/apache/hadoop/hive/ql/parse/CalcitePlanner.java @@ -1880,7 +1880,7 @@ public RelNode apply(RelOptCluster cluster, RelOptSchema relOptSchema, SchemaPlu if (conf.getBoolVar(ConfVars.SEMIJOIN_CONVERSION)) { perfLogger.PerfLogBegin(this.getClass().getName(), PerfLogger.OPTIMIZER); calciteOptimizedPlan = hepPlan(calciteOptimizedPlan, false, mdProvider.getMetadataProvider(), null, - HiveSemiJoinRule.INSTANCE_PROJECT, HiveSemiJoinRule.INSTANCE_AGGREGATE); + HiveSemiJoinRule.INSTANCE_PROJECT, HiveSemiJoinRule.INSTANCE_PROJECT_SWAPPED, HiveSemiJoinRule.INSTANCE_AGGREGATE); perfLogger.PerfLogEnd(this.getClass().getName(), PerfLogger.OPTIMIZER, "Calcite: Semijoin conversion"); } diff --git a/ql/src/test/queries/clientpositive/semijoin.q b/ql/src/test/queries/clientpositive/semijoin.q index 144069bbe6..e1d31cc8f3 100644 --- a/ql/src/test/queries/clientpositive/semijoin.q +++ b/ql/src/test/queries/clientpositive/semijoin.q @@ -1,4 +1,5 @@ --! qt:dataset:src +--! qt:dataset:part SET hive.vectorized.execution.enabled=false; set hive.mapred.mode=nonstrict; -- SORT_QUERY_RESULTS @@ -86,3 +87,6 @@ explain select key, value from src outr left semi join select key, value from src outr left semi join (select a.key, b.value from src a join (select distinct value from src) b on a.value > b.value group by a.key, b.value) inr on outr.key=inr.key and outr.value=inr.value; + +explain cbo select pp.p_partkey from (select distinct p_name from part) p join part pp on pp.p_name = p.p_name; +select pp.p_partkey from (select distinct p_name from part) p join part pp on pp.p_name = p.p_name; diff --git a/ql/src/test/results/clientpositive/llap/semijoin.q.out b/ql/src/test/results/clientpositive/llap/semijoin.q.out index 531ef46c78..63a270e57d 100644 --- a/ql/src/test/results/clientpositive/llap/semijoin.q.out +++ b/ql/src/test/results/clientpositive/llap/semijoin.q.out @@ -3076,3 +3076,55 @@ POSTHOOK: query: select key, value from src outr left semi join POSTHOOK: type: QUERY POSTHOOK: Input: default@src #### A masked pattern was here #### +PREHOOK: query: explain cbo select pp.p_partkey from (select distinct p_name from part) p join part pp on pp.p_name = p.p_name +PREHOOK: type: QUERY +PREHOOK: Input: default@part +#### A masked pattern was here #### +POSTHOOK: query: explain cbo select pp.p_partkey from (select distinct p_name from part) p join part pp on pp.p_name = p.p_name +POSTHOOK: type: QUERY +POSTHOOK: Input: default@part +#### A masked pattern was here #### +CBO PLAN: +HiveProject(p_partkey=[$0]) + HiveSemiJoin(condition=[=($1, $3)], joinType=[inner]) + HiveProject(p_partkey=[$0], p_name=[$1]) + HiveFilter(condition=[IS NOT NULL($1)]) + HiveTableScan(table=[[default, part]], table:alias=[pp]) + HiveProject(p_partkey=[$0], p_name=[$1], p_mfgr=[$2], p_brand=[$3], p_type=[$4], p_size=[$5], p_container=[$6], p_retailprice=[$7], p_comment=[$8], BLOCK__OFFSET__INSIDE__FILE=[$9], INPUT__FILE__NAME=[$10], ROW__ID=[$11]) + HiveFilter(condition=[IS NOT NULL($1)]) + HiveTableScan(table=[[default, part]], table:alias=[part]) + +PREHOOK: query: select pp.p_partkey from (select distinct p_name from part) p join part pp on pp.p_name = p.p_name +PREHOOK: type: QUERY +PREHOOK: Input: default@part +#### A masked pattern was here #### +POSTHOOK: query: select pp.p_partkey from (select distinct p_name from part) p join part pp on pp.p_name = p.p_name +POSTHOOK: type: QUERY +POSTHOOK: Input: default@part +#### A masked pattern was here #### +105685 +110592 +112398 +121152 +121152 +132666 +144293 +146985 +15103 +155733 +17273 +17927 +191709 +192697 +195606 +33357 +40982 +42669 +45261 +48427 +49671 +65667 +78486 +85768 +86428 +90681