diff --git a/ql/src/java/org/apache/hadoop/hive/ql/exec/spark/SortByShuffler.java b/ql/src/java/org/apache/hadoop/hive/ql/exec/spark/SortByShuffler.java index 766813c..9e539f0 100644 --- a/ql/src/java/org/apache/hadoop/hive/ql/exec/spark/SortByShuffler.java +++ b/ql/src/java/org/apache/hadoop/hive/ql/exec/spark/SortByShuffler.java @@ -24,6 +24,7 @@ import org.apache.spark.Partitioner; import org.apache.spark.api.java.JavaPairRDD; import org.apache.spark.api.java.function.PairFlatMapFunction; +import org.apache.spark.storage.StorageLevel; import scala.Tuple2; import java.util.*; @@ -31,12 +32,14 @@ public class SortByShuffler implements SparkShuffler { private final boolean totalOrder; + private final SparkPlan sparkPlan; /** * @param totalOrder whether this shuffler provides total order shuffle. */ - public SortByShuffler(boolean totalOrder) { + public SortByShuffler(boolean totalOrder, SparkPlan sparkPlan) { this.totalOrder = totalOrder; + this.sparkPlan = sparkPlan; } @Override @@ -45,6 +48,8 @@ public SortByShuffler(boolean totalOrder) { JavaPairRDD rdd; if (totalOrder) { if (numPartitions > 0) { + input.persist(StorageLevel.DISK_ONLY()); + sparkPlan.addCachedRDDId(input.id()); rdd = input.sortByKey(true, numPartitions); } else { rdd = input.sortByKey(true); diff --git a/ql/src/java/org/apache/hadoop/hive/ql/exec/spark/SparkPlanGenerator.java b/ql/src/java/org/apache/hadoop/hive/ql/exec/spark/SparkPlanGenerator.java index 6abef4e..66ffe5d 100644 --- a/ql/src/java/org/apache/hadoop/hive/ql/exec/spark/SparkPlanGenerator.java +++ b/ql/src/java/org/apache/hadoop/hive/ql/exec/spark/SparkPlanGenerator.java @@ -215,9 +215,9 @@ private ShuffleTran generate(SparkPlan sparkPlan, SparkEdgeProperty edge, boolea "AssertionError: SHUFFLE_NONE should only be used for UnionWork."); SparkShuffler shuffler; if (edge.isMRShuffle()) { - shuffler = new SortByShuffler(false); + shuffler = new SortByShuffler(false, sparkPlan); } else if (edge.isShuffleSort()) { - shuffler = new SortByShuffler(true); + shuffler = new SortByShuffler(true, sparkPlan); } else { shuffler = new GroupByShuffler(); }