diff --git ql/src/java/org/apache/hadoop/hive/ql/exec/spark/HashTableLoader.java ql/src/java/org/apache/hadoop/hive/ql/exec/spark/HashTableLoader.java index d30ae51..95dad90 100644 --- ql/src/java/org/apache/hadoop/hive/ql/exec/spark/HashTableLoader.java +++ ql/src/java/org/apache/hadoop/hive/ql/exec/spark/HashTableLoader.java @@ -76,7 +76,7 @@ public void load( } // All HashTables share the same base dir, // which is passed in as the tmp path - Path baseDir = localWork.getTmpHDFSPath(); + Path baseDir = localWork.getTmpPath(); if (baseDir == null) { return; } diff --git ql/src/java/org/apache/hadoop/hive/ql/optimizer/physical/SparkMapJoinResolver.java ql/src/java/org/apache/hadoop/hive/ql/optimizer/physical/SparkMapJoinResolver.java index 4b9a6cb..91f06b4 100644 --- ql/src/java/org/apache/hadoop/hive/ql/optimizer/physical/SparkMapJoinResolver.java +++ ql/src/java/org/apache/hadoop/hive/ql/optimizer/physical/SparkMapJoinResolver.java @@ -21,23 +21,29 @@ import java.io.Serializable; import java.util.ArrayList; import java.util.HashMap; +import java.util.LinkedHashMap; import java.util.List; import java.util.Map; import java.util.Set; import java.util.Stack; +import org.apache.hadoop.fs.Path; import org.apache.hadoop.hive.conf.HiveConf; -import org.apache.hadoop.hive.ql.exec.HashTableSinkOperator; +import org.apache.hadoop.hive.ql.Context; import org.apache.hadoop.hive.ql.exec.MapJoinOperator; import org.apache.hadoop.hive.ql.exec.Operator; +import org.apache.hadoop.hive.ql.exec.SparkHashTableSinkOperator; import org.apache.hadoop.hive.ql.exec.Task; import org.apache.hadoop.hive.ql.exec.TaskFactory; +import org.apache.hadoop.hive.ql.exec.Utilities; import org.apache.hadoop.hive.ql.exec.spark.SparkTask; import org.apache.hadoop.hive.ql.lib.Dispatcher; import org.apache.hadoop.hive.ql.lib.Node; import org.apache.hadoop.hive.ql.lib.TaskGraphWalker; import org.apache.hadoop.hive.ql.parse.SemanticException; import org.apache.hadoop.hive.ql.plan.BaseWork; +import org.apache.hadoop.hive.ql.plan.FetchWork; +import org.apache.hadoop.hive.ql.plan.MapredLocalWork; import org.apache.hadoop.hive.ql.plan.OperatorDesc; import org.apache.hadoop.hive.ql.plan.SparkWork; @@ -120,7 +126,7 @@ private void moveWork(SparkWork sparkWork, BaseWork work, SparkWork targetWork) // update this information in sparkWorkMap sparkWorkMap.put(work, parentWork); for (BaseWork parent : parentWorks) { - if (containsOp(parent, HashTableSinkOperator.class)) { + if (containsOp(parent, SparkHashTableSinkOperator.class)) { moveWork(sparkWork, parent, parentWork); } else { moveWork(sparkWork, parent, targetWork); @@ -129,6 +135,44 @@ private void moveWork(SparkWork sparkWork, BaseWork work, SparkWork targetWork) } } + private void generateLocalWork(Task currentTask, + SparkWork sparkWork) { + for (BaseWork work : sparkWork.getAllWorkUnsorted()) { + if (containsOp(work, SparkHashTableSinkOperator.class) || + containsOp(work, MapJoinOperator.class)) { + work.setMapRedLocalWork(new MapredLocalWork()); + } + } + + Context ctx = physicalContext.getContext(); + + for (BaseWork work : sparkWork.getAllWorkUnsorted()) { + if (containsOp(work, MapJoinOperator.class)) { + Path tmpPath = Utilities.generateTmpPath(ctx.getMRTmpPath(), currentTask.getId()); + MapredLocalWork bigTableLocalWork = work.getMapRedLocalWork(); + List> dummyOps = + new ArrayList>(work.getDummyOps()); + bigTableLocalWork.setDummyParentOp(dummyOps); + + for (BaseWork parentWork : sparkWork.getParents(work)) { + if (containsOp(parentWork,SparkHashTableSinkOperator.class)) { + parentWork.getMapRedLocalWork().setTmpHDFSPath(tmpPath); + parentWork.getMapRedLocalWork().setDummyParentOp( + new ArrayList>()); + } + } + + bigTableLocalWork.setAliasToWork( + new LinkedHashMap>()); + bigTableLocalWork.setAliasToFetchWork(new LinkedHashMap()); + bigTableLocalWork.setTmpPath(tmpPath); + + // TODO: set inputFileChangeSensitive and BucketMapjoinContext, + // TODO: enable non-staged mapjoin + } + } + } + // Create a new SparkTask for the specified SparkWork, recursively compute // all the parent SparkTasks that this new task is depend on, if they don't already exists. private SparkTask createSparkTask(SparkTask originalTask, @@ -167,6 +211,8 @@ public Object dispatch(Node nd, Stack stack, Object... nos) if (currentTask instanceof SparkTask) { SparkWork sparkWork = ((SparkTask) currentTask).getWork(); + generateLocalWork(currentTask, sparkWork); + dependencyGraph.put(sparkWork, new ArrayList()); Set leaves = sparkWork.getLeaves(); for (BaseWork leaf : leaves) { diff --git ql/src/java/org/apache/hadoop/hive/ql/plan/MapredLocalWork.java ql/src/java/org/apache/hadoop/hive/ql/plan/MapredLocalWork.java index 785e4a0..6fbdcd2 100644 --- ql/src/java/org/apache/hadoop/hive/ql/plan/MapredLocalWork.java +++ ql/src/java/org/apache/hadoop/hive/ql/plan/MapredLocalWork.java @@ -45,7 +45,7 @@ private BucketMapJoinContext bucketMapjoinContext; private Path tmpPath; private String stageID; - // Temp HDFS path for Spark HashTable sink and loader + // Temp HDFS path for Spark HashTable sink private Path tmpHDFSPath; private List> dummyParentOp;