diff --git ql/src/java/org/apache/hadoop/hive/ql/exec/spark/HiveBytesWritableCache.java ql/src/java/org/apache/hadoop/hive/ql/exec/spark/HiveBytesWritableCache.java
new file mode 100644
index 0000000..3cb4669
--- /dev/null
+++ ql/src/java/org/apache/hadoop/hive/ql/exec/spark/HiveBytesWritableCache.java
@@ -0,0 +1,258 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.hadoop.hive.ql.exec.spark;
+
+import com.clearspring.analytics.util.Preconditions;
+import com.esotericsoftware.kryo.io.Input;
+import com.esotericsoftware.kryo.io.Output;
+import org.apache.hadoop.fs.FileUtil;
+import org.apache.hadoop.io.BytesWritable;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.io.File;
+import java.io.FileInputStream;
+import java.io.FileOutputStream;
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Iterator;
+import java.util.List;
+import java.util.NoSuchElementException;
+
+/**
+ * A cache with fixed buffer size for {@link BytesWritable}s. If the buffer is full,
+ * new entries will be spill to disk.
+ * NOTE: this class is NOT thread safe. It also only implement an internal iterator.
+ *
+ * Use this class in the following pattern:
+ *
+ *
+ * HiveBytesWritableCache cache = new ...
+ *
+ * // Write entries to cache. May persist to disk.
+ * while (...) {
+ * cache.add(..);
+ * }
+ *
+ * // Done with writing. Start reading from cache.
+ * cache.startRead();
+ * while (cache.hasNext()) {
+ * BytesWritable bw = cache.next();
+ * ...
+ * }
+ *
+ * // Done with reading. Close and clear the cache.
+ * cache.close();
+ *
+ */
+public class HiveBytesWritableCache implements Iterable, Iterator {
+ private static final Logger LOG = LoggerFactory.getLogger(HiveBytesWritableCache.class.getName());
+
+ private static final int DEFAULT_IN_MEMORY_NUM_ROWS = 1024;
+
+ private final List buffer;
+
+ // Indicate whether we have flushed any data to disk.
+ // If this is not set, then all data can be hold in a single buffer, and thus
+ // no need to flush to disk and initialize input & output.
+ private boolean flushed;
+
+ private File parentFile;
+ private File tmpFile;
+
+ private int cursor;
+ private final int maxBufferEntries;
+
+ private Input input;
+ private Output output;
+
+ HiveBytesWritableCache() {
+ this(DEFAULT_IN_MEMORY_NUM_ROWS);
+ }
+
+ HiveBytesWritableCache(int maxBufferEntries) {
+ this.flushed = false;
+ this.buffer = new ArrayList<>();
+ this.maxBufferEntries = maxBufferEntries;
+ this.cursor = 0;
+ }
+
+ public void add(BytesWritable value) {
+ if (buffer.size() >= maxBufferEntries) {
+ flushBuffer();
+ }
+ buffer.add(value);
+ }
+
+ /**
+ * Start reading from the cache. MUST be called before calling any
+ * of the iterator methods.
+ */
+ public void startRead() {
+ if (flushed && !buffer.isEmpty()) {
+ flushBuffer();
+ }
+ if (output != null) {
+ output.close();
+ output = null;
+ }
+ if (flushed) {
+ loadBuffer();
+ }
+ }
+
+ @Override
+ public boolean hasNext() {
+ return cursor < buffer.size() || hasMoreInput();
+ }
+
+ @Override
+ public BytesWritable next() {
+ if (!hasNext()) {
+ throw new NoSuchElementException("No more next");
+ }
+ if (cursor >= buffer.size()) {
+ loadBuffer();
+ }
+ return buffer.get(cursor++);
+ }
+
+ @Override
+ public void remove() {
+ throw new UnsupportedOperationException("Remove is not supported");
+ }
+
+ @Override
+ public Iterator iterator() {
+ return this;
+ }
+
+ /**
+ * Close this cache. Idempotent.
+ */
+ public void close() {
+ cursor = 0;
+ buffer.clear();
+
+ if (parentFile != null) {
+ if (input != null) {
+ try {
+ input.close();
+ } catch (Throwable e) {
+ LOG.warn("Error when closing cache input.", e);
+ }
+ input = null;
+ }
+ if (output != null) {
+ try {
+ output.close();
+ } catch (Throwable e) {
+ LOG.warn("Error when closing cache output.", e);
+ }
+ output = null;
+ }
+ FileUtil.fullyDelete(parentFile);
+ parentFile = null;
+ tmpFile = null;
+ }
+ }
+
+ /**
+ * Whether there's more data to load from disk
+ */
+ private boolean hasMoreInput() {
+ return input != null && !input.eof();
+ }
+
+ private void flushBuffer() {
+ // Initialize output temporary file if not already set
+ if (output == null) {
+ try {
+ setupOutput();
+ } catch (IOException e) {
+ close();
+ throw new RuntimeException("Failed to set up output stream.", e);
+ }
+ }
+
+ try {
+ for (int i = 0; i < buffer.size(); i++) {
+ BytesWritable writable = buffer.get(i);
+ writeValue(output, writable);
+ }
+ buffer.clear();
+ flushed = true;
+ } catch (Exception e) {
+ close();
+ throw new RuntimeException("Error when spilling to disk", e);
+ }
+ }
+
+ private void loadBuffer() {
+ // Initialize input stream from the temporary file if not already set
+ if (input == null) {
+ try {
+ FileInputStream fis = new FileInputStream(tmpFile);
+ input = new Input(fis);
+ } catch (IOException e) {
+ close();
+ throw new RuntimeException("Failed to set up input stream.", e);
+ }
+ }
+
+ cursor = 0;
+ buffer.clear();
+ for (int i = 0; i < maxBufferEntries; i++) {
+ if (input.eof()) {
+ break;
+ }
+ buffer.add(readValue(input));
+ }
+ }
+
+ private void setupOutput() throws IOException {
+ Preconditions.checkState(parentFile == null && tmpFile == null);
+ while (true) {
+ parentFile = File.createTempFile("hive-resultcache", "");
+ if (parentFile.delete() && parentFile.mkdir()) {
+ parentFile.deleteOnExit();
+ break;
+ }
+ LOG.debug("Retry creating tmp result-cache directory...");
+ }
+
+ tmpFile = File.createTempFile("ResultCache", ".tmp", parentFile);
+ LOG.info("ResultCache created temp file " + tmpFile.getAbsolutePath());
+ tmpFile.deleteOnExit();
+
+ FileOutputStream fos;
+ fos = new FileOutputStream(tmpFile);
+ output = new Output(fos);
+ }
+
+ private BytesWritable readValue(Input input) {
+ return new BytesWritable(input.readBytes(input.readInt()));
+ }
+
+ private void writeValue(Output output, BytesWritable bytesWritable) {
+ int size = bytesWritable.getLength();
+ output.writeInt(size);
+ output.writeBytes(bytesWritable.getBytes(), 0, size);
+ }
+
+}
diff --git ql/src/java/org/apache/hadoop/hive/ql/exec/spark/SortByShuffler.java ql/src/java/org/apache/hadoop/hive/ql/exec/spark/SortByShuffler.java
index 997ab7e..dff0d63 100644
--- ql/src/java/org/apache/hadoop/hive/ql/exec/spark/SortByShuffler.java
+++ ql/src/java/org/apache/hadoop/hive/ql/exec/spark/SortByShuffler.java
@@ -27,7 +27,9 @@
import org.apache.spark.storage.StorageLevel;
import scala.Tuple2;
-import java.util.*;
+import java.util.Iterator;
+import java.util.NoSuchElementException;
+
public class SortByShuffler implements SparkShuffler {
@@ -80,7 +82,7 @@ public String getName() {
// Use input iterator to back returned iterable object.
return new Iterator>>() {
HiveKey curKey = null;
- List curValues = new ArrayList();
+ HiveBytesWritableCache curValues = new HiveBytesWritableCache();
@Override
public boolean hasNext() {
@@ -89,16 +91,14 @@ public boolean hasNext() {
@Override
public Tuple2> next() {
- // TODO: implement this by accumulating rows with the same key into a list.
- // Note that this list needs to improved to prevent excessive memory usage, but this
- // can be done in later phase.
while (it.hasNext()) {
Tuple2 pair = it.next();
if (curKey != null && !curKey.equals(pair._1())) {
HiveKey key = curKey;
- List values = curValues;
+ HiveBytesWritableCache values = curValues;
+ values.startRead();
curKey = pair._1();
- curValues = new ArrayList();
+ curValues = new HiveBytesWritableCache();
curValues.add(pair._2());
return new Tuple2>(key, values);
}
@@ -111,6 +111,7 @@ public boolean hasNext() {
// if we get here, this should be the last element we have
HiveKey key = curKey;
curKey = null;
+ curValues.startRead();
return new Tuple2>(key, curValues);
}