diff --git ql/src/java/org/apache/hadoop/hive/ql/optimizer/physical/SparkMapJoinResolver.java ql/src/java/org/apache/hadoop/hive/ql/optimizer/physical/SparkMapJoinResolver.java new file mode 100644 index 0000000..795553a --- /dev/null +++ ql/src/java/org/apache/hadoop/hive/ql/optimizer/physical/SparkMapJoinResolver.java @@ -0,0 +1,262 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.hadoop.hive.ql.optimizer.physical; + +import com.clearspring.analytics.util.Preconditions; +import org.apache.hadoop.hive.conf.HiveConf; +import org.apache.hadoop.hive.ql.exec.HashTableSinkOperator; +import org.apache.hadoop.hive.ql.exec.MapJoinOperator; +import org.apache.hadoop.hive.ql.exec.Operator; +import org.apache.hadoop.hive.ql.exec.TableScanOperator; +import org.apache.hadoop.hive.ql.exec.Task; +import org.apache.hadoop.hive.ql.exec.TaskFactory; +import org.apache.hadoop.hive.ql.exec.spark.SparkTask; +import org.apache.hadoop.hive.ql.lib.Dispatcher; +import org.apache.hadoop.hive.ql.lib.Node; +import org.apache.hadoop.hive.ql.lib.TaskGraphWalker; +import org.apache.hadoop.hive.ql.parse.SemanticException; +import org.apache.hadoop.hive.ql.plan.BaseWork; +import org.apache.hadoop.hive.ql.plan.OperatorDesc; +import org.apache.hadoop.hive.ql.plan.SparkEdgeProperty; +import org.apache.hadoop.hive.ql.plan.SparkWork; + +import java.io.Serializable; +import java.util.ArrayList; +import java.util.HashSet; +import java.util.LinkedList; +import java.util.List; +import java.util.Set; +import java.util.Stack; + +public class SparkMapJoinResolver implements PhysicalPlanResolver { + + @Override + public PhysicalContext resolve(PhysicalContext pctx) throws SemanticException { + + Dispatcher dispatcher = new SparkMapJoinTaskDispatcher(pctx); + TaskGraphWalker graphWalker = new TaskGraphWalker(dispatcher); + + ArrayList topNodes = new ArrayList(); + topNodes.addAll(pctx.getRootTasks()); + graphWalker.startWalking(topNodes, null); + return pctx; + } + + class SparkMapJoinTaskDispatcher implements Dispatcher { + + private final PhysicalContext physicalContext; + + public SparkMapJoinTaskDispatcher(PhysicalContext pc) { + super(); + physicalContext = pc; + } + + private boolean containsOp(BaseWork work, Class clazz) { + for (Operator op : work.getAllOperators()) { + if (clazz.isInstance(op)) { + return true; + } + } + return false; + } + + // Iterate over the path, add each BaseWork from the sourceWork to the targetWork. + // Stop if the currentWork is the last one on the path, OR if the currentWork has any parent. + private void addPathToWork(SparkWork sourceWork, SparkWork targetWork, List path) { + SparkEdgeProperty edge = null; + + // the last work on the path is the work with MJ + for (int i = 0; i < path.size()-1; i++) { + BaseWork current = path.get(i); + Preconditions.checkArgument(sourceWork.contains(current), + "AssertionError: a BaseWork on the path should be in the sourceWork"); + // in case like this: + // MW MW <- path start + // \ / + // RW + // | + // MW <- work with MJ + // we should NOT remove RW yet, since there's another path that will go through it. + // it will be removed when we are traversing the other path. + if (sourceWork.getParents(current).size() > 0) { + return; + } + targetWork.add(current); + if (edge != null) { + targetWork.connect(path.get(i-1), current, edge); + } + edge = sourceWork.getEdgeProperty(current, path.get(i+1)); + sourceWork.remove(current); + } + } + + // Split the incoming SparkWork into two: one with all the "surface" works (branches) + // end with HTS. These works are connected to some other works with MJ. As result, + // they are removed from the input SparkWork and added to a new SparkWork. The edges + // are preserved. + // If there is no BaseWork in the input SparkWork that contains MJ operator, + // this method returns null. + private SparkWork split(SparkWork sparkWork, Set MJsToStop) { + SparkWork newSparkWork = null; + Stack> pathStack = new Stack>(); + Set MJsviaHTS = new HashSet(); + + // We need to handle this case: + // MW HTS + // | / + // MJ + // | + // HTS + // | + // MJ + // | + // FS + // which should be differentiated with this case: + // MW + // | + // MJ + // | + // HTS + // | + // MJ + // | + // FS + // In the first case, we should stop at the first MJ, while in the second + // case, we should go down the tree until we hit the second MJ. + // This is resolved by using a "MJsToStop" set, which contains all MJ + // operators, that we should stop going down when we meet them. + // Whenever we reaches a MJ from a HTS, we remove the MJ from this set, + // so that it can be used to reach more HTSs. + + // Another case we should handle: + // MW HTS + // | / + // MJ HTS + // | / + // MJ + // | + // FS + // This two MJs can be put in the same SparkWork + + // HTS + // | + // MJ + // | + // HTS HTS + // | / + // MJ + // | + // HTS + // | + // MJ + // | + // FS + // In this case, in the first round we should stop at the first MJ. + for (BaseWork root : sparkWork.getRoots()) { + Preconditions.checkArgument(containsOp(root, TableScanOperator.class), + "AssertionError: a root BaseWork should contain a TableScanOperator"); + pathStack.clear(); + LinkedList p = new LinkedList(); + p.add(root); + pathStack.push(p); + + while (!pathStack.isEmpty()) { + LinkedList currentPath = pathStack.pop(); + BaseWork currentWork = currentPath.getLast(); + if (containsOp(currentWork, MapJoinOperator.class)) { + if (currentPath.size() > 1 && + containsOp(currentPath.get(currentPath.size() - 2), HashTableSinkOperator.class)) { + if (newSparkWork == null) { + newSparkWork = new SparkWork( + physicalContext.conf.getVar(HiveConf.ConfVars.HIVEQUERYID)); + } + addPathToWork(sparkWork, newSparkWork, currentPath); + MJsviaHTS.add(currentWork); + continue; + } + + // Don't go down the tree if this MJ is first visited in this round + // through a HTS. + if (MJsToStop.contains(currentWork)) { + continue; + } + } + + for (BaseWork childWork : sparkWork.getChildren(currentWork)) { + LinkedList next = new LinkedList(currentPath); + next.add(childWork); + pathStack.push(next); + } + } + } + + MJsToStop.removeAll(MJsviaHTS); + return newSparkWork; + } + + @Override + public Object dispatch(Node nd, Stack stack, Object... nos) + throws SemanticException { + Task currentTask = (Task) nd; + if (currentTask instanceof SparkTask) { // FIXME: add more conditions? + List resultWorks = new LinkedList(); + SparkWork sparkWork = ((SparkTask) currentTask).getWork(); + + // Find out all BaseWorks with MJ operator + Set MJstoStop = new HashSet(); + for (BaseWork bw : sparkWork.getAllWork()) { + if (containsOp(bw, MapJoinOperator.class)) { + MJstoStop.add(bw); + } + } + + SparkWork newSparkWork = split(sparkWork, MJstoStop); + // Keep iterating until we end up with a SparkWork that contains no HTS. + while (newSparkWork != null) { + resultWorks.add(newSparkWork); + newSparkWork = split(sparkWork, MJstoStop); + } + resultWorks.add(sparkWork); + + // Now convert the SparkWorks into SparkTasks + SparkTask prevTask = null; + for (SparkWork work : resultWorks) { + SparkTask task = (SparkTask) TaskFactory.get(work, physicalContext.conf); + if (prevTask == null) { + List> parentTasks = currentTask.getParentTasks(); + if (parentTasks != null && parentTasks.size() > 0) { + for (Task parentTask : parentTasks) { + parentTask.addDependentTask(task); + parentTask.removeDependentTask(currentTask); + } + } else { + physicalContext.addToRootTask(task); + physicalContext.removeFromRootTask(currentTask); + } + } else { + prevTask.addDependentTask(task); + } + prevTask = task; + } + } + + return null; + } + } +} \ No newline at end of file