diff --git a/ql/src/java/org/apache/hadoop/hive/ql/exec/NodeUtils.java b/ql/src/java/org/apache/hadoop/hive/ql/exec/NodeUtils.java index 5aae311..39b2f22 100644 --- a/ql/src/java/org/apache/hadoop/hive/ql/exec/NodeUtils.java +++ b/ql/src/java/org/apache/hadoop/hive/ql/exec/NodeUtils.java @@ -20,8 +20,11 @@ import org.apache.hadoop.hive.ql.lib.Node; +import java.util.ArrayList; import java.util.Collection; +import java.util.Collections; import java.util.HashSet; +import java.util.List; import java.util.Set; /** @@ -31,47 +34,60 @@ public static void iterateTask(Collection> tasks, Class clazz, Function function) { Set visited = new HashSet(); - for (Task task : tasks) { - iterateTask(task, clazz, function, visited); + while (!tasks.isEmpty()) { + tasks = iterateTask(tasks, clazz, function, visited); } return; } - private static void iterateTask(Task task, Class clazz, Function function, Set visited) { - if (!visited.add(task)) { - return; - } - if (clazz.isInstance(task)) { - function.apply(clazz.cast(task)); - } - // this is for ConditionalTask - if (task.getDependentTasks() != null) { - for (Task dependent : task.getDependentTasks()) { - iterateTask(dependent, clazz, function, visited); + private static Collection> iterateTask(Collection> tasks, + Class clazz, + Function function, + Set visited) { + Collection> childTasks = new ArrayList<>(); + for (Task task : tasks) { + if (!visited.add(task)) { + continue; + } + if (clazz.isInstance(task)) { + function.apply(clazz.cast(task)); + } + // this is for ConditionalTask + if (task.getDependentTasks() != null) { + childTasks.addAll(task.getDependentTasks()); } } + return childTasks; } public static void iterate(Collection nodes, Class clazz, Function function) { Set visited = new HashSet(); - for (Node task : nodes) { - iterate(task, clazz, function, visited); + List> listNodes = Collections.singletonList(nodes); + while (!listNodes.isEmpty()) { + listNodes = iterate(listNodes, clazz, function, visited); } return; } - private static void iterate(Node node, Class clazz, Function function, Set visited) { - if (!visited.add(node)) { - return; - } - if (clazz.isInstance(node)) { - function.apply(clazz.cast(node)); - } - if (node.getChildren() != null) { - for (Node child : node.getChildren()) { - iterate(child, clazz, function, visited); + private static List> iterate(List> listNodes, + Class clazz, + Function function, + Set visited) { + List> childListNodes = new ArrayList<>(); + for (Collection nodes : listNodes) { + for (Node node : nodes) { + if (!visited.add(node)) { + continue; + } + if (clazz.isInstance(node)) { + function.apply(clazz.cast(node)); + } + if (node.getChildren() != null) { + childListNodes.add(node.getChildren()); + } } } + return childListNodes; } public static interface Function { diff --git a/ql/src/java/org/apache/hadoop/hive/ql/exec/Utilities.java b/ql/src/java/org/apache/hadoop/hive/ql/exec/Utilities.java index af82671..cc2e119 100644 --- a/ql/src/java/org/apache/hadoop/hive/ql/exec/Utilities.java +++ b/ql/src/java/org/apache/hadoop/hive/ql/exec/Utilities.java @@ -2383,15 +2383,20 @@ public static boolean isEmptyPath(Configuration job, Path dirPath) throws IOExce public static List getTezTasks(List> tasks) { List tezTasks = new ArrayList(); - Set> visited = new HashSet>(); if (tasks != null) { - getTezTasks(tasks, tezTasks, visited); + Set> visited = new HashSet>(); + while (!tasks.isEmpty()) { + tasks = getTezTasks(tasks, tezTasks, visited); + } } return tezTasks; } - private static void getTezTasks(List> tasks, List tezTasks, - Set> visited) { + private static List> getTezTasks( + List> tasks, + List tezTasks, + Set> visited) { + List> childTasks = new ArrayList<>(); for (Task task : tasks) { if (visited.contains(task)) { continue; @@ -2401,23 +2406,29 @@ private static void getTezTasks(List> tasks, List getSparkTasks(List> tasks) { List sparkTasks = new ArrayList(); - Set> visited = new HashSet>(); if (tasks != null) { - getSparkTasks(tasks, sparkTasks, visited); + Set> visited = new HashSet>(); + while (!tasks.isEmpty()) { + tasks = getSparkTasks(tasks, sparkTasks, visited); + } } return sparkTasks; } - private static void getSparkTasks(List> tasks, - List sparkTasks, Set> visited) { + private static List> getSparkTasks( + List> tasks, + List sparkTasks, + Set> visited) { + List> childTasks = new ArrayList<>(); for (Task task : tasks) { if (visited.contains(task)) { continue; @@ -2427,23 +2438,29 @@ private static void getSparkTasks(List> tasks, } if (task.getDependentTasks() != null) { - getSparkTasks(task.getDependentTasks(), sparkTasks, visited); + childTasks.addAll(task.getDependentTasks()); } visited.add(task); } + return childTasks; } public static List getMRTasks(List> tasks) { List mrTasks = new ArrayList(); - Set> visited = new HashSet>(); if (tasks != null) { - getMRTasks(tasks, mrTasks, visited); + Set> visited = new HashSet>(); + while (!tasks.isEmpty()) { + tasks = getMRTasks(tasks, mrTasks, visited); + } } return mrTasks; } - private static void getMRTasks(List> tasks, List mrTasks, - Set> visited) { + private static List> getMRTasks( + List> tasks, + List mrTasks, + Set> visited) { + List> childTasks = new ArrayList<>(); for (Task task : tasks) { if (visited.contains(task)) { continue; @@ -2453,10 +2470,11 @@ private static void getMRTasks(List> tasks, List