diff --git a/spark-client/src/main/java/org/apache/hive/spark/client/RemoteDriver.java b/spark-client/src/main/java/org/apache/hive/spark/client/RemoteDriver.java index ede8ce9e40..68f91ea80d 100644 --- a/spark-client/src/main/java/org/apache/hive/spark/client/RemoteDriver.java +++ b/spark-client/src/main/java/org/apache/hive/spark/client/RemoteDriver.java @@ -88,7 +88,7 @@ private volatile JobContextImpl jc; private volatile boolean running; - private RemoteDriver(String[] args) throws Exception { + public RemoteDriver(String[] args) throws Exception { this.activeJobs = Maps.newConcurrentMap(); this.jcLock = new Object(); this.shutdownLock = new Object(); @@ -176,7 +176,7 @@ public void rpcClosed(Rpc rpc) { } } - private void run() throws InterruptedException { + public void run() throws InterruptedException { synchronized (shutdownLock) { while (running) { shutdownLock.wait(); @@ -201,7 +201,7 @@ private void submit(JobWrapper job) { } } - private synchronized void shutdown(Throwable error) { + public synchronized void shutdown(Throwable error) { if (running) { if (error == null) { LOG.info("Shutting down remote driver."); diff --git a/spark-client/src/main/java/org/apache/hive/spark/client/SparkClientImpl.java b/spark-client/src/main/java/org/apache/hive/spark/client/SparkClientImpl.java index f6a23dc600..2144479a96 100644 --- a/spark-client/src/main/java/org/apache/hive/spark/client/SparkClientImpl.java +++ b/spark-client/src/main/java/org/apache/hive/spark/client/SparkClientImpl.java @@ -90,6 +90,9 @@ private final ClientProtocol protocol; private volatile boolean isAlive; + // Purely for testing purposes + private RemoteDriver remoteDriver; + SparkClientImpl(RpcServer rpcServer, Map conf, HiveConf hiveConf) throws IOException, SparkException { this.conf = conf; this.hiveConf = hiveConf; @@ -177,6 +180,10 @@ public void stop() { LOG.warn("Timed out shutting down remote driver, interrupting..."); driverThread.interrupt(); } + + if (remoteDriver != null) { + remoteDriver.shutdown(null); + } } @Override @@ -236,7 +243,8 @@ public void run() { args.add(String.format("%s=%s", e.getKey(), conf.get(e.getKey()))); } try { - RemoteDriver.main(args.toArray(new String[args.size()])); + remoteDriver = new RemoteDriver(args.toArray(new String[args.size()])); + remoteDriver.run(); } catch (Exception e) { LOG.error("Error running driver.", e); }