diff --git a/ql/src/java/org/apache/hadoop/hive/ql/exec/spark/session/SparkSessionManagerImpl.java b/ql/src/java/org/apache/hadoop/hive/ql/exec/spark/session/SparkSessionManagerImpl.java index 79a56bd715..8dae54d7bd 100644 --- a/ql/src/java/org/apache/hadoop/hive/ql/exec/spark/session/SparkSessionManagerImpl.java +++ b/ql/src/java/org/apache/hadoop/hive/ql/exec/spark/session/SparkSessionManagerImpl.java @@ -96,7 +96,7 @@ public void setup(HiveConf hiveConf) throws HiveException { startTimeoutThread(); Map sparkConf = HiveSparkClientFactory.initiateSparkConf(hiveConf, null); try { - SparkClientFactory.initialize(sparkConf); + SparkClientFactory.initialize(sparkConf, hiveConf); inited = true; } catch (IOException e) { throw new HiveException("Error initializing SparkClientFactory", e); diff --git a/spark-client/src/main/java/org/apache/hive/spark/client/SparkClientFactory.java b/spark-client/src/main/java/org/apache/hive/spark/client/SparkClientFactory.java index 1974e88523..54ecdf08e1 100644 --- a/spark-client/src/main/java/org/apache/hive/spark/client/SparkClientFactory.java +++ b/spark-client/src/main/java/org/apache/hive/spark/client/SparkClientFactory.java @@ -47,12 +47,12 @@ * * @param conf Map containing configuration parameters for the client library. */ - public static void initialize(Map conf) throws IOException { + public static void initialize(Map conf, HiveConf hiveConf) throws IOException { if (server == null) { synchronized (serverLock) { if (server == null) { try { - server = new RpcServer(conf); + server = new RpcServer(conf, hiveConf); } catch (InterruptedException ie) { throw Throwables.propagate(ie); } diff --git a/spark-client/src/main/java/org/apache/hive/spark/client/SparkSubmitSparkClient.java b/spark-client/src/main/java/org/apache/hive/spark/client/SparkSubmitSparkClient.java index 7a6e77bdc6..1879829700 100644 --- a/spark-client/src/main/java/org/apache/hive/spark/client/SparkSubmitSparkClient.java +++ b/spark-client/src/main/java/org/apache/hive/spark/client/SparkSubmitSparkClient.java @@ -31,6 +31,8 @@ import java.util.concurrent.Callable; import java.util.concurrent.Future; import java.util.concurrent.FutureTask; +import java.util.regex.Pattern; +import java.util.regex.Matcher; import org.apache.commons.lang3.StringUtils; @@ -51,6 +53,7 @@ private static final Logger LOG = LoggerFactory.getLogger(SparkSubmitSparkClient.class); + private static final Pattern YARN_APPLICATION_ID_REGEX = Pattern.compile("\\s(application_[0-9]+_[0-9]+)(\\s|$)"); private static final String SPARK_HOME_ENV = "SPARK_HOME"; private static final String SPARK_HOME_KEY = "spark.home"; @@ -191,17 +194,37 @@ private String getSparkJobCredentialProviderPassword() { final Process child = pb.start(); String threadName = Thread.currentThread().getName(); final List childErrorLog = Collections.synchronizedList(new ArrayList()); + final List childOutLog = Collections.synchronizedList(new ArrayList()); final LogRedirector.LogSourceCallback callback = () -> isAlive; LogRedirector.redirect("spark-submit-stdout-redir-" + threadName, - new LogRedirector(child.getInputStream(), LOG, callback)); + new LogRedirector(child.getInputStream(), LOG, childOutLog, callback)); LogRedirector.redirect("spark-submit-stderr-redir-" + threadName, new LogRedirector(child.getErrorStream(), LOG, childErrorLog, callback)); runnable = () -> { try { int exitCode = child.waitFor(); - if (exitCode != 0) { + if (exitCode == 0) { + synchronized (childOutLog) { + for (String line : childOutLog) { + Matcher m = YARN_APPLICATION_ID_REGEX.matcher(line); + if (m.find()) { + LOG.info("Found application id " + m.group(1)); + rpcServer.setApplicationId(m.group(1)); + } + } + } + synchronized (childErrorLog) { + for (String line : childErrorLog) { + Matcher m = YARN_APPLICATION_ID_REGEX.matcher(line); + if (m.find()) { + LOG.info("Found application id " + m.group(1)); + rpcServer.setApplicationId(m.group(1)); + } + } + } + } else { List errorMessages = new ArrayList<>(); synchronized (childErrorLog) { for (String line : childErrorLog) { diff --git a/spark-client/src/main/java/org/apache/hive/spark/client/rpc/RpcServer.java b/spark-client/src/main/java/org/apache/hive/spark/client/rpc/RpcServer.java index 0c67ffd813..0172119b25 100644 --- a/spark-client/src/main/java/org/apache/hive/spark/client/rpc/RpcServer.java +++ b/spark-client/src/main/java/org/apache/hive/spark/client/rpc/RpcServer.java @@ -61,6 +61,11 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.apache.hadoop.hive.common.classification.InterfaceAudience; +import org.apache.hadoop.hive.conf.HiveConf; +import org.apache.hadoop.yarn.api.records.ApplicationId; +import org.apache.hadoop.yarn.api.records.ApplicationReport; +import org.apache.hadoop.yarn.api.records.YarnApplicationState; +import org.apache.hadoop.yarn.client.api.YarnClient; /** * An RPC server. The server matches remote clients based on a secret that is generated on @@ -79,9 +84,12 @@ private final int port; private final ConcurrentMap pendingClients; private final RpcConfiguration config; + private String applicationId; + private final HiveConf hiveConf; - public RpcServer(Map mapConf) throws IOException, InterruptedException { + public RpcServer(Map mapConf, HiveConf hiveConf) throws IOException, InterruptedException { this.config = new RpcConfiguration(mapConf); + this.hiveConf = hiveConf; this.group = new NioEventLoopGroup( this.config.getRpcThreadCount(), new ThreadFactoryBuilder() @@ -166,14 +174,116 @@ private ChannelFuture bindServerPort(ServerBootstrap serverBootstrap) return registerClient(clientId, secret, serverDispatcher, config.getServerConnectTimeoutMs()); } + public void setApplicationId(String applicationId) { + this.applicationId = applicationId; + } + + /** + * This function converts an application in form of a String into a {@link ApplicationId} + * + * @param appIDStr The application id in form of a string + * @return the application id as an instance of ApplicationId class. + */ + private static ApplicationId getApplicationIDFromString(String appIDStr) { + String[] parts = appIDStr.split("_"); + if (parts.length < 3) { + throw new IllegalStateException("the application id found is not valid. application id: " + appIDStr); + } + long timestamp = Long.valueOf(parts[1]); + int id = Integer.valueOf(parts[2]); + return ApplicationId.newInstance(timestamp, id); + } + + public static boolean isApplicationAccepted(HiveConf conf, String applicationId) { + if (applicationId == null) { + return false; + } + YarnClient yarnClient = null; + try { + ApplicationId appId = getApplicationIDFromString(applicationId); + yarnClient = YarnClient.createYarnClient(); + yarnClient.init(conf); + yarnClient.start(); + ApplicationReport appReport = yarnClient.getApplicationReport(appId); + return appReport != null && appReport.getYarnApplicationState() == YarnApplicationState.ACCEPTED; + } catch (Exception ex) { + LOG.error("Failed getting application status for: " + applicationId + ": " + ex, ex); + return false; + } finally { + if (yarnClient != null) { + try { + yarnClient.stop(); + } catch (Exception ex) { + LOG.error("Failed to stop yarn client: " + ex, ex); + } + } + } + } + + static class YarnApplicationStateFinder { + public boolean isApplicationAccepted(HiveConf conf, String applicationId) { + if (applicationId == null) { + return false; + } + YarnClient yarnClient = null; + try { + LOG.info("Trying to find " + applicationId); + ApplicationId appId = getApplicationIDFromString(applicationId); + yarnClient = YarnClient.createYarnClient(); + yarnClient.init(conf); + yarnClient.start(); + ApplicationReport appReport = yarnClient.getApplicationReport(appId); + return appReport != null && appReport.getYarnApplicationState() == YarnApplicationState.ACCEPTED; + } catch (Exception ex) { + LOG.error("Failed getting application status for: " + applicationId + ": " + ex, ex); + return false; + } finally { + if (yarnClient != null) { + try { + yarnClient.stop(); + } catch (Exception ex) { + LOG.error("Failed to stop yarn client: " + ex, ex); + } + } + } + } + } + @VisibleForTesting Future registerClient(final String clientId, String secret, - RpcDispatcher serverDispatcher, long clientTimeoutMs) { + RpcDispatcher serverDispatcher, final long clientTimeoutMs) { + return registerClient(clientId, secret, serverDispatcher, clientTimeoutMs, new YarnApplicationStateFinder()); + } + + @VisibleForTesting + Future registerClient(final String clientId, String secret, + RpcDispatcher serverDispatcher, long clientTimeoutMs, + YarnApplicationStateFinder yarnApplicationStateFinder) { final Promise promise = group.next().newPromise(); Runnable timeout = new Runnable() { @Override public void run() { + // check to see if application is in ACCEPTED state, if so, don't set failure + // if applicationId is not null + // do yarn application -status $applicationId + // if state == ACCEPTED + // reschedule timeout runnable + // else + // set failure as below + LOG.info("Trying to find " + applicationId); + if (yarnApplicationStateFinder.isApplicationAccepted(hiveConf, applicationId)) { + final ClientInfo client = pendingClients.get(clientId); + if (client != null) { + LOG.info("Extending timeout for client " + clientId); + ScheduledFuture oldTimeoutFuture = client.timeoutFuture; + client.timeoutFuture = group.schedule(this, + clientTimeoutMs, + TimeUnit.MILLISECONDS); + oldTimeoutFuture.cancel(true); + return; + } + } promise.setFailure(new TimeoutException( String.format("Client '%s' timed out waiting for connection from the Remote Spark" + " Driver", clientId))); @@ -369,7 +479,7 @@ public void handle(Callback[] callbacks) { final Promise promise; final String secret; final RpcDispatcher dispatcher; - final ScheduledFuture timeoutFuture; + ScheduledFuture timeoutFuture; private ClientInfo(String id, Promise promise, String secret, RpcDispatcher dispatcher, ScheduledFuture timeoutFuture) { diff --git a/spark-client/src/test/java/org/apache/hive/spark/client/TestSparkClient.java b/spark-client/src/test/java/org/apache/hive/spark/client/TestSparkClient.java index d7380038fa..996b24ed7f 100644 --- a/spark-client/src/test/java/org/apache/hive/spark/client/TestSparkClient.java +++ b/spark-client/src/test/java/org/apache/hive/spark/client/TestSparkClient.java @@ -339,7 +339,7 @@ public Void answer(InvocationOnMock invocation) throws Throwable { private void runTest(TestFunction test) throws Exception { Map conf = createConf(); - SparkClientFactory.initialize(conf); + SparkClientFactory.initialize(conf, HIVECONF); SparkClient client = null; try { test.config(conf); diff --git a/spark-client/src/test/java/org/apache/hive/spark/client/rpc/TestRpc.java b/spark-client/src/test/java/org/apache/hive/spark/client/rpc/TestRpc.java index 013bcff30c..435d1b0544 100644 --- a/spark-client/src/test/java/org/apache/hive/spark/client/rpc/TestRpc.java +++ b/spark-client/src/test/java/org/apache/hive/spark/client/rpc/TestRpc.java @@ -54,6 +54,8 @@ import org.junit.Before; import org.junit.Test; +import org.apache.hadoop.hive.conf.HiveConf; + import static org.junit.Assert.*; public class TestRpc { @@ -64,10 +66,12 @@ private static final Map emptyConfig = ImmutableMap.of(HiveConf.ConfVars.SPARK_RPC_CHANNEL_LOG_LEVEL.varname, "DEBUG"); private static final int RETRY_ACQUIRE_PORT_COUNT = 10; + private HiveConf hiveConf; @Before public void setUp() { closeables = Lists.newArrayList(); + hiveConf = new HiveConf(); } @After @@ -99,7 +103,7 @@ public void testRpcDispatcher() throws Exception { @Test public void testClientServer() throws Exception { - RpcServer server = autoClose(new RpcServer(emptyConfig)); + RpcServer server = autoClose(new RpcServer(emptyConfig, hiveConf)); Rpc[] rpcs = createRpcConnection(server); Rpc serverRpc = rpcs[0]; Rpc client = rpcs[1]; @@ -136,25 +140,25 @@ public void testServerAddress() throws Exception { // Test if rpc_server_address is configured config.put(HiveConf.ConfVars.SPARK_RPC_SERVER_ADDRESS.varname, hostAddress); - RpcServer server1 = autoClose(new RpcServer(config)); + RpcServer server1 = autoClose(new RpcServer(config, hiveConf)); assertTrue("Host address should match the expected one", server1.getAddress() == hostAddress); // Test if rpc_server_address is not configured but HS2 server host is configured config.put(HiveConf.ConfVars.SPARK_RPC_SERVER_ADDRESS.varname, ""); config.put(HiveConf.ConfVars.HIVE_SERVER2_THRIFT_BIND_HOST.varname, hostAddress); - RpcServer server2 = autoClose(new RpcServer(config)); + RpcServer server2 = autoClose(new RpcServer(config, hiveConf)); assertTrue("Host address should match the expected one", server2.getAddress() == hostAddress); // Test if both are not configured config.put(HiveConf.ConfVars.SPARK_RPC_SERVER_ADDRESS.varname, ""); config.put(HiveConf.ConfVars.HIVE_SERVER2_THRIFT_BIND_HOST.varname, ""); - RpcServer server3 = autoClose(new RpcServer(config)); + RpcServer server3 = autoClose(new RpcServer(config, hiveConf)); assertTrue("Host address should match the expected one", server3.getAddress() == InetAddress.getLocalHost().getHostName()); } @Test public void testBadHello() throws Exception { - RpcServer server = autoClose(new RpcServer(emptyConfig)); + RpcServer server = autoClose(new RpcServer(emptyConfig, hiveConf)); Future serverRpcFuture = server.registerClient("client", "newClient", new TestDispatcher()); @@ -180,12 +184,12 @@ public void testBadHello() throws Exception { public void testServerPort() throws Exception { Map config = new HashMap(); - RpcServer server0 = new RpcServer(config); + RpcServer server0 = new RpcServer(config, hiveConf); assertTrue("Empty port range should return a random valid port: " + server0.getPort(), server0.getPort() >= 0); IOUtils.closeQuietly(server0); config.put(HiveConf.ConfVars.SPARK_RPC_SERVER_PORT.varname, "49152-49222,49223,49224-49333"); - RpcServer server1 = new RpcServer(config); + RpcServer server1 = new RpcServer(config, hiveConf); assertTrue("Port should be within configured port range:" + server1.getPort(), server1.getPort() >= 49152 && server1.getPort() <= 49333); IOUtils.closeQuietly(server1); @@ -194,7 +198,7 @@ public void testServerPort() throws Exception { for (int i = 0; i < RETRY_ACQUIRE_PORT_COUNT; i++) { try { config.put(HiveConf.ConfVars.SPARK_RPC_SERVER_PORT.varname, String.valueOf(expectedPort)); - server2 = new RpcServer(config); + server2 = new RpcServer(config, hiveConf); break; } catch (Exception e) { LOG.debug("Error while connecting to port " + expectedPort + " retrying: " + e.getMessage()); @@ -208,7 +212,7 @@ public void testServerPort() throws Exception { config.put(HiveConf.ConfVars.SPARK_RPC_SERVER_PORT.varname, "49552-49222,49223,49224-49333"); try { - autoClose(new RpcServer(config)); + autoClose(new RpcServer(config, hiveConf)); assertTrue("Invalid port range should throw an exception", false); // Should not reach here } catch(IllegalArgumentException e) { assertEquals( @@ -222,7 +226,7 @@ public void testServerPort() throws Exception { for (int i = 0; i < RETRY_ACQUIRE_PORT_COUNT; i++) { try { config.put(HiveConf.ConfVars.SPARK_RPC_SERVER_PORT.varname, String.valueOf(expectedPort) + ",21-23"); - server3 = new RpcServer(config); + server3 = new RpcServer(config, hiveConf); break; } catch (Exception e) { LOG.debug("Error while connecting to port " + expectedPort + " retrying"); @@ -236,7 +240,7 @@ public void testServerPort() throws Exception { @Test public void testCloseListener() throws Exception { - RpcServer server = autoClose(new RpcServer(emptyConfig)); + RpcServer server = autoClose(new RpcServer(emptyConfig, hiveConf)); Rpc[] rpcs = createRpcConnection(server); Rpc client = rpcs[1]; @@ -255,7 +259,7 @@ public void rpcClosed(Rpc rpc) { @Test public void testNotDeserializableRpc() throws Exception { - RpcServer server = autoClose(new RpcServer(emptyConfig)); + RpcServer server = autoClose(new RpcServer(emptyConfig, hiveConf)); Rpc[] rpcs = createRpcConnection(server); Rpc client = rpcs[1]; @@ -273,7 +277,7 @@ public void testEncryption() throws Exception { .putAll(emptyConfig) .put(RpcConfiguration.RPC_SASL_OPT_PREFIX + "qop", Rpc.SASL_AUTH_CONF) .build(); - RpcServer server = autoClose(new RpcServer(eConf)); + RpcServer server = autoClose(new RpcServer(eConf, hiveConf)); Rpc[] rpcs = createRpcConnection(server, eConf, null); Rpc client = rpcs[1]; @@ -288,7 +292,7 @@ public void testClientTimeout() throws Exception { Map conf = ImmutableMap.builder() .putAll(emptyConfig) .build(); - RpcServer server = autoClose(new RpcServer(conf)); + RpcServer server = autoClose(new RpcServer(conf, hiveConf)); String secret = server.createSecret(); try { @@ -310,9 +314,53 @@ public void testClientTimeout() throws Exception { } } + static class MockYarnApplicationStateFinder extends RpcServer.YarnApplicationStateFinder { + private int count = 0; + public boolean isApplicationAccepted(HiveConf conf, String applicationId) { + return count++ < 10; + } + } + + + /** + * Tests that we don't timeout with a short timeout but the spark application isn't running + */ + @Test + public void testExtendClientTimeout() throws Exception { + Map conf = ImmutableMap.builder() + .putAll(emptyConfig) + .build(); + RpcServer server = autoClose(new RpcServer(conf, hiveConf)); + String secret = server.createSecret(); + MockYarnApplicationStateFinder yarnApplicationStateFinder = new MockYarnApplicationStateFinder(); + Future promise = server.registerClient("client", secret, new TestDispatcher(), 2L, + yarnApplicationStateFinder); + assertFalse(promise.isDone()); + Thread.sleep(50); + try { + promise.get(); + fail("Server should have timed out client."); + } catch (ExecutionException ee) { + assertTrue(ee.getCause() instanceof TimeoutException); + } + + NioEventLoopGroup eloop = new NioEventLoopGroup(); + Future clientRpcFuture = Rpc.createClient(conf, eloop, + "localhost", server.getPort(), "client", secret, new TestDispatcher()); + try { + autoClose(clientRpcFuture.get()); + fail("Client should have failed to connect to server."); + } catch (ExecutionException ee) { + // Error should not be a timeout. + assertFalse(ee.getCause() instanceof TimeoutException); + } + } + + + @Test public void testRpcServerMultiThread() throws Exception { - final RpcServer server = autoClose(new RpcServer(emptyConfig)); + final RpcServer server = autoClose(new RpcServer(emptyConfig, hiveConf)); final String msg = "Hello World!"; Callable callable = () -> { Rpc[] rpcs = createRpcConnection(server, emptyConfig, UUID.randomUUID().toString());