diff --git ql/src/java/org/apache/hadoop/hive/ql/optimizer/physical/SparkMapJoinResolver.java ql/src/java/org/apache/hadoop/hive/ql/optimizer/physical/SparkMapJoinResolver.java new file mode 100644 index 0000000..d59290f --- /dev/null +++ ql/src/java/org/apache/hadoop/hive/ql/optimizer/physical/SparkMapJoinResolver.java @@ -0,0 +1,264 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.hadoop.hive.ql.optimizer.physical; + +import com.clearspring.analytics.util.Preconditions; +import org.apache.hadoop.hive.conf.HiveConf; +import org.apache.hadoop.hive.ql.exec.HashTableSinkOperator; +import org.apache.hadoop.hive.ql.exec.MapJoinOperator; +import org.apache.hadoop.hive.ql.exec.Operator; +import org.apache.hadoop.hive.ql.exec.TableScanOperator; +import org.apache.hadoop.hive.ql.exec.Task; +import org.apache.hadoop.hive.ql.exec.TaskFactory; +import org.apache.hadoop.hive.ql.exec.spark.SparkTask; +import org.apache.hadoop.hive.ql.lib.Dispatcher; +import org.apache.hadoop.hive.ql.lib.Node; +import org.apache.hadoop.hive.ql.lib.TaskGraphWalker; +import org.apache.hadoop.hive.ql.parse.SemanticException; +import org.apache.hadoop.hive.ql.plan.BaseWork; +import org.apache.hadoop.hive.ql.plan.OperatorDesc; +import org.apache.hadoop.hive.ql.plan.SparkEdgeProperty; +import org.apache.hadoop.hive.ql.plan.SparkWork; + +import java.io.Serializable; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.HashSet; +import java.util.LinkedList; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.Stack; + +public class SparkMapJoinResolver implements PhysicalPlanResolver { + + @Override + public PhysicalContext resolve(PhysicalContext pctx) throws SemanticException { + + Dispatcher dispatcher = new SparkMapJoinTaskDispatcher(pctx); + TaskGraphWalker graphWalker = new TaskGraphWalker(dispatcher); + + ArrayList topNodes = new ArrayList(); + topNodes.addAll(pctx.getRootTasks()); + graphWalker.startWalking(topNodes, null); + return pctx; + } + + class SparkMapJoinTaskDispatcher implements Dispatcher { + + private final PhysicalContext physicalContext; + + public SparkMapJoinTaskDispatcher(PhysicalContext pc) { + super(); + physicalContext = pc; + } + + private boolean containsOp(BaseWork work, Class clazz) { + for (Operator op : work.getAllOperators()) { + if (clazz.isInstance(op)) { + return true; + } + } + return false; + } + + // Iterate over the path, add each BaseWork from the sourceWork to the targetWork. + // Stop if the currentWork is the last one on the path, OR if the currentWork has any parent. + private void addPathToWork(SparkWork sourceWork, SparkWork targetWork, List path) { + SparkEdgeProperty edge = null; + + // the last work on the path is the work with MJ + for (int i = 0; i < path.size()-1; i++) { + BaseWork current = path.get(i); + Preconditions.checkArgument(sourceWork.contains(current), + "AssertionError: a BaseWork on the path should be in the sourceWork"); + // in case like this: + // MW MW <- path start + // \ / + // RW + // | + // MW <- work with MJ + // we should NOT remove RW yet, since there's another path that will go through it. + // it will be removed when we are traversing the other path. + if (sourceWork.getParents(current).size() > 0) { + return; + } + targetWork.add(current); + if (edge != null) { + targetWork.connect(path.get(i-1), current, edge); + } + edge = sourceWork.getEdgeProperty(current, path.get(i+1)); + sourceWork.remove(current); + } + } + + // Split the incoming SparkWork into two: one with all the "surface" works (branches) + // end with HTS. These works are connected to some other works with MJ. As result, + // they are removed from the input SparkWork and added to a new SparkWork. The edges + // are preserved. + // If there is no BaseWork in the input SparkWork that contains MJ operator, + // this method returns null. + private SparkWork split(SparkWork sparkWork, Map> MJtoHTSMap) { + SparkWork newSparkWork = null; + Stack> pathStack = new Stack>(); + + // We need to handle this case: + // MW HTS + // | / + // MJ + // | + // HTS + // | + // MJ + // | + // FS + // which should be differentiated with this case: + // MW + // | + // MJ + // | + // HTS + // | + // MJ + // | + // FS + // In the first case, we should stop at the first MJ, while in the second + // case, we should go down the tree until we hit the second MJ. + // This is resolved by using a MJ->HTS map, which maps a work with MJ to + // all parent works with HTS. In general, we shouldn't go pass a MJ until + // all the HTSs associated with it are seen. + + // Another case we should handle: + // MW HTS + // | / + // MJ HTS + // | / + // MJ + // | + // FS + // This two MJs can be put in the same SparkWork + + for (BaseWork root : sparkWork.getRoots()) { + Preconditions.checkArgument(containsOp(root, TableScanOperator.class), + "AssertionError: a root BaseWork should contain a TableScanOperator"); + pathStack.clear(); + LinkedList p = new LinkedList(); + p.add(root); + pathStack.push(p); + + while (!pathStack.isEmpty()) { + LinkedList currentPath = pathStack.pop(); + BaseWork currentWork = currentPath.getLast(); + if (containsOp(currentWork, MapJoinOperator.class)) { + Set parentWorksWithHTS = MJtoHTSMap.get(currentWork); + if (currentPath.size() > 1) { + BaseWork parentWork = currentPath.get(currentPath.size() - 2); + if (containsOp(parentWork, HashTableSinkOperator.class)) { + Preconditions.checkArgument(parentWorksWithHTS.contains(parentWork), + "AssertionError: when we are in a HTS-MJ case, the HTS should also" + + "present in the MJtoHTS map"); + if (newSparkWork == null) { + newSparkWork = new SparkWork( + physicalContext.conf.getVar(HiveConf.ConfVars.HIVEQUERYID)); + } + addPathToWork(sparkWork, newSparkWork, currentPath); + parentWorksWithHTS.remove(parentWork); + if (parentWorksWithHTS.size() == 0) { + MJtoHTSMap.remove(currentWork); + } + + // If we have seen a MJ-HTS case, don't go further + continue; + } + } + + // Also don't go further if we haven't processed + // all HTSs associated with the MJ. + if (MJtoHTSMap.containsKey(currentWork)) { + continue; + } + } + + for (BaseWork childWork : sparkWork.getChildren(currentWork)) { + LinkedList next = new LinkedList(currentPath); + next.add(childWork); + pathStack.push(next); + } + } + } + + return newSparkWork; + } + + @Override + public Object dispatch(Node nd, Stack stack, Object... nos) + throws SemanticException { + Task currentTask = (Task) nd; + if (currentTask instanceof SparkTask) { // FIXME: add more conditions? + List resultWorks = new LinkedList(); + SparkWork sparkWork = ((SparkTask) currentTask).getWork(); + + // For each BaseWork with MJ operator, find out all the parent BaseWorks with HTS, + // and add them to the map. + Map> MJtoHTSMap = new HashMap>(); + for (BaseWork bw : sparkWork.getAllWork()) { + if (containsOp(bw, MapJoinOperator.class)) { + MJtoHTSMap.put(bw, new HashSet()); + for (BaseWork pw : sparkWork.getParents(bw)) { + if (containsOp(pw, HashTableSinkOperator.class)) { + MJtoHTSMap.get(bw).add(pw); + } + } + } + } + + SparkWork newSparkWork = split(sparkWork, MJtoHTSMap); + // Keep iterating until we end up with a SparkWork that contains no HTS. + while (newSparkWork != null) { + resultWorks.add(newSparkWork); + newSparkWork = split(sparkWork, MJtoHTSMap); + } + resultWorks.add(sparkWork); + + // Now convert the SparkWorks into SparkTasks + SparkTask prevTask = null; + for (SparkWork work : resultWorks) { + SparkTask task = (SparkTask) TaskFactory.get(work, physicalContext.conf); + if (prevTask == null) { + List> parentTasks = currentTask.getParentTasks(); + if (parentTasks != null && parentTasks.size() > 0) { + for (Task parentTask : parentTasks) { + parentTask.addDependentTask(task); + parentTask.removeDependentTask(currentTask); + } + } else { + physicalContext.addToRootTask(task); + physicalContext.removeFromRootTask(currentTask); + } + } else { + prevTask.addDependentTask(task); + } + prevTask = task; + } + } + + return null; + } + } +} \ No newline at end of file diff --git ql/src/java/org/apache/hadoop/hive/ql/plan/SparkWork.java ql/src/java/org/apache/hadoop/hive/ql/plan/SparkWork.java index 66fd6b6..708fa63 100644 --- ql/src/java/org/apache/hadoop/hive/ql/plan/SparkWork.java +++ ql/src/java/org/apache/hadoop/hive/ql/plan/SparkWork.java @@ -134,6 +134,15 @@ public void addAll(BaseWork[] bws) { } /** + * Whether the specified BaseWork is a vertex in this graph + * @param w the BaseWork to check + * @return whether specified BaseWork is in this graph + */ + public boolean contains(BaseWork w) { + return workGraph.containsKey(w); + } + + /** * add creates a new node in the graph without any connections */ public void add(BaseWork w) {