diff --git a/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/rules/HiveFilterSetOpTransposeRule.java b/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/rules/HiveFilterSetOpTransposeRule.java index 3ee29e0482..c617f9b415 100644 --- a/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/rules/HiveFilterSetOpTransposeRule.java +++ b/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/rules/HiveFilterSetOpTransposeRule.java @@ -17,14 +17,28 @@ */ package org.apache.hadoop.hive.ql.optimizer.calcite.rules; +import java.util.ArrayList; +import java.util.List; + +import org.apache.calcite.plan.RelOptPredicateList; import org.apache.calcite.plan.RelOptRuleCall; +import org.apache.calcite.plan.RelOptUtil; +import org.apache.calcite.rel.RelNode; import org.apache.calcite.rel.core.Filter; +import org.apache.calcite.rel.core.SetOp; +import org.apache.calcite.rel.core.Union; +import org.apache.calcite.rel.metadata.RelMetadataQuery; import org.apache.calcite.rel.rules.FilterSetOpTransposeRule; +import org.apache.calcite.rel.type.RelDataTypeField; +import org.apache.calcite.rex.RexBuilder; import org.apache.calcite.rex.RexNode; +import org.apache.calcite.rex.RexSimplify; +import org.apache.calcite.tools.RelBuilder; import org.apache.calcite.tools.RelBuilderFactory; import org.apache.hadoop.hive.ql.optimizer.calcite.HiveCalciteUtil; import org.apache.hadoop.hive.ql.optimizer.calcite.HiveRelFactories; +import com.google.common.collect.ImmutableList; public class HiveFilterSetOpTransposeRule extends FilterSetOpTransposeRule { @@ -48,4 +62,61 @@ public boolean matches(RelOptRuleCall call) { return super.matches(call); } + + + //~ Methods ---------------------------------------------------------------- + + // implement RelOptRule + public void onMatch(RelOptRuleCall call) { + Filter filterRel = call.rel(0); + SetOp setOp = call.rel(1); + + RexNode condition = filterRel.getCondition(); + + // create filters on top of each setop child, modifying the filter + // condition to reference each setop child + RexBuilder rexBuilder = filterRel.getCluster().getRexBuilder(); + final RelBuilder relBuilder = call.builder(); + List origFields = setOp.getRowType().getFieldList(); + int[] adjustments = new int[origFields.size()]; + final List newSetOpInputs = new ArrayList<>(); + RelNode lastInput = null; + for (int index = 0; index < setOp.getInputs().size(); index++) { + RelNode input = setOp.getInput(index); + RexNode newCondition = condition.accept(new RelOptUtil.RexInputConverter(rexBuilder, + origFields, input.getRowType().getFieldList(), adjustments)); + if (setOp instanceof Union && setOp.all) { + final RelMetadataQuery mq = RelMetadataQuery.instance(); + final RelOptPredicateList predicates = mq.getPulledUpPredicates(input); + if (predicates != null) { + ImmutableList.Builder listBuilder = ImmutableList.builder(); + listBuilder.addAll(predicates.pulledUpPredicates); + listBuilder.add(newCondition); + RexSimplify simplifierUnknownAsFalse = + new RexSimplify(rexBuilder, true, filterRel.getCluster().getPlanner().getExecutor()); + final RexNode x = simplifierUnknownAsFalse.simplifyAnds(listBuilder.build()); + if (x.isAlwaysFalse()) { + // this is the last branch, and it is always false + if (index == setOp.getInputs().size() - 1) { + lastInput = relBuilder.push(input).filter(newCondition).build(); + } + // remove this branch + continue; + } + } + } + newSetOpInputs.add(relBuilder.push(input).filter(newCondition).build()); + } + if (newSetOpInputs.size() > 1) { + // create a new setop whose children are the filters created above + SetOp newSetOp = setOp.copy(setOp.getTraitSet(), newSetOpInputs); + call.transformTo(newSetOp); + } else if (newSetOpInputs.size() == 1) { + call.transformTo(newSetOpInputs.get(0)); + } else { + // we have to keep at least a branch before we support empty values() in + // hive + call.transformTo(lastInput); + } + } }