diff --git ql/src/java/org/apache/hadoop/hive/ql/exec/spark/HiveBytesWritableCache.java ql/src/java/org/apache/hadoop/hive/ql/exec/spark/HiveBytesWritableCache.java new file mode 100644 index 0000000..3cb4669 --- /dev/null +++ ql/src/java/org/apache/hadoop/hive/ql/exec/spark/HiveBytesWritableCache.java @@ -0,0 +1,258 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.hadoop.hive.ql.exec.spark; + +import com.clearspring.analytics.util.Preconditions; +import com.esotericsoftware.kryo.io.Input; +import com.esotericsoftware.kryo.io.Output; +import org.apache.hadoop.fs.FileUtil; +import org.apache.hadoop.io.BytesWritable; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.File; +import java.io.FileInputStream; +import java.io.FileOutputStream; +import java.io.IOException; +import java.util.ArrayList; +import java.util.Iterator; +import java.util.List; +import java.util.NoSuchElementException; + +/** + * A cache with fixed buffer size for {@link BytesWritable}s. If the buffer is full, + * new entries will be spill to disk. + * NOTE: this class is NOT thread safe. It also only implement an internal iterator. + * + * Use this class in the following pattern: + * + * + * HiveBytesWritableCache cache = new ... + * + * // Write entries to cache. May persist to disk. + * while (...) { + * cache.add(..); + * } + * + * // Done with writing. Start reading from cache. + * cache.startRead(); + * while (cache.hasNext()) { + * BytesWritable bw = cache.next(); + * ... + * } + * + * // Done with reading. Close and clear the cache. + * cache.close(); + * + */ +public class HiveBytesWritableCache implements Iterable, Iterator { + private static final Logger LOG = LoggerFactory.getLogger(HiveBytesWritableCache.class.getName()); + + private static final int DEFAULT_IN_MEMORY_NUM_ROWS = 1024; + + private final List buffer; + + // Indicate whether we have flushed any data to disk. + // If this is not set, then all data can be hold in a single buffer, and thus + // no need to flush to disk and initialize input & output. + private boolean flushed; + + private File parentFile; + private File tmpFile; + + private int cursor; + private final int maxBufferEntries; + + private Input input; + private Output output; + + HiveBytesWritableCache() { + this(DEFAULT_IN_MEMORY_NUM_ROWS); + } + + HiveBytesWritableCache(int maxBufferEntries) { + this.flushed = false; + this.buffer = new ArrayList<>(); + this.maxBufferEntries = maxBufferEntries; + this.cursor = 0; + } + + public void add(BytesWritable value) { + if (buffer.size() >= maxBufferEntries) { + flushBuffer(); + } + buffer.add(value); + } + + /** + * Start reading from the cache. MUST be called before calling any + * of the iterator methods. + */ + public void startRead() { + if (flushed && !buffer.isEmpty()) { + flushBuffer(); + } + if (output != null) { + output.close(); + output = null; + } + if (flushed) { + loadBuffer(); + } + } + + @Override + public boolean hasNext() { + return cursor < buffer.size() || hasMoreInput(); + } + + @Override + public BytesWritable next() { + if (!hasNext()) { + throw new NoSuchElementException("No more next"); + } + if (cursor >= buffer.size()) { + loadBuffer(); + } + return buffer.get(cursor++); + } + + @Override + public void remove() { + throw new UnsupportedOperationException("Remove is not supported"); + } + + @Override + public Iterator iterator() { + return this; + } + + /** + * Close this cache. Idempotent. + */ + public void close() { + cursor = 0; + buffer.clear(); + + if (parentFile != null) { + if (input != null) { + try { + input.close(); + } catch (Throwable e) { + LOG.warn("Error when closing cache input.", e); + } + input = null; + } + if (output != null) { + try { + output.close(); + } catch (Throwable e) { + LOG.warn("Error when closing cache output.", e); + } + output = null; + } + FileUtil.fullyDelete(parentFile); + parentFile = null; + tmpFile = null; + } + } + + /** + * Whether there's more data to load from disk + */ + private boolean hasMoreInput() { + return input != null && !input.eof(); + } + + private void flushBuffer() { + // Initialize output temporary file if not already set + if (output == null) { + try { + setupOutput(); + } catch (IOException e) { + close(); + throw new RuntimeException("Failed to set up output stream.", e); + } + } + + try { + for (int i = 0; i < buffer.size(); i++) { + BytesWritable writable = buffer.get(i); + writeValue(output, writable); + } + buffer.clear(); + flushed = true; + } catch (Exception e) { + close(); + throw new RuntimeException("Error when spilling to disk", e); + } + } + + private void loadBuffer() { + // Initialize input stream from the temporary file if not already set + if (input == null) { + try { + FileInputStream fis = new FileInputStream(tmpFile); + input = new Input(fis); + } catch (IOException e) { + close(); + throw new RuntimeException("Failed to set up input stream.", e); + } + } + + cursor = 0; + buffer.clear(); + for (int i = 0; i < maxBufferEntries; i++) { + if (input.eof()) { + break; + } + buffer.add(readValue(input)); + } + } + + private void setupOutput() throws IOException { + Preconditions.checkState(parentFile == null && tmpFile == null); + while (true) { + parentFile = File.createTempFile("hive-resultcache", ""); + if (parentFile.delete() && parentFile.mkdir()) { + parentFile.deleteOnExit(); + break; + } + LOG.debug("Retry creating tmp result-cache directory..."); + } + + tmpFile = File.createTempFile("ResultCache", ".tmp", parentFile); + LOG.info("ResultCache created temp file " + tmpFile.getAbsolutePath()); + tmpFile.deleteOnExit(); + + FileOutputStream fos; + fos = new FileOutputStream(tmpFile); + output = new Output(fos); + } + + private BytesWritable readValue(Input input) { + return new BytesWritable(input.readBytes(input.readInt())); + } + + private void writeValue(Output output, BytesWritable bytesWritable) { + int size = bytesWritable.getLength(); + output.writeInt(size); + output.writeBytes(bytesWritable.getBytes(), 0, size); + } + +} diff --git ql/src/java/org/apache/hadoop/hive/ql/exec/spark/SortByShuffler.java ql/src/java/org/apache/hadoop/hive/ql/exec/spark/SortByShuffler.java index 997ab7e..dff0d63 100644 --- ql/src/java/org/apache/hadoop/hive/ql/exec/spark/SortByShuffler.java +++ ql/src/java/org/apache/hadoop/hive/ql/exec/spark/SortByShuffler.java @@ -27,7 +27,9 @@ import org.apache.spark.storage.StorageLevel; import scala.Tuple2; -import java.util.*; +import java.util.Iterator; +import java.util.NoSuchElementException; + public class SortByShuffler implements SparkShuffler { @@ -80,7 +82,7 @@ public String getName() { // Use input iterator to back returned iterable object. return new Iterator>>() { HiveKey curKey = null; - List curValues = new ArrayList(); + HiveBytesWritableCache curValues = new HiveBytesWritableCache(); @Override public boolean hasNext() { @@ -89,16 +91,14 @@ public boolean hasNext() { @Override public Tuple2> next() { - // TODO: implement this by accumulating rows with the same key into a list. - // Note that this list needs to improved to prevent excessive memory usage, but this - // can be done in later phase. while (it.hasNext()) { Tuple2 pair = it.next(); if (curKey != null && !curKey.equals(pair._1())) { HiveKey key = curKey; - List values = curValues; + HiveBytesWritableCache values = curValues; + values.startRead(); curKey = pair._1(); - curValues = new ArrayList(); + curValues = new HiveBytesWritableCache(); curValues.add(pair._2()); return new Tuple2>(key, values); } @@ -111,6 +111,7 @@ public boolean hasNext() { // if we get here, this should be the last element we have HiveKey key = curKey; curKey = null; + curValues.startRead(); return new Tuple2>(key, curValues); }