diff --git ql/pom.xml ql/pom.xml index 8c3e55eaf4..7d8f9398c6 100644 --- ql/pom.xml +++ ql/pom.xml @@ -770,6 +770,12 @@ ${powermock.version} test + + com.google.guava + guava-testlib + ${guava.version} + test + diff --git ql/src/java/org/apache/hadoop/hive/ql/exec/MapJoinOperator.java ql/src/java/org/apache/hadoop/hive/ql/exec/MapJoinOperator.java index da1dd426c9..b0b4ec8610 100644 --- ql/src/java/org/apache/hadoop/hive/ql/exec/MapJoinOperator.java +++ ql/src/java/org/apache/hadoop/hive/ql/exec/MapJoinOperator.java @@ -54,6 +54,7 @@ import org.apache.hadoop.hive.ql.exec.persistence.MatchTracker; import org.apache.hadoop.hive.ql.exec.persistence.ObjectContainer; import org.apache.hadoop.hive.ql.exec.persistence.UnwrapRowContainer; +import org.apache.hadoop.hive.ql.exec.spark.SmallTableCache; import org.apache.hadoop.hive.ql.exec.spark.SparkUtilities; import org.apache.hadoop.hive.ql.exec.tez.LlapObjectCache; import org.apache.hadoop.hive.ql.exec.tez.LlapObjectSubCache; @@ -738,6 +739,21 @@ protected void generateFullOuterSmallTableNoMatches(byte smallTablePos, @Override public void closeOp(boolean abort) throws HiveException { + // Call the small table cache cache method, this way when a task finishes, we still keep the small table around + // for at least 30 seconds, which gives any tasks scheduled in the future a chance to re-use the small table. + if (HiveConf.getVar(hconf, ConfVars.HIVE_EXECUTION_ENGINE).equals("spark") && + SparkUtilities.isDedicatedCluster(hconf)) { + + for (byte pos = 0; pos < mapJoinTables.length; pos++) { + if (pos != conf.getPosBigTable()) { + MapJoinTableContainer container = mapJoinTables[pos]; + if (container != null && container.getKey() != null) { + SmallTableCache.cache(container.getKey(), container); + } + } + } + } + if (isFullOuterMapJoin) { // FULL OUTER MapJoin: After matching the Big Table row keys against the Small Table, we now diff --git ql/src/java/org/apache/hadoop/hive/ql/exec/persistence/AbstractMapJoinTableContainer.java ql/src/java/org/apache/hadoop/hive/ql/exec/persistence/AbstractMapJoinTableContainer.java index 9e65fd98d6..ee12162e4f 100644 --- ql/src/java/org/apache/hadoop/hive/ql/exec/persistence/AbstractMapJoinTableContainer.java +++ ql/src/java/org/apache/hadoop/hive/ql/exec/persistence/AbstractMapJoinTableContainer.java @@ -28,6 +28,8 @@ protected static final String THESHOLD_NAME = "threshold"; protected static final String LOAD_NAME = "load"; + private String key; + /** Creates metadata for implementation classes' ctors from threshold and load factor. */ protected static Map createConstructorMetaData(int threshold, float loadFactor) { Map metaData = new HashMap(); @@ -48,4 +50,14 @@ protected AbstractMapJoinTableContainer(Map metaData) { protected void putMetaData(String key, String value) { metaData.put(key, value); } + + @Override + public void setKey(String key) { + this.key = key; + } + + @Override + public String getKey() { + return key; + } } diff --git ql/src/java/org/apache/hadoop/hive/ql/exec/persistence/HybridHashTableContainer.java ql/src/java/org/apache/hadoop/hive/ql/exec/persistence/HybridHashTableContainer.java index 54377428ea..545a729652 100644 --- ql/src/java/org/apache/hadoop/hive/ql/exec/persistence/HybridHashTableContainer.java +++ ql/src/java/org/apache/hadoop/hive/ql/exec/persistence/HybridHashTableContainer.java @@ -121,6 +121,7 @@ private final List EMPTY_LIST = new ArrayList(0); private final String spillLocalDirs; + private String key; @Override public long getEstimatedMemorySize() { @@ -1296,4 +1297,14 @@ public void setSerde(MapJoinObjectSerDeContext keyCtx, MapJoinObjectSerDeContext } } } + + @Override + public void setKey(String key) { + this.key = key; + } + + @Override + public String getKey() { + return key; + } } diff --git ql/src/java/org/apache/hadoop/hive/ql/exec/persistence/MapJoinBytesTableContainer.java ql/src/java/org/apache/hadoop/hive/ql/exec/persistence/MapJoinBytesTableContainer.java index 0e4b8df036..4081cc3087 100644 --- ql/src/java/org/apache/hadoop/hive/ql/exec/persistence/MapJoinBytesTableContainer.java +++ ql/src/java/org/apache/hadoop/hive/ql/exec/persistence/MapJoinBytesTableContainer.java @@ -100,6 +100,7 @@ private DirectKeyValueWriter directWriteHelper; private final List EMPTY_LIST = new ArrayList(0); + private String key; public MapJoinBytesTableContainer(Configuration hconf, MapJoinObjectSerDeContext valCtx, long keyCount, long memUsage) throws SerDeException { @@ -443,6 +444,16 @@ public void setSerde(MapJoinObjectSerDeContext keyContext, MapJoinObjectSerDeCon } } + @Override + public void setKey(String key) { + this.key = key; + } + + @Override + public String getKey() { + return key; + } + @SuppressWarnings("deprecation") @Override public MapJoinKey putRow(Writable currentKey, Writable currentValue) throws SerDeException { diff --git ql/src/java/org/apache/hadoop/hive/ql/exec/persistence/MapJoinTableContainer.java ql/src/java/org/apache/hadoop/hive/ql/exec/persistence/MapJoinTableContainer.java index 74e0b120ea..595271a2f6 100644 --- ql/src/java/org/apache/hadoop/hive/ql/exec/persistence/MapJoinTableContainer.java +++ ql/src/java/org/apache/hadoop/hive/ql/exec/persistence/MapJoinTableContainer.java @@ -173,4 +173,14 @@ MapJoinKey putRow(Writable currentKey, Writable currentValue) void setSerde(MapJoinObjectSerDeContext keyCtx, MapJoinObjectSerDeContext valCtx) throws SerDeException; + + /** + * Assign a key to the container, which can be used to cache it. + */ + void setKey(String key); + + /** + * Return the assigned key. + */ + String getKey(); } diff --git ql/src/java/org/apache/hadoop/hive/ql/exec/persistence/MapJoinTableContainerSerDe.java ql/src/java/org/apache/hadoop/hive/ql/exec/persistence/MapJoinTableContainerSerDe.java index 24b8fea338..5fff1e3508 100644 --- ql/src/java/org/apache/hadoop/hive/ql/exec/persistence/MapJoinTableContainerSerDe.java +++ ql/src/java/org/apache/hadoop/hive/ql/exec/persistence/MapJoinTableContainerSerDe.java @@ -164,6 +164,7 @@ public MapJoinTableContainer load( } } if (tableContainer != null) { + tableContainer.setKey(folder.toString()); tableContainer.seal(); } return tableContainer; @@ -261,8 +262,8 @@ public MapJoinTableContainer loadFastContainer(MapJoinDesc mapJoinDesc, } } } + tableContainer.setKey(folder.toString()); } - tableContainer.seal(); return tableContainer; } catch (IOException e) { diff --git ql/src/java/org/apache/hadoop/hive/ql/exec/spark/HashTableLoader.java ql/src/java/org/apache/hadoop/hive/ql/exec/spark/HashTableLoader.java index cf27e92baf..4e6a8eed65 100644 --- ql/src/java/org/apache/hadoop/hive/ql/exec/spark/HashTableLoader.java +++ ql/src/java/org/apache/hadoop/hive/ql/exec/spark/HashTableLoader.java @@ -21,6 +21,7 @@ import java.util.Arrays; import java.util.List; import java.util.Set; +import java.util.concurrent.ExecutionException; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -159,22 +160,20 @@ private MapJoinTableContainer load(FileSystem fs, Path path, MapJoinTableContainerSerDe mapJoinTableSerde) throws HiveException { LOG.info("\tLoad back all hashtable files from tmp folder uri:" + path); if (!SparkUtilities.isDedicatedCluster(hconf)) { - return useFastContainer ? mapJoinTableSerde.loadFastContainer(desc, fs, path, hconf) : - mapJoinTableSerde.load(fs, path, hconf); + return loadMapJoinTableContainer(fs, path, mapJoinTableSerde); } - MapJoinTableContainer mapJoinTable = SmallTableCache.get(path); - if (mapJoinTable == null) { - synchronized (path.toString().intern()) { - mapJoinTable = SmallTableCache.get(path); - if (mapJoinTable == null) { - mapJoinTable = useFastContainer ? - mapJoinTableSerde.loadFastContainer(desc, fs, path, hconf) : - mapJoinTableSerde.load(fs, path, hconf); - SmallTableCache.cache(path, mapJoinTable); - } - } + + try { + return SmallTableCache.get(path.toString(), () -> loadMapJoinTableContainer(fs, path, mapJoinTableSerde)); + } catch (ExecutionException e) { + throw new HiveException(e); } - return mapJoinTable; + } + + private MapJoinTableContainer loadMapJoinTableContainer( + FileSystem fs, Path path, MapJoinTableContainerSerDe mapJoinTableSerde) throws HiveException { + return useFastContainer ? mapJoinTableSerde.loadFastContainer(desc, fs, path, hconf) : + mapJoinTableSerde.load(fs, path, hconf); } private void loadDirectly(MapJoinTableContainer[] mapJoinTables, String inputFileName) diff --git ql/src/java/org/apache/hadoop/hive/ql/exec/spark/SmallTableCache.java ql/src/java/org/apache/hadoop/hive/ql/exec/spark/SmallTableCache.java index 3293100af9..ccffef04c3 100644 --- ql/src/java/org/apache/hadoop/hive/ql/exec/spark/SmallTableCache.java +++ ql/src/java/org/apache/hadoop/hive/ql/exec/spark/SmallTableCache.java @@ -17,20 +17,33 @@ */ package org.apache.hadoop.hive.ql.exec.spark; -import java.util.concurrent.ConcurrentHashMap; - -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; +import com.google.common.annotations.VisibleForTesting; +import com.google.common.base.Ticker; +import com.google.common.cache.Cache; +import com.google.common.cache.CacheBuilder; +import com.google.common.util.concurrent.ThreadFactoryBuilder; import org.apache.hadoop.conf.Configuration; -import org.apache.hadoop.fs.Path; import org.apache.hadoop.hive.conf.HiveConf; import org.apache.hadoop.hive.ql.exec.persistence.MapJoinTableContainer; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.concurrent.Callable; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.Executors; +import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.TimeUnit; +import java.util.function.BiConsumer; + +import static org.apache.commons.lang3.math.NumberUtils.INTEGER_ONE; +import static org.apache.commons.lang3.math.NumberUtils.INTEGER_ZERO; public class SmallTableCache { private static final Logger LOG = LoggerFactory.getLogger(SmallTableCache.class.getName()); - private static final ConcurrentHashMap - tableContainerMap = new ConcurrentHashMap(); + private static final SmallTableLocalCache + TABLE_CONTAINER_CACHE = new SmallTableLocalCache(); + private static volatile String queryId; /** @@ -40,13 +53,10 @@ public static void initialize(Configuration conf) { String currentQueryId = conf.get(HiveConf.ConfVars.HIVEQUERYID.varname); if (!currentQueryId.equals(queryId)) { - if (!tableContainerMap.isEmpty()) { - synchronized (tableContainerMap) { - if (!currentQueryId.equals(queryId) && !tableContainerMap.isEmpty()) { - for (MapJoinTableContainer tableContainer: tableContainerMap.values()) { - tableContainer.clear(); - } - tableContainerMap.clear(); + if (TABLE_CONTAINER_CACHE.size() != 0) { + synchronized (TABLE_CONTAINER_CACHE) { + if (!currentQueryId.equals(queryId) && TABLE_CONTAINER_CACHE.size() != 0) { + TABLE_CONTAINER_CACHE.clear((path, tableContainer) -> tableContainer.clear()); if (LOG.isDebugEnabled()) { LOG.debug("Cleaned up small table cache for query " + queryId); } @@ -57,17 +67,84 @@ public static void initialize(Configuration conf) { } } - public static void cache(Path path, MapJoinTableContainer tableContainer) { - if (tableContainerMap.putIfAbsent(path, tableContainer) == null && LOG.isDebugEnabled()) { - LOG.debug("Cached small table file " + path + " for query " + queryId); - } + public static void cache(String key, MapJoinTableContainer tableContainer) { + TABLE_CONTAINER_CACHE.put(key, tableContainer); + } + + public static MapJoinTableContainer get(String key, Callable valueLoader) + throws ExecutionException { + return TABLE_CONTAINER_CACHE.get(key, valueLoader); } - public static MapJoinTableContainer get(Path path) { - MapJoinTableContainer tableContainer = tableContainerMap.get(path); - if (tableContainer != null && LOG.isDebugEnabled()) { - LOG.debug("Loaded small table file " + path + " from cache for query " + queryId); + /** + * Two level cache implementation. The level 1 cache keeps the cached values until 30 seconds, + * the level 2 cache keeps the values (by using soft references) until the GC decides, that it + * needs the memory occupied by the cache. + * + * @param the type of the key + * @param the type of the value + */ + @VisibleForTesting + static class SmallTableLocalCache { + + private static final int MAINTENANCE_THREAD_CLEANUP_PERIOD = 10; + private static final int L1_CACHE_EXPIRE_DURATION = 30; + + private final Cache cacheL1; + private final Cache cacheL2; + private final ScheduledExecutorService cleanupService; + + SmallTableLocalCache() { + this(Ticker.systemTicker()); + } + + @VisibleForTesting + SmallTableLocalCache(Ticker ticker) { + cleanupService = Executors.newScheduledThreadPool(INTEGER_ONE, + new ThreadFactoryBuilder().setNameFormat("SmallTableCache Cleanup Thread").setDaemon(true).build()); + cacheL1 = CacheBuilder.newBuilder().expireAfterAccess(L1_CACHE_EXPIRE_DURATION, TimeUnit.SECONDS) + .ticker(ticker).build(); + cacheL2 = CacheBuilder.newBuilder().softValues().build(); + cleanupService.scheduleAtFixedRate(() -> { + cleanup(); + }, INTEGER_ZERO, MAINTENANCE_THREAD_CLEANUP_PERIOD, TimeUnit.SECONDS); + } + + /** + * Return the number of cached elements. + */ + // L2 >= L1, because if a cached item is in L1 then its in L2 as well. + public long size() { + return cacheL2.size(); + } + + /** + * Invalidate the cache, and call the action on the elements, if additional cleanup is required. + */ + public void clear(BiConsumer action) { + cacheL1.invalidateAll(); + cacheL2.asMap().forEach(action); + cacheL2.invalidateAll(); + } + + @VisibleForTesting + void cleanup() { + cacheL1.cleanUp(); + } + + /** + * Put an item into the cache. If the item was already there, it will be overwritten. + */ + public void put(K key, V value) { + cacheL2.put(key, value); + cacheL1.put(key, value); + } + + /** + * Retrieves an item from the cache, and if its not there, it will use the valueLoader to load it and cache it. + */ + public V get(K key, Callable valueLoader) throws ExecutionException { + return cacheL1.get(key, () -> cacheL2.get(key, valueLoader)); } - return tableContainer; } } diff --git ql/src/java/org/apache/hadoop/hive/ql/exec/vector/mapjoin/fast/VectorMapJoinFastTableContainer.java ql/src/java/org/apache/hadoop/hive/ql/exec/vector/mapjoin/fast/VectorMapJoinFastTableContainer.java index e8dcbf18cb..4ab8902a3f 100644 --- ql/src/java/org/apache/hadoop/hive/ql/exec/vector/mapjoin/fast/VectorMapJoinFastTableContainer.java +++ ql/src/java/org/apache/hadoop/hive/ql/exec/vector/mapjoin/fast/VectorMapJoinFastTableContainer.java @@ -61,6 +61,7 @@ private final VectorMapJoinFastHashTable vectorMapJoinFastHashTable; + private String key; public VectorMapJoinFastTableContainer(MapJoinDesc desc, Configuration hconf, long estimatedKeyCount) throws SerDeException { @@ -88,6 +89,16 @@ public VectorMapJoinHashTable vectorMapJoinHashTable() { return vectorMapJoinFastHashTable; } + @Override + public void setKey(String key) { + this.key = key; + } + + @Override + public String getKey() { + return key; + } + private VectorMapJoinFastHashTable createHashTable(int newThreshold) { VectorMapJoinDesc vectorDesc = (VectorMapJoinDesc) desc.getVectorDesc(); diff --git ql/src/test/org/apache/hadoop/hive/ql/exec/spark/TestSmallTableCache.java ql/src/test/org/apache/hadoop/hive/ql/exec/spark/TestSmallTableCache.java new file mode 100644 index 0000000000..6568a22c7b --- /dev/null +++ ql/src/test/org/apache/hadoop/hive/ql/exec/spark/TestSmallTableCache.java @@ -0,0 +1,202 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.hadoop.hive.ql.exec.spark; + +import com.google.common.testing.FakeTicker; +import org.junit.Before; +import org.junit.Test; + +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; + +import static org.junit.Assert.assertEquals; + +/** + * Test the two level cache. + */ +public class TestSmallTableCache { + private static final String KEY = "Test"; + private static final String TEST_VALUE_1 = "TestValue1"; + private static final String TEST_VALUE_2 = "TestValue2"; + private SmallTableCache.SmallTableLocalCache cache; + private AtomicInteger counter; + + @Before + public void setUp() { + this.cache = new SmallTableCache.SmallTableLocalCache<>(); + this.counter = new AtomicInteger(0); + } + + @Test + public void testEmptyCache() throws ExecutionException { + + String res = cache.get(KEY, () -> { + counter.incrementAndGet(); + return new String(TEST_VALUE_1); + }); + + assertEquals(TEST_VALUE_1, res); + assertEquals(1, counter.get()); + assertEquals(1, cache.size()); + } + + @Test + public void testL1Hit() throws ExecutionException { + cache.get(KEY, () -> { + counter.incrementAndGet(); + return new String(TEST_VALUE_1); + }); + + String res = cache.get(KEY, () -> { + counter.incrementAndGet(); + return new String(TEST_VALUE_2); + }); + + assertEquals(TEST_VALUE_1, res); + assertEquals(1, counter.get()); + assertEquals(1, cache.size()); + } + + @Test + public void testL2Hit() throws ExecutionException { + + FakeTicker ticker = new FakeTicker(); + cache = new SmallTableCache.SmallTableLocalCache<>(ticker); + + cache.get(KEY, () -> { + counter.incrementAndGet(); + return new String(TEST_VALUE_1); + }); + + ticker.advance(60, TimeUnit.SECONDS); + + String res = cache.get(KEY, () -> { + counter.incrementAndGet(); + return new String(TEST_VALUE_2); + }); + + assertEquals(TEST_VALUE_1, res); + assertEquals(1, counter.get()); + assertEquals(1, cache.size()); + } + + @Test + public void testL2Miss() throws ExecutionException { + + FakeTicker ticker = new FakeTicker(); + cache = new SmallTableCache.SmallTableLocalCache<>(ticker); + + cache.get(KEY, () -> { + counter.incrementAndGet(); + return new String(TEST_VALUE_1); + }); + + ticker.advance(60, TimeUnit.SECONDS); + cache.cleanup(); + forceOOMToClearSoftValues(); + + String res = cache.get(KEY, () -> { + counter.incrementAndGet(); + return new String(TEST_VALUE_2); + }); + + assertEquals(TEST_VALUE_2, res); + assertEquals(2, counter.get()); + assertEquals(1, cache.size()); + } + + @Test + public void testL2IsNotClearedIfTheItemIsInL1() throws ExecutionException { + + FakeTicker ticker = new FakeTicker(); + cache = new SmallTableCache.SmallTableLocalCache<>(ticker); + + cache.get(KEY, () -> { + counter.incrementAndGet(); + return new String(TEST_VALUE_1); + }); + + forceOOMToClearSoftValues(); + ticker.advance(60, TimeUnit.SECONDS); + cache.cleanup(); + + String res = cache.get(KEY, () -> { + counter.incrementAndGet(); + return new String(TEST_VALUE_2); + }); + + assertEquals(TEST_VALUE_1, res); + assertEquals(1, counter.get()); + assertEquals(1, cache.size()); + } + + @Test + public void testClear() throws ExecutionException { + cache.get(KEY, () -> { + counter.incrementAndGet(); + return new String(TEST_VALUE_1); + }); + cache.clear((k, v) -> { + }); + + assertEquals(1, counter.get()); + assertEquals(0, cache.size()); + } + + @Test + public void testPutL1() throws ExecutionException { + cache.put(KEY, new String(TEST_VALUE_1)); + String res = cache.get(KEY, () -> { + counter.incrementAndGet(); + return new String(TEST_VALUE_2); + }); + + assertEquals(TEST_VALUE_1, res); + assertEquals(0, counter.get()); + } + + @Test + public void testPutL2() throws ExecutionException { + + FakeTicker ticker = new FakeTicker(); + cache = new SmallTableCache.SmallTableLocalCache<>(ticker); + + cache.put(KEY, new String(TEST_VALUE_1)); + String res = cache.get(KEY, () -> { + counter.incrementAndGet(); + return new String(TEST_VALUE_2); + }); + + ticker.advance(60, TimeUnit.SECONDS); + cache.cleanup(); + + assertEquals(TEST_VALUE_1, res); + assertEquals(0, counter.get()); + } + + private void forceOOMToClearSoftValues() { + try { + while (true) { + Object[] ignored = new Object[Integer.MAX_VALUE / 2]; + } + } catch (OutOfMemoryError e) { + } + } +}