diff --git ql/src/java/org/apache/hadoop/hive/ql/exec/spark/CacheTran.java ql/src/java/org/apache/hadoop/hive/ql/exec/spark/CacheTran.java index 5ec27ec..490aa30 100644 --- ql/src/java/org/apache/hadoop/hive/ql/exec/spark/CacheTran.java +++ ql/src/java/org/apache/hadoop/hive/ql/exec/spark/CacheTran.java @@ -38,7 +38,7 @@ protected CacheTran(boolean cache) { if (caching) { if (cachedRDD == null) { cachedRDD = doTransform(input); - cachedRDD.persist(StorageLevel.MEMORY_AND_DISK()); + //cachedRDD.persist(StorageLevel.MEMORY_AND_DISK()); } return cachedRDD; } else { diff --git ql/src/java/org/apache/hadoop/hive/ql/exec/spark/GroupByShuffler.java ql/src/java/org/apache/hadoop/hive/ql/exec/spark/GroupByShuffler.java index e128dd2..7406b52 100644 --- ql/src/java/org/apache/hadoop/hive/ql/exec/spark/GroupByShuffler.java +++ ql/src/java/org/apache/hadoop/hive/ql/exec/spark/GroupByShuffler.java @@ -18,23 +18,138 @@ package org.apache.hadoop.hive.ql.exec.spark; +import java.util.Iterator; + import org.apache.hadoop.hive.ql.io.HiveKey; import org.apache.hadoop.io.BytesWritable; import org.apache.spark.api.java.JavaPairRDD; +import org.apache.spark.api.java.function.PairFlatMapFunction; + +import scala.Tuple2; public class GroupByShuffler implements SparkShuffler { @Override public JavaPairRDD> shuffle( JavaPairRDD input, int numPartitions) { + JavaPairRDD rdd; if (numPartitions > 0) { - return input.groupByKey(numPartitions); + rdd = input.repartition(numPartitions); } - return input.groupByKey(); + rdd = input.repartition(1); + return rdd.mapPartitionsToPair(new ShuffleFunction()); } @Override public String getName() { return "GroupBy"; } + + private static class ShuffleFunction implements + PairFlatMapFunction>, + HiveKey, Iterable> { + // make eclipse happy + private static final long serialVersionUID = 1L; + + @Override + public Iterator>> call( + final Iterator> it) throws Exception { + // Use input iterator to back returned iterable object as well as the iterable value object. + // Since two iterables are both backed by a single iterator, it's expected that for each + // tuple in the outer iterator, the value iterator must be consumed before moving on to the + // next tupe in the outer iterator. + final PeekableIterator peekableIt = new PeekableIterator(it); + return new Iterator>>() { + + @Override + public boolean hasNext() { + return peekableIt.hasNext(); + // This should be fine since it's expected that when this is called, peekableIt should have + // a new key at the beginning or nothing left at all. + } + + @Override + public Tuple2> next() { + Tuple2 pair = peekableIt.peek(); + final HiveKey key = pair._1(); + + final Iterator values = new Iterator() { + + @Override + public boolean hasNext() { + Tuple2 nextPair = peekableIt.peek(); + return nextPair != null && key.equals(nextPair._1()); + } + + @Override + public BytesWritable next() { + return peekableIt.next()._2(); + } + + @Override + public void remove() { + throw new UnsupportedOperationException(); + } + + }; + + return new Tuple2>(key, new Iterable() { + + @Override + public Iterator iterator() { + return values; + } + + }); + } + + @Override + public void remove() { + // Not implemented. + // throw Unsupported Method Invocation Exception. + throw new UnsupportedOperationException(); + } + + }; + } + } + + private static class PeekableIterator implements Iterator> { + private Iterator> it; + Tuple2 next = null; + + public PeekableIterator(Iterator> it) { + this.it = it; + next = it.hasNext() ? it.next() : null; + } + + @Override + public boolean hasNext() { + return (next != null); + + } + + @Override + public Tuple2 next() { + if (next == null) { + throw new RuntimeException("PeekableIterator.next() shouldn't be called unless hasNext() returns true"); + } + Tuple2 current = next; + next = (it.hasNext() ? it.next() : null); + return current; + } + + public Tuple2 peek() { + return next; + } + + @Override + public void remove() { + // Not implemented. + // throw Unsupported Method Invocation Exception. + throw new UnsupportedOperationException(); + } + + } + } diff --git ql/src/java/org/apache/hadoop/hive/ql/exec/spark/MapInput.java ql/src/java/org/apache/hadoop/hive/ql/exec/spark/MapInput.java index 26cfebd..166fe70 100644 --- ql/src/java/org/apache/hadoop/hive/ql/exec/spark/MapInput.java +++ ql/src/java/org/apache/hadoop/hive/ql/exec/spark/MapInput.java @@ -59,7 +59,7 @@ public void setToCache(boolean toCache) { Preconditions.checkArgument(input == null, "AssertionError: MapInput doesn't take any input"); JavaPairRDD result; - if (toCache) { + if (false) { result = hadoopRDD.mapToPair(new CopyFunction()); sparkPlan.addCachedRDDId(result.id()); result = result.persist(StorageLevel.MEMORY_AND_DISK()); diff --git ql/src/java/org/apache/hadoop/hive/ql/exec/spark/ShuffleTran.java ql/src/java/org/apache/hadoop/hive/ql/exec/spark/ShuffleTran.java index a774395..653747f 100644 --- ql/src/java/org/apache/hadoop/hive/ql/exec/spark/ShuffleTran.java +++ ql/src/java/org/apache/hadoop/hive/ql/exec/spark/ShuffleTran.java @@ -44,7 +44,7 @@ public ShuffleTran(SparkPlan sparkPlan, SparkShuffler sf, int n, boolean toCache @Override public JavaPairRDD> transform(JavaPairRDD input) { JavaPairRDD> result = shuffler.shuffle(input, numOfPartitions); - if (toCache) { + if (false) { sparkPlan.addCachedRDDId(result.id()); result = result.persist(StorageLevel.MEMORY_AND_DISK()); } diff --git ql/src/java/org/apache/hadoop/hive/ql/exec/spark/SortByShuffler.java ql/src/java/org/apache/hadoop/hive/ql/exec/spark/SortByShuffler.java index 997ab7e..7c38343 100644 --- ql/src/java/org/apache/hadoop/hive/ql/exec/spark/SortByShuffler.java +++ ql/src/java/org/apache/hadoop/hive/ql/exec/spark/SortByShuffler.java @@ -48,7 +48,7 @@ public SortByShuffler(boolean totalOrder, SparkPlan sparkPlan) { JavaPairRDD rdd; if (totalOrder) { if (numPartitions > 0) { - if (numPartitions > 1 && input.getStorageLevel() == StorageLevel.NONE()) { + if (numPartitions > 1 && input.getStorageLevel() == StorageLevel.NONE() && false) { input.persist(StorageLevel.DISK_ONLY()); sparkPlan.addCachedRDDId(input.id()); }