diff --git a/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/HiveRelOptUtil.java b/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/HiveRelOptUtil.java index 50fbb78a94..268284a6da 100644 --- a/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/HiveRelOptUtil.java +++ b/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/HiveRelOptUtil.java @@ -21,13 +21,19 @@ import java.util.ArrayList; import java.util.List; +import com.google.common.collect.ImmutableList; import org.apache.calcite.plan.RelOptCluster; import org.apache.calcite.plan.RelOptUtil; import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rel.core.AggregateCall; +import org.apache.calcite.rel.core.Filter; +import org.apache.calcite.rel.core.RelFactories; import org.apache.calcite.rel.type.RelDataType; import org.apache.calcite.rel.type.RelDataTypeField; import org.apache.calcite.rex.RexBuilder; import org.apache.calcite.rex.RexCall; +import org.apache.calcite.rex.RexFieldAccess; +import org.apache.calcite.rex.RexInputRef; import org.apache.calcite.rex.RexNode; import org.apache.calcite.rex.RexUtil; import org.apache.calcite.sql.SqlKind; @@ -347,4 +353,111 @@ public String get(int index) { }, true, relBuilder); } + public static RexNode splitCorrelatedFilterCondition( + Filter filter, + List joinKeys, + List correlatedJoinKeys, + boolean extractCorrelatedFieldAccess) { + final List nonEquiList = new ArrayList<>(); + + splitCorrelatedFilterCondition( + filter, + filter.getCondition(), + joinKeys, + correlatedJoinKeys, + nonEquiList, + extractCorrelatedFieldAccess); + + // Convert the remainders into a list that are AND'ed together. + return RexUtil.composeConjunction( + filter.getCluster().getRexBuilder(), nonEquiList, true); + } + + private static void splitCorrelatedFilterCondition( + Filter filter, + RexNode condition, + List joinKeys, + List correlatedJoinKeys, + List nonEquiList, + boolean extractCorrelatedFieldAccess) { + if (condition instanceof RexCall) { + RexCall call = (RexCall) condition; + if (call.getOperator().getKind() == SqlKind.AND) { + for (RexNode operand : call.getOperands()) { + splitCorrelatedFilterCondition( + filter, + operand, + joinKeys, + correlatedJoinKeys, + nonEquiList, + extractCorrelatedFieldAccess); + } + return; + } + + if (call.getOperator().getKind() == SqlKind.EQUALS) { + final List operands = call.getOperands(); + RexNode op0 = operands.get(0); + RexNode op1 = operands.get(1); + + if (extractCorrelatedFieldAccess) { + if (!RexUtil.containsFieldAccess(op0) + && (op1 instanceof RexFieldAccess)) { + joinKeys.add(op0); + correlatedJoinKeys.add(op1); + return; + } else if ( + (op0 instanceof RexFieldAccess) + && !RexUtil.containsFieldAccess(op1)) { + correlatedJoinKeys.add(op0); + joinKeys.add(op1); + return; + } + } else { + if (!(RexUtil.containsInputRef(op0)) + && (op1 instanceof RexInputRef)) { + correlatedJoinKeys.add(op0); + joinKeys.add(op1); + return; + } else if ( + (op0 instanceof RexInputRef) + && !(RexUtil.containsInputRef(op1))) { + joinKeys.add(op0); + correlatedJoinKeys.add(op1); + return; + } + } + } + } + + // The operator is not of RexCall type + // So we fail. Fall through. + // Add this condition to the list of non-equi-join conditions. + nonEquiList.add(condition); + } + + /** + * Creates a LogicalAggregate that removes all duplicates from the result of + * an underlying relational expression. + * + * @param rel underlying rel + * @return rel implementing SingleValueAgg + */ + public static RelNode createSingleValueAggRel( + RelOptCluster cluster, + RelNode rel, + RelFactories.AggregateFactory aggregateFactory) { + // assert (rel.getRowType().getFieldCount() == 1); + final int aggCallCnt = rel.getRowType().getFieldCount(); + final List aggCalls = new ArrayList<>(); + + for (int i = 0; i < aggCallCnt; i++) { + aggCalls.add( + AggregateCall.create( + SqlStdOperatorTable.SINGLE_VALUE, false, false, + ImmutableList.of(i), -1, 0, rel, null, null)); + } + + return aggregateFactory.createAggregate(rel, false, ImmutableBitSet.of(), null, aggCalls); + } } diff --git a/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/cost/HiveVolcanoPlanner.java b/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/cost/HiveVolcanoPlanner.java index 23ff518ada..5dcbff6659 100644 --- a/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/cost/HiveVolcanoPlanner.java +++ b/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/cost/HiveVolcanoPlanner.java @@ -97,7 +97,10 @@ public RelOptCost getCost(RelNode rel, RelMetadataQuery mq) { if (rel instanceof RelSubset) { // Get cost of the subset, best rel may have been chosen or not RelSubset subset = (RelSubset) rel; - return getCost(Util.first(subset.getBest(), subset.getOriginal()), mq); + if (subset.getBest() != null) { + return getCost(subset.getBest(), mq); + } + return costFactory.makeInfiniteCost(); } if (rel.getTraitSet().getTrait(ConventionTraitDef.INSTANCE) == Convention.NONE) { diff --git a/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/rules/HiveRelDecorrelator.java b/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/rules/HiveRelDecorrelator.java index 24e22a0bf3..7e3d495933 100644 --- a/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/rules/HiveRelDecorrelator.java +++ b/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/rules/HiveRelDecorrelator.java @@ -53,22 +53,14 @@ import org.apache.calcite.rel.core.Correlate; import org.apache.calcite.rel.core.CorrelationId; import org.apache.calcite.rel.core.Filter; +import org.apache.calcite.rel.core.Join; import org.apache.calcite.rel.core.JoinRelType; import org.apache.calcite.rel.core.Project; -import org.apache.calcite.rel.core.RelFactories; import org.apache.calcite.rel.core.Sort; import org.apache.calcite.rel.core.Values; -import org.apache.calcite.rel.logical.LogicalAggregate; import org.apache.calcite.rel.logical.LogicalCorrelate; -import org.apache.calcite.rel.logical.LogicalFilter; -import org.apache.calcite.rel.logical.LogicalIntersect; -import org.apache.calcite.rel.logical.LogicalJoin; -import org.apache.calcite.rel.logical.LogicalProject; -import org.apache.calcite.rel.logical.LogicalUnion; import org.apache.calcite.rel.metadata.RelMdUtil; import org.apache.calcite.rel.metadata.RelMetadataQuery; -import org.apache.calcite.rel.rules.FilterJoinRule; -import org.apache.calcite.rel.rules.FilterProjectTransposeRule; import org.apache.calcite.rel.type.RelDataType; import org.apache.calcite.rel.type.RelDataTypeFactory; import org.apache.calcite.rel.type.RelDataTypeField; @@ -104,6 +96,8 @@ import org.apache.calcite.util.Stacks; import org.apache.calcite.util.Util; import org.apache.calcite.util.mapping.Mappings; +import org.apache.hadoop.hive.ql.optimizer.calcite.HiveRelFactories; +import org.apache.hadoop.hive.ql.optimizer.calcite.HiveRelOptUtil; import org.apache.hadoop.hive.ql.optimizer.calcite.HiveRelShuttleImpl; import org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveAggregate; import org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveFilter; @@ -137,7 +131,7 @@ * NOTE: this whole logic is replicated from Calcite's RelDecorrelator * and is exteneded to make it suitable for HIVE * We should get rid of this and replace it with Calcite's RelDecorrelator - * once that works with Join, Project etc instead of LogicalJoin, LogicalProject. + * once that works with Join, Project etc instead of Join, Project. * At this point this has differed from Calcite's version significantly so cannot * get rid of this. * @@ -198,8 +192,7 @@ private HiveRelDecorrelator( this.cm = cm; this.rexBuilder = cluster.getRexBuilder(); this.context = context; - relBuilder = RelFactories.LOGICAL_BUILDER.create(cluster, null); - + relBuilder = HiveRelFactories.HIVE_BUILDER.create(cluster, null); } //~ Methods ---------------------------------------------------------------- @@ -245,8 +238,8 @@ private RelNode decorrelate(RelNode root) { HepProgram program = HepProgram.builder() .addRuleInstance(new AdjustProjectForCountAggregateRule(false)) .addRuleInstance(new AdjustProjectForCountAggregateRule(true)) - .addRuleInstance(FilterJoinRule.FILTER_ON_JOIN) - .addRuleInstance(FilterProjectTransposeRule.INSTANCE) + .addRuleInstance(HiveFilterJoinRule.FILTER_ON_JOIN) + .addRuleInstance(HiveFilterProjectTransposeRule.INSTANCE) // FilterCorrelateRule rule mistakenly pushes a FILTER, consiting of correlated vars, // on top of LogicalCorrelate to within left input for scalar corr queries // which causes exception during decorrelation. This has been disabled for now. @@ -265,8 +258,8 @@ private RelNode decorrelate(RelNode root) { if (frame != null) { // has been rewritten; apply rules post-decorrelation final HepProgram program2 = HepProgram.builder() - .addRuleInstance(FilterJoinRule.FILTER_ON_JOIN) - .addRuleInstance(FilterJoinRule.JOIN) + .addRuleInstance(HiveFilterJoinRule.FILTER_ON_JOIN) + .addRuleInstance(HiveFilterJoinRule.JOIN) .build(); final HepPlanner planner2 = createPlanner(program2); @@ -504,11 +497,11 @@ public Frame decorrelateRel(Values rel) { } /** - * Rewrites a {@link LogicalAggregate}. + * Rewrites a {@link Aggregate}. * * @param rel Aggregate to rewrite */ - public Frame decorrelateRel(LogicalAggregate rel) throws SemanticException{ + public Frame decorrelateRel(Aggregate rel) throws SemanticException{ if (rel.getGroupType() != Aggregate.Group.SIMPLE) { throw new AssertionError(Bug.CALCITE_461_FIXED); } @@ -654,12 +647,8 @@ public Frame decorrelateRel(LogicalAggregate rel) throws SemanticException{ newInputOutputFieldCount + i); } - relBuilder.push( - LogicalAggregate.create(newProject, - false, - newGroupSet, - null, - newAggCalls)); + relBuilder.push(newProject) + .aggregate(relBuilder.groupKey(newGroupSet, null), newAggCalls); if (!omittedConstants.isEmpty()) { final List postProjects = new ArrayList<>(relBuilder.fields()); @@ -877,17 +866,17 @@ public Frame decorrelateRel(HiveProject rel) throws SemanticException{ final List oldProjects = rel.getProjects(); final List relOutput = rel.getRowType().getFieldList(); - // LogicalProject projects the original expressions, + // Project projects the original expressions, // plus any correlated variables the input wants to pass along. final List> projects = Lists.newArrayList(); - // If this LogicalProject has correlated reference, create value generator + // If this Project has correlated reference, create value generator // and produce the correlated variables in the new output. if (cm.mapRefRelToCorRef.containsKey(rel)) { frame = decorrelateInputWithValueGenerator(rel); } - // LogicalProject projects the original expressions + // Project projects the original expressions final Map mapOldToNewOutputs = new HashMap<>(); int newPos; for (newPos = 0; newPos < oldProjects.size(); newPos++) { @@ -917,11 +906,11 @@ public Frame decorrelateRel(HiveProject rel) throws SemanticException{ corDefOutputs); } /** - * Rewrite LogicalProject. + * Rewrite Project. * * @param rel the project rel to rewrite */ - public Frame decorrelateRel(LogicalProject rel) throws SemanticException{ + public Frame decorrelateRel(Project rel) throws SemanticException{ // // Rewrite logic: // @@ -937,17 +926,17 @@ public Frame decorrelateRel(LogicalProject rel) throws SemanticException{ final List oldProjects = rel.getProjects(); final List relOutput = rel.getRowType().getFieldList(); - // LogicalProject projects the original expressions, + // Project projects the original expressions, // plus any correlated variables the input wants to pass along. final List> projects = Lists.newArrayList(); - // If this LogicalProject has correlated reference, create value generator + // If this Project has correlated reference, create value generator // and produce the correlated variables in the new output. if (cm.mapRefRelToCorRef.containsKey(rel)) { frame = decorrelateInputWithValueGenerator(rel); } - // LogicalProject projects the original expressions + // Project projects the original expressions final Map mapOldToNewOutputs = new HashMap<>(); int newPos; for (newPos = 0; newPos < oldProjects.size(); newPos++) { @@ -977,13 +966,6 @@ public Frame decorrelateRel(LogicalProject rel) throws SemanticException{ /** * Create RelNode tree that produces a list of correlated variables. - * - * @param correlations correlated variables to generate - * @param valueGenFieldOffset offset in the output that generated columns - * will start - * @param mapCorVarToOutputPos output positions for the correlated variables - * generated - * @return RelNode the root of the resultant RelNode tree */ private RelNode createValueGenerator( Iterable correlations, @@ -1039,11 +1021,9 @@ private RelNode createValueGenerator( assert newInput != null; if (!joinedInputs.contains(newInput)) { - RelNode project = - RelOptUtil.createProject( - newInput, - mapNewInputToOutputs.get(newInput)); - RelNode distinct = RelOptUtil.createDistinctRel(project); + RelNode project = RelOptUtil.createProject( + HiveRelFactories.HIVE_PROJECT_FACTORY, newInput, mapNewInputToOutputs.get(newInput)); + RelNode distinct = relBuilder.push(project).distinct().build(); RelOptCluster cluster = distinct.getCluster(); joinedInputs.add(newInput); @@ -1053,10 +1033,8 @@ private RelNode createValueGenerator( if (r == null) { r = distinct; } else { - r = - LogicalJoin.create(r, distinct, - cluster.getRexBuilder().makeLiteral(true), - ImmutableSet.of(), JoinRelType.INNER); + r = relBuilder.push(r).push(distinct) + .join(JoinRelType.INNER, cluster.getRexBuilder().makeLiteral(true)).build(); } } } @@ -1153,13 +1131,12 @@ private Frame decorrelateInputWithValueGenerator(RelNode rel) { leftInputOutputCount, corDefOutputs); - RelNode join = - LogicalJoin.create(frame.r, valueGenRel, rexBuilder.makeLiteral(true), - ImmutableSet.of(), JoinRelType.INNER); + RelNode join = relBuilder.push(frame.r).push(valueGenRel) + .join(JoinRelType.INNER, rexBuilder.makeLiteral(true)).build(); - // LogicalJoin or LogicalFilter does not change the old input ordering. All + // Join or Filter does not change the old input ordering. All // input fields from newLeftInput(i.e. the original input to the old - // LogicalFilter) are in the output and in the same position. + // Filter) are in the output and in the same position. return register(oldInput, join, frame.oldToNewOutputs, corDefOutputs); } @@ -1242,16 +1219,16 @@ public Frame decorrelateRel(HiveFilter rel) throws SemanticException { // // Rewrite logic: // - // 1. If a LogicalFilter references a correlated field in its filter - // condition, rewrite the LogicalFilter to be - // LogicalFilter - // LogicalJoin(cross product) + // 1. If a Filter references a correlated field in its filter + // condition, rewrite the Filter to be + // Filter + // Join(cross product) // OriginalFilterInput // ValueGenerator(produces distinct sets of correlated variables) // and rewrite the correlated fieldAccess in the filter condition to - // reference the LogicalJoin output. + // reference the Join output. // - // 2. If LogicalFilter does not reference correlated variables, simply + // 2. If Filter does not reference correlated variables, simply // rewrite the filter condition using new input. // @@ -1263,7 +1240,7 @@ public Frame decorrelateRel(HiveFilter rel) throws SemanticException { } Frame oldInputFrame = frame; - // If this LogicalFilter has correlated reference, create value generator + // If this Filter has correlated reference, create value generator // and produce the correlated variables in the new output. if (cm.mapRefRelToCorRef.containsKey(rel)) { frame = decorrelateInputWithValueGenerator(rel); @@ -1306,24 +1283,24 @@ public Frame decorrelateRel(HiveFilter rel) throws SemanticException { } /** - * Rewrite LogicalFilter. + * Rewrite Filter. * * @param rel the filter rel to rewrite */ - public Frame decorrelateRel(LogicalFilter rel) { + public Frame decorrelateRel(Filter rel) { // // Rewrite logic: // - // 1. If a LogicalFilter references a correlated field in its filter - // condition, rewrite the LogicalFilter to be - // LogicalFilter - // LogicalJoin(cross product) + // 1. If a Filter references a correlated field in its filter + // condition, rewrite the Filter to be + // Filter + // Join(cross product) // OriginalFilterInput // ValueGenerator(produces distinct sets of correlated variables) // and rewrite the correlated fieldAccess in the filter condition to - // reference the LogicalJoin output. + // reference the Join output. // - // 2. If LogicalFilter does not reference correlated variables, simply + // 2. If Filter does not reference correlated variables, simply // rewrite the filter condition using new input. // @@ -1334,7 +1311,7 @@ public Frame decorrelateRel(LogicalFilter rel) { return null; } - // If this LogicalFilter has correlated reference, create value generator + // If this Filter has correlated reference, create value generator // and produce the correlated variables in the new output. if (cm.mapRefRelToCorRef.containsKey(rel)) { frame = decorrelateInputWithValueGenerator(rel); @@ -1500,9 +1477,8 @@ public Frame decorrelateRel(LogicalCorrelate rel) { rightFrame.oldToNewOutputs.get(i) + newLeftFieldCount); } - newJoin = LogicalJoin.create(leftFrame.r, rightFrame.r, condition, - ImmutableSet.of(), rel.getJoinType().toJoinType()); - + newJoin = relBuilder.push(leftFrame.r).push(rightFrame.r) + .join(rel.getJoinType().toJoinType(), condition).build(); } valueGen.pop(); @@ -1564,11 +1540,11 @@ public Frame decorrelateRel(HiveJoin rel) throws SemanticException{ return register(rel, newJoin, mapOldToNewOutputs, corDefOutputs); } /** - * Rewrite LogicalJoin. + * Rewrite Join. * - * @param rel LogicalJoin + * @param rel Join */ - public Frame decorrelateRel(LogicalJoin rel) { + public Frame decorrelateRel(Join rel) { // // Rewrite logic: // @@ -1684,11 +1660,11 @@ private RexInputRef getNewForOldInputRef(RexInputRef oldInputRef) { * @param join Join * @param project Original project as the right-hand input of the join * @param nullIndicatorPos Position of null indicator - * @return the subtree with the new LogicalProject at the root + * @return the subtree with the new Project at the root */ private RelNode projectJoinOutputWithNullability( - LogicalJoin join, - LogicalProject project, + Join join, + Project project, int nullIndicatorPos) { final RelDataTypeFactory typeFactory = join.getCluster().getTypeFactory(); final RelNode left = join.getLeft(); @@ -1730,7 +1706,8 @@ private RelNode projectJoinOutputWithNullability( newProjExprs.add(Pair.of(newProjExpr, pair.right)); } - return RelOptUtil.createProject(join, newProjExprs, false); + return RelOptUtil.createProject(join, Pair.left(newProjExprs), Pair.right(newProjExprs), + false, relBuilder); } /** @@ -1741,11 +1718,11 @@ private RelNode projectJoinOutputWithNullability( * @param project the original project as the RHS input of the join * @param isCount Positions which are calls to the COUNT * aggregation function - * @return the subtree with the new LogicalProject at the root + * @return the subtree with the new Project at the root */ private RelNode aggregateCorrelatorOutput( Correlate correlate, - LogicalProject project, + Project project, Set isCount) { final RelNode left = correlate.getLeft(); final JoinRelType joinType = correlate.getJoinType().toJoinType(); @@ -1777,7 +1754,8 @@ private RelNode aggregateCorrelatorOutput( newProjects.add(Pair.of(newProjExpr, pair.right)); } - return RelOptUtil.createProject(correlate, newProjects, false); + return RelOptUtil.createProject(correlate, Pair.left(newProjects), Pair.right(newProjects), + false, relBuilder); } /** @@ -1792,8 +1770,8 @@ private RelNode aggregateCorrelatorOutput( */ private boolean checkCorVars( LogicalCorrelate correlate, - LogicalProject project, - LogicalFilter filter, + Project project, + Filter filter, List correlatedJoinKeys) { if (filter != null) { assert correlatedJoinKeys != null; @@ -1852,7 +1830,7 @@ private void removeCorVarFromTree(LogicalCorrelate correlate) { * * @param input Input relational expression * @param additionalExprs Additional expressions and names - * @return the new LogicalProject + * @return the new Project */ private RelNode createProjectWithAdditionalExprs( RelNode input, @@ -1868,7 +1846,8 @@ private RelNode createProjectWithAdditionalExprs( field.e.getName())); } projects.addAll(additionalExprs); - return RelOptUtil.createProject(input, projects, false); + return RelOptUtil.createProject(input, Pair.left(projects), Pair.right(projects), + false, relBuilder); } /* Returns an immutable map with the identity [0: 0, .., count-1: count-1]. */ @@ -2206,16 +2185,16 @@ private RexNode createCaseExpression( RemoveSingleAggregateRule() { super( operand( - LogicalAggregate.class, + Aggregate.class, operand( - LogicalProject.class, - operand(LogicalAggregate.class, any())))); + Project.class, + operand(Aggregate.class, any())))); } public void onMatch(RelOptRuleCall call) { - LogicalAggregate singleAggregate = call.rel(0); - LogicalProject project = call.rel(1); - LogicalAggregate aggregate = call.rel(2); + Aggregate singleAggregate = call.rel(0); + Project project = call.rel(1); + Aggregate aggregate = call.rel(2); // check singleAggRel is single_value agg if ((!singleAggregate.getGroupSet().isEmpty()) @@ -2241,15 +2220,11 @@ public void onMatch(RelOptRuleCall call) { // singleAggRel produces a nullable type, so create the new // projection that casts proj expr to a nullable type. final RelOptCluster cluster = project.getCluster(); - RelNode newProject = - RelOptUtil.createProject(aggregate, - ImmutableList.of( - rexBuilder.makeCast( - cluster.getTypeFactory().createTypeWithNullability( - projExprs.get(0).getType(), - true), - projExprs.get(0))), - null); + RelNode newProject = RelOptUtil.createProject(aggregate, + ImmutableList.of(rexBuilder.makeCast( + cluster.getTypeFactory().createTypeWithNullability(projExprs.get(0).getType(), true), + projExprs.get(0))), + null, false, relBuilder); call.transformTo(newProject); } } @@ -2260,16 +2235,16 @@ public void onMatch(RelOptRuleCall call) { super( operand(LogicalCorrelate.class, operand(RelNode.class, any()), - operand(LogicalAggregate.class, - operand(LogicalProject.class, + operand(Aggregate.class, + operand(Project.class, operand(RelNode.class, any()))))); } public void onMatch(RelOptRuleCall call) { final LogicalCorrelate correlate = call.rel(0); final RelNode left = call.rel(1); - final LogicalAggregate aggregate = call.rel(2); - final LogicalProject project = call.rel(3); + final Aggregate aggregate = call.rel(2); + final Project project = call.rel(3); RelNode right = call.rel(4); final RelOptCluster cluster = correlate.getCluster(); @@ -2281,8 +2256,8 @@ public void onMatch(RelOptRuleCall call) { // // CorrelateRel(left correlation, condition = true) // LeftInputRel - // LogicalAggregate (groupby (0) single_value()) - // LogicalProject-A (may reference coVar) + // Aggregate (groupby (0) single_value()) + // Project-A (may reference coVar) // RightInputRel final JoinRelType joinType = correlate.getJoinType().toJoinType(); @@ -2311,18 +2286,18 @@ public void onMatch(RelOptRuleCall call) { int nullIndicatorPos; - if ((right instanceof LogicalFilter) + if ((right instanceof Filter) && cm.mapRefRelToCorRef.containsKey(right)) { // rightInputRel has this shape: // - // LogicalFilter (references corvar) + // Filter (references corvar) // FilterInputRel // If rightInputRel is a filter and contains correlated // reference, make sure the correlated keys in the filter // condition forms a unique key of the RHS. - LogicalFilter filter = (LogicalFilter) right; + Filter filter = (Filter) right; right = filter.getInput(); assert right instanceof HepRelVertex; @@ -2342,7 +2317,7 @@ public void onMatch(RelOptRuleCall call) { // refs. These comparisons are AND'ed together. List tmpRightJoinKeys = Lists.newArrayList(); List correlatedJoinKeys = Lists.newArrayList(); - RelOptUtil.splitCorrelatedFilterCondition( + HiveRelOptUtil.splitCorrelatedFilterCondition( filter, tmpRightJoinKeys, correlatedJoinKeys, @@ -2386,8 +2361,8 @@ public void onMatch(RelOptRuleCall call) { // Change the plan to this structure. // Note that the aggregateRel is removed. // - // LogicalProject-A' (replace corvar to input ref from the LogicalJoin) - // LogicalJoin (replace corvar to input ref from LeftInputRel) + // Project-A' (replace corvar to input ref from the Join) + // Join (replace corvar to input ref from LeftInputRel) // LeftInputRel // RightInputRel(oreviously FilterInputRel) @@ -2410,11 +2385,11 @@ public void onMatch(RelOptRuleCall call) { // Change the plan to this structure. // - // LogicalProject-A' (replace corvar to input ref from LogicalJoin) - // LogicalJoin (left, condition = true) + // Project-A' (replace corvar to input ref from Join) + // Join (left, condition = true) // LeftInputRel - // LogicalAggregate(groupby(0), single_value(0), s_v(1)....) - // LogicalProject-B (everything from input plus literal true) + // Aggregate(groupby(0), single_value(0), s_v(1)....) + // Project-B (everything from input plus literal true) // ProjInputRel // make the new projRel to provide a null indicator @@ -2426,7 +2401,7 @@ public void onMatch(RelOptRuleCall call) { // make the new aggRel right = - RelOptUtil.createSingleValueAggRel(cluster, right); + HiveRelOptUtil.createSingleValueAggRel(cluster, right, HiveRelFactories.HIVE_AGGREGATE_FACTORY); // The last field: // single_value(true) @@ -2439,9 +2414,8 @@ public void onMatch(RelOptRuleCall call) { } // make the new join rel - LogicalJoin join = - LogicalJoin.create(left, right, joinCond, - ImmutableSet.of(), joinType); + Join join = (Join) relBuilder.push(left).push(right) + .join(joinType, joinCond).build(); RelNode newProject = projectJoinOutputWithNullability(join, project, nullIndicatorPos); @@ -2459,18 +2433,18 @@ public void onMatch(RelOptRuleCall call) { super( operand(LogicalCorrelate.class, operand(RelNode.class, any()), - operand(LogicalProject.class, - operand(LogicalAggregate.class, null, Aggregate.IS_SIMPLE, - operand(LogicalProject.class, + operand(Project.class, + operand(Aggregate.class, null, Aggregate.IS_SIMPLE, + operand(Project.class, operand(RelNode.class, any())))))); } public void onMatch(RelOptRuleCall call) { final LogicalCorrelate correlate = call.rel(0); final RelNode left = call.rel(1); - final LogicalProject aggOutputProject = call.rel(2); - final LogicalAggregate aggregate = call.rel(3); - final LogicalProject aggInputProject = call.rel(4); + final Project aggOutputProject = call.rel(2); + final Aggregate aggregate = call.rel(3); + final Project aggInputProject = call.rel(4); RelNode right = call.rel(5); final RelOptCluster cluster = correlate.getCluster(); @@ -2482,9 +2456,9 @@ public void onMatch(RelOptRuleCall call) { // // CorrelateRel(left correlation, condition = true) // LeftInputRel - // LogicalProject-A (a RexNode) - // LogicalAggregate (groupby (0), agg0(), agg1()...) - // LogicalProject-B (references coVar) + // Project-A (a RexNode) + // Aggregate (groupby (0), agg0(), agg1()...) + // Project-B (references coVar) // rightInputRel // check aggOutputProject projects only one expression @@ -2523,13 +2497,13 @@ public void onMatch(RelOptRuleCall call) { } } - if ((right instanceof LogicalFilter) + if ((right instanceof Filter) && cm.mapRefRelToCorRef.containsKey(right)) { // rightInputRel has this shape: // - // LogicalFilter (references corvar) + // Filter (references corvar) // FilterInputRel - LogicalFilter filter = (LogicalFilter) right; + Filter filter = (Filter) right; right = filter.getInput(); assert right instanceof HepRelVertex; @@ -2550,7 +2524,7 @@ public void onMatch(RelOptRuleCall call) { // expressions. These comparisons are AND'ed together. List rightJoinKeys = Lists.newArrayList(); List tmpCorrelatedJoinKeys = Lists.newArrayList(); - RelOptUtil.splitCorrelatedFilterCondition( + HiveRelOptUtil.splitCorrelatedFilterCondition( filter, rightJoinKeys, tmpCorrelatedJoinKeys, @@ -2600,21 +2574,21 @@ public void onMatch(RelOptRuleCall call) { // // CorrelateRel(left correlation, condition = true) // LeftInputRel - // LogicalProject-A (a RexNode) - // LogicalAggregate (groupby(0), agg0(),agg1()...) - // LogicalProject-B (may reference coVar) - // LogicalFilter (references corVar) + // Project-A (a RexNode) + // Aggregate (groupby(0), agg0(),agg1()...) + // Project-B (may reference coVar) + // Filter (references corVar) // RightInputRel (no correlated reference) // // to this plan: // - // LogicalProject-A' (all gby keys + rewritten nullable ProjExpr) - // LogicalAggregate (groupby(all left input refs) + // Project-A' (all gby keys + rewritten nullable ProjExpr) + // Aggregate (groupby(all left input refs) // agg0(rewritten expression), // agg1()...) - // LogicalProject-B' (rewriten original projected exprs) - // LogicalJoin(replace corvar w/ input ref from LeftInputRel) + // Project-B' (rewriten original projected exprs) + // Join(replace corvar w/ input ref from LeftInputRel) // LeftInputRel // RightInputRel // @@ -2626,14 +2600,14 @@ public void onMatch(RelOptRuleCall call) { // projection list from the RHS for simplicity to avoid // searching for non-null fields. // - // LogicalProject-A' (all gby keys + rewritten nullable ProjExpr) - // LogicalAggregate (groupby(all left input refs), + // Project-A' (all gby keys + rewritten nullable ProjExpr) + // Aggregate (groupby(all left input refs), // count(nullIndicator), other aggs...) - // LogicalProject-B' (all left input refs plus + // Project-B' (all left input refs plus // the rewritten original projected exprs) - // LogicalJoin(replace corvar to input ref from LeftInputRel) + // Join(replace corvar to input ref from LeftInputRel) // LeftInputRel - // LogicalProject (everything from RightInputRel plus + // Project (everything from RightInputRel plus // the nullIndicator "true") // RightInputRel // @@ -2668,20 +2642,20 @@ public void onMatch(RelOptRuleCall call) { // // CorrelateRel(left correlation, condition = true) // LeftInputRel - // LogicalProject-A (a RexNode) - // LogicalAggregate (groupby(0), agg0(), agg1()...) - // LogicalProject-B (references coVar) + // Project-A (a RexNode) + // Aggregate (groupby(0), agg0(), agg1()...) + // Project-B (references coVar) // RightInputRel (no correlated reference) // // to this plan: // - // LogicalProject-A' (all gby keys + rewritten nullable ProjExpr) - // LogicalAggregate (groupby(all left input refs) + // Project-A' (all gby keys + rewritten nullable ProjExpr) + // Aggregate (groupby(all left input refs) // agg0(rewritten expression), // agg1()...) - // LogicalProject-B' (rewriten original projected exprs) - // LogicalJoin (LOJ cond = true) + // Project-B' (rewriten original projected exprs) + // Join (LOJ cond = true) // LeftInputRel // RightInputRel // @@ -2693,14 +2667,14 @@ public void onMatch(RelOptRuleCall call) { // projection list from the RHS for simplicity to avoid // searching for non-null fields. // - // LogicalProject-A' (all gby keys + rewritten nullable ProjExpr) - // LogicalAggregate (groupby(all left input refs), + // Project-A' (all gby keys + rewritten nullable ProjExpr) + // Aggregate (groupby(all left input refs), // count(nullIndicator), other aggs...) - // LogicalProject-B' (all left input refs plus + // Project-B' (all left input refs plus // the rewritten original projected exprs) - // LogicalJoin(replace corvar to input ref from LeftInputRel) + // Join(replace corvar to input ref from LeftInputRel) // LeftInputRel - // LogicalProject (everything from RightInputRel plus + // Project (everything from RightInputRel plus // the nullIndicator "true") // RightInputRel } else { @@ -2718,9 +2692,8 @@ public void onMatch(RelOptRuleCall call) { Pair.of(rexBuilder.makeLiteral(true), "nullIndicator"))); - LogicalJoin join = - LogicalJoin.create(left, right, joinCond, - ImmutableSet.of(), joinType); + Join join = (Join) relBuilder.push(left).push(right) + .join(joinType, joinCond).build(); // To the consumer of joinOutputProjRel, nullIndicator is located // at the end @@ -2754,11 +2727,8 @@ public void onMatch(RelOptRuleCall call) { joinOutputProjects.add( rexBuilder.makeInputRef(join, nullIndicatorPos)); - RelNode joinOutputProject = - RelOptUtil.createProject( - join, - joinOutputProjects, - null); + RelNode joinOutputProject = RelOptUtil.createProject( + join, joinOutputProjects, null, false, relBuilder); // nullIndicator is now at a different location in the output of // the join @@ -2793,12 +2763,9 @@ public void onMatch(RelOptRuleCall call) { ImmutableBitSet groupSet = ImmutableBitSet.range(groupCount); - LogicalAggregate newAggregate = - LogicalAggregate.create(joinOutputProject, - false, - groupSet, - null, - newAggCalls); + Aggregate newAggregate = + (Aggregate) relBuilder.push(joinOutputProject) + .aggregate(relBuilder.groupKey(groupSet, null), newAggCalls).build(); List newAggOutputProjectList = Lists.newArrayList(); for (int i : groupSet) { @@ -2815,11 +2782,8 @@ public void onMatch(RelOptRuleCall call) { true), newAggOutputProjects)); - RelNode newAggOutputProject = - RelOptUtil.createProject( - newAggregate, - newAggOutputProjectList, - null); + RelNode newAggOutputProject = RelOptUtil.createProject( + newAggregate, newAggOutputProjectList, null, false, relBuilder); call.transformTo(newAggOutputProject); @@ -2844,19 +2808,19 @@ public void onMatch(RelOptRuleCall call) { flavor ? operand(LogicalCorrelate.class, operand(RelNode.class, any()), - operand(LogicalProject.class, - operand(LogicalAggregate.class, any()))) + operand(Project.class, + operand(Aggregate.class, any()))) : operand(LogicalCorrelate.class, operand(RelNode.class, any()), - operand(LogicalAggregate.class, any()))); + operand(Aggregate.class, any()))); this.flavor = flavor; } public void onMatch(RelOptRuleCall call) { final LogicalCorrelate correlate = call.rel(0); final RelNode left = call.rel(1); - final LogicalProject aggOutputProject; - final LogicalAggregate aggregate; + final Project aggOutputProject; + final Aggregate aggregate; if (flavor) { aggOutputProject = call.rel(2); aggregate = call.rel(3); @@ -2870,11 +2834,8 @@ public void onMatch(RelOptRuleCall call) { for (int i = 0; i < fields.size(); i++) { projects.add(RexInputRef.of2(projects.size(), fields)); } - aggOutputProject = - (LogicalProject) RelOptUtil.createProject( - aggregate, - projects, - false); + aggOutputProject = (Project) RelOptUtil.createProject( + aggregate, Pair.left(projects), Pair.right(projects), false, relBuilder); } onMatch2(call, correlate, left, aggOutputProject, aggregate); } @@ -2883,8 +2844,8 @@ private void onMatch2( RelOptRuleCall call, LogicalCorrelate correlate, RelNode leftInput, - LogicalProject aggOutputProject, - LogicalAggregate aggregate) { + Project aggOutputProject, + Aggregate aggregate) { if (generatedCorRels.contains(correlate)) { // This correlator was generated by a previous invocation of // this rule. No further work to do. @@ -2899,8 +2860,8 @@ private void onMatch2( // // CorrelateRel(left correlation, condition = true) // LeftInputRel - // LogicalProject-A (a RexNode) - // LogicalAggregate (groupby (0), agg0(), agg1()...) + // Project-A (a RexNode) + // Aggregate (groupby (0), agg0(), agg1()...) // check aggOutputProj projects only one expression List aggOutputProjExprs = aggOutputProject.getProjects(); @@ -2940,7 +2901,7 @@ private void onMatch2( // replacing references to count() with case statement) // Correlator(left correlation, condition = true) // LeftInputRel - // LogicalAggregate (groupby (0), agg0(), agg1()...) + // Aggregate (groupby (0), agg0(), agg1()...) // LogicalCorrelate newCorrelate = LogicalCorrelate.create(leftInput, aggregate, @@ -3179,24 +3140,10 @@ public RelNode visit(HiveUnion rel) { mightRequireValueGen = true; return rel; } - public RelNode visit(LogicalUnion rel) { - mightRequireValueGen = true; - return rel; - } - public RelNode visit(LogicalIntersect rel) { - mightRequireValueGen = true; - return rel; - } - public RelNode visit(HiveIntersect rel) { mightRequireValueGen = true; return rel; } - - @Override public RelNode visit(LogicalJoin rel) { - mightRequireValueGen = true; - return rel; - } @Override public RelNode visit(HiveProject rel) { if(!(hasRexOver(((HiveProject)rel).getProjects()))) { mightRequireValueGen = false; @@ -3206,15 +3153,6 @@ public RelNode visit(HiveIntersect rel) { return rel; } } - @Override public RelNode visit(LogicalProject rel) { - if(!(hasRexOver(((LogicalProject)rel).getProjects()))) { - mightRequireValueGen = false; - return super.visit(rel); - } else { - mightRequireValueGen = true; - return rel; - } - } @Override public RelNode visit(HiveAggregate rel) { // if there are aggregate functions or grouping sets we will need // value generator @@ -3228,17 +3166,6 @@ public RelNode visit(HiveIntersect rel) { return rel; } } - @Override public RelNode visit(LogicalAggregate rel) { - if(rel.getAggCallList().isEmpty() && !rel.indicator) { - this.mightRequireValueGen = false; - return super.visit(rel); - } else { - // need to reset to true in case previous aggregate/project - // has set it to false - this.mightRequireValueGen = true; - return rel; - } - } @Override public RelNode visit(LogicalCorrelate rel) { // this means we are hitting nested subquery so don't // need to go further @@ -3279,16 +3206,6 @@ CorelMap build(RelNode rel) { mapFieldAccessToCorVar); } - @Override public RelNode visit(LogicalJoin join) { - try { - Stacks.push(stack, join); - join.getCondition().accept(rexVisitor(join)); - } finally { - Stacks.pop(stack, join); - } - return visitJoin(join); - } - public RelNode visit(HiveJoin join) { try { Stacks.push(stack, join); @@ -3329,6 +3246,7 @@ public RelNode visit(final HiveProject project) { } return super.visit(project); } + public RelNode visit(final HiveFilter filter) { try { Stacks.push(stack, filter); @@ -3338,27 +3256,6 @@ public RelNode visit(final HiveFilter filter) { } return super.visit(filter); } - @Override public RelNode visit(final LogicalFilter filter) { - try { - Stacks.push(stack, filter); - filter.getCondition().accept(rexVisitor(filter)); - } finally { - Stacks.pop(stack, filter); - } - return super.visit(filter); - } - - @Override public RelNode visit(LogicalProject project) { - try { - Stacks.push(stack, project); - for (RexNode node : project.getProjects()) { - node.accept(rexVisitor(project)); - } - } finally { - Stacks.pop(stack, project); - } - return super.visit(project); - } private RexVisitorImpl rexVisitor(final RelNode rel) { return new RexVisitorImpl(true) {