diff --git ql/pom.xml ql/pom.xml
index d73deba440..2d3a0c5468 100644
--- ql/pom.xml
+++ ql/pom.xml
@@ -770,6 +770,12 @@
${powermock.version}
test
+
+ com.google.guava
+ guava-testlib
+ ${guava.version}
+ test
+
diff --git ql/src/java/org/apache/hadoop/hive/ql/exec/spark/HashTableLoader.java ql/src/java/org/apache/hadoop/hive/ql/exec/spark/HashTableLoader.java
index cf27e92baf..e60dbaef8e 100644
--- ql/src/java/org/apache/hadoop/hive/ql/exec/spark/HashTableLoader.java
+++ ql/src/java/org/apache/hadoop/hive/ql/exec/spark/HashTableLoader.java
@@ -21,6 +21,7 @@
import java.util.Arrays;
import java.util.List;
import java.util.Set;
+import java.util.concurrent.ExecutionException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -159,22 +160,19 @@ private MapJoinTableContainer load(FileSystem fs, Path path,
MapJoinTableContainerSerDe mapJoinTableSerde) throws HiveException {
LOG.info("\tLoad back all hashtable files from tmp folder uri:" + path);
if (!SparkUtilities.isDedicatedCluster(hconf)) {
- return useFastContainer ? mapJoinTableSerde.loadFastContainer(desc, fs, path, hconf) :
- mapJoinTableSerde.load(fs, path, hconf);
+ return loadMapJoinTableContainer(fs, path, mapJoinTableSerde);
}
- MapJoinTableContainer mapJoinTable = SmallTableCache.get(path);
- if (mapJoinTable == null) {
- synchronized (path.toString().intern()) {
- mapJoinTable = SmallTableCache.get(path);
- if (mapJoinTable == null) {
- mapJoinTable = useFastContainer ?
- mapJoinTableSerde.loadFastContainer(desc, fs, path, hconf) :
- mapJoinTableSerde.load(fs, path, hconf);
- SmallTableCache.cache(path, mapJoinTable);
- }
- }
+
+ try {
+ return SmallTableCache.get(path, () -> loadMapJoinTableContainer(fs, path, mapJoinTableSerde));
+ } catch (ExecutionException e) {
+ throw new HiveException(e);
}
- return mapJoinTable;
+ }
+
+ private MapJoinTableContainer loadMapJoinTableContainer(FileSystem fs, Path path, MapJoinTableContainerSerDe mapJoinTableSerde) throws HiveException {
+ return useFastContainer ? mapJoinTableSerde.loadFastContainer(desc, fs, path, hconf) :
+ mapJoinTableSerde.load(fs, path, hconf);
}
private void loadDirectly(MapJoinTableContainer[] mapJoinTables, String inputFileName)
diff --git ql/src/java/org/apache/hadoop/hive/ql/exec/spark/SmallTableCache.java ql/src/java/org/apache/hadoop/hive/ql/exec/spark/SmallTableCache.java
index 3293100af9..a1b8952b7a 100644
--- ql/src/java/org/apache/hadoop/hive/ql/exec/spark/SmallTableCache.java
+++ ql/src/java/org/apache/hadoop/hive/ql/exec/spark/SmallTableCache.java
@@ -17,20 +17,32 @@
*/
package org.apache.hadoop.hive.ql.exec.spark;
-import java.util.concurrent.ConcurrentHashMap;
-
-import org.slf4j.Logger;
-import org.slf4j.LoggerFactory;
+import com.google.common.annotations.VisibleForTesting;
+import com.google.common.base.Ticker;
+import com.google.common.cache.Cache;
+import com.google.common.cache.CacheBuilder;
+import com.google.common.util.concurrent.ThreadFactoryBuilder;
+import org.apache.commons.lang3.math.NumberUtils;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.hive.conf.HiveConf;
import org.apache.hadoop.hive.ql.exec.persistence.MapJoinTableContainer;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.util.concurrent.Callable;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.Executors;
+import java.util.concurrent.ScheduledExecutorService;
+import java.util.concurrent.TimeUnit;
+import java.util.function.BiConsumer;
public class SmallTableCache {
private static final Logger LOG = LoggerFactory.getLogger(SmallTableCache.class.getName());
- private static final ConcurrentHashMap
- tableContainerMap = new ConcurrentHashMap();
+ private static final SmallTableLocalCache
+ tableContainerCache = new SmallTableLocalCache();
+
private static volatile String queryId;
/**
@@ -40,13 +52,10 @@
public static void initialize(Configuration conf) {
String currentQueryId = conf.get(HiveConf.ConfVars.HIVEQUERYID.varname);
if (!currentQueryId.equals(queryId)) {
- if (!tableContainerMap.isEmpty()) {
- synchronized (tableContainerMap) {
- if (!currentQueryId.equals(queryId) && !tableContainerMap.isEmpty()) {
- for (MapJoinTableContainer tableContainer: tableContainerMap.values()) {
- tableContainer.clear();
- }
- tableContainerMap.clear();
+ if (tableContainerCache.size() != 0) {
+ synchronized (tableContainerCache) {
+ if (!currentQueryId.equals(queryId) && tableContainerCache.size() != 0) {
+ tableContainerCache.clear((path, tableContainer) -> tableContainer.clear());
if (LOG.isDebugEnabled()) {
LOG.debug("Cleaned up small table cache for query " + queryId);
}
@@ -57,17 +66,59 @@ public static void initialize(Configuration conf) {
}
}
- public static void cache(Path path, MapJoinTableContainer tableContainer) {
- if (tableContainerMap.putIfAbsent(path, tableContainer) == null && LOG.isDebugEnabled()) {
- LOG.debug("Cached small table file " + path + " for query " + queryId);
- }
+ public static MapJoinTableContainer get(Path path, Callable valueLoader)
+ throws ExecutionException {
+ return tableContainerCache.get(path, valueLoader);
}
- public static MapJoinTableContainer get(Path path) {
- MapJoinTableContainer tableContainer = tableContainerMap.get(path);
- if (tableContainer != null && LOG.isDebugEnabled()) {
- LOG.debug("Loaded small table file " + path + " from cache for query " + queryId);
+ @VisibleForTesting
+ static class SmallTableLocalCache {
+
+ private static final int MAINTENANCE_THREAD_CLEANUP_PERIOD = 10;
+ private static final int L1_CACHE_EXPIRE_DURATION = 30;
+
+ private final Cache L1;
+ private final Cache L2;
+ private final ScheduledExecutorService scheduler;
+
+ public SmallTableLocalCache() {
+ this(Ticker.systemTicker());
+ }
+
+ @VisibleForTesting
+ SmallTableLocalCache(Ticker ticker) {
+ scheduler = Executors.newScheduledThreadPool(NumberUtils.INTEGER_ONE, new ThreadFactoryBuilder().setNameFormat("SmallTableCache maintenance thread").setDaemon(true).build());
+ L1 = CacheBuilder.newBuilder().expireAfterAccess(L1_CACHE_EXPIRE_DURATION, TimeUnit.SECONDS).ticker(ticker).build();
+ L2 = CacheBuilder.newBuilder().softValues().build();
+ scheduler.scheduleAtFixedRate(() -> {
+ cleanup();
+ }, NumberUtils.INTEGER_ZERO, MAINTENANCE_THREAD_CLEANUP_PERIOD, TimeUnit.SECONDS);
+ }
+
+ // L2 >= L1, because if a cached item is in L1 then its in L2 as well.
+ public long size(){
+ return L2.size();
+ }
+
+ public void clear(BiConsumer super K, ? super V> action) {
+ L1.invalidateAll();
+ L2.asMap().forEach(action);
+ L2.invalidateAll();
+ }
+
+ @VisibleForTesting
+ void cleanup() {
+ L1.cleanUp();
+ }
+
+ public V get(K key, Callable valueLoader) throws ExecutionException {
+ V value = L1.getIfPresent(key);
+ if (value == null) {
+ value = L2.get(key, valueLoader);
+ L1.put(key, value);
+ }
+ return value;
}
- return tableContainer;
}
+
}
diff --git ql/src/test/org/apache/hadoop/hive/ql/exec/spark/TestSmallTableCache.java ql/src/test/org/apache/hadoop/hive/ql/exec/spark/TestSmallTableCache.java
new file mode 100644
index 0000000000..ade989ece5
--- /dev/null
+++ ql/src/test/org/apache/hadoop/hive/ql/exec/spark/TestSmallTableCache.java
@@ -0,0 +1,151 @@
+package org.apache.hadoop.hive.ql.exec.spark;
+
+import com.google.common.testing.FakeTicker;
+import org.junit.Before;
+import org.junit.Test;
+
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.atomic.AtomicInteger;
+
+import static org.junit.Assert.assertEquals;
+
+public class TestSmallTableCache {
+ private static final String KEY = "Test";
+ private static final String TEST_VALUE_1 = "TestValue1";
+ private static final String TEST_VALUE_2 = "TestValue2";
+ private SmallTableCache.SmallTableLocalCache cache;
+ private AtomicInteger counter;
+
+ @Before
+ public void setUp() {
+ this.cache = new SmallTableCache.SmallTableLocalCache<>();
+ this.counter = new AtomicInteger(0);
+
+ }
+
+ @Test
+ public void testEmptyCache() throws ExecutionException {
+
+ String res = cache.get(KEY, () -> {
+ counter.incrementAndGet();
+ return new String(TEST_VALUE_1);
+ });
+
+ assertEquals(TEST_VALUE_1, res);
+ assertEquals(1, counter.get());
+ assertEquals(1, cache.size());
+ }
+
+ @Test
+ public void testL1Hit() throws ExecutionException {
+ cache.get(KEY, () -> {
+ counter.incrementAndGet();
+ return new String(TEST_VALUE_1);
+ });
+
+ String res = cache.get(KEY, () -> {
+ counter.incrementAndGet();
+ return new String(TEST_VALUE_2);
+ });
+
+ assertEquals(TEST_VALUE_1, res);
+ assertEquals(1, counter.get());
+ assertEquals(1, cache.size());
+ }
+
+ @Test
+ public void testL2Hit() throws ExecutionException {
+
+ FakeTicker ticker = new FakeTicker();
+ cache = new SmallTableCache.SmallTableLocalCache<>(ticker);
+
+ cache.get(KEY, () -> {
+ counter.incrementAndGet();
+ return new String(TEST_VALUE_1);
+ });
+
+ ticker.advance(60, TimeUnit.SECONDS);
+
+ String res = cache.get(KEY, () -> {
+ counter.incrementAndGet();
+ return new String(TEST_VALUE_2);
+ });
+
+ assertEquals(TEST_VALUE_1, res);
+ assertEquals(1, counter.get());
+ assertEquals(1, cache.size());
+ }
+
+ @Test
+ public void testL2Miss() throws ExecutionException {
+
+ FakeTicker ticker = new FakeTicker();
+ cache = new SmallTableCache.SmallTableLocalCache<>(ticker);
+
+ cache.get(KEY, () -> {
+ counter.incrementAndGet();
+ return new String(TEST_VALUE_1);
+ });
+
+ ticker.advance(60, TimeUnit.SECONDS);
+ cache.cleanup();
+ forceOOMToClearSoftValues();
+
+ String res = cache.get(KEY, () -> {
+ counter.incrementAndGet();
+ return new String(TEST_VALUE_2);
+ });
+
+ assertEquals(TEST_VALUE_2, res);
+ assertEquals(2, counter.get());
+ assertEquals(1, cache.size());
+ }
+
+ @Test
+ public void testL2IsNotClearedIfTheItemIsInL1() throws ExecutionException {
+
+ FakeTicker ticker = new FakeTicker();
+ cache = new SmallTableCache.SmallTableLocalCache<>(ticker);
+
+ cache.get(KEY, () -> {
+ counter.incrementAndGet();
+ return new String(TEST_VALUE_1);
+ });
+
+ forceOOMToClearSoftValues();
+ ticker.advance(60, TimeUnit.SECONDS);
+ cache.cleanup();
+
+ String res = cache.get(KEY, () -> {
+ counter.incrementAndGet();
+ return new String(TEST_VALUE_2);
+ });
+
+ assertEquals(TEST_VALUE_1, res);
+ assertEquals(1, counter.get());
+ assertEquals(1, cache.size());
+ }
+
+ @Test
+ public void testClear() throws ExecutionException {
+ cache.get(KEY, () -> {
+ counter.incrementAndGet();
+ return new String(TEST_VALUE_1);
+ });
+ cache.clear((k, v) -> {
+ });
+
+ assertEquals(1, counter.get());
+ assertEquals(0, cache.size());
+ }
+
+ private void forceOOMToClearSoftValues() {
+ try {
+ while (true) {
+ Object[] ignored = new Object[Integer.MAX_VALUE / 2];
+ }
+ } catch (OutOfMemoryError e) {
+ }
+ }
+}