diff --git ql/src/java/org/apache/hadoop/hive/ql/exec/spark/HashTableLoader.java ql/src/java/org/apache/hadoop/hive/ql/exec/spark/HashTableLoader.java index d30ae51..95dad90 100644 --- ql/src/java/org/apache/hadoop/hive/ql/exec/spark/HashTableLoader.java +++ ql/src/java/org/apache/hadoop/hive/ql/exec/spark/HashTableLoader.java @@ -76,7 +76,7 @@ public void load( } // All HashTables share the same base dir, // which is passed in as the tmp path - Path baseDir = localWork.getTmpHDFSPath(); + Path baseDir = localWork.getTmpPath(); if (baseDir == null) { return; } diff --git ql/src/java/org/apache/hadoop/hive/ql/optimizer/physical/SparkMapJoinResolver.java ql/src/java/org/apache/hadoop/hive/ql/optimizer/physical/SparkMapJoinResolver.java index 9ce1a18..96481f1 100644 --- ql/src/java/org/apache/hadoop/hive/ql/optimizer/physical/SparkMapJoinResolver.java +++ ql/src/java/org/apache/hadoop/hive/ql/optimizer/physical/SparkMapJoinResolver.java @@ -20,24 +20,31 @@ import java.io.Serializable; import java.util.ArrayList; +import java.util.Collection; import java.util.HashMap; +import java.util.LinkedHashMap; import java.util.List; import java.util.Map; import java.util.Set; import java.util.Stack; +import org.apache.hadoop.fs.Path; import org.apache.hadoop.hive.conf.HiveConf; -import org.apache.hadoop.hive.ql.exec.HashTableSinkOperator; +import org.apache.hadoop.hive.ql.Context; import org.apache.hadoop.hive.ql.exec.MapJoinOperator; import org.apache.hadoop.hive.ql.exec.Operator; +import org.apache.hadoop.hive.ql.exec.SparkHashTableSinkOperator; import org.apache.hadoop.hive.ql.exec.Task; import org.apache.hadoop.hive.ql.exec.TaskFactory; +import org.apache.hadoop.hive.ql.exec.Utilities; import org.apache.hadoop.hive.ql.exec.spark.SparkTask; import org.apache.hadoop.hive.ql.lib.Dispatcher; import org.apache.hadoop.hive.ql.lib.Node; import org.apache.hadoop.hive.ql.lib.TaskGraphWalker; import org.apache.hadoop.hive.ql.parse.SemanticException; import org.apache.hadoop.hive.ql.plan.BaseWork; +import org.apache.hadoop.hive.ql.plan.FetchWork; +import org.apache.hadoop.hive.ql.plan.MapredLocalWork; import org.apache.hadoop.hive.ql.plan.OperatorDesc; import org.apache.hadoop.hive.ql.plan.SparkWork; @@ -120,7 +127,7 @@ private void moveWork(SparkWork sparkWork, BaseWork work, SparkWork targetWork) // update this information in sparkWorkMap sparkWorkMap.put(work, parentWork); for (BaseWork parent : parentWorks) { - if (containsOp(parent, HashTableSinkOperator.class)) { + if (containsOp(parent, SparkHashTableSinkOperator.class)) { moveWork(sparkWork, parent, parentWork); } else { moveWork(sparkWork, parent, targetWork); @@ -129,6 +136,46 @@ private void moveWork(SparkWork sparkWork, BaseWork work, SparkWork targetWork) } } + private void generateLocalWork(SparkTask originalTask) { + SparkWork originalWork = originalTask.getWork(); + Collection allBaseWorks = originalWork.getAllWorkUnsorted(); + + for (BaseWork work : allBaseWorks) { + if (containsOp(work, SparkHashTableSinkOperator.class) || + containsOp(work, MapJoinOperator.class)) { + work.setMapRedLocalWork(new MapredLocalWork()); + } + } + + Context ctx = physicalContext.getContext(); + + for (BaseWork work : allBaseWorks) { + if (containsOp(work, MapJoinOperator.class)) { + Path tmpPath = Utilities.generateTmpPath(ctx.getMRTmpPath(), originalTask.getId()); + MapredLocalWork bigTableLocalWork = work.getMapRedLocalWork(); + List> dummyOps = + new ArrayList>(work.getDummyOps()); + bigTableLocalWork.setDummyParentOp(dummyOps); + + for (BaseWork parentWork : originalWork.getParents(work)) { + if (containsOp(parentWork,SparkHashTableSinkOperator.class)) { + parentWork.getMapRedLocalWork().setTmpHDFSPath(tmpPath); + parentWork.getMapRedLocalWork().setDummyParentOp( + new ArrayList>()); + } + } + + bigTableLocalWork.setAliasToWork( + new LinkedHashMap>()); + bigTableLocalWork.setAliasToFetchWork(new LinkedHashMap()); + bigTableLocalWork.setTmpPath(tmpPath); + + // TODO: set inputFileChangeSensitive and BucketMapjoinContext, + // TODO: enable non-staged mapjoin + } + } + } + // Create a new SparkTask for the specified SparkWork, recursively compute // all the parent SparkTasks that this new task is depend on, if they don't already exists. private SparkTask createSparkTask(SparkTask originalTask, @@ -167,7 +214,11 @@ public Object dispatch(Node nd, Stack stack, Object... nos) throws SemanticException { Task currentTask = (Task) nd; if (currentTask instanceof SparkTask) { - SparkWork sparkWork = ((SparkTask) currentTask).getWork(); + SparkTask sparkTask = (SparkTask) currentTask; + SparkWork sparkWork = sparkTask.getWork(); + + // Generate MapredLocalWorks for MJ and HTS + generateLocalWork(sparkTask); dependencyGraph.put(sparkWork, new ArrayList()); Set leaves = sparkWork.getLeaves(); @@ -187,7 +238,7 @@ public Object dispatch(Node nd, Stack stack, Object... nos) // Now create SparkTasks from the SparkWorks, also set up dependency for (SparkWork work : dependencyGraph.keySet()) { - createSparkTask((SparkTask)currentTask, work, createdTaskMap); + createSparkTask(sparkTask, work, createdTaskMap); } } diff --git ql/src/java/org/apache/hadoop/hive/ql/plan/MapredLocalWork.java ql/src/java/org/apache/hadoop/hive/ql/plan/MapredLocalWork.java index 785e4a0..6fbdcd2 100644 --- ql/src/java/org/apache/hadoop/hive/ql/plan/MapredLocalWork.java +++ ql/src/java/org/apache/hadoop/hive/ql/plan/MapredLocalWork.java @@ -45,7 +45,7 @@ private BucketMapJoinContext bucketMapjoinContext; private Path tmpPath; private String stageID; - // Temp HDFS path for Spark HashTable sink and loader + // Temp HDFS path for Spark HashTable sink private Path tmpHDFSPath; private List> dummyParentOp;