diff --git ql/src/java/org/apache/hadoop/hive/ql/optimizer/spark/SparkSortMergeJoinFactory.java ql/src/java/org/apache/hadoop/hive/ql/optimizer/spark/SparkSortMergeJoinFactory.java index 6e0ac38..aca0630 100644 --- ql/src/java/org/apache/hadoop/hive/ql/optimizer/spark/SparkSortMergeJoinFactory.java +++ ql/src/java/org/apache/hadoop/hive/ql/optimizer/spark/SparkSortMergeJoinFactory.java @@ -1,20 +1,20 @@ /** -* Licensed to the Apache Software Foundation (ASF) under one -* or more contributor license agreements. See the NOTICE file -* distributed with this work for additional information -* regarding copyright ownership. The ASF licenses this file -* to you under the Apache License, Version 2.0 (the -* "License"); you may not use this file except in compliance -* with the License. You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*/ + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ package org.apache.hadoop.hive.ql.optimizer.spark; import java.util.List; @@ -36,8 +36,8 @@ import org.apache.hadoop.hive.ql.plan.OperatorDesc; /** -* Operator factory for Spark SMBJoin processing. -*/ + * Operator factory for Spark SMBJoin processing. + */ public final class SparkSortMergeJoinFactory { private SparkSortMergeJoinFactory() { @@ -45,131 +45,79 @@ private SparkSortMergeJoinFactory() { } /** - * Get the branch on which we are invoked (walking) from. See diagram below. - * We are at the SMBJoinOp and could have come from TS of any of the input tables. - */ - public static int getPositionParent(SMBMapJoinOperator op, - Stack stack) { - int size = stack.size(); - assert size >= 2 && stack.get(size - 1) == op; - @SuppressWarnings("unchecked") - Operator parent = - (Operator) stack.get(size - 2); - List> parOp = op.getParentOperators(); - int pos = parOp.indexOf(parent); - return pos; - } - - /** - * SortMergeMapJoin processor, input is a SMBJoinOp that is part of a MapWork: - * - * MapWork: - * - * (Big) (Small) (Small) - * TS TS TS - * \ | / - * \ DS DS - * \ | / - * SMBJoinOP + * Annotate MapWork, input is a SMBJoinOp that is part of a MapWork, and its root TS operator. * * 1. Initializes the MapWork's aliasToWork, pointing to big-table's TS. * 2. Adds the bucketing information to the MapWork. * 3. Adds localwork to the MapWork, with localWork's aliasToWork pointing to small-table's TS. + * @param context proc walker context + * @param mapWork mapwork to annotate + * @param smbMapJoinOp SMB Map Join operator to get data + * @param ts Table Scan operator to get data + * @param local Whether ts is from a 'local' source (small-table that will be loaded by SMBJoin 'local' task) */ - private static class SortMergeJoinProcessor implements NodeProcessor { - - public static void setupBucketMapJoinInfo(MapWork plan, SMBMapJoinOperator currMapJoinOp) { - if (currMapJoinOp != null) { - Map>> aliasBucketFileNameMapping = - currMapJoinOp.getConf().getAliasBucketFileNameMapping(); - if (aliasBucketFileNameMapping != null) { - MapredLocalWork localPlan = plan.getMapRedLocalWork(); - if (localPlan == null) { - localPlan = currMapJoinOp.getConf().getLocalWork(); - } else { - // local plan is not null, we want to merge it into SMBMapJoinOperator's local work - MapredLocalWork smbLocalWork = currMapJoinOp.getConf().getLocalWork(); - if (smbLocalWork != null) { - localPlan.getAliasToFetchWork().putAll(smbLocalWork.getAliasToFetchWork()); - localPlan.getAliasToWork().putAll(smbLocalWork.getAliasToWork()); - } - } + public static void annotateMapWork(GenSparkProcContext context, MapWork mapWork, + SMBMapJoinOperator smbMapJoinOp, TableScanOperator ts, boolean local) + throws SemanticException { + initSMBJoinPlan(context, mapWork, ts, local); + setupBucketMapJoinInfo(mapWork, smbMapJoinOp); + } - if (localPlan == null) { - return; + private static void setupBucketMapJoinInfo(MapWork plan, SMBMapJoinOperator currMapJoinOp) { + if (currMapJoinOp != null) { + Map>> aliasBucketFileNameMapping = + currMapJoinOp.getConf().getAliasBucketFileNameMapping(); + if (aliasBucketFileNameMapping != null) { + MapredLocalWork localPlan = plan.getMapRedLocalWork(); + if (localPlan == null) { + localPlan = currMapJoinOp.getConf().getLocalWork(); + } else { + // local plan is not null, we want to merge it into SMBMapJoinOperator's local work + MapredLocalWork smbLocalWork = currMapJoinOp.getConf().getLocalWork(); + if (smbLocalWork != null) { + localPlan.getAliasToFetchWork().putAll(smbLocalWork.getAliasToFetchWork()); + localPlan.getAliasToWork().putAll(smbLocalWork.getAliasToWork()); } - plan.setMapRedLocalWork(null); - currMapJoinOp.getConf().setLocalWork(localPlan); - - BucketMapJoinContext bucketMJCxt = new BucketMapJoinContext(); - localPlan.setBucketMapjoinContext(bucketMJCxt); - bucketMJCxt.setAliasBucketFileNameMapping(aliasBucketFileNameMapping); - bucketMJCxt.setBucketFileNameMapping( - currMapJoinOp.getConf().getBigTableBucketNumMapping()); - localPlan.setInputFileChangeSensitive(true); - bucketMJCxt.setMapJoinBigTableAlias(currMapJoinOp.getConf().getBigTableAlias()); - bucketMJCxt - .setBucketMatcherClass(org.apache.hadoop.hive.ql.exec.DefaultBucketMatcher.class); - bucketMJCxt.setBigTablePartSpecToFileMapping( - currMapJoinOp.getConf().getBigTablePartSpecToFileMapping()); - - plan.setUseBucketizedHiveInputFormat(true); - } - } - } - - /** - * Initialize the mapWork. - * - * @param opProcCtx - * processing context - */ - private static void initSMBJoinPlan(MapWork mapWork, - GenSparkProcContext opProcCtx, boolean local) - throws SemanticException { - TableScanOperator ts = (TableScanOperator) opProcCtx.currentRootOperator; - String currAliasId = findAliasId(opProcCtx, ts); - GenMapRedUtils.setMapWork(mapWork, opProcCtx.parseContext, - opProcCtx.inputs, null, ts, currAliasId, opProcCtx.conf, local); - } - private static String findAliasId(GenSparkProcContext opProcCtx, TableScanOperator ts) { - for (String alias : opProcCtx.topOps.keySet()) { - if (opProcCtx.topOps.get(alias) == ts) { - return alias; + if (localPlan == null) { + return; } + plan.setMapRedLocalWork(null); + currMapJoinOp.getConf().setLocalWork(localPlan); + + BucketMapJoinContext bucketMJCxt = new BucketMapJoinContext(); + localPlan.setBucketMapjoinContext(bucketMJCxt); + bucketMJCxt.setAliasBucketFileNameMapping(aliasBucketFileNameMapping); + bucketMJCxt.setBucketFileNameMapping( + currMapJoinOp.getConf().getBigTableBucketNumMapping()); + localPlan.setInputFileChangeSensitive(true); + bucketMJCxt.setMapJoinBigTableAlias(currMapJoinOp.getConf().getBigTableAlias()); + bucketMJCxt + .setBucketMatcherClass(org.apache.hadoop.hive.ql.exec.DefaultBucketMatcher.class); + bucketMJCxt.setBigTablePartSpecToFileMapping( + currMapJoinOp.getConf().getBigTablePartSpecToFileMapping()); + + plan.setUseBucketizedHiveInputFormat(true); + } - return null; } + } - /** - * 1. Initializes the MapWork's aliasToWork, pointing to big-table's TS. - * 2. Adds the bucketing information to the MapWork. - * 3. Adds localwork to the MapWork, with localWork's aliasToWork pointing to small-table's TS. - */ - @Override - public Object process(Node nd, Stack stack, NodeProcessorCtx procCtx, - Object... nodeOutputs) throws SemanticException { - SMBMapJoinOperator mapJoin = (SMBMapJoinOperator) nd; - GenSparkProcContext ctx = (GenSparkProcContext) procCtx; - - // find the branch on which this processor was invoked - int pos = getPositionParent(mapJoin, stack); - boolean local = pos != mapJoin.getConf().getPosBigTable(); - - MapWork mapWork = ctx.smbJoinWorkMap.get(mapJoin); - initSMBJoinPlan(mapWork, ctx, local); - - // find the associated mapWork that contains this processor. - setupBucketMapJoinInfo(mapWork, mapJoin); - - // local aliases need not to hand over context further - return false; - } + private static void initSMBJoinPlan(GenSparkProcContext opProcCtx, + MapWork mapWork, TableScanOperator currentRootOperator, boolean local) + throws SemanticException { + String currAliasId = findAliasId(opProcCtx, currentRootOperator); + GenMapRedUtils.setMapWork(mapWork, opProcCtx.parseContext, + opProcCtx.inputs, null, currentRootOperator, currAliasId, opProcCtx.conf, local); } - public static NodeProcessor getTableScanMapJoin() { - return new SortMergeJoinProcessor(); + private static String findAliasId(GenSparkProcContext opProcCtx, TableScanOperator ts) { + for (String alias : opProcCtx.topOps.keySet()) { + if (opProcCtx.topOps.get(alias) == ts) { + return alias; + } + } + return null; } } diff --git ql/src/java/org/apache/hadoop/hive/ql/parse/spark/GenSparkProcContext.java ql/src/java/org/apache/hadoop/hive/ql/parse/spark/GenSparkProcContext.java index 773cfbd..447f104 100644 --- ql/src/java/org/apache/hadoop/hive/ql/parse/spark/GenSparkProcContext.java +++ ql/src/java/org/apache/hadoop/hive/ql/parse/spark/GenSparkProcContext.java @@ -36,7 +36,6 @@ import org.apache.hadoop.hive.ql.parse.ParseContext; import org.apache.hadoop.hive.ql.plan.BaseWork; import org.apache.hadoop.hive.ql.plan.DependencyCollectionWork; -import org.apache.hadoop.hive.ql.plan.MapWork; import org.apache.hadoop.hive.ql.plan.MoveWork; import org.apache.hadoop.hive.ql.plan.OperatorDesc; import org.apache.hadoop.hive.ql.plan.ReduceWork; @@ -44,6 +43,7 @@ import org.apache.hadoop.hive.ql.plan.SparkWork; import java.io.Serializable; +import java.util.HashMap; import java.util.LinkedHashMap; import java.util.LinkedHashSet; import java.util.LinkedList; @@ -103,8 +103,8 @@ // map that says which mapjoin belongs to which work item public final Map> mapJoinWorkMap; - // a map to keep track of which MapWork item holds which SMBMapJoinOp - public final Map smbJoinWorkMap; + // Map to keep track of which SMB Join operators and their information to annotate their MapWork with. + public final Map smbMapJoinCtxMap; // a map to keep track of which root generated which work public final Map, BaseWork> rootToWorkMap; @@ -160,7 +160,7 @@ public GenSparkProcContext(HiveConf conf, new LinkedHashMap>(); this.linkOpWithWorkMap = new LinkedHashMap, Map>(); this.linkWorkWithReduceSinkMap = new LinkedHashMap>(); - this.smbJoinWorkMap = new LinkedHashMap(); + this.smbMapJoinCtxMap = new HashMap(); this.mapJoinWorkMap = new LinkedHashMap>(); this.rootToWorkMap = new LinkedHashMap, BaseWork>(); this.childToWorkMap = new LinkedHashMap, List>(); diff --git ql/src/java/org/apache/hadoop/hive/ql/parse/spark/GenSparkUtils.java ql/src/java/org/apache/hadoop/hive/ql/parse/spark/GenSparkUtils.java index 0eac6e1..c19bc21 100644 --- ql/src/java/org/apache/hadoop/hive/ql/parse/spark/GenSparkUtils.java +++ ql/src/java/org/apache/hadoop/hive/ql/parse/spark/GenSparkUtils.java @@ -41,10 +41,12 @@ import org.apache.hadoop.hive.ql.exec.JoinOperator; import org.apache.hadoop.hive.ql.exec.Operator; import org.apache.hadoop.hive.ql.exec.ReduceSinkOperator; +import org.apache.hadoop.hive.ql.exec.SMBMapJoinOperator; import org.apache.hadoop.hive.ql.exec.TableScanOperator; import org.apache.hadoop.hive.ql.exec.UnionOperator; import org.apache.hadoop.hive.ql.exec.Utilities; import org.apache.hadoop.hive.ql.optimizer.GenMapRedUtils; +import org.apache.hadoop.hive.ql.optimizer.spark.SparkSortMergeJoinFactory; import org.apache.hadoop.hive.ql.parse.ParseContext; import org.apache.hadoop.hive.ql.parse.PrunedPartitionList; import org.apache.hadoop.hive.ql.parse.SemanticException; @@ -443,6 +445,25 @@ private static boolean isSame(List list1, List list2 return null; } + /** + * Fill MapWork with 'local' work and bucket information for SMB Join. + * @param context context, containing references to MapWorks and their SMB information. + * @throws SemanticException + */ + public void annotateMapWork(GenSparkProcContext context) throws SemanticException { + for (SMBMapJoinOperator smbMapJoinOp : context.smbMapJoinCtxMap.keySet()) { + //initialize mapwork with smbMapJoin information. + SparkSMBMapJoinInfo smbMapJoinInfo = context.smbMapJoinCtxMap.get(smbMapJoinOp); + MapWork work = smbMapJoinInfo.mapWork; + SparkSortMergeJoinFactory.annotateMapWork(context, work, smbMapJoinOp, + (TableScanOperator) smbMapJoinInfo.bigTableRootOp, false); + for (Operator smallTableRootOp : smbMapJoinInfo.smallTableRootOps) { + SparkSortMergeJoinFactory.annotateMapWork(context, work, smbMapJoinOp, + (TableScanOperator) smallTableRootOp, true); + } + } + } + public synchronized int getNextSeqNumber() { return ++sequenceNumber; } diff --git ql/src/java/org/apache/hadoop/hive/ql/parse/spark/GenSparkWork.java ql/src/java/org/apache/hadoop/hive/ql/parse/spark/GenSparkWork.java index cb5d4fe..3dd6d92 100644 --- ql/src/java/org/apache/hadoop/hive/ql/parse/spark/GenSparkWork.java +++ ql/src/java/org/apache/hadoop/hive/ql/parse/spark/GenSparkWork.java @@ -34,10 +34,12 @@ import org.apache.hadoop.hive.ql.exec.OperatorFactory; import org.apache.hadoop.hive.ql.exec.ReduceSinkOperator; import org.apache.hadoop.hive.ql.exec.SMBMapJoinOperator; +import org.apache.hadoop.hive.ql.exec.TableScanOperator; import org.apache.hadoop.hive.ql.lib.Node; import org.apache.hadoop.hive.ql.lib.NodeProcessor; import org.apache.hadoop.hive.ql.lib.NodeProcessorCtx; import org.apache.hadoop.hive.ql.optimizer.GenMapRedUtils; +import org.apache.hadoop.hive.ql.optimizer.spark.SparkSortMergeJoinFactory; import org.apache.hadoop.hive.ql.parse.SemanticException; import org.apache.hadoop.hive.ql.plan.BaseWork; import org.apache.hadoop.hive.ql.plan.MapWork; @@ -118,18 +120,12 @@ public Object process(Node nd, Stack stack, } else { // create a new vertex if (context.preceedingWork == null) { - if (smbOp != null) { - // This logic is for SortMergeBucket MapJoin case. - // This MapWork (of big-table, see above..) is later initialized by SparkMapJoinFactory - // processor, so don't initialize it here. Just keep track of it in the context, - // for later processing. - work = utils.createMapWork(context, root, sparkWork, null, true); - if (context.smbJoinWorkMap.get(smbOp) != null) { - throw new SemanticException("Each SMBMapJoin should be associated only with one Mapwork"); - } - context.smbJoinWorkMap.put(smbOp, (MapWork) work); - } else { + if (smbOp == null) { work = utils.createMapWork(context, root, sparkWork, null); + } else { + //save work to be initialized later with SMB information. + work = utils.createMapWork(context, root, sparkWork, null, true); + context.smbMapJoinCtxMap.get(smbOp).mapWork = (MapWork) work; } } else { work = utils.createReduceWork(context, root, sparkWork); diff --git ql/src/java/org/apache/hadoop/hive/ql/parse/spark/SparkCompiler.java ql/src/java/org/apache/hadoop/hive/ql/parse/spark/SparkCompiler.java index 3a7477a..19aae70 100644 --- ql/src/java/org/apache/hadoop/hive/ql/parse/spark/SparkCompiler.java +++ ql/src/java/org/apache/hadoop/hive/ql/parse/spark/SparkCompiler.java @@ -186,17 +186,30 @@ public Object process(Node n, Stack s, * * Some of the other processors are expecting only one traversal beyond SMBJoinOp. * We need to traverse from the big-table path only, and stop traversing on the small-table path once we reach SMBJoinOp. + * Also add some SMB join information to the context, so we can properly annotate the MapWork later on. */ opRules.put(new TypeRule(SMBMapJoinOperator.class), new NodeProcessor() { @Override public Object process(Node currNode, Stack stack, NodeProcessorCtx procCtx, Object... os) throws SemanticException { + GenSparkProcContext context = (GenSparkProcContext) procCtx; + SMBMapJoinOperator currSmbNode = (SMBMapJoinOperator) currNode; + SparkSMBMapJoinInfo smbMapJoinCtx = context.smbMapJoinCtxMap.get(currSmbNode); + if (smbMapJoinCtx == null) { + smbMapJoinCtx = new SparkSMBMapJoinInfo(); + context.smbMapJoinCtxMap.put(currSmbNode, smbMapJoinCtx); + } + for (Node stackNode : stack) { if (stackNode instanceof DummyStoreOperator) { + //If coming from small-table side, do some book-keeping, and skip traversal. + smbMapJoinCtx.smallTableRootOps.add(context.currentRootOperator); return true; } } + //If coming from big-table side, do some book-keeping, and continue traversal + smbMapJoinCtx.bigTableRootOp = context.currentRootOperator; return false; } } @@ -210,24 +223,14 @@ public Object process(Node currNode, Stack stack, GraphWalker ogw = new GenSparkWorkWalker(disp, procCtx); ogw.startWalking(topNodes, null); - - // ------------------- Second Pass ----------------------- - // SMB Join optimizations to add the "localWork" and bucketing data structures to MapWork. - opRules.clear(); - opRules.put(new TypeRule(SMBMapJoinOperator.class), - SparkSortMergeJoinFactory.getTableScanMapJoin()); - - disp = new DefaultRuleDispatcher(null, opRules, procCtx); - topNodes = new ArrayList(); - topNodes.addAll(pCtx.getTopOps().values()); - ogw = new GenSparkWorkWalker(disp, procCtx); - ogw.startWalking(topNodes, null); - // we need to clone some operator plans and remove union operators still for (BaseWork w: procCtx.workWithUnionOperators) { GenSparkUtils.getUtils().removeUnionOperators(conf, procCtx, w); } + // we need to fill MapWork with 'local' work and bucket information for SMB Join. + GenSparkUtils.getUtils().annotateMapWork(procCtx); + // finally make sure the file sink operators are set up right for (FileSinkOperator fileSink: procCtx.fileSinkSet) { GenSparkUtils.getUtils().processFileSink(procCtx, fileSink); diff --git ql/src/java/org/apache/hadoop/hive/ql/parse/spark/SparkSMBMapJoinInfo.java ql/src/java/org/apache/hadoop/hive/ql/parse/spark/SparkSMBMapJoinInfo.java new file mode 100644 index 0000000..9dad202 --- /dev/null +++ ql/src/java/org/apache/hadoop/hive/ql/parse/spark/SparkSMBMapJoinInfo.java @@ -0,0 +1,33 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.hadoop.hive.ql.parse.spark; + +import org.apache.hadoop.hive.ql.exec.Operator; +import org.apache.hadoop.hive.ql.plan.MapWork; + +import java.util.ArrayList; +import java.util.List; + +/** + * Data structure to keep track of SMBMapJoin operators during query compilation for Spark. + */ +public class SparkSMBMapJoinInfo { + Operator bigTableRootOp; + List> smallTableRootOps = new ArrayList>(); + MapWork mapWork; +}