diff --git a/ql/src/java/org/apache/hadoop/hive/ql/exec/Registry.java b/ql/src/java/org/apache/hadoop/hive/ql/exec/Registry.java index a4584e3..64513e3 100644 --- a/ql/src/java/org/apache/hadoop/hive/ql/exec/Registry.java +++ b/ql/src/java/org/apache/hadoop/hive/ql/exec/Registry.java @@ -80,7 +80,7 @@ * Persistent map contains refcounts that are only modified in synchronized methods for now, * so there's no separate effort to make refcount operations thread-safe. */ - private final Map, Integer> persistent = new ConcurrentHashMap<>(); + private final Map persistent = new ConcurrentHashMap<>(); private final Set mSessionUDFLoaders = new LinkedHashSet(); private final boolean isNative; @@ -309,7 +309,7 @@ public boolean isBuiltInFunc(Class udfClass) { } public boolean isPermanentFunc(Class udfClass) { - return udfClass != null && persistent.containsKey(udfClass); + return udfClass != null && persistent.containsKey(udfClass.getCanonicalName()); } public Set getCurrentFunctionNames() { @@ -463,9 +463,9 @@ private void addFunction(String functionName, FunctionInfo function) { if (function.isBuiltIn()) { builtIns.add(function.getFunctionClass()); } else if (function.isPersistent()) { - Class functionClass = getPermanentUdfClass(function); - Integer refCount = persistent.get(functionClass); - persistent.put(functionClass, Integer.valueOf(refCount == null ? 1 : refCount + 1)); + String className = function.getClassName(); + Integer refCount = persistent.get(className); + persistent.put(className, Integer.valueOf(refCount == null ? 1 : refCount + 1)); } } finally { lock.unlock(); @@ -507,13 +507,13 @@ public void unregisterFunction(String functionName) throws HiveException { } private void removePersistentFunctionUnderLock(FunctionInfo fi) { - Class functionClass = getPermanentUdfClass(fi); - Integer refCount = persistent.get(functionClass); + String className = fi.getClassName(); + Integer refCount = persistent.get(className); assert refCount != null; if (refCount == 1) { - persistent.remove(functionClass); + persistent.remove(className); } else { - persistent.put(functionClass, Integer.valueOf(refCount - 1)); + persistent.put(className, Integer.valueOf(refCount - 1)); } } diff --git a/ql/src/test/org/apache/hadoop/hive/ql/exec/TestRegistry.java b/ql/src/test/org/apache/hadoop/hive/ql/exec/TestRegistry.java new file mode 100644 index 0000000..41a2441 --- /dev/null +++ b/ql/src/test/org/apache/hadoop/hive/ql/exec/TestRegistry.java @@ -0,0 +1,96 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + *

+ * http://www.apache.org/licenses/LICENSE-2.0 + *

+ * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.hadoop.hive.ql.exec; + +import com.google.common.io.ByteStreams; +import com.google.common.io.Files; +import org.apache.commons.io.FileUtils; +import org.apache.hadoop.hive.conf.HiveConf; +import org.apache.hadoop.hive.ql.session.SessionState; +import org.junit.Test; + +import javax.tools.JavaCompiler; +import javax.tools.JavaFileObject; +import javax.tools.SimpleJavaFileObject; +import javax.tools.ToolProvider; +import java.io.File; +import java.io.FileInputStream; +import java.io.FileOutputStream; +import java.io.IOException; +import java.net.URI; +import java.nio.file.Paths; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.jar.JarEntry; +import java.util.jar.JarOutputStream; + +public class TestRegistry { + private static class JavaSourceFromCode extends SimpleJavaFileObject { + private final String code; + + public JavaSourceFromCode(URI uri, String code) { + super(uri, Kind.SOURCE); + this.code = code; + } + + @Override + public CharSequence getCharContent(boolean ignoreEncodingErrors) throws IOException { + return code; + } + } + + @Test + public void testAddRemoveFunction() throws Exception { + final SessionState ss = new SessionState(new HiveConf()); + SessionState.setCurrentSessionState(ss); + final File tmpDir = Files.createTempDir(); + try { + final File srcDir = new File(tmpDir, "repro/"); + final String className = "MyFooUDF"; + JavaFileObject source = new JavaSourceFromCode( + new File(srcDir, className + ".java").toURI(), + "package repro;import org.apache.hadoop.hive.ql.exec.UDF;" + + "public class " + className + " extends UDF{}"); + JavaCompiler compiler = ToolProvider.getSystemJavaCompiler(); + compiler.getTask(null, null, null, + new ArrayList(), null, Arrays.asList(source)).call(); + File classFile = new File(tmpDir, className + ".class"); + Files.move(new File(className + ".class"), classFile); + File jarFile = new File(tmpDir, "test.jar"); + FileOutputStream jarFileStream = new FileOutputStream(jarFile); + JarOutputStream jarStream = new JarOutputStream(jarFileStream, new java.util.jar.Manifest()); + JarEntry jarEntry = new JarEntry(Paths.get("repro", classFile.getName()).toString()); + jarStream.putNextEntry(jarEntry); + FileInputStream in = new FileInputStream(classFile); + ByteStreams.copy(in, jarStream); + in.close(); + jarStream.close(); + jarFileStream.close(); + Registry registry = new Registry(true); + FunctionInfo.FunctionResource resource = new FunctionInfo.FunctionResource( + SessionState.ResourceType.JAR, jarFile.toURI().toString()); + registry.registerPermanentFunction(className, "repro." + className, false, resource); + registry.unregisterFunction(className); + registry.registerPermanentFunction(className, "repro." + className, true, resource); + registry.unregisterFunction(className); + } finally { + FileUtils.deleteDirectory(tmpDir); + } + } +}