diff --git ql/src/java/org/apache/hadoop/hive/ql/exec/spark/HiveBaseFunctionResultList.java ql/src/java/org/apache/hadoop/hive/ql/exec/spark/HiveBaseFunctionResultList.java index 0df2580..80f370f 100644 --- ql/src/java/org/apache/hadoop/hive/ql/exec/spark/HiveBaseFunctionResultList.java +++ ql/src/java/org/apache/hadoop/hive/ql/exec/spark/HiveBaseFunctionResultList.java @@ -17,19 +17,19 @@ */ package org.apache.hadoop.hive.ql.exec.spark; -import com.google.common.base.Preconditions; +import java.io.IOException; +import java.io.Serializable; +import java.util.Iterator; +import java.util.NoSuchElementException; + import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.hive.ql.io.HiveKey; import org.apache.hadoop.io.BytesWritable; import org.apache.hadoop.mapred.OutputCollector; + import scala.Tuple2; -import java.io.IOException; -import java.io.Serializable; -import java.util.ArrayList; -import java.util.Iterator; -import java.util.List; -import java.util.NoSuchElementException; +import com.google.common.base.Preconditions; /** * Base class for @@ -38,9 +38,10 @@ * are processed in lazy fashion i.e when output records are requested * through Iterator interface. */ +@SuppressWarnings("rawtypes") public abstract class HiveBaseFunctionResultList implements Iterable, OutputCollector, Serializable { - + private static final long serialVersionUID = -5756253667434351741L; private final Iterator inputIterator; private boolean isClosed = false; @@ -106,11 +107,9 @@ public boolean hasNext(){ while (inputIterator.hasNext() && !processingDone()) { try { processNextRecord(inputIterator.next()); - // TODO Current HiveKVResultCache does not support read-then-write, - // should not enable lazy execution here. See HIVE-7873 - // if (lastRecordOutput.hasNext()) { - // return true; - // } + if (lastRecordOutput.hasNext()) { + return true; + } } catch (IOException ex) { // TODO: better handling of exception. throw new RuntimeException("Error while processing input.", ex); diff --git ql/src/java/org/apache/hadoop/hive/ql/exec/spark/HiveKVResultCache.java ql/src/java/org/apache/hadoop/hive/ql/exec/spark/HiveKVResultCache.java index a6b9037..b5e5df0 100644 --- ql/src/java/org/apache/hadoop/hive/ql/exec/spark/HiveKVResultCache.java +++ ql/src/java/org/apache/hadoop/hive/ql/exec/spark/HiveKVResultCache.java @@ -17,7 +17,9 @@ */ package org.apache.hadoop.hive.ql.exec.spark; -import com.google.common.base.Preconditions; +import java.util.ArrayList; +import java.util.List; + import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.hive.conf.HiveConf; import org.apache.hadoop.hive.ql.exec.persistence.RowContainer; @@ -32,14 +34,21 @@ import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils.ObjectInspectorCopyOption; import org.apache.hadoop.io.BytesWritable; import org.apache.hadoop.mapred.Reporter; + import scala.Tuple2; -import java.util.ArrayList; -import java.util.List; +import com.google.common.base.Preconditions; /** * Wrapper around {@link org.apache.hadoop.hive.ql.exec.persistence.RowContainer} + * + * This class is thread safe under one condition: only one thread invokes + * {@link #add(HiveKey, BytesWritable)} and the other methods like {@link #next()} + * or {@link #hasNext()}, while other threads invoke only {@link #add(HiveKey, BytesWritable)} + * concurrently. It is not fully thread safe, for example, we can't have two threads + * invoke {@link #next()} at the same time. */ +@SuppressWarnings({"deprecation", "unchecked", "rawtypes"}) public class HiveKVResultCache { public static final int IN_MEMORY_CACHE_SIZE = 512; @@ -47,14 +56,20 @@ private static final String COL_TYPES = serdeConstants.BINARY_TYPE_NAME + ":" + serdeConstants.BINARY_TYPE_NAME; + // Used to cache rows added while container is iterated. + private RowContainer backupContainer; + private RowContainer container; + private Configuration conf; private int cursor = 0; public HiveKVResultCache(Configuration conf) { - initRowContainer(conf); + container = initRowContainer(conf); + this.conf = conf; } - private void initRowContainer(Configuration conf) { + private static RowContainer initRowContainer(Configuration conf) { + RowContainer container; try { container = new RowContainer(IN_MEMORY_CACHE_SIZE, conf, Reporter.NULL); @@ -71,6 +86,7 @@ private void initRowContainer(Configuration conf) { } catch(Exception ex) { throw new RuntimeException("Failed to create RowContainer", ex); } + return container; } public void add(HiveKey key, BytesWritable value) { @@ -80,14 +96,26 @@ public void add(HiveKey key, BytesWritable value) { row.add(wrappedHiveKey); row.add(value); - try { - container.addRow(row); - } catch (HiveException ex) { - throw new RuntimeException("Failed to add KV pair to RowContainer", ex); + synchronized (this) { + try { + if (cursor == 0) { + container.addRow(row); + } else { + if (backupContainer == null) { + backupContainer = initRowContainer(conf); + } + backupContainer.addRow(row); + } + } catch (HiveException ex) { + throw new RuntimeException("Failed to add KV pair to RowContainer", ex); + } } } public void clear() { + if (cursor == 0) { + return; + } try { container.clearRows(); } catch(HiveException ex) { @@ -97,7 +125,21 @@ public void clear() { } public boolean hasNext() { - return container.rowCount() > 0 && cursor < container.rowCount(); + if (container.rowCount() > 0 && cursor < container.rowCount()) { + return true; + } + synchronized (this) { + if (backupContainer == null + || backupContainer.rowCount() == 0) { + return false; + } + clear(); + // Switch containers + RowContainer tmp = container; + container = backupContainer; + backupContainer = tmp; + return true; + } } public Tuple2 next() { @@ -106,11 +148,15 @@ public boolean hasNext() { try { List row; if (cursor == 0) { - row = container.first(); + synchronized (this) { + // Prevent racing with add() + row = container.first(); + cursor++; + } } else { row = container.next(); + cursor++; } - cursor++; HiveKey key = KryoSerializer.deserialize(row.get(0).getBytes(), HiveKey.class); return new Tuple2(key, row.get(1)); } catch (HiveException ex) { diff --git ql/src/test/org/apache/hadoop/hive/ql/exec/spark/TestHiveKVResultCache.java ql/src/test/org/apache/hadoop/hive/ql/exec/spark/TestHiveKVResultCache.java index 496a11f..6144436 100644 --- ql/src/test/org/apache/hadoop/hive/ql/exec/spark/TestHiveKVResultCache.java +++ ql/src/test/org/apache/hadoop/hive/ql/exec/spark/TestHiveKVResultCache.java @@ -17,14 +17,27 @@ */ package org.apache.hadoop.hive.ql.exec.spark; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.HashSet; +import java.util.Iterator; +import java.util.List; +import java.util.concurrent.LinkedBlockingQueue; + +import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.hive.conf.HiveConf; import org.apache.hadoop.hive.ql.io.HiveKey; import org.apache.hadoop.io.BytesWritable; import org.junit.Test; + import scala.Tuple2; -import static org.junit.Assert.assertTrue; +import com.clearspring.analytics.util.Preconditions; +@SuppressWarnings({"unchecked", "rawtypes"}) public class TestHiveKVResultCache { @Test public void testSimple() throws Exception { @@ -87,4 +100,237 @@ private void testSpillingHelper(HiveKVResultCache cache, int numRecords) { cache.clear(); } -} \ No newline at end of file + + @Test + public void testResultList() throws Exception { + scanAndVerify(10000, 0, 0, "a", "b"); + scanAndVerify(10000, 512, 0, "a", "b"); + scanAndVerify(10000, 512 * 2, 0, "a", "b"); + scanAndVerify(10000, 512, 10, "a", "b"); + scanAndVerify(10000, 512 * 2, 10, "a", "b"); + } + + private static void scanAndVerify( + long rows, int threshold, int separate, String prefix1, String prefix2) { + ArrayList> output = + new ArrayList>((int)rows); + scanResultList(rows, threshold, separate, output, prefix1, prefix2); + assertEquals(rows, output.size()); + long primaryRows = rows * (100 - separate) / 100; + long separateRows = rows - primaryRows; + HashSet primaryRowKeys = new HashSet(); + HashSet separateRowKeys = new HashSet(); + for (Tuple2 item: output) { + String key = new String(item._1.copyBytes()); + String value = new String(item._2.copyBytes()); + String prefix = key.substring(0, key.indexOf('_')); + Long id = Long.valueOf(key.substring(5 + prefix.length())); + if (prefix.equals(prefix1)) { + assertTrue(id >= 0 && id < primaryRows); + primaryRowKeys.add(id); + } else { + assertEquals(prefix2, prefix); + assertTrue(id >= 0 && id < separateRows); + separateRowKeys.add(id); + } + assertEquals(prefix + "_value_" + id, value); + } + assertEquals(separateRows, separateRowKeys.size()); + assertEquals(primaryRows, primaryRowKeys.size()); + } + + private static class MyHiveFunctionResultList extends HiveBaseFunctionResultList { + private static final long serialVersionUID = 3924109264432927384L; + + // Total rows to emit during the whole iteration, + // excluding the rows emitted by the separate thread. + private long primaryRows; + // Batch of rows to emit per processNextRecord() call. + private int thresholdRows; + // Rows to be emitted with a separate thread per processNextRecord() call. + private long separateRows; + // Thread to generate the separate rows beside the normal thread. + private Thread separateRowGenerator; + + // Counter for rows emitted + private long rowsEmitted; + private long separateRowsEmitted; + + // Prefix for primary row keys + private String prefix1; + // Prefix for separate row keys + private String prefix2; + + // A queue to notify separateRowGenerator to generate the next batch of rows. + private LinkedBlockingQueue queue; + + MyHiveFunctionResultList(Configuration conf, Iterator inputIterator) { + super(conf, inputIterator); + } + + void init(long rows, int threshold, int separate, String p1, String p2) { + Preconditions.checkArgument((threshold > 0 || separate == 0) + && separate < 100 && separate >= 0 && rows > 0); + primaryRows = rows * (100 - separate) / 100; + separateRows = rows - primaryRows; + thresholdRows = threshold; + prefix1 = p1; + prefix2 = p2; + if (separateRows > 0) { + separateRowGenerator = new Thread(new Runnable() { + @Override + public void run() { + try { + long separateBatchSize = thresholdRows * separateRows / primaryRows; + while (!queue.take().booleanValue()) { + for (int i = 0; i < separateBatchSize; i++) { + collect(prefix2, separateRowsEmitted++); + } + } + } catch (InterruptedException e) { + e.printStackTrace(); + } + for (; separateRowsEmitted < separateRows;) { + collect(prefix2, separateRowsEmitted++); + } + } + }); + queue = new LinkedBlockingQueue(); + separateRowGenerator.start(); + } + } + + public void collect(String prefix, long id) { + String k = prefix + "_key_" + id; + String v = prefix + "_value_" + id; + HiveKey key = new HiveKey(k.getBytes(), k.hashCode()); + BytesWritable value = new BytesWritable(v.getBytes()); + try { + collect(key, value); + } catch (IOException e) { + e.printStackTrace(); + } + } + + @Override + protected void processNextRecord(Object inputRecord) throws IOException { + for (int i = 0; i < thresholdRows; i++) { + collect(prefix1, rowsEmitted++); + } + if (separateRowGenerator != null) { + queue.add(Boolean.FALSE); + } + } + + @Override + protected boolean processingDone() { + return false; + } + + @Override + protected void closeRecordProcessor() { + for (; rowsEmitted < primaryRows;) { + collect(prefix1, rowsEmitted++); + } + if (separateRowGenerator != null) { + queue.add(Boolean.TRUE); + try { + separateRowGenerator.join(); + } catch (InterruptedException e) { + e.printStackTrace(); + } + } + } + } + + private static long scanResultList(long rows, int threshold, int separate, + List> output, String prefix1, String prefix2) { + final long iteratorCount = threshold == 0 ? 1 : rows * (100 - separate) / 100 / threshold; + MyHiveFunctionResultList resultList = new MyHiveFunctionResultList( + new HiveConf(), new Iterator() { + // Input record iterator, not used + private int i = 0; + @Override + public boolean hasNext() { + return i++ < iteratorCount; + } + + @Override + public Object next() { + return Integer.valueOf(i); + } + + @Override + public void remove() { + } + }); + + resultList.init(rows, threshold, separate, prefix1, prefix2); + long startTime = System.currentTimeMillis(); + Iterator it = resultList.iterator(); + while (it.hasNext()) { + Object item = it.next(); + if (output != null) { + output.add((Tuple2)item); + } + } + long endTime = System.currentTimeMillis(); + return endTime - startTime; + } + + private static long[] scanResultList(long rows, int threshold, int extra) { + // 1. Simulate emitting all records in closeRecordProcessor(). + long t1 = scanResultList(rows, 0, 0, null, "a", "b"); + + // 2. Simulate emitting records in processNextRecord() with small memory usage limit. + long t2 = scanResultList(rows, threshold, 0, null, "c", "d"); + + // 3. Simulate emitting records in processNextRecord() with large memory usage limit. + long t3 = scanResultList(rows, threshold * 10, 0, null, "e", "f"); + + // 4. Same as 2. Also emit extra records from a separate thread. + long t4 = scanResultList(rows, threshold, extra, null, "g", "h"); + + // 5. Same as 3. Also emit extra records from a separate thread. + long t5 = scanResultList(rows, threshold * 10, extra, null, "i", "j"); + + return new long[] {t1, t2, t3, t4, t5}; + } + + public static void main(String[] args) throws Exception { + long rows = 1000000; // total rows to generate + int threshold = 512; // # of rows to cache at most + int extra = 5; // percentile of extra rows to generate by a different thread + + if (args.length > 0) { + rows = Long.parseLong(args[0]); + } + if (args.length > 1) { + threshold = Integer.parseInt(args[1]); + } + if (args.length > 2) { + extra = Integer.parseInt(args[2]); + } + + // Warm up couple times + for (int i = 0; i < 2; i++) { + scanResultList(rows, threshold, extra); + } + + int count = 5; + long[] t = new long[count]; + // Run count times and get average + for (int i = 0; i < count; i++) { + long[] tmp = scanResultList(rows, threshold, extra); + for (int k = 0; k < count; k++) { + t[k] += tmp[k]; + } + } + for (int i = 0; i < count; i++) { + t[i] /= count; + } + + System.out.println(t[0] + "\t" + t[1] + "\t" + t[2] + + "\t" + t[3] + "\t" + t[4]); + } +}