diff --git hadoop-yarn-project/hadoop-yarn/hadoop-yarn-client/src/main/java/org/apache/hadoop/yarn/client/api/async/AMRMClientAsync.java hadoop-yarn-project/hadoop-yarn/hadoop-yarn-client/src/main/java/org/apache/hadoop/yarn/client/api/async/AMRMClientAsync.java index e726b73..b72b306 100644 --- hadoop-yarn-project/hadoop-yarn/hadoop-yarn-client/src/main/java/org/apache/hadoop/yarn/client/api/async/AMRMClientAsync.java +++ hadoop-yarn-project/hadoop-yarn/hadoop-yarn-client/src/main/java/org/apache/hadoop/yarn/client/api/async/AMRMClientAsync.java @@ -18,9 +18,11 @@ package org.apache.hadoop.yarn.client.api.async; +import com.google.common.base.Supplier; import java.io.IOException; import java.util.Collection; import java.util.List; +import java.util.concurrent.TimeoutException; import java.util.concurrent.atomic.AtomicInteger; import org.apache.hadoop.classification.InterfaceAudience.Private; @@ -189,6 +191,20 @@ public abstract void unregisterApplicationMaster( */ public abstract int getClusterNodeCount(); + public void waitFor(Supplier check) + throws TimeoutException, InterruptedException { + waitFor(check, 30000, 1000); + } + + public void waitFor(Supplier check, + int waitForMillis) throws TimeoutException, InterruptedException { + waitFor(check, waitForMillis, 1000); + } + + public abstract void waitFor(Supplier check, + int waitForMillis, int checkEveryMillis) + throws TimeoutException, InterruptedException; + public interface CallbackHandler { /** diff --git hadoop-yarn-project/hadoop-yarn/hadoop-yarn-client/src/main/java/org/apache/hadoop/yarn/client/api/async/impl/AMRMClientAsyncImpl.java hadoop-yarn-project/hadoop-yarn/hadoop-yarn-client/src/main/java/org/apache/hadoop/yarn/client/api/async/impl/AMRMClientAsyncImpl.java index 57acb2c..4a52cc7 100644 --- hadoop-yarn-project/hadoop-yarn/hadoop-yarn-client/src/main/java/org/apache/hadoop/yarn/client/api/async/impl/AMRMClientAsyncImpl.java +++ hadoop-yarn-project/hadoop-yarn/hadoop-yarn-client/src/main/java/org/apache/hadoop/yarn/client/api/async/impl/AMRMClientAsyncImpl.java @@ -18,17 +18,20 @@ package org.apache.hadoop.yarn.client.api.async.impl; +import com.google.common.base.Supplier; import java.io.IOException; import java.util.Collection; import java.util.List; import java.util.concurrent.BlockingQueue; import java.util.concurrent.LinkedBlockingQueue; +import java.util.concurrent.TimeoutException; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.apache.hadoop.classification.InterfaceAudience.Private; import org.apache.hadoop.classification.InterfaceStability.Unstable; import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.util.Time; import org.apache.hadoop.yarn.api.protocolrecords.AllocateResponse; import org.apache.hadoop.yarn.api.protocolrecords.RegisterApplicationMasterResponse; import org.apache.hadoop.yarn.api.records.AMCommand; @@ -66,7 +69,7 @@ private volatile float progress; private volatile Throwable savedException; - + public AMRMClientAsyncImpl(int intervalMs, CallbackHandler callbackHandler) { this(new AMRMClientImpl(), intervalMs, callbackHandler); } @@ -150,12 +153,12 @@ public RegisterApplicationMasterResponse registerApplicationMaster( * @throws IOException */ public void unregisterApplicationMaster(FinalApplicationStatus appStatus, - String appMessage, String appTrackingUrl) throws YarnException, - IOException { - synchronized (unregisterHeartbeatLock) { - keepRunning = false; - client.unregisterApplicationMaster(appStatus, appMessage, appTrackingUrl); - } + String appMessage, String appTrackingUrl) throws YarnException, + IOException { + synchronized (unregisterHeartbeatLock) { + keepRunning = false; + client.unregisterApplicationMaster(appStatus, appMessage, appTrackingUrl); + } } /** @@ -253,7 +256,28 @@ public void run() { } } } - + + public void waitFor(Supplier check, + int waitForMillis) throws TimeoutException, InterruptedException { + waitFor(check, waitForMillis, 1000); + } + + public void waitFor(Supplier check, + int waitForMillis, int checkEveryMillis) + throws TimeoutException, InterruptedException { + long st = Time.now(); + do { + boolean result = check.get(); + if (result) { + return; + } + + Thread.sleep(checkEveryMillis); + } while (Time.now() - st < waitForMillis); + + throw new TimeoutException("Timed out waiting for condition."); + } + private class CallbackHandlerThread extends Thread { public CallbackHandlerThread() { super("AMRM Callback Handler Thread"); @@ -262,6 +286,7 @@ public CallbackHandlerThread() { public void run() { while (true) { if (!keepRunning) { + handler.onShutdownRequest(); return; } try { diff --git hadoop-yarn-project/hadoop-yarn/hadoop-yarn-client/src/test/java/org/apache/hadoop/yarn/client/api/async/impl/TestAMRMClientAsync.java hadoop-yarn-project/hadoop-yarn/hadoop-yarn-client/src/test/java/org/apache/hadoop/yarn/client/api/async/impl/TestAMRMClientAsync.java index e21c4ba..22028cf 100644 --- hadoop-yarn-project/hadoop-yarn/hadoop-yarn-client/src/test/java/org/apache/hadoop/yarn/client/api/async/impl/TestAMRMClientAsync.java +++ hadoop-yarn-project/hadoop-yarn/hadoop-yarn-client/src/test/java/org/apache/hadoop/yarn/client/api/async/impl/TestAMRMClientAsync.java @@ -18,6 +18,8 @@ package org.apache.hadoop.yarn.client.api.async.impl; +import com.google.common.base.Supplier; +import java.util.concurrent.TimeoutException; import static org.mockito.Matchers.anyFloat; import static org.mockito.Matchers.anyInt; import static org.mockito.Matchers.anyString; @@ -206,7 +208,7 @@ private void runHeartBeatThrowOutException(Exception ex) throws Exception{ @Test//(timeout=10000) public void testAMRMClientAsyncReboot() throws Exception { Configuration conf = new Configuration(); - TestCallbackHandler callbackHandler = new TestCallbackHandler(); + final TestCallbackHandler callbackHandler = new TestCallbackHandler(); @SuppressWarnings("unchecked") AMRMClient client = mock(AMRMClientImpl.class); @@ -219,16 +221,17 @@ public void testAMRMClientAsyncReboot() throws Exception { AMRMClientAsync.createAMRMClientAsync(client, 20, callbackHandler); asyncClient.init(conf); asyncClient.start(); - + + Supplier rebootFlagChecker = new Supplier() { + @Override + public Boolean get() { + return callbackHandler.reboot; + } + }; + synchronized (callbackHandler.notifier) { asyncClient.registerApplicationMaster("localhost", 1234, null); - while(callbackHandler.reboot == false) { - try { - callbackHandler.notifier.wait(); - } catch (InterruptedException e) { - e.printStackTrace(); - } - } + asyncClient.waitFor(rebootFlagChecker, 5000); } asyncClient.stop(); @@ -295,7 +298,7 @@ public void testCallAMRMClientAsyncStopFromCallbackHandler() } } - void runCallBackThrowOutException(TestCallbackHandler2 callbackHandler) throws + void runCallBackThrowOutException(final TestCallbackHandler2 callbackHandler) throws InterruptedException, YarnException, IOException { Configuration conf = new Configuration(); @SuppressWarnings("unchecked") @@ -315,9 +318,21 @@ void runCallBackThrowOutException(TestCallbackHandler2 callbackHandler) throws asyncClient.init(conf); asyncClient.start(); + Supplier notifyFlagChecker = new Supplier() { + @Override + public Boolean get() { + return callbackHandler.notify; + } + }; + // call register and wait for error callback and stop synchronized (callbackHandler.notifier) { asyncClient.registerApplicationMaster("localhost", 1234, null); + try { + asyncClient.waitFor(notifyFlagChecker); + } catch (TimeoutException e) { + Assert.fail("timeout not expected."); + } while(callbackHandler.notify == false) { try { callbackHandler.notifier.wait(); @@ -354,6 +369,11 @@ public void testCallBackThrowOutExceptionNoStop() throws YarnException, runCallBackThrowOutException(callbackHandler); } + @Test (timeout = 5000) + public void testWaitForUnregister() { + + } + private AllocateResponse createAllocateResponse( List completed, List allocated, List nmTokens) {