diff --git ql/src/java/org/apache/hadoop/hive/ql/optimizer/physical/SparkMapJoinResolver.java ql/src/java/org/apache/hadoop/hive/ql/optimizer/physical/SparkMapJoinResolver.java new file mode 100644 index 0000000..a8b7ac6 --- /dev/null +++ ql/src/java/org/apache/hadoop/hive/ql/optimizer/physical/SparkMapJoinResolver.java @@ -0,0 +1,191 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.hadoop.hive.ql.optimizer.physical; + +import java.io.Serializable; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.Stack; + +import org.apache.hadoop.hive.conf.HiveConf; +import org.apache.hadoop.hive.ql.exec.HashTableSinkOperator; +import org.apache.hadoop.hive.ql.exec.MapJoinOperator; +import org.apache.hadoop.hive.ql.exec.Operator; +import org.apache.hadoop.hive.ql.exec.Task; +import org.apache.hadoop.hive.ql.exec.TaskFactory; +import org.apache.hadoop.hive.ql.exec.spark.SparkTask; +import org.apache.hadoop.hive.ql.lib.Dispatcher; +import org.apache.hadoop.hive.ql.lib.Node; +import org.apache.hadoop.hive.ql.lib.TaskGraphWalker; +import org.apache.hadoop.hive.ql.parse.SemanticException; +import org.apache.hadoop.hive.ql.plan.BaseWork; +import org.apache.hadoop.hive.ql.plan.OperatorDesc; +import org.apache.hadoop.hive.ql.plan.SparkWork; + +public class SparkMapJoinResolver implements PhysicalPlanResolver { + + @Override + public PhysicalContext resolve(PhysicalContext pctx) throws SemanticException { + + Dispatcher dispatcher = new SparkMapJoinTaskDispatcher(pctx); + TaskGraphWalker graphWalker = new TaskGraphWalker(dispatcher); + + ArrayList topNodes = new ArrayList(); + topNodes.addAll(pctx.getRootTasks()); + graphWalker.startWalking(topNodes, null); + return pctx; + } + + // Check whether the specified BaseWork's operator tree contains a operator + // of the specified operator class + private boolean containsOp(BaseWork work, Class clazz) { + for (Operator op : work.getAllOperators()) { + if (clazz.isInstance(op)) { + return true; + } + } + return false; + } + + class SparkMapJoinTaskDispatcher implements Dispatcher { + + private final PhysicalContext physicalContext; + + // For each BaseWork with MJ operator, we build a SparkWork for its small table BaseWorks + // This map records such information + private final Map sparkWorkMap; + + // SparkWork dependency graph - from a SparkWork with MJ operators to all + // of its parent SparkWorks for the small tables + private final Map> dependencyGraph; + + public SparkMapJoinTaskDispatcher(PhysicalContext pc) { + super(); + physicalContext = pc; + sparkWorkMap = new HashMap(); + dependencyGraph = new HashMap>(); + } + + // Move the specified work from the sparkWork to the targetWork + // Note that, in order not to break the graph (since we need it for the edges), + // we don't remove the work from the sparkWork here. The removal is done later. + private void moveWork(SparkWork sparkWork, BaseWork work, SparkWork targetWork) { + List parentWorks = sparkWork.getParents(work); + if (sparkWork != targetWork) { + targetWork.add(work); + + // If any child work for this work is already added to the targetWork earlier, + // we should connect this work with it + for (BaseWork childWork : sparkWork.getChildren(work)) { + if (targetWork.contains(childWork)) { + targetWork.connect(work, childWork, sparkWork.getEdgeProperty(work, childWork)); + } + } + } + + if (!containsOp(work, MapJoinOperator.class)) { + for (BaseWork parent : parentWorks) { + moveWork(sparkWork, parent, targetWork); + } + } else { + // Create a new SparkWork for all the small tables of this work + SparkWork parentWork = + new SparkWork(physicalContext.conf.getVar(HiveConf.ConfVars.HIVEQUERYID)); + + // Update dependency graph + if (!dependencyGraph.containsKey(targetWork)) { + dependencyGraph.put(targetWork, new ArrayList()); + } + dependencyGraph.get(targetWork).add(parentWork); + + // this work is now moved to the parentWork, thus we should + // update this information in sparkWorkMap + sparkWorkMap.put(work, parentWork); + for (BaseWork parent : parentWorks) { + if (containsOp(parent, HashTableSinkOperator.class)) { + moveWork(sparkWork, parent, parentWork); + } else { + moveWork(sparkWork, parent, targetWork); + } + } + } + } + + // Create a new SparkTask for the specified SparkWork, recursively compute + // all the parent SparkTasks that this new task is depend on, if they don't already exists. + private SparkTask createSparkTask(Task originalTask, + SparkWork sparkWork, + Map createdTaskMap) { + if (createdTaskMap.containsKey(sparkWork)) { + return createdTaskMap.get(sparkWork); + } + SparkTask resultTask = (SparkTask) TaskFactory.get(sparkWork, physicalContext.conf); + if (dependencyGraph.get(sparkWork) != null) { + for (SparkWork parentWork : dependencyGraph.get(sparkWork)) { + SparkTask parentTask = createSparkTask(originalTask, parentWork, createdTaskMap); + parentTask.addDependentTask(resultTask); + } + } else { + List> parentTasks = originalTask.getParentTasks(); + if (parentTasks != null && parentTasks.size() > 0) { + for (Task parentTask : parentTasks) { + parentTask.addDependentTask(resultTask); + } + } else { + physicalContext.addToRootTask(resultTask); + physicalContext.removeFromRootTask(originalTask); + } + } + return resultTask; + } + + @Override + public Object dispatch(Node nd, Stack stack, Object... nos) + throws SemanticException { + Task currentTask = (Task) nd; + if (currentTask instanceof SparkTask) { + SparkWork sparkWork = ((SparkTask) currentTask).getWork(); + Set leaves = sparkWork.getLeaves(); + for (BaseWork leaf : leaves) { + moveWork(sparkWork, leaf, sparkWork); + } + + // Now remove all BaseWorks in all the childSparkWorks that we created + // from the original SparkWork + for (SparkWork newSparkWork : sparkWorkMap.values()) { + for (BaseWork work : newSparkWork.getAllWorkUnsorted()) { + sparkWork.remove(work); + } + } + + Map createdTaskMap = new HashMap(); + + // Now create SparkTasks from the SparkWorks, also set up dependency + for (SparkWork work : dependencyGraph.keySet()) { + createSparkTask(currentTask, work, createdTaskMap); + } + } + + return null; + } + } +} \ No newline at end of file diff --git ql/src/java/org/apache/hadoop/hive/ql/plan/SparkWork.java ql/src/java/org/apache/hadoop/hive/ql/plan/SparkWork.java index 46d02bf..351d533 100644 --- ql/src/java/org/apache/hadoop/hive/ql/plan/SparkWork.java +++ ql/src/java/org/apache/hadoop/hive/ql/plan/SparkWork.java @@ -134,6 +134,15 @@ public void addAll(BaseWork[] bws) { } /** + * Whether the specified BaseWork is a vertex in this graph + * @param w the BaseWork to check + * @return whether specified BaseWork is in this graph + */ + public boolean contains(BaseWork w) { + return workGraph.containsKey(w); + } + + /** * add creates a new node in the graph without any connections */ public void add(BaseWork w) {