diff --git a/ql/src/java/org/apache/hadoop/hive/ql/exec/spark/session/SparkSessionManagerImpl.java b/ql/src/java/org/apache/hadoop/hive/ql/exec/spark/session/SparkSessionManagerImpl.java index 616807c1f2..d8741f7a87 100644 --- a/ql/src/java/org/apache/hadoop/hive/ql/exec/spark/session/SparkSessionManagerImpl.java +++ b/ql/src/java/org/apache/hadoop/hive/ql/exec/spark/session/SparkSessionManagerImpl.java @@ -80,7 +80,7 @@ public void setup(HiveConf hiveConf) throws HiveException { LOG.info("Setting up the session manager."); Map conf = HiveSparkClientFactory.initiateSparkConf(hiveConf); try { - SparkClientFactory.initialize(conf); + SparkClientFactory.initialize(conf, hiveConf); inited = true; } catch (IOException e) { throw new HiveException("Error initializing SparkClientFactory", e); diff --git a/spark-client/src/main/java/org/apache/hive/spark/client/SparkClientFactory.java b/spark-client/src/main/java/org/apache/hive/spark/client/SparkClientFactory.java index 8cedd30e1b..6f8031681d 100644 --- a/spark-client/src/main/java/org/apache/hive/spark/client/SparkClientFactory.java +++ b/spark-client/src/main/java/org/apache/hive/spark/client/SparkClientFactory.java @@ -51,12 +51,12 @@ * * @param conf Map containing configuration parameters for the client library. */ - public static void initialize(Map conf) throws IOException { + public static void initialize(Map conf, HiveConf hiveConf) throws IOException { if (server == null) { synchronized (serverLock) { if (server == null) { try { - server = new RpcServer(conf); + server = new RpcServer(conf, hiveConf); } catch (InterruptedException ie) { throw Throwables.propagate(ie); } diff --git a/spark-client/src/main/java/org/apache/hive/spark/client/SparkClientImpl.java b/spark-client/src/main/java/org/apache/hive/spark/client/SparkClientImpl.java index f002bfe97e..155e563e3d 100644 --- a/spark-client/src/main/java/org/apache/hive/spark/client/SparkClientImpl.java +++ b/spark-client/src/main/java/org/apache/hive/spark/client/SparkClientImpl.java @@ -54,6 +54,8 @@ import java.util.concurrent.Future; import java.util.concurrent.TimeoutException; import java.util.concurrent.atomic.AtomicInteger; +import java.util.regex.Pattern; +import java.util.regex.Matcher; import org.apache.commons.lang3.StringUtils; import org.apache.hadoop.hive.conf.Constants; @@ -73,6 +75,7 @@ private static final long serialVersionUID = 1L; private static final Logger LOG = LoggerFactory.getLogger(SparkClientImpl.class); + private static final Pattern YARN_APPLICATION_ID_REGEX = Pattern.compile("\\s(application_[0-9]+_[0-9]+)(\\s|$)"); private static final long DEFAULT_SHUTDOWN_TIMEOUT = 10000; // In milliseconds private static final long MAX_ERR_LOG_LINES_FOR_RPC = 1000; @@ -481,7 +484,8 @@ public void run() { final Process child = pb.start(); String threadName = Thread.currentThread().getName(); final List childErrorLog = Collections.synchronizedList(new ArrayList()); - redirect("RemoteDriver-stdout-redir-" + threadName, new Redirector(child.getInputStream())); + final List childOutLog = Collections.synchronizedList(new ArrayList()); + redirect("RemoteDriver-stdout-redir-" + threadName, new Redirector(child.getInputStream(), childOutLog)); redirect("RemoteDriver-stderr-redir-" + threadName, new Redirector(child.getErrorStream(), childErrorLog)); runnable = new Runnable() { @@ -489,7 +493,27 @@ public void run() { public void run() { try { int exitCode = child.waitFor(); - if (exitCode != 0) { + LOG.info("Spark submit exit code " + exitCode); + if (exitCode == 0) { + synchronized (childOutLog) { + for (String line : childOutLog) { + Matcher m = YARN_APPLICATION_ID_REGEX.matcher(line); + if (m.find()) { + LOG.info("Found application id " + m.group(1)); + rpcServer.setApplicationId(m.group(1)); + } + } + } + synchronized (childErrorLog) { + for (String line : childErrorLog) { + Matcher m = YARN_APPLICATION_ID_REGEX.matcher(line); + if (m.find()) { + LOG.info("Found application id " + m.group(1)); + rpcServer.setApplicationId(m.group(1)); + } + } + } + } else { StringBuilder errStr = new StringBuilder(); synchronized(childErrorLog) { Iterator iter = childErrorLog.iterator(); diff --git a/spark-client/src/main/java/org/apache/hive/spark/client/rpc/RpcServer.java b/spark-client/src/main/java/org/apache/hive/spark/client/rpc/RpcServer.java index d3f295fce1..6a25af2955 100644 --- a/spark-client/src/main/java/org/apache/hive/spark/client/rpc/RpcServer.java +++ b/spark-client/src/main/java/org/apache/hive/spark/client/rpc/RpcServer.java @@ -60,6 +60,12 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.apache.hadoop.hive.common.classification.InterfaceAudience; +import org.apache.hadoop.hive.conf.HiveConf; +import org.apache.hadoop.yarn.api.records.ApplicationId; +import org.apache.hadoop.yarn.api.records.ApplicationReport; +import org.apache.hadoop.yarn.api.records.YarnApplicationState; +import org.apache.hadoop.yarn.client.api.YarnClient; + /** * An RPC server. The server matches remote clients based on a secret that is generated on @@ -78,9 +84,12 @@ private final int port; private final ConcurrentMap pendingClients; private final RpcConfiguration config; + private String applicationId; + private final HiveConf hiveConf; - public RpcServer(Map mapConf) throws IOException, InterruptedException { + public RpcServer(Map mapConf, HiveConf hiveConf) throws IOException, InterruptedException { this.config = new RpcConfiguration(mapConf); + this.hiveConf = hiveConf; this.group = new NioEventLoopGroup( this.config.getRpcThreadCount(), new ThreadFactoryBuilder() @@ -161,14 +170,89 @@ private ChannelFuture bindServerPort(ServerBootstrap serverBootstrap) return registerClient(clientId, secret, serverDispatcher, config.getServerConnectTimeoutMs()); } + public void setApplicationId(String applicationId) { + this.applicationId = applicationId; + } + + /** + * This function converts an application in form of a String into a {@link ApplicationId} + * + * @param appIDStr The application id in form of a string + * @return the application id as an instance of ApplicationId class. + */ + private static ApplicationId getApplicationIDFromString(String appIDStr) { + String[] parts = appIDStr.split("_"); + if (parts.length < 3) { + throw new IllegalStateException("the application id found is not valid. application id: " + appIDStr); + } + long timestamp = Long.valueOf(parts[1]); + int id = Integer.valueOf(parts[2]); + return ApplicationId.newInstance(timestamp, id); + } + + static class YarnApplicationStateFinder { + public boolean isApplicationAccepted(HiveConf conf, String applicationId) { + if (applicationId == null) { + return false; + } + YarnClient yarnClient = null; + try { + LOG.info("Trying to find " + applicationId); + ApplicationId appId = getApplicationIDFromString(applicationId); + yarnClient = YarnClient.createYarnClient(); + yarnClient.init(conf); + yarnClient.start(); + ApplicationReport appReport = yarnClient.getApplicationReport(appId); + return appReport != null && appReport.getYarnApplicationState() == YarnApplicationState.ACCEPTED; + } catch (Exception ex) { + LOG.error("Failed getting application status for: " + applicationId + ": " + ex, ex); + return false; + } finally { + if (yarnClient != null) { + try { + yarnClient.stop(); + } catch (Exception ex) { + LOG.error("Failed to stop yarn client: " + ex, ex); + } + } + } + } + } + + Future registerClient(final String clientId, String secret, + RpcDispatcher serverDispatcher, final long clientTimeoutMs) { + return registerClient(clientId, secret, serverDispatcher, clientTimeoutMs, new YarnApplicationStateFinder()); + } + @VisibleForTesting Future registerClient(final String clientId, String secret, - RpcDispatcher serverDispatcher, long clientTimeoutMs) { + RpcDispatcher serverDispatcher, final long clientTimeoutMs, + final YarnApplicationStateFinder yarnApplicationStateFinder) { final Promise promise = group.next().newPromise(); Runnable timeout = new Runnable() { @Override public void run() { + // check to see if application is in ACCEPTED state, if so, don't set failure + // if applicationId is not null + // do yarn application -status $applicationId + // if state == ACCEPTED + // reschedule timeout runnable + // else + // set failure as below + LOG.info("Trying to find " + applicationId); + if (yarnApplicationStateFinder.isApplicationAccepted(hiveConf, applicationId)) { + final ClientInfo client = pendingClients.get(clientId); + if (client != null) { + LOG.info("Extending timeout for client " + clientId); + ScheduledFuture oldTimeoutFuture = client.timeoutFuture; + client.timeoutFuture = group.schedule(this, + clientTimeoutMs, + TimeUnit.MILLISECONDS); + oldTimeoutFuture.cancel(true); + return; + } + } promise.setFailure(new TimeoutException("Timed out waiting for client connection.")); } }; @@ -351,7 +435,7 @@ public void handle(Callback[] callbacks) { final Promise promise; final String secret; final RpcDispatcher dispatcher; - final ScheduledFuture timeoutFuture; + ScheduledFuture timeoutFuture; private ClientInfo(String id, Promise promise, String secret, RpcDispatcher dispatcher, ScheduledFuture timeoutFuture) { diff --git a/spark-client/src/test/java/org/apache/hive/spark/client/TestSparkClient.java b/spark-client/src/test/java/org/apache/hive/spark/client/TestSparkClient.java index 406a015f33..72d9459233 100644 --- a/spark-client/src/test/java/org/apache/hive/spark/client/TestSparkClient.java +++ b/spark-client/src/test/java/org/apache/hive/spark/client/TestSparkClient.java @@ -268,7 +268,7 @@ public void call(SparkClient client) throws Exception { private void runTest(boolean local, TestFunction test) throws Exception { Map conf = createConf(local); - SparkClientFactory.initialize(conf); + SparkClientFactory.initialize(conf, HIVECONF); SparkClient client = null; try { test.config(conf); diff --git a/spark-client/src/test/java/org/apache/hive/spark/client/rpc/TestRpc.java b/spark-client/src/test/java/org/apache/hive/spark/client/rpc/TestRpc.java index fa173e56aa..102ec79c21 100644 --- a/spark-client/src/test/java/org/apache/hive/spark/client/rpc/TestRpc.java +++ b/spark-client/src/test/java/org/apache/hive/spark/client/rpc/TestRpc.java @@ -53,6 +53,8 @@ import org.junit.Before; import org.junit.Test; +import org.apache.hadoop.hive.conf.HiveConf; + import static org.junit.Assert.*; public class TestRpc { @@ -62,10 +64,12 @@ private Collection closeables; private static final Map emptyConfig = ImmutableMap.of(HiveConf.ConfVars.SPARK_RPC_CHANNEL_LOG_LEVEL.varname, "DEBUG"); + private HiveConf hiveConf; @Before public void setUp() { closeables = Lists.newArrayList(); + hiveConf = new HiveConf(); } @After @@ -97,7 +101,7 @@ public void testRpcDispatcher() throws Exception { @Test public void testClientServer() throws Exception { - RpcServer server = autoClose(new RpcServer(emptyConfig)); + RpcServer server = autoClose(new RpcServer(emptyConfig, hiveConf)); Rpc[] rpcs = createRpcConnection(server); Rpc serverRpc = rpcs[0]; Rpc client = rpcs[1]; @@ -134,25 +138,25 @@ public void testServerAddress() throws Exception { // Test if rpc_server_address is configured config.put(HiveConf.ConfVars.SPARK_RPC_SERVER_ADDRESS.varname, hostAddress); - RpcServer server1 = autoClose(new RpcServer(config)); + RpcServer server1 = autoClose(new RpcServer(config, hiveConf)); assertTrue("Host address should match the expected one", server1.getAddress() == hostAddress); // Test if rpc_server_address is not configured but HS2 server host is configured config.put(HiveConf.ConfVars.SPARK_RPC_SERVER_ADDRESS.varname, ""); config.put(HiveConf.ConfVars.HIVE_SERVER2_THRIFT_BIND_HOST.varname, hostAddress); - RpcServer server2 = autoClose(new RpcServer(config)); + RpcServer server2 = autoClose(new RpcServer(config, hiveConf)); assertTrue("Host address should match the expected one", server2.getAddress() == hostAddress); // Test if both are not configured config.put(HiveConf.ConfVars.SPARK_RPC_SERVER_ADDRESS.varname, ""); config.put(HiveConf.ConfVars.HIVE_SERVER2_THRIFT_BIND_HOST.varname, ""); - RpcServer server3 = autoClose(new RpcServer(config)); + RpcServer server3 = autoClose(new RpcServer(config, hiveConf)); assertTrue("Host address should match the expected one", server3.getAddress() == InetAddress.getLocalHost().getHostName()); } @Test public void testBadHello() throws Exception { - RpcServer server = autoClose(new RpcServer(emptyConfig)); + RpcServer server = autoClose(new RpcServer(emptyConfig, hiveConf)); Future serverRpcFuture = server.registerClient("client", "newClient", new TestDispatcher()); @@ -178,24 +182,24 @@ public void testBadHello() throws Exception { public void testServerPort() throws Exception { Map config = new HashMap(); - RpcServer server0 = new RpcServer(config); + RpcServer server0 = new RpcServer(config, hiveConf); assertTrue("Empty port range should return a random valid port: " + server0.getPort(), server0.getPort() >= 0); IOUtils.closeQuietly(server0); config.put(HiveConf.ConfVars.SPARK_RPC_SERVER_PORT.varname, "49152-49222,49223,49224-49333"); - RpcServer server1 = new RpcServer(config); + RpcServer server1 = new RpcServer(config,hiveConf); assertTrue("Port should be within configured port range:" + server1.getPort(), server1.getPort() >= 49152 && server1.getPort() <= 49333); IOUtils.closeQuietly(server1); int expectedPort = 65535; config.put(HiveConf.ConfVars.SPARK_RPC_SERVER_PORT.varname, String.valueOf(expectedPort)); - RpcServer server2 = new RpcServer(config); + RpcServer server2 = new RpcServer(config, hiveConf); assertTrue("Port should match configured one: " + server2.getPort(), server2.getPort() == expectedPort); IOUtils.closeQuietly(server2); config.put(HiveConf.ConfVars.SPARK_RPC_SERVER_PORT.varname, "49552-49222,49223,49224-49333"); try { - autoClose(new RpcServer(config)); + autoClose(new RpcServer(config, hiveConf)); assertTrue("Invalid port range should throw an exception", false); // Should not reach here } catch(IOException e) { assertEquals("Incorrect RPC server port configuration for HiveServer2", e.getMessage()); @@ -204,14 +208,14 @@ public void testServerPort() throws Exception { // Retry logic expectedPort = 65535; config.put(HiveConf.ConfVars.SPARK_RPC_SERVER_PORT.varname, String.valueOf(expectedPort) + ",21-23"); - RpcServer server3 = new RpcServer(config); + RpcServer server3 = new RpcServer(config, hiveConf); assertTrue("Port should match configured one:" + server3.getPort(), server3.getPort() == expectedPort); IOUtils.closeQuietly(server3); } @Test public void testCloseListener() throws Exception { - RpcServer server = autoClose(new RpcServer(emptyConfig)); + RpcServer server = autoClose(new RpcServer(emptyConfig, hiveConf)); Rpc[] rpcs = createRpcConnection(server); Rpc client = rpcs[1]; @@ -230,7 +234,7 @@ public void rpcClosed(Rpc rpc) { @Test public void testNotDeserializableRpc() throws Exception { - RpcServer server = autoClose(new RpcServer(emptyConfig)); + RpcServer server = autoClose(new RpcServer(emptyConfig, hiveConf)); Rpc[] rpcs = createRpcConnection(server); Rpc client = rpcs[1]; @@ -248,7 +252,7 @@ public void testEncryption() throws Exception { .putAll(emptyConfig) .put(RpcConfiguration.RPC_SASL_OPT_PREFIX + "qop", Rpc.SASL_AUTH_CONF) .build(); - RpcServer server = autoClose(new RpcServer(eConf)); + RpcServer server = autoClose(new RpcServer(eConf, hiveConf)); Rpc[] rpcs = createRpcConnection(server, eConf, null); Rpc client = rpcs[1]; @@ -263,7 +267,7 @@ public void testClientTimeout() throws Exception { Map conf = ImmutableMap.builder() .putAll(emptyConfig) .build(); - RpcServer server = autoClose(new RpcServer(conf)); + RpcServer server = autoClose(new RpcServer(conf, hiveConf)); String secret = server.createSecret(); try { @@ -285,9 +289,51 @@ public void testClientTimeout() throws Exception { } } + static class MockYarnApplicationStateFinder extends RpcServer.YarnApplicationStateFinder { + private int count = 0; + public boolean isApplicationAccepted(HiveConf conf, String applicationId) { + return count++ < 10; + } + } + + + /** + * Tests that we don't timeout with a short timeout but the spark application isn't running + */ + @Test + public void testExtendClientTimeout() throws Exception { + Map conf = ImmutableMap.builder() + .putAll(emptyConfig) + .build(); + RpcServer server = autoClose(new RpcServer(conf, hiveConf)); + String secret = server.createSecret(); + MockYarnApplicationStateFinder yarnApplicationStateFinder = new MockYarnApplicationStateFinder(); + Future promise = server.registerClient("client", secret, new TestDispatcher(), 2L, + yarnApplicationStateFinder); + assertFalse(promise.isDone()); + Thread.sleep(50); + try { + promise.get(); + fail("Server should have timed out client."); + } catch (ExecutionException ee) { + assertTrue(ee.getCause() instanceof TimeoutException); + } + + NioEventLoopGroup eloop = new NioEventLoopGroup(); + Future clientRpcFuture = Rpc.createClient(conf, eloop, + "localhost", server.getPort(), "client", secret, new TestDispatcher()); + try { + autoClose(clientRpcFuture.get()); + fail("Client should have failed to connect to server."); + } catch (ExecutionException ee) { + // Error should not be a timeout. + assertFalse(ee.getCause() instanceof TimeoutException); + } + } + @Test public void testRpcServerMultiThread() throws Exception { - final RpcServer server = autoClose(new RpcServer(emptyConfig)); + final RpcServer server = autoClose(new RpcServer(emptyConfig, hiveConf)); final String msg = "Hello World!"; Callable callable = new Callable() { public String call() throws Exception {