diff --git ql/pom.xml ql/pom.xml index d73deba440..2d3a0c5468 100644 --- ql/pom.xml +++ ql/pom.xml @@ -770,6 +770,12 @@ ${powermock.version} test + + com.google.guava + guava-testlib + ${guava.version} + test + diff --git ql/src/java/org/apache/hadoop/hive/ql/exec/spark/HashTableLoader.java ql/src/java/org/apache/hadoop/hive/ql/exec/spark/HashTableLoader.java index cf27e92baf..e60dbaef8e 100644 --- ql/src/java/org/apache/hadoop/hive/ql/exec/spark/HashTableLoader.java +++ ql/src/java/org/apache/hadoop/hive/ql/exec/spark/HashTableLoader.java @@ -21,6 +21,7 @@ import java.util.Arrays; import java.util.List; import java.util.Set; +import java.util.concurrent.ExecutionException; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -159,22 +160,19 @@ private MapJoinTableContainer load(FileSystem fs, Path path, MapJoinTableContainerSerDe mapJoinTableSerde) throws HiveException { LOG.info("\tLoad back all hashtable files from tmp folder uri:" + path); if (!SparkUtilities.isDedicatedCluster(hconf)) { - return useFastContainer ? mapJoinTableSerde.loadFastContainer(desc, fs, path, hconf) : - mapJoinTableSerde.load(fs, path, hconf); + return loadMapJoinTableContainer(fs, path, mapJoinTableSerde); } - MapJoinTableContainer mapJoinTable = SmallTableCache.get(path); - if (mapJoinTable == null) { - synchronized (path.toString().intern()) { - mapJoinTable = SmallTableCache.get(path); - if (mapJoinTable == null) { - mapJoinTable = useFastContainer ? - mapJoinTableSerde.loadFastContainer(desc, fs, path, hconf) : - mapJoinTableSerde.load(fs, path, hconf); - SmallTableCache.cache(path, mapJoinTable); - } - } + + try { + return SmallTableCache.get(path, () -> loadMapJoinTableContainer(fs, path, mapJoinTableSerde)); + } catch (ExecutionException e) { + throw new HiveException(e); } - return mapJoinTable; + } + + private MapJoinTableContainer loadMapJoinTableContainer(FileSystem fs, Path path, MapJoinTableContainerSerDe mapJoinTableSerde) throws HiveException { + return useFastContainer ? mapJoinTableSerde.loadFastContainer(desc, fs, path, hconf) : + mapJoinTableSerde.load(fs, path, hconf); } private void loadDirectly(MapJoinTableContainer[] mapJoinTables, String inputFileName) diff --git ql/src/java/org/apache/hadoop/hive/ql/exec/spark/SmallTableCache.java ql/src/java/org/apache/hadoop/hive/ql/exec/spark/SmallTableCache.java index 3293100af9..a1b8952b7a 100644 --- ql/src/java/org/apache/hadoop/hive/ql/exec/spark/SmallTableCache.java +++ ql/src/java/org/apache/hadoop/hive/ql/exec/spark/SmallTableCache.java @@ -17,20 +17,32 @@ */ package org.apache.hadoop.hive.ql.exec.spark; -import java.util.concurrent.ConcurrentHashMap; - -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; +import com.google.common.annotations.VisibleForTesting; +import com.google.common.base.Ticker; +import com.google.common.cache.Cache; +import com.google.common.cache.CacheBuilder; +import com.google.common.util.concurrent.ThreadFactoryBuilder; +import org.apache.commons.lang3.math.NumberUtils; import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.fs.Path; import org.apache.hadoop.hive.conf.HiveConf; import org.apache.hadoop.hive.ql.exec.persistence.MapJoinTableContainer; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.concurrent.Callable; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.Executors; +import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.TimeUnit; +import java.util.function.BiConsumer; public class SmallTableCache { private static final Logger LOG = LoggerFactory.getLogger(SmallTableCache.class.getName()); - private static final ConcurrentHashMap - tableContainerMap = new ConcurrentHashMap(); + private static final SmallTableLocalCache + tableContainerCache = new SmallTableLocalCache(); + private static volatile String queryId; /** @@ -40,13 +52,10 @@ public static void initialize(Configuration conf) { String currentQueryId = conf.get(HiveConf.ConfVars.HIVEQUERYID.varname); if (!currentQueryId.equals(queryId)) { - if (!tableContainerMap.isEmpty()) { - synchronized (tableContainerMap) { - if (!currentQueryId.equals(queryId) && !tableContainerMap.isEmpty()) { - for (MapJoinTableContainer tableContainer: tableContainerMap.values()) { - tableContainer.clear(); - } - tableContainerMap.clear(); + if (tableContainerCache.size() != 0) { + synchronized (tableContainerCache) { + if (!currentQueryId.equals(queryId) && tableContainerCache.size() != 0) { + tableContainerCache.clear((path, tableContainer) -> tableContainer.clear()); if (LOG.isDebugEnabled()) { LOG.debug("Cleaned up small table cache for query " + queryId); } @@ -57,17 +66,59 @@ public static void initialize(Configuration conf) { } } - public static void cache(Path path, MapJoinTableContainer tableContainer) { - if (tableContainerMap.putIfAbsent(path, tableContainer) == null && LOG.isDebugEnabled()) { - LOG.debug("Cached small table file " + path + " for query " + queryId); - } + public static MapJoinTableContainer get(Path path, Callable valueLoader) + throws ExecutionException { + return tableContainerCache.get(path, valueLoader); } - public static MapJoinTableContainer get(Path path) { - MapJoinTableContainer tableContainer = tableContainerMap.get(path); - if (tableContainer != null && LOG.isDebugEnabled()) { - LOG.debug("Loaded small table file " + path + " from cache for query " + queryId); + @VisibleForTesting + static class SmallTableLocalCache { + + private static final int MAINTENANCE_THREAD_CLEANUP_PERIOD = 10; + private static final int L1_CACHE_EXPIRE_DURATION = 30; + + private final Cache L1; + private final Cache L2; + private final ScheduledExecutorService scheduler; + + public SmallTableLocalCache() { + this(Ticker.systemTicker()); + } + + @VisibleForTesting + SmallTableLocalCache(Ticker ticker) { + scheduler = Executors.newScheduledThreadPool(NumberUtils.INTEGER_ONE, new ThreadFactoryBuilder().setNameFormat("SmallTableCache maintenance thread").setDaemon(true).build()); + L1 = CacheBuilder.newBuilder().expireAfterAccess(L1_CACHE_EXPIRE_DURATION, TimeUnit.SECONDS).ticker(ticker).build(); + L2 = CacheBuilder.newBuilder().softValues().build(); + scheduler.scheduleAtFixedRate(() -> { + cleanup(); + }, NumberUtils.INTEGER_ZERO, MAINTENANCE_THREAD_CLEANUP_PERIOD, TimeUnit.SECONDS); + } + + // L2 >= L1, because if a cached item is in L1 then its in L2 as well. + public long size(){ + return L2.size(); + } + + public void clear(BiConsumer action) { + L1.invalidateAll(); + L2.asMap().forEach(action); + L2.invalidateAll(); + } + + @VisibleForTesting + void cleanup() { + L1.cleanUp(); + } + + public V get(K key, Callable valueLoader) throws ExecutionException { + V value = L1.getIfPresent(key); + if (value == null) { + value = L2.get(key, valueLoader); + L1.put(key, value); + } + return value; } - return tableContainer; } + } diff --git ql/src/test/org/apache/hadoop/hive/ql/exec/spark/TestSmallTableCache.java ql/src/test/org/apache/hadoop/hive/ql/exec/spark/TestSmallTableCache.java new file mode 100644 index 0000000000..ade989ece5 --- /dev/null +++ ql/src/test/org/apache/hadoop/hive/ql/exec/spark/TestSmallTableCache.java @@ -0,0 +1,151 @@ +package org.apache.hadoop.hive.ql.exec.spark; + +import com.google.common.testing.FakeTicker; +import org.junit.Before; +import org.junit.Test; + +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; + +import static org.junit.Assert.assertEquals; + +public class TestSmallTableCache { + private static final String KEY = "Test"; + private static final String TEST_VALUE_1 = "TestValue1"; + private static final String TEST_VALUE_2 = "TestValue2"; + private SmallTableCache.SmallTableLocalCache cache; + private AtomicInteger counter; + + @Before + public void setUp() { + this.cache = new SmallTableCache.SmallTableLocalCache<>(); + this.counter = new AtomicInteger(0); + + } + + @Test + public void testEmptyCache() throws ExecutionException { + + String res = cache.get(KEY, () -> { + counter.incrementAndGet(); + return new String(TEST_VALUE_1); + }); + + assertEquals(TEST_VALUE_1, res); + assertEquals(1, counter.get()); + assertEquals(1, cache.size()); + } + + @Test + public void testL1Hit() throws ExecutionException { + cache.get(KEY, () -> { + counter.incrementAndGet(); + return new String(TEST_VALUE_1); + }); + + String res = cache.get(KEY, () -> { + counter.incrementAndGet(); + return new String(TEST_VALUE_2); + }); + + assertEquals(TEST_VALUE_1, res); + assertEquals(1, counter.get()); + assertEquals(1, cache.size()); + } + + @Test + public void testL2Hit() throws ExecutionException { + + FakeTicker ticker = new FakeTicker(); + cache = new SmallTableCache.SmallTableLocalCache<>(ticker); + + cache.get(KEY, () -> { + counter.incrementAndGet(); + return new String(TEST_VALUE_1); + }); + + ticker.advance(60, TimeUnit.SECONDS); + + String res = cache.get(KEY, () -> { + counter.incrementAndGet(); + return new String(TEST_VALUE_2); + }); + + assertEquals(TEST_VALUE_1, res); + assertEquals(1, counter.get()); + assertEquals(1, cache.size()); + } + + @Test + public void testL2Miss() throws ExecutionException { + + FakeTicker ticker = new FakeTicker(); + cache = new SmallTableCache.SmallTableLocalCache<>(ticker); + + cache.get(KEY, () -> { + counter.incrementAndGet(); + return new String(TEST_VALUE_1); + }); + + ticker.advance(60, TimeUnit.SECONDS); + cache.cleanup(); + forceOOMToClearSoftValues(); + + String res = cache.get(KEY, () -> { + counter.incrementAndGet(); + return new String(TEST_VALUE_2); + }); + + assertEquals(TEST_VALUE_2, res); + assertEquals(2, counter.get()); + assertEquals(1, cache.size()); + } + + @Test + public void testL2IsNotClearedIfTheItemIsInL1() throws ExecutionException { + + FakeTicker ticker = new FakeTicker(); + cache = new SmallTableCache.SmallTableLocalCache<>(ticker); + + cache.get(KEY, () -> { + counter.incrementAndGet(); + return new String(TEST_VALUE_1); + }); + + forceOOMToClearSoftValues(); + ticker.advance(60, TimeUnit.SECONDS); + cache.cleanup(); + + String res = cache.get(KEY, () -> { + counter.incrementAndGet(); + return new String(TEST_VALUE_2); + }); + + assertEquals(TEST_VALUE_1, res); + assertEquals(1, counter.get()); + assertEquals(1, cache.size()); + } + + @Test + public void testClear() throws ExecutionException { + cache.get(KEY, () -> { + counter.incrementAndGet(); + return new String(TEST_VALUE_1); + }); + cache.clear((k, v) -> { + }); + + assertEquals(1, counter.get()); + assertEquals(0, cache.size()); + } + + private void forceOOMToClearSoftValues() { + try { + while (true) { + Object[] ignored = new Object[Integer.MAX_VALUE / 2]; + } + } catch (OutOfMemoryError e) { + } + } +}