diff --git a/common/src/java/org/apache/hadoop/hive/conf/HiveConf.java b/common/src/java/org/apache/hadoop/hive/conf/HiveConf.java
index 43c53fc..4cd98f6 100644
--- a/common/src/java/org/apache/hadoop/hive/conf/HiveConf.java
+++ b/common/src/java/org/apache/hadoop/hive/conf/HiveConf.java
@@ -2059,6 +2059,11 @@ public void setSparkConfigUpdated(boolean isSparkConfigUpdated) {
"Channel logging level for remote Spark driver. One of {DEBUG, ERROR, INFO, TRACE, WARN}."),
SPARK_RPC_SASL_MECHANISM("hive.spark.client.rpc.sasl.mechanisms", "DIGEST-MD5",
"Name of the SASL mechanism to use for authentication."),
+ SPARK_DYNAMIC_RDD_CACHING("hive.spark.dynamic.rdd.caching", false,
+ "When dynamic rdd caching is enabled, Hive would find all the possible shared rdd, cache and reuse it in the" +
+ "generated Spark job finally."),
+ SPARK_DYNAMIC_RDD_CACHING_THRESHOLD("hive.spark.dynamic.rdd.caching.threshold", 100 * 1024 * 1024L,
+ "Maximum table rawDataSize which Hive may try to cache dynamically, default is 100MB."),
NWAYJOINREORDER("hive.reorder.nway.joins", true,
"Runs reordering of tables within single n-way join (i.e.: picks streamtable)"),
HIVE_LOG_N_RECORDS("hive.log.every.n.records", 0L, new RangeValidator(0L, null),
diff --git a/ql/src/java/org/apache/hadoop/hive/ql/exec/spark/CacheTran.java b/ql/src/java/org/apache/hadoop/hive/ql/exec/spark/CacheTran.java
new file mode 100644
index 0000000..f3a53cf
--- /dev/null
+++ b/ql/src/java/org/apache/hadoop/hive/ql/exec/spark/CacheTran.java
@@ -0,0 +1,54 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.hadoop.hive.ql.exec.spark;
+
+import org.apache.hadoop.io.WritableComparable;
+import org.apache.spark.api.java.JavaPairRDD;
+import org.apache.spark.storage.StorageLevel;
+
+public abstract class CacheTran
+ implements SparkTran {
+ // whether to cache current RDD.
+ private boolean caching = false;
+ private JavaPairRDD cachedRDD;
+
+ protected CacheTran(boolean cache) {
+ this.caching = cache;
+ }
+
+ @Override
+ public JavaPairRDD transform(
+ JavaPairRDD input) {
+ if (caching) {
+ if (cachedRDD == null) {
+ cachedRDD = doTransform(input);
+ cachedRDD.persist(StorageLevel.MEMORY_ONLY());
+ }
+ return cachedRDD;
+ } else {
+ return doTransform(input);
+ }
+ }
+
+ public Boolean isCacheEnable() {
+ return caching;
+ }
+
+ protected abstract JavaPairRDD doTransform(JavaPairRDD input);
+}
diff --git a/ql/src/java/org/apache/hadoop/hive/ql/exec/spark/MapTran.java b/ql/src/java/org/apache/hadoop/hive/ql/exec/spark/MapTran.java
index 2170243..2a18991 100644
--- a/ql/src/java/org/apache/hadoop/hive/ql/exec/spark/MapTran.java
+++ b/ql/src/java/org/apache/hadoop/hive/ql/exec/spark/MapTran.java
@@ -22,12 +22,20 @@
import org.apache.hadoop.io.BytesWritable;
import org.apache.spark.api.java.JavaPairRDD;
-public class MapTran implements SparkTran {
+public class MapTran extends CacheTran {
private HiveMapFunction mapFunc;
private String name = "MapTran";
+ public MapTran() {
+ this(false);
+ }
+
+ public MapTran(boolean cache) {
+ super(cache);
+ }
+
@Override
- public JavaPairRDD transform(
+ public JavaPairRDD doTransform(
JavaPairRDD input) {
return input.mapPartitionsToPair(mapFunc);
}
@@ -42,11 +50,6 @@ public String getName() {
}
@Override
- public Boolean isCacheEnable() {
- return null;
- }
-
- @Override
public void setName(String name) {
this.name = name;
}
diff --git a/ql/src/java/org/apache/hadoop/hive/ql/exec/spark/ReduceTran.java b/ql/src/java/org/apache/hadoop/hive/ql/exec/spark/ReduceTran.java
index e60dfac..a601a4b 100644
--- a/ql/src/java/org/apache/hadoop/hive/ql/exec/spark/ReduceTran.java
+++ b/ql/src/java/org/apache/hadoop/hive/ql/exec/spark/ReduceTran.java
@@ -22,12 +22,20 @@
import org.apache.hadoop.io.BytesWritable;
import org.apache.spark.api.java.JavaPairRDD;
-public class ReduceTran implements SparkTran, HiveKey, BytesWritable> {
+public class ReduceTran extends CacheTran, HiveKey, BytesWritable> {
private HiveReduceFunction reduceFunc;
private String name = "Reduce";
+ public ReduceTran() {
+ this(false);
+ }
+
+ public ReduceTran(boolean caching) {
+ super(caching);
+ }
+
@Override
- public JavaPairRDD transform(
+ public JavaPairRDD doTransform(
JavaPairRDD> input) {
return input.mapPartitionsToPair(reduceFunc);
}
@@ -42,11 +50,6 @@ public String getName() {
}
@Override
- public Boolean isCacheEnable() {
- return null;
- }
-
- @Override
public void setName(String name) {
this.name = name;
}
diff --git a/ql/src/java/org/apache/hadoop/hive/ql/exec/spark/SparkPlan.java b/ql/src/java/org/apache/hadoop/hive/ql/exec/spark/SparkPlan.java
index ee5c78a..02f05ee 100644
--- a/ql/src/java/org/apache/hadoop/hive/ql/exec/spark/SparkPlan.java
+++ b/ql/src/java/org/apache/hadoop/hive/ql/exec/spark/SparkPlan.java
@@ -88,6 +88,7 @@
}
perfLogger.PerfLogEnd(CLASS_NAME, PerfLogger.SPARK_BUILD_RDD_GRAPH);
+ LOG.info("print generated spark rdd graph:\n" + SparkUtilities.rddGraphToString(finalRDD));
return finalRDD;
}
diff --git a/ql/src/java/org/apache/hadoop/hive/ql/exec/spark/SparkPlanGenerator.java b/ql/src/java/org/apache/hadoop/hive/ql/exec/spark/SparkPlanGenerator.java
index 3f240f5..dc3818f 100644
--- a/ql/src/java/org/apache/hadoop/hive/ql/exec/spark/SparkPlanGenerator.java
+++ b/ql/src/java/org/apache/hadoop/hive/ql/exec/spark/SparkPlanGenerator.java
@@ -74,6 +74,7 @@
private final Map workToParentWorkTranMap;
// a map from each BaseWork to its cloned JobConf
private final Map workToJobConf;
+ private List worksToClone;
public SparkPlanGenerator(
JavaSparkContext sc,
@@ -96,6 +97,7 @@ public SparkPlan generate(SparkWork sparkWork) throws Exception {
perfLogger.PerfLogBegin(CLASS_NAME, PerfLogger.SPARK_BUILD_PLAN);
SparkPlan sparkPlan = new SparkPlan();
cloneToWork = sparkWork.getCloneToWork();
+ worksToClone = sparkWork.getWorksToClone();
workToTranMap.clear();
workToParentWorkTranMap.clear();
@@ -211,13 +213,17 @@ private SparkTran generate(BaseWork work) throws Exception {
JobConf newJobConf = cloneJobConf(work);
checkSpecs(work, newJobConf);
byte[] confBytes = KryoSerializer.serializeJobConf(newJobConf);
+ boolean caching = false;
+ if (worksToClone != null && worksToClone.contains(work)) {
+ caching = true;
+ }
if (work instanceof MapWork) {
- MapTran mapTran = new MapTran();
+ MapTran mapTran = new MapTran(caching);
HiveMapFunction mapFunc = new HiveMapFunction(confBytes, sparkReporter);
mapTran.setMapFunction(mapFunc);
return mapTran;
} else if (work instanceof ReduceWork) {
- ReduceTran reduceTran = new ReduceTran();
+ ReduceTran reduceTran = new ReduceTran(caching);
HiveReduceFunction reduceFunc = new HiveReduceFunction(confBytes, sparkReporter);
reduceTran.setReduceFunction(reduceFunc);
return reduceTran;
diff --git a/ql/src/java/org/apache/hadoop/hive/ql/exec/spark/SparkUtilities.java b/ql/src/java/org/apache/hadoop/hive/ql/exec/spark/SparkUtilities.java
index e6c845c..ca0ffb6 100644
--- a/ql/src/java/org/apache/hadoop/hive/ql/exec/spark/SparkUtilities.java
+++ b/ql/src/java/org/apache/hadoop/hive/ql/exec/spark/SparkUtilities.java
@@ -21,6 +21,7 @@
import java.io.IOException;
import java.net.URI;
import java.net.URISyntaxException;
+import java.util.Collection;
import java.util.UUID;
import org.apache.commons.io.FilenameUtils;
@@ -34,7 +35,12 @@
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.ql.session.SessionState;
import org.apache.hadoop.io.BytesWritable;
+import org.apache.spark.Dependency;
import org.apache.spark.SparkConf;
+import org.apache.spark.api.java.JavaPairRDD;
+import org.apache.spark.rdd.RDD;
+import org.apache.spark.rdd.UnionRDD;
+import scala.collection.JavaConversions;
/**
* Contains utilities methods used as part of Spark tasks.
@@ -122,4 +128,34 @@ public static SparkSession getSparkSession(HiveConf conf,
SessionState.get().setSparkSession(sparkSession);
return sparkSession;
}
+
+
+ public static String rddGraphToString(JavaPairRDD rdd) {
+ StringBuilder sb = new StringBuilder();
+ rddToString(rdd.rdd(), sb, "");
+ return sb.toString();
+ }
+
+ private static void rddToString(RDD rdd, StringBuilder sb, String offset) {
+ sb.append(offset).append(rdd.getClass().getCanonicalName()).append("[").append(rdd.hashCode()).append("]");
+ if (rdd.getStorageLevel().useMemory()) {
+ sb.append("(cached)");
+ }
+ sb.append("\n");
+ Collection dependencies = JavaConversions.asJavaCollection(rdd.dependencies());
+ if (dependencies != null) {
+ offset += "\t";
+ for (Dependency dependency : dependencies) {
+ RDD parentRdd = dependency.rdd();
+ rddToString(parentRdd, sb, offset);
+ }
+ } else if (rdd instanceof UnionRDD) {
+ UnionRDD unionRDD = (UnionRDD) rdd;
+ offset += "\t";
+ Collection parentRdds = JavaConversions.asJavaCollection(unionRDD.rdds());
+ for (RDD parentRdd : parentRdds) {
+ rddToString(parentRdd, sb, offset);
+ }
+ }
+ }
}
diff --git a/ql/src/java/org/apache/hadoop/hive/ql/optimizer/physical/SparkRddCachingResolver.java b/ql/src/java/org/apache/hadoop/hive/ql/optimizer/physical/SparkRddCachingResolver.java
new file mode 100644
index 0000000..cfe6488
--- /dev/null
+++ b/ql/src/java/org/apache/hadoop/hive/ql/optimizer/physical/SparkRddCachingResolver.java
@@ -0,0 +1,114 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.hadoop.hive.ql.optimizer.physical;
+
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Set;
+import java.util.Stack;
+
+import com.google.common.collect.Lists;
+import org.apache.hadoop.hive.conf.HiveConf;
+import org.apache.hadoop.hive.ql.exec.Operator;
+import org.apache.hadoop.hive.ql.exec.spark.SparkTask;
+import org.apache.hadoop.hive.ql.lib.Dispatcher;
+import org.apache.hadoop.hive.ql.lib.Node;
+import org.apache.hadoop.hive.ql.lib.TaskGraphWalker;
+import org.apache.hadoop.hive.ql.parse.SemanticException;
+import org.apache.hadoop.hive.ql.plan.BaseWork;
+import org.apache.hadoop.hive.ql.plan.SparkWork;
+
+/**
+ * A Hive query may scan the same table multi times, or even share same part in the query,
+ * so cache the shared RDD would reduce IO and CPU cost, which help to improve Hive performance.
+ * SparkRddCachingResolver is in charge of walking through Tasks, and parttern matching all
+ * the cacheable MapWork/ReduceWork.
+ */
+public class SparkRddCachingResolver implements PhysicalPlanResolver {
+
+ @Override
+ public PhysicalContext resolve(PhysicalContext pctx) throws SemanticException {
+ List topNodes = new ArrayList();
+ topNodes.addAll(pctx.getRootTasks());
+ TaskGraphWalker taskWalker = new TaskGraphWalker(new CommonParentWorkMatcher(pctx.getConf()));
+ taskWalker.startWalking(topNodes, null);
+ return pctx;
+ }
+
+ /**
+ * A Work graph like:
+ * MapWork1
+ * / \
+ * ReduceWork1 ReduceWork2
+ * would be translated into RDD graph like:
+ * MapPartitionsRDD1
+ * / \
+ * ShuffledRDD1 ShuffledRDD2
+ * / \
+ * MapPartitionsRDD3 MapPartitionsRDD4
+ * In the Spark implementation, MapPartitionsRDD1 would be computed twice, so cache it may improve performance.
+ * ShareParentWorkMatching try to match all the works with multi children in SparkWork here.
+ */
+ class CommonParentWorkMatcher implements Dispatcher {
+ private HiveConf hiveConf;
+
+ CommonParentWorkMatcher(HiveConf hiveConf) {
+ this.hiveConf = hiveConf;
+ }
+
+ @Override
+ public Object dispatch(Node nd, Stack stack, Object... nodeOutputs) throws SemanticException {
+ if (nd instanceof SparkTask) {
+ SparkTask sparkTask = (SparkTask) nd;
+ SparkWork sparkWork = sparkTask.getWork();
+ List toCacheWorks = Lists.newLinkedList();
+ long threshold = HiveConf.getLongVar(hiveConf, HiveConf.ConfVars.SPARK_DYNAMIC_RDD_CACHING_THRESHOLD);
+ for (BaseWork work : sparkWork.getRoots()) {
+ checkWork(toCacheWorks, work, sparkWork, threshold);
+ }
+ sparkWork.setWorksToClone(toCacheWorks);
+ }
+ return null;
+ }
+ }
+
+ private void checkWork(List toCacheWork, BaseWork current, SparkWork sparkWork, long threshold) {
+ List children = sparkWork.getChildren(current);
+ if (children != null) {
+ if (children.size() > 1) {
+ long estimatedDataSize = getEstimatedWorkDataSize(current);
+ if (estimatedDataSize != 0 && estimatedDataSize < threshold) {
+ toCacheWork.add(current);
+ }
+ }
+ for (BaseWork child : children) {
+ checkWork(toCacheWork, child, sparkWork, threshold);
+ }
+ }
+ }
+
+ private long getEstimatedWorkDataSize(BaseWork work) {
+ long size = 0;
+ Set> leafOperators = work.getAllLeafOperators();
+ for (Operator operator : leafOperators) {
+ size += operator.getStatistics().getDataSize();
+ }
+ return size;
+ }
+}
diff --git a/ql/src/java/org/apache/hadoop/hive/ql/parse/spark/SparkCompiler.java b/ql/src/java/org/apache/hadoop/hive/ql/parse/spark/SparkCompiler.java
index 19aae70..1670afe 100644
--- a/ql/src/java/org/apache/hadoop/hive/ql/parse/spark/SparkCompiler.java
+++ b/ql/src/java/org/apache/hadoop/hive/ql/parse/spark/SparkCompiler.java
@@ -64,6 +64,7 @@
import org.apache.hadoop.hive.ql.optimizer.physical.PhysicalContext;
import org.apache.hadoop.hive.ql.optimizer.physical.SparkCrossProductCheck;
import org.apache.hadoop.hive.ql.optimizer.physical.SparkMapJoinResolver;
+import org.apache.hadoop.hive.ql.optimizer.physical.SparkRddCachingResolver;
import org.apache.hadoop.hive.ql.optimizer.physical.StageIDsRearranger;
import org.apache.hadoop.hive.ql.optimizer.physical.Vectorizer;
import org.apache.hadoop.hive.ql.optimizer.spark.SetSparkReducerParallelism;
@@ -337,6 +338,10 @@ protected void optimizeTaskPlan(List> rootTasks, Pa
LOG.debug("Skipping stage id rearranger");
}
+ if (conf.getBoolVar(HiveConf.ConfVars.SPARK_DYNAMIC_RDD_CACHING)) {
+ physicalCtx = new SparkRddCachingResolver().resolve(physicalCtx);
+ }
+
PERF_LOGGER.PerfLogEnd(CLASS_NAME, PerfLogger.SPARK_OPTIMIZE_TASK_TREE);
return;
}
diff --git a/ql/src/java/org/apache/hadoop/hive/ql/plan/SparkWork.java b/ql/src/java/org/apache/hadoop/hive/ql/plan/SparkWork.java
index bb5dd79..0a1bcae 100644
--- a/ql/src/java/org/apache/hadoop/hive/ql/plan/SparkWork.java
+++ b/ql/src/java/org/apache/hadoop/hive/ql/plan/SparkWork.java
@@ -32,6 +32,8 @@
import java.util.Map;
import java.util.Set;
+import com.google.common.collect.ImmutableMap;
+import com.google.common.collect.Lists;
import org.apache.commons.lang3.tuple.ImmutablePair;
import org.apache.commons.lang3.tuple.Pair;
@@ -62,9 +64,12 @@
private Map cloneToWork;
+ private List worksToClone;
+
public SparkWork(String name) {
this.name = name + ":" + (++counter);
cloneToWork = new HashMap();
+ worksToClone = Lists.newLinkedList();
}
@@ -423,4 +428,12 @@ public String toString() {
public void setCloneToWork(Map cloneToWork) {
this.cloneToWork = cloneToWork;
}
+
+ public List getWorksToClone() {
+ return worksToClone;
+ }
+
+ public void setWorksToClone(List worksToClone) {
+ this.worksToClone = worksToClone;
+ }
}