diff --git a/ql/src/java/org/apache/hadoop/hive/ql/exec/Utilities.java b/ql/src/java/org/apache/hadoop/hive/ql/exec/Utilities.java index cd2d091a23..82b3823531 100644 --- a/ql/src/java/org/apache/hadoop/hive/ql/exec/Utilities.java +++ b/ql/src/java/org/apache/hadoop/hive/ql/exec/Utilities.java @@ -80,7 +80,6 @@ import org.apache.commons.lang3.StringEscapeUtils; import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.filecache.DistributedCache; -import org.apache.hadoop.fs.CommonConfigurationKeysPublic; import org.apache.hadoop.fs.ContentSummary; import org.apache.hadoop.fs.FSDataInputStream; import org.apache.hadoop.fs.FSDataOutputStream; @@ -119,7 +118,7 @@ import org.apache.hadoop.hive.ql.exec.spark.SparkTask; import org.apache.hadoop.hive.ql.exec.tez.DagUtils; import org.apache.hadoop.hive.ql.exec.tez.TezTask; -import org.apache.hadoop.hive.ql.exec.vector.VectorExpressionDescriptor; +import org.apache.hadoop.hive.ql.exec.util.DAGTraversal; import org.apache.hadoop.hive.ql.exec.vector.VectorizedInputFormatInterface; import org.apache.hadoop.hive.ql.exec.vector.VectorizedRowBatchCtx; import org.apache.hadoop.hive.ql.io.AcidUtils; @@ -137,7 +136,6 @@ import org.apache.hadoop.hive.ql.io.SelfDescribingInputFormatInterface; import org.apache.hadoop.hive.ql.io.merge.MergeFileMapper; import org.apache.hadoop.hive.ql.io.merge.MergeFileWork; -import org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat; import org.apache.hadoop.hive.ql.io.rcfile.truncate.ColumnTruncateMapper; import org.apache.hadoop.hive.ql.io.rcfile.truncate.ColumnTruncateWork; import org.apache.hadoop.hive.ql.log.PerfLogger; @@ -147,7 +145,6 @@ import org.apache.hadoop.hive.ql.metadata.InputEstimator; import org.apache.hadoop.hive.ql.metadata.Partition; import org.apache.hadoop.hive.ql.metadata.Table; -import org.apache.hadoop.hive.ql.optimizer.physical.Vectorizer; import org.apache.hadoop.hive.ql.parse.SemanticException; import org.apache.hadoop.hive.ql.plan.BaseWork; import org.apache.hadoop.hive.ql.plan.DynamicPartitionCtx; @@ -167,7 +164,6 @@ import org.apache.hadoop.hive.ql.stats.StatsFactory; import org.apache.hadoop.hive.ql.stats.StatsPublisher; import org.apache.hadoop.hive.serde.serdeConstants; -import org.apache.hadoop.hive.serde2.ColumnProjectionUtils; import org.apache.hadoop.hive.serde2.MetadataTypedColumnsetSerDe; import org.apache.hadoop.hive.serde2.SerDeException; import org.apache.hadoop.hive.serde2.SerDeUtils; @@ -2592,41 +2588,49 @@ public static boolean isEmptyPath(Configuration job, Path dirPath) throws IOExce } public static List getTezTasks(List> tasks) { - return getTasks(tasks, TezTask.class); + return getTasks(tasks, new TaskFilterFunction<>(TezTask.class)); } public static List getSparkTasks(List> tasks) { - return getTasks(tasks, SparkTask.class); + return getTasks(tasks, new TaskFilterFunction<>(SparkTask.class)); } public static List getMRTasks(List> tasks) { - return getTasks(tasks, ExecDriver.class); + return getTasks(tasks, new TaskFilterFunction<>(ExecDriver.class)); } - @SuppressWarnings("unchecked") - public static List getTasks(List> tasks, Class requiredType) { - List typeSpecificTasks = new ArrayList<>(); - if (tasks != null) { - Set> visited = new HashSet<>(); - while (!tasks.isEmpty()) { - List> childTasks = new ArrayList<>(); - for (Task task : tasks) { - if (visited.contains(task)) { - continue; - } - if (requiredType.isInstance(task) && !typeSpecificTasks.contains(task)) { - typeSpecificTasks.add((T) task); - } - if (task.getDependentTasks() != null) { - childTasks.addAll(task.getDependentTasks()); - } - visited.add(task); - } - // start recursion - tasks = childTasks; + static class TaskFilterFunction implements DAGTraversal.Function { + private Set> visited = new HashSet<>(); + private Class requiredType; + private List typeSpecificTasks = new ArrayList<>(); + + TaskFilterFunction(Class requiredType) { + this.requiredType = requiredType; + } + + @Override + public void process(Task task) { + if (requiredType.isInstance(task) && !typeSpecificTasks.contains(task)) { + typeSpecificTasks.add((T) task); } + visited.add(task); + } + + List getTasks() { + return typeSpecificTasks; } - return typeSpecificTasks; + + @Override + public boolean skipProcessing(Task task) { + return visited.contains(task); + } + } + + @SuppressWarnings("unchecked") + private static List getTasks(List> tasks, + TaskFilterFunction function) { + DAGTraversal.traverse(tasks, function); + return function.getTasks(); } /** diff --git a/ql/src/java/org/apache/hadoop/hive/ql/exec/repl/bootstrap/AddDependencyToLeaves.java b/ql/src/java/org/apache/hadoop/hive/ql/exec/repl/bootstrap/AddDependencyToLeaves.java new file mode 100644 index 0000000000..8fef879067 --- /dev/null +++ b/ql/src/java/org/apache/hadoop/hive/ql/exec/repl/bootstrap/AddDependencyToLeaves.java @@ -0,0 +1,33 @@ +package org.apache.hadoop.hive.ql.exec.repl.bootstrap; + +import org.apache.hadoop.hive.conf.HiveConf; +import org.apache.hadoop.hive.ql.exec.DependencyCollectionTask; +import org.apache.hadoop.hive.ql.exec.Task; +import org.apache.hadoop.hive.ql.exec.TaskFactory; +import org.apache.hadoop.hive.ql.exec.util.DAGTraversal; +import org.apache.hadoop.hive.ql.plan.DependencyCollectionWork; + +import java.io.Serializable; +import java.util.List; + +public class AddDependencyToLeaves implements DAGTraversal.Function { + private Task dependencyCollectionTask; + + AddDependencyToLeaves(HiveConf hiveConf, + List> postDependencyCollectionTasks) { + this.dependencyCollectionTask = TaskFactory.get(new DependencyCollectionWork(), hiveConf); + postDependencyCollectionTasks.forEach(dependencyCollectionTask::addDependentTask); + } + + @Override + public void process(Task task) { + if (task.getChildTasks() == null) { + task.addDependentTask(dependencyCollectionTask); + } + } + + @Override + public boolean skipProcessing(Task task) { + return task instanceof DependencyCollectionTask; + } +} diff --git a/ql/src/java/org/apache/hadoop/hive/ql/exec/repl/bootstrap/ReplLoadTask.java b/ql/src/java/org/apache/hadoop/hive/ql/exec/repl/bootstrap/ReplLoadTask.java index bf5c819e90..66be57d5cc 100644 --- a/ql/src/java/org/apache/hadoop/hive/ql/exec/repl/bootstrap/ReplLoadTask.java +++ b/ql/src/java/org/apache/hadoop/hive/ql/exec/repl/bootstrap/ReplLoadTask.java @@ -39,6 +39,7 @@ Licensed to the Apache Software Foundation (ASF) under one import org.apache.hadoop.hive.ql.exec.repl.bootstrap.load.table.LoadTable; import org.apache.hadoop.hive.ql.exec.repl.bootstrap.load.table.TableContext; import org.apache.hadoop.hive.ql.exec.repl.bootstrap.load.util.Context; +import org.apache.hadoop.hive.ql.exec.util.DAGTraversal; import org.apache.hadoop.hive.ql.parse.SemanticException; import org.apache.hadoop.hive.ql.parse.repl.ReplLogger; import org.apache.hadoop.hive.ql.plan.api.StageType; @@ -241,6 +242,12 @@ a database ( directory ) /** * There was a database update done before and we want to make sure we update the last repl * id on this database as we are now going to switch to processing a new database. + * + * This has to be last task in the graph since if there are intermediate tasks and the last.repl.id + * is a root level task then in the execution phase the root level tasks will get executed first, + * however if any of the child tasks of the bootstrap load failed then even though the bootstrap has failed + * the last repl status of the target database will return a valid value, which will not represent + * the state of the database. */ private TaskTracker updateDatabaseLastReplID(int maxTasks, Context context, Scope scope) throws SemanticException { @@ -251,7 +258,11 @@ private TaskTracker updateDatabaseLastReplID(int maxTasks, Context context, Scop TaskTracker taskTracker = new AlterDatabase(context, work.databaseEvent(context.hiveConf), work.dbNameToLoadIn, new TaskTracker(maxTasks)).tasks(); - scope.rootTasks.addAll(taskTracker.tasks()); + + AddDependencyToLeaves function = + new AddDependencyToLeaves(context.hiveConf, taskTracker.tasks()); + DAGTraversal.traverse(scope.rootTasks, function); + return taskTracker; } diff --git a/ql/src/java/org/apache/hadoop/hive/ql/exec/util/DAGTraversal.java b/ql/src/java/org/apache/hadoop/hive/ql/exec/util/DAGTraversal.java new file mode 100644 index 0000000000..59a8a447cc --- /dev/null +++ b/ql/src/java/org/apache/hadoop/hive/ql/exec/util/DAGTraversal.java @@ -0,0 +1,34 @@ +package org.apache.hadoop.hive.ql.exec.util; + +import org.apache.hadoop.hive.ql.exec.Task; + +import java.io.Serializable; +import java.util.ArrayList; +import java.util.List; + +public class DAGTraversal { + public static void traverse(List> tasks, + Function function) { + List> listOfTasks = new ArrayList<>(tasks); + while (!listOfTasks.isEmpty()) { + List> children = new ArrayList<>(); + for (Task task : listOfTasks) { + // skip processing has to be done first before continuing + if (function.skipProcessing(task)) { + continue; + } + if (task.getDependentTasks() != null) { + children.addAll(task.getDependentTasks()); + } + function.process(task); + } + listOfTasks = children; + } + } + + public interface Function { + void process(Task task); + + boolean skipProcessing(Task task); + } +} diff --git a/ql/src/test/org/apache/hadoop/hive/ql/exec/util/DAGTraversalTest.java b/ql/src/test/org/apache/hadoop/hive/ql/exec/util/DAGTraversalTest.java new file mode 100644 index 0000000000..6e280afd8a --- /dev/null +++ b/ql/src/test/org/apache/hadoop/hive/ql/exec/util/DAGTraversalTest.java @@ -0,0 +1,64 @@ +package org.apache.hadoop.hive.ql.exec.util; + +import org.apache.hadoop.hive.ql.exec.Task; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.mockito.runners.MockitoJUnitRunner; + +import java.io.Serializable; +import java.util.Arrays; +import java.util.Collections; + +import static org.junit.Assert.assertEquals; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +@RunWith(MockitoJUnitRunner.class) +public class DAGTraversalTest { + + static class CountLeafFunction implements DAGTraversal.Function { + int count = 0; + + @Override + public void process(Task task) { + if (task.getDependentTasks() == null || task.getDependentTasks().isEmpty()) { + count++; + } + } + + @Override + public boolean skipProcessing(Task task) { + return false; + } + } + + @Test + public void shouldCountNumberOfLeafNodesCorrectly() { + Task taskWith5NodeTree = linearTree(5); + Task taskWith1NodeTree = linearTree(1); + Task taskWith3NodeTree = linearTree(3); + @SuppressWarnings("unchecked") Task rootTask = mock(Task.class); + when(rootTask.getDependentTasks()) + .thenReturn(Arrays.asList(taskWith1NodeTree, taskWith3NodeTree, taskWith5NodeTree)); + + CountLeafFunction function = new CountLeafFunction(); + DAGTraversal.traverse(Collections.singletonList(rootTask), function); + assertEquals(3, function.count); + } + + private Task linearTree(int numOfNodes) { + Task current = null, head = null; + for (int i = 0; i < numOfNodes; i++) { + @SuppressWarnings("unchecked") Task task = mock(Task.class); + if (current != null) { + when(current.getDependentTasks()).thenReturn(Collections.singletonList(task)); + } + if (head == null) { + head = task; + } + current = task; + } + return head; + } + +} \ No newline at end of file