diff --git a/ql/src/java/org/apache/hadoop/hive/ql/exec/Utilities.java b/ql/src/java/org/apache/hadoop/hive/ql/exec/Utilities.java index ba97d22ed9..af8267177d 100644 --- a/ql/src/java/org/apache/hadoop/hive/ql/exec/Utilities.java +++ b/ql/src/java/org/apache/hadoop/hive/ql/exec/Utilities.java @@ -2383,62 +2383,79 @@ public static boolean isEmptyPath(Configuration job, Path dirPath) throws IOExce public static List getTezTasks(List> tasks) { List tezTasks = new ArrayList(); + Set> visited = new HashSet>(); if (tasks != null) { - getTezTasks(tasks, tezTasks); + getTezTasks(tasks, tezTasks, visited); } return tezTasks; } - private static void getTezTasks(List> tasks, List tezTasks) { + private static void getTezTasks(List> tasks, List tezTasks, + Set> visited) { for (Task task : tasks) { + if (visited.contains(task)) { + continue; + } if (task instanceof TezTask && !tezTasks.contains(task)) { tezTasks.add((TezTask) task); } if (task.getDependentTasks() != null) { - getTezTasks(task.getDependentTasks(), tezTasks); + getTezTasks(task.getDependentTasks(), tezTasks, visited); } + visited.add(task); } } public static List getSparkTasks(List> tasks) { List sparkTasks = new ArrayList(); + Set> visited = new HashSet>(); if (tasks != null) { - getSparkTasks(tasks, sparkTasks); + getSparkTasks(tasks, sparkTasks, visited); } return sparkTasks; } private static void getSparkTasks(List> tasks, - List sparkTasks) { + List sparkTasks, Set> visited) { for (Task task : tasks) { + if (visited.contains(task)) { + continue; + } if (task instanceof SparkTask && !sparkTasks.contains(task)) { sparkTasks.add((SparkTask) task); } if (task.getDependentTasks() != null) { - getSparkTasks(task.getDependentTasks(), sparkTasks); + getSparkTasks(task.getDependentTasks(), sparkTasks, visited); } + visited.add(task); } } public static List getMRTasks(List> tasks) { List mrTasks = new ArrayList(); + Set> visited = new HashSet>(); if (tasks != null) { - getMRTasks(tasks, mrTasks); + getMRTasks(tasks, mrTasks, visited); } return mrTasks; } - private static void getMRTasks(List> tasks, List mrTasks) { + private static void getMRTasks(List> tasks, List mrTasks, + Set> visited) { for (Task task : tasks) { + if (visited.contains(task)) { + continue; + } if (task instanceof ExecDriver && !mrTasks.contains(task)) { mrTasks.add((ExecDriver) task); } if (task.getDependentTasks() != null) { - getMRTasks(task.getDependentTasks(), mrTasks); + getMRTasks(task.getDependentTasks(), mrTasks, visited); } + visited.add(task); } } diff --git a/ql/src/test/org/apache/hadoop/hive/ql/exec/TestUtilities.java b/ql/src/test/org/apache/hadoop/hive/ql/exec/TestUtilities.java index 434e20622f..dcd8f95eb0 100644 --- a/ql/src/test/org/apache/hadoop/hive/ql/exec/TestUtilities.java +++ b/ql/src/test/org/apache/hadoop/hive/ql/exec/TestUtilities.java @@ -32,6 +32,7 @@ import java.io.File; import java.io.IOException; +import java.io.Serializable; import java.sql.Timestamp; import java.util.ArrayList; import java.util.LinkedHashMap; @@ -49,10 +50,14 @@ import org.apache.hadoop.hive.conf.HiveConf; import org.apache.hadoop.hive.metastore.api.hive_metastoreConstants; import org.apache.hadoop.hive.ql.Context; +import org.apache.hadoop.hive.ql.exec.mr.ExecDriver; +import org.apache.hadoop.hive.ql.exec.spark.SparkTask; +import org.apache.hadoop.hive.ql.exec.tez.TezTask; import org.apache.hadoop.hive.ql.io.*; import org.apache.hadoop.hive.ql.metadata.HiveException; import org.apache.hadoop.hive.ql.metadata.InputEstimator; import org.apache.hadoop.hive.ql.metadata.Table; +import org.apache.hadoop.hive.ql.plan.DependencyCollectionWork; import org.apache.hadoop.hive.ql.plan.DynamicPartitionCtx; import org.apache.hadoop.hive.ql.plan.ExprNodeConstantDesc; import org.apache.hadoop.hive.ql.plan.ExprNodeDesc; @@ -685,4 +690,124 @@ private ContentSummary runTestGetInputSummary(JobConf jobConf, Properties proper } } } + + private Task getDependencyCollectionTask(){ + return TaskFactory.get(new DependencyCollectionWork(), new HiveConf()); + } + + /** + * Generates a task graph that looks like this: + * + * ---->DTa---- + * / \ + * root ----->DTb-----*-->DTd---> ProvidedTask --> DTe + * \ / + * ---->DTc---- + */ + private List> getTestDiamondTaskGraph(Task providedTask){ + // Note: never instantiate a task without TaskFactory.get() if you're not + // okay with .equals() breaking. Doing it via TaskFactory.get makes sure + // that an id is generated, and two tasks of the same type don't show + // up as "equal", which is important for things like iterating over an + // array. Without this, DTa, DTb, and DTc would show up as one item in + // the list of children. Thus, we're instantiating via a helper method + // that instantiates via TaskFactory.get() + Task root = getDependencyCollectionTask(); + Task DTa = getDependencyCollectionTask(); + Task DTb = getDependencyCollectionTask(); + Task DTc = getDependencyCollectionTask(); + Task DTd = getDependencyCollectionTask(); + Task DTe = getDependencyCollectionTask(); + + root.addDependentTask(DTa); + root.addDependentTask(DTb); + root.addDependentTask(DTc); + + DTa.addDependentTask(DTd); + DTb.addDependentTask(DTd); + DTc.addDependentTask(DTd); + + DTd.addDependentTask(providedTask); + + providedTask.addDependentTask(DTe); + + List> retVals = new ArrayList>(); + retVals.add(root); + return retVals; + } + + /** + * DependencyCollectionTask that counts how often getDependentTasks on it + * (and thus, on its descendants) is called counted via Task.getDependentTasks. + * It is used to wrap another task to intercept calls on it. + */ + public class CountingWrappingTask extends DependencyCollectionTask { + int count; + Task wrappedDep = null; + + public CountingWrappingTask(Task dep) { + count = 0; + wrappedDep = dep; + super.addDependentTask(wrappedDep); + } + + public boolean addDependentTask(Task dependent) { + return wrappedDep.addDependentTask(dependent); + } + + @Override + public List> getDependentTasks() { + count++; + System.err.println("YAH:getDepTasks got called!"); + (new Exception()).printStackTrace(System.err); + LOG.info("YAH!getDepTasks", new Exception()); + return super.getDependentTasks(); + } + + public int getDepCallCount() { + return count; + } + + @Override + public String getName() { + return "COUNTER_TASK"; + } + + @Override + public String toString() { + return getName() + "_" + wrappedDep.toString(); + } + }; + + /** + * This test tests that Utilities.get*Tasks do not repeat themselves in the process + * of extracting tasks from a given set of root tasks when given DAGs that can have + * multiple paths, such as the case with Diamond-shaped DAGs common to replication. + */ + @Test + public void testGetTasksHaveNoRepeats() { + + CountingWrappingTask mrTask = new CountingWrappingTask(new ExecDriver()); + CountingWrappingTask tezTask = new CountingWrappingTask(new TezTask()); + CountingWrappingTask sparkTask = new CountingWrappingTask(new SparkTask()); + + // First check - we should not have repeats in results + assertEquals("No repeated MRTasks from Utilities.getMRTasks", 1, + Utilities.getMRTasks(getTestDiamondTaskGraph(mrTask)).size()); + assertEquals("No repeated TezTasks from Utilities.getTezTasks", 1, + Utilities.getTezTasks(getTestDiamondTaskGraph(tezTask)).size()); + assertEquals("No repeated TezTasks from Utilities.getSparkTasks", 1, + Utilities.getSparkTasks(getTestDiamondTaskGraph(sparkTask)).size()); + + // Second check - the tasks we looked for must not have been accessed more than + // once as a result of the traversal (note that we actually wind up accessing + // 2 times , because each visit counts twice, once to check for existence, and + // once to visit. + + assertEquals("MRTasks should have been visited only once", 2, mrTask.getDepCallCount()); + assertEquals("TezTasks should have been visited only once", 2, tezTask.getDepCallCount()); + assertEquals("SparkTasks should have been visited only once", 2, sparkTask.getDepCallCount()); + + } + }