diff --git hadoop-yarn-project/hadoop-yarn/hadoop-yarn-client/src/main/java/org/apache/hadoop/yarn/client/api/AMRMClient.java hadoop-yarn-project/hadoop-yarn/hadoop-yarn-client/src/main/java/org/apache/hadoop/yarn/client/api/AMRMClient.java
index 3daa156..1389f59 100644
--- hadoop-yarn-project/hadoop-yarn/hadoop-yarn-client/src/main/java/org/apache/hadoop/yarn/client/api/AMRMClient.java
+++ hadoop-yarn-project/hadoop-yarn/hadoop-yarn-client/src/main/java/org/apache/hadoop/yarn/client/api/AMRMClient.java
@@ -18,6 +18,7 @@
package org.apache.hadoop.yarn.client.api;
+import com.google.common.base.Supplier;
import java.io.IOException;
import java.util.Collection;
import java.util.List;
@@ -336,4 +337,22 @@ public NMTokenCache getNMTokenCache() {
return nmTokenCache;
}
+
+ /**
+ * Wait for check returns true.
+ * See also {@link #waitFor(com.google.common.base.Supplier, int)}
+ * @param check
+ */
+ public void waitFor(Supplier check) throws InterruptedException {
+ waitFor(check, 1000);
+ }
+
+ /**
+ * Wait for check returns true.
+ * @param check user defined checker
+ * @param checkEveryMillis interval to call check
+ */
+ public abstract void waitFor(Supplier check, int checkEveryMillis)
+ throws InterruptedException;
+
}
diff --git hadoop-yarn-project/hadoop-yarn/hadoop-yarn-client/src/main/java/org/apache/hadoop/yarn/client/api/async/AMRMClientAsync.java hadoop-yarn-project/hadoop-yarn/hadoop-yarn-client/src/main/java/org/apache/hadoop/yarn/client/api/async/AMRMClientAsync.java
index e726b73..60862d6 100644
--- hadoop-yarn-project/hadoop-yarn/hadoop-yarn-client/src/main/java/org/apache/hadoop/yarn/client/api/async/AMRMClientAsync.java
+++ hadoop-yarn-project/hadoop-yarn/hadoop-yarn-client/src/main/java/org/apache/hadoop/yarn/client/api/async/AMRMClientAsync.java
@@ -18,9 +18,11 @@
package org.apache.hadoop.yarn.client.api.async;
+import com.google.common.base.Supplier;
import java.io.IOException;
import java.util.Collection;
import java.util.List;
+import java.util.concurrent.TimeoutException;
import java.util.concurrent.atomic.AtomicInteger;
import org.apache.hadoop.classification.InterfaceAudience.Private;
@@ -189,6 +191,23 @@ public abstract void unregisterApplicationMaster(
*/
public abstract int getClusterNodeCount();
+ /**
+ * Wait for check returns true.
+ * See also {@link #waitFor(com.google.common.base.Supplier, int)}
+ * @param check
+ */
+ public void waitFor(Supplier check) throws InterruptedException {
+ waitFor(check, 1000);
+ }
+
+ /**
+ * Wait for check returns true.
+ * @param check user defined checker
+ * @param checkEveryMillis interval to call check
+ */
+ public abstract void waitFor(Supplier check, int checkEveryMillis)
+ throws InterruptedException;
+
public interface CallbackHandler {
/**
diff --git hadoop-yarn-project/hadoop-yarn/hadoop-yarn-client/src/main/java/org/apache/hadoop/yarn/client/api/async/impl/AMRMClientAsyncImpl.java hadoop-yarn-project/hadoop-yarn/hadoop-yarn-client/src/main/java/org/apache/hadoop/yarn/client/api/async/impl/AMRMClientAsyncImpl.java
index e7659bd..a09eede 100644
--- hadoop-yarn-project/hadoop-yarn/hadoop-yarn-client/src/main/java/org/apache/hadoop/yarn/client/api/async/impl/AMRMClientAsyncImpl.java
+++ hadoop-yarn-project/hadoop-yarn/hadoop-yarn-client/src/main/java/org/apache/hadoop/yarn/client/api/async/impl/AMRMClientAsyncImpl.java
@@ -18,17 +18,20 @@
package org.apache.hadoop.yarn.client.api.async.impl;
+import com.google.common.base.Supplier;
import java.io.IOException;
import java.util.Collection;
import java.util.List;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.LinkedBlockingQueue;
+import java.util.concurrent.TimeoutException;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.classification.InterfaceAudience.Private;
import org.apache.hadoop.classification.InterfaceStability.Unstable;
import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.util.Time;
import org.apache.hadoop.yarn.api.protocolrecords.AllocateResponse;
import org.apache.hadoop.yarn.api.protocolrecords.RegisterApplicationMasterResponse;
import org.apache.hadoop.yarn.api.records.AMCommand;
@@ -66,7 +69,7 @@
private volatile float progress;
private volatile Throwable savedException;
-
+
public AMRMClientAsyncImpl(int intervalMs, CallbackHandler callbackHandler) {
this(new AMRMClientImpl(), intervalMs, callbackHandler);
}
@@ -150,12 +153,12 @@ public RegisterApplicationMasterResponse registerApplicationMaster(
* @throws IOException
*/
public void unregisterApplicationMaster(FinalApplicationStatus appStatus,
- String appMessage, String appTrackingUrl) throws YarnException,
- IOException {
- synchronized (unregisterHeartbeatLock) {
- keepRunning = false;
- client.unregisterApplicationMaster(appStatus, appMessage, appTrackingUrl);
- }
+ String appMessage, String appTrackingUrl) throws YarnException,
+ IOException {
+ synchronized (unregisterHeartbeatLock) {
+ keepRunning = false;
+ client.unregisterApplicationMaster(appStatus, appMessage, appTrackingUrl);
+ }
}
/**
@@ -252,7 +255,35 @@ public void run() {
}
}
}
-
+
+ @Override
+ public void waitFor(Supplier check, int checkEveryMillis)
+ throws InterruptedException {
+ if (checkEveryMillis <= 0) {
+ checkEveryMillis = 10000;
+ }
+ final int loggingCounterInitValue = 60000 / checkEveryMillis;
+ int loggingCounter = loggingCounterInitValue;
+
+ do {
+ if (LOG.isDebugEnabled()) {
+ LOG.debug("Check the condition for main loop.");
+ }
+
+ boolean result = check.get();
+ if (result) {
+ LOG.info("Exits the main loop.");
+ return;
+ }
+ if (--loggingCounter <= 0) {
+ LOG.info("Waiting in main loop.");
+ loggingCounter = loggingCounterInitValue;
+ }
+
+ Thread.sleep(checkEveryMillis);
+ } while (true);
+ }
+
private class CallbackHandlerThread extends Thread {
public CallbackHandlerThread() {
super("AMRM Callback Handler Thread");
diff --git hadoop-yarn-project/hadoop-yarn/hadoop-yarn-client/src/main/java/org/apache/hadoop/yarn/client/api/impl/AMRMClientImpl.java hadoop-yarn-project/hadoop-yarn/hadoop-yarn-client/src/main/java/org/apache/hadoop/yarn/client/api/impl/AMRMClientImpl.java
index 1db7054..d21d177 100644
--- hadoop-yarn-project/hadoop-yarn/hadoop-yarn-client/src/main/java/org/apache/hadoop/yarn/client/api/impl/AMRMClientImpl.java
+++ hadoop-yarn-project/hadoop-yarn/hadoop-yarn-client/src/main/java/org/apache/hadoop/yarn/client/api/impl/AMRMClientImpl.java
@@ -18,6 +18,7 @@
package org.apache.hadoop.yarn.client.api.impl;
+import com.google.common.base.Supplier;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collection;
@@ -743,4 +744,32 @@ public synchronized void updateBlacklist(List blacklistAdditions,
"blacklistRemovals in updateBlacklist.");
}
}
+
+ @Override
+ public void waitFor(Supplier check, int checkEveryMillis)
+ throws InterruptedException {
+ if (checkEveryMillis <= 0) {
+ checkEveryMillis = 10000;
+ }
+ final int loggingCounterInitValue = 60000 / checkEveryMillis;
+ int loggingCounter = loggingCounterInitValue;
+
+ do {
+ if (LOG.isDebugEnabled()) {
+ LOG.debug("Check the condition for main loop.");
+ }
+
+ boolean result = check.get();
+ if (result) {
+ LOG.info("Exits the main loop.");
+ return;
+ }
+ if (--loggingCounter <= 0) {
+ LOG.info("Waiting in main loop.");
+ loggingCounter = loggingCounterInitValue;
+ }
+
+ Thread.sleep(checkEveryMillis);
+ } while (true);
+ }
}
diff --git hadoop-yarn-project/hadoop-yarn/hadoop-yarn-client/src/test/java/org/apache/hadoop/yarn/client/api/async/impl/TestAMRMClientAsync.java hadoop-yarn-project/hadoop-yarn/hadoop-yarn-client/src/test/java/org/apache/hadoop/yarn/client/api/async/impl/TestAMRMClientAsync.java
index 728a558..29560f7 100644
--- hadoop-yarn-project/hadoop-yarn/hadoop-yarn-client/src/test/java/org/apache/hadoop/yarn/client/api/async/impl/TestAMRMClientAsync.java
+++ hadoop-yarn-project/hadoop-yarn/hadoop-yarn-client/src/test/java/org/apache/hadoop/yarn/client/api/async/impl/TestAMRMClientAsync.java
@@ -18,6 +18,7 @@
package org.apache.hadoop.yarn.client.api.async.impl;
+import com.google.common.base.Supplier;
import static org.mockito.Matchers.anyFloat;
import static org.mockito.Matchers.anyInt;
import static org.mockito.Matchers.anyString;
@@ -180,7 +181,7 @@ private void runHeartBeatThrowOutException(Exception ex) throws Exception{
AMRMClient client = mock(AMRMClientImpl.class);
when(client.allocate(anyFloat())).thenThrow(ex);
- AMRMClientAsync asyncClient =
+ AMRMClientAsync asyncClient =
AMRMClientAsync.createAMRMClientAsync(client, 20, callbackHandler);
asyncClient.init(conf);
asyncClient.start();
@@ -203,6 +204,73 @@ private void runHeartBeatThrowOutException(Exception ex) throws Exception{
Assert.assertTrue(callbackHandler.callbackCount == 0);
}
+ @Test//(timeout=10000)
+ public void testAMRMClientAsyncReboot() throws Exception {
+ Configuration conf = new Configuration();
+ TestCallbackHandler callbackHandler = new TestCallbackHandler();
+ @SuppressWarnings("unchecked")
+ AMRMClient client = mock(AMRMClientImpl.class);
+
+ final AllocateResponse rebootResponse = createAllocateResponse(
+ new ArrayList(), new ArrayList(), null);
+ rebootResponse.setAMCommand(AMCommand.AM_RESYNC);
+ when(client.allocate(anyFloat())).thenReturn(rebootResponse);
+
+ AMRMClientAsync asyncClient =
+ AMRMClientAsync.createAMRMClientAsync(client, 20, callbackHandler);
+ asyncClient.init(conf);
+ asyncClient.start();
+
+ synchronized (callbackHandler.notifier) {
+ asyncClient.registerApplicationMaster("localhost", 1234, null);
+ while(callbackHandler.reboot == false) {
+ try {
+ callbackHandler.notifier.wait();
+ } catch (InterruptedException e) {
+ e.printStackTrace();
+ }
+ }
+ }
+
+ asyncClient.stop();
+ // stopping should have joined all threads and completed all callbacks
+ Assert.assertTrue(callbackHandler.callbackCount == 0);
+ }
+
+ @Test
+ public void testAMRMClientAsyncRebootWithWaitFor() throws Exception {
+ Configuration conf = new Configuration();
+ final TestCallbackHandler callbackHandler = new TestCallbackHandler();
+ @SuppressWarnings("unchecked")
+ AMRMClient client = mock(AMRMClientImpl.class);
+
+ final AllocateResponse rebootResponse = createAllocateResponse(
+ new ArrayList(), new ArrayList(), null);
+ rebootResponse.setAMCommand(AMCommand.AM_RESYNC);
+ when(client.allocate(anyFloat())).thenReturn(rebootResponse);
+
+ AMRMClientAsync asyncClient =
+ AMRMClientAsync.createAMRMClientAsync(client, 20, callbackHandler);
+ asyncClient.init(conf);
+ asyncClient.start();
+
+ Supplier checker = new Supplier() {
+ @Override
+ public Boolean get() {
+ return callbackHandler.reboot;
+ }
+ };
+
+ synchronized (callbackHandler.notifier) {
+ asyncClient.registerApplicationMaster("localhost", 1234, null);
+ asyncClient.waitFor(checker);
+ }
+
+ asyncClient.stop();
+ // stopping should have joined all threads and completed all callbacks
+ Assert.assertTrue(callbackHandler.callbackCount == 0);
+ }
+
@Test (timeout = 10000)
public void testAMRMClientAsyncShutDown() throws Exception {
Configuration conf = new Configuration();
@@ -262,6 +330,41 @@ public void testCallAMRMClientAsyncStopFromCallbackHandler()
}
}
+ @Test (timeout = 5000)
+ public void testCallAMRMClientAsyncStopFromCallbackHandlerWithWaitFor()
+ throws YarnException, IOException, InterruptedException {
+ Configuration conf = new Configuration();
+ final TestCallbackHandler2 callbackHandler = new TestCallbackHandler2();
+ @SuppressWarnings("unchecked")
+ AMRMClient client = mock(AMRMClientImpl.class);
+
+ List completed = Arrays.asList(
+ ContainerStatus.newInstance(newContainerId(0, 0, 0, 0),
+ ContainerState.COMPLETE, "", 0));
+ final AllocateResponse response = createAllocateResponse(completed,
+ new ArrayList(), null);
+
+ when(client.allocate(anyFloat())).thenReturn(response);
+
+ AMRMClientAsync asyncClient =
+ AMRMClientAsync.createAMRMClientAsync(client, 20, callbackHandler);
+ callbackHandler.asynClient = asyncClient;
+ asyncClient.init(conf);
+ asyncClient.start();
+
+ Supplier checker = new Supplier() {
+ @Override
+ public Boolean get() {
+ return callbackHandler.notify;
+ }
+ };
+
+ synchronized (callbackHandler.notifier) {
+ asyncClient.registerApplicationMaster("localhost", 1234, null);
+ asyncClient.waitFor(checker);
+ }
+ }
+
void runCallBackThrowOutException(TestCallbackHandler2 callbackHandler) throws
InterruptedException, YarnException, IOException {
Configuration conf = new Configuration();
diff --git hadoop-yarn-project/hadoop-yarn/hadoop-yarn-client/src/test/java/org/apache/hadoop/yarn/client/api/impl/TestAMRMClient.java hadoop-yarn-project/hadoop-yarn/hadoop-yarn-client/src/test/java/org/apache/hadoop/yarn/client/api/impl/TestAMRMClient.java
index 5961532..b720db8 100644
--- hadoop-yarn-project/hadoop-yarn/hadoop-yarn-client/src/test/java/org/apache/hadoop/yarn/client/api/impl/TestAMRMClient.java
+++ hadoop-yarn-project/hadoop-yarn/hadoop-yarn-client/src/test/java/org/apache/hadoop/yarn/client/api/impl/TestAMRMClient.java
@@ -18,6 +18,7 @@
package org.apache.hadoop.yarn.client.api.impl;
+import com.google.common.base.Supplier;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;
@@ -800,6 +801,37 @@ public AllocateResponse answer(InvocationOnMock invocation)
assertEquals(0, amClient.ask.size());
assertEquals(0, amClient.release.size());
}
+
+ @Test
+ public void testWaitFor() throws InterruptedException {
+ AMRMClientImpl amClient = null;
+ Supplier countDownChecker = new Supplier() {
+ int counter = 0;
+ @Override
+ public Boolean get() {
+ counter++;
+ if (counter > 3) {
+ return true;
+ } else {
+ return false;
+ }
+ }
+ };
+
+ try {
+ // start am rm client
+ amClient =
+ (AMRMClientImpl) AMRMClient
+ . createAMRMClient();
+ amClient.init(new YarnConfiguration());
+ amClient.start();
+ amClient.waitFor(countDownChecker, 1000);
+ } finally {
+ if (amClient != null) {
+ amClient.stop();
+ }
+ }
+ }
private void sleep(int sleepTime) {
try {