diff --git ql/src/java/org/apache/hadoop/hive/ql/optimizer/physical/SparkMapJoinResolver.java ql/src/java/org/apache/hadoop/hive/ql/optimizer/physical/SparkMapJoinResolver.java index a8b7ac6..d47163b 100644 --- ql/src/java/org/apache/hadoop/hive/ql/optimizer/physical/SparkMapJoinResolver.java +++ ql/src/java/org/apache/hadoop/hive/ql/optimizer/physical/SparkMapJoinResolver.java @@ -21,23 +21,31 @@ import java.io.Serializable; import java.util.ArrayList; import java.util.HashMap; +import java.util.LinkedHashMap; import java.util.List; import java.util.Map; import java.util.Set; import java.util.Stack; +import com.google.common.base.Preconditions; +import org.apache.hadoop.fs.Path; import org.apache.hadoop.hive.conf.HiveConf; +import org.apache.hadoop.hive.ql.Context; import org.apache.hadoop.hive.ql.exec.HashTableSinkOperator; import org.apache.hadoop.hive.ql.exec.MapJoinOperator; import org.apache.hadoop.hive.ql.exec.Operator; import org.apache.hadoop.hive.ql.exec.Task; import org.apache.hadoop.hive.ql.exec.TaskFactory; +import org.apache.hadoop.hive.ql.exec.Utilities; import org.apache.hadoop.hive.ql.exec.spark.SparkTask; import org.apache.hadoop.hive.ql.lib.Dispatcher; import org.apache.hadoop.hive.ql.lib.Node; import org.apache.hadoop.hive.ql.lib.TaskGraphWalker; import org.apache.hadoop.hive.ql.parse.SemanticException; import org.apache.hadoop.hive.ql.plan.BaseWork; +import org.apache.hadoop.hive.ql.plan.FetchWork; +import org.apache.hadoop.hive.ql.plan.MapWork; +import org.apache.hadoop.hive.ql.plan.MapredLocalWork; import org.apache.hadoop.hive.ql.plan.OperatorDesc; import org.apache.hadoop.hive.ql.plan.SparkWork; @@ -130,6 +138,67 @@ private void moveWork(SparkWork sparkWork, BaseWork work, SparkWork targetWork) } } + private List findAllMapWorks(SparkWork sparkWork, BaseWork work) { + List result = new ArrayList(); + if (work instanceof MapWork) { + result.add((MapWork)work); + } else { + for (BaseWork parentWork : sparkWork.getParents(work)) { + result.addAll(findAllMapWorks(sparkWork, parentWork)); + } + } + return result; + } + + private void generateLocalWork(Task currentTask, + SparkWork sparkWork) { + for (BaseWork work : sparkWork.getAllWorkUnsorted()) { + if (containsOp(work, MapJoinOperator.class)) { + // If the MJ operator is in a ReduceWork, we need to go up to find + // the MapWork with the TS for the big table + Context ctx = physicalContext.getContext(); + BaseWork bigTableWork = work; + MapredLocalWork localWork = new MapredLocalWork(); + LinkedHashMap> aliasToWorkMap + = new LinkedHashMap>(); + + List> dummyOps = + new ArrayList>(bigTableWork.getDummyOps()); + localWork.setDummyParentOp(dummyOps); + + for (BaseWork parentWork : sparkWork.getParents(work)) { + if (containsOp(parentWork, HashTableSinkOperator.class)) { + for (MapWork mw : findAllMapWorks(sparkWork, parentWork)) { + aliasToWorkMap.putAll(mw.getAliasToWork()); + } + } else { + // The MJ operator is in a ReduceWork + List parentWorks = findAllMapWorks(sparkWork, parentWork); + Preconditions.checkArgument(parentWorks.size() == 1, + "AssertionError: should only contain one MapWork"); + bigTableWork = parentWorks.get(0); + } + } + localWork.setAliasToWork(aliasToWorkMap); + // TODO: enable non-staged map join optimization + localWork.setAliasToFetchWork(new LinkedHashMap()); + + Preconditions.checkArgument(bigTableWork instanceof MapWork, + "AssertionError: BaseWork with a big table should be a MapWork"); + + // Set up the shared tmp URI + Path tmpPath = Utilities.generateTmpPath(ctx.getMRTmpPath(), currentTask.getId()); + localWork.setTmpPath(tmpPath); + ((MapWork)bigTableWork).setTmpHDFSPath( + Utilities.generateTmpPath(ctx.getMRTmpPath(), currentTask.getId())); + + // TODO: set inputFileChangeSensitive and BucketMapjoinContext + + bigTableWork.setMapRedLocalWork(localWork); + } + } + } + // Create a new SparkTask for the specified SparkWork, recursively compute // all the parent SparkTasks that this new task is depend on, if they don't already exists. private SparkTask createSparkTask(Task originalTask, @@ -164,6 +233,10 @@ public Object dispatch(Node nd, Stack stack, Object... nos) Task currentTask = (Task) nd; if (currentTask instanceof SparkTask) { SparkWork sparkWork = ((SparkTask) currentTask).getWork(); + + // First, create MapredLocalWork and attach to each MapWork for big table + generateLocalWork(currentTask, sparkWork); + Set leaves = sparkWork.getLeaves(); for (BaseWork leaf : leaves) { moveWork(sparkWork, leaf, sparkWork); diff --git ql/src/java/org/apache/hadoop/hive/ql/parse/spark/SparkCompiler.java ql/src/java/org/apache/hadoop/hive/ql/parse/spark/SparkCompiler.java index 11e711e..0e6d67a 100644 --- ql/src/java/org/apache/hadoop/hive/ql/parse/spark/SparkCompiler.java +++ ql/src/java/org/apache/hadoop/hive/ql/parse/spark/SparkCompiler.java @@ -34,6 +34,8 @@ import org.apache.hadoop.hive.ql.Context; import org.apache.hadoop.hive.ql.exec.ConditionalTask; import org.apache.hadoop.hive.ql.exec.FileSinkOperator; +import org.apache.hadoop.hive.ql.exec.JoinOperator; +import org.apache.hadoop.hive.ql.exec.MapJoinOperator; import org.apache.hadoop.hive.ql.exec.Operator; import org.apache.hadoop.hive.ql.exec.ReduceSinkOperator; import org.apache.hadoop.hive.ql.exec.SMBMapJoinOperator; @@ -58,9 +60,12 @@ import org.apache.hadoop.hive.ql.optimizer.physical.CrossProductCheck; import org.apache.hadoop.hive.ql.optimizer.physical.NullScanOptimizer; import org.apache.hadoop.hive.ql.optimizer.physical.PhysicalContext; +import org.apache.hadoop.hive.ql.optimizer.physical.SparkMapJoinResolver; import org.apache.hadoop.hive.ql.optimizer.physical.StageIDsRearranger; import org.apache.hadoop.hive.ql.optimizer.physical.Vectorizer; import org.apache.hadoop.hive.ql.optimizer.spark.SetSparkReducerParallelism; +import org.apache.hadoop.hive.ql.optimizer.spark.SparkMapJoinOptimizer; +import org.apache.hadoop.hive.ql.optimizer.spark.SparkReduceSinkMapJoinProc; import org.apache.hadoop.hive.ql.optimizer.spark.SparkSortMergeJoinFactory; import org.apache.hadoop.hive.ql.parse.GlobalLimitCtx; import org.apache.hadoop.hive.ql.parse.ParseContext; @@ -113,8 +118,8 @@ protected void optimizeOperatorPlan(ParseContext pCtx, Set inputs, new SetSparkReducerParallelism()); // TODO: need to research and verify support convert join to map join optimization. - //opRules.put(new RuleRegExp(new String("Convert Join to Map-join"), - // JoinOperator.getOperatorName() + "%"), new SparkMapJoinOptimizer()); + opRules.put(new RuleRegExp(new String("Convert Join to Map-join"), + JoinOperator.getOperatorName() + "%"), new SparkMapJoinOptimizer()); // The dispatcher fires the processor corresponding to the closest matching // rule and passes the context along @@ -146,8 +151,8 @@ protected void generateTaskTree(List> rootTasks, Pa opRules.put(new RuleRegExp("Split Work - ReduceSink", ReduceSinkOperator.getOperatorName() + "%"), genSparkWork); - //opRules.put(new RuleRegExp("No more walking on ReduceSink-MapJoin", - // MapJoinOperator.getOperatorName() + "%"), new SparkReduceSinkMapJoinProc()); + opRules.put(new RuleRegExp("No more walking on ReduceSink-MapJoin", + MapJoinOperator.getOperatorName() + "%"), new SparkReduceSinkMapJoinProc()); opRules.put(new RuleRegExp("Split Work + Move/Merge - FileSink", FileSinkOperator.getOperatorName() + "%"), @@ -262,6 +267,8 @@ protected void optimizeTaskPlan(List> rootTasks, Pa PhysicalContext physicalCtx = new PhysicalContext(conf, pCtx, pCtx.getContext(), rootTasks, pCtx.getFetchTask()); + physicalCtx = new SparkMapJoinResolver().resolve(physicalCtx); + if (conf.getBoolVar(HiveConf.ConfVars.HIVENULLSCANOPTIMIZE)) { physicalCtx = new NullScanOptimizer().resolve(physicalCtx); } else {