diff --git a/spark-client/src/main/java/org/apache/hive/spark/client/SparkClientFactory.java b/spark-client/src/main/java/org/apache/hive/spark/client/SparkClientFactory.java index cf60b13..8cedd30 100644 --- a/spark-client/src/main/java/org/apache/hive/spark/client/SparkClientFactory.java +++ b/spark-client/src/main/java/org/apache/hive/spark/client/SparkClientFactory.java @@ -43,20 +43,24 @@ /** Used by client and driver to share a secret for establishing an RPC session. */ static final String CONF_KEY_SECRET = "spark.client.authentication.secret"; - private static RpcServer server = null; - private static final Object stopLock = new Object(); + private static volatile RpcServer server = null; + private static final Object serverLock = new Object(); /** * Initializes the SparkClient library. Must be called before creating client instances. * * @param conf Map containing configuration parameters for the client library. */ - public static synchronized void initialize(Map conf) throws IOException { + public static void initialize(Map conf) throws IOException { if (server == null) { - try { - server = new RpcServer(conf); - } catch (InterruptedException ie) { - throw Throwables.propagate(ie); + synchronized (serverLock) { + if (server == null) { + try { + server = new RpcServer(conf); + } catch (InterruptedException ie) { + throw Throwables.propagate(ie); + } + } } } } @@ -64,7 +68,7 @@ public static synchronized void initialize(Map conf) throws IOEx /** Stops the SparkClient library. */ public static void stop() { if (server != null) { - synchronized (stopLock) { + synchronized (serverLock) { if (server != null) { server.close(); server = null; @@ -79,7 +83,7 @@ public static void stop() { * @param sparkConf Configuration for the remote Spark application, contains spark.* properties. * @param hiveConf Configuration for Hive, contains hive.* properties. */ - public static synchronized SparkClient createClient(Map sparkConf, HiveConf hiveConf) + public static SparkClient createClient(Map sparkConf, HiveConf hiveConf) throws IOException, SparkException { Preconditions.checkState(server != null, "initialize() not called."); return new SparkClientImpl(server, sparkConf, hiveConf); diff --git a/spark-client/src/main/java/org/apache/hive/spark/client/rpc/RpcConfiguration.java b/spark-client/src/main/java/org/apache/hive/spark/client/rpc/RpcConfiguration.java index 8c59015..b6fd4f8 100644 --- a/spark-client/src/main/java/org/apache/hive/spark/client/rpc/RpcConfiguration.java +++ b/spark-client/src/main/java/org/apache/hive/spark/client/rpc/RpcConfiguration.java @@ -19,6 +19,7 @@ import java.io.IOException; import java.util.ArrayList; +import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -66,7 +67,8 @@ private static final HiveConf DEFAULT_CONF = new HiveConf(); public RpcConfiguration(Map config) { - this.config = config; + // make sure we don't modify the config in RpcConfiguration + this.config = Collections.unmodifiableMap(config); } long getConnectTimeoutMs() { diff --git a/spark-client/src/main/java/org/apache/hive/spark/client/rpc/RpcServer.java b/spark-client/src/main/java/org/apache/hive/spark/client/rpc/RpcServer.java index 08fb535..6c6ab74 100644 --- a/spark-client/src/main/java/org/apache/hive/spark/client/rpc/RpcServer.java +++ b/spark-client/src/main/java/org/apache/hive/spark/client/rpc/RpcServer.java @@ -110,7 +110,6 @@ public void run() { } }) - .option(ChannelOption.SO_BACKLOG, 1) .option(ChannelOption.SO_REUSEADDR, true) .childOption(ChannelOption.SO_KEEPALIVE, true); diff --git a/spark-client/src/test/java/org/apache/hive/spark/client/rpc/TestRpc.java b/spark-client/src/test/java/org/apache/hive/spark/client/rpc/TestRpc.java index 5a4801c..21e4595 100644 --- a/spark-client/src/test/java/org/apache/hive/spark/client/rpc/TestRpc.java +++ b/spark-client/src/test/java/org/apache/hive/spark/client/rpc/TestRpc.java @@ -20,10 +20,17 @@ import java.io.Closeable; import java.net.InetAddress; import java.io.IOException; +import java.util.ArrayList; import java.util.Collection; import java.util.HashMap; +import java.util.List; import java.util.Map; +import java.util.UUID; +import java.util.concurrent.Callable; import java.util.concurrent.ExecutionException; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.ThreadLocalRandom; import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeoutException; import java.util.concurrent.atomic.AtomicInteger; @@ -53,7 +60,7 @@ private static final Logger LOG = LoggerFactory.getLogger(TestRpc.class); private Collection closeables; - private Map emptyConfig = + private static final Map emptyConfig = ImmutableMap.of(HiveConf.ConfVars.SPARK_RPC_CHANNEL_LOG_LEVEL.varname, "DEBUG"); @Before @@ -242,7 +249,7 @@ public void testEncryption() throws Exception { .put(RpcConfiguration.RPC_SASL_OPT_PREFIX + "qop", Rpc.SASL_AUTH_CONF) .build(); RpcServer server = autoClose(new RpcServer(eConf)); - Rpc[] rpcs = createRpcConnection(server, eConf); + Rpc[] rpcs = createRpcConnection(server, eConf, null); Rpc client = rpcs[1]; TestMessage outbound = new TestMessage("Hello World!"); @@ -278,6 +285,35 @@ public void testClientTimeout() throws Exception { } } + @Test + public void testRpcServerMultiThread() throws Exception { + final RpcServer server = autoClose(new RpcServer(emptyConfig)); + final String msg = "Hello World!"; + Callable callable = () -> { + Rpc[] rpcs = createRpcConnection(server, emptyConfig, UUID.randomUUID().toString()); + Rpc rpc; + if (ThreadLocalRandom.current().nextBoolean()) { + rpc = rpcs[0]; + } else { + rpc = rpcs[1]; + } + TestMessage outbound = new TestMessage("Hello World!"); + Future call = rpc.call(outbound, TestMessage.class); + TestMessage reply = call.get(10, TimeUnit.SECONDS); + return reply.message; + }; + final int numThreads = ThreadLocalRandom.current().nextInt(5) + 5; + ExecutorService executor = Executors.newFixedThreadPool(numThreads); + List> futures = new ArrayList<>(numThreads); + for (int i = 0; i < numThreads; i++) { + futures.add(executor.submit(callable)); + } + executor.shutdown(); + for (java.util.concurrent.Future future : futures) { + assertEquals(msg, future.get()); + } + } + private void transfer(Rpc serverRpc, Rpc clientRpc) { EmbeddedChannel client = (EmbeddedChannel) clientRpc.getChannel(); EmbeddedChannel server = (EmbeddedChannel) serverRpc.getChannel(); @@ -308,20 +344,23 @@ private void transfer(Rpc serverRpc, Rpc clientRpc) { * @return two-tuple (server rpc, client rpc) */ private Rpc[] createRpcConnection(RpcServer server) throws Exception { - return createRpcConnection(server, emptyConfig); + return createRpcConnection(server, emptyConfig, null); } - private Rpc[] createRpcConnection(RpcServer server, Map clientConf) - throws Exception { + private Rpc[] createRpcConnection(RpcServer server, Map clientConf, + String clientId) throws Exception { + if (clientId == null) { + clientId = "client"; + } String secret = server.createSecret(); - Future serverRpcFuture = server.registerClient("client", secret, new TestDispatcher()); + Future serverRpcFuture = server.registerClient(clientId, secret, new TestDispatcher()); NioEventLoopGroup eloop = new NioEventLoopGroup(); Future clientRpcFuture = Rpc.createClient(clientConf, eloop, - "localhost", server.getPort(), "client", secret, new TestDispatcher()); + "localhost", server.getPort(), clientId, secret, new TestDispatcher()); Rpc serverRpc = autoClose(serverRpcFuture.get(10, TimeUnit.SECONDS)); Rpc clientRpc = autoClose(clientRpcFuture.get(10, TimeUnit.SECONDS)); - return new Rpc[] { serverRpc, clientRpc }; + return new Rpc[]{serverRpc, clientRpc}; } private static class TestMessage {