diff --git a/ql/src/java/org/apache/hadoop/hive/ql/optimizer/spark/CombineEquivalentWorkResolver.java b/ql/src/java/org/apache/hadoop/hive/ql/optimizer/spark/CombineEquivalentWorkResolver.java new file mode 100644 index 0000000..20d3f01 --- /dev/null +++ b/ql/src/java/org/apache/hadoop/hive/ql/optimizer/spark/CombineEquivalentWorkResolver.java @@ -0,0 +1,294 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + *

+ * http://www.apache.org/licenses/LICENSE-2.0 + *

+ * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.hadoop.hive.ql.optimizer.spark; + +import java.lang.annotation.Annotation; +import java.lang.reflect.Method; +import java.util.ArrayList; +import java.util.Comparator; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.Stack; + +import com.google.common.collect.Maps; +import com.google.common.collect.Sets; +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.apache.hadoop.hive.ql.exec.Operator; +import org.apache.hadoop.hive.ql.exec.spark.SparkTask; +import org.apache.hadoop.hive.ql.lib.Dispatcher; +import org.apache.hadoop.hive.ql.lib.Node; +import org.apache.hadoop.hive.ql.lib.TaskGraphWalker; +import org.apache.hadoop.hive.ql.optimizer.physical.PhysicalContext; +import org.apache.hadoop.hive.ql.optimizer.physical.PhysicalPlanResolver; +import org.apache.hadoop.hive.ql.parse.SemanticException; +import org.apache.hadoop.hive.ql.plan.BaseWork; +import org.apache.hadoop.hive.ql.plan.Explain; +import org.apache.hadoop.hive.ql.plan.OperatorDesc; +import org.apache.hadoop.hive.ql.plan.SparkEdgeProperty; +import org.apache.hadoop.hive.ql.plan.SparkWork; +import org.apache.hive.common.util.AnnotationUtils; + +/** + * CombineEquivalentWorkResolver would search inside SparkWork, and find and combine equivalent + * works. + */ +public class CombineEquivalentWorkResolver implements PhysicalPlanResolver { + protected static transient Log LOG = LogFactory.getLog(CombineEquivalentWorkResolver.class); + + @Override + public PhysicalContext resolve(PhysicalContext pctx) throws SemanticException { + List topNodes = new ArrayList(); + topNodes.addAll(pctx.getRootTasks()); + TaskGraphWalker taskWalker = new TaskGraphWalker(new EquivalentWorkMatcher()); + taskWalker.startWalking(topNodes, null); + return pctx; + } + + class EquivalentWorkMatcher implements Dispatcher { + // key: work, value: all the works equivalent as key. + private Map> equivalentWorks = Maps.newHashMap(); + + @Override + public Object dispatch(Node nd, Stack stack, Object... nodeOutputs) throws SemanticException { + if (nd instanceof SparkTask) { + SparkTask sparkTask = (SparkTask) nd; + SparkWork sparkWork = sparkTask.getWork(); + Set roots = sparkWork.getRoots(); + equivalentWorks.clear(); + if (roots.size() > 1) { + BaseWork[] rootWorks = new BaseWork[roots.size()]; + roots.toArray(rootWorks); + for (int i = 0; i < rootWorks.length; i++) { + for (int j = i + 1; j < rootWorks.length; j++) { + compareWorkRecursively(rootWorks[i], rootWorks[j], sparkWork); + } + } + } + + Set replacedWorks = Sets.newHashSet(); + + for (BaseWork current : equivalentWorks.keySet()) { + // Work has been replaced, skip. + if (replacedWorks.contains(current)) { + continue; + } + Set currentEquivalent = equivalentWorks.get(current); + for (BaseWork equiWork : currentEquivalent) { + if (equiWork != current) { + replaceWork(equiWork, current, sparkWork); + replacedWorks.add(equiWork); + } + } + } + } + return null; + } + + private void compareWorkRecursively(BaseWork first, BaseWork second, SparkWork sparkWork) { + if (first == second) { + return; + } + // make sure their parents are equivalent before compare them self. + List firstParents = sparkWork.getParents(first); + List secondParents = sparkWork.getParents(second); + if (firstParents == null && secondParents != null) { + return; + } else if (firstParents != null && secondParents == null) { + return; + } else if (firstParents.size() != secondParents.size()) { + return; + } + + if (firstParents != null && secondParents != null && firstParents.size() == secondParents.size()) { + for (BaseWork parent : firstParents) { + Set parentEquivalentWorks = equivalentWorks.get(parent); + if (parentEquivalentWorks == null) { + return; + } + boolean result = false; + BaseWork matchedWork = null; + for (BaseWork secondParent : secondParents) { + Set secondParentEquivalentWorks = equivalentWorks.get(secondParent); + if (secondParentEquivalentWorks != null && parentEquivalentWorks.equals(secondParentEquivalentWorks)) { + result = true; + matchedWork = secondParent; + break; + } + } + if (result) { + secondParents.remove(matchedWork); + } else { + return; + } + } + } + + if (compareWork(first, second)) { + Set firstEquivalentWorks = equivalentWorks.get(first); + if (firstEquivalentWorks == null) { + firstEquivalentWorks = Sets.newHashSet(); + equivalentWorks.put(first, firstEquivalentWorks); + } + Set secondEquivalentWorks = equivalentWorks.get(second); + if (secondEquivalentWorks == null) { + secondEquivalentWorks = Sets.newHashSet(); + equivalentWorks.put(second, secondEquivalentWorks); + } + + firstEquivalentWorks.add(first); + firstEquivalentWorks.add(second); + secondEquivalentWorks.add(first); + secondEquivalentWorks.add(second); + + Set childrenWorks = Sets.newHashSet(); + childrenWorks.addAll(sparkWork.getChildren(first)); + childrenWorks.addAll(sparkWork.getChildren(second)); + + if (childrenWorks.size() > 1) { + BaseWork[] childrenWorkArray = new BaseWork[childrenWorks.size()]; + childrenWorks.toArray(childrenWorkArray); + for (int i = 0; i < childrenWorkArray.length; i++) { + for (int j = i + 1; j < childrenWorkArray.length; j++) { + compareWorkRecursively(childrenWorkArray[i], childrenWorkArray[j], sparkWork); + } + } + + } + } + } + + private void replaceWork(BaseWork previous, BaseWork current, SparkWork sparkWork) { + List parents = sparkWork.getParents(previous); + List children = sparkWork.getChildren(previous); + for (BaseWork parent : parents) { + SparkEdgeProperty edgeProperty = sparkWork.getEdgeProperty(parent, previous); + sparkWork.disconnect(parent, previous); + sparkWork.connect(parent, current, edgeProperty); + } + for (BaseWork child : children) { + SparkEdgeProperty edgeProperty = sparkWork.getEdgeProperty(previous, child); + sparkWork.disconnect(previous, child); + sparkWork.connect(current, child, edgeProperty); + } + sparkWork.remove(previous); + } + + private boolean compareWork(BaseWork first, BaseWork second) { + + if (!first.getClass().getName().equals(second.getClass().getName())) { + return false; + } + + Set> firstRootOperators = first.getAllRootOperators(); + Set> secondRootOperators = second.getAllRootOperators(); + if (firstRootOperators.size() != secondRootOperators.size()) { + return false; + } + + Iterator> firstIterator = firstRootOperators.iterator(); + Iterator> secondIterator = secondRootOperators.iterator(); + while (firstIterator.hasNext()) { + boolean result = compareOperatorChain(firstIterator.next(), secondIterator.next()); + if (!result) { + return result; + } + } + + return true; + } + + private boolean compareOperatorChain(Operator firstOperator, Operator secondOperator) { + boolean result = compareCurrentOperator(firstOperator, secondOperator); + if (!result) { + return result; + } + + List> firstOperatorChildOperators = firstOperator.getChildOperators(); + List> secondOperatorChildOperators = secondOperator.getChildOperators(); + if (firstOperatorChildOperators == null && secondOperatorChildOperators != null) { + return false; + } else if (firstOperatorChildOperators != null && secondOperatorChildOperators == null) { + return false; + } else if (firstOperatorChildOperators != null && secondOperatorChildOperators != null) { + if (firstOperatorChildOperators.size() != secondOperatorChildOperators.size()) { + return false; + } + int size = firstOperatorChildOperators.size(); + for (int i = 0; i < size; i++) { + result = compareOperatorChain(firstOperatorChildOperators.get(i), secondOperatorChildOperators.get(i)); + if (!result) { + return false; + } + } + } + + return true; + } + + /** + * Compare Operators through their Explain output string. + * + * @param firstOperator + * @param secondOperator + * @return + */ + private boolean compareCurrentOperator(Operator firstOperator, Operator secondOperator) { + if (!firstOperator.getClass().getName().equals(secondOperator.getClass().getName())) { + return false; + } + + Method[] methods = firstOperator.getConf().getClass().getMethods(); + for (Method m : methods) { + Annotation note = AnnotationUtils.getAnnotation(m, Explain.class); + + if (note instanceof Explain) { + Explain explain = (Explain) note; + if (Explain.Level.EXTENDED.in(explain.explainLevels()) || + Explain.Level.DEFAULT.in(explain.explainLevels())) { + Object firstObj = null; + Object secondObj = null; + try { + firstObj = m.invoke(firstOperator.getConf()); + } catch (Exception e) { + LOG.debug("Failed to get method return value.", e); + firstObj = null; + } + + try { + secondObj = m.invoke(secondOperator.getConf()); + } catch (Exception e) { + LOG.debug("Failed to get method return value.", e); + secondObj = null; + } + + if ((firstObj == null && secondObj != null) || + (firstObj != null && secondObj == null) || + (firstObj != null && secondObj != null && + !firstObj.toString().equals(secondObj.toString()))) { + return false; + } + } + } + } + return true; + } + } +} diff --git a/ql/src/java/org/apache/hadoop/hive/ql/parse/spark/SparkCompiler.java b/ql/src/java/org/apache/hadoop/hive/ql/parse/spark/SparkCompiler.java index 19aae70..7f2c079 100644 --- a/ql/src/java/org/apache/hadoop/hive/ql/parse/spark/SparkCompiler.java +++ b/ql/src/java/org/apache/hadoop/hive/ql/parse/spark/SparkCompiler.java @@ -66,6 +66,7 @@ import org.apache.hadoop.hive.ql.optimizer.physical.SparkMapJoinResolver; import org.apache.hadoop.hive.ql.optimizer.physical.StageIDsRearranger; import org.apache.hadoop.hive.ql.optimizer.physical.Vectorizer; +import org.apache.hadoop.hive.ql.optimizer.spark.CombineEquivalentWorkResolver; import org.apache.hadoop.hive.ql.optimizer.spark.SetSparkReducerParallelism; import org.apache.hadoop.hive.ql.optimizer.spark.SparkJoinHintOptimizer; import org.apache.hadoop.hive.ql.optimizer.spark.SparkJoinOptimizer; @@ -337,6 +338,8 @@ protected void optimizeTaskPlan(List> rootTasks, Pa LOG.debug("Skipping stage id rearranger"); } + new CombineEquivalentWorkResolver().resolve(physicalCtx); + PERF_LOGGER.PerfLogEnd(CLASS_NAME, PerfLogger.SPARK_OPTIMIZE_TASK_TREE); return; }