diff --git ql/src/java/org/apache/hadoop/hive/ql/exec/spark/GroupByShuffler.java ql/src/java/org/apache/hadoop/hive/ql/exec/spark/GroupByShuffler.java index e128dd2..7406b52 100644 --- ql/src/java/org/apache/hadoop/hive/ql/exec/spark/GroupByShuffler.java +++ ql/src/java/org/apache/hadoop/hive/ql/exec/spark/GroupByShuffler.java @@ -18,23 +18,138 @@ package org.apache.hadoop.hive.ql.exec.spark; +import java.util.Iterator; + import org.apache.hadoop.hive.ql.io.HiveKey; import org.apache.hadoop.io.BytesWritable; import org.apache.spark.api.java.JavaPairRDD; +import org.apache.spark.api.java.function.PairFlatMapFunction; + +import scala.Tuple2; public class GroupByShuffler implements SparkShuffler { @Override public JavaPairRDD> shuffle( JavaPairRDD input, int numPartitions) { + JavaPairRDD rdd; if (numPartitions > 0) { - return input.groupByKey(numPartitions); + rdd = input.repartition(numPartitions); } - return input.groupByKey(); + rdd = input.repartition(1); + return rdd.mapPartitionsToPair(new ShuffleFunction()); } @Override public String getName() { return "GroupBy"; } + + private static class ShuffleFunction implements + PairFlatMapFunction>, + HiveKey, Iterable> { + // make eclipse happy + private static final long serialVersionUID = 1L; + + @Override + public Iterator>> call( + final Iterator> it) throws Exception { + // Use input iterator to back returned iterable object as well as the iterable value object. + // Since two iterables are both backed by a single iterator, it's expected that for each + // tuple in the outer iterator, the value iterator must be consumed before moving on to the + // next tupe in the outer iterator. + final PeekableIterator peekableIt = new PeekableIterator(it); + return new Iterator>>() { + + @Override + public boolean hasNext() { + return peekableIt.hasNext(); + // This should be fine since it's expected that when this is called, peekableIt should have + // a new key at the beginning or nothing left at all. + } + + @Override + public Tuple2> next() { + Tuple2 pair = peekableIt.peek(); + final HiveKey key = pair._1(); + + final Iterator values = new Iterator() { + + @Override + public boolean hasNext() { + Tuple2 nextPair = peekableIt.peek(); + return nextPair != null && key.equals(nextPair._1()); + } + + @Override + public BytesWritable next() { + return peekableIt.next()._2(); + } + + @Override + public void remove() { + throw new UnsupportedOperationException(); + } + + }; + + return new Tuple2>(key, new Iterable() { + + @Override + public Iterator iterator() { + return values; + } + + }); + } + + @Override + public void remove() { + // Not implemented. + // throw Unsupported Method Invocation Exception. + throw new UnsupportedOperationException(); + } + + }; + } + } + + private static class PeekableIterator implements Iterator> { + private Iterator> it; + Tuple2 next = null; + + public PeekableIterator(Iterator> it) { + this.it = it; + next = it.hasNext() ? it.next() : null; + } + + @Override + public boolean hasNext() { + return (next != null); + + } + + @Override + public Tuple2 next() { + if (next == null) { + throw new RuntimeException("PeekableIterator.next() shouldn't be called unless hasNext() returns true"); + } + Tuple2 current = next; + next = (it.hasNext() ? it.next() : null); + return current; + } + + public Tuple2 peek() { + return next; + } + + @Override + public void remove() { + // Not implemented. + // throw Unsupported Method Invocation Exception. + throw new UnsupportedOperationException(); + } + + } + }