diff --git a/ql/src/java/org/apache/hadoop/hive/ql/optimizer/spark/CombineEquivalentWorkResolver.java b/ql/src/java/org/apache/hadoop/hive/ql/optimizer/spark/CombineEquivalentWorkResolver.java
new file mode 100644
index 0000000..20d3f01
--- /dev/null
+++ b/ql/src/java/org/apache/hadoop/hive/ql/optimizer/spark/CombineEquivalentWorkResolver.java
@@ -0,0 +1,294 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.hadoop.hive.ql.optimizer.spark;
+
+import java.lang.annotation.Annotation;
+import java.lang.reflect.Method;
+import java.util.ArrayList;
+import java.util.Comparator;
+import java.util.Iterator;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+import java.util.Stack;
+
+import com.google.common.collect.Maps;
+import com.google.common.collect.Sets;
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+import org.apache.hadoop.hive.ql.exec.Operator;
+import org.apache.hadoop.hive.ql.exec.spark.SparkTask;
+import org.apache.hadoop.hive.ql.lib.Dispatcher;
+import org.apache.hadoop.hive.ql.lib.Node;
+import org.apache.hadoop.hive.ql.lib.TaskGraphWalker;
+import org.apache.hadoop.hive.ql.optimizer.physical.PhysicalContext;
+import org.apache.hadoop.hive.ql.optimizer.physical.PhysicalPlanResolver;
+import org.apache.hadoop.hive.ql.parse.SemanticException;
+import org.apache.hadoop.hive.ql.plan.BaseWork;
+import org.apache.hadoop.hive.ql.plan.Explain;
+import org.apache.hadoop.hive.ql.plan.OperatorDesc;
+import org.apache.hadoop.hive.ql.plan.SparkEdgeProperty;
+import org.apache.hadoop.hive.ql.plan.SparkWork;
+import org.apache.hive.common.util.AnnotationUtils;
+
+/**
+ * CombineEquivalentWorkResolver would search inside SparkWork, and find and combine equivalent
+ * works.
+ */
+public class CombineEquivalentWorkResolver implements PhysicalPlanResolver {
+ protected static transient Log LOG = LogFactory.getLog(CombineEquivalentWorkResolver.class);
+
+ @Override
+ public PhysicalContext resolve(PhysicalContext pctx) throws SemanticException {
+ List topNodes = new ArrayList();
+ topNodes.addAll(pctx.getRootTasks());
+ TaskGraphWalker taskWalker = new TaskGraphWalker(new EquivalentWorkMatcher());
+ taskWalker.startWalking(topNodes, null);
+ return pctx;
+ }
+
+ class EquivalentWorkMatcher implements Dispatcher {
+ // key: work, value: all the works equivalent as key.
+ private Map> equivalentWorks = Maps.newHashMap();
+
+ @Override
+ public Object dispatch(Node nd, Stack stack, Object... nodeOutputs) throws SemanticException {
+ if (nd instanceof SparkTask) {
+ SparkTask sparkTask = (SparkTask) nd;
+ SparkWork sparkWork = sparkTask.getWork();
+ Set roots = sparkWork.getRoots();
+ equivalentWorks.clear();
+ if (roots.size() > 1) {
+ BaseWork[] rootWorks = new BaseWork[roots.size()];
+ roots.toArray(rootWorks);
+ for (int i = 0; i < rootWorks.length; i++) {
+ for (int j = i + 1; j < rootWorks.length; j++) {
+ compareWorkRecursively(rootWorks[i], rootWorks[j], sparkWork);
+ }
+ }
+ }
+
+ Set replacedWorks = Sets.newHashSet();
+
+ for (BaseWork current : equivalentWorks.keySet()) {
+ // Work has been replaced, skip.
+ if (replacedWorks.contains(current)) {
+ continue;
+ }
+ Set currentEquivalent = equivalentWorks.get(current);
+ for (BaseWork equiWork : currentEquivalent) {
+ if (equiWork != current) {
+ replaceWork(equiWork, current, sparkWork);
+ replacedWorks.add(equiWork);
+ }
+ }
+ }
+ }
+ return null;
+ }
+
+ private void compareWorkRecursively(BaseWork first, BaseWork second, SparkWork sparkWork) {
+ if (first == second) {
+ return;
+ }
+ // make sure their parents are equivalent before compare them self.
+ List firstParents = sparkWork.getParents(first);
+ List secondParents = sparkWork.getParents(second);
+ if (firstParents == null && secondParents != null) {
+ return;
+ } else if (firstParents != null && secondParents == null) {
+ return;
+ } else if (firstParents.size() != secondParents.size()) {
+ return;
+ }
+
+ if (firstParents != null && secondParents != null && firstParents.size() == secondParents.size()) {
+ for (BaseWork parent : firstParents) {
+ Set parentEquivalentWorks = equivalentWorks.get(parent);
+ if (parentEquivalentWorks == null) {
+ return;
+ }
+ boolean result = false;
+ BaseWork matchedWork = null;
+ for (BaseWork secondParent : secondParents) {
+ Set secondParentEquivalentWorks = equivalentWorks.get(secondParent);
+ if (secondParentEquivalentWorks != null && parentEquivalentWorks.equals(secondParentEquivalentWorks)) {
+ result = true;
+ matchedWork = secondParent;
+ break;
+ }
+ }
+ if (result) {
+ secondParents.remove(matchedWork);
+ } else {
+ return;
+ }
+ }
+ }
+
+ if (compareWork(first, second)) {
+ Set firstEquivalentWorks = equivalentWorks.get(first);
+ if (firstEquivalentWorks == null) {
+ firstEquivalentWorks = Sets.newHashSet();
+ equivalentWorks.put(first, firstEquivalentWorks);
+ }
+ Set secondEquivalentWorks = equivalentWorks.get(second);
+ if (secondEquivalentWorks == null) {
+ secondEquivalentWorks = Sets.newHashSet();
+ equivalentWorks.put(second, secondEquivalentWorks);
+ }
+
+ firstEquivalentWorks.add(first);
+ firstEquivalentWorks.add(second);
+ secondEquivalentWorks.add(first);
+ secondEquivalentWorks.add(second);
+
+ Set childrenWorks = Sets.newHashSet();
+ childrenWorks.addAll(sparkWork.getChildren(first));
+ childrenWorks.addAll(sparkWork.getChildren(second));
+
+ if (childrenWorks.size() > 1) {
+ BaseWork[] childrenWorkArray = new BaseWork[childrenWorks.size()];
+ childrenWorks.toArray(childrenWorkArray);
+ for (int i = 0; i < childrenWorkArray.length; i++) {
+ for (int j = i + 1; j < childrenWorkArray.length; j++) {
+ compareWorkRecursively(childrenWorkArray[i], childrenWorkArray[j], sparkWork);
+ }
+ }
+
+ }
+ }
+ }
+
+ private void replaceWork(BaseWork previous, BaseWork current, SparkWork sparkWork) {
+ List parents = sparkWork.getParents(previous);
+ List children = sparkWork.getChildren(previous);
+ for (BaseWork parent : parents) {
+ SparkEdgeProperty edgeProperty = sparkWork.getEdgeProperty(parent, previous);
+ sparkWork.disconnect(parent, previous);
+ sparkWork.connect(parent, current, edgeProperty);
+ }
+ for (BaseWork child : children) {
+ SparkEdgeProperty edgeProperty = sparkWork.getEdgeProperty(previous, child);
+ sparkWork.disconnect(previous, child);
+ sparkWork.connect(current, child, edgeProperty);
+ }
+ sparkWork.remove(previous);
+ }
+
+ private boolean compareWork(BaseWork first, BaseWork second) {
+
+ if (!first.getClass().getName().equals(second.getClass().getName())) {
+ return false;
+ }
+
+ Set> firstRootOperators = first.getAllRootOperators();
+ Set> secondRootOperators = second.getAllRootOperators();
+ if (firstRootOperators.size() != secondRootOperators.size()) {
+ return false;
+ }
+
+ Iterator> firstIterator = firstRootOperators.iterator();
+ Iterator> secondIterator = secondRootOperators.iterator();
+ while (firstIterator.hasNext()) {
+ boolean result = compareOperatorChain(firstIterator.next(), secondIterator.next());
+ if (!result) {
+ return result;
+ }
+ }
+
+ return true;
+ }
+
+ private boolean compareOperatorChain(Operator> firstOperator, Operator> secondOperator) {
+ boolean result = compareCurrentOperator(firstOperator, secondOperator);
+ if (!result) {
+ return result;
+ }
+
+ List> firstOperatorChildOperators = firstOperator.getChildOperators();
+ List> secondOperatorChildOperators = secondOperator.getChildOperators();
+ if (firstOperatorChildOperators == null && secondOperatorChildOperators != null) {
+ return false;
+ } else if (firstOperatorChildOperators != null && secondOperatorChildOperators == null) {
+ return false;
+ } else if (firstOperatorChildOperators != null && secondOperatorChildOperators != null) {
+ if (firstOperatorChildOperators.size() != secondOperatorChildOperators.size()) {
+ return false;
+ }
+ int size = firstOperatorChildOperators.size();
+ for (int i = 0; i < size; i++) {
+ result = compareOperatorChain(firstOperatorChildOperators.get(i), secondOperatorChildOperators.get(i));
+ if (!result) {
+ return false;
+ }
+ }
+ }
+
+ return true;
+ }
+
+ /**
+ * Compare Operators through their Explain output string.
+ *
+ * @param firstOperator
+ * @param secondOperator
+ * @return
+ */
+ private boolean compareCurrentOperator(Operator> firstOperator, Operator> secondOperator) {
+ if (!firstOperator.getClass().getName().equals(secondOperator.getClass().getName())) {
+ return false;
+ }
+
+ Method[] methods = firstOperator.getConf().getClass().getMethods();
+ for (Method m : methods) {
+ Annotation note = AnnotationUtils.getAnnotation(m, Explain.class);
+
+ if (note instanceof Explain) {
+ Explain explain = (Explain) note;
+ if (Explain.Level.EXTENDED.in(explain.explainLevels()) ||
+ Explain.Level.DEFAULT.in(explain.explainLevels())) {
+ Object firstObj = null;
+ Object secondObj = null;
+ try {
+ firstObj = m.invoke(firstOperator.getConf());
+ } catch (Exception e) {
+ LOG.debug("Failed to get method return value.", e);
+ firstObj = null;
+ }
+
+ try {
+ secondObj = m.invoke(secondOperator.getConf());
+ } catch (Exception e) {
+ LOG.debug("Failed to get method return value.", e);
+ secondObj = null;
+ }
+
+ if ((firstObj == null && secondObj != null) ||
+ (firstObj != null && secondObj == null) ||
+ (firstObj != null && secondObj != null &&
+ !firstObj.toString().equals(secondObj.toString()))) {
+ return false;
+ }
+ }
+ }
+ }
+ return true;
+ }
+ }
+}
diff --git a/ql/src/java/org/apache/hadoop/hive/ql/parse/spark/SparkCompiler.java b/ql/src/java/org/apache/hadoop/hive/ql/parse/spark/SparkCompiler.java
index 19aae70..7f2c079 100644
--- a/ql/src/java/org/apache/hadoop/hive/ql/parse/spark/SparkCompiler.java
+++ b/ql/src/java/org/apache/hadoop/hive/ql/parse/spark/SparkCompiler.java
@@ -66,6 +66,7 @@
import org.apache.hadoop.hive.ql.optimizer.physical.SparkMapJoinResolver;
import org.apache.hadoop.hive.ql.optimizer.physical.StageIDsRearranger;
import org.apache.hadoop.hive.ql.optimizer.physical.Vectorizer;
+import org.apache.hadoop.hive.ql.optimizer.spark.CombineEquivalentWorkResolver;
import org.apache.hadoop.hive.ql.optimizer.spark.SetSparkReducerParallelism;
import org.apache.hadoop.hive.ql.optimizer.spark.SparkJoinHintOptimizer;
import org.apache.hadoop.hive.ql.optimizer.spark.SparkJoinOptimizer;
@@ -337,6 +338,8 @@ protected void optimizeTaskPlan(List> rootTasks, Pa
LOG.debug("Skipping stage id rearranger");
}
+ new CombineEquivalentWorkResolver().resolve(physicalCtx);
+
PERF_LOGGER.PerfLogEnd(CLASS_NAME, PerfLogger.SPARK_OPTIMIZE_TASK_TREE);
return;
}