diff --git hadoop-yarn-project/hadoop-yarn/hadoop-yarn-client/src/main/java/org/apache/hadoop/yarn/client/api/AMRMClient.java hadoop-yarn-project/hadoop-yarn/hadoop-yarn-client/src/main/java/org/apache/hadoop/yarn/client/api/AMRMClient.java index 3daa156..9b83aa6 100644 --- hadoop-yarn-project/hadoop-yarn/hadoop-yarn-client/src/main/java/org/apache/hadoop/yarn/client/api/AMRMClient.java +++ hadoop-yarn-project/hadoop-yarn/hadoop-yarn-client/src/main/java/org/apache/hadoop/yarn/client/api/AMRMClient.java @@ -18,6 +18,7 @@ package org.apache.hadoop.yarn.client.api; +import com.google.common.base.Supplier; import java.io.IOException; import java.util.Collection; import java.util.List; @@ -336,4 +337,38 @@ public NMTokenCache getNMTokenCache() { return nmTokenCache; } + /** + * Wait for check to return true for each 1000 ms. + * See also {@link #waitFor(com.google.common.base.Supplier, int)} + * and {@link #waitFor(com.google.common.base.Supplier, int, int)} + * @param check + */ + public void waitFor(Supplier check) throws InterruptedException { + waitFor(check, 1000); + } + + /** + * Wait for check to return true for each + * checkEveryMillis ms. + * See also {@link #waitFor(com.google.common.base.Supplier, int, int)} + * @param check user defined checker + * @param checkEveryMillis interval to call check + */ + public void waitFor(Supplier check, int checkEveryMillis) + throws InterruptedException, IllegalArgumentException { + waitFor(check, checkEveryMillis, 1); + }; + + /** + * Wait for check to return true for each + * checkEveryMillis ms. In the main loop, this method will log + * the message "waiting in main loop" for each logInterval times + * iteration to confirm the thread is alive. + * @param check user defined checker + * @param checkEveryMillis interval to call check + * @param logInterval interval to log for each + */ + public abstract void waitFor(Supplier check, int checkEveryMillis, + int logInterval) throws InterruptedException, IllegalArgumentException; + } diff --git hadoop-yarn-project/hadoop-yarn/hadoop-yarn-client/src/main/java/org/apache/hadoop/yarn/client/api/async/AMRMClientAsync.java hadoop-yarn-project/hadoop-yarn/hadoop-yarn-client/src/main/java/org/apache/hadoop/yarn/client/api/async/AMRMClientAsync.java index e726b73..6229a77 100644 --- hadoop-yarn-project/hadoop-yarn/hadoop-yarn-client/src/main/java/org/apache/hadoop/yarn/client/api/async/AMRMClientAsync.java +++ hadoop-yarn-project/hadoop-yarn/hadoop-yarn-client/src/main/java/org/apache/hadoop/yarn/client/api/async/AMRMClientAsync.java @@ -18,9 +18,11 @@ package org.apache.hadoop.yarn.client.api.async; +import com.google.common.base.Supplier; import java.io.IOException; import java.util.Collection; import java.util.List; +import java.util.concurrent.TimeoutException; import java.util.concurrent.atomic.AtomicInteger; import org.apache.hadoop.classification.InterfaceAudience.Private; @@ -189,7 +191,42 @@ public abstract void unregisterApplicationMaster( */ public abstract int getClusterNodeCount(); - public interface CallbackHandler { + /** + * Wait for check to return true for each 1000 ms. + * See also {@link #waitFor(com.google.common.base.Supplier, int)} + * and {@link #waitFor(com.google.common.base.Supplier, int, int)} + * @param check + */ + public void waitFor(Supplier check) throws InterruptedException { + waitFor(check, 1000); + } + + /** + * Wait for check to return true for each + * checkEveryMillis ms. + * See also {@link #waitFor(com.google.common.base.Supplier, int, int)} + * @param check user defined checker + * @param checkEveryMillis interval to call check + */ + public void waitFor(Supplier check, int checkEveryMillis) + throws InterruptedException, IllegalArgumentException { + waitFor(check, checkEveryMillis, 1); + }; + + /** + * Wait for check to return true for each + * checkEveryMillis ms. In the main loop, this method will log + * the message "waiting in main loop" for each logInterval times + * iteration to confirm the thread is alive. + * @param check user defined checker + * @param checkEveryMillis interval to call check + * @param logInterval interval to log for each + */ + public abstract void waitFor(Supplier check, int checkEveryMillis, + int logInterval) throws InterruptedException, IllegalArgumentException; + + + public interface CallbackHandler { /** * Called when the ResourceManager responds to a heartbeat with completed diff --git hadoop-yarn-project/hadoop-yarn/hadoop-yarn-client/src/main/java/org/apache/hadoop/yarn/client/api/async/impl/AMRMClientAsyncImpl.java hadoop-yarn-project/hadoop-yarn/hadoop-yarn-client/src/main/java/org/apache/hadoop/yarn/client/api/async/impl/AMRMClientAsyncImpl.java index e7659bd..e3b7fc3 100644 --- hadoop-yarn-project/hadoop-yarn/hadoop-yarn-client/src/main/java/org/apache/hadoop/yarn/client/api/async/impl/AMRMClientAsyncImpl.java +++ hadoop-yarn-project/hadoop-yarn/hadoop-yarn-client/src/main/java/org/apache/hadoop/yarn/client/api/async/impl/AMRMClientAsyncImpl.java @@ -18,6 +18,8 @@ package org.apache.hadoop.yarn.client.api.async.impl; +import com.google.common.base.Preconditions; +import com.google.common.base.Supplier; import java.io.IOException; import java.util.Collection; import java.util.List; @@ -315,4 +317,32 @@ public void run() { } } } + + @Override + public void waitFor(Supplier check, int checkEveryMillis, + int logInterval) throws InterruptedException, IllegalArgumentException { + Preconditions.checkNotNull(check, "check is null"); + Preconditions.checkArgument(checkEveryMillis <= 0 || logInterval <= 0, + "checkEveryMillis and logInterval should be positive value"); + + int loggingCounter = logInterval; + do { + if (LOG.isDebugEnabled()) { + LOG.debug("Check the condition for main loop."); + } + + boolean result = check.get(); + if (result) { + LOG.info("Exits the main loop."); + return; + } + if (--loggingCounter <= 0) { + LOG.info("Waiting in main loop."); + loggingCounter = logInterval; + } + + Thread.sleep(checkEveryMillis); + } while (true); + } + } diff --git hadoop-yarn-project/hadoop-yarn/hadoop-yarn-client/src/main/java/org/apache/hadoop/yarn/client/api/impl/AMRMClientImpl.java hadoop-yarn-project/hadoop-yarn/hadoop-yarn-client/src/main/java/org/apache/hadoop/yarn/client/api/impl/AMRMClientImpl.java index 1db7054..d14cb0d 100644 --- hadoop-yarn-project/hadoop-yarn/hadoop-yarn-client/src/main/java/org/apache/hadoop/yarn/client/api/impl/AMRMClientImpl.java +++ hadoop-yarn-project/hadoop-yarn/hadoop-yarn-client/src/main/java/org/apache/hadoop/yarn/client/api/impl/AMRMClientImpl.java @@ -18,6 +18,7 @@ package org.apache.hadoop.yarn.client.api.impl; +import com.google.common.base.Supplier; import java.io.IOException; import java.util.ArrayList; import java.util.Collection; @@ -743,4 +744,31 @@ public synchronized void updateBlacklist(List blacklistAdditions, "blacklistRemovals in updateBlacklist."); } } + + @Override + public void waitFor(Supplier check, int checkEveryMillis, + int logInterval) throws InterruptedException, IllegalArgumentException { + Preconditions.checkNotNull(check, "check is null"); + Preconditions.checkArgument(checkEveryMillis <= 0 || logInterval <= 0, + "checkEveryMillis and logInterval should be positive value"); + + int loggingCounter = logInterval; + do { + if (LOG.isDebugEnabled()) { + LOG.debug("Check the condition for main loop."); + } + + boolean result = check.get(); + if (result) { + LOG.info("Exits the main loop."); + return; + } + if (--loggingCounter <= 0) { + LOG.info("Waiting in main loop."); + loggingCounter = logInterval; + } + + Thread.sleep(checkEveryMillis); + } while (true); + } } diff --git hadoop-yarn-project/hadoop-yarn/hadoop-yarn-client/src/test/java/org/apache/hadoop/yarn/client/api/async/impl/TestAMRMClientAsync.java hadoop-yarn-project/hadoop-yarn/hadoop-yarn-client/src/test/java/org/apache/hadoop/yarn/client/api/async/impl/TestAMRMClientAsync.java index 728a558..0d205b6 100644 --- hadoop-yarn-project/hadoop-yarn/hadoop-yarn-client/src/test/java/org/apache/hadoop/yarn/client/api/async/impl/TestAMRMClientAsync.java +++ hadoop-yarn-project/hadoop-yarn/hadoop-yarn-client/src/test/java/org/apache/hadoop/yarn/client/api/async/impl/TestAMRMClientAsync.java @@ -18,6 +18,7 @@ package org.apache.hadoop.yarn.client.api.async.impl; +import com.google.common.base.Supplier; import static org.mockito.Matchers.anyFloat; import static org.mockito.Matchers.anyInt; import static org.mockito.Matchers.anyString; @@ -180,7 +181,7 @@ private void runHeartBeatThrowOutException(Exception ex) throws Exception{ AMRMClient client = mock(AMRMClientImpl.class); when(client.allocate(anyFloat())).thenThrow(ex); - AMRMClientAsync asyncClient = + AMRMClientAsync asyncClient = AMRMClientAsync.createAMRMClientAsync(client, 20, callbackHandler); asyncClient.init(conf); asyncClient.start(); @@ -204,6 +205,73 @@ private void runHeartBeatThrowOutException(Exception ex) throws Exception{ } @Test (timeout = 10000) + public void testAMRMClientAsyncReboot() throws Exception { + Configuration conf = new Configuration(); + TestCallbackHandler callbackHandler = new TestCallbackHandler(); + @SuppressWarnings("unchecked") + AMRMClient client = mock(AMRMClientImpl.class); + + final AllocateResponse rebootResponse = createAllocateResponse( + new ArrayList(), new ArrayList(), null); + rebootResponse.setAMCommand(AMCommand.AM_RESYNC); + when(client.allocate(anyFloat())).thenReturn(rebootResponse); + + AMRMClientAsync asyncClient = + AMRMClientAsync.createAMRMClientAsync(client, 20, callbackHandler); + asyncClient.init(conf); + asyncClient.start(); + + synchronized (callbackHandler.notifier) { + asyncClient.registerApplicationMaster("localhost", 1234, null); + while(callbackHandler.reboot == false) { + try { + callbackHandler.notifier.wait(); + } catch (InterruptedException e) { + e.printStackTrace(); + } + } + } + + asyncClient.stop(); + // stopping should have joined all threads and completed all callbacks + Assert.assertTrue(callbackHandler.callbackCount == 0); + } + + @Test + public void testAMRMClientAsyncRebootWithWaitFor() throws Exception { + Configuration conf = new Configuration(); + final TestCallbackHandler callbackHandler = new TestCallbackHandler(); + @SuppressWarnings("unchecked") + AMRMClient client = mock(AMRMClientImpl.class); + + final AllocateResponse rebootResponse = createAllocateResponse( + new ArrayList(), new ArrayList(), null); + rebootResponse.setAMCommand(AMCommand.AM_RESYNC); + when(client.allocate(anyFloat())).thenReturn(rebootResponse); + + AMRMClientAsync asyncClient = + AMRMClientAsync.createAMRMClientAsync(client, 20, callbackHandler); + asyncClient.init(conf); + asyncClient.start(); + + Supplier checker = new Supplier() { + @Override + public Boolean get() { + synchronized (callbackHandler) { + return callbackHandler.reboot; + } + } + }; + + asyncClient.registerApplicationMaster("localhost", 1234, null); + asyncClient.waitFor(checker); + + asyncClient.stop(); + // stopping should have joined all threads and completed all callbacks + Assert.assertTrue(callbackHandler.callbackCount == 0); + } + + @Test (timeout = 10000) public void testAMRMClientAsyncShutDown() throws Exception { Configuration conf = new Configuration(); TestCallbackHandler callbackHandler = new TestCallbackHandler(); @@ -262,6 +330,42 @@ public void testCallAMRMClientAsyncStopFromCallbackHandler() } } + @Test (timeout = 5000) + public void testCallAMRMClientAsyncStopFromCallbackHandlerWithWaitFor() + throws YarnException, IOException, InterruptedException { + Configuration conf = new Configuration(); + final TestCallbackHandler2 callbackHandler = new TestCallbackHandler2(); + @SuppressWarnings("unchecked") + AMRMClient client = mock(AMRMClientImpl.class); + + List completed = Arrays.asList( + ContainerStatus.newInstance(newContainerId(0, 0, 0, 0), + ContainerState.COMPLETE, "", 0)); + final AllocateResponse response = createAllocateResponse(completed, + new ArrayList(), null); + + when(client.allocate(anyFloat())).thenReturn(response); + + AMRMClientAsync asyncClient = + AMRMClientAsync.createAMRMClientAsync(client, 20, callbackHandler); + callbackHandler.asynClient = asyncClient; + asyncClient.init(conf); + asyncClient.start(); + + Supplier checker = new Supplier() { + @Override + public Boolean get() { + synchronized (callbackHandler) { + return callbackHandler.notify; + } + } + }; + + asyncClient.registerApplicationMaster("localhost", 1234, null); + asyncClient.waitFor(checker); + Assert.assertTrue(checker.get()); + } + void runCallBackThrowOutException(TestCallbackHandler2 callbackHandler) throws InterruptedException, YarnException, IOException { Configuration conf = new Configuration(); diff --git hadoop-yarn-project/hadoop-yarn/hadoop-yarn-client/src/test/java/org/apache/hadoop/yarn/client/api/impl/TestAMRMClient.java hadoop-yarn-project/hadoop-yarn/hadoop-yarn-client/src/test/java/org/apache/hadoop/yarn/client/api/impl/TestAMRMClient.java index 5961532..601fe8c 100644 --- hadoop-yarn-project/hadoop-yarn/hadoop-yarn-client/src/test/java/org/apache/hadoop/yarn/client/api/impl/TestAMRMClient.java +++ hadoop-yarn-project/hadoop-yarn/hadoop-yarn-client/src/test/java/org/apache/hadoop/yarn/client/api/impl/TestAMRMClient.java @@ -18,6 +18,7 @@ package org.apache.hadoop.yarn.client.api.impl; +import com.google.common.base.Supplier; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; @@ -800,6 +801,40 @@ public AllocateResponse answer(InvocationOnMock invocation) assertEquals(0, amClient.ask.size()); assertEquals(0, amClient.release.size()); } + + class CountDownSupplier implements Supplier { + int counter = 0; + @Override + public Boolean get() { + counter++; + if (counter >= 3) { + return true; + } else { + return false; + } + } + }; + + @Test + public void testWaitFor() throws InterruptedException { + AMRMClientImpl amClient = null; + CountDownSupplier countDownChecker = new CountDownSupplier(); + + try { + // start am rm client + amClient = + (AMRMClientImpl) AMRMClient + . createAMRMClient(); + amClient.init(new YarnConfiguration()); + amClient.start(); + amClient.waitFor(countDownChecker, 1000); + assertEquals(3, countDownChecker.counter); + } finally { + if (amClient != null) { + amClient.stop(); + } + } + } private void sleep(int sleepTime) { try {