diff --git spark-client/src/main/java/org/apache/hive/spark/client/SparkClientImpl.java spark-client/src/main/java/org/apache/hive/spark/client/SparkClientImpl.java index 863aaa8..016ddb7 100644 --- spark-client/src/main/java/org/apache/hive/spark/client/SparkClientImpl.java +++ spark-client/src/main/java/org/apache/hive/spark/client/SparkClientImpl.java @@ -518,45 +518,46 @@ private void redirect(String name, Redirector redirector) { JobHandleImpl submit(Job job, List> listeners) { final String jobId = UUID.randomUUID().toString(); - final Promise promise = driverRpc.createPromise(); + final Promise jobHandlePromise = driverRpc.createPromise(); + final Promise rpcCallPromise = driverRpc.createPromise(); final JobHandleImpl handle = - new JobHandleImpl(SparkClientImpl.this, promise, jobId, listeners); + new JobHandleImpl(SparkClientImpl.this, jobHandlePromise, jobId, listeners); jobs.put(jobId, handle); - final io.netty.util.concurrent.Future rpc = driverRpc.call(new JobRequest(jobId, job)); - LOG.debug("Send JobRequest[{}].", jobId); - - // Link the RPC and the promise so that events from one are propagated to the other as + // Listener used to link the RPC and the promise so that events from one are propagated to the other as // needed. - rpc.addListener(new GenericFutureListener>() { + rpcCallPromise.addListener(new GenericFutureListener>() { @Override public void operationComplete(io.netty.util.concurrent.Future f) { if (f.isSuccess()) { handle.changeState(JobHandle.State.QUEUED); - } else if (!promise.isDone()) { - promise.setFailure(f.cause()); + } else if (!jobHandlePromise.isDone()) { + jobHandlePromise.setFailure(f.cause()); } } }); - promise.addListener(new GenericFutureListener>() { + + jobHandlePromise.addListener(new GenericFutureListener>() { @Override public void operationComplete(Promise p) { if (jobId != null) { jobs.remove(jobId); } - if (p.isCancelled() && !rpc.isDone()) { - rpc.cancel(true); + if (p.isCancelled() && !rpcCallPromise.isDone()) { + rpcCallPromise.cancel(true); } } }); - return handle; + + driverRpc.call(new JobRequest<>(jobId, job), rpcCallPromise); + LOG.debug("Send JobRequest[{}].", jobId); +return handle; } Future run(Job job) { - @SuppressWarnings("unchecked") - final io.netty.util.concurrent.Future rpc = (io.netty.util.concurrent.Future) - driverRpc.call(new SyncJobRequest(job), Serializable.class); - return rpc; + final io.netty.util.concurrent.Promise promise = driverRpc.createPromise(); + driverRpc.call(new SyncJobRequest<>(job), Serializable.class, (Promise) promise); + return promise; } void cancel(String jobId) { @@ -564,7 +565,9 @@ void cancel(String jobId) { } Future endSession() { - return driverRpc.call(new EndSession()); + Promise promise = driverRpc.createPromise(); + driverRpc.call(new EndSession(), promise); + return promise; } private void handle(ChannelHandlerContext ctx, Error msg) { diff --git spark-client/src/main/java/org/apache/hive/spark/client/rpc/Rpc.java spark-client/src/main/java/org/apache/hive/spark/client/rpc/Rpc.java index b2f133b..706dc86 100644 --- spark-client/src/main/java/org/apache/hive/spark/client/rpc/Rpc.java +++ spark-client/src/main/java/org/apache/hive/spark/client/rpc/Rpc.java @@ -252,11 +252,20 @@ public void addListener(Listener l) { } /** - * Send an RPC call to the remote endpoint and returns a future that can be used to monitor the - * operation. + * Send an RPC call to the remote endpoint. + * @param msg message to send */ - public Future call(Object msg) { - return call(msg, Void.class); + public void call(Object msg) { + call(msg, Void.class, this.createPromise()); + } + + /** + * Send an RPC call to the remote endpoint. + * @param msg message to send + * @param promise future used to monitor the operation + */ + public void call(Object msg, Promise promise) { + call(msg, Void.class, promise); } public boolean isActive() { @@ -264,19 +273,17 @@ public boolean isActive() { } /** - * Send an RPC call to the remote endpoint and returns a future that can be used to monitor the - * operation. + * Send an RPC call to the remote endpoint. * * @param msg RPC call to send. * @param retType Type of expected reply. - * @return A future used to monitor the operation. + * @param promise a future used to monitor the operation. */ - public Future call(Object msg, Class retType) { + public void call(Object msg, Class retType, final Promise promise) { Preconditions.checkArgument(msg != null); Preconditions.checkState(channel.isActive(), "RPC channel is closed."); try { final long id = rpcId.getAndIncrement(); - final Promise promise = createPromise(); ChannelFutureListener listener = new ChannelFutureListener() { @Override public void operationComplete(ChannelFuture cf) { @@ -294,7 +301,6 @@ public void operationComplete(ChannelFuture cf) { channel.write(new MessageHeader(id, Rpc.MessageType.CALL)).addListener(listener); channel.writeAndFlush(msg).addListener(listener); } - return promise; } catch (Exception e) { throw Throwables.propagate(e); } diff --git spark-client/src/test/java/org/apache/hive/spark/client/rpc/TestRpc.java spark-client/src/test/java/org/apache/hive/spark/client/rpc/TestRpc.java index 77c3d02..d84f09e 100644 --- spark-client/src/test/java/org/apache/hive/spark/client/rpc/TestRpc.java +++ spark-client/src/test/java/org/apache/hive/spark/client/rpc/TestRpc.java @@ -38,6 +38,7 @@ import io.netty.channel.nio.NioEventLoopGroup; import io.netty.util.concurrent.Future; +import io.netty.util.concurrent.Promise; import org.apache.commons.io.IOUtils; import org.apache.hadoop.hive.conf.HiveConf; import org.slf4j.Logger; @@ -79,7 +80,8 @@ public void testRpcDispatcher() throws Exception { Rpc clientRpc = autoClose(Rpc.createEmbedded(new TestDispatcher())); TestMessage outbound = new TestMessage("Hello World!"); - Future call = clientRpc.call(outbound, TestMessage.class); + Promise call = clientRpc.createPromise(); + clientRpc.call(outbound, TestMessage.class, call); LOG.debug("Transferring messages..."); transfer(serverRpc, clientRpc); @@ -96,18 +98,22 @@ public void testClientServer() throws Exception { Rpc client = rpcs[1]; TestMessage outbound = new TestMessage("Hello World!"); - Future call = client.call(outbound, TestMessage.class); + Promise call = client.createPromise(); + client.call(outbound, TestMessage.class, call); TestMessage reply = call.get(10, TimeUnit.SECONDS); assertEquals(outbound.message, reply.message); TestMessage another = new TestMessage("Hello again!"); - Future anotherCall = client.call(another, TestMessage.class); + Promise anotherCall = client.createPromise(); + client.call(another, TestMessage.class, anotherCall); TestMessage anotherReply = anotherCall.get(10, TimeUnit.SECONDS); assertEquals(another.message, anotherReply.message); String errorMsg = "This is an error."; try { - client.call(new ErrorCall(errorMsg)).get(10, TimeUnit.SECONDS); + Promise errorCall = client.createPromise(); + client.call(new ErrorCall(errorMsg), errorCall); + errorCall.get(10, TimeUnit.SECONDS); } catch (ExecutionException ee) { assertTrue(ee.getCause() instanceof RpcException); assertTrue(ee.getCause().getMessage().indexOf(errorMsg) >= 0); @@ -115,7 +121,8 @@ public void testClientServer() throws Exception { // Test from server to client too. TestMessage serverMsg = new TestMessage("Hello from the server!"); - Future serverCall = serverRpc.call(serverMsg, TestMessage.class); + Promise serverCall = client.createPromise(); + serverRpc.call(serverMsg, TestMessage.class, serverCall); TestMessage serverReply = serverCall.get(10, TimeUnit.SECONDS); assertEquals(serverMsg.message, serverReply.message); } @@ -228,7 +235,9 @@ public void testNotDeserializableRpc() throws Exception { Rpc client = rpcs[1]; try { - client.call(new NotDeserializable(42)).get(10, TimeUnit.SECONDS); + Promise errorCall = client.createPromise(); + client.call(new NotDeserializable(42), errorCall); + errorCall.get(10, TimeUnit.SECONDS); } catch (ExecutionException ee) { assertTrue(ee.getCause() instanceof RpcException); assertTrue(ee.getCause().getMessage().indexOf("KryoException") >= 0); @@ -246,7 +255,8 @@ public void testEncryption() throws Exception { Rpc client = rpcs[1]; TestMessage outbound = new TestMessage("Hello World!"); - Future call = client.call(outbound, TestMessage.class); + Promise call = client.createPromise(); + client.call(outbound, TestMessage.class, call); TestMessage reply = call.get(10, TimeUnit.SECONDS); assertEquals(outbound.message, reply.message); }