diff --git ql/src/java/org/apache/hadoop/hive/ql/lib/PreOrderWalker.java ql/src/java/org/apache/hadoop/hive/ql/lib/PreOrderWalker.java index 9e4612d..f22694b 100644 --- ql/src/java/org/apache/hadoop/hive/ql/lib/PreOrderWalker.java +++ ql/src/java/org/apache/hadoop/hive/ql/lib/PreOrderWalker.java @@ -18,6 +18,8 @@ package org.apache.hadoop.hive.ql.lib; +import org.apache.hadoop.hive.ql.exec.ConditionalTask; +import org.apache.hadoop.hive.ql.exec.Task; import org.apache.hadoop.hive.ql.parse.SemanticException; /** @@ -58,6 +60,12 @@ public void walk(Node nd) throws SemanticException { for (Node n : nd.getChildren()) { walk(n); } + } else if (nd instanceof ConditionalTask) { + for (Task n : ((ConditionalTask) nd).getListTasks()) { + if (n.getParentTasks() == null || n.getParentTasks().isEmpty()) { + walk(n); + } + } } opStack.pop(); diff --git ql/src/java/org/apache/hadoop/hive/ql/optimizer/spark/SparkSkewJoinProcFactory.java ql/src/java/org/apache/hadoop/hive/ql/optimizer/spark/SparkSkewJoinProcFactory.java index 4c2f42e..04007e8 100644 --- ql/src/java/org/apache/hadoop/hive/ql/optimizer/spark/SparkSkewJoinProcFactory.java +++ ql/src/java/org/apache/hadoop/hive/ql/optimizer/spark/SparkSkewJoinProcFactory.java @@ -50,6 +50,7 @@ import org.apache.hadoop.hive.ql.plan.TableDesc; import java.io.Serializable; +import java.util.HashSet; import java.util.List; import java.util.Set; import java.util.Stack; @@ -58,6 +59,9 @@ * Spark-version of SkewJoinProcFactory. */ public class SparkSkewJoinProcFactory { + // let's remember the join operators we have processed + private static Set visitedJoinOp = new HashSet(); + private SparkSkewJoinProcFactory() { // prevent instantiation } @@ -84,10 +88,11 @@ public Object process(Node nd, Stack stack, NodeProcessorCtx procCtx, if (!op.getConf().isFixedAsSorted() && currentTsk instanceof SparkTask && reduceWork != null && ((SparkTask) currentTsk).getWork().contains(reduceWork) && GenSparkSkewJoinProcessor.supportRuntimeSkewJoin( - op, currentTsk, parseContext.getConf())) { + op, currentTsk, parseContext.getConf()) && !visitedJoinOp.contains(op)) { // first we try to split the task splitTask((SparkTask) currentTsk, reduceWork, parseContext); GenSparkSkewJoinProcessor.processSkewJoin(op, currentTsk, reduceWork, parseContext); + visitedJoinOp.add(op); } return null; } @@ -112,8 +117,7 @@ private static void splitTask(SparkTask currentTask, ReduceWork reduceWork, SparkWork newWork = new SparkWork(parseContext.getConf().getVar(HiveConf.ConfVars.HIVEQUERYID)); newWork.add(childWork); - copyWorkGraph(currentWork, newWork, childWork, true); - copyWorkGraph(currentWork, newWork, childWork, false); + copyWorkGraph(currentWork, newWork, childWork); // remove them from current spark work for (BaseWork baseWork : newWork.getAllWorkUnsorted()) { currentWork.remove(baseWork); @@ -196,21 +200,21 @@ private static boolean canSplit(SparkWork sparkWork) { /** * Copy a sub-graph from originWork to newWork. */ - private static void copyWorkGraph(SparkWork originWork, SparkWork newWork, - BaseWork baseWork, boolean upWards) { - if (upWards) { - for (BaseWork parent : originWork.getParents(baseWork)) { - newWork.add(parent); - SparkEdgeProperty edgeProperty = originWork.getEdgeProperty(parent, baseWork); - newWork.connect(parent, baseWork, edgeProperty); - copyWorkGraph(originWork, newWork, parent, true); - } - } else { - for (BaseWork child : originWork.getChildren(baseWork)) { + private static void copyWorkGraph(SparkWork originWork, SparkWork newWork, BaseWork baseWork) { + for (BaseWork child : originWork.getChildren(baseWork)) { + if (!newWork.contains(child)) { newWork.add(child); SparkEdgeProperty edgeProperty = originWork.getEdgeProperty(baseWork, child); newWork.connect(baseWork, child, edgeProperty); - copyWorkGraph(originWork, newWork, child, false); + copyWorkGraph(originWork, newWork, child); + } + } + for (BaseWork parent : originWork.getParents(baseWork)) { + if (!newWork.contains(parent)) { + newWork.add(parent); + SparkEdgeProperty edgeProperty = originWork.getEdgeProperty(parent, baseWork); + newWork.connect(parent, baseWork, edgeProperty); + copyWorkGraph(originWork, newWork, parent); } } } diff --git ql/src/java/org/apache/hadoop/hive/ql/optimizer/spark/SparkSkewJoinResolver.java ql/src/java/org/apache/hadoop/hive/ql/optimizer/spark/SparkSkewJoinResolver.java index 984380d..f2d406f 100644 --- ql/src/java/org/apache/hadoop/hive/ql/optimizer/spark/SparkSkewJoinResolver.java +++ ql/src/java/org/apache/hadoop/hive/ql/optimizer/spark/SparkSkewJoinResolver.java @@ -37,6 +37,7 @@ import org.apache.hadoop.hive.ql.lib.GraphWalker; import org.apache.hadoop.hive.ql.lib.Node; import org.apache.hadoop.hive.ql.lib.NodeProcessor; +import org.apache.hadoop.hive.ql.lib.PreOrderWalker; import org.apache.hadoop.hive.ql.lib.Rule; import org.apache.hadoop.hive.ql.lib.RuleRegExp; import org.apache.hadoop.hive.ql.optimizer.physical.PhysicalContext; @@ -54,7 +55,8 @@ @Override public PhysicalContext resolve(PhysicalContext pctx) throws SemanticException { Dispatcher disp = new SparkSkewJoinTaskDispatcher(pctx); - GraphWalker ogw = new DefaultGraphWalker(disp); + // since we may split current task, use a pre-order walker + GraphWalker ogw = new PreOrderWalker(disp); ArrayList topNodes = new ArrayList(); topNodes.addAll(pctx.getRootTasks()); ogw.startWalking(topNodes, null);