diff --git a/ql/src/java/org/apache/hadoop/hive/ql/optimizer/MapJoinProcessor.java b/ql/src/java/org/apache/hadoop/hive/ql/optimizer/MapJoinProcessor.java index 46dcfaf..536dfc9 100644 --- a/ql/src/java/org/apache/hadoop/hive/ql/optimizer/MapJoinProcessor.java +++ b/ql/src/java/org/apache/hadoop/hive/ql/optimizer/MapJoinProcessor.java @@ -384,7 +384,7 @@ public static MapJoinOperator convertMapJoin(HiveConf conf, return mapJoinOp; } - static MapJoinOperator convertJoinOpMapJoinOp(HiveConf hconf, + public static MapJoinOperator convertJoinOpMapJoinOp(HiveConf hconf, LinkedHashMap, OpParseContext> opParseCtxMap, JoinOperator op, QBJoinTree joinTree, int mapJoinPos, boolean noCheckOuterJoin) throws SemanticException { diff --git a/ql/src/java/org/apache/hadoop/hive/ql/optimizer/spark/SparkMapJoinOptimizer.java b/ql/src/java/org/apache/hadoop/hive/ql/optimizer/spark/SparkMapJoinOptimizer.java new file mode 100644 index 0000000..6c7dc58 --- /dev/null +++ b/ql/src/java/org/apache/hadoop/hive/ql/optimizer/spark/SparkMapJoinOptimizer.java @@ -0,0 +1,476 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.hadoop.hive.ql.optimizer.spark; + +import java.util.HashSet; +import java.util.Set; +import java.util.Stack; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.apache.hadoop.hive.conf.HiveConf; +import org.apache.hadoop.hive.ql.exec.AppMasterEventOperator; +import org.apache.hadoop.hive.ql.exec.FileSinkOperator; +import org.apache.hadoop.hive.ql.exec.GroupByOperator; +import org.apache.hadoop.hive.ql.exec.JoinOperator; +import org.apache.hadoop.hive.ql.exec.MapJoinOperator; +import org.apache.hadoop.hive.ql.exec.MuxOperator; +import org.apache.hadoop.hive.ql.exec.Operator; +import org.apache.hadoop.hive.ql.exec.ReduceSinkOperator; +import org.apache.hadoop.hive.ql.lib.Node; +import org.apache.hadoop.hive.ql.lib.NodeProcessor; +import org.apache.hadoop.hive.ql.lib.NodeProcessorCtx; +import org.apache.hadoop.hive.ql.optimizer.MapJoinProcessor; +import org.apache.hadoop.hive.ql.parse.ParseContext; +import org.apache.hadoop.hive.ql.parse.SemanticException; +import org.apache.hadoop.hive.ql.parse.spark.OptimizeSparkProcContext; +import org.apache.hadoop.hive.ql.plan.DynamicPruningEventDesc; +import org.apache.hadoop.hive.ql.plan.OpTraits; +import org.apache.hadoop.hive.ql.plan.OperatorDesc; +import org.apache.hadoop.hive.ql.plan.Statistics; +/** + * SparkMapJoinOptimizer cloned from ConvertJoinMapJoin is an optimization that replaces a common join + * (aka shuffle join) with a map join (aka broadcast or fragment replicate + * join when possible. Map joins have restrictions on which joins can be + * converted (e.g.: full outer joins cannot be handled as map joins) as well + * as memory restrictions (one side of the join has to fit into memory). + */ +public class SparkMapJoinOptimizer implements NodeProcessor { + + private static final Log LOG = LogFactory.getLog(SparkMapJoinOptimizer.class.getName()); + + @SuppressWarnings("unchecked") + @Override + /* + * (non-Javadoc) we should ideally not modify the tree we traverse. However, + * since we need to walk the tree at any time when we modify the operator, we + * might as well do it here. + */ + public Object + process(Node nd, Stack stack, NodeProcessorCtx procCtx, Object... nodeOutputs) + throws SemanticException { + + OptimizeSparkProcContext context = (OptimizeSparkProcContext) procCtx; + HiveConf conf = context.getConf(); + ParseContext parseContext = context.getParseContext(); + JoinOperator joinOp = (JoinOperator) nd; + + /* + if (!conf.getBoolVar(HiveConf.ConfVars.HIVECONVERTJOIN) + && !(conf.getBoolVar(HiveConf.ConfVars.HIVE_AUTO_SORTMERGE_JOIN))) { + // we are just converting to a common merge join operator. The shuffle + // join in map-reduce case. + int pos = 0; // it doesn't matter which position we use in this case. + convertJoinSMBJoin(joinOp, context, pos, 0, false, false); + return null; + }*/ + + // if we have traits, and table info is present in the traits, we know the + // exact number of buckets. Else choose the largest number of estimated + // reducers from the parent operators. + //TODO enable later. disabling this check for now + int numBuckets = 1; + + LOG.info("Estimated number of buckets " + numBuckets); + int mapJoinConversionPos = getMapJoinConversionPos(joinOp, context, numBuckets); + /* TODO: handle this later + if (mapJoinConversionPos < 0) { + // we cannot convert to bucket map join, we cannot convert to + // map join either based on the size. Check if we can convert to SMB join. + if (conf.getBoolVar(HiveConf.ConfVars.HIVE_AUTO_SORTMERGE_JOIN) == false) { + convertJoinSMBJoin(joinOp, context, 0, 0, false, false); + return null; + } + Class bigTableMatcherClass = null; + try { + bigTableMatcherClass = + (Class) (Class.forName(HiveConf.getVar( + parseContext.getConf(), + HiveConf.ConfVars.HIVE_AUTO_SORTMERGE_JOIN_BIGTABLE_SELECTOR))); + } catch (ClassNotFoundException e) { + throw new SemanticException(e.getMessage()); + } + + BigTableSelectorForAutoSMJ bigTableMatcher = + ReflectionUtils.newInstance(bigTableMatcherClass, null); + JoinDesc joinDesc = joinOp.getConf(); + JoinCondDesc[] joinCondns = joinDesc.getConds(); + Set joinCandidates = MapJoinProcessor.getBigTableCandidates(joinCondns); + if (joinCandidates.isEmpty()) { + // This is a full outer join. This can never be a map-join + // of any type. So return false. + return false; + } + mapJoinConversionPos = + bigTableMatcher.getBigTablePosition(parseContext, joinOp, joinCandidates); + if (mapJoinConversionPos < 0) { + // contains aliases from sub-query + // we are just converting to a common merge join operator. The shuffle + // join in map-reduce case. + int pos = 0; // it doesn't matter which position we use in this case. + convertJoinSMBJoin(joinOp, context, pos, 0, false, false); + return null; + } + + if (checkConvertJoinSMBJoin(joinOp, context, mapJoinConversionPos, tezBucketJoinProcCtx)) { + convertJoinSMBJoin(joinOp, context, mapJoinConversionPos, + tezBucketJoinProcCtx.getNumBuckets(), tezBucketJoinProcCtx.isSubQuery(), true); + } else { + // we are just converting to a common merge join operator. The shuffle + // join in map-reduce case. + int pos = 0; // it doesn't matter which position we use in this case. + convertJoinSMBJoin(joinOp, context, pos, 0, false, false); + } + return null; + } + + if (numBuckets > 1) { + if (conf.getBoolVar(HiveConf.ConfVars.HIVE_CONVERT_JOIN_BUCKET_MAPJOIN_TEZ)) { + if (convertJoinBucketMapJoin(joinOp, context, mapJoinConversionPos, tezBucketJoinProcCtx)) { + return null; + } + } + }*/ + + LOG.info("Convert to non-bucketed map join"); + // check if we can convert to map join no bucket scaling. + mapJoinConversionPos = getMapJoinConversionPos(joinOp, context, 1); + + + /* + if (mapJoinConversionPos < 0) { + // we are just converting to a common merge join operator. The shuffle + // join in map-reduce case. + int pos = 0; // it doesn't matter which position we use in this case. + convertJoinSMBJoin(joinOp, context, pos, 0, false, false); + return null; + }*/ + + MapJoinOperator mapJoinOp = convertJoinMapJoin(joinOp, context, mapJoinConversionPos); + // map join operator by default has no bucket cols + mapJoinOp.setOpTraits(new OpTraits(null, -1, null)); + mapJoinOp.setStatistics(joinOp.getStatistics()); + // propagate this change till the next RS + for (Operator childOp : mapJoinOp.getChildOperators()) { + setAllChildrenTraitsToNull(childOp); + } + + return null; + } + + // replaces the join operator with a new CommonJoinOperator, removes the + // parent reduce sinks + /* + private void convertJoinSMBJoin(JoinOperator joinOp, OptimizeSparkProcContext context, + int mapJoinConversionPos, int numBuckets, boolean isSubQuery, boolean adjustParentsChildren) + throws SemanticException { + ParseContext parseContext = context.parseContext; + MapJoinDesc mapJoinDesc = null; + if (adjustParentsChildren) { + mapJoinDesc = MapJoinProcessor.getMapJoinDesc(context.conf, parseContext.getOpParseCtx(), + joinOp, parseContext.getJoinContext().get(joinOp), mapJoinConversionPos, true); + } else { + JoinDesc joinDesc = joinOp.getConf(); + // retain the original join desc in the map join. + mapJoinDesc = + new MapJoinDesc(null, null, joinDesc.getExprs(), null, null, + joinDesc.getOutputColumnNames(), mapJoinConversionPos, joinDesc.getConds(), + joinDesc.getFilters(), joinDesc.getNoOuterJoin(), null); + } + + @SuppressWarnings("unchecked") + CommonMergeJoinOperator mergeJoinOp = + (CommonMergeJoinOperator) OperatorFactory.get(new CommonMergeJoinDesc(numBuckets, + isSubQuery, mapJoinConversionPos, mapJoinDesc)); + OpTraits opTraits = + new OpTraits(joinOp.getOpTraits().getBucketColNames(), numBuckets, joinOp.getOpTraits() + .getSortCols()); + mergeJoinOp.setOpTraits(opTraits); + mergeJoinOp.setStatistics(joinOp.getStatistics()); + + for (Operator parentOp : joinOp.getParentOperators()) { + int pos = parentOp.getChildOperators().indexOf(joinOp); + parentOp.getChildOperators().remove(pos); + parentOp.getChildOperators().add(pos, mergeJoinOp); + } + + for (Operator childOp : joinOp.getChildOperators()) { + int pos = childOp.getParentOperators().indexOf(joinOp); + childOp.getParentOperators().remove(pos); + childOp.getParentOperators().add(pos, mergeJoinOp); + } + + List> childOperators = mergeJoinOp.getChildOperators(); + if (childOperators == null) { + childOperators = new ArrayList>(); + mergeJoinOp.setChildOperators(childOperators); + } + + List> parentOperators = mergeJoinOp.getParentOperators(); + if (parentOperators == null) { + parentOperators = new ArrayList>(); + mergeJoinOp.setParentOperators(parentOperators); + } + + childOperators.clear(); + parentOperators.clear(); + childOperators.addAll(joinOp.getChildOperators()); + parentOperators.addAll(joinOp.getParentOperators()); + mergeJoinOp.getConf().setGenJoinKeys(false); + + if (adjustParentsChildren) { + mergeJoinOp.getConf().setGenJoinKeys(true); + List> newParentOpList = + new ArrayList>(); + for (Operator parentOp : mergeJoinOp.getParentOperators()) { + for (Operator grandParentOp : parentOp.getParentOperators()) { + grandParentOp.getChildOperators().remove(parentOp); + grandParentOp.getChildOperators().add(mergeJoinOp); + newParentOpList.add(grandParentOp); + } + } + mergeJoinOp.getParentOperators().clear(); + mergeJoinOp.getParentOperators().addAll(newParentOpList); + List> parentOps = + new ArrayList>(mergeJoinOp.getParentOperators()); + for (Operator parentOp : parentOps) { + int parentIndex = mergeJoinOp.getParentOperators().indexOf(parentOp); + if (parentIndex == mapJoinConversionPos) { + continue; + } + + // insert the dummy store operator here + DummyStoreOperator dummyStoreOp = new TezDummyStoreOperator(); + dummyStoreOp.setParentOperators(new ArrayList>()); + dummyStoreOp.setChildOperators(new ArrayList>()); + dummyStoreOp.getChildOperators().add(mergeJoinOp); + int index = parentOp.getChildOperators().indexOf(mergeJoinOp); + parentOp.getChildOperators().remove(index); + parentOp.getChildOperators().add(index, dummyStoreOp); + dummyStoreOp.getParentOperators().add(parentOp); + mergeJoinOp.getParentOperators().remove(parentIndex); + mergeJoinOp.getParentOperators().add(parentIndex, dummyStoreOp); + } + } + mergeJoinOp.cloneOriginalParentsList(mergeJoinOp.getParentOperators()); + } + */ + private void setAllChildrenTraitsToNull(Operator currentOp) { + if (currentOp instanceof ReduceSinkOperator) { + return; + } + currentOp.setOpTraits(new OpTraits(null, -1, null)); + for (Operator childOp : currentOp.getChildOperators()) { + if ((childOp instanceof ReduceSinkOperator) || (childOp instanceof GroupByOperator)) { + break; + } + setAllChildrenTraitsToNull(childOp); + } + } + + + private void setNumberOfBucketsOnChildren(Operator currentOp) { + int numBuckets = currentOp.getOpTraits().getNumBuckets(); + for (Operatorop : currentOp.getChildOperators()) { + if (!(op instanceof ReduceSinkOperator) && !(op instanceof GroupByOperator)) { + op.getOpTraits().setNumBuckets(numBuckets); + setNumberOfBucketsOnChildren(op); + } + } + } + + /** + * This method returns the big table position in a map-join. If the given join + * cannot be converted to a map-join (This could happen for several reasons - one + * of them being presence of 2 or more big tables that cannot fit in-memory), it returns -1. + * + * Otherwise, it returns an int value that is the index of the big table in the set + * MapJoinProcessor.bigTableCandidateSet + * + * @param joinOp + * @param context + * @param buckets + * @return + */ + private int getMapJoinConversionPos(JoinOperator joinOp, OptimizeSparkProcContext context, + int buckets) { + Set bigTableCandidateSet = + MapJoinProcessor.getBigTableCandidates(joinOp.getConf().getConds()); + + long maxSize = context.getConf().getLongVar( + HiveConf.ConfVars.HIVECONVERTJOINNOCONDITIONALTASKTHRESHOLD); + + int bigTablePosition = -1; + + Statistics bigInputStat = null; + long totalSize = 0; + int pos = 0; + + // bigTableFound means we've encountered a table that's bigger than the + // max. This table is either the the big table or we cannot convert. + boolean bigTableFound = false; + + for (Operator parentOp : joinOp.getParentOperators()) { + + Statistics currInputStat = parentOp.getStatistics(); + if (currInputStat == null) { + LOG.warn("Couldn't get statistics from: "+parentOp); + return -1; + } + + long inputSize = currInputStat.getDataSize(); + if ((bigInputStat == null) || + ((bigInputStat != null) && + (inputSize > bigInputStat.getDataSize()))) { + + if (bigTableFound) { + // cannot convert to map join; we've already chosen a big table + // on size and there's another one that's bigger. + return -1; + } + + if (inputSize/buckets > maxSize) { + if (!bigTableCandidateSet.contains(pos)) { + // can't use the current table as the big table, but it's too + // big for the map side. + return -1; + } + + bigTableFound = true; + } + + if (bigInputStat != null) { + // we're replacing the current big table with a new one. Need + // to count the current one as a map table then. + totalSize += bigInputStat.getDataSize(); + } + + if (totalSize/buckets > maxSize) { + // sum of small tables size in this join exceeds configured limit + // hence cannot convert. + return -1; + } + + if (bigTableCandidateSet.contains(pos)) { + bigTablePosition = pos; + bigInputStat = currInputStat; + } + } else { + totalSize += currInputStat.getDataSize(); + if (totalSize/buckets > maxSize) { + // cannot hold all map tables in memory. Cannot convert. + return -1; + } + } + pos++; + } + + return bigTablePosition; + } + + /* + * Once we have decided on the map join, the tree would transform from + * + * | | + * Join MapJoin + * / \ / \ + * RS RS ---> RS TS (big table) + * / \ / + * TS TS TS (small table) + * + * for spark. + */ + + public MapJoinOperator convertJoinMapJoin(JoinOperator joinOp, OptimizeSparkProcContext context, + int bigTablePosition) throws SemanticException { + // bail on mux operator because currently the mux operator masks the emit keys + // of the constituent reduce sinks. + for (Operator parentOp : joinOp.getParentOperators()) { + if (parentOp instanceof MuxOperator) { + return null; + } + } + + //can safely convert the join to a map join. + ParseContext parseContext = context.getParseContext(); + MapJoinOperator mapJoinOp = + MapJoinProcessor.convertJoinOpMapJoinOp(context.getConf(), parseContext.getOpParseCtx(), joinOp, + parseContext.getJoinContext().get(joinOp), bigTablePosition, true); + + Operator parentBigTableOp = + mapJoinOp.getParentOperators().get(bigTablePosition); + if (parentBigTableOp instanceof ReduceSinkOperator) { + for (Operator p : parentBigTableOp.getParentOperators()) { + // we might have generated a dynamic partition operator chain. Since + // we're removing the reduce sink we need do remove that too. + Set> dynamicPartitionOperators = new HashSet>(); + for (Operator c : p.getChildOperators()) { + if (hasDynamicPartitionBroadcast(c)) { + dynamicPartitionOperators.add(c); + } + } + for (Operator c : dynamicPartitionOperators) { + p.removeChild(c); + } + } + mapJoinOp.getParentOperators().remove(bigTablePosition); + if (!(mapJoinOp.getParentOperators().contains(parentBigTableOp.getParentOperators().get(0)))) { + mapJoinOp.getParentOperators().add(bigTablePosition, + parentBigTableOp.getParentOperators().get(0)); + } + parentBigTableOp.getParentOperators().get(0).removeChild(parentBigTableOp); + for (Operator op : mapJoinOp.getParentOperators()) { + if (!(op.getChildOperators().contains(mapJoinOp))) { + op.getChildOperators().add(mapJoinOp); + } + op.getChildOperators().remove(joinOp); + } + } + + return mapJoinOp; + } + + private boolean hasDynamicPartitionBroadcast(Operator parent) { + boolean hasDynamicPartitionPruning = false; + + for (Operator op: parent.getChildOperators()) { + while (op != null) { + if (op instanceof AppMasterEventOperator && op.getConf() instanceof DynamicPruningEventDesc) { + // found dynamic partition pruning operator + hasDynamicPartitionPruning = true; + break; + } + + if (op instanceof ReduceSinkOperator || op instanceof FileSinkOperator) { + // crossing reduce sink or file sink means the pruning isn't for this parent. + break; + } + + if (op.getChildOperators().size() != 1) { + // dynamic partition pruning pipeline doesn't have multiple children + break; + } + + op = op.getChildOperators().get(0); + } + } + + return hasDynamicPartitionPruning; + } +} diff --git a/ql/src/java/org/apache/hadoop/hive/ql/parse/spark/GenSparkProcContext.java b/ql/src/java/org/apache/hadoop/hive/ql/parse/spark/GenSparkProcContext.java index ed88c60..d33d877 100644 --- a/ql/src/java/org/apache/hadoop/hive/ql/parse/spark/GenSparkProcContext.java +++ b/ql/src/java/org/apache/hadoop/hive/ql/parse/spark/GenSparkProcContext.java @@ -43,6 +43,7 @@ import org.apache.hadoop.hive.ql.parse.ParseContext; import org.apache.hadoop.hive.ql.plan.BaseWork; import org.apache.hadoop.hive.ql.plan.DependencyCollectionWork; +import org.apache.hadoop.hive.ql.plan.MapWork; import org.apache.hadoop.hive.ql.plan.MoveWork; import org.apache.hadoop.hive.ql.plan.OperatorDesc; import org.apache.hadoop.hive.ql.plan.SparkEdgeProperty; @@ -129,16 +130,28 @@ // remember which reducesinks we've already connected public final Set connectedReduceSinks; + // Alias to operator map, from the semantic analyzer. + // This is necessary as sometimes semantic analyzer's mapping is different than operator's own alias. + public final Map> topOps; + + // Keep track of the current table alias (from last TableScan) + public String currentAliasId; + + // Keep track of the current Table-Scan. + public TableScanOperator currentTs; + + @SuppressWarnings("unchecked") public GenSparkProcContext(HiveConf conf, ParseContext parseContext, List> moveTask, List> rootTasks, - Set inputs, Set outputs) { + Set inputs, Set outputs, Map> topOps) { this.conf = conf; this.parseContext = parseContext; this.moveTask = moveTask; this.rootTasks = rootTasks; this.inputs = inputs; this.outputs = outputs; + this.topOps = topOps; this.currentTask = (SparkTask) TaskFactory.get( new SparkWork(conf.getVar(HiveConf.ConfVars.HIVEQUERYID)), conf); this.rootTasks.add(currentTask); diff --git a/ql/src/java/org/apache/hadoop/hive/ql/parse/spark/GenSparkUtils.java b/ql/src/java/org/apache/hadoop/hive/ql/parse/spark/GenSparkUtils.java index 8e28887..3d08d49 100644 --- a/ql/src/java/org/apache/hadoop/hive/ql/parse/spark/GenSparkUtils.java +++ b/ql/src/java/org/apache/hadoop/hive/ql/parse/spark/GenSparkUtils.java @@ -23,34 +23,42 @@ import java.util.HashMap; import java.util.HashSet; import java.util.Iterator; -import java.util.List; import java.util.LinkedList; +import java.util.List; import java.util.Map; import java.util.Set; -import com.google.common.base.Strings; -import org.apache.hadoop.fs.Path; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.Path; import org.apache.hadoop.hive.conf.HiveConf; -import org.apache.hadoop.hive.ql.exec.*; -import org.apache.hadoop.hive.ql.exec.spark.SparkTask; +import org.apache.hadoop.hive.ql.exec.FetchTask; +import org.apache.hadoop.hive.ql.exec.FileSinkOperator; +import org.apache.hadoop.hive.ql.exec.GroupByOperator; +import org.apache.hadoop.hive.ql.exec.HashTableDummyOperator; +import org.apache.hadoop.hive.ql.exec.JoinOperator; +import org.apache.hadoop.hive.ql.exec.Operator; +import org.apache.hadoop.hive.ql.exec.ReduceSinkOperator; +import org.apache.hadoop.hive.ql.exec.TableScanOperator; +import org.apache.hadoop.hive.ql.exec.UnionOperator; +import org.apache.hadoop.hive.ql.exec.Utilities; import org.apache.hadoop.hive.ql.optimizer.GenMapRedUtils; import org.apache.hadoop.hive.ql.parse.ParseContext; import org.apache.hadoop.hive.ql.parse.PrunedPartitionList; import org.apache.hadoop.hive.ql.parse.SemanticException; - -import com.google.common.base.Preconditions; import org.apache.hadoop.hive.ql.plan.BaseWork; import org.apache.hadoop.hive.ql.plan.MapWork; +import org.apache.hadoop.hive.ql.plan.MapredLocalWork; import org.apache.hadoop.hive.ql.plan.OperatorDesc; import org.apache.hadoop.hive.ql.plan.ReduceWork; import org.apache.hadoop.hive.ql.plan.SparkEdgeProperty; import org.apache.hadoop.hive.ql.plan.SparkWork; -import org.apache.hadoop.hive.ql.plan.TableDesc; import org.apache.hadoop.hive.ql.plan.UnionWork; +import com.google.common.base.Preconditions; +import com.google.common.base.Strings; + /** * GenSparkUtils is a collection of shared helper methods to produce SparkWork * Cloned from GenTezUtils. @@ -313,4 +321,28 @@ public static boolean isSortNecessary(ReduceSinkOperator reduceSinkOperator) { } return true; } + + + /** + * Is an operator of the given class a child of the given operator. This is more flexible + * than GraphWalker to tell apart subclasses such as SMBMapJoinOp vs MapJoinOp that have a common name. + * @param op parent operator to start search + * @param klazz given class + * @return + * @throws SemanticException + */ + public static Operator getChildOperator(Operator op, Class klazz) throws SemanticException { + if (klazz.isInstance(op)) { + return op; + } + List> childOperators = op.getChildOperators(); + for (Operator childOp : childOperators) { + Operator result = getChildOperator(childOp, klazz); + if (result != null) { + return result; + } + } + return null; + } + } diff --git a/ql/src/java/org/apache/hadoop/hive/ql/parse/spark/GenSparkWork.java b/ql/src/java/org/apache/hadoop/hive/ql/parse/spark/GenSparkWork.java index 4f5feca..b94db6b 100644 --- a/ql/src/java/org/apache/hadoop/hive/ql/parse/spark/GenSparkWork.java +++ b/ql/src/java/org/apache/hadoop/hive/ql/parse/spark/GenSparkWork.java @@ -25,19 +25,29 @@ import java.util.Map.Entry; import java.util.Stack; -import com.google.common.base.Strings; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.apache.hadoop.hive.conf.HiveConf; -import org.apache.hadoop.hive.ql.exec.*; +import org.apache.hadoop.hive.ql.exec.HashTableDummyOperator; +import org.apache.hadoop.hive.ql.exec.JoinOperator; +import org.apache.hadoop.hive.ql.exec.MapJoinOperator; +import org.apache.hadoop.hive.ql.exec.Operator; +import org.apache.hadoop.hive.ql.exec.OperatorFactory; +import org.apache.hadoop.hive.ql.exec.ReduceSinkOperator; import org.apache.hadoop.hive.ql.lib.Node; import org.apache.hadoop.hive.ql.lib.NodeProcessor; import org.apache.hadoop.hive.ql.lib.NodeProcessorCtx; import org.apache.hadoop.hive.ql.optimizer.GenMapRedUtils; import org.apache.hadoop.hive.ql.parse.SemanticException; -import org.apache.hadoop.hive.ql.plan.*; +import org.apache.hadoop.hive.ql.plan.BaseWork; +import org.apache.hadoop.hive.ql.plan.ReduceSinkDesc; +import org.apache.hadoop.hive.ql.plan.ReduceWork; +import org.apache.hadoop.hive.ql.plan.SparkEdgeProperty; +import org.apache.hadoop.hive.ql.plan.SparkWork; +import org.apache.hadoop.hive.ql.plan.UnionWork; import com.google.common.base.Preconditions; +import com.google.common.base.Strings; /** * GenSparkWork separates the operator tree into spark tasks. diff --git a/ql/src/java/org/apache/hadoop/hive/ql/parse/spark/SparkCompiler.java b/ql/src/java/org/apache/hadoop/hive/ql/parse/spark/SparkCompiler.java index 1c663c4..6d21eee 100644 --- a/ql/src/java/org/apache/hadoop/hive/ql/parse/spark/SparkCompiler.java +++ b/ql/src/java/org/apache/hadoop/hive/ql/parse/spark/SparkCompiler.java @@ -34,6 +34,8 @@ import org.apache.hadoop.hive.ql.Context; import org.apache.hadoop.hive.ql.exec.ConditionalTask; import org.apache.hadoop.hive.ql.exec.FileSinkOperator; +import org.apache.hadoop.hive.ql.exec.JoinOperator; +import org.apache.hadoop.hive.ql.exec.MapJoinOperator; import org.apache.hadoop.hive.ql.exec.Operator; import org.apache.hadoop.hive.ql.exec.ReduceSinkOperator; import org.apache.hadoop.hive.ql.exec.TableScanOperator; @@ -59,6 +61,8 @@ import org.apache.hadoop.hive.ql.optimizer.physical.StageIDsRearranger; import org.apache.hadoop.hive.ql.optimizer.physical.Vectorizer; import org.apache.hadoop.hive.ql.optimizer.spark.SetSparkReducerParallelism; +import org.apache.hadoop.hive.ql.optimizer.spark.SparkMapJoinOptimizer; +import org.apache.hadoop.hive.ql.optimizer.spark.SparkReduceSinkMapJoinProc; import org.apache.hadoop.hive.ql.parse.GlobalLimitCtx; import org.apache.hadoop.hive.ql.parse.ParseContext; import org.apache.hadoop.hive.ql.parse.SemanticException; @@ -69,7 +73,6 @@ import org.apache.hadoop.hive.ql.plan.OperatorDesc; import org.apache.hadoop.hive.ql.plan.SparkWork; import org.apache.hadoop.hive.ql.session.SessionState.LogHelper; - /** * SparkCompiler translates the operator plan into SparkTasks. * @@ -111,8 +114,8 @@ protected void optimizeOperatorPlan(ParseContext pCtx, Set inputs, new SetSparkReducerParallelism()); // TODO: need to research and verify support convert join to map join optimization. - //opRules.put(new RuleRegExp(new String("Convert Join to Map-join"), - // JoinOperator.getOperatorName() + "%"), new ConvertJoinMapJoin()); + opRules.put(new RuleRegExp(new String("Convert Join to Map-join"), + JoinOperator.getOperatorName() + "%"), new SparkMapJoinOptimizer()); // The dispatcher fires the processor corresponding to the closest matching // rule and passes the context along @@ -136,7 +139,7 @@ protected void generateTaskTree(List> rootTasks, Pa GenSparkWork genSparkWork = new GenSparkWork(GenSparkUtils.getUtils()); GenSparkProcContext procCtx = new GenSparkProcContext( - conf, tempParseContext, mvTask, rootTasks, inputs, outputs); + conf, tempParseContext, mvTask, rootTasks, inputs, outputs, pCtx.getTopOps()); // create a walker which walks the tree in a DFS manner while maintaining // the operator stack. The dispatcher generates the plan from the operator tree @@ -144,8 +147,8 @@ protected void generateTaskTree(List> rootTasks, Pa opRules.put(new RuleRegExp("Split Work - ReduceSink", ReduceSinkOperator.getOperatorName() + "%"), genSparkWork); -// opRules.put(new RuleRegExp("No more walking on ReduceSink-MapJoin", -// MapJoinOperator.getOperatorName() + "%"), new ReduceSinkMapJoinProc()); + opRules.put(new RuleRegExp("No more walking on ReduceSink-MapJoin", + MapJoinOperator.getOperatorName() + "%"), new SparkReduceSinkMapJoinProc()); opRules.put(new RuleRegExp("Split Work + Move/Merge - FileSink", FileSinkOperator.getOperatorName() + "%"),