diff --git spark-client/src/main/java/org/apache/hive/spark/client/JobHandle.java spark-client/src/main/java/org/apache/hive/spark/client/JobHandle.java index 44aa255..c02c403 100644 --- spark-client/src/main/java/org/apache/hive/spark/client/JobHandle.java +++ spark-client/src/main/java/org/apache/hive/spark/client/JobHandle.java @@ -61,14 +61,6 @@ State getState(); /** - * Add a listener to the job handle. If the job's state is not SENT, a callback for the - * corresponding state will be invoked immediately. - * - * @param l The listener to add. - */ - void addListener(Listener l); - - /** * The current state of the submitted job. */ static enum State { diff --git spark-client/src/main/java/org/apache/hive/spark/client/JobHandleImpl.java spark-client/src/main/java/org/apache/hive/spark/client/JobHandleImpl.java index 17c8f40..85eeffb 100644 --- spark-client/src/main/java/org/apache/hive/spark/client/JobHandleImpl.java +++ spark-client/src/main/java/org/apache/hive/spark/client/JobHandleImpl.java @@ -24,8 +24,7 @@ import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeoutException; -import com.google.common.base.Throwables; -import com.google.common.collect.Lists; +import com.google.common.collect.ImmutableList; import io.netty.util.concurrent.Promise; import org.apache.hive.spark.counter.SparkCounters; @@ -40,19 +39,26 @@ private final MetricsCollection metrics; private final Promise promise; private final List sparkJobIds; - private final List listeners; + private final List> listeners; private volatile State state; private volatile SparkCounters sparkCounters; - JobHandleImpl(SparkClientImpl client, Promise promise, String jobId) { + JobHandleImpl(SparkClientImpl client, Promise promise, String jobId, + List> listeners) { this.client = client; this.jobId = jobId; this.promise = promise; - this.listeners = Lists.newLinkedList(); + this.listeners = ImmutableList.copyOf(listeners); this.metrics = new MetricsCollection(); this.sparkJobIds = new CopyOnWriteArrayList(); this.state = State.SENT; this.sparkCounters = null; + + synchronized (this.listeners) { + for (Listener l : listeners) { + initializeListener(l); + } + } } /** Requests a running job to be cancelled. */ @@ -122,29 +128,6 @@ public State getState() { return state; } - @Override - public void addListener(Listener l) { - synchronized (listeners) { - listeners.add(l); - // If current state is a final state, notify of Spark job IDs before notifying about the - // state transition. - if (state.ordinal() >= State.CANCELLED.ordinal()) { - for (Integer i : sparkJobIds) { - l.onSparkJobStarted(this, i); - } - } - - fireStateChange(state, l); - - // Otherwise, notify about Spark jobs after the state notification. - if (state.ordinal() < State.CANCELLED.ordinal()) { - for (Integer i : sparkJobIds) { - l.onSparkJobStarted(this, i); - } - } - } - } - public void setSparkCounters(SparkCounters sparkCounters) { this.sparkCounters = sparkCounters; } @@ -179,7 +162,7 @@ boolean changeState(State newState) { synchronized (listeners) { if (newState.ordinal() > state.ordinal() && state.ordinal() < State.CANCELLED.ordinal()) { state = newState; - for (Listener l : listeners) { + for (Listener l : listeners) { fireStateChange(newState, l); } return true; @@ -191,13 +174,32 @@ boolean changeState(State newState) { void addSparkJobId(int sparkJobId) { synchronized (listeners) { sparkJobIds.add(sparkJobId); - for (Listener l : listeners) { + for (Listener l : listeners) { l.onSparkJobStarted(this, sparkJobId); } } } - private void fireStateChange(State s, Listener l) { + private void initializeListener(Listener l) { + // If current state is a final state, notify of Spark job IDs before notifying about the + // state transition. + if (state.ordinal() >= State.CANCELLED.ordinal()) { + for (Integer i : sparkJobIds) { + l.onSparkJobStarted(this, i); + } + } + + fireStateChange(state, l); + + // Otherwise, notify about Spark jobs after the sßtate notification. + if (state.ordinal() < State.CANCELLED.ordinal()) { + for (Integer i : sparkJobIds) { + l.onSparkJobStarted(this, i); + } + } + } + + private void fireStateChange(State s, Listener l) { switch (s) { case SENT: break; diff --git spark-client/src/main/java/org/apache/hive/spark/client/SparkClient.java spark-client/src/main/java/org/apache/hive/spark/client/SparkClient.java index 3e921a5..e952f27 100644 --- spark-client/src/main/java/org/apache/hive/spark/client/SparkClient.java +++ spark-client/src/main/java/org/apache/hive/spark/client/SparkClient.java @@ -19,6 +19,7 @@ import java.io.Serializable; import java.net.URI; +import java.util.List; import java.util.concurrent.Future; import org.apache.hadoop.hive.common.classification.InterfaceAudience; @@ -38,6 +39,15 @@ JobHandle submit(Job job); /** + * Submits a job for asynchronous execution. + * + * @param job The job to execute. + * @param listeners jobhandle listeners to invoke during the job processing + * @return A handle that be used to monitor the job. + */ + JobHandle submit(Job job, List> listeners); + + /** * Asks the remote context to run a job immediately. *

* Normally, the remote context will queue jobs and execute them based on how many worker diff --git spark-client/src/main/java/org/apache/hive/spark/client/SparkClientImpl.java spark-client/src/main/java/org/apache/hive/spark/client/SparkClientImpl.java index e2a30a7..863aaa8 100644 --- spark-client/src/main/java/org/apache/hive/spark/client/SparkClientImpl.java +++ spark-client/src/main/java/org/apache/hive/spark/client/SparkClientImpl.java @@ -45,6 +45,7 @@ import java.net.URI; import java.net.URL; import java.util.ArrayList; +import java.util.Collections; import java.util.List; import java.util.Map; import java.util.Properties; @@ -137,7 +138,12 @@ public void rpcClosed(Rpc rpc) { @Override public JobHandle submit(Job job) { - return protocol.submit(job); + return protocol.submit(job, Collections.>emptyList()); + } + + @Override + public JobHandle submit(Job job, List> listeners) { + return protocol.submit(job, listeners); } @Override @@ -510,10 +516,11 @@ private void redirect(String name, Redirector redirector) { private class ClientProtocol extends BaseProtocol { - JobHandleImpl submit(Job job) { + JobHandleImpl submit(Job job, List> listeners) { final String jobId = UUID.randomUUID().toString(); final Promise promise = driverRpc.createPromise(); - final JobHandleImpl handle = new JobHandleImpl(SparkClientImpl.this, promise, jobId); + final JobHandleImpl handle = + new JobHandleImpl(SparkClientImpl.this, promise, jobId, listeners); jobs.put(jobId, handle); final io.netty.util.concurrent.Future rpc = driverRpc.call(new JobRequest(jobId, job)); diff --git spark-client/src/test/java/org/apache/hive/spark/client/TestJobHandle.java spark-client/src/test/java/org/apache/hive/spark/client/TestJobHandle.java index e8f352d..d6b627b 100644 --- spark-client/src/test/java/org/apache/hive/spark/client/TestJobHandle.java +++ spark-client/src/test/java/org/apache/hive/spark/client/TestJobHandle.java @@ -19,6 +19,7 @@ import java.io.Serializable; +import com.google.common.collect.Lists; import io.netty.util.concurrent.Promise; import org.junit.Test; import org.junit.runner.RunWith; @@ -38,8 +39,8 @@ @Test public void testStateChanges() throws Exception { - JobHandleImpl handle = new JobHandleImpl(client, promise, "job"); - handle.addListener(listener); + JobHandleImpl handle = + new JobHandleImpl(client, promise, "job", Lists.newArrayList(listener)); assertTrue(handle.changeState(JobHandle.State.QUEUED)); verify(listener).onJobQueued(handle); @@ -60,8 +61,8 @@ public void testStateChanges() throws Exception { @Test public void testFailedJob() throws Exception { - JobHandleImpl handle = new JobHandleImpl(client, promise, "job"); - handle.addListener(listener); + JobHandleImpl handle = + new JobHandleImpl(client, promise, "job", Lists.newArrayList(listener)); Throwable cause = new Exception(); when(promise.cause()).thenReturn(cause); @@ -73,8 +74,8 @@ public void testFailedJob() throws Exception { @Test public void testSucceededJob() throws Exception { - JobHandleImpl handle = new JobHandleImpl(client, promise, "job"); - handle.addListener(listener); + JobHandleImpl handle = + new JobHandleImpl(client, promise, "job", Lists.newArrayList(listener)); Serializable result = new Exception(); when(promise.get()).thenReturn(result); @@ -86,16 +87,15 @@ public void testSucceededJob() throws Exception { @Test public void testImmediateCallback() throws Exception { - JobHandleImpl handle = new JobHandleImpl(client, promise, "job"); + JobHandleImpl handle = + new JobHandleImpl(client, promise, "job", Lists.newArrayList(listener, listener2)); assertTrue(handle.changeState(JobHandle.State.QUEUED)); - handle.addListener(listener); verify(listener).onJobQueued(handle); handle.changeState(JobHandle.State.STARTED); handle.addSparkJobId(1); handle.changeState(JobHandle.State.CANCELLED); - handle.addListener(listener2); InOrder inOrder = inOrder(listener2); inOrder.verify(listener2).onSparkJobStarted(same(handle), eq(1)); inOrder.verify(listener2).onJobCancelled(same(handle)); diff --git spark-client/src/test/java/org/apache/hive/spark/client/TestSparkClient.java spark-client/src/test/java/org/apache/hive/spark/client/TestSparkClient.java index b95cd7a..47eabbe 100644 --- spark-client/src/test/java/org/apache/hive/spark/client/TestSparkClient.java +++ spark-client/src/test/java/org/apache/hive/spark/client/TestSparkClient.java @@ -17,6 +17,7 @@ package org.apache.hive.spark.client; +import com.google.common.collect.Lists; import org.apache.hive.spark.client.JobHandle.Listener; import org.slf4j.Logger; @@ -27,8 +28,6 @@ import org.mockito.stubbing.Answer; -import org.mockito.stubbing.Answer; - import org.mockito.Mockito; import java.io.File; @@ -37,8 +36,10 @@ import java.io.InputStream; import java.io.Serializable; import java.net.URI; +import java.util.ArrayList; import java.util.Arrays; import java.util.HashMap; +import java.util.List; import java.util.Map; import java.util.concurrent.ExecutionException; import java.util.concurrent.Future; @@ -94,8 +95,8 @@ public void testJobSubmission() throws Exception { @Override public void call(SparkClient client) throws Exception { JobHandle.Listener listener = newListener(); - JobHandle handle = client.submit(new SimpleJob()); - handle.addListener(listener); + List> listeners = Lists.newArrayList(listener);; + JobHandle handle = client.submit(new SimpleJob(), listeners); assertEquals("hello", handle.get(TIMEOUT, TimeUnit.SECONDS)); // Try an invalid state transition on the handle. This ensures that the actual state @@ -127,8 +128,8 @@ public void testErrorJob() throws Exception { @Override public void call(SparkClient client) throws Exception { JobHandle.Listener listener = newListener(); - JobHandle handle = client.submit(new ErrorJob()); - handle.addListener(listener); + List> listeners = Lists.newArrayList(listener); + JobHandle handle = client.submit(new ErrorJob(), listeners); try { handle.get(TIMEOUT, TimeUnit.SECONDS); fail("Should have thrown an exception."); @@ -177,8 +178,8 @@ public void testMetricsCollection() throws Exception { @Override public void call(SparkClient client) throws Exception { JobHandle.Listener listener = newListener(); - JobHandle future = client.submit(new AsyncSparkJob()); - future.addListener(listener); + List> listeners = Lists.newArrayList(listener); + JobHandle future = client.submit(new AsyncSparkJob(), listeners); future.get(TIMEOUT, TimeUnit.SECONDS); MetricsCollection metrics = future.getMetrics(); assertEquals(1, metrics.getJobIds().size()); @@ -187,8 +188,8 @@ public void call(SparkClient client) throws Exception { eq(metrics.getJobIds().iterator().next())); JobHandle.Listener listener2 = newListener(); - JobHandle future2 = client.submit(new AsyncSparkJob()); - future2.addListener(listener2); + List> listeners2 = Lists.newArrayList(listener2); + JobHandle future2 = client.submit(new AsyncSparkJob(), listeners2); future2.get(TIMEOUT, TimeUnit.SECONDS); MetricsCollection metrics2 = future2.getMetrics(); assertEquals(1, metrics2.getJobIds().size());