diff --git a/ql/src/java/org/apache/hadoop/hive/ql/optimizer/SparkRemoveDynamicPruningBySize.java b/ql/src/java/org/apache/hadoop/hive/ql/optimizer/SparkRemoveDynamicPruningBySize.java index a6bf3af..c41a0c8 100644 --- a/ql/src/java/org/apache/hadoop/hive/ql/optimizer/SparkRemoveDynamicPruningBySize.java +++ b/ql/src/java/org/apache/hadoop/hive/ql/optimizer/SparkRemoveDynamicPruningBySize.java @@ -23,12 +23,12 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.apache.hadoop.hive.conf.HiveConf.ConfVars; -import org.apache.hadoop.hive.ql.exec.Operator; import org.apache.hadoop.hive.ql.lib.Node; import org.apache.hadoop.hive.ql.lib.NodeProcessor; import org.apache.hadoop.hive.ql.lib.NodeProcessorCtx; import org.apache.hadoop.hive.ql.optimizer.spark.SparkPartitionPruningSinkDesc; import org.apache.hadoop.hive.ql.parse.SemanticException; +import org.apache.hadoop.hive.ql.parse.spark.GenSparkUtils; import org.apache.hadoop.hive.ql.parse.spark.OptimizeSparkProcContext; import org.apache.hadoop.hive.ql.parse.spark.SparkPartitionPruningSinkOperator; @@ -54,15 +54,7 @@ public Object process(Node nd, Stack stack, NodeProcessorCtx procContext, if (desc.getStatistics().getDataSize() > context.getConf() .getLongVar(ConfVars.SPARK_DYNAMIC_PARTITION_PRUNING_MAX_DATA_SIZE)) { - Operator child = op; - Operator curr = op; - - while (curr.getChildOperators().size() <= 1) { - child = curr; - curr = curr.getParentOperators().get(0); - } - - curr.removeChild(child); + GenSparkUtils.removeBranch(op); // at this point we've found the fork in the op pipeline that has the pruning as a child plan. LOG.info("Disabling dynamic pruning for: " + desc.getTableScan().getName() diff --git a/ql/src/java/org/apache/hadoop/hive/ql/parse/spark/GenSparkUtils.java b/ql/src/java/org/apache/hadoop/hive/ql/parse/spark/GenSparkUtils.java index 8a85574..7b2b3c0 100644 --- a/ql/src/java/org/apache/hadoop/hive/ql/parse/spark/GenSparkUtils.java +++ b/ql/src/java/org/apache/hadoop/hive/ql/parse/spark/GenSparkUtils.java @@ -596,4 +596,23 @@ private void findRoots(Operator op, List> ops) { findRoots(p, ops); } } + + /** + * Remove the branch that contains the specified operator. Do nothing if there's no branching, + * i.e. all the upstream operators have only one child. + */ + public static void removeBranch(Operator op) { + Operator child = op; + Operator curr = op; + + while (curr.getChildOperators().size() <= 1) { + child = curr; + if (curr.getParentOperators() == null || curr.getParentOperators().isEmpty()) { + return; + } + curr = curr.getParentOperators().get(0); + } + + curr.removeChild(child); + } } diff --git a/ql/src/java/org/apache/hadoop/hive/ql/parse/spark/SparkCompiler.java b/ql/src/java/org/apache/hadoop/hive/ql/parse/spark/SparkCompiler.java index baf77c7..71528e8 100644 --- a/ql/src/java/org/apache/hadoop/hive/ql/parse/spark/SparkCompiler.java +++ b/ql/src/java/org/apache/hadoop/hive/ql/parse/spark/SparkCompiler.java @@ -20,11 +20,13 @@ import java.io.Serializable; import java.util.ArrayList; import java.util.HashMap; +import java.util.HashSet; import java.util.LinkedHashMap; import java.util.List; import java.util.Map; import java.util.Set; import java.util.Stack; +import java.util.concurrent.atomic.AtomicInteger; import org.apache.hadoop.hive.conf.HiveConf; import org.apache.hadoop.hive.ql.Context; @@ -73,6 +75,7 @@ import org.apache.hadoop.hive.ql.optimizer.spark.SetSparkReducerParallelism; import org.apache.hadoop.hive.ql.optimizer.spark.SparkJoinHintOptimizer; import org.apache.hadoop.hive.ql.optimizer.spark.SparkJoinOptimizer; +import org.apache.hadoop.hive.ql.optimizer.spark.SparkPartitionPruningSinkDesc; import org.apache.hadoop.hive.ql.optimizer.spark.SparkReduceSinkMapJoinProc; import org.apache.hadoop.hive.ql.optimizer.spark.SparkSkewJoinResolver; import org.apache.hadoop.hive.ql.optimizer.spark.SplitSparkWorkResolver; @@ -116,9 +119,118 @@ protected void optimizeOperatorPlan(ParseContext pCtx, Set inputs, // Run Join releated optimizations runJoinOptimizations(procCtx); + // Remove cyclic dependencies for DPP + runCycleAnalysisForPartitionPruning(procCtx); + PERF_LOGGER.PerfLogEnd(CLASS_NAME, PerfLogger.SPARK_OPTIMIZE_OPERATOR_TREE); } + private void runCycleAnalysisForPartitionPruning(OptimizeSparkProcContext procCtx) { + if (!conf.getBoolVar(HiveConf.ConfVars.SPARK_DYNAMIC_PARTITION_PRUNING)) { + return; + } + + boolean cycleFree = false; + while (!cycleFree) { + cycleFree = true; + Set>> components = getComponents(procCtx); + for (Set> component : components) { + if (LOG.isDebugEnabled()) { + LOG.debug("Component: "); + for (Operator co : component) { + LOG.debug("Operator: " + co.getName() + ", " + co.getIdentifier()); + } + } + if (component.size() != 1) { + LOG.info("Found cycle in operator plan..."); + cycleFree = false; + removeDPPOperator(component, procCtx); + break; + } + } + LOG.info("Cycle free: " + cycleFree); + } + } + + private void removeDPPOperator(Set> component, OptimizeSparkProcContext context) { + SparkPartitionPruningSinkOperator toRemove = null; + for (Operator o : component) { + if (o instanceof SparkPartitionPruningSinkOperator) { + // we want to remove the DPP with bigger data size + if (toRemove == null + || o.getConf().getStatistics().getDataSize() > toRemove.getConf().getStatistics() + .getDataSize()) { + toRemove = (SparkPartitionPruningSinkOperator) o; + } + } + } + + if (toRemove == null) { + return; + } + + GenSparkUtils.removeBranch(toRemove); + // at this point we've found the fork in the op pipeline that has the pruning as a child plan. + LOG.info("Disabling dynamic pruning for: " + + toRemove.getConf().getTableScan().toString() + ". Needed to break cyclic dependency"); + } + + // Tarjan's algo + private Set>> getComponents(OptimizeSparkProcContext procCtx) { + AtomicInteger index = new AtomicInteger(); + Map, Integer> indexes = new HashMap, Integer>(); + Map, Integer> lowLinks = new HashMap, Integer>(); + Stack> nodes = new Stack>(); + Set>> components = new HashSet>>(); + + for (Operator o : procCtx.getParseContext().getTopOps().values()) { + if (!indexes.containsKey(o)) { + connect(o, index, nodes, indexes, lowLinks, components); + } + } + return components; + } + + private void connect(Operator o, AtomicInteger index, Stack> nodes, + Map, Integer> indexes, Map, Integer> lowLinks, + Set>> components) { + + indexes.put(o, index.get()); + lowLinks.put(o, index.get()); + index.incrementAndGet(); + nodes.push(o); + + List> children; + if (o instanceof SparkPartitionPruningSinkOperator) { + children = new ArrayList<>(); + children.addAll(o.getChildOperators()); + TableScanOperator ts = ((SparkPartitionPruningSinkDesc) o.getConf()).getTableScan(); + LOG.debug("Adding special edge: " + o.getName() + " --> " + ts.toString()); + children.add(ts); + } else { + children = o.getChildOperators(); + } + + for (Operator child : children) { + if (!indexes.containsKey(child)) { + connect(child, index, nodes, indexes, lowLinks, components); + lowLinks.put(o, Math.min(lowLinks.get(o), lowLinks.get(child))); + } else if (nodes.contains(child)) { + lowLinks.put(o, Math.min(lowLinks.get(o), indexes.get(child))); + } + } + + if (lowLinks.get(o).equals(indexes.get(o))) { + Set> component = new HashSet>(); + components.add(component); + Operator current; + do { + current = nodes.pop(); + component.add(current); + } while (current != o); + } + } + private void runStatsAnnotation(OptimizeSparkProcContext procCtx) throws SemanticException { new AnnotateWithStatistics().transform(procCtx.getParseContext()); new AnnotateWithOpTraits().transform(procCtx.getParseContext());