diff --git ql/src/java/org/apache/hadoop/hive/ql/optimizer/physical/GenSparkSkewJoinProcessor.java ql/src/java/org/apache/hadoop/hive/ql/optimizer/physical/GenSparkSkewJoinProcessor.java new file mode 100644 index 0000000..105463f --- /dev/null +++ ql/src/java/org/apache/hadoop/hive/ql/optimizer/physical/GenSparkSkewJoinProcessor.java @@ -0,0 +1,467 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.hadoop.hive.ql.optimizer.physical; + +import com.google.common.base.Preconditions; +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.apache.hadoop.fs.Path; +import org.apache.hadoop.hive.conf.HiveConf; +import org.apache.hadoop.hive.ql.exec.ColumnInfo; +import org.apache.hadoop.hive.ql.exec.ConditionalTask; +import org.apache.hadoop.hive.ql.exec.HashTableDummyOperator; +import org.apache.hadoop.hive.ql.exec.JoinOperator; +import org.apache.hadoop.hive.ql.exec.MapJoinOperator; +import org.apache.hadoop.hive.ql.exec.Operator; +import org.apache.hadoop.hive.ql.exec.OperatorFactory; +import org.apache.hadoop.hive.ql.exec.RowSchema; +import org.apache.hadoop.hive.ql.exec.SparkHashTableSinkOperator; +import org.apache.hadoop.hive.ql.exec.TableScanOperator; +import org.apache.hadoop.hive.ql.exec.Task; +import org.apache.hadoop.hive.ql.exec.TaskFactory; +import org.apache.hadoop.hive.ql.exec.Utilities; +import org.apache.hadoop.hive.ql.exec.spark.SparkTask; +import org.apache.hadoop.hive.ql.io.HiveInputFormat; +import org.apache.hadoop.hive.ql.optimizer.GenMapRedUtils; +import org.apache.hadoop.hive.ql.parse.ParseContext; +import org.apache.hadoop.hive.ql.parse.SemanticException; +import org.apache.hadoop.hive.ql.parse.spark.GenSparkUtils; +import org.apache.hadoop.hive.ql.plan.BaseWork; +import org.apache.hadoop.hive.ql.plan.ConditionalResolverSkewJoin; +import org.apache.hadoop.hive.ql.plan.ConditionalWork; +import org.apache.hadoop.hive.ql.plan.ExprNodeColumnDesc; +import org.apache.hadoop.hive.ql.plan.ExprNodeDesc; +import org.apache.hadoop.hive.ql.plan.HashTableDummyDesc; +import org.apache.hadoop.hive.ql.plan.JoinDesc; +import org.apache.hadoop.hive.ql.plan.MapJoinDesc; +import org.apache.hadoop.hive.ql.plan.MapWork; +import org.apache.hadoop.hive.ql.plan.OperatorDesc; +import org.apache.hadoop.hive.ql.plan.PartitionDesc; +import org.apache.hadoop.hive.ql.plan.PlanUtils; +import org.apache.hadoop.hive.ql.plan.ReduceWork; +import org.apache.hadoop.hive.ql.plan.SparkEdgeProperty; +import org.apache.hadoop.hive.ql.plan.SparkHashTableSinkDesc; +import org.apache.hadoop.hive.ql.plan.SparkWork; +import org.apache.hadoop.hive.ql.plan.TableDesc; +import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo; +import org.apache.hadoop.hive.serde2.typeinfo.TypeInfoFactory; + +import java.io.Serializable; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +/** + * Copied from GenMRSkewJoinProcessor. It's used for spark task + * + */ +public class GenSparkSkewJoinProcessor { + private static final Log LOG = LogFactory.getLog(GenSparkSkewJoinProcessor.class.getName()); + + private GenSparkSkewJoinProcessor() { + // prevent instantiation + } + + public static void processSkewJoin(JoinOperator joinOp, Task currTask, + ReduceWork reduceWork, ParseContext parseCtx) throws SemanticException { + + // We are trying to adding map joins to handle skew keys, and map join right + // now does not work with outer joins + if (!GenMRSkewJoinProcessor.skewJoinEnabled(parseCtx.getConf(), joinOp) || + !(currTask instanceof SparkTask)) { + return; + } + SparkWork currentWork = ((SparkTask) currTask).getWork(); + if (!supportRuntimeSkewJoin(currentWork, reduceWork)) { + return; + } + + List> children = currTask.getChildTasks(); + if (children != null && children.size() > 1) { + LOG.warn("Skip runtime skew join as current task has multiple children."); + return; + } + + Task child = + children != null && children.size() == 1 ? children.get(0) : null; + + Path baseTmpDir = parseCtx.getContext().getMRTmpPath(); + + JoinDesc joinDescriptor = joinOp.getConf(); + Map> joinValues = joinDescriptor.getExprs(); + int numAliases = joinValues.size(); + + Map bigKeysDirMap = new HashMap(); + Map> smallKeysDirMap = new HashMap>(); + Map skewJoinJobResultsDir = new HashMap(); + Byte[] tags = joinDescriptor.getTagOrder(); + // for each joining table, set dir for big key and small keys properly + for (int i = 0; i < numAliases; i++) { + Byte alias = tags[i]; + bigKeysDirMap.put(alias, GenMRSkewJoinProcessor.getBigKeysDir(baseTmpDir, alias)); + Map smallKeysMap = new HashMap(); + smallKeysDirMap.put(alias, smallKeysMap); + for (Byte src2 : tags) { + if (!src2.equals(alias)) { + smallKeysMap.put(src2, GenMRSkewJoinProcessor.getSmallKeysDir(baseTmpDir, alias, src2)); + } + } + skewJoinJobResultsDir.put(alias, + GenMRSkewJoinProcessor.getBigKeysSkewJoinResultDir(baseTmpDir, alias)); + } + + joinDescriptor.setHandleSkewJoin(true); + joinDescriptor.setBigKeysDirMap(bigKeysDirMap); + joinDescriptor.setSmallKeysDirMap(smallKeysDirMap); + joinDescriptor.setSkewKeyDefinition(HiveConf.getIntVar(parseCtx.getConf(), + HiveConf.ConfVars.HIVESKEWJOINKEY)); + + // create proper table/column desc for spilled tables + TableDesc keyTblDesc = (TableDesc) reduceWork.getKeyDesc().clone(); + List joinKeys = Utilities + .getColumnNames(keyTblDesc.getProperties()); + List joinKeyTypes = Utilities.getColumnTypes(keyTblDesc + .getProperties()); + + Map tableDescList = new HashMap(); + Map rowSchemaList = new HashMap(); + Map> newJoinValues = new HashMap>(); + Map> newJoinKeys = new HashMap>(); + // used for create mapJoinDesc, should be in order + List newJoinValueTblDesc = new ArrayList(); + + for (Byte tag : tags) { + newJoinValueTblDesc.add(null); + } + + for (int i = 0; i < numAliases; i++) { + Byte alias = tags[i]; + List valueCols = joinValues.get(alias); + String colNames = ""; + String colTypes = ""; + int columnSize = valueCols.size(); + List newValueExpr = new ArrayList(); + List newKeyExpr = new ArrayList(); + ArrayList columnInfos = new ArrayList(); + + boolean first = true; + for (int k = 0; k < columnSize; k++) { + TypeInfo type = valueCols.get(k).getTypeInfo(); + String newColName = i + "_VALUE_" + k; // any name, it does not matter. + ColumnInfo columnInfo = new ColumnInfo(newColName, type, alias.toString(), false); + columnInfos.add(columnInfo); + newValueExpr.add(new ExprNodeColumnDesc( + columnInfo.getType(), columnInfo.getInternalName(), + columnInfo.getTabAlias(), false)); + if (!first) { + colNames = colNames + ","; + colTypes = colTypes + ","; + } + first = false; + colNames = colNames + newColName; + colTypes = colTypes + valueCols.get(k).getTypeString(); + } + + // we are putting join keys at last part of the spilled table + for (int k = 0; k < joinKeys.size(); k++) { + if (!first) { + colNames = colNames + ","; + colTypes = colTypes + ","; + } + first = false; + colNames = colNames + joinKeys.get(k); + colTypes = colTypes + joinKeyTypes.get(k); + ColumnInfo columnInfo = new ColumnInfo(joinKeys.get(k), TypeInfoFactory + .getPrimitiveTypeInfo(joinKeyTypes.get(k)), alias.toString(), false); + columnInfos.add(columnInfo); + newKeyExpr.add(new ExprNodeColumnDesc( + columnInfo.getType(), columnInfo.getInternalName(), + columnInfo.getTabAlias(), false)); + } + + newJoinValues.put(alias, newValueExpr); + newJoinKeys.put(alias, newKeyExpr); + tableDescList.put(alias, Utilities.getTableDesc(colNames, colTypes)); + rowSchemaList.put(alias, new RowSchema(columnInfos)); + + // construct value table Desc + String valueColNames = ""; + String valueColTypes = ""; + first = true; + for (int k = 0; k < columnSize; k++) { + String newColName = i + "_VALUE_" + k; // any name, it does not matter. + if (!first) { + valueColNames = valueColNames + ","; + valueColTypes = valueColTypes + ","; + } + valueColNames = valueColNames + newColName; + valueColTypes = valueColTypes + valueCols.get(k).getTypeString(); + first = false; + } + newJoinValueTblDesc.set((byte) i, Utilities.getTableDesc( + valueColNames, valueColTypes)); + } + + joinDescriptor.setSkewKeysValuesTables(tableDescList); + joinDescriptor.setKeyTableDesc(keyTblDesc); + + // create N-1 map join tasks + HashMap> bigKeysDirToTaskMap = + new HashMap>(); + List listWorks = new ArrayList(); + List> listTasks = new ArrayList>(); + for (int i = 0; i < numAliases - 1; i++) { + Byte src = tags[i]; + HiveConf hiveConf = new HiveConf(parseCtx.getConf(), + GenSparkSkewJoinProcessor.class); + SparkWork sparkWork = new SparkWork(parseCtx.getConf().getVar(HiveConf.ConfVars.HIVEQUERYID)); + Task skewJoinMapJoinTask = TaskFactory.get(sparkWork, hiveConf); + skewJoinMapJoinTask.setFetchSource(currTask.isFetchSource()); + + // create N TableScans + Operator[] parentOps = new TableScanOperator[tags.length]; + for (int k = 0; k < tags.length; k++) { + Operator ts = + GenMapRedUtils.createTemporaryTableScanOperator(rowSchemaList.get((byte) k)); + ((TableScanOperator)ts).setTableDesc(tableDescList.get((byte)k)); + parentOps[k] = ts; + } + + // create the MapJoinOperator + String dumpFilePrefix = "mapfile"+ PlanUtils.getCountForMapJoinDumpFilePrefix(); + MapJoinDesc mapJoinDescriptor = new MapJoinDesc(newJoinKeys, keyTblDesc, + newJoinValues, newJoinValueTblDesc, newJoinValueTblDesc,joinDescriptor + .getOutputColumnNames(), i, joinDescriptor.getConds(), + joinDescriptor.getFilters(), joinDescriptor.getNoOuterJoin(), dumpFilePrefix); + mapJoinDescriptor.setTagOrder(tags); + mapJoinDescriptor.setHandleSkewJoin(false); + mapJoinDescriptor.setNullSafes(joinDescriptor.getNullSafes()); + // temporarily, mark it as child of all the TS + MapJoinOperator mapJoinOp = (MapJoinOperator) OperatorFactory + .getAndMakeChild(mapJoinDescriptor, null, parentOps); + + // clone the original join operator, and replace it with the MJ + // this makes sure MJ has the same downstream operator plan as the original join + List> reducerList = new ArrayList>(); + reducerList.add(reduceWork.getReducer()); + Operator reducer = Utilities.cloneOperatorTree( + parseCtx.getConf(), reducerList).get(0); + Preconditions.checkArgument(reducer instanceof JoinOperator, + "Reducer should be join operator, but actually is " + reducer.getName()); + JoinOperator cloneJoinOp = (JoinOperator) reducer; + List> childOps = cloneJoinOp + .getChildOperators(); + for (Operator childOp : childOps) { + childOp.replaceParent(cloneJoinOp, mapJoinOp); + } + mapJoinOp.setChildOperators(childOps); + + // set memory usage for the MJ operator + setMemUsage(mapJoinOp, skewJoinMapJoinTask, parseCtx); + + // create N MapWorks and add them to the SparkWork + MapWork bigMapWork = null; + Map smallTblDirs = smallKeysDirMap.get(src); + for (int j = 0; j < tags.length; j++) { + MapWork mapWork = PlanUtils.getMapRedWork().getMapWork(); + sparkWork.add(mapWork); + // This code has been only added for testing + boolean mapperCannotSpanPartns = + parseCtx.getConf().getBoolVar( + HiveConf.ConfVars.HIVE_MAPPER_CANNOT_SPAN_MULTIPLE_PARTITIONS); + mapWork.setMapperCannotSpanPartns(mapperCannotSpanPartns); + Operator tableScan = parentOps[j]; + String alias = tags[j].toString(); + ArrayList aliases = new ArrayList(); + aliases.add(alias); + Path path; + if (j == i) { + path = bigKeysDirMap.get(tags[j]); + bigKeysDirToTaskMap.put(path, skewJoinMapJoinTask); + bigMapWork = mapWork; + // in MR, ReduceWork is a terminal work, but that's not the case for spark, therefore for + // big dir MapWork, we'll have to clone all dependent works in the original work graph + cloneWorkGraph(currentWork, sparkWork, reduceWork, mapWork); + } else { + path = smallTblDirs.get(tags[j]); + } + mapWork.getPathToAliases().put(path.toString(), aliases); + mapWork.getAliasToWork().put(alias, tableScan); + PartitionDesc partitionDesc = new PartitionDesc(tableDescList.get(tags[j]), null); + mapWork.getPathToPartitionInfo().put(path.toString(), partitionDesc); + mapWork.getAliasToPartnInfo().put(alias, partitionDesc); + mapWork.setNumMapTasks(HiveConf.getIntVar(hiveConf, + HiveConf.ConfVars.HIVESKEWJOINMAPJOINNUMMAPTASK)); + mapWork.setMinSplitSize(HiveConf.getLongVar(hiveConf, + HiveConf.ConfVars.HIVESKEWJOINMAPJOINMINSPLIT)); + mapWork.setInputformat(HiveInputFormat.class.getName()); + mapWork.setName("Map " + GenSparkUtils.getUtils().getNextSeqNumber()); + } + // connect all small dir map work to the big dir map work + Preconditions.checkArgument(bigMapWork != null, "Haven't identified big dir MapWork"); + for (BaseWork work : sparkWork.getRoots()) { + Preconditions.checkArgument(work instanceof MapWork, + "All root work should be MapWork, but got " + work.getClass().getSimpleName()); + if(work != bigMapWork) { + sparkWork.connect(work, bigMapWork, + new SparkEdgeProperty(SparkEdgeProperty.SHUFFLE_NONE)); + } + } + + // insert SparkHashTableSink and Dummy operators + for (int j = 0; j < tags.length; j++) { + if (j != i) { + insertSHTS(tags[j], (TableScanOperator) parentOps[j], bigMapWork); + } + } + + // keep it as reference in case we need fetch work +// localPlan.getAliasToFetchWork().put(small_alias.toString(), +// new FetchWork(tblDir, tableDescList.get(small_alias))); + + listWorks.add(skewJoinMapJoinTask.getWork()); + listTasks.add(skewJoinMapJoinTask); + } + if (children != null) { + for (Task tsk : listTasks) { + for (Task oldChild : children) { + tsk.addDependentTask(oldChild); + } + } + } + if (child != null) { + currTask.removeDependentTask(child); + listTasks.add(child); + } + ConditionalResolverSkewJoin.ConditionalResolverSkewJoinCtx context = + new ConditionalResolverSkewJoin.ConditionalResolverSkewJoinCtx(bigKeysDirToTaskMap, child); + + ConditionalWork cndWork = new ConditionalWork(listWorks); + ConditionalTask cndTsk = (ConditionalTask) TaskFactory.get(cndWork, parseCtx.getConf()); + cndTsk.setListTasks(listTasks); + cndTsk.setResolver(new ConditionalResolverSkewJoin()); + cndTsk.setResolverCtx(context); + currTask.setChildTasks(new ArrayList>()); + currTask.addDependentTask(cndTsk); + } + + /** + * Insert SparkHashTableSink and HashTableDummy between small dir TS and MJ + */ + private static void insertSHTS(byte tag, TableScanOperator tableScan, MapWork bigMapWork) { + Preconditions.checkArgument(tableScan.getChildOperators().size() == 1 + && tableScan.getChildOperators().get(0) instanceof MapJoinOperator); + HashTableDummyDesc desc = new HashTableDummyDesc(); + HashTableDummyOperator dummyOp = (HashTableDummyOperator) OperatorFactory.get(desc); + dummyOp.getConf().setTbl(tableScan.getTableDesc()); + MapJoinOperator mapJoinOp = (MapJoinOperator) tableScan.getChildOperators().get(0); + mapJoinOp.replaceParent(tableScan, dummyOp); + List> mapJoinChildren = + new ArrayList>(); + mapJoinChildren.add(mapJoinOp); + dummyOp.setChildOperators(mapJoinChildren); + bigMapWork.addDummyOp(dummyOp); + MapJoinDesc mjDesc = mapJoinOp.getConf(); + SparkHashTableSinkDesc hashTableSinkDesc = new SparkHashTableSinkDesc(mjDesc); + SparkHashTableSinkOperator hashTableSinkOp = + (SparkHashTableSinkOperator) OperatorFactory.get(hashTableSinkDesc); + int[] valueIndex = mjDesc.getValueIndex(tag); + if (valueIndex != null) { + List newValues = new ArrayList(); + List values = hashTableSinkDesc.getExprs().get(tag); + for (int index = 0; index < values.size(); index++) { + if (valueIndex[index] < 0) { + newValues.add(values.get(index)); + } + } + hashTableSinkDesc.getExprs().put(tag, newValues); + } + tableScan.replaceChild(mapJoinOp, hashTableSinkOp); + List> tableScanParents = + new ArrayList>(); + tableScanParents.add(tableScan); + hashTableSinkOp.setParentOperators(tableScanParents); + } + + private static void setMemUsage(MapJoinOperator mapJoinOp, Task task, + ParseContext parseContext) { + MapJoinResolver.LocalMapJoinProcCtx context = + new MapJoinResolver.LocalMapJoinProcCtx(task, parseContext); + try { + new LocalMapJoinProcFactory.LocalMapJoinProcessor().hasGroupBy(mapJoinOp, + context); + } catch (Exception e) { + LOG.warn("Error setting memory usage.", e); + return; + } + MapJoinDesc mapJoinDesc = mapJoinOp.getConf(); + // mapjoin should not affected by join reordering + mapJoinDesc.resetOrder(); + HiveConf conf = context.getParseCtx().getConf(); + float hashtableMemoryUsage; + if (context.isFollowedByGroupBy()) { + hashtableMemoryUsage = conf.getFloatVar( + HiveConf.ConfVars.HIVEHASHTABLEFOLLOWBYGBYMAXMEMORYUSAGE); + } else { + hashtableMemoryUsage = conf.getFloatVar( + HiveConf.ConfVars.HIVEHASHTABLEMAXMEMORYUSAGE); + } + mapJoinDesc.setHashTableMemoryUsage(hashtableMemoryUsage); + } + + private static void cloneWorkGraph(SparkWork originSparkWork, SparkWork newSparkWork, + BaseWork originWork, BaseWork newWork) { + for (BaseWork child : originSparkWork.getChildren(originWork)) { + SparkEdgeProperty edgeProperty = originSparkWork.getEdgeProperty(originWork, child); + BaseWork cloneChild = Utilities.cloneBaseWork(child); + cloneChild.setName(cloneChild.getName().replaceAll("^([a-zA-Z]+)(\\s+)(\\d+)", + "$1$2" + GenSparkUtils.getUtils().getNextSeqNumber())); + newSparkWork.add(cloneChild); + newSparkWork.connect(newWork, cloneChild, edgeProperty); + cloneWorkGraph(originSparkWork, newSparkWork, child, cloneChild); + } + } + + /** + * ReduceWork is not terminal work in spark, so we disable runtime skew join for + * some complicated cases for now, leaving them to future tasks. + * As an example, consider the following spark work graph: + * M1 M5 + * \ / + * R2 (join) M6 + * \ / + * R3 (join) + * | + * R4 (group) + * If we create map join task for R2, we have to clone M6 as well so that the results + * get joined properly. + * + * Let's only support the case where downstream work of the current ReduceWork all + * have single parent. + */ + private static boolean supportRuntimeSkewJoin(SparkWork sparkWork, BaseWork work) { + for (BaseWork child : sparkWork.getChildren(work)) { + if (sparkWork.getParents(child).size() > 1 || !supportRuntimeSkewJoin(sparkWork, child)) { + return false; + } + } + return true; + } +} diff --git ql/src/java/org/apache/hadoop/hive/ql/optimizer/spark/SparkSkewJoinProcFactory.java ql/src/java/org/apache/hadoop/hive/ql/optimizer/spark/SparkSkewJoinProcFactory.java new file mode 100644 index 0000000..6a9c308 --- /dev/null +++ ql/src/java/org/apache/hadoop/hive/ql/optimizer/spark/SparkSkewJoinProcFactory.java @@ -0,0 +1,71 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.hadoop.hive.ql.optimizer.spark; + +import org.apache.hadoop.hive.ql.exec.JoinOperator; +import org.apache.hadoop.hive.ql.exec.Task; +import org.apache.hadoop.hive.ql.exec.spark.SparkTask; +import org.apache.hadoop.hive.ql.lib.Node; +import org.apache.hadoop.hive.ql.lib.NodeProcessor; +import org.apache.hadoop.hive.ql.lib.NodeProcessorCtx; +import org.apache.hadoop.hive.ql.optimizer.physical.GenSparkSkewJoinProcessor; +import org.apache.hadoop.hive.ql.optimizer.physical.SkewJoinProcFactory; +import org.apache.hadoop.hive.ql.parse.ParseContext; +import org.apache.hadoop.hive.ql.parse.SemanticException; + +import java.io.Serializable; +import java.util.Stack; + +/** + * Spark-version of SkewJoinProcFactory + */ +public class SparkSkewJoinProcFactory { + private SparkSkewJoinProcFactory() { + // prevent instantiation + } + + public static NodeProcessor getDefaultProc() { + return SkewJoinProcFactory.getDefaultProc(); + } + + public static NodeProcessor getJoinProc() { + return new SparkSkewJoinJoinProcessor(); + } + + public static class SparkSkewJoinJoinProcessor implements NodeProcessor { + + @Override + public Object process(Node nd, Stack stack, NodeProcessorCtx procCtx, + Object... nodeOutputs) throws SemanticException { + SparkSkewJoinResolver.SparkSkewJoinProcCtx context = + (SparkSkewJoinResolver.SparkSkewJoinProcCtx) procCtx; + JoinOperator op = (JoinOperator) nd; + if (op.getConf().isFixedAsSorted()) { + return null; + } + ParseContext parseContext = context.getParseCtx(); + Task currentTsk = context.getCurrentTask(); + if (currentTsk instanceof SparkTask) { + GenSparkSkewJoinProcessor.processSkewJoin(op, currentTsk, + context.getReducerToReduceWork().get(op), parseContext); + } + return null; + } + } +} diff --git ql/src/java/org/apache/hadoop/hive/ql/optimizer/spark/SparkSkewJoinResolver.java ql/src/java/org/apache/hadoop/hive/ql/optimizer/spark/SparkSkewJoinResolver.java new file mode 100644 index 0000000..9b180f1 --- /dev/null +++ ql/src/java/org/apache/hadoop/hive/ql/optimizer/spark/SparkSkewJoinResolver.java @@ -0,0 +1,122 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.hadoop.hive.ql.optimizer.spark; + +import org.apache.hadoop.hive.ql.exec.CommonJoinOperator; +import org.apache.hadoop.hive.ql.exec.ConditionalTask; +import org.apache.hadoop.hive.ql.exec.Operator; +import org.apache.hadoop.hive.ql.exec.Task; +import org.apache.hadoop.hive.ql.exec.spark.SparkTask; +import org.apache.hadoop.hive.ql.lib.DefaultGraphWalker; +import org.apache.hadoop.hive.ql.lib.DefaultRuleDispatcher; +import org.apache.hadoop.hive.ql.lib.Dispatcher; +import org.apache.hadoop.hive.ql.lib.GraphWalker; +import org.apache.hadoop.hive.ql.lib.Node; +import org.apache.hadoop.hive.ql.lib.NodeProcessor; +import org.apache.hadoop.hive.ql.lib.Rule; +import org.apache.hadoop.hive.ql.lib.RuleRegExp; +import org.apache.hadoop.hive.ql.optimizer.physical.PhysicalContext; +import org.apache.hadoop.hive.ql.optimizer.physical.PhysicalPlanResolver; +import org.apache.hadoop.hive.ql.optimizer.physical.SkewJoinResolver; +import org.apache.hadoop.hive.ql.parse.ParseContext; +import org.apache.hadoop.hive.ql.parse.SemanticException; +import org.apache.hadoop.hive.ql.plan.ReduceWork; +import org.apache.hadoop.hive.ql.plan.SparkWork; + +import java.io.Serializable; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.LinkedHashMap; +import java.util.Map; +import java.util.Stack; + +/** + * Spark version of SkewJoinResolver + */ +public class SparkSkewJoinResolver implements PhysicalPlanResolver { + @Override + public PhysicalContext resolve(PhysicalContext pctx) throws SemanticException { + Dispatcher disp = new SparkSkewJoinTaskDispatcher(pctx); + GraphWalker ogw = new DefaultGraphWalker(disp); + ArrayList topNodes = new ArrayList(); + topNodes.addAll(pctx.getRootTasks()); + ogw.startWalking(topNodes, null); + return pctx; + } + + class SparkSkewJoinTaskDispatcher implements Dispatcher{ + private PhysicalContext physicalContext; + + public SparkSkewJoinTaskDispatcher(PhysicalContext context) { + super(); + physicalContext = context; + } + + @Override + public Object dispatch(Node nd, Stack stack, Object... nodeOutputs) + throws SemanticException { + + Task task = (Task) nd; + if (task instanceof SparkTask) { + SparkWork sparkWork = ((SparkTask) task).getWork(); + if (sparkWork.getAllReduceWork().isEmpty()) { + return null; + } + SparkSkewJoinProcCtx skewJoinProcCtx = + new SparkSkewJoinProcCtx(task, physicalContext.getParseContext()); + Map opRules = new LinkedHashMap(); + opRules.put(new RuleRegExp("R1", CommonJoinOperator.getOperatorName() + "%"), + SparkSkewJoinProcFactory.getJoinProc()); + Dispatcher disp = new DefaultRuleDispatcher( + SparkSkewJoinProcFactory.getDefaultProc(), opRules, skewJoinProcCtx); + GraphWalker ogw = new DefaultGraphWalker(disp); + ArrayList topNodes = new ArrayList(); + for (ReduceWork reduceWork : sparkWork.getAllReduceWork()) { + topNodes.add(reduceWork.getReducer()); + skewJoinProcCtx.getReducerToReduceWork().put(reduceWork.getReducer(), reduceWork); + } + ogw.startWalking(topNodes, null); + } + return null; + } + + public PhysicalContext getPhysicalContext() { + return physicalContext; + } + + public void setPhysicalContext(PhysicalContext physicalContext) { + this.physicalContext = physicalContext; + } + } + + public static class SparkSkewJoinProcCtx extends SkewJoinResolver.SkewJoinProcCtx { + // need a map from the reducer to the corresponding ReduceWork + private Map, ReduceWork> reducerToReduceWork; + + public SparkSkewJoinProcCtx(Task task, + ParseContext parseCtx) { + super(task, parseCtx); + reducerToReduceWork = new HashMap, ReduceWork>(); + } + + public Map, ReduceWork> getReducerToReduceWork() { + return reducerToReduceWork; + } + } +} diff --git ql/src/java/org/apache/hadoop/hive/ql/parse/spark/SparkCompiler.java ql/src/java/org/apache/hadoop/hive/ql/parse/spark/SparkCompiler.java index 7a01b1f..24e1460 100644 --- ql/src/java/org/apache/hadoop/hive/ql/parse/spark/SparkCompiler.java +++ ql/src/java/org/apache/hadoop/hive/ql/parse/spark/SparkCompiler.java @@ -67,6 +67,7 @@ import org.apache.hadoop.hive.ql.optimizer.spark.SetSparkReducerParallelism; import org.apache.hadoop.hive.ql.optimizer.spark.SparkMapJoinOptimizer; import org.apache.hadoop.hive.ql.optimizer.spark.SparkReduceSinkMapJoinProc; +import org.apache.hadoop.hive.ql.optimizer.spark.SparkSkewJoinResolver; import org.apache.hadoop.hive.ql.optimizer.spark.SparkSortMergeJoinFactory; import org.apache.hadoop.hive.ql.optimizer.spark.SplitSparkWorkResolver; import org.apache.hadoop.hive.ql.parse.GlobalLimitCtx; @@ -302,6 +303,13 @@ protected void optimizeTaskPlan(List> rootTasks, Pa } else { LOG.debug("Skipping stage id rearranger"); } + + if (conf.getBoolVar(HiveConf.ConfVars.HIVESKEWJOIN)) { + // TODO: enable after HIVE-8913 is done + //(new SparkSkewJoinResolver()).resolve(physicalCtx); + } else { + LOG.debug("Skipping runtime skew join optimization"); + } return; } } diff --git ql/src/java/org/apache/hadoop/hive/ql/plan/SparkWork.java ql/src/java/org/apache/hadoop/hive/ql/plan/SparkWork.java index b5aa99c..275d567 100644 --- ql/src/java/org/apache/hadoop/hive/ql/plan/SparkWork.java +++ ql/src/java/org/apache/hadoop/hive/ql/plan/SparkWork.java @@ -19,6 +19,8 @@ package org.apache.hadoop.hive.ql.plan; import java.io.Serializable; + +import java.util.ArrayList; import java.util.Collection; import java.util.Collections; import java.util.HashMap; @@ -395,6 +397,17 @@ public String toString() { return result; } + // get all reduce works in this spark work + public List getAllReduceWork() { + List result = new ArrayList(); + for (BaseWork work : getAllWork()) { + if (work instanceof ReduceWork) { + result.add((ReduceWork) work); + } + } + return result; + } + public Map getCloneToWork() { return cloneToWork; }