diff --git a/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/rules/HiveSemiJoinRule.java b/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/rules/HiveSemiJoinRule.java index 07ce76287a..c3e82ca7c0 100644 --- a/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/rules/HiveSemiJoinRule.java +++ b/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/rules/HiveSemiJoinRule.java @@ -22,7 +22,6 @@ import org.apache.calcite.plan.RelOptRuleCall; import org.apache.calcite.plan.RelOptRuleOperand; import org.apache.calcite.plan.RelOptUtil; -import org.apache.calcite.plan.hep.HepRelVertex; import org.apache.calcite.rel.RelNode; import org.apache.calcite.rel.core.Aggregate; import org.apache.calcite.rel.core.AggregateCall; @@ -67,32 +66,33 @@ public static final HiveAggregateToSemiJoinRule INSTANCE_AGGREGATE = new HiveAggregateToSemiJoinRule(HiveRelFactories.HIVE_BUILDER); - private HiveSemiJoinRule(RelOptRuleOperand operand, RelBuilderFactory relBuilder) { + private HiveSemiJoinRule(final RelOptRuleOperand operand, final RelBuilderFactory relBuilder) { super(operand, relBuilder, null); } - private RelNode buildProject(final Aggregate aggregate, RexBuilder rexBuilder, RelBuilder relBuilder) { - assert(!aggregate.indicator && aggregate.getAggCallList().isEmpty()); + private RelNode buildProject(final Aggregate aggregate, final RexBuilder rexBuilder, + final RelBuilder relBuilder) { + assert (!aggregate.indicator && aggregate.getAggCallList().isEmpty()); RelNode input = aggregate.getInput(); List groupingKeys = aggregate.getGroupSet().asList(); List projects = new ArrayList<>(); - for(Integer keys:groupingKeys) { + for (Integer keys:groupingKeys) { projects.add(rexBuilder.makeInputRef(input, keys.intValue())); } return relBuilder.push(aggregate.getInput()).project(projects).build(); } private boolean needProject(final RelNode input, final RelNode aggregate) { - if((input instanceof HepRelVertex - && ((HepRelVertex)input).getCurrentRel() instanceof Join) + if (input instanceof Join || input.getRowType().getFieldCount() != aggregate.getRowType().getFieldCount()) { return true; } return false; } - protected void perform(RelOptRuleCall call, ImmutableBitSet topRefs, - RelNode topOperator, Join join, RelNode left, Aggregate aggregate) { + protected void perform(final RelOptRuleCall call, final ImmutableBitSet topRefs, + final RelNode topOperator, final Join join, final RelNode left, + final Aggregate aggregate, final RelNode aggregateInput) { LOG.debug("Matched HiveSemiJoinRule"); final RelOptCluster cluster = join.getCluster(); final RexBuilder rexBuilder = cluster.getRexBuilder(); @@ -109,7 +109,11 @@ protected void perform(RelOptRuleCall call, ImmutableBitSet topRefs, // By the way, neither a super-set nor a sub-set would work. return; } - if(join.getJoinType() == JoinRelType.LEFT) { + if (!joinInfo.isEqui()) { + return; + } + + if (join.getJoinType() == JoinRelType.LEFT) { // since for LEFT join we are only interested in rows from LEFT we can get rid of right side call.transformTo(topOperator.copy(topOperator.getTraitSet(), ImmutableList.of(left))); return; @@ -117,18 +121,15 @@ protected void perform(RelOptRuleCall call, ImmutableBitSet topRefs, if (join.getJoinType() != JoinRelType.INNER) { return; } - if (!joinInfo.isEqui()) { - return; - } + LOG.debug("All conditions matched for HiveSemiJoinRule. Going to apply transformation."); final List newRightKeyBuilder = Lists.newArrayList(); final List aggregateKeys = aggregate.getGroupSet().asList(); for (int key : joinInfo.rightKeys) { newRightKeyBuilder.add(aggregateKeys.get(key)); } - RelNode input = aggregate.getInput(); - final RelNode newRight = needProject(input, aggregate) ? - buildProject(aggregate, rexBuilder, call.builder()) : input; + final RelNode newRight = needProject(aggregateInput, aggregate) + ? buildProject(aggregate, rexBuilder, call.builder()) : aggregateInput; final RexNode newCondition = RelOptUtil.createEquiJoinCondition(left, joinInfo.leftKeys, newRight, joinInfo.rightKeys, rexBuilder); @@ -142,43 +143,45 @@ protected void perform(RelOptRuleCall call, ImmutableBitSet topRefs, public static class HiveProjectToSemiJoinRule extends HiveSemiJoinRule { /** Creates a HiveProjectToSemiJoinRule. */ - public HiveProjectToSemiJoinRule(RelBuilderFactory relBuilder) { + public HiveProjectToSemiJoinRule(final RelBuilderFactory relBuilder) { super( operand(Project.class, some(operand(Join.class, some( operand(RelNode.class, any()), - operand(Aggregate.class, any()))))), + operand(Aggregate.class, + some(operand(RelNode.class, any()))))))), relBuilder); } - @Override public void onMatch(RelOptRuleCall call) { + @Override public void onMatch(final RelOptRuleCall call) { final Project project = call.rel(0); final Join join = call.rel(1); final RelNode left = call.rel(2); final Aggregate aggregate = call.rel(3); final ImmutableBitSet topRefs = RelOptUtil.InputFinder.bits(project.getChildExps(), null); - perform(call, topRefs, project, join, left, aggregate); + final RelNode aggregateInput = call.rel(4); + perform(call, topRefs, project, join, left, aggregate, aggregateInput); } } /** SemiJoinRule that matches a Project on top of a Join with an Aggregate - * as its right child. */ + * as its left child. */ public static class HiveProjectToSemiJoinRuleSwapInputs extends HiveSemiJoinRule { /** Creates a HiveProjectToSemiJoinRule. */ - public HiveProjectToSemiJoinRuleSwapInputs(RelBuilderFactory relBuilder) { + public HiveProjectToSemiJoinRuleSwapInputs(final RelBuilderFactory relBuilder) { super( operand(Project.class, some(operand(Join.class, some( - operand(Aggregate.class, any()), + operand(Aggregate.class, some(operand(RelNode.class, any()))), operand(RelNode.class, any()))))), relBuilder); } - private Project swapInputs(Join join, Project topProject, RelBuilder builder) { + private Project swapInputs(final Join join, final Project topProject, final RelBuilder builder) { RexBuilder rexBuilder = join.getCluster().getRexBuilder(); int rightInputSize = join.getRight().getRowType().getFieldCount(); @@ -188,7 +191,7 @@ private Project swapInputs(Join join, Project topProject, RelBuilder builder) { //swap the join inputs //adjust join condition int[] condAdjustments = new int[joinFields.size()]; - for(int i=0; i