diff --git common/src/java/org/apache/hadoop/hive/conf/HiveConf.java common/src/java/org/apache/hadoop/hive/conf/HiveConf.java
index 43c53fc..73f5f24 100644
--- common/src/java/org/apache/hadoop/hive/conf/HiveConf.java
+++ common/src/java/org/apache/hadoop/hive/conf/HiveConf.java
@@ -2059,6 +2059,11 @@ public void setSparkConfigUpdated(boolean isSparkConfigUpdated) {
"Channel logging level for remote Spark driver. One of {DEBUG, ERROR, INFO, TRACE, WARN}."),
SPARK_RPC_SASL_MECHANISM("hive.spark.client.rpc.sasl.mechanisms", "DIGEST-MD5",
"Name of the SASL mechanism to use for authentication."),
+ SPARK_DYNAMIC_RDD_CACHING("hive.spark.dynamic.rdd.caching", true,
+ "When dynamic rdd caching is enabled, Hive would find all the possible shared rdd, cache and reuse it in the" +
+ "generated Spark job finally."),
+ SPARK_DYNAMIC_RDD_CACHING_THRESHOLD("hive.spark.dynamic.rdd.caching.threshold", 100 * 1024 * 1024L,
+ "Maximum table rawDataSize which Hive may try to cache dynamically, default is 100MB."),
NWAYJOINREORDER("hive.reorder.nway.joins", true,
"Runs reordering of tables within single n-way join (i.e.: picks streamtable)"),
HIVE_LOG_N_RECORDS("hive.log.every.n.records", 0L, new RangeValidator(0L, null),
diff --git ql/src/java/org/apache/hadoop/hive/ql/exec/spark/CacheTran.java ql/src/java/org/apache/hadoop/hive/ql/exec/spark/CacheTran.java
new file mode 100644
index 0000000..f3a53cf
--- /dev/null
+++ ql/src/java/org/apache/hadoop/hive/ql/exec/spark/CacheTran.java
@@ -0,0 +1,54 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.hadoop.hive.ql.exec.spark;
+
+import org.apache.hadoop.io.WritableComparable;
+import org.apache.spark.api.java.JavaPairRDD;
+import org.apache.spark.storage.StorageLevel;
+
+public abstract class CacheTran
+ implements SparkTran {
+ // whether to cache current RDD.
+ private boolean caching = false;
+ private JavaPairRDD cachedRDD;
+
+ protected CacheTran(boolean cache) {
+ this.caching = cache;
+ }
+
+ @Override
+ public JavaPairRDD transform(
+ JavaPairRDD input) {
+ if (caching) {
+ if (cachedRDD == null) {
+ cachedRDD = doTransform(input);
+ cachedRDD.persist(StorageLevel.MEMORY_ONLY());
+ }
+ return cachedRDD;
+ } else {
+ return doTransform(input);
+ }
+ }
+
+ public Boolean isCacheEnable() {
+ return caching;
+ }
+
+ protected abstract JavaPairRDD doTransform(JavaPairRDD input);
+}
diff --git ql/src/java/org/apache/hadoop/hive/ql/exec/spark/LocalHiveSparkClient.java ql/src/java/org/apache/hadoop/hive/ql/exec/spark/LocalHiveSparkClient.java
index 19d3fee..65d9ba7 100644
--- ql/src/java/org/apache/hadoop/hive/ql/exec/spark/LocalHiveSparkClient.java
+++ ql/src/java/org/apache/hadoop/hive/ql/exec/spark/LocalHiveSparkClient.java
@@ -136,7 +136,7 @@ public SparkJobRef execute(DriverContext driverContext, SparkWork sparkWork) thr
// As we always use foreach action to submit RDD graph, it would only trigger one job.
int jobId = future.jobIds().get(0);
LocalSparkJobStatus sparkJobStatus = new LocalSparkJobStatus(
- sc, jobId, jobMetricsListener, sparkCounters, plan.getCachedRDDIds(), future);
+ sc, jobId, jobMetricsListener, sparkCounters, future);
return new LocalSparkJobRef(Integer.toString(jobId), hiveConf, sparkJobStatus, sc);
}
diff --git ql/src/java/org/apache/hadoop/hive/ql/exec/spark/MapInput.java ql/src/java/org/apache/hadoop/hive/ql/exec/spark/MapInput.java
index 26cfebd..a7f9406 100644
--- ql/src/java/org/apache/hadoop/hive/ql/exec/spark/MapInput.java
+++ ql/src/java/org/apache/hadoop/hive/ql/exec/spark/MapInput.java
@@ -18,23 +18,15 @@
package org.apache.hadoop.hive.ql.exec.spark;
-import org.apache.hadoop.conf.Configuration;
+import com.google.common.base.Preconditions;
import org.apache.hadoop.io.Writable;
import org.apache.hadoop.io.WritableComparable;
-import org.apache.hadoop.io.WritableUtils;
import org.apache.spark.api.java.JavaPairRDD;
-import org.apache.spark.api.java.function.PairFunction;
-import org.apache.spark.storage.StorageLevel;
-
-import scala.Tuple2;
-
-import com.google.common.base.Preconditions;
-public class MapInput implements SparkTran {
private JavaPairRDD hadoopRDD;
- private boolean toCache;
private final SparkPlan sparkPlan;
private String name = "MapInput";
@@ -44,47 +36,17 @@ public MapInput(SparkPlan sparkPlan, JavaPairRDD h
public MapInput(SparkPlan sparkPlan,
JavaPairRDD hadoopRDD, boolean toCache) {
+ super(toCache);
this.hadoopRDD = hadoopRDD;
- this.toCache = toCache;
this.sparkPlan = sparkPlan;
}
- public void setToCache(boolean toCache) {
- this.toCache = toCache;
- }
-
@Override
- public JavaPairRDD transform(
+ public JavaPairRDD doTransform(
JavaPairRDD input) {
Preconditions.checkArgument(input == null,
"AssertionError: MapInput doesn't take any input");
- JavaPairRDD result;
- if (toCache) {
- result = hadoopRDD.mapToPair(new CopyFunction());
- sparkPlan.addCachedRDDId(result.id());
- result = result.persist(StorageLevel.MEMORY_AND_DISK());
- } else {
- result = hadoopRDD;
- }
- return result;
- }
-
- private static class CopyFunction implements PairFunction,
- WritableComparable, Writable> {
-
- private transient Configuration conf;
-
- @Override
- public Tuple2
- call(Tuple2 tuple) throws Exception {
- if (conf == null) {
- conf = new Configuration();
- }
-
- return new Tuple2(tuple._1(),
- WritableUtils.clone(tuple._2(), conf));
- }
-
+ return hadoopRDD;
}
@Override
@@ -93,11 +55,6 @@ public String getName() {
}
@Override
- public Boolean isCacheEnable() {
- return new Boolean(toCache);
- }
-
- @Override
public void setName(String name) {
this.name = name;
}
diff --git ql/src/java/org/apache/hadoop/hive/ql/exec/spark/MapTran.java ql/src/java/org/apache/hadoop/hive/ql/exec/spark/MapTran.java
index 2170243..2a18991 100644
--- ql/src/java/org/apache/hadoop/hive/ql/exec/spark/MapTran.java
+++ ql/src/java/org/apache/hadoop/hive/ql/exec/spark/MapTran.java
@@ -22,12 +22,20 @@
import org.apache.hadoop.io.BytesWritable;
import org.apache.spark.api.java.JavaPairRDD;
-public class MapTran implements SparkTran {
+public class MapTran extends CacheTran {
private HiveMapFunction mapFunc;
private String name = "MapTran";
+ public MapTran() {
+ this(false);
+ }
+
+ public MapTran(boolean cache) {
+ super(cache);
+ }
+
@Override
- public JavaPairRDD transform(
+ public JavaPairRDD doTransform(
JavaPairRDD input) {
return input.mapPartitionsToPair(mapFunc);
}
@@ -42,11 +50,6 @@ public String getName() {
}
@Override
- public Boolean isCacheEnable() {
- return null;
- }
-
- @Override
public void setName(String name) {
this.name = name;
}
diff --git ql/src/java/org/apache/hadoop/hive/ql/exec/spark/ReduceTran.java ql/src/java/org/apache/hadoop/hive/ql/exec/spark/ReduceTran.java
index e60dfac..3d56876 100644
--- ql/src/java/org/apache/hadoop/hive/ql/exec/spark/ReduceTran.java
+++ ql/src/java/org/apache/hadoop/hive/ql/exec/spark/ReduceTran.java
@@ -22,12 +22,20 @@
import org.apache.hadoop.io.BytesWritable;
import org.apache.spark.api.java.JavaPairRDD;
-public class ReduceTran implements SparkTran, HiveKey, BytesWritable> {
+public class ReduceTran extends CacheTran, HiveKey, BytesWritable> {
private HiveReduceFunction reduceFunc;
private String name = "Reduce";
+ public ReduceTran() {
+ this(false);
+ }
+
+ public ReduceTran(boolean caching) {
+ super(caching);
+ }
+
@Override
- public JavaPairRDD transform(
+ public JavaPairRDD doTransform(
JavaPairRDD> input) {
return input.mapPartitionsToPair(reduceFunc);
}
@@ -42,11 +50,6 @@ public String getName() {
}
@Override
- public Boolean isCacheEnable() {
- return null;
- }
-
- @Override
public void setName(String name) {
this.name = name;
}
diff --git ql/src/java/org/apache/hadoop/hive/ql/exec/spark/RemoteHiveSparkClient.java ql/src/java/org/apache/hadoop/hive/ql/exec/spark/RemoteHiveSparkClient.java
index 8b15099..2a4686b 100644
--- ql/src/java/org/apache/hadoop/hive/ql/exec/spark/RemoteHiveSparkClient.java
+++ ql/src/java/org/apache/hadoop/hive/ql/exec/spark/RemoteHiveSparkClient.java
@@ -259,7 +259,7 @@ public Serializable call(JobContext jc) throws Exception {
JavaPairRDD finalRDD = plan.generateGraph();
// We use Spark RDD async action to submit job as it's the only way to get jobId now.
JavaFutureAction future = finalRDD.foreachAsync(HiveVoidFunction.getInstance());
- jc.monitor(future, sparkCounters, plan.getCachedRDDIds());
+ jc.monitor(future, sparkCounters);
return null;
}
diff --git ql/src/java/org/apache/hadoop/hive/ql/exec/spark/ShuffleTran.java ql/src/java/org/apache/hadoop/hive/ql/exec/spark/ShuffleTran.java
index a774395..12e67e6 100644
--- ql/src/java/org/apache/hadoop/hive/ql/exec/spark/ShuffleTran.java
+++ ql/src/java/org/apache/hadoop/hive/ql/exec/spark/ShuffleTran.java
@@ -23,10 +23,9 @@
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.storage.StorageLevel;
-public class ShuffleTran implements SparkTran> {
+public class ShuffleTran extends CacheTran> {
private final SparkShuffler shuffler;
private final int numOfPartitions;
- private final boolean toCache;
private final SparkPlan sparkPlan;
private String name = "Shuffle";
@@ -35,19 +34,15 @@ public ShuffleTran(SparkPlan sparkPlan, SparkShuffler sf, int n) {
}
public ShuffleTran(SparkPlan sparkPlan, SparkShuffler sf, int n, boolean toCache) {
+ super(toCache);
shuffler = sf;
numOfPartitions = n;
- this.toCache = toCache;
this.sparkPlan = sparkPlan;
}
@Override
- public JavaPairRDD> transform(JavaPairRDD input) {
+ public JavaPairRDD> doTransform(JavaPairRDD input) {
JavaPairRDD> result = shuffler.shuffle(input, numOfPartitions);
- if (toCache) {
- sparkPlan.addCachedRDDId(result.id());
- result = result.persist(StorageLevel.MEMORY_AND_DISK());
- }
return result;
}
@@ -61,11 +56,6 @@ public String getName() {
}
@Override
- public Boolean isCacheEnable() {
- return new Boolean(toCache);
- }
-
- @Override
public void setName(String name) {
this.name = name;
}
diff --git ql/src/java/org/apache/hadoop/hive/ql/exec/spark/SparkPlan.java ql/src/java/org/apache/hadoop/hive/ql/exec/spark/SparkPlan.java
index ee5c78a..0cb20c9 100644
--- ql/src/java/org/apache/hadoop/hive/ql/exec/spark/SparkPlan.java
+++ ql/src/java/org/apache/hadoop/hive/ql/exec/spark/SparkPlan.java
@@ -88,6 +88,7 @@
}
perfLogger.PerfLogEnd(CLASS_NAME, PerfLogger.SPARK_BUILD_RDD_GRAPH);
+ LOG.info("print generated spark rdd graph:\n" + SparkUtilities.rddGraphToString(finalRDD));
return finalRDD;
}
@@ -238,14 +239,6 @@ public void addTran(SparkTran tran) {
leafTrans.add(tran);
}
- public void addCachedRDDId(int rddId) {
- cachedRDDIds.add(rddId);
- }
-
- public Set getCachedRDDIds() {
- return cachedRDDIds;
- }
-
/**
* This method returns a topologically sorted list of SparkTran.
*/
diff --git ql/src/java/org/apache/hadoop/hive/ql/exec/spark/SparkPlanGenerator.java ql/src/java/org/apache/hadoop/hive/ql/exec/spark/SparkPlanGenerator.java
index 3f240f5..c91de08 100644
--- ql/src/java/org/apache/hadoop/hive/ql/exec/spark/SparkPlanGenerator.java
+++ ql/src/java/org/apache/hadoop/hive/ql/exec/spark/SparkPlanGenerator.java
@@ -22,9 +22,13 @@
import java.util.List;
import java.util.Map;
import java.util.Set;
+import java.util.concurrent.ConcurrentMap;
+import java.util.concurrent.atomic.AtomicInteger;
import com.google.common.base.Preconditions;
+import com.google.common.collect.Maps;
+import org.apache.commons.lang3.tuple.Pair;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.fs.Path;
@@ -69,12 +73,12 @@
private Context context;
private Path scratchDir;
private SparkReporter sparkReporter;
- private Map cloneToWork;
private final Map workToTranMap;
- private final Map workToParentWorkTranMap;
// a map from each BaseWork to its cloned JobConf
private final Map workToJobConf;
+ private List cachingWorks;
+
public SparkPlanGenerator(
JavaSparkContext sc,
Context context,
@@ -87,17 +91,15 @@ public SparkPlanGenerator(
this.jobConf = jobConf;
this.scratchDir = scratchDir;
this.workToTranMap = new HashMap();
- this.workToParentWorkTranMap = new HashMap();
this.sparkReporter = sparkReporter;
this.workToJobConf = new HashMap();
}
public SparkPlan generate(SparkWork sparkWork) throws Exception {
perfLogger.PerfLogBegin(CLASS_NAME, PerfLogger.SPARK_BUILD_PLAN);
+ cachingWorks = sparkWork.getCachingWorks();
SparkPlan sparkPlan = new SparkPlan();
- cloneToWork = sparkWork.getCloneToWork();
workToTranMap.clear();
- workToParentWorkTranMap.clear();
try {
for (BaseWork work : sparkWork.getAllWork()) {
@@ -122,13 +124,6 @@ public SparkPlan generate(SparkWork sparkWork) throws Exception {
// Generate (possibly get from a cached result) parent SparkTran
private SparkTran generateParentTran(SparkPlan sparkPlan, SparkWork sparkWork,
BaseWork work) throws Exception {
- if (cloneToWork.containsKey(work)) {
- BaseWork originalWork = cloneToWork.get(work);
- if (workToParentWorkTranMap.containsKey(originalWork)) {
- return workToParentWorkTranMap.get(originalWork);
- }
- }
-
SparkTran result;
if (work instanceof MapWork) {
result = generateMapInput(sparkPlan, (MapWork)work);
@@ -136,7 +131,7 @@ private SparkTran generateParentTran(SparkPlan sparkPlan, SparkWork sparkWork,
} else if (work instanceof ReduceWork) {
List parentWorks = sparkWork.getParents(work);
result = generate(sparkPlan,
- sparkWork.getEdgeProperty(parentWorks.get(0), work), cloneToWork.containsKey(work));
+ sparkWork.getEdgeProperty(parentWorks.get(0), work));
sparkPlan.addTran(result);
for (BaseWork parentWork : parentWorks) {
sparkPlan.connect(workToTranMap.get(parentWork), result);
@@ -146,10 +141,6 @@ private SparkTran generateParentTran(SparkPlan sparkPlan, SparkWork sparkWork,
+ "but found " + work.getClass().getName());
}
- if (cloneToWork.containsKey(work)) {
- workToParentWorkTranMap.put(cloneToWork.get(work), result);
- }
-
return result;
}
@@ -181,18 +172,17 @@ private SparkTran generateParentTran(SparkPlan sparkPlan, SparkWork sparkWork,
@SuppressWarnings("unchecked")
private MapInput generateMapInput(SparkPlan sparkPlan, MapWork mapWork)
- throws Exception {
+ throws Exception {
JobConf jobConf = cloneJobConf(mapWork);
Class ifClass = getInputFormat(jobConf, mapWork);
JavaPairRDD hadoopRDD = sc.hadoopRDD(jobConf, ifClass,
- WritableComparable.class, Writable.class);
- // Caching is disabled for MapInput due to HIVE-8920
- MapInput result = new MapInput(sparkPlan, hadoopRDD, false/*cloneToWork.containsKey(mapWork)*/);
+ WritableComparable.class, Writable.class);
+ MapInput result = new MapInput(sparkPlan, hadoopRDD);
return result;
}
- private ShuffleTran generate(SparkPlan sparkPlan, SparkEdgeProperty edge, boolean toCache) {
+ private ShuffleTran generate(SparkPlan sparkPlan, SparkEdgeProperty edge) {
Preconditions.checkArgument(!edge.isShuffleNone(),
"AssertionError: SHUFFLE_NONE should only be used for UnionWork.");
SparkShuffler shuffler;
@@ -203,7 +193,7 @@ private ShuffleTran generate(SparkPlan sparkPlan, SparkEdgeProperty edge, boolea
} else {
shuffler = new GroupByShuffler();
}
- return new ShuffleTran(sparkPlan, shuffler, edge.getNumPartitions(), toCache);
+ return new ShuffleTran(sparkPlan, shuffler, edge.getNumPartitions());
}
private SparkTran generate(BaseWork work) throws Exception {
@@ -211,13 +201,17 @@ private SparkTran generate(BaseWork work) throws Exception {
JobConf newJobConf = cloneJobConf(work);
checkSpecs(work, newJobConf);
byte[] confBytes = KryoSerializer.serializeJobConf(newJobConf);
+ boolean caching = false;
+ if (cachingWorks != null && cachingWorks.contains(work)) {
+ caching = true;
+ }
if (work instanceof MapWork) {
- MapTran mapTran = new MapTran();
+ MapTran mapTran = new MapTran(caching);
HiveMapFunction mapFunc = new HiveMapFunction(confBytes, sparkReporter);
mapTran.setMapFunction(mapFunc);
return mapTran;
} else if (work instanceof ReduceWork) {
- ReduceTran reduceTran = new ReduceTran();
+ ReduceTran reduceTran = new ReduceTran(caching);
HiveReduceFunction reduceFunc = new HiveReduceFunction(confBytes, sparkReporter);
reduceTran.setReduceFunction(reduceFunc);
return reduceTran;
@@ -295,5 +289,4 @@ private void initStatsPublisher(BaseWork work) throws HiveException {
}
}
}
-
}
diff --git ql/src/java/org/apache/hadoop/hive/ql/exec/spark/SparkUtilities.java ql/src/java/org/apache/hadoop/hive/ql/exec/spark/SparkUtilities.java
index e6c845c..2529d8a 100644
--- ql/src/java/org/apache/hadoop/hive/ql/exec/spark/SparkUtilities.java
+++ ql/src/java/org/apache/hadoop/hive/ql/exec/spark/SparkUtilities.java
@@ -21,6 +21,7 @@
import java.io.IOException;
import java.net.URI;
import java.net.URISyntaxException;
+import java.util.Collection;
import java.util.UUID;
import org.apache.commons.io.FilenameUtils;
@@ -34,7 +35,12 @@
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.ql.session.SessionState;
import org.apache.hadoop.io.BytesWritable;
+import org.apache.spark.Dependency;
import org.apache.spark.SparkConf;
+import org.apache.spark.api.java.JavaPairRDD;
+import org.apache.spark.rdd.RDD;
+import org.apache.spark.rdd.UnionRDD;
+import scala.collection.JavaConversions;
/**
* Contains utilities methods used as part of Spark tasks.
@@ -122,4 +128,33 @@ public static SparkSession getSparkSession(HiveConf conf,
SessionState.get().setSparkSession(sparkSession);
return sparkSession;
}
+
+ public static String rddGraphToString(JavaPairRDD rdd) {
+ StringBuilder sb = new StringBuilder();
+ rddToString(rdd.rdd(), sb, "");
+ return sb.toString();
+ }
+
+ private static void rddToString(RDD rdd, StringBuilder sb, String offset) {
+ sb.append(offset).append(rdd.getClass().getCanonicalName()).append("[").append(rdd.hashCode()).append("]");
+ if (rdd.getStorageLevel().useMemory()) {
+ sb.append("(cached)");
+ }
+ sb.append("\n");
+ Collection dependencies = JavaConversions.asJavaCollection(rdd.dependencies());
+ if (dependencies != null) {
+ offset += "\t";
+ for (Dependency dependency : dependencies) {
+ RDD parentRdd = dependency.rdd();
+ rddToString(parentRdd, sb, offset);
+ }
+ } else if (rdd instanceof UnionRDD) {
+ UnionRDD unionRDD = (UnionRDD) rdd;
+ offset += "\t";
+ Collection parentRdds = JavaConversions.asJavaCollection(unionRDD.rdds());
+ for (RDD parentRdd : parentRdds) {
+ rddToString(parentRdd, sb, offset);
+ }
+ }
+ }
}
diff --git ql/src/java/org/apache/hadoop/hive/ql/exec/spark/status/impl/LocalSparkJobStatus.java ql/src/java/org/apache/hadoop/hive/ql/exec/spark/status/impl/LocalSparkJobStatus.java
index 5d62596..4081232 100644
--- ql/src/java/org/apache/hadoop/hive/ql/exec/spark/status/impl/LocalSparkJobStatus.java
+++ ql/src/java/org/apache/hadoop/hive/ql/exec/spark/status/impl/LocalSparkJobStatus.java
@@ -52,16 +52,14 @@
private JobMetricsListener jobMetricsListener;
private SparkCounters sparkCounters;
private JavaFutureAction future;
- private Set cachedRDDIds;
public LocalSparkJobStatus(JavaSparkContext sparkContext, int jobId,
JobMetricsListener jobMetricsListener, SparkCounters sparkCounters,
- Set cachedRDDIds, JavaFutureAction future) {
+ JavaFutureAction future) {
this.sparkContext = sparkContext;
this.jobId = jobId;
this.jobMetricsListener = jobMetricsListener;
this.sparkCounters = sparkCounters;
- this.cachedRDDIds = cachedRDDIds;
this.future = future;
}
@@ -141,11 +139,6 @@ public SparkStatistics getSparkStatistics() {
@Override
public void cleanup() {
jobMetricsListener.cleanup(jobId);
- if (cachedRDDIds != null) {
- for (Integer cachedRDDId: cachedRDDIds) {
- sparkContext.sc().unpersistRDD(cachedRDDId, false);
- }
- }
}
private Map combineJobLevelMetrics(Map> jobMetric) {
diff --git ql/src/java/org/apache/hadoop/hive/ql/optimizer/physical/SparkMapJoinResolver.java ql/src/java/org/apache/hadoop/hive/ql/optimizer/physical/SparkMapJoinResolver.java
index 8e56263..7b3026b 100644
--- ql/src/java/org/apache/hadoop/hive/ql/optimizer/physical/SparkMapJoinResolver.java
+++ ql/src/java/org/apache/hadoop/hive/ql/optimizer/physical/SparkMapJoinResolver.java
@@ -163,8 +163,6 @@ private void moveWork(SparkWork sparkWork, BaseWork work, SparkWork targetWork)
// Create a new SparkWork for all the small tables of this work
SparkWork parentWork =
new SparkWork(physicalContext.conf.getVar(HiveConf.ConfVars.HIVEQUERYID));
- // copy cloneToWork to ensure RDD cache still works
- parentWork.setCloneToWork(sparkWork.getCloneToWork());
dependencyGraph.get(targetWork).add(parentWork);
dependencyGraph.put(parentWork, new ArrayList());
diff --git ql/src/java/org/apache/hadoop/hive/ql/optimizer/physical/SparkRddCachingResolver.java ql/src/java/org/apache/hadoop/hive/ql/optimizer/physical/SparkRddCachingResolver.java
new file mode 100644
index 0000000..8ff8fd7
--- /dev/null
+++ ql/src/java/org/apache/hadoop/hive/ql/optimizer/physical/SparkRddCachingResolver.java
@@ -0,0 +1,181 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.hadoop.hive.ql.optimizer.physical;
+
+import java.io.Serializable;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.Iterator;
+import java.util.LinkedHashMap;
+import java.util.LinkedList;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+import java.util.Stack;
+import java.util.UUID;
+
+import com.google.common.base.Preconditions;
+import com.google.common.collect.Lists;
+import com.google.common.collect.Maps;
+import com.google.common.collect.Sets;
+import org.apache.commons.lang3.tuple.MutablePair;
+import org.apache.commons.lang3.tuple.Pair;
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.hive.conf.HiveConf;
+import org.apache.hadoop.hive.ql.DriverContext;
+import org.apache.hadoop.hive.ql.exec.MapOperator;
+import org.apache.hadoop.hive.ql.exec.Operator;
+import org.apache.hadoop.hive.ql.exec.TableScanOperator;
+import org.apache.hadoop.hive.ql.exec.Task;
+import org.apache.hadoop.hive.ql.exec.spark.SparkTask;
+import org.apache.hadoop.hive.ql.io.orc.OrcInputFormat;
+import org.apache.hadoop.hive.ql.io.orc.OrcNewInputFormat;
+import org.apache.hadoop.hive.ql.lib.Dispatcher;
+import org.apache.hadoop.hive.ql.lib.Node;
+import org.apache.hadoop.hive.ql.lib.TaskGraphWalker;
+import org.apache.hadoop.hive.ql.metadata.HiveException;
+import org.apache.hadoop.hive.ql.metadata.Table;
+import org.apache.hadoop.hive.ql.parse.SemanticException;
+import org.apache.hadoop.hive.ql.plan.BaseWork;
+import org.apache.hadoop.hive.ql.plan.MapWork;
+import org.apache.hadoop.hive.ql.plan.OperatorDesc;
+import org.apache.hadoop.hive.ql.plan.PartitionDesc;
+import org.apache.hadoop.hive.ql.plan.SparkWork;
+import org.apache.hadoop.hive.ql.plan.TableDesc;
+import org.apache.hadoop.hive.ql.plan.TableScanDesc;
+import org.apache.hadoop.hive.ql.plan.api.StageType;
+import org.apache.hadoop.hive.ql.stats.StatsUtils;
+import org.apache.hadoop.hive.serde2.Deserializer;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorConverters;
+import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;
+import org.apache.hadoop.mapred.InputFormat;
+import parquet.hadoop.ParquetInputFormat;
+
+/**
+ * A Hive query may scan the same table multi times, or even share same part in the query,
+ * so cache the shared RDD would reduce IO and CPU cost, which help to improve Hive performance.
+ * SparkRddCachingResolver is in charge of walking through Tasks, and parttern matching all
+ * the cacheable MapWork/ReduceWork.
+ */
+public class SparkRddCachingResolver implements PhysicalPlanResolver {
+
+ private static final Log LOG = LogFactory.getLog(SparkRddCachingResolver.class);
+
+ @Override
+ public PhysicalContext resolve(PhysicalContext pctx) throws SemanticException {
+ List sparkTasks = getAllSparkTasks(pctx);
+ ArrayList topNodes = new ArrayList();
+ if (sparkTasks.size() > 0) {
+ topNodes.addAll(sparkTasks);
+ TaskGraphWalker sparkWorkOgw = new TaskGraphWalker(new ShareParentWorkMatching(pctx.getConf()));
+ sparkWorkOgw.startWalking(topNodes, null);
+ } else {
+ LOG.info("No SparkWork found, skip SparkRddCachingResolver.");
+ }
+
+ return pctx;
+ }
+
+ /**
+ * A Work graph like:
+ * MapWork1
+ * / \
+ * ReduceWork1 ReduceWork2
+ * would be translated into RDD graph like:
+ * MapPartitionsRDD1
+ * / \
+ * ShuffledRDD1 ShuffledRDD2
+ * / \
+ * MapPartitionsRDD3 MapPartitionsRDD4
+ * In the Spark implementation, MapPartitionsRDD1 would be computed twice, so cache it may improve performance.
+ * ShareParentWorkMatching try to match all the works with multi children in SparkWork here.
+ */
+ class ShareParentWorkMatching implements Dispatcher {
+ private HiveConf hiveConf;
+
+ ShareParentWorkMatching(HiveConf hiveConf) {
+ this.hiveConf = hiveConf;
+ }
+
+ @Override
+ public Object dispatch(Node nd, Stack stack, Object... nodeOutputs) throws SemanticException {
+ if (nd instanceof SparkTask) {
+ SparkTask sparkTask = (SparkTask) nd;
+ SparkWork sparkWork = sparkTask.getWork();
+ List toCacheWorks = Lists.newLinkedList();
+ List> pairs = Lists.newArrayList(sparkWork.getEdgeProperties().keySet());
+ int size = pairs.size();
+ for (int i = 0; i < size; i++) {
+ Pair first = pairs.get(i);
+ for (int j = i + 1; j < size; j++) {
+ Pair second = pairs.get(j);
+ if (first.getKey().equals(second.getKey()) && !first.getValue().equals(second.getValue())) {
+ BaseWork work = first.getKey();
+ long estimatedDataSize = getWorkDataSize(work);
+ long threshold = HiveConf.getLongVar(hiveConf, HiveConf.ConfVars.SPARK_DYNAMIC_RDD_CACHING_THRESHOLD);
+ if (estimatedDataSize > threshold) {
+ continue;
+ } else {
+ toCacheWorks.add(work);
+ }
+ }
+ }
+ }
+ sparkWork.setCachingWorks(toCacheWorks);
+ }
+ return null;
+ }
+ }
+
+ private List getAllSparkTasks(PhysicalContext pctx) {
+ List sparkTasks = new LinkedList();
+ List> rootTasks = pctx.getRootTasks();
+ if (rootTasks != null) {
+ for (Task extends Serializable> task : rootTasks) {
+ getSparkTask(task, sparkTasks);
+ }
+ }
+ return sparkTasks;
+ }
+
+ private void getSparkTask(Task task, List sparkTasks) {
+ if (task instanceof SparkTask) {
+ sparkTasks.add((SparkTask) task);
+ List childTasks = task.getChildTasks();
+ if (childTasks != null) {
+ for (Task childTask : childTasks) {
+ getSparkTask(childTask, sparkTasks);
+ }
+ }
+ }
+ }
+
+ private long getWorkDataSize(BaseWork work) {
+ long size = 0;
+ Set> leafOperators = work.getAllLeafOperators();
+ for (Operator operator : leafOperators) {
+ size += operator.getStatistics().getDataSize();
+ }
+ return size;
+ }
+}
diff --git ql/src/java/org/apache/hadoop/hive/ql/optimizer/spark/SparkSkewJoinProcFactory.java ql/src/java/org/apache/hadoop/hive/ql/optimizer/spark/SparkSkewJoinProcFactory.java
index 5990d17..aba6954 100644
--- ql/src/java/org/apache/hadoop/hive/ql/optimizer/spark/SparkSkewJoinProcFactory.java
+++ ql/src/java/org/apache/hadoop/hive/ql/optimizer/spark/SparkSkewJoinProcFactory.java
@@ -121,7 +121,6 @@ private static void splitTask(SparkTask currentTask, ReduceWork reduceWork,
// remove them from current spark work
for (BaseWork baseWork : newWork.getAllWorkUnsorted()) {
currentWork.remove(baseWork);
- currentWork.getCloneToWork().remove(baseWork);
}
// create TS to read intermediate data
Context baseCtx = parseContext.getContext();
diff --git ql/src/java/org/apache/hadoop/hive/ql/optimizer/spark/SplitSparkWorkResolver.java ql/src/java/org/apache/hadoop/hive/ql/optimizer/spark/SplitSparkWorkResolver.java
index fb20080..67252da 100644
--- ql/src/java/org/apache/hadoop/hive/ql/optimizer/spark/SplitSparkWorkResolver.java
+++ ql/src/java/org/apache/hadoop/hive/ql/optimizer/spark/SplitSparkWorkResolver.java
@@ -128,7 +128,6 @@ private void splitBaseWork(SparkWork sparkWork, BaseWork parentWork, List> rootTasks, Pa
LOG.debug("Skipping stage id rearranger");
}
+ if (conf.getBoolVar(HiveConf.ConfVars.SPARK_DYNAMIC_RDD_CACHING)) {
+ physicalCtx = new SparkRddCachingResolver().resolve(physicalCtx);
+ }
+
PERF_LOGGER.PerfLogEnd(CLASS_NAME, PerfLogger.SPARK_OPTIMIZE_TASK_TREE);
return;
}
diff --git ql/src/java/org/apache/hadoop/hive/ql/plan/SparkWork.java ql/src/java/org/apache/hadoop/hive/ql/plan/SparkWork.java
index bb5dd79..56d9714 100644
--- ql/src/java/org/apache/hadoop/hive/ql/plan/SparkWork.java
+++ ql/src/java/org/apache/hadoop/hive/ql/plan/SparkWork.java
@@ -19,23 +19,24 @@
package org.apache.hadoop.hive.ql.plan;
import java.io.Serializable;
-
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
+import java.util.LinkedHashMap;
import java.util.LinkedHashSet;
import java.util.LinkedList;
-import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
+import com.google.common.base.Preconditions;
+import com.google.common.collect.ImmutableMap;
+import com.google.common.collect.Lists;
import org.apache.commons.lang3.tuple.ImmutablePair;
import org.apache.commons.lang3.tuple.Pair;
-import com.google.common.base.Preconditions;
/**
* This class encapsulates all the work objects that can be executed
@@ -60,11 +61,10 @@
private Map> requiredCounterPrefix;
- private Map cloneToWork;
+ private List cachingWorks;
public SparkWork(String name) {
this.name = name + ":" + (++counter);
- cloneToWork = new HashMap();
}
@@ -416,11 +416,28 @@ public String toString() {
return result;
}
- public Map getCloneToWork() {
- return cloneToWork;
+ /**
+ * @return all map works of this spark work, in sorted order.
+ */
+ public List getAllMapWork() {
+ List result = Lists.newLinkedList();
+ for (BaseWork work : getAllWork()) {
+ if (work instanceof MapWork) {
+ result.add((MapWork) work);
+ }
+ }
+ return result;
+ }
+
+ public Map, SparkEdgeProperty> getEdgeProperties() {
+ return ImmutableMap.copyOf(edgeProperties);
+ }
+
+ public List getCachingWorks() {
+ return cachingWorks;
}
- public void setCloneToWork(Map cloneToWork) {
- this.cloneToWork = cloneToWork;
+ public void setCachingWorks(List cachingWorks) {
+ this.cachingWorks = cachingWorks;
}
}
diff --git spark-client/src/main/java/org/apache/hive/spark/client/JobContext.java spark-client/src/main/java/org/apache/hive/spark/client/JobContext.java
index af6332e..590a604 100644
--- spark-client/src/main/java/org/apache/hive/spark/client/JobContext.java
+++ spark-client/src/main/java/org/apache/hive/spark/client/JobContext.java
@@ -47,7 +47,7 @@
* @return The job (unmodified).
*/
JavaFutureAction monitor(
- JavaFutureAction job, SparkCounters sparkCounters, Set cachedRDDIds);
+ JavaFutureAction job, SparkCounters sparkCounters);
/**
* Return a map from client job Id to corresponding JavaFutureActions.
diff --git spark-client/src/main/java/org/apache/hive/spark/client/JobContextImpl.java spark-client/src/main/java/org/apache/hive/spark/client/JobContextImpl.java
index beed8a3..7732476 100644
--- spark-client/src/main/java/org/apache/hive/spark/client/JobContextImpl.java
+++ spark-client/src/main/java/org/apache/hive/spark/client/JobContextImpl.java
@@ -54,8 +54,8 @@ public JavaSparkContext sc() {
@Override
public JavaFutureAction monitor(JavaFutureAction job,
- SparkCounters sparkCounters, Set cachedRDDIds) {
- monitorCb.get().call(job, sparkCounters, cachedRDDIds);
+ SparkCounters sparkCounters) {
+ monitorCb.get().call(job, sparkCounters);
return job;
}
diff --git spark-client/src/main/java/org/apache/hive/spark/client/MonitorCallback.java spark-client/src/main/java/org/apache/hive/spark/client/MonitorCallback.java
index e1e899e..b09f28e 100644
--- spark-client/src/main/java/org/apache/hive/spark/client/MonitorCallback.java
+++ spark-client/src/main/java/org/apache/hive/spark/client/MonitorCallback.java
@@ -25,6 +25,6 @@
interface MonitorCallback {
- void call(JavaFutureAction> future, SparkCounters sparkCounters, Set cachedRDDIds);
+ void call(JavaFutureAction> future, SparkCounters sparkCounters);
}
diff --git spark-client/src/main/java/org/apache/hive/spark/client/RemoteDriver.java spark-client/src/main/java/org/apache/hive/spark/client/RemoteDriver.java
index b77c9e8..1703309 100644
--- spark-client/src/main/java/org/apache/hive/spark/client/RemoteDriver.java
+++ spark-client/src/main/java/org/apache/hive/spark/client/RemoteDriver.java
@@ -318,7 +318,7 @@ private Object handle(ChannelHandlerContext ctx, SyncJobRequest msg) throws Exce
jc.setMonitorCb(new MonitorCallback() {
@Override
public void call(JavaFutureAction> future,
- SparkCounters sparkCounters, Set cachedRDDIds) {
+ SparkCounters sparkCounters) {
throw new IllegalStateException(
"JobContext.monitor() is not available for synchronous jobs.");
}
@@ -338,8 +338,6 @@ public void call(JavaFutureAction> future,
private final List> jobs;
private final AtomicInteger completed;
private SparkCounters sparkCounters;
- private Set cachedRDDIds;
-
private Future> future;
JobWrapper(BaseProtocol.JobRequest req) {
@@ -347,7 +345,6 @@ public void call(JavaFutureAction> future,
this.jobs = Lists.newArrayList();
this.completed = new AtomicInteger();
this.sparkCounters = null;
- this.cachedRDDIds = null;
}
@Override
@@ -358,8 +355,8 @@ public Void call() throws Exception {
jc.setMonitorCb(new MonitorCallback() {
@Override
public void call(JavaFutureAction> future,
- SparkCounters sparkCounters, Set cachedRDDIds) {
- monitorJob(future, sparkCounters, cachedRDDIds);
+ SparkCounters sparkCounters) {
+ monitorJob(future, sparkCounters);
}
});
@@ -393,7 +390,6 @@ public void call(JavaFutureAction> future,
} finally {
jc.setMonitorCb(null);
activeJobs.remove(req.id);
- releaseCache();
}
return null;
}
@@ -409,30 +405,14 @@ void jobDone() {
}
}
- /**
- * Release cached RDDs as soon as the job is done.
- * This is different from local Spark client so as
- * to save a RPC call/trip, avoid passing cached RDD
- * id information around. Otherwise, we can follow
- * the local Spark client way to be consistent.
- */
- void releaseCache() {
- if (cachedRDDIds != null) {
- for (Integer cachedRDDId: cachedRDDIds) {
- jc.sc().sc().unpersistRDD(cachedRDDId, false);
- }
- }
- }
-
private void monitorJob(JavaFutureAction> job,
- SparkCounters sparkCounters, Set cachedRDDIds) {
+ SparkCounters sparkCounters) {
jobs.add(job);
if (!jc.getMonitoredJobs().containsKey(req.id)) {
jc.getMonitoredJobs().put(req.id, new CopyOnWriteArrayList>());
}
jc.getMonitoredJobs().get(req.id).add(job);
this.sparkCounters = sparkCounters;
- this.cachedRDDIds = cachedRDDIds;
protocol.jobSubmitted(req.id, job.jobIds().get(0));
}
diff --git spark-client/src/test/java/org/apache/hive/spark/client/TestSparkClient.java spark-client/src/test/java/org/apache/hive/spark/client/TestSparkClient.java
index d33ad7e..df91079 100644
--- spark-client/src/test/java/org/apache/hive/spark/client/TestSparkClient.java
+++ spark-client/src/test/java/org/apache/hive/spark/client/TestSparkClient.java
@@ -319,7 +319,7 @@ public Integer call(JobContext jc) throws Exception {
public void call(Integer l) throws Exception {
}
- }), null, null);
+ }), null);
future.get(TIMEOUT, TimeUnit.SECONDS);
@@ -380,7 +380,7 @@ public String call(JobContext jc) {
counters.createCounter("group2", "counter2");
jc.monitor(jc.sc().parallelize(Arrays.asList(1, 2, 3, 4, 5), 5).foreachAsync(this),
- counters, null);
+ counters);
return null;
}