diff --git a/ql/src/java/org/apache/hadoop/hive/ql/exec/Utilities.java b/ql/src/java/org/apache/hadoop/hive/ql/exec/Utilities.java index 5214688d24..6b12eac78f 100644 --- a/ql/src/java/org/apache/hadoop/hive/ql/exec/Utilities.java +++ b/ql/src/java/org/apache/hadoop/hive/ql/exec/Utilities.java @@ -80,7 +80,6 @@ import org.apache.commons.lang3.StringEscapeUtils; import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.filecache.DistributedCache; -import org.apache.hadoop.fs.CommonConfigurationKeysPublic; import org.apache.hadoop.fs.ContentSummary; import org.apache.hadoop.fs.FSDataInputStream; import org.apache.hadoop.fs.FSDataOutputStream; @@ -119,7 +118,7 @@ import org.apache.hadoop.hive.ql.exec.spark.SparkTask; import org.apache.hadoop.hive.ql.exec.tez.DagUtils; import org.apache.hadoop.hive.ql.exec.tez.TezTask; -import org.apache.hadoop.hive.ql.exec.vector.VectorExpressionDescriptor; +import org.apache.hadoop.hive.ql.exec.util.EfficientDAGTraversal; import org.apache.hadoop.hive.ql.exec.vector.VectorizedInputFormatInterface; import org.apache.hadoop.hive.ql.exec.vector.VectorizedRowBatchCtx; import org.apache.hadoop.hive.ql.io.AcidUtils; @@ -137,7 +136,6 @@ import org.apache.hadoop.hive.ql.io.SelfDescribingInputFormatInterface; import org.apache.hadoop.hive.ql.io.merge.MergeFileMapper; import org.apache.hadoop.hive.ql.io.merge.MergeFileWork; -import org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat; import org.apache.hadoop.hive.ql.io.rcfile.truncate.ColumnTruncateMapper; import org.apache.hadoop.hive.ql.io.rcfile.truncate.ColumnTruncateWork; import org.apache.hadoop.hive.ql.log.PerfLogger; @@ -147,7 +145,6 @@ import org.apache.hadoop.hive.ql.metadata.InputEstimator; import org.apache.hadoop.hive.ql.metadata.Partition; import org.apache.hadoop.hive.ql.metadata.Table; -import org.apache.hadoop.hive.ql.optimizer.physical.Vectorizer; import org.apache.hadoop.hive.ql.parse.SemanticException; import org.apache.hadoop.hive.ql.plan.BaseWork; import org.apache.hadoop.hive.ql.plan.DynamicPartitionCtx; @@ -167,7 +164,6 @@ import org.apache.hadoop.hive.ql.stats.StatsFactory; import org.apache.hadoop.hive.ql.stats.StatsPublisher; import org.apache.hadoop.hive.serde.serdeConstants; -import org.apache.hadoop.hive.serde2.ColumnProjectionUtils; import org.apache.hadoop.hive.serde2.MetadataTypedColumnsetSerDe; import org.apache.hadoop.hive.serde2.SerDeException; import org.apache.hadoop.hive.serde2.SerDeUtils; @@ -2570,41 +2566,49 @@ public static boolean isEmptyPath(Configuration job, Path dirPath) throws IOExce } public static List getTezTasks(List> tasks) { - return getTasks(tasks, TezTask.class); + return getTasks(tasks, new TaskFilterFunction<>(TezTask.class)); } public static List getSparkTasks(List> tasks) { - return getTasks(tasks, SparkTask.class); + return getTasks(tasks, new TaskFilterFunction<>(SparkTask.class)); } public static List getMRTasks(List> tasks) { - return getTasks(tasks, ExecDriver.class); + return getTasks(tasks, new TaskFilterFunction<>(ExecDriver.class)); } - @SuppressWarnings("unchecked") - public static List getTasks(List> tasks, Class requiredType) { - List typeSpecificTasks = new ArrayList<>(); - if (tasks != null) { - Set> visited = new HashSet<>(); - while (!tasks.isEmpty()) { - List> childTasks = new ArrayList<>(); - for (Task task : tasks) { - if (visited.contains(task)) { - continue; - } - if (requiredType.isInstance(task) && !typeSpecificTasks.contains(task)) { - typeSpecificTasks.add((T) task); - } - if (task.getDependentTasks() != null) { - childTasks.addAll(task.getDependentTasks()); - } - visited.add(task); - } - // start recursion - tasks = childTasks; + static class TaskFilterFunction implements EfficientDAGTraversal.Function { + private Set> visited = new HashSet<>(); + private Class requiredType; + private List typeSpecificTasks = new ArrayList<>(); + + TaskFilterFunction(Class requiredType) { + this.requiredType = requiredType; + } + + @Override + public void process(Task task) { + if (requiredType.isInstance(task) && !typeSpecificTasks.contains(task)) { + typeSpecificTasks.add((T) task); } + visited.add(task); + } + + List getTasks() { + return typeSpecificTasks; } - return typeSpecificTasks; + + @Override + public boolean skipProcessing(Task task) { + return visited.contains(task); + } + } + + @SuppressWarnings("unchecked") + private static List getTasks(List> tasks, + TaskFilterFunction function) { + EfficientDAGTraversal.traverse(tasks, function); + return function.getTasks(); } /** diff --git a/ql/src/java/org/apache/hadoop/hive/ql/exec/repl/bootstrap/DependencyCollectionFunction.java b/ql/src/java/org/apache/hadoop/hive/ql/exec/repl/bootstrap/DependencyCollectionFunction.java new file mode 100644 index 0000000000..34cc091a4f --- /dev/null +++ b/ql/src/java/org/apache/hadoop/hive/ql/exec/repl/bootstrap/DependencyCollectionFunction.java @@ -0,0 +1,33 @@ +package org.apache.hadoop.hive.ql.exec.repl.bootstrap; + +import org.apache.hadoop.hive.conf.HiveConf; +import org.apache.hadoop.hive.ql.exec.DependencyCollectionTask; +import org.apache.hadoop.hive.ql.exec.Task; +import org.apache.hadoop.hive.ql.exec.TaskFactory; +import org.apache.hadoop.hive.ql.exec.util.EfficientDAGTraversal; +import org.apache.hadoop.hive.ql.plan.DependencyCollectionWork; + +import java.io.Serializable; +import java.util.List; + +public class DependencyCollectionFunction implements EfficientDAGTraversal.Function { + private Task dependencyCollectionTask; + + DependencyCollectionFunction(HiveConf hiveConf, + List> postDependencyCollectionTasks) { + this.dependencyCollectionTask = TaskFactory.get(new DependencyCollectionWork(), hiveConf); + postDependencyCollectionTasks.forEach(dependencyCollectionTask::addDependentTask); + } + + @Override + public void process(Task task) { + if (task.getChildTasks().isEmpty()) { + task.addDependentTask(dependencyCollectionTask); + } + } + + @Override + public boolean skipProcessing(Task task) { + return task instanceof DependencyCollectionTask; + } +} diff --git a/ql/src/java/org/apache/hadoop/hive/ql/exec/repl/bootstrap/ReplLoadTask.java b/ql/src/java/org/apache/hadoop/hive/ql/exec/repl/bootstrap/ReplLoadTask.java index bf5c819e90..9579fe5667 100644 --- a/ql/src/java/org/apache/hadoop/hive/ql/exec/repl/bootstrap/ReplLoadTask.java +++ b/ql/src/java/org/apache/hadoop/hive/ql/exec/repl/bootstrap/ReplLoadTask.java @@ -39,6 +39,7 @@ Licensed to the Apache Software Foundation (ASF) under one import org.apache.hadoop.hive.ql.exec.repl.bootstrap.load.table.LoadTable; import org.apache.hadoop.hive.ql.exec.repl.bootstrap.load.table.TableContext; import org.apache.hadoop.hive.ql.exec.repl.bootstrap.load.util.Context; +import org.apache.hadoop.hive.ql.exec.util.EfficientDAGTraversal; import org.apache.hadoop.hive.ql.parse.SemanticException; import org.apache.hadoop.hive.ql.parse.repl.ReplLogger; import org.apache.hadoop.hive.ql.plan.api.StageType; @@ -251,7 +252,11 @@ private TaskTracker updateDatabaseLastReplID(int maxTasks, Context context, Scop TaskTracker taskTracker = new AlterDatabase(context, work.databaseEvent(context.hiveConf), work.dbNameToLoadIn, new TaskTracker(maxTasks)).tasks(); - scope.rootTasks.addAll(taskTracker.tasks()); + + DependencyCollectionFunction function = + new DependencyCollectionFunction(context.hiveConf, taskTracker.tasks()); + EfficientDAGTraversal.traverse(scope.rootTasks, function); + return taskTracker; } diff --git a/ql/src/java/org/apache/hadoop/hive/ql/exec/util/EfficientDAGTraversal.java b/ql/src/java/org/apache/hadoop/hive/ql/exec/util/EfficientDAGTraversal.java new file mode 100644 index 0000000000..e55ac9a812 --- /dev/null +++ b/ql/src/java/org/apache/hadoop/hive/ql/exec/util/EfficientDAGTraversal.java @@ -0,0 +1,34 @@ +package org.apache.hadoop.hive.ql.exec.util; + +import org.apache.hadoop.hive.ql.exec.Task; + +import java.io.Serializable; +import java.util.ArrayList; +import java.util.List; + +public class EfficientDAGTraversal { + public static void traverse(List> tasks, + Function function) { + List> listOfTasks = new ArrayList<>(tasks); + while (!listOfTasks.isEmpty()) { + List> children = new ArrayList<>(); + for (Task task : listOfTasks) { + // skip processing has to be done first before continuing + if (function.skipProcessing(task)) { + continue; + } + if (task.getChildTasks() != null) { + children.addAll(task.getChildTasks()); + } + function.process(task); + } + listOfTasks = children; + } + } + + public interface Function { + void process(Task task); + + boolean skipProcessing(Task task); + } +}