diff --git ql/src/java/org/apache/hadoop/hive/ql/optimizer/ConvertJoinMapJoin.java ql/src/java/org/apache/hadoop/hive/ql/optimizer/ConvertJoinMapJoin.java new file mode 100644 index 0000000..dc2188f --- /dev/null +++ ql/src/java/org/apache/hadoop/hive/ql/optimizer/ConvertJoinMapJoin.java @@ -0,0 +1,189 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.hadoop.hive.ql.optimizer; + +import java.util.Set; +import java.util.Stack; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.apache.hadoop.hive.conf.HiveConf; +import org.apache.hadoop.hive.ql.exec.JoinOperator; +import org.apache.hadoop.hive.ql.exec.MapJoinOperator; +import org.apache.hadoop.hive.ql.exec.Operator; +import org.apache.hadoop.hive.ql.exec.ReduceSinkOperator; +import org.apache.hadoop.hive.ql.lib.Node; +import org.apache.hadoop.hive.ql.lib.NodeProcessor; +import org.apache.hadoop.hive.ql.lib.NodeProcessorCtx; +import org.apache.hadoop.hive.ql.metadata.HiveException; +import org.apache.hadoop.hive.ql.optimizer.MapJoinProcessor; +import org.apache.hadoop.hive.ql.parse.OptimizeTezProcContext; +import org.apache.hadoop.hive.ql.parse.ParseContext; +import org.apache.hadoop.hive.ql.parse.SemanticException; +import org.apache.hadoop.hive.ql.plan.OperatorDesc; +import org.apache.hadoop.hive.ql.plan.Statistics; + +/** + * ConvertJoinMapJoin is an optimization that replaces a commone join + * (aka shuffle join) with a map join (aka broadcast or fragment replicate + * join when possible. Map joins have restrictions on which joins can be + * converted (e.g.: full outer joins cannot be handled as map joins) as well + * as memory restrictions (one side of the join has to fit into memory). + */ +public class ConvertJoinMapJoin implements NodeProcessor { + + static final private Log LOG = LogFactory.getLog(ConvertJoinMapJoin.class.getName()); + + @Override + /* + * (non-Javadoc) + * we should ideally not modify the tree we traverse. + * However, since we need to walk the tree at any time when we modify the + * operator, we might as well do it here. + */ + public Object process(Node nd, Stack stack, + NodeProcessorCtx procCtx, Object... nodeOutputs) + throws SemanticException { + + OptimizeTezProcContext context = (OptimizeTezProcContext) procCtx; + + if (!context.conf.getBoolVar(HiveConf.ConfVars.HIVECONVERTJOIN)) { + return null; + } + + JoinOperator joinOp = (JoinOperator) nd; + + Set bigTableCandidateSet = MapJoinProcessor. + getBigTableCandidates(joinOp.getConf().getConds()); + + long maxSize = context.conf.getLongVar( + HiveConf.ConfVars.HIVECONVERTJOINNOCONDITIONALTASKTHRESHOLD); + + int bigTablePosition = -1; + + Statistics bigInputStat = null; + long totalSize = 0; + int pos = 0; + + // bigTableFound means we've encountered a table that's bigger than the + // max. This table is either the the big table or we cannot convert. + boolean bigTableFound = false; + + for (Operator parentOp : joinOp.getParentOperators()) { + + Statistics currInputStat = null; + try { + currInputStat = parentOp.getStatistics(context.conf); + } catch (HiveException e) { + return null; + } + + long inputSize = currInputStat.getNumberOfBytes(); + if ((bigInputStat == null) || + ((bigInputStat != null) && + (inputSize > bigInputStat.getNumberOfBytes()))) { + + if (bigTableFound) { + // cannot convert to map join; we've already chosen a big table + // on size and there's another one that's bigger. + return null; + } + + if (inputSize > maxSize) { + if (!bigTableCandidateSet.contains(pos)) { + // can't use the current table as the big table, but it's too + // big for the map side. + return null; + } + + bigTableFound = true; + } + + if (bigInputStat != null) { + // we're replacing the current big table with a new one. Need + // to count the current one as a map table then. + totalSize += bigInputStat.getNumberOfBytes(); + } + + if (totalSize > maxSize) { + // sum of small tables size in this join exceeds configured limit + // hence cannot convert. + return null; + } + + if (bigTableCandidateSet.contains(pos)) { + bigTablePosition = pos; + bigInputStat = currInputStat; + } + } else { + totalSize += currInputStat.getNumberOfBytes(); + if (totalSize > maxSize) { + // cannot hold all map tables in memory. Cannot convert. + return null; + } + } + pos++; + } + + if (bigTablePosition == -1) { + // all tables have size 0. We let the suffle join handle this case. + return null; + } + + /* + * Once we have decided on the map join, the tree would transform from + * + * | | + * Join MapJoin + * / \ / \ + * RS RS ---> RS TS (big table) + * / \ / + * TS TS TS (small table) + * + * for tez. + */ + + // convert to a map join operator with this information + ParseContext parseContext = context.parseContext; + MapJoinOperator mapJoinOp = MapJoinProcessor. + convertJoinOpMapJoinOp(parseContext.getOpParseCtx(), + joinOp, parseContext.getJoinContext().get(joinOp), bigTablePosition, true, false); + + Operator parentBigTableOp + = mapJoinOp.getParentOperators().get(bigTablePosition); + + if (parentBigTableOp instanceof ReduceSinkOperator) { + mapJoinOp.getParentOperators().remove(bigTablePosition); + if (!(mapJoinOp.getParentOperators().contains( + parentBigTableOp.getParentOperators().get(0)))) { + mapJoinOp.getParentOperators().add(bigTablePosition, + parentBigTableOp.getParentOperators().get(0)); + } + parentBigTableOp.getParentOperators().get(0).removeChild(parentBigTableOp); + for (Operator op : mapJoinOp.getParentOperators()) { + if (!(op.getChildOperators().contains(mapJoinOp))) { + op.getChildOperators().add(mapJoinOp); + } + op.getChildOperators().remove(joinOp); + } + } + + return null; + } +} diff --git ql/src/java/org/apache/hadoop/hive/ql/optimizer/MapJoinProcessor.java ql/src/java/org/apache/hadoop/hive/ql/optimizer/MapJoinProcessor.java index 31ca07a..0a08bec 100644 --- ql/src/java/org/apache/hadoop/hive/ql/optimizer/MapJoinProcessor.java +++ ql/src/java/org/apache/hadoop/hive/ql/optimizer/MapJoinProcessor.java @@ -373,21 +373,90 @@ public static MapJoinOperator convertMapJoin( pos++; } - // get the join keys from old parent ReduceSink operators + // create the map-join operator + MapJoinOperator mapJoinOp = convertJoinOpMapJoinOp(opParseCtxMap, + op, joinTree, mapJoinPos, noCheckOuterJoin, validateMapJoinTree); + + + // remove old parents for (pos = 0; pos < newParentOps.size(); pos++) { - ReduceSinkOperator oldPar = (ReduceSinkOperator) oldReduceSinkParentOps.get(pos); - ReduceSinkDesc rsconf = oldPar.getConf(); + newParentOps.get(pos).removeChild(oldReduceSinkParentOps.get(pos)); + newParentOps.get(pos).getChildOperators().add(mapJoinOp); + } + + + mapJoinOp.getParentOperators().removeAll(oldReduceSinkParentOps); + mapJoinOp.setParentOperators(newParentOps); + + + // change the children of the original join operator to point to the map + // join operator + + return mapJoinOp; + } + + public static MapJoinOperator convertJoinOpMapJoinOp( + LinkedHashMap, OpParseContext> opParseCtxMap, + JoinOperator op, QBJoinTree joinTree, int mapJoinPos, boolean noCheckOuterJoin, + boolean validateMapJoinTree) + throws SemanticException { + + JoinDesc desc = op.getConf(); + JoinCondDesc[] condns = desc.getConds(); + Byte[] tagOrder = desc.getTagOrder(); + + // outer join cannot be performed on a table which is being cached + if (!noCheckOuterJoin) { + if (checkMapJoin(mapJoinPos, condns) < 0) { + throw new SemanticException(ErrorMsg.NO_OUTER_MAPJOIN.getMsg()); + } + } + + Map> keyExprMap = new HashMap>(); + + // Walk over all the sources (which are guaranteed to be reduce sink + // operators). + // The join outputs a concatenation of all the inputs. + QBJoinTree leftSrc = joinTree.getJoinSrc(); + List> oldReduceSinkParentOps = + new ArrayList>(); + if (leftSrc != null) { + // assert mapJoinPos == 0; + Operator parentOp = op.getParentOperators().get(0); + assert parentOp.getParentOperators().size() == 1; + oldReduceSinkParentOps.add(parentOp); + } + + + byte pos = 0; + for (String src : joinTree.getBaseSrc()) { + if (src != null) { + Operator parentOp = op.getParentOperators().get(pos); + assert parentOp.getParentOperators().size() == 1; + Operator grandParentOp = + parentOp.getParentOperators().get(0); + + oldReduceSinkParentOps.add(parentOp); + } + pos++; + } + + // get the join keys from old parent ReduceSink operators + for (pos = 0; pos < op.getParentOperators().size(); pos++) { + ReduceSinkOperator parent = (ReduceSinkOperator) oldReduceSinkParentOps.get(pos); + ReduceSinkDesc rsconf = parent.getConf(); List keys = rsconf.getKeyCols(); keyExprMap.put(pos, keys); } - // removing RS, only ExprNodeDesc is changed (key/value/filter exprs and colExprMap) - // others (output column-name, RR, schema) remain intact - Map colExprMap = op.getColumnExprMap(); - List outputColumnNames = op.getConf().getOutputColumnNames(); + List keyCols = keyExprMap.get(Byte.valueOf((byte) 0)); + StringBuilder keyOrder = new StringBuilder(); + for (int i = 0; i < keyCols.size(); i++) { + keyOrder.append("+"); + } + Map colExprMap = op.getColumnExprMap(); List schema = new ArrayList(op.getSchema().getSignature()); - Map> valueExprs = op.getConf().getExprs(); Map> newValueExprs = new HashMap>(); for (Map.Entry> entry : valueExprs.entrySet()) { @@ -411,45 +480,12 @@ public static MapJoinOperator convertMapJoin( } } - Map> filters = desc.getFilters(); - Map> newFilters = new HashMap>(); - for (Map.Entry> entry : filters.entrySet()) { - byte srcTag = entry.getKey(); - List filter = entry.getValue(); - - Operator terminal = oldReduceSinkParentOps.get(srcTag); - newFilters.put(srcTag, ExprNodeDescUtils.backtrack(filter, op, terminal)); - } - desc.setFilters(filters = newFilters); - - // remove old parents - for (pos = 0; pos < newParentOps.size(); pos++) { - newParentOps.get(pos).removeChild(oldReduceSinkParentOps.get(pos)); - } - - JoinCondDesc[] joinCondns = op.getConf().getConds(); - - Operator[] newPar = new Operator[newParentOps.size()]; - pos = 0; - for (Operator o : newParentOps) { - newPar[pos++] = o; - } - - List keyCols = keyExprMap.get(Byte.valueOf((byte) 0)); - StringBuilder keyOrder = new StringBuilder(); - for (int i = 0; i < keyCols.size(); i++) { - keyOrder.append("+"); - } - - TableDesc keyTableDesc = PlanUtils.getMapJoinKeyTableDesc(PlanUtils - .getFieldSchemasFromColumnList(keyCols, MAPJOINKEY_FIELDPREFIX)); - + // construct valueTableDescs and valueFilteredTableDescs List valueTableDescs = new ArrayList(); List valueFiltedTableDescs = new ArrayList(); - int[][] filterMap = desc.getFilterMap(); - for (pos = 0; pos < newParentOps.size(); pos++) { - List valueCols = newValueExprs.get(pos); + for (pos = 0; pos < op.getParentOperators().size(); pos++) { + List valueCols = newValueExprs.get(Byte.valueOf((byte) pos)); int length = valueCols.size(); List valueFilteredCols = new ArrayList(length); // deep copy expr node desc @@ -476,6 +512,19 @@ public static MapJoinOperator convertMapJoin( valueTableDescs.add(valueTableDesc); valueFiltedTableDescs.add(valueFilteredTableDesc); } + + Map> filters = desc.getFilters(); + Map> newFilters = new HashMap>(); + for (Map.Entry> entry : filters.entrySet()) { + byte srcTag = entry.getKey(); + List filter = entry.getValue(); + + Operator terminal = op.getParentOperators().get(srcTag); + newFilters.put(srcTag, ExprNodeDescUtils.backtrack(filter, op, terminal)); + } + desc.setFilters(filters = newFilters); + + // create dumpfile prefix needed to create descriptor String dumpFilePrefix = ""; if( joinTree.getMapAliases() != null ) { for(String mapAlias : joinTree.getMapAliases()) { @@ -485,6 +534,11 @@ public static MapJoinOperator convertMapJoin( } else { dumpFilePrefix = "mapfile"+PlanUtils.getCountForMapJoinDumpFilePrefix(); } + + List outputColumnNames = op.getConf().getOutputColumnNames(); + TableDesc keyTableDesc = PlanUtils.getMapJoinKeyTableDesc(PlanUtils + .getFieldSchemasFromColumnList(keyCols, MAPJOINKEY_FIELDPREFIX)); + JoinCondDesc[] joinCondns = op.getConf().getConds(); MapJoinDesc mapJoinDescriptor = new MapJoinDesc(keyExprMap, keyTableDesc, newValueExprs, valueTableDescs, valueFiltedTableDescs, outputColumnNames, mapJoinPos, joinCondns, filters, op.getConf().getNoOuterJoin(), dumpFilePrefix); @@ -492,8 +546,11 @@ public static MapJoinOperator convertMapJoin( mapJoinDescriptor.setNullSafes(desc.getNullSafes()); mapJoinDescriptor.setFilterMap(desc.getFilterMap()); + // reduce sink row resolver used to generate map join op + RowResolver outputRS = opParseCtxMap.get(op).getRowResolver(); + MapJoinOperator mapJoinOp = (MapJoinOperator) OperatorFactory.getAndMakeChild( - mapJoinDescriptor, new RowSchema(outputRS.getColumnInfos()), newPar); + mapJoinDescriptor, new RowSchema(outputRS.getColumnInfos()), op.getParentOperators()); OpParseContext ctx = new OpParseContext(outputRS); opParseCtxMap.put(mapJoinOp, ctx); @@ -501,15 +558,12 @@ public static MapJoinOperator convertMapJoin( mapJoinOp.getConf().setReversedExprs(op.getConf().getReversedExprs()); mapJoinOp.setColumnExprMap(colExprMap); - // change the children of the original join operator to point to the map - // join operator List> childOps = op.getChildOperators(); for (Operator childOp : childOps) { childOp.replaceParent(op, mapJoinOp); } mapJoinOp.setChildOperators(childOps); - mapJoinOp.setParentOperators(newParentOps); op.setChildOperators(null); op.setParentOperators(null); @@ -519,6 +573,7 @@ public static MapJoinOperator convertMapJoin( } return mapJoinOp; + } /** diff --git ql/src/java/org/apache/hadoop/hive/ql/parse/TezCompiler.java ql/src/java/org/apache/hadoop/hive/ql/parse/TezCompiler.java index 7e51310..42cc024 100644 --- ql/src/java/org/apache/hadoop/hive/ql/parse/TezCompiler.java +++ ql/src/java/org/apache/hadoop/hive/ql/parse/TezCompiler.java @@ -32,6 +32,7 @@ import org.apache.hadoop.hive.ql.Context; import org.apache.hadoop.hive.ql.exec.ConditionalTask; import org.apache.hadoop.hive.ql.exec.FileSinkOperator; +import org.apache.hadoop.hive.ql.exec.JoinOperator; import org.apache.hadoop.hive.ql.exec.Operator; import org.apache.hadoop.hive.ql.exec.ReduceSinkOperator; import org.apache.hadoop.hive.ql.exec.Task; @@ -45,6 +46,7 @@ import org.apache.hadoop.hive.ql.lib.NodeProcessor; import org.apache.hadoop.hive.ql.lib.Rule; import org.apache.hadoop.hive.ql.lib.RuleRegExp; +import org.apache.hadoop.hive.ql.optimizer.ConvertJoinMapJoin; import org.apache.hadoop.hive.ql.plan.BaseWork; import org.apache.hadoop.hive.ql.plan.MapWork; import org.apache.hadoop.hive.ql.plan.MoveWork; @@ -79,6 +81,8 @@ protected void optimizeOperatorPlan(ParseContext pCtx, Set inputs, opRules.put(new RuleRegExp(new String("Set parallelism - ReduceSink"), ReduceSinkOperator.getOperatorName() + "%"), new SetReducerParallelism()); + opRules.put(new RuleRegExp(new String("Convert Join to Map-join"), + JoinOperator.getOperatorName() + "%"), new ConvertJoinMapJoin()); // if this is an explain statement add rule to generate statistics for // the whole tree.