diff --git a/ql/src/java/org/apache/hadoop/hive/ql/hooks/HookUtils.java b/ql/src/java/org/apache/hadoop/hive/ql/hooks/HookUtils.java index 2f0bd88de241fca15d222ef5b8818b61d7e63559..aabe15e7d8207e61155c202dcf9581960c8efd06 100644 --- a/ql/src/java/org/apache/hadoop/hive/ql/hooks/HookUtils.java +++ b/ql/src/java/org/apache/hadoop/hive/ql/hooks/HookUtils.java @@ -18,15 +18,93 @@ package org.apache.hadoop.hive.ql.hooks; -import java.util.ArrayList; -import java.util.List; +import java.lang.annotation.Annotation; +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; +import java.util.*; -import org.apache.hadoop.hive.common.JavaUtils; import org.apache.hadoop.hive.conf.HiveConf; import org.apache.hadoop.hive.conf.HiveConf.ConfVars; import org.apache.hadoop.hive.ql.exec.Utilities; public class HookUtils { + @Retention(RetentionPolicy.RUNTIME) + public @interface CacheableHook { } + + private static final HooksLRUCache HOOKS_LRU_CACHE = new HooksLRUCache(); + private static class HooksLRUCache { + private static final int SIZE = 100; + private static final int INITIAL_SIZE = 11; + private static final float THRESHOLD = 0.75F; + private static final boolean USE_LRU_ORDER = true; + + private final Map cacheable = new HashMap<>(); + private final Map instanceCache = new HashMap<>(); + + // Use an LRU so attackers can't spam the system with class-names and overflow the memory. + private final Map> classNameLRUCache + = new LinkedHashMap>(INITIAL_SIZE, THRESHOLD, USE_LRU_ORDER) { + + @Override + protected boolean removeEldestEntry(Map.Entry> eldest) { + return size() > SIZE; + } + }; + + public synchronized List instantiateHooksFor(String name) + throws IllegalAccessException, InstantiationException, ClassNotFoundException { + if (name == null || name.isEmpty()) return Collections.emptyList(); + + String trimmedName = name.trim(); + List hooks = new ArrayList(); + for (Class hookClass : classesForName(trimmedName)) { + hooks.add((T)instantiate(hookClass)); + } + + return hooks; + } + + private Iterable classesForName(String name) throws ClassNotFoundException { + if (!classNameLRUCache.containsKey(name)) { + List classesForName = new ArrayList(); + for (String className : name.split(",")) { + String trimmedClassName = className.trim(); + if (!trimmedClassName.isEmpty()) { + classesForName.add(Class.forName(trimmedClassName, true, Utilities.getSessionSpecifiedClassLoader())); + } + } + + classNameLRUCache.put(name, classesForName); + } + + return classNameLRUCache.get(name); + } + + private Hook instantiate(Class type) throws IllegalAccessException, InstantiationException { + if (instanceCache.containsKey(type)) return instanceCache.get(type); + + Hook instance = (Hook)type.newInstance(); + if (isCachable(type)) { + instanceCache.put(type, instance); + } + + return instance; + } + + private boolean isCachable(Class type) { + if (cacheable.containsKey(type)) cacheable.get(type); + for (Annotation annotation : type.getDeclaredAnnotations()) { + if (annotation instanceof CacheableHook) { + cacheable.put(type, true); + return true; + } + } + + cacheable.put(type, false); + return false; + } + } + /** * Returns the hooks specified in a configuration variable. The hooks are returned * in a list in the order they were specified in the configuration variable. @@ -44,25 +122,8 @@ public static List getHooks(HiveConf conf, ConfVars hookConfVar, Class clazz) throws InstantiationException, IllegalAccessException, ClassNotFoundException { - String csHooks = conf.getVar(hookConfVar); - List hooks = new ArrayList(); - if (csHooks == null) { - return hooks; - } - - csHooks = csHooks.trim(); - if (csHooks.equals("")) { - return hooks; - } - - String[] hookClasses = csHooks.split(","); - for (String hookClass : hookClasses) { - T hook = (T) Class.forName(hookClass.trim(), true, - Utilities.getSessionSpecifiedClassLoader()).newInstance(); - hooks.add(hook); - } - return hooks; + return HOOKS_LRU_CACHE.instantiateHooksFor(conf.getVar(hookConfVar)); } public static String redactLogString(HiveConf conf, String logString) diff --git a/ql/src/test/org/apache/hadoop/hive/ql/hooks/TestHooks.java b/ql/src/test/org/apache/hadoop/hive/ql/hooks/TestHooks.java index 8d27762522f0ac5f59696e019ec52a25e77318db..d43305e1268515f23208d32c6ab056f8e1842622 100644 --- a/ql/src/test/org/apache/hadoop/hive/ql/hooks/TestHooks.java +++ b/ql/src/test/org/apache/hadoop/hive/ql/hooks/TestHooks.java @@ -18,6 +18,8 @@ package org.apache.hadoop.hive.ql.hooks; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotEquals; +import static org.junit.Assert.assertTrue; import org.apache.hadoop.hive.conf.HiveConf; import org.apache.hadoop.hive.ql.Driver; @@ -27,21 +29,25 @@ import org.junit.AfterClass; import org.junit.Test; +import java.util.HashSet; +import java.util.List; +import java.util.Set; + public class TestHooks { @BeforeClass public static void onetimeSetup() throws Exception { - HiveConf conf = new HiveConf(TestHooks.class); - Driver driver = createDriver(conf); - int ret = driver.run("create table t1(i int)").getResponseCode(); - assertEquals("Checking command success", 0, ret); + //HiveConf conf = new HiveConf(TestHooks.class); + //Driver driver = createDriver(conf); + //int ret = driver.run("create table t1(i int)").getResponseCode(); + //assertEquals("Checking command success", 0, ret); } @AfterClass public static void onetimeTeardown() throws Exception { - HiveConf conf = new HiveConf(TestHooks.class); - Driver driver = createDriver(conf); - driver.run("drop table t1"); + //HiveConf conf = new HiveConf(TestHooks.class); + //Driver driver = createDriver(conf); + //driver.run("drop table t1"); } @Before @@ -76,6 +82,64 @@ public void testQueryRedactor() throws Exception { assertEquals("select 'AAA' from t1", HiveConf.getVar(conf, HiveConf.ConfVars.HIVEQUERYSTRING)); } + @HookUtils.CacheableHook + public static class CachableRedactor extends Redactor { } + public static class NonCachableRedactor extends Redactor { } + public static class AnotherNonCachableRedactor extends Redactor { } + + @Test + public void testNonCachableHook() throws Exception { + HiveConf conf = new HiveConf(TestHooks.class); + HiveConf.setVar(conf, HiveConf.ConfVars.QUERYREDACTORHOOKS, NonCachableRedactor.class.getName()); + + Hook[] instances = new Hook[100]; + for (int index = 0; index < instances.length; index++) { + instances[index] = HookUtils.getHooks(conf, HiveConf.ConfVars.QUERYREDACTORHOOKS, Hook.class).get(0); + } + + for (int lhs = 0; lhs < instances.length; lhs++) { + for (int rhs = lhs + 1; rhs < instances.length; rhs++) { + assertNotEquals(instances[lhs], instances[rhs]); + } + } + } + + @Test + public void testCachableHook() throws Exception { + HiveConf conf = new HiveConf(TestHooks.class); + HiveConf.setVar(conf, HiveConf.ConfVars.QUERYREDACTORHOOKS, CachableRedactor.class.getName()); + + Hook[] instances = new Hook[100]; + for (int index = 0; index < instances.length; index++) { + instances[index] = HookUtils.getHooks(conf, HiveConf.ConfVars.QUERYREDACTORHOOKS, Hook.class).get(0); + } + + for (int lhs = 0; lhs < instances.length; lhs++) { + for (int rhs = lhs + 1; rhs < instances.length; rhs++) { + assertEquals(instances[lhs], instances[rhs]); + } + } + } + + @Test + public void testMultipleHooks() throws Exception { + HiveConf conf = new HiveConf(TestHooks.class); + HiveConf.setVar(conf, HiveConf.ConfVars.QUERYREDACTORHOOKS, + String.format(" %s, %s, %s ", + CachableRedactor.class.getName(), + NonCachableRedactor.class.getName(), + AnotherNonCachableRedactor.class.getName())); + + Set seen = new HashSet<>(); + for (Hook hook : HookUtils.getHooks(conf, HiveConf.ConfVars.QUERYREDACTORHOOKS, Hook.class)) { + seen.add(hook.getClass()); + } + + assertTrue(seen.contains(CachableRedactor.class)); + assertTrue(seen.contains(NonCachableRedactor.class)); + assertTrue(seen.contains(AnotherNonCachableRedactor.class)); + } + public static class SimpleQueryRedactor extends Redactor { public String redactQuery(String query) { return query.replaceAll("XXX", "AAA");