diff --git a/ql/src/java/org/apache/hadoop/hive/ql/optimizer/physical/SparkMapJoinResolver.java b/ql/src/java/org/apache/hadoop/hive/ql/optimizer/physical/SparkMapJoinResolver.java new file mode 100644 index 0000000..58995e3 --- /dev/null +++ b/ql/src/java/org/apache/hadoop/hive/ql/optimizer/physical/SparkMapJoinResolver.java @@ -0,0 +1,269 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.hadoop.hive.ql.optimizer.physical; + +import java.io.Serializable; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.HashSet; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.Stack; + +import org.apache.hadoop.hive.conf.HiveConf; +import org.apache.hadoop.hive.ql.exec.OperatorFactory; +import org.apache.hadoop.hive.ql.exec.ReduceSinkOperator; +import org.apache.hadoop.hive.ql.exec.HashTableSinkOperator; +import org.apache.hadoop.hive.ql.exec.MapJoinOperator; +import org.apache.hadoop.hive.ql.exec.Operator; +import org.apache.hadoop.hive.ql.exec.Task; +import org.apache.hadoop.hive.ql.exec.TaskFactory; +import org.apache.hadoop.hive.ql.exec.spark.SparkTask; +import org.apache.hadoop.hive.ql.lib.DefaultGraphWalker; +import org.apache.hadoop.hive.ql.lib.DefaultRuleDispatcher; +import org.apache.hadoop.hive.ql.lib.Dispatcher; +import org.apache.hadoop.hive.ql.lib.GraphWalker; +import org.apache.hadoop.hive.ql.lib.Node; +import org.apache.hadoop.hive.ql.lib.NodeProcessor; +import org.apache.hadoop.hive.ql.lib.Rule; +import org.apache.hadoop.hive.ql.lib.RuleRegExp; +import org.apache.hadoop.hive.ql.lib.TaskGraphWalker; +import org.apache.hadoop.hive.ql.parse.ParseContext; +import org.apache.hadoop.hive.ql.parse.SemanticException; +import org.apache.hadoop.hive.ql.parse.spark.GenSparkUtils; +import org.apache.hadoop.hive.ql.plan.BaseWork; +import org.apache.hadoop.hive.ql.plan.HashTableSinkDesc; +import org.apache.hadoop.hive.ql.plan.MapJoinDesc; +import org.apache.hadoop.hive.ql.plan.MapWork; +import org.apache.hadoop.hive.ql.plan.OperatorDesc; +import org.apache.hadoop.hive.ql.plan.SparkWork; + +import com.google.common.base.Preconditions; + +/** + * This class is similar to MapJoinResolver. The difference though, is that + * we split a SparkWork into two SparkWorks, one containing all the BasWorks for the + * small tables, and the other containing the BaseWork for the big table. + * + * We also set up dependency for the two new SparkWorks. + */ +public class SparkMapJoinResolver implements PhysicalPlanResolver { + @Override + public PhysicalContext resolve(PhysicalContext pctx) throws SemanticException { + + Dispatcher dispatcher = new SparkMapJoinTaskDispatcher(pctx); + TaskGraphWalker graphWalker = new TaskGraphWalker(dispatcher); + + ArrayList topNodes = new ArrayList(); + topNodes.addAll(pctx.getRootTasks()); + graphWalker.startWalking(topNodes, null); + return pctx; + } + + class SparkMapJoinTaskDispatcher implements Dispatcher { + + private final PhysicalContext physicalContext; + private final Map workMap; + private final Map> dependencyMap; + + public SparkMapJoinTaskDispatcher(PhysicalContext pc) { + physicalContext = pc; + workMap = new HashMap(); + dependencyMap = new HashMap>(); + } + + private boolean containsOp(BaseWork work, Class clazz) { + for (Operator op : work.getAllOperators()) { + if (clazz.isInstance(op)) + return true; + } + return false; + } + + private Operator getOp(BaseWork work, Class clazz) { + for (Operator op : work.getAllOperators()) { + if (clazz.isInstance(op)) + return op; + } + return null; + } + + // Merge "sourceWork" into "targetWork", also adjust workMap and + // dependencyMap to reflect this change. + private void mergeSparkWork(SparkWork sourceWork, SparkWork targetWork) { + if (sourceWork == targetWork) { + // DON'T merge self + return; + } + for (BaseWork work : sourceWork.getAllWork()) { + workMap.put(work, targetWork); + targetWork.add(work); + for (BaseWork parentWork : sourceWork.getParents(work)) { + targetWork.connect(parentWork, work, sourceWork.getEdgeProperty(parentWork, work)); + } + } + + for (Set workSet : dependencyMap.values()) { + if (workSet.contains(sourceWork)) { + workSet.remove(sourceWork); + workSet.add(targetWork); + } + } + + if (dependencyMap.containsKey(sourceWork)) { + Set setToAdd = dependencyMap.get(sourceWork); + if (!dependencyMap.containsKey(targetWork)) { + dependencyMap.put(targetWork, new HashSet()); + } + dependencyMap.get(targetWork).addAll(setToAdd); + setToAdd.clear(); + } + } + + // Create a SparkTask from the input SparkWork, and set up dependency + // with the information from dependencyMap. If SparkTasks this task depends on + // are not available yet, recursively compute those. + private SparkTask createSparkTask(Task currentTask, + SparkWork work, Map taskMap) { + if (taskMap.containsKey(work)) { + return taskMap.get(work); + } + SparkTask newTask = (SparkTask) TaskFactory.get(work, physicalContext.conf); + List> parentTasks = currentTask.getParentTasks(); + if (!dependencyMap.containsKey(work) || dependencyMap.get(work).isEmpty()) { + if (parentTasks == null) { + for (Task parentTask : parentTasks) { + parentTask.addDependentTask(newTask); + parentTask.removeDependentTask(currentTask); + } + } else { + physicalContext.addToRootTask(newTask); + physicalContext.removeFromRootTask(currentTask); + } + } else { + for (SparkWork parentWork : dependencyMap.get(work)) { + SparkTask parentTask = createSparkTask(currentTask, parentWork, taskMap); + parentTask.addDependentTask(newTask); + } + } + return newTask; + } + + // Add a dependency edge so that "sourceWork" is dependent on "targetWork" + private void addDependency(SparkWork sourceWork, SparkWork targetWork) { + if (!dependencyMap.containsKey(sourceWork)) { + dependencyMap.put(sourceWork, new HashSet()); + } + dependencyMap.get(sourceWork).add(targetWork); + } + + @Override + public Object dispatch(Node nd, Stack stack, Object... nos) + throws SemanticException { + Task currentTask = (Task) nd; + if (currentTask instanceof SparkTask + /* + TODO- uncomment this condition later - Task.CONVERTED_MAPJOIN + should get set in CommonJoinResolver + && currentTask.getTaskTag() == Task.CONVERTED_MAPJOIN */) { + + // Right now, we assume that a work will NOT contain multiple HTS/MJ. + HiveConf conf = physicalContext.getConf(); + workMap.clear(); + dependencyMap.clear(); + SparkWork sparkWork = ((SparkTask) currentTask).getWork(); + + for (BaseWork work : sparkWork.getAllWork()) { + SparkWork currentSparkWork = new SparkWork(conf.getVar(HiveConf.ConfVars.HIVEQUERYID)); + SparkWork mergedParentSparkWork = null; + currentSparkWork.add(work); + + for (BaseWork parentWork : sparkWork.getParents(work)) { + SparkWork parentSparkWork = workMap.get(parentWork); + if (containsOp(work, MapJoinOperator.class)) { + MapJoinOperator mjOp = (MapJoinOperator) getOp(work, MapJoinOperator.class); + BaseWork modifiedParentWork = replaceReduceSinkWithHashTableSink(parentWork, physicalContext, mjOp); + if (containsOp(modifiedParentWork, HashTableSinkOperator.class)) { + if (mergedParentSparkWork == null) { + mergedParentSparkWork = parentSparkWork; + addDependency(currentSparkWork, mergedParentSparkWork); + } + } else { + mergeSparkWork(parentSparkWork, currentSparkWork); + } + } else { + // current work doesn't contain MJ - we can merge it with the parent work + mergeSparkWork(parentSparkWork, currentSparkWork); + currentSparkWork.connect(parentWork, work, sparkWork.getEdgeProperty(parentWork, work)); + } + } + + workMap.put(work, currentSparkWork); + } + + // Now create SparkTasks + // TODO: need to handle ConditionalTask + Map taskMap = new HashMap(); + for (SparkWork work : workMap.values()) { + createSparkTask(currentTask, work, taskMap); + } + } + + return null; + } + + /** + * for map-join, replace the reduce sink op with hashTableSink Operator in the operator tree + * This is partly based on MapJoinResolver.adjustLocalTask in M/R + * @param smallTableMapWork + * @param phyCtx + * @return + * @throws SemanticException + */ + private BaseWork replaceReduceSinkWithHashTableSink(BaseWork smallTableMapWork, PhysicalContext phyCtx, MapJoinOperator mjOp) + throws SemanticException { + + ParseContext pc = phyCtx.getParseContext(); + ReduceSinkOperator rsOp = (ReduceSinkOperator) getOp(smallTableMapWork, ReduceSinkOperator.class); + MapJoinDesc mjDesc = mjOp.getConf(); + + //TODO: set hashtable memory usage lower if map-join followed by group-by. Investigate if this is relevant for spark + /* + HiveConf conf = pc.getConf(); + float hashtableMemoryUsage = conf.getFloatVar( + HiveConf.ConfVars.HIVEHASHTABLEMAXMEMORYUSAGE);*/ + + HashTableSinkDesc hashTableSinkDesc = new HashTableSinkDesc(mjDesc); + HashTableSinkOperator hashTableSinkOp = (HashTableSinkOperator) OperatorFactory + .get(hashTableSinkDesc); + + //get all parents of reduce sink + List> parentsOp = rsOp.getParentOperators(); + for (Operator parent : parentsOp) { + parent.replaceChild(rsOp, hashTableSinkOp); + } + + return smallTableMapWork; + } + + } +} + diff --git a/ql/src/java/org/apache/hadoop/hive/ql/parse/spark/SparkCompiler.java b/ql/src/java/org/apache/hadoop/hive/ql/parse/spark/SparkCompiler.java index 795a5d7..ba683a3 100644 --- a/ql/src/java/org/apache/hadoop/hive/ql/parse/spark/SparkCompiler.java +++ b/ql/src/java/org/apache/hadoop/hive/ql/parse/spark/SparkCompiler.java @@ -17,12 +17,25 @@ */ package org.apache.hadoop.hive.ql.parse.spark; +import java.io.Serializable; +import java.util.ArrayList; +import java.util.Deque; +import java.util.HashMap; +import java.util.LinkedHashMap; +import java.util.LinkedList; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.Stack; + import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.apache.hadoop.hive.conf.HiveConf; import org.apache.hadoop.hive.ql.Context; import org.apache.hadoop.hive.ql.exec.ConditionalTask; import org.apache.hadoop.hive.ql.exec.FileSinkOperator; +import org.apache.hadoop.hive.ql.exec.JoinOperator; +import org.apache.hadoop.hive.ql.exec.MapJoinOperator; import org.apache.hadoop.hive.ql.exec.Operator; import org.apache.hadoop.hive.ql.exec.ReduceSinkOperator; import org.apache.hadoop.hive.ql.exec.SMBMapJoinOperator; @@ -47,9 +60,12 @@ import org.apache.hadoop.hive.ql.optimizer.physical.CrossProductCheck; import org.apache.hadoop.hive.ql.optimizer.physical.NullScanOptimizer; import org.apache.hadoop.hive.ql.optimizer.physical.PhysicalContext; +import org.apache.hadoop.hive.ql.optimizer.physical.SparkMapJoinResolver; import org.apache.hadoop.hive.ql.optimizer.physical.StageIDsRearranger; import org.apache.hadoop.hive.ql.optimizer.physical.Vectorizer; import org.apache.hadoop.hive.ql.optimizer.spark.SetSparkReducerParallelism; +import org.apache.hadoop.hive.ql.optimizer.spark.SparkMapJoinOptimizer; +import org.apache.hadoop.hive.ql.optimizer.spark.SparkReduceSinkMapJoinProc; import org.apache.hadoop.hive.ql.optimizer.spark.SparkSortMergeJoinFactory; import org.apache.hadoop.hive.ql.parse.GlobalLimitCtx; import org.apache.hadoop.hive.ql.parse.ParseContext; @@ -61,17 +77,6 @@ import org.apache.hadoop.hive.ql.plan.OperatorDesc; import org.apache.hadoop.hive.ql.plan.SparkWork; import org.apache.hadoop.hive.ql.session.SessionState.LogHelper; - -import java.io.Serializable; -import java.util.ArrayList; -import java.util.Deque; -import java.util.HashMap; -import java.util.LinkedHashMap; -import java.util.LinkedList; -import java.util.List; -import java.util.Map; -import java.util.Set; -import java.util.Stack; /** * SparkCompiler translates the operator plan into SparkTasks. * @@ -112,9 +117,8 @@ protected void optimizeOperatorPlan(ParseContext pCtx, Set inputs, ReduceSinkOperator.getOperatorName() + "%"), new SetSparkReducerParallelism()); - // TODO: need to research and verify support convert join to map join optimization. - //opRules.put(new RuleRegExp(new String("Convert Join to Map-join"), - // JoinOperator.getOperatorName() + "%"), new SparkMapJoinOptimizer()); + opRules.put(new RuleRegExp(new String("Convert Join to Map-join"), + JoinOperator.getOperatorName() + "%"), new SparkMapJoinOptimizer()); // The dispatcher fires the processor corresponding to the closest matching // rule and passes the context along @@ -146,8 +150,8 @@ protected void generateTaskTree(List> rootTasks, Pa opRules.put(new RuleRegExp("Split Work - ReduceSink", ReduceSinkOperator.getOperatorName() + "%"), genSparkWork); - //opRules.put(new RuleRegExp("No more walking on ReduceSink-MapJoin", - // MapJoinOperator.getOperatorName() + "%"), new SparkReduceSinkMapJoinProc()); + opRules.put(new RuleRegExp("No more walking on ReduceSink-MapJoin", + MapJoinOperator.getOperatorName() + "%"), new SparkReduceSinkMapJoinProc()); opRules.put(new RuleRegExp("Split Work + Move/Merge - FileSink", FileSinkOperator.getOperatorName() + "%"), @@ -262,6 +266,9 @@ protected void optimizeTaskPlan(List> rootTasks, Pa PhysicalContext physicalCtx = new PhysicalContext(conf, pCtx, pCtx.getContext(), rootTasks, pCtx.getFetchTask()); + SparkMapJoinResolver r = new SparkMapJoinResolver(); + r.resolve(physicalCtx); + if (conf.getBoolVar(HiveConf.ConfVars.HIVENULLSCANOPTIMIZE)) { physicalCtx = new NullScanOptimizer().resolve(physicalCtx); } else { @@ -285,6 +292,7 @@ protected void optimizeTaskPlan(List> rootTasks, Pa } else { LOG.debug("Skipping stage id rearranger"); } + return; } }