diff --git a/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/util/Shell.java b/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/util/Shell.java index 83877b730a..600f799fa2 100644 --- a/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/util/Shell.java +++ b/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/util/Shell.java @@ -27,7 +27,9 @@ import java.nio.charset.Charset; import java.util.Arrays; import java.util.Collections; +import java.util.HashSet; import java.util.Map; +import java.util.Set; import java.util.Timer; import java.util.TimerTask; import java.util.WeakHashMap; @@ -50,8 +52,8 @@ @InterfaceAudience.Public @InterfaceStability.Evolving public abstract class Shell { - private static final Map CHILD_PROCESSES = - Collections.synchronizedMap(new WeakHashMap()); + private static final Map CHILD_SHELLS = + Collections.synchronizedMap(new WeakHashMap()); public static final Logger LOG = LoggerFactory.getLogger(Shell.class); /** @@ -820,6 +822,7 @@ private static boolean isSetsidSupported() { private File dir; private Process process; // sub process used to execute the command private int exitCode; + private Thread thread; /** Flag to indicate whether or not the script has finished executing. */ private final AtomicBoolean completed = new AtomicBoolean(false); @@ -920,7 +923,9 @@ private void runCommand() throws IOException { } else { process = builder.start(); } - CHILD_PROCESSES.put(process, null); + + thread = Thread.currentThread(); + CHILD_SHELLS.put(this, null); if (timeOutInterval > 0) { timeOutTimer = new Timer("Shell command timeout"); @@ -1017,7 +1022,7 @@ public void run() { LOG.warn("Error while closing the error stream", ioe); } process.destroy(); - CHILD_PROCESSES.remove(process); + CHILD_SHELLS.remove(this); lastTime = Time.monotonicNow(); } } @@ -1065,6 +1070,14 @@ public int getExitCode() { return exitCode; } + /** get the thread that is running this instance of Shell + * @return the thread of the process + */ + public Thread getThread() { + return thread; + } + + /** * This is an IOException with exit code added. */ @@ -1319,19 +1332,38 @@ public void run() { /** * Static method to destroy all running Shell processes - * Iterates through a list of all currently running Shell - * processes and destroys them one by one. This method is thread safe and - * is intended to be used in a shutdown hook. + * Iterates through a map of all currently running Shell + * processes and destroys them one by one. This method is thread safe */ - public static void destroyAllProcesses() { - synchronized (CHILD_PROCESSES) { - for (Process key : CHILD_PROCESSES.keySet()) { - Process process = key; - if (key != null) { - process.destroy(); + public static void destroyAllShellProcesses() { + synchronized (CHILD_SHELLS) { + for (Shell shell : CHILD_SHELLS.keySet()) { + if (shell.getProcess() != null) { + shell.getProcess().destroy(); } } - CHILD_PROCESSES.clear(); + CHILD_SHELLS.clear(); + } + } + + /** + * Static method to return a Set of all Shell objects + */ + public static Set getAllShells() { + synchronized (CHILD_SHELLS) { + Set childShells = new HashSet <> + (CHILD_SHELLS.keySet()); + return childShells; + } + } + + /** + * Static method to remove a Shell object from the Set of all + * currently running Shells + */ + public static void removeShell(Shell shell) { + synchronized (CHILD_SHELLS) { + CHILD_SHELLS.remove(shell); } } } diff --git a/hadoop-common-project/hadoop-common/src/test/java/org/apache/hadoop/util/TestShell.java b/hadoop-common-project/hadoop-common/src/test/java/org/apache/hadoop/util/TestShell.java index 88859b55aa..4f30edb005 100644 --- a/hadoop-common-project/hadoop-common/src/test/java/org/apache/hadoop/util/TestShell.java +++ b/hadoop-common-project/hadoop-common/src/test/java/org/apache/hadoop/util/TestShell.java @@ -474,7 +474,7 @@ public void testBashQuote() { } @Test(timeout=120000) - public void testShellKillAllProcesses() throws Throwable { + public void testDestroyAllShellProcesses() throws Throwable { Assume.assumeFalse(WINDOWS); StringBuffer sleepCommand = new StringBuffer(); sleepCommand.append("sleep 200"); @@ -519,7 +519,7 @@ public Boolean get() { } }, 10, 10000); - Shell.destroyAllProcesses(); + Shell.destroyAllShellProcesses(); shexc1.getProcess().waitFor(); shexc2.getProcess().waitFor(); } diff --git a/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/localizer/ContainerLocalizer.java b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/localizer/ContainerLocalizer.java index 04be6318c5..f9b12040f4 100644 --- a/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/localizer/ContainerLocalizer.java +++ b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/localizer/ContainerLocalizer.java @@ -25,9 +25,11 @@ import java.util.ArrayList; import java.util.Arrays; import java.util.HashMap; +import java.util.HashSet; import java.util.Iterator; import java.util.List; import java.util.Map; +import java.util.Set; import java.util.concurrent.Callable; import java.util.concurrent.CancellationException; import java.util.concurrent.CompletionService; @@ -53,6 +55,7 @@ import org.apache.hadoop.security.token.TokenIdentifier; import org.apache.hadoop.util.DiskValidator; import org.apache.hadoop.util.DiskValidatorFactory; +import org.apache.hadoop.util.Shell; import org.apache.hadoop.util.concurrent.HadoopExecutors; import org.apache.hadoop.yarn.YarnUncaughtExceptionHandler; import org.apache.hadoop.yarn.api.records.LocalResource; @@ -75,6 +78,8 @@ import com.google.common.annotations.VisibleForTesting; import com.google.common.util.concurrent.ThreadFactoryBuilder; +import static org.apache.hadoop.util.Shell.*; + public class ContainerLocalizer { static final Log LOG = LogFactory.getLog(ContainerLocalizer.class); @@ -101,6 +106,8 @@ private final String appCacheDirContextName; private final DiskValidator diskValidator; + private Set localizingThreads = new HashSet<>(); + public ContainerLocalizer(FileContext lfs, String user, String appId, String localizerId, List localDirs, RecordFactory recordFactory) throws IOException { @@ -184,7 +191,9 @@ public LocalizationProtocol run() { } finally { try { if (exec != null) { - exec.shutdownNow(); + exec.shutdown(); + destroyShellProcesses(getAllShells()); + exec.awaitTermination(10, TimeUnit.SECONDS); } LocalDirAllocator.removeContext(appCacheDirContextName); } finally { @@ -202,10 +211,34 @@ ExecutorService createDownloadThreadPool() { return new ExecutorCompletionService(exec); } + public class FSDownloadWrapper extends FSDownload { + + public FSDownloadWrapper(FileContext files, UserGroupInformation ugi, Configuration conf, + Path destDirPath, LocalResource resource) { + super(files, ugi, conf, destDirPath, resource); + } + + @Override + public Path call() throws Exception { + Thread currentThread = Thread.currentThread(); + localizingThreads.add(currentThread); + try { + return doDownloadCall(); + } finally { + localizingThreads.remove(currentThread); + } + } + + public Path doDownloadCall() throws Exception { + return super.call(); + } + + } + Callable download(Path path, LocalResource rsrc, UserGroupInformation ugi) throws IOException { diskValidator.checkStatus(new File(path.toUri().getRawPath())); - return new FSDownload(lfs, ugi, conf, path, rsrc); + return new FSDownloadWrapper(lfs, ugi, conf, path, rsrc); } static long getEstimatedSize(LocalResource rsrc) { @@ -363,6 +396,7 @@ public static void buildMainArgs(List command, public static void main(String[] argv) throws Throwable { Thread.setDefaultUncaughtExceptionHandler(new YarnUncaughtExceptionHandler()); + int nRet = 0; // usage: $0 user appId locId host port app_log_dir user_dir [user_dir]* // let $x = $x/usercache for $local.dir // MKDIR $x/$user/appcache/$appid @@ -399,7 +433,9 @@ public static void main(String[] argv) throws Throwable { // space in both DefaultCE and LCE cases e.printStackTrace(System.out); LOG.error("Exception in main:", e); - System.exit(-1); + nRet = -1; + } finally { + System.exit(nRet); } } @@ -436,4 +472,13 @@ private static void createDir(FileContext lfs, Path dirPath, lfs.setPermission(dirPath, perms); } } + + private void destroyShellProcesses(Set shells) { + for (Shell shell : shells) { + if(localizingThreads.contains(shell.getThread())) { + shell.getProcess().destroy(); + removeShell(shell); + } + } + } } diff --git a/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/test/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/localizer/TestContainerLocalizer.java b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/test/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/localizer/TestContainerLocalizer.java index fac708655f..8c3fa31683 100644 --- a/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/test/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/localizer/TestContainerLocalizer.java +++ b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/test/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/localizer/TestContainerLocalizer.java @@ -17,6 +17,8 @@ */ package org.apache.hadoop.yarn.server.nodemanager.containermanager.localizer; +import static junit.framework.TestCase.assertFalse; +import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; import static org.mockito.Matchers.any; import static org.mockito.Matchers.anyBoolean; @@ -25,6 +27,7 @@ import static org.mockito.Matchers.eq; import static org.mockito.Matchers.isA; import static org.mockito.Matchers.same; +import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.doNothing; import static org.mockito.Mockito.doReturn; import static org.mockito.Mockito.doThrow; @@ -45,7 +48,10 @@ import java.util.concurrent.CompletionService; import java.util.concurrent.ExecutorService; import java.util.concurrent.Future; +import java.util.concurrent.TimeUnit; +import com.google.common.base.Supplier; +import com.google.common.util.concurrent.ThreadFactoryBuilder; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.apache.hadoop.conf.Configuration; @@ -60,6 +66,10 @@ import org.apache.hadoop.security.Credentials; import org.apache.hadoop.security.UserGroupInformation; import org.apache.hadoop.security.token.Token; +import org.apache.hadoop.test.GenericTestUtils; +import org.apache.hadoop.util.Shell; +import org.apache.hadoop.util.Shell.ShellCommandExecutor; +import org.apache.hadoop.util.concurrent.HadoopExecutors; import org.apache.hadoop.yarn.api.records.LocalResource; import org.apache.hadoop.yarn.api.records.LocalResourceType; import org.apache.hadoop.yarn.api.records.LocalResourceVisibility; @@ -76,6 +86,7 @@ import org.junit.Assert; import org.junit.Test; import org.mockito.ArgumentMatcher; +import org.mockito.Mockito; import org.mockito.invocation.InvocationOnMock; import org.mockito.stubbing.Answer; @@ -104,6 +115,7 @@ public void testMain() throws Exception { spylfs = spy(fs.getDefaultFileSystem()); ContainerLocalizer localizer = setupContainerLocalizerForTest(); + mockOutDownloads(localizer); // verify created cache List privCacheList = new ArrayList(); @@ -131,7 +143,7 @@ public void testMain() throws Exception { ResourceLocalizationSpec rsrcD = getMockRsrc(random, LocalResourceVisibility.PRIVATE, privCacheList.get(0)); - + when(nmProxy.heartbeat(isA(LocalizerStatus.class))) .thenReturn(new MockLocalizerHeartbeatResponse(LocalizerAction.LIVE, Collections.singletonList(rsrcA))) @@ -206,6 +218,7 @@ public void testMainFailure() throws Exception { FileContext fs = FileContext.getLocalFSFileContext(); spylfs = spy(fs.getDefaultFileSystem()); ContainerLocalizer localizer = setupContainerLocalizerForTest(); + mockOutDownloads(localizer); // Assume the NM heartbeat fails say because of absent tokens. when(nmProxy.heartbeat(isA(LocalizerStatus.class))).thenThrow( @@ -226,6 +239,7 @@ public void testLocalizerTokenIsGettingRemoved() throws Exception { FileContext fs = FileContext.getLocalFSFileContext(); spylfs = spy(fs.getDefaultFileSystem()); ContainerLocalizer localizer = setupContainerLocalizerForTest(); + mockOutDownloads(localizer); doNothing().when(localizer).localizeFiles(any(LocalizationProtocol.class), any(CompletionService.class), any(UserGroupInformation.class)); localizer.runLocalization(nmAddr); @@ -241,6 +255,7 @@ public void testContainerLocalizerClosesFilesystems() throws Exception { spylfs = spy(fs.getDefaultFileSystem()); ContainerLocalizer localizer = setupContainerLocalizerForTest(); + mockOutDownloads(localizer); doNothing().when(localizer).localizeFiles(any(LocalizationProtocol.class), any(CompletionService.class), any(UserGroupInformation.class)); verify(localizer, never()).closeFileSystems( @@ -266,8 +281,107 @@ public void testContainerLocalizerClosesFilesystems() throws Exception { } } + @Test + public void testMultipleLocalizers() throws Exception { + FakeTestContainerLocalizer testA = new FakeTestContainerLocalizer(); + FakeTestContainerLocalizer testB = new FakeTestContainerLocalizer(); + + FakeContainerLocalizer localizerA = testA.init(); + FakeContainerLocalizer localizerB = testB.init(); + + // run localization + Thread threadA = new Thread() { + @Override + public void run() { + try { + localizerA.runLocalization(nmAddr); + } catch (Exception e) { + + } + } + }; + Thread threadB = new Thread() { + @Override + public void run() { + try { + localizerB.runLocalization(nmAddr); + } catch (Exception e) { + + } + } + }; + ShellCommandExecutor shexcA = null; + ShellCommandExecutor shexcB = null; + try { + threadA.start(); + threadB.start(); + + GenericTestUtils.waitFor(new Supplier() { + @Override + public Boolean get() { + FakeContainerLocalizer.FakeLongDownload downloader = localizerA.getDownloader(); + if (downloader != null && downloader.getShexc() != null && + downloader.getShexc().getProcess() != null) { + return true; + } else { + return false; + } + } + }, 1000, 30000); + + GenericTestUtils.waitFor(new Supplier() { + @Override + public Boolean get() { + FakeContainerLocalizer.FakeLongDownload downloader = localizerB.getDownloader(); + if (downloader != null && downloader.getShexc() != null && + downloader.getShexc().getProcess() != null) { + return true; + } else { + return false; + } + } + }, 1000, 30000); + + shexcA = localizerA.getDownloader().getShexc(); + shexcB = localizerB.getDownloader().getShexc(); + + assertTrue("Localizer A process not running, but should be", + shexcA.getProcess().isAlive()); + assertTrue("Localizer B process not running, but should be", + shexcB.getProcess().isAlive()); + + // Stop heartbeat from giving anymore resources to download + testA.heartbeatResponse++; + testB.heartbeatResponse++; + + // Send DIE to localizerA. This should kill its subprocesses + testA.heartbeatResponse++; + + threadA.join(); + shexcA.getProcess().waitFor(10000, TimeUnit.MILLISECONDS); + + assertFalse("Localizer A process is still running, but shouldn't be", + shexcA.getProcess().isAlive()); + assertTrue("Localizer B process not running, but should be", + shexcB.getProcess().isAlive()); + + } finally { + // Make sure everything gets cleaned up + // Process A should already be dead + shexcA.getProcess().destroy(); + shexcB.getProcess().destroy(); + shexcA.getProcess().waitFor(10000, TimeUnit.MILLISECONDS); + shexcB.getProcess().waitFor(10000, TimeUnit.MILLISECONDS); + + threadA.join(); + // Send DIE to localizer B + testB.heartbeatResponse++; + threadB.join(); + } + } + @SuppressWarnings("unchecked") // mocked generics - private ContainerLocalizer setupContainerLocalizerForTest() + private FakeContainerLocalizer setupContainerLocalizerForTest() throws Exception { // don't actually create dirs doNothing().when(spylfs).mkdir( @@ -280,9 +394,9 @@ private ContainerLocalizer setupContainerLocalizerForTest() localDirs.add(lfs.makeQualified(new Path(basedir, i + ""))); } RecordFactory mockRF = getMockLocalizerRecordFactory(); - ContainerLocalizer concreteLoc = new ContainerLocalizer(lfs, appUser, + FakeContainerLocalizer concreteLoc = new FakeContainerLocalizer(lfs, appUser, appId, containerId, localDirs, mockRF); - ContainerLocalizer localizer = spy(concreteLoc); + FakeContainerLocalizer localizer = spy(concreteLoc); // return credential stream instead of opening local file random = new Random(); @@ -299,8 +413,11 @@ private ContainerLocalizer setupContainerLocalizerForTest() nmProxy = mock(LocalizationProtocol.class); doReturn(nmProxy).when(localizer).getProxy(nmAddr); doNothing().when(localizer).sleep(anyInt()); - + return localizer; + } + + private void mockOutDownloads(ContainerLocalizer localizer) { // return result instantly for deterministic test ExecutorService syncExec = mock(ExecutorService.class); CompletionService cs = mock(CompletionService.class); @@ -318,8 +435,6 @@ private ContainerLocalizer setupContainerLocalizerForTest() }); doReturn(syncExec).when(localizer).createDownloadThreadPool(); doReturn(cs).when(localizer).createCompletionService(syncExec); - - return localizer; } static class HBMatches extends ArgumentMatcher { @@ -363,6 +478,87 @@ public Path call() throws IOException { } } + class FakeContainerLocalizer extends ContainerLocalizer { + FakeLongDownload downloader; + + public FakeContainerLocalizer(FileContext lfs, String user, String appId, + String localizerId, List localDirs, + RecordFactory recordFactory) throws IOException { + super(lfs, user, appId, localizerId, localDirs, recordFactory); + } + + public FakeLongDownload getDownloader() { + return downloader; + } + + @Override + Callable download(Path path, LocalResource rsrc, UserGroupInformation ugi) throws IOException { + downloader = new FakeLongDownload(Mockito.mock(FileContext.class), ugi, new Configuration(), path, rsrc); + return downloader; + } + + class FakeLongDownload extends ContainerLocalizer.FSDownloadWrapper { + private final Path localPath; + Shell.ShellCommandExecutor shexc; + public FakeLongDownload(FileContext files, UserGroupInformation ugi, Configuration conf, + Path destDirPath, LocalResource resource) { + super(files, ugi, conf, destDirPath, resource); + this.localPath = new Path("file:///localcache"); + } + + public Shell.ShellCommandExecutor getShexc() { + return shexc; + } + + @Override + public Path doDownloadCall() throws IOException { + StringBuffer sleepCommand = new StringBuffer(); + sleepCommand.append("sleep 30"); + String[] shellCmd = {"bash", "-c", sleepCommand.toString()}; + shexc = new Shell.ShellCommandExecutor(shellCmd); + shexc.execute(); + + return localPath; + } + } + } + + class FakeTestContainerLocalizer extends TestContainerLocalizer { + int heartbeatResponse = 0; + public FakeContainerLocalizer init() throws Exception { + FileContext fs = FileContext.getLocalFSFileContext(); + spylfs = spy(fs.getDefaultFileSystem()); + FakeContainerLocalizer localizer = setupContainerLocalizerForTest(); + + // verify created cache + List privCacheList = new ArrayList(); + for (Path p : localDirs) { + Path base = new Path(new Path(p, ContainerLocalizer.USERCACHE), appUser); + Path privcache = new Path(base, ContainerLocalizer.FILECACHE); + privCacheList.add(privcache); + } + + ResourceLocalizationSpec rsrc = getMockRsrc(random, + LocalResourceVisibility.PRIVATE, privCacheList.get(0)); + + // mock heartbeat responses from NM + doAnswer(new Answer() { + @Override + public MockLocalizerHeartbeatResponse answer(InvocationOnMock invocationOnMock) throws Throwable { + if(heartbeatResponse == 0) { + return new MockLocalizerHeartbeatResponse(LocalizerAction.LIVE, Collections.singletonList(rsrc)); + } else if (heartbeatResponse < 2) { + return new MockLocalizerHeartbeatResponse(LocalizerAction.LIVE, Collections.emptyList()); + } else { + return new MockLocalizerHeartbeatResponse(LocalizerAction.DIE, null); + } + } + }).when(nmProxy).heartbeat(isA(LocalizerStatus.class)); + + return localizer; + } + } + static RecordFactory getMockLocalizerRecordFactory() { RecordFactory mockRF = mock(RecordFactory.class); when(mockRF.newRecordInstance(same(LocalResourceStatus.class)))