diff --git spark-client/src/main/java/org/apache/hive/spark/client/JobHandle.java spark-client/src/main/java/org/apache/hive/spark/client/JobHandle.java index 44aa255..c02c403 100644 --- spark-client/src/main/java/org/apache/hive/spark/client/JobHandle.java +++ spark-client/src/main/java/org/apache/hive/spark/client/JobHandle.java @@ -61,14 +61,6 @@ State getState(); /** - * Add a listener to the job handle. If the job's state is not SENT, a callback for the - * corresponding state will be invoked immediately. - * - * @param l The listener to add. - */ - void addListener(Listener l); - - /** * The current state of the submitted job. */ static enum State { diff --git spark-client/src/main/java/org/apache/hive/spark/client/JobHandleImpl.java spark-client/src/main/java/org/apache/hive/spark/client/JobHandleImpl.java index 17c8f40..2ce5e8e 100644 --- spark-client/src/main/java/org/apache/hive/spark/client/JobHandleImpl.java +++ spark-client/src/main/java/org/apache/hive/spark/client/JobHandleImpl.java @@ -122,8 +122,7 @@ public State getState() { return state; } - @Override - public void addListener(Listener l) { + void addListener(Listener l) { synchronized (listeners) { listeners.add(l); // If current state is a final state, notify of Spark job IDs before notifying about the diff --git spark-client/src/main/java/org/apache/hive/spark/client/SparkClient.java spark-client/src/main/java/org/apache/hive/spark/client/SparkClient.java index 3e921a5..e952f27 100644 --- spark-client/src/main/java/org/apache/hive/spark/client/SparkClient.java +++ spark-client/src/main/java/org/apache/hive/spark/client/SparkClient.java @@ -19,6 +19,7 @@ import java.io.Serializable; import java.net.URI; +import java.util.List; import java.util.concurrent.Future; import org.apache.hadoop.hive.common.classification.InterfaceAudience; @@ -38,6 +39,15 @@ JobHandle submit(Job job); /** + * Submits a job for asynchronous execution. + * + * @param job The job to execute. + * @param listeners jobhandle listeners to invoke during the job processing + * @return A handle that be used to monitor the job. + */ + JobHandle submit(Job job, List> listeners); + + /** * Asks the remote context to run a job immediately. *

* Normally, the remote context will queue jobs and execute them based on how many worker diff --git spark-client/src/main/java/org/apache/hive/spark/client/SparkClientImpl.java spark-client/src/main/java/org/apache/hive/spark/client/SparkClientImpl.java index e2a30a7..baf14a0 100644 --- spark-client/src/main/java/org/apache/hive/spark/client/SparkClientImpl.java +++ spark-client/src/main/java/org/apache/hive/spark/client/SparkClientImpl.java @@ -45,6 +45,7 @@ import java.net.URI; import java.net.URL; import java.util.ArrayList; +import java.util.Collections; import java.util.List; import java.util.Map; import java.util.Properties; @@ -137,7 +138,12 @@ public void rpcClosed(Rpc rpc) { @Override public JobHandle submit(Job job) { - return protocol.submit(job); + return protocol.submit(job, Collections.>emptyList()); + } + + @Override + public JobHandle submit(Job job, List> listeners) { + return protocol.submit(job, listeners); } @Override @@ -510,10 +516,13 @@ private void redirect(String name, Redirector redirector) { private class ClientProtocol extends BaseProtocol { - JobHandleImpl submit(Job job) { + JobHandleImpl submit(Job job, List> listeners) { final String jobId = UUID.randomUUID().toString(); final Promise promise = driverRpc.createPromise(); final JobHandleImpl handle = new JobHandleImpl(SparkClientImpl.this, promise, jobId); + for (JobHandle.Listener l : listeners) { + handle.addListener(l); + } jobs.put(jobId, handle); final io.netty.util.concurrent.Future rpc = driverRpc.call(new JobRequest(jobId, job)); diff --git spark-client/src/test/java/org/apache/hive/spark/client/TestSparkClient.java spark-client/src/test/java/org/apache/hive/spark/client/TestSparkClient.java index b95cd7a..47eabbe 100644 --- spark-client/src/test/java/org/apache/hive/spark/client/TestSparkClient.java +++ spark-client/src/test/java/org/apache/hive/spark/client/TestSparkClient.java @@ -17,6 +17,7 @@ package org.apache.hive.spark.client; +import com.google.common.collect.Lists; import org.apache.hive.spark.client.JobHandle.Listener; import org.slf4j.Logger; @@ -27,8 +28,6 @@ import org.mockito.stubbing.Answer; -import org.mockito.stubbing.Answer; - import org.mockito.Mockito; import java.io.File; @@ -37,8 +36,10 @@ import java.io.InputStream; import java.io.Serializable; import java.net.URI; +import java.util.ArrayList; import java.util.Arrays; import java.util.HashMap; +import java.util.List; import java.util.Map; import java.util.concurrent.ExecutionException; import java.util.concurrent.Future; @@ -94,8 +95,8 @@ public void testJobSubmission() throws Exception { @Override public void call(SparkClient client) throws Exception { JobHandle.Listener listener = newListener(); - JobHandle handle = client.submit(new SimpleJob()); - handle.addListener(listener); + List> listeners = Lists.newArrayList(listener);; + JobHandle handle = client.submit(new SimpleJob(), listeners); assertEquals("hello", handle.get(TIMEOUT, TimeUnit.SECONDS)); // Try an invalid state transition on the handle. This ensures that the actual state @@ -127,8 +128,8 @@ public void testErrorJob() throws Exception { @Override public void call(SparkClient client) throws Exception { JobHandle.Listener listener = newListener(); - JobHandle handle = client.submit(new ErrorJob()); - handle.addListener(listener); + List> listeners = Lists.newArrayList(listener); + JobHandle handle = client.submit(new ErrorJob(), listeners); try { handle.get(TIMEOUT, TimeUnit.SECONDS); fail("Should have thrown an exception."); @@ -177,8 +178,8 @@ public void testMetricsCollection() throws Exception { @Override public void call(SparkClient client) throws Exception { JobHandle.Listener listener = newListener(); - JobHandle future = client.submit(new AsyncSparkJob()); - future.addListener(listener); + List> listeners = Lists.newArrayList(listener); + JobHandle future = client.submit(new AsyncSparkJob(), listeners); future.get(TIMEOUT, TimeUnit.SECONDS); MetricsCollection metrics = future.getMetrics(); assertEquals(1, metrics.getJobIds().size()); @@ -187,8 +188,8 @@ public void call(SparkClient client) throws Exception { eq(metrics.getJobIds().iterator().next())); JobHandle.Listener listener2 = newListener(); - JobHandle future2 = client.submit(new AsyncSparkJob()); - future2.addListener(listener2); + List> listeners2 = Lists.newArrayList(listener2); + JobHandle future2 = client.submit(new AsyncSparkJob(), listeners2); future2.get(TIMEOUT, TimeUnit.SECONDS); MetricsCollection metrics2 = future2.getMetrics(); assertEquals(1, metrics2.getJobIds().size());