diff --git ql/src/java/org/apache/hadoop/hive/ql/optimizer/physical/SparkMapJoinResolver.java ql/src/java/org/apache/hadoop/hive/ql/optimizer/physical/SparkMapJoinResolver.java index a8b7ac6..21c82b6 100644 --- ql/src/java/org/apache/hadoop/hive/ql/optimizer/physical/SparkMapJoinResolver.java +++ ql/src/java/org/apache/hadoop/hive/ql/optimizer/physical/SparkMapJoinResolver.java @@ -111,11 +111,8 @@ private void moveWork(SparkWork sparkWork, BaseWork work, SparkWork targetWork) SparkWork parentWork = new SparkWork(physicalContext.conf.getVar(HiveConf.ConfVars.HIVEQUERYID)); - // Update dependency graph - if (!dependencyGraph.containsKey(targetWork)) { - dependencyGraph.put(targetWork, new ArrayList()); - } dependencyGraph.get(targetWork).add(parentWork); + dependencyGraph.put(parentWork, new ArrayList()); // this work is now moved to the parentWork, thus we should // update this information in sparkWorkMap @@ -132,14 +129,15 @@ private void moveWork(SparkWork sparkWork, BaseWork work, SparkWork targetWork) // Create a new SparkTask for the specified SparkWork, recursively compute // all the parent SparkTasks that this new task is depend on, if they don't already exists. - private SparkTask createSparkTask(Task originalTask, + private SparkTask createSparkTask(SparkTask originalTask, SparkWork sparkWork, Map createdTaskMap) { if (createdTaskMap.containsKey(sparkWork)) { return createdTaskMap.get(sparkWork); } - SparkTask resultTask = (SparkTask) TaskFactory.get(sparkWork, physicalContext.conf); - if (dependencyGraph.get(sparkWork) != null) { + SparkTask resultTask = originalTask.getWork() == sparkWork ? + originalTask : (SparkTask) TaskFactory.get(sparkWork, physicalContext.conf); + if (!dependencyGraph.get(sparkWork).isEmpty()) { for (SparkWork parentWork : dependencyGraph.get(sparkWork)) { SparkTask parentTask = createSparkTask(originalTask, parentWork, createdTaskMap); parentTask.addDependentTask(resultTask); @@ -155,6 +153,8 @@ private SparkTask createSparkTask(Task originalTask, physicalContext.removeFromRootTask(originalTask); } } + + createdTaskMap.put(sparkWork, resultTask); return resultTask; } @@ -164,6 +164,8 @@ public Object dispatch(Node nd, Stack stack, Object... nos) Task currentTask = (Task) nd; if (currentTask instanceof SparkTask) { SparkWork sparkWork = ((SparkTask) currentTask).getWork(); + + dependencyGraph.put(sparkWork, new ArrayList()); Set leaves = sparkWork.getLeaves(); for (BaseWork leaf : leaves) { moveWork(sparkWork, leaf, sparkWork); @@ -181,7 +183,7 @@ public Object dispatch(Node nd, Stack stack, Object... nos) // Now create SparkTasks from the SparkWorks, also set up dependency for (SparkWork work : dependencyGraph.keySet()) { - createSparkTask(currentTask, work, createdTaskMap); + createSparkTask((SparkTask)currentTask, work, createdTaskMap); } }