diff --git a/ql/src/java/org/apache/hadoop/hive/ql/exec/Utilities.java b/ql/src/java/org/apache/hadoop/hive/ql/exec/Utilities.java index 1be7eab48d..b78c930cf5 100644 --- a/ql/src/java/org/apache/hadoop/hive/ql/exec/Utilities.java +++ b/ql/src/java/org/apache/hadoop/hive/ql/exec/Utilities.java @@ -121,6 +121,7 @@ import org.apache.hadoop.hive.ql.exec.spark.SparkTask; import org.apache.hadoop.hive.ql.exec.tez.DagUtils; import org.apache.hadoop.hive.ql.exec.tez.TezTask; +import org.apache.hadoop.hive.ql.exec.util.DAGTraversal; import org.apache.hadoop.hive.ql.exec.vector.VectorizedInputFormatInterface; import org.apache.hadoop.hive.ql.exec.vector.VectorizedRowBatchCtx; import org.apache.hadoop.hive.ql.io.AcidUtils; @@ -2569,41 +2570,49 @@ public static boolean isEmptyPath(Configuration job, Path dirPath) throws IOExce } public static List getTezTasks(List> tasks) { - return getTasks(tasks, TezTask.class); + return getTasks(tasks, new TaskFilterFunction<>(TezTask.class)); } public static List getSparkTasks(List> tasks) { - return getTasks(tasks, SparkTask.class); + return getTasks(tasks, new TaskFilterFunction<>(SparkTask.class)); } public static List getMRTasks(List> tasks) { - return getTasks(tasks, ExecDriver.class); + return getTasks(tasks, new TaskFilterFunction<>(ExecDriver.class)); } - @SuppressWarnings("unchecked") - public static List getTasks(List> tasks, Class requiredType) { - List typeSpecificTasks = new ArrayList<>(); - if (tasks != null) { - Set> visited = new HashSet<>(); - while (!tasks.isEmpty()) { - List> childTasks = new ArrayList<>(); - for (Task task : tasks) { - if (visited.contains(task)) { - continue; - } - if (requiredType.isInstance(task) && !typeSpecificTasks.contains(task)) { - typeSpecificTasks.add((T) task); - } - if (task.getDependentTasks() != null) { - childTasks.addAll(task.getDependentTasks()); - } - visited.add(task); - } - // start recursion - tasks = childTasks; + static class TaskFilterFunction implements DAGTraversal.Function { + private Set> visited = new HashSet<>(); + private Class requiredType; + private List typeSpecificTasks = new ArrayList<>(); + + TaskFilterFunction(Class requiredType) { + this.requiredType = requiredType; + } + + @Override + public void process(Task task) { + if (requiredType.isInstance(task) && !typeSpecificTasks.contains(task)) { + typeSpecificTasks.add((T) task); } + visited.add(task); + } + + List getTasks() { + return typeSpecificTasks; } - return typeSpecificTasks; + + @Override + public boolean skipProcessing(Task task) { + return visited.contains(task); + } + } + + @SuppressWarnings("unchecked") + private static List getTasks(List> tasks, + TaskFilterFunction function) { + DAGTraversal.traverse(tasks, function); + return function.getTasks(); } /** diff --git a/ql/src/java/org/apache/hadoop/hive/ql/exec/repl/bootstrap/AddDependencyToLeaves.java b/ql/src/java/org/apache/hadoop/hive/ql/exec/repl/bootstrap/AddDependencyToLeaves.java new file mode 100644 index 0000000000..cf838e1d13 --- /dev/null +++ b/ql/src/java/org/apache/hadoop/hive/ql/exec/repl/bootstrap/AddDependencyToLeaves.java @@ -0,0 +1,51 @@ +/* + Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + */ + +package org.apache.hadoop.hive.ql.exec.repl.bootstrap; + +import org.apache.hadoop.hive.ql.exec.Task; +import org.apache.hadoop.hive.ql.exec.util.DAGTraversal; + +import java.io.Serializable; +import java.util.Collections; +import java.util.List; + +public class AddDependencyToLeaves implements DAGTraversal.Function { + private List> postDependencyCollectionTasks; + + AddDependencyToLeaves(List> postDependencyCollectionTasks) { + this.postDependencyCollectionTasks = postDependencyCollectionTasks; + } + + public AddDependencyToLeaves(Task postDependencyTask) { + this(Collections.singletonList(postDependencyTask)); + } + + + @Override + public void process(Task task) { + if (task.getChildTasks() == null) { + postDependencyCollectionTasks.forEach(task::addDependentTask); + } + } + + @Override + public boolean skipProcessing(Task task) { + return postDependencyCollectionTasks.contains(task); + } +} diff --git a/ql/src/java/org/apache/hadoop/hive/ql/exec/repl/bootstrap/ReplLoadTask.java b/ql/src/java/org/apache/hadoop/hive/ql/exec/repl/bootstrap/ReplLoadTask.java index bf5c819e90..bfbec45d94 100644 --- a/ql/src/java/org/apache/hadoop/hive/ql/exec/repl/bootstrap/ReplLoadTask.java +++ b/ql/src/java/org/apache/hadoop/hive/ql/exec/repl/bootstrap/ReplLoadTask.java @@ -39,12 +39,14 @@ Licensed to the Apache Software Foundation (ASF) under one import org.apache.hadoop.hive.ql.exec.repl.bootstrap.load.table.LoadTable; import org.apache.hadoop.hive.ql.exec.repl.bootstrap.load.table.TableContext; import org.apache.hadoop.hive.ql.exec.repl.bootstrap.load.util.Context; +import org.apache.hadoop.hive.ql.exec.util.DAGTraversal; import org.apache.hadoop.hive.ql.parse.SemanticException; import org.apache.hadoop.hive.ql.parse.repl.ReplLogger; import org.apache.hadoop.hive.ql.plan.api.StageType; import java.io.Serializable; import java.util.ArrayList; +import java.util.Collections; import java.util.List; import static org.apache.hadoop.hive.ql.exec.repl.bootstrap.load.LoadDatabase.AlterDatabase; @@ -225,22 +227,28 @@ a database ( directory ) return 0; } - private Task createEndReplLogTask(Context context, Scope scope, + private void createEndReplLogTask(Context context, Scope scope, ReplLogger replLogger) throws SemanticException { Database dbInMetadata = work.databaseEvent(context.hiveConf).dbInMetadata(work.dbNameToLoadIn); ReplStateLogWork replLogWork = new ReplStateLogWork(replLogger, dbInMetadata.getParameters()); Task replLogTask = TaskFactory.get(replLogWork, conf); - if (null == scope.rootTasks) { + if (scope.rootTasks.isEmpty()) { scope.rootTasks.add(replLogTask); } else { - dependency(scope.rootTasks, replLogTask); + DAGTraversal.traverse(scope.rootTasks, + new AddDependencyToLeaves(Collections.singletonList(replLogTask))); } - return replLogTask; } /** * There was a database update done before and we want to make sure we update the last repl * id on this database as we are now going to switch to processing a new database. + * + * This has to be last task in the graph since if there are intermediate tasks and the last.repl.id + * is a root level task then in the execution phase the root level tasks will get executed first, + * however if any of the child tasks of the bootstrap load failed then even though the bootstrap has failed + * the last repl status of the target database will return a valid value, which will not represent + * the state of the database. */ private TaskTracker updateDatabaseLastReplID(int maxTasks, Context context, Scope scope) throws SemanticException { @@ -251,7 +259,10 @@ private TaskTracker updateDatabaseLastReplID(int maxTasks, Context context, Scop TaskTracker taskTracker = new AlterDatabase(context, work.databaseEvent(context.hiveConf), work.dbNameToLoadIn, new TaskTracker(maxTasks)).tasks(); - scope.rootTasks.addAll(taskTracker.tasks()); + + AddDependencyToLeaves function = new AddDependencyToLeaves(taskTracker.tasks()); + DAGTraversal.traverse(scope.rootTasks, function); + return taskTracker; } @@ -288,27 +299,8 @@ private void createBuilderTask(List> rootTasks, */ if (shouldCreateAnotherLoadTask) { Task loadTask = TaskFactory.get(work, conf, true); - dependency(rootTasks, loadTask); - } - } - - /** - * add the dependency to the leaf node - */ - public static boolean dependency(List> tasks, Task tailTask) { - if (tasks == null || tasks.isEmpty()) { - return true; - } - for (Task task : tasks) { - if (task == tailTask) { - continue; - } - boolean leafNode = dependency(task.getChildTasks(), tailTask); - if (leafNode) { - task.addDependentTask(tailTask); - } + DAGTraversal.traverse(rootTasks, new AddDependencyToLeaves(loadTask)); } - return false; } @Override diff --git a/ql/src/java/org/apache/hadoop/hive/ql/exec/repl/bootstrap/load/LoadFunction.java b/ql/src/java/org/apache/hadoop/hive/ql/exec/repl/bootstrap/load/LoadFunction.java index 8852a60d15..ef4ed4d4cc 100644 --- a/ql/src/java/org/apache/hadoop/hive/ql/exec/repl/bootstrap/load/LoadFunction.java +++ b/ql/src/java/org/apache/hadoop/hive/ql/exec/repl/bootstrap/load/LoadFunction.java @@ -22,12 +22,13 @@ Licensed to the Apache Software Foundation (ASF) under one import org.apache.hadoop.hive.ql.exec.Task; import org.apache.hadoop.hive.ql.exec.TaskFactory; import org.apache.hadoop.hive.ql.exec.repl.ReplStateLogWork; +import org.apache.hadoop.hive.ql.exec.repl.bootstrap.AddDependencyToLeaves; import org.apache.hadoop.hive.ql.exec.repl.bootstrap.events.FunctionEvent; import org.apache.hadoop.hive.ql.exec.repl.bootstrap.load.util.Context; +import org.apache.hadoop.hive.ql.exec.util.DAGTraversal; import org.apache.hadoop.hive.ql.parse.EximUtil; import org.apache.hadoop.hive.ql.parse.SemanticException; import org.apache.hadoop.hive.ql.parse.repl.ReplLogger; -import org.apache.hadoop.hive.ql.exec.repl.bootstrap.ReplLoadTask; import org.apache.hadoop.hive.ql.parse.repl.load.message.CreateFunctionHandler; import org.apache.hadoop.hive.ql.parse.repl.load.message.MessageHandler; import org.slf4j.Logger; @@ -61,7 +62,7 @@ private void createFunctionReplLogTask(List> functi String functionName) { ReplStateLogWork replLogWork = new ReplStateLogWork(replLogger, functionName); Task replLogTask = TaskFactory.get(replLogWork, context.hiveConf); - ReplLoadTask.dependency(functionTasks, replLogTask); + DAGTraversal.traverse(functionTasks, new AddDependencyToLeaves(replLogTask)); } public TaskTracker tasks() throws IOException, SemanticException { diff --git a/ql/src/java/org/apache/hadoop/hive/ql/exec/repl/bootstrap/load/table/LoadPartitions.java b/ql/src/java/org/apache/hadoop/hive/ql/exec/repl/bootstrap/load/table/LoadPartitions.java index 03608167d8..262225fc20 100644 --- a/ql/src/java/org/apache/hadoop/hive/ql/exec/repl/bootstrap/load/table/LoadPartitions.java +++ b/ql/src/java/org/apache/hadoop/hive/ql/exec/repl/bootstrap/load/table/LoadPartitions.java @@ -27,12 +27,13 @@ Licensed to the Apache Software Foundation (ASF) under one import org.apache.hadoop.hive.ql.exec.TaskFactory; import org.apache.hadoop.hive.ql.exec.Utilities; import org.apache.hadoop.hive.ql.exec.repl.ReplStateLogWork; -import org.apache.hadoop.hive.ql.exec.repl.bootstrap.ReplLoadTask; +import org.apache.hadoop.hive.ql.exec.repl.bootstrap.AddDependencyToLeaves; import org.apache.hadoop.hive.ql.exec.repl.bootstrap.events.TableEvent; import org.apache.hadoop.hive.ql.exec.repl.bootstrap.load.ReplicationState; import org.apache.hadoop.hive.ql.exec.repl.bootstrap.load.TaskTracker; import org.apache.hadoop.hive.ql.exec.repl.bootstrap.load.util.Context; import org.apache.hadoop.hive.ql.exec.repl.bootstrap.load.util.PathUtils; +import org.apache.hadoop.hive.ql.exec.util.DAGTraversal; import org.apache.hadoop.hive.ql.metadata.HiveException; import org.apache.hadoop.hive.ql.metadata.Partition; import org.apache.hadoop.hive.ql.metadata.Table; @@ -47,7 +48,6 @@ Licensed to the Apache Software Foundation (ASF) under one import org.apache.hadoop.hive.ql.plan.LoadTableDesc.LoadFileType; import org.apache.hadoop.hive.ql.plan.MoveWork; import org.apache.hadoop.hive.ql.session.SessionState; -import org.mortbay.jetty.servlet.AbstractSessionManager; import org.datanucleus.util.StringUtils; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -118,7 +118,7 @@ private void createTableReplLogTask() throws SemanticException { if (tracker.tasks().isEmpty()) { tracker.addTask(replLogTask); } else { - ReplLoadTask.dependency(tracker.tasks(), replLogTask); + DAGTraversal.traverse(tracker.tasks(), new AddDependencyToLeaves(replLogTask)); List> visited = new ArrayList<>(); tracker.updateTaskCount(replLogTask, visited); diff --git a/ql/src/java/org/apache/hadoop/hive/ql/exec/repl/bootstrap/load/table/LoadTable.java b/ql/src/java/org/apache/hadoop/hive/ql/exec/repl/bootstrap/load/table/LoadTable.java index 766a9a92c6..bb1f4e5050 100644 --- a/ql/src/java/org/apache/hadoop/hive/ql/exec/repl/bootstrap/load/table/LoadTable.java +++ b/ql/src/java/org/apache/hadoop/hive/ql/exec/repl/bootstrap/load/table/LoadTable.java @@ -28,11 +28,12 @@ Licensed to the Apache Software Foundation (ASF) under one import org.apache.hadoop.hive.ql.exec.TaskFactory; import org.apache.hadoop.hive.ql.exec.Utilities; import org.apache.hadoop.hive.ql.exec.repl.ReplStateLogWork; -import org.apache.hadoop.hive.ql.exec.repl.bootstrap.ReplLoadTask; +import org.apache.hadoop.hive.ql.exec.repl.bootstrap.AddDependencyToLeaves; import org.apache.hadoop.hive.ql.exec.repl.bootstrap.events.TableEvent; import org.apache.hadoop.hive.ql.exec.repl.bootstrap.load.TaskTracker; import org.apache.hadoop.hive.ql.exec.repl.bootstrap.load.util.Context; import org.apache.hadoop.hive.ql.exec.repl.bootstrap.load.util.PathUtils; +import org.apache.hadoop.hive.ql.exec.util.DAGTraversal; import org.apache.hadoop.hive.ql.metadata.Table; import org.apache.hadoop.hive.ql.parse.EximUtil; import org.apache.hadoop.hive.ql.parse.ImportSemanticAnalyzer; @@ -78,12 +79,12 @@ public LoadTable(TableEvent event, Context context, ReplLogger replLogger, private void createTableReplLogTask(String tableName, TableType tableType) throws SemanticException { ReplStateLogWork replLogWork = new ReplStateLogWork(replLogger,tableName, tableType); Task replLogTask = TaskFactory.get(replLogWork, context.hiveConf); - ReplLoadTask.dependency(tracker.tasks(), replLogTask); + DAGTraversal.traverse(tracker.tasks(), new AddDependencyToLeaves(replLogTask)); if (tracker.tasks().isEmpty()) { tracker.addTask(replLogTask); } else { - ReplLoadTask.dependency(tracker.tasks(), replLogTask); + DAGTraversal.traverse(tracker.tasks(), new AddDependencyToLeaves(replLogTask)); List> visited = new ArrayList<>(); tracker.updateTaskCount(replLogTask, visited); diff --git a/ql/src/java/org/apache/hadoop/hive/ql/exec/util/DAGTraversal.java b/ql/src/java/org/apache/hadoop/hive/ql/exec/util/DAGTraversal.java new file mode 100644 index 0000000000..1e436bad54 --- /dev/null +++ b/ql/src/java/org/apache/hadoop/hive/ql/exec/util/DAGTraversal.java @@ -0,0 +1,55 @@ +/* + Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + */ + +package org.apache.hadoop.hive.ql.exec.util; + +import org.apache.hadoop.hive.ql.exec.Task; + +import java.io.Serializable; +import java.util.ArrayList; +import java.util.List; + +/** + * The dag traversal done here is written to be not recursion based as large DAG's will lead to + * stack overflow's, hence iteration based. + */ +public class DAGTraversal { + public static void traverse(List> tasks, Function function) { + List> listOfTasks = new ArrayList<>(tasks); + while (!listOfTasks.isEmpty()) { + List> children = new ArrayList<>(); + for (Task task : listOfTasks) { + // skip processing has to be done first before continuing + if (function.skipProcessing(task)) { + continue; + } + if (task.getDependentTasks() != null) { + children.addAll(task.getDependentTasks()); + } + function.process(task); + } + listOfTasks = children; + } + } + + public interface Function { + void process(Task task); + + boolean skipProcessing(Task task); + } +} diff --git a/ql/src/test/org/apache/hadoop/hive/ql/exec/repl/bootstrap/AddDependencyToLeavesTest.java b/ql/src/test/org/apache/hadoop/hive/ql/exec/repl/bootstrap/AddDependencyToLeavesTest.java new file mode 100644 index 0000000000..a807483f0a --- /dev/null +++ b/ql/src/test/org/apache/hadoop/hive/ql/exec/repl/bootstrap/AddDependencyToLeavesTest.java @@ -0,0 +1,85 @@ +/* + Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + */ + +package org.apache.hadoop.hive.ql.exec.repl.bootstrap; + +import org.apache.hadoop.hive.conf.HiveConf; +import org.apache.hadoop.hive.ql.exec.Task; +import org.apache.hadoop.hive.ql.exec.TaskFactory; +import org.apache.hadoop.hive.ql.exec.util.DAGTraversal; +import org.apache.hadoop.hive.ql.plan.DependencyCollectionWork; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.mockito.Mock; +import org.mockito.runners.MockitoJUnitRunner; + +import java.io.Serializable; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +@RunWith(MockitoJUnitRunner.class) +public class AddDependencyToLeavesTest { + @Mock + private HiveConf hiveConf; + + @Test + public void shouldNotSkipIntermediateDependencyCollectionTasks() { + Task collectionWorkTaskOne = + TaskFactory.get(new DependencyCollectionWork(), hiveConf); + Task collectionWorkTaskTwo = + TaskFactory.get(new DependencyCollectionWork(), hiveConf); + Task collectionWorkTaskThree = + TaskFactory.get(new DependencyCollectionWork(), hiveConf); + + @SuppressWarnings("unchecked") Task rootTask = mock(Task.class); + when(rootTask.getDependentTasks()) + .thenReturn( + Arrays.asList(collectionWorkTaskOne, collectionWorkTaskTwo, collectionWorkTaskThree)); + @SuppressWarnings("unchecked") List> tasksPostCurrentGraph = + Arrays.asList(mock(Task.class), mock(Task.class)); + + DAGTraversal.traverse(Collections.singletonList(rootTask), + new AddDependencyToLeaves(tasksPostCurrentGraph)); + + List> dependentTasksForOne = + collectionWorkTaskOne.getDependentTasks(); + List> dependentTasksForTwo = + collectionWorkTaskTwo.getDependentTasks(); + List> dependentTasksForThree = + collectionWorkTaskThree.getDependentTasks(); + + assertEquals(dependentTasksForOne.size(), 2); + assertEquals(dependentTasksForTwo.size(), 2); + assertEquals(dependentTasksForThree.size(), 2); + assertTrue(tasksPostCurrentGraph.containsAll(dependentTasksForOne)); + assertTrue(tasksPostCurrentGraph.containsAll(dependentTasksForTwo)); + assertTrue(tasksPostCurrentGraph.containsAll(dependentTasksForThree)); + +// assertTrue(dependentTasksForOne.iterator().next() instanceof DependencyCollectionTask); +// assertTrue(dependentTasksForTwo.iterator().next() instanceof DependencyCollectionTask); +// assertTrue(dependentTasksForThree.iterator().next() instanceof DependencyCollectionTask); + } + + +} \ No newline at end of file diff --git a/ql/src/test/org/apache/hadoop/hive/ql/exec/util/DAGTraversalTest.java b/ql/src/test/org/apache/hadoop/hive/ql/exec/util/DAGTraversalTest.java new file mode 100644 index 0000000000..4bce6bc58b --- /dev/null +++ b/ql/src/test/org/apache/hadoop/hive/ql/exec/util/DAGTraversalTest.java @@ -0,0 +1,82 @@ +/* + Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + */ + +package org.apache.hadoop.hive.ql.exec.util; + +import org.apache.hadoop.hive.ql.exec.Task; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.mockito.runners.MockitoJUnitRunner; + +import java.io.Serializable; +import java.util.Arrays; +import java.util.Collections; + +import static org.junit.Assert.assertEquals; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +@RunWith(MockitoJUnitRunner.class) +public class DAGTraversalTest { + + static class CountLeafFunction implements DAGTraversal.Function { + int count = 0; + + @Override + public void process(Task task) { + if (task.getDependentTasks() == null || task.getDependentTasks().isEmpty()) { + count++; + } + } + + @Override + public boolean skipProcessing(Task task) { + return false; + } + } + + @Test + public void shouldCountNumberOfLeafNodesCorrectly() { + Task taskWith5NodeTree = linearTree(5); + Task taskWith1NodeTree = linearTree(1); + Task taskWith3NodeTree = linearTree(3); + @SuppressWarnings("unchecked") Task rootTask = mock(Task.class); + when(rootTask.getDependentTasks()) + .thenReturn(Arrays.asList(taskWith1NodeTree, taskWith3NodeTree, taskWith5NodeTree)); + + CountLeafFunction function = new CountLeafFunction(); + DAGTraversal.traverse(Collections.singletonList(rootTask), function); + assertEquals(3, function.count); + } + + private Task linearTree(int numOfNodes) { + Task current = null, head = null; + for (int i = 0; i < numOfNodes; i++) { + @SuppressWarnings("unchecked") Task task = mock(Task.class); + if (current != null) { + when(current.getDependentTasks()).thenReturn(Collections.singletonList(task)); + } + if (head == null) { + head = task; + } + current = task; + } + return head; + } + +} \ No newline at end of file