diff --git hadoop-yarn-project/hadoop-yarn/hadoop-yarn-client/src/main/java/org/apache/hadoop/yarn/client/api/AMRMClient.java hadoop-yarn-project/hadoop-yarn/hadoop-yarn-client/src/main/java/org/apache/hadoop/yarn/client/api/AMRMClient.java index 23f4ea1..90781a8 100644 --- hadoop-yarn-project/hadoop-yarn/hadoop-yarn-client/src/main/java/org/apache/hadoop/yarn/client/api/AMRMClient.java +++ hadoop-yarn-project/hadoop-yarn/hadoop-yarn-client/src/main/java/org/apache/hadoop/yarn/client/api/AMRMClient.java @@ -18,6 +18,7 @@ package org.apache.hadoop.yarn.client.api; +import com.google.common.base.Supplier; import java.io.IOException; import java.util.Collection; import java.util.List; @@ -329,4 +330,22 @@ public NMTokenCache getNMTokenCache() { return nmTokenCache; } -} + + /** + * Wait for check returns true. + * See also {@link #waitFor(com.google.common.base.Supplier, int)} + * @param check + */ + public void waitFor(Supplier check) throws InterruptedException { + waitFor(check, 1000); + } + + /** + * Wait for check returns true. + * @param check user defined checker + * @param checkEveryMillis interval to call check + */ + public abstract void waitFor(Supplier check, int checkEveryMillis) + throws InterruptedException; + +} \ No newline at end of file diff --git hadoop-yarn-project/hadoop-yarn/hadoop-yarn-client/src/main/java/org/apache/hadoop/yarn/client/api/async/AMRMClientAsync.java hadoop-yarn-project/hadoop-yarn/hadoop-yarn-client/src/main/java/org/apache/hadoop/yarn/client/api/async/AMRMClientAsync.java index e726b73..60862d6 100644 --- hadoop-yarn-project/hadoop-yarn/hadoop-yarn-client/src/main/java/org/apache/hadoop/yarn/client/api/async/AMRMClientAsync.java +++ hadoop-yarn-project/hadoop-yarn/hadoop-yarn-client/src/main/java/org/apache/hadoop/yarn/client/api/async/AMRMClientAsync.java @@ -18,9 +18,11 @@ package org.apache.hadoop.yarn.client.api.async; +import com.google.common.base.Supplier; import java.io.IOException; import java.util.Collection; import java.util.List; +import java.util.concurrent.TimeoutException; import java.util.concurrent.atomic.AtomicInteger; import org.apache.hadoop.classification.InterfaceAudience.Private; @@ -189,6 +191,23 @@ public abstract void unregisterApplicationMaster( */ public abstract int getClusterNodeCount(); + /** + * Wait for check returns true. + * See also {@link #waitFor(com.google.common.base.Supplier, int)} + * @param check + */ + public void waitFor(Supplier check) throws InterruptedException { + waitFor(check, 1000); + } + + /** + * Wait for check returns true. + * @param check user defined checker + * @param checkEveryMillis interval to call check + */ + public abstract void waitFor(Supplier check, int checkEveryMillis) + throws InterruptedException; + public interface CallbackHandler { /** diff --git hadoop-yarn-project/hadoop-yarn/hadoop-yarn-client/src/main/java/org/apache/hadoop/yarn/client/api/async/impl/AMRMClientAsyncImpl.java hadoop-yarn-project/hadoop-yarn/hadoop-yarn-client/src/main/java/org/apache/hadoop/yarn/client/api/async/impl/AMRMClientAsyncImpl.java index 57acb2c..6272d47 100644 --- hadoop-yarn-project/hadoop-yarn/hadoop-yarn-client/src/main/java/org/apache/hadoop/yarn/client/api/async/impl/AMRMClientAsyncImpl.java +++ hadoop-yarn-project/hadoop-yarn/hadoop-yarn-client/src/main/java/org/apache/hadoop/yarn/client/api/async/impl/AMRMClientAsyncImpl.java @@ -18,17 +18,20 @@ package org.apache.hadoop.yarn.client.api.async.impl; +import com.google.common.base.Supplier; import java.io.IOException; import java.util.Collection; import java.util.List; import java.util.concurrent.BlockingQueue; import java.util.concurrent.LinkedBlockingQueue; +import java.util.concurrent.TimeoutException; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.apache.hadoop.classification.InterfaceAudience.Private; import org.apache.hadoop.classification.InterfaceStability.Unstable; import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.util.Time; import org.apache.hadoop.yarn.api.protocolrecords.AllocateResponse; import org.apache.hadoop.yarn.api.protocolrecords.RegisterApplicationMasterResponse; import org.apache.hadoop.yarn.api.records.AMCommand; @@ -66,7 +69,7 @@ private volatile float progress; private volatile Throwable savedException; - + public AMRMClientAsyncImpl(int intervalMs, CallbackHandler callbackHandler) { this(new AMRMClientImpl(), intervalMs, callbackHandler); } @@ -150,12 +153,12 @@ public RegisterApplicationMasterResponse registerApplicationMaster( * @throws IOException */ public void unregisterApplicationMaster(FinalApplicationStatus appStatus, - String appMessage, String appTrackingUrl) throws YarnException, - IOException { - synchronized (unregisterHeartbeatLock) { - keepRunning = false; - client.unregisterApplicationMaster(appStatus, appMessage, appTrackingUrl); - } + String appMessage, String appTrackingUrl) throws YarnException, + IOException { + synchronized (unregisterHeartbeatLock) { + keepRunning = false; + client.unregisterApplicationMaster(appStatus, appMessage, appTrackingUrl); + } } /** @@ -253,7 +256,25 @@ public void run() { } } } - + + @Override + public void waitFor(Supplier check, int checkEveryMillis) + throws InterruptedException { + do { + if (LOG.isDebugEnabled()) { + LOG.debug("Check the condition for main loop."); + } + + boolean result = check.get(); + if (result) { + LOG.info("Exits the main loop."); + return; + } + + Thread.sleep(checkEveryMillis); + } while (true); + } + private class CallbackHandlerThread extends Thread { public CallbackHandlerThread() { super("AMRM Callback Handler Thread"); diff --git hadoop-yarn-project/hadoop-yarn/hadoop-yarn-client/src/main/java/org/apache/hadoop/yarn/client/api/impl/AMRMClientImpl.java hadoop-yarn-project/hadoop-yarn/hadoop-yarn-client/src/main/java/org/apache/hadoop/yarn/client/api/impl/AMRMClientImpl.java index 1eebaac..893ee17 100644 --- hadoop-yarn-project/hadoop-yarn/hadoop-yarn-client/src/main/java/org/apache/hadoop/yarn/client/api/impl/AMRMClientImpl.java +++ hadoop-yarn-project/hadoop-yarn/hadoop-yarn-client/src/main/java/org/apache/hadoop/yarn/client/api/impl/AMRMClientImpl.java @@ -18,6 +18,7 @@ package org.apache.hadoop.yarn.client.api.impl; +import com.google.common.base.Supplier; import java.io.IOException; import java.util.ArrayList; import java.util.Collection; @@ -675,4 +676,22 @@ public synchronized void updateBlacklist(List blacklistAdditions, "blacklistRemovals in updateBlacklist."); } } + + @Override + public void waitFor(Supplier check, int checkEveryMillis) + throws InterruptedException { + do { + if (LOG.isDebugEnabled()) { + LOG.debug("Check the condition for main loop."); + } + + boolean result = check.get(); + if (result) { + LOG.info("Exits the main loop."); + return; + } + + Thread.sleep(checkEveryMillis); + } while (true); + } } diff --git hadoop-yarn-project/hadoop-yarn/hadoop-yarn-client/src/test/java/org/apache/hadoop/yarn/client/api/async/impl/TestAMRMClientAsync.java hadoop-yarn-project/hadoop-yarn/hadoop-yarn-client/src/test/java/org/apache/hadoop/yarn/client/api/async/impl/TestAMRMClientAsync.java index e21c4ba..228dafd 100644 --- hadoop-yarn-project/hadoop-yarn/hadoop-yarn-client/src/test/java/org/apache/hadoop/yarn/client/api/async/impl/TestAMRMClientAsync.java +++ hadoop-yarn-project/hadoop-yarn/hadoop-yarn-client/src/test/java/org/apache/hadoop/yarn/client/api/async/impl/TestAMRMClientAsync.java @@ -18,6 +18,7 @@ package org.apache.hadoop.yarn.client.api.async.impl; +import com.google.common.base.Supplier; import static org.mockito.Matchers.anyFloat; import static org.mockito.Matchers.anyInt; import static org.mockito.Matchers.anyString; @@ -235,6 +236,40 @@ public void testAMRMClientAsyncReboot() throws Exception { // stopping should have joined all threads and completed all callbacks Assert.assertTrue(callbackHandler.callbackCount == 0); } + + @Test + public void testAMRMClientAsyncRebootWithWaitFor() throws Exception { + Configuration conf = new Configuration(); + final TestCallbackHandler callbackHandler = new TestCallbackHandler(); + @SuppressWarnings("unchecked") + AMRMClient client = mock(AMRMClientImpl.class); + + final AllocateResponse rebootResponse = createAllocateResponse( + new ArrayList(), new ArrayList(), null); + rebootResponse.setAMCommand(AMCommand.AM_RESYNC); + when(client.allocate(anyFloat())).thenReturn(rebootResponse); + + AMRMClientAsync asyncClient = + AMRMClientAsync.createAMRMClientAsync(client, 20, callbackHandler); + asyncClient.init(conf); + asyncClient.start(); + + Supplier checker = new Supplier() { + @Override + public Boolean get() { + return callbackHandler.reboot; + } + }; + + synchronized (callbackHandler.notifier) { + asyncClient.registerApplicationMaster("localhost", 1234, null); + asyncClient.waitFor(checker); + } + + asyncClient.stop(); + // stopping should have joined all threads and completed all callbacks + Assert.assertTrue(callbackHandler.callbackCount == 0); + } @Test (timeout = 10000) public void testAMRMClientAsyncShutDown() throws Exception { @@ -295,6 +330,41 @@ public void testCallAMRMClientAsyncStopFromCallbackHandler() } } + @Test (timeout = 5000) + public void testCallAMRMClientAsyncStopFromCallbackHandlerWithWaitFor() + throws YarnException, IOException, InterruptedException { + Configuration conf = new Configuration(); + final TestCallbackHandler2 callbackHandler = new TestCallbackHandler2(); + @SuppressWarnings("unchecked") + AMRMClient client = mock(AMRMClientImpl.class); + + List completed = Arrays.asList( + ContainerStatus.newInstance(newContainerId(0, 0, 0, 0), + ContainerState.COMPLETE, "", 0)); + final AllocateResponse response = createAllocateResponse(completed, + new ArrayList(), null); + + when(client.allocate(anyFloat())).thenReturn(response); + + AMRMClientAsync asyncClient = + AMRMClientAsync.createAMRMClientAsync(client, 20, callbackHandler); + callbackHandler.asynClient = asyncClient; + asyncClient.init(conf); + asyncClient.start(); + + Supplier checker = new Supplier() { + @Override + public Boolean get() { + return callbackHandler.notify; + } + }; + + synchronized (callbackHandler.notifier) { + asyncClient.registerApplicationMaster("localhost", 1234, null); + asyncClient.waitFor(checker); + } + } + void runCallBackThrowOutException(TestCallbackHandler2 callbackHandler) throws InterruptedException, YarnException, IOException { Configuration conf = new Configuration(); diff --git hadoop-yarn-project/hadoop-yarn/hadoop-yarn-client/src/test/java/org/apache/hadoop/yarn/client/api/impl/TestAMRMClient.java hadoop-yarn-project/hadoop-yarn/hadoop-yarn-client/src/test/java/org/apache/hadoop/yarn/client/api/impl/TestAMRMClient.java index 5961532..b720db8 100644 --- hadoop-yarn-project/hadoop-yarn/hadoop-yarn-client/src/test/java/org/apache/hadoop/yarn/client/api/impl/TestAMRMClient.java +++ hadoop-yarn-project/hadoop-yarn/hadoop-yarn-client/src/test/java/org/apache/hadoop/yarn/client/api/impl/TestAMRMClient.java @@ -18,6 +18,7 @@ package org.apache.hadoop.yarn.client.api.impl; +import com.google.common.base.Supplier; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; @@ -800,6 +801,37 @@ public AllocateResponse answer(InvocationOnMock invocation) assertEquals(0, amClient.ask.size()); assertEquals(0, amClient.release.size()); } + + @Test + public void testWaitFor() throws InterruptedException { + AMRMClientImpl amClient = null; + Supplier countDownChecker = new Supplier() { + int counter = 0; + @Override + public Boolean get() { + counter++; + if (counter > 3) { + return true; + } else { + return false; + } + } + }; + + try { + // start am rm client + amClient = + (AMRMClientImpl) AMRMClient + . createAMRMClient(); + amClient.init(new YarnConfiguration()); + amClient.start(); + amClient.waitFor(countDownChecker, 1000); + } finally { + if (amClient != null) { + amClient.stop(); + } + } + } private void sleep(int sleepTime) { try {