diff --git a/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-api/src/main/java/org/apache/hadoop/yarn/conf/YarnConfiguration.java b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-api/src/main/java/org/apache/hadoop/yarn/conf/YarnConfiguration.java index 7dd5ce3..06aa348 100644 --- a/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-api/src/main/java/org/apache/hadoop/yarn/conf/YarnConfiguration.java +++ b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-api/src/main/java/org/apache/hadoop/yarn/conf/YarnConfiguration.java @@ -2757,6 +2757,11 @@ public static boolean areNodeLabelsEnabled( public static final String TIMELINE_XFS_OPTIONS = TIMELINE_XFS_PREFIX + "xframe-options"; + /** Number of threads to use for private localization fetching. */ + public static final String NM_PRIVATE_LOCALIZER_FETCH_THREAD_COUNT = NM_PREFIX + + "localizer.private.fetch.thread-count"; + public static final int DEFAULT_NM_PRIVATE_LOCALIZER_FETCH_THREAD_COUNT = 1; + public YarnConfiguration() { super(); } diff --git a/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-common/src/main/resources/yarn-default.xml b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-common/src/main/resources/yarn-default.xml index 3c30ed3..4566eae 100644 --- a/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-common/src/main/resources/yarn-default.xml +++ b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-common/src/main/resources/yarn-default.xml @@ -2999,4 +2999,12 @@ 3000 + + + Number of threads to use for private localization fetching + + yarn.nodemanager.localizer.private.fetch.thread-count + 1 + + diff --git a/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/localizer/ContainerLocalizer.java b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/localizer/ContainerLocalizer.java index 04be631..1724f3f 100644 --- a/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/localizer/ContainerLocalizer.java +++ b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/localizer/ContainerLocalizer.java @@ -35,6 +35,9 @@ import java.util.concurrent.ExecutorCompletionService; import java.util.concurrent.ExecutorService; import java.util.concurrent.Future; +import java.util.concurrent.LinkedBlockingQueue; +import java.util.concurrent.ThreadFactory; +import java.util.concurrent.ThreadPoolExecutor; import java.util.concurrent.TimeUnit; import org.apache.commons.logging.Log; @@ -53,7 +56,6 @@ import org.apache.hadoop.security.token.TokenIdentifier; import org.apache.hadoop.util.DiskValidator; import org.apache.hadoop.util.DiskValidatorFactory; -import org.apache.hadoop.util.concurrent.HadoopExecutors; import org.apache.hadoop.yarn.YarnUncaughtExceptionHandler; import org.apache.hadoop.yarn.api.records.LocalResource; import org.apache.hadoop.yarn.api.records.SerializedException; @@ -173,11 +175,11 @@ public LocalizationProtocol run() { ugi.addToken(token); } - ExecutorService exec = null; + ThreadPoolExecutor exec = null; try { exec = createDownloadThreadPool(); CompletionService ecs = createCompletionService(exec); - localizeFiles(nodeManager, ecs, ugi); + localizeFiles(nodeManager, ecs, ugi, exec); return; } catch (Throwable e) { throw new IOException(e); @@ -192,10 +194,22 @@ public LocalizationProtocol run() { } } } + + int getDownloadThreadCount() { + return conf.getInt( + YarnConfiguration.NM_PRIVATE_LOCALIZER_FETCH_THREAD_COUNT, + YarnConfiguration.DEFAULT_NM_PRIVATE_LOCALIZER_FETCH_THREAD_COUNT); + } - ExecutorService createDownloadThreadPool() { - return HadoopExecutors.newSingleThreadExecutor(new ThreadFactoryBuilder() - .setNameFormat("ContainerLocalizer Downloader").build()); + ThreadPoolExecutor createDownloadThreadPool() { + int nThreads = getDownloadThreadCount(); + ThreadFactory tf = new ThreadFactoryBuilder() + .setNameFormat("ContainerLocalizer Downloader #%d") + .build(); + return new ThreadPoolExecutor(nThreads, nThreads, + 0L, TimeUnit.MILLISECONDS, + new LinkedBlockingQueue(), + tf); } CompletionService createCompletionService(ExecutorService exec) { @@ -235,9 +249,11 @@ protected void closeFileSystems(UserGroupInformation ugi) { } protected void localizeFiles(LocalizationProtocol nodemanager, - CompletionService cs, UserGroupInformation ugi) + CompletionService cs, UserGroupInformation ugi, ThreadPoolExecutor exec) throws IOException, YarnException { + int downloadThreadCount = getDownloadThreadCount(); while (true) { + boolean newRequestReceived = false; try { LocalizerStatus status = createStatus(); LocalizerHeartbeatResponse response = nodemanager.heartbeat(status); @@ -246,6 +262,7 @@ protected void localizeFiles(LocalizationProtocol nodemanager, List newRsrcs = response.getResourceSpecs(); for (ResourceLocalizationSpec newRsrc : newRsrcs) { if (!pendingResources.containsKey(newRsrc.getResource())) { + newRequestReceived = true; pendingResources.put(newRsrc.getResource(), cs.submit(download( new Path(newRsrc.getDestinationDirectory().getFile()), newRsrc.getResource(), ugi))); @@ -269,7 +286,14 @@ protected void localizeFiles(LocalizationProtocol nodemanager, } return; } - cs.poll(1000, TimeUnit.MILLISECONDS); + if (exec.getActiveCount() >= downloadThreadCount + || !newRequestReceived) { + // Each heartbeat gives us only 1 resource to download. Don't wait + // for the first 'threadCount' heartbeats to allow parallel download. + // Subsequent downloads are also parallel because cs.poll(...) + // returns early when any download finishes before the timeout. + cs.poll(1000, TimeUnit.MILLISECONDS); + } } catch (InterruptedException e) { return; } catch (YarnException e) { diff --git a/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/localizer/ResourceLocalizationService.java b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/localizer/ResourceLocalizationService.java index 37473e3..a51b824 100644 --- a/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/localizer/ResourceLocalizationService.java +++ b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/localizer/ResourceLocalizationService.java @@ -1110,10 +1110,9 @@ LocalizerHeartbeatResponse processHeartbeat( List rsrcs = new ArrayList(); - /* - * TODO : It doesn't support multiple downloads per ContainerLocalizer - * at the same time. We need to think whether we should support this. - */ + // Return one resource per heartbeat. + // ContainerLocalizer can run multiple heartbeats to get multiple + // resources LocalResource next = findNextResource(); if (next != null) { try { diff --git a/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/test/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/localizer/TestContainerLocalizer.java b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/test/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/localizer/TestContainerLocalizer.java index fac7086..9b89328 100644 --- a/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/test/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/localizer/TestContainerLocalizer.java +++ b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/test/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/localizer/TestContainerLocalizer.java @@ -43,8 +43,7 @@ import java.util.Random; import java.util.concurrent.Callable; import java.util.concurrent.CompletionService; -import java.util.concurrent.ExecutorService; -import java.util.concurrent.Future; +import java.util.concurrent.ThreadPoolExecutor; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; @@ -64,6 +63,7 @@ import org.apache.hadoop.yarn.api.records.LocalResourceType; import org.apache.hadoop.yarn.api.records.LocalResourceVisibility; import org.apache.hadoop.yarn.api.records.URL; +import org.apache.hadoop.yarn.conf.YarnConfiguration; import org.apache.hadoop.yarn.exceptions.YarnException; import org.apache.hadoop.yarn.exceptions.YarnRuntimeException; import org.apache.hadoop.yarn.factories.RecordFactory; @@ -72,7 +72,6 @@ import org.apache.hadoop.yarn.server.nodemanager.api.protocolrecords.LocalResourceStatus; import org.apache.hadoop.yarn.server.nodemanager.api.protocolrecords.LocalizerAction; import org.apache.hadoop.yarn.server.nodemanager.api.protocolrecords.LocalizerStatus; -import org.apache.hadoop.yarn.util.ConverterUtils; import org.junit.Assert; import org.junit.Test; import org.mockito.ArgumentMatcher; @@ -200,6 +199,120 @@ public boolean matches(Object o) { })); } + private ContainerLocalizer setupContainerLocalizerForTest() throws Exception { + Configuration conf = new Configuration(); + return setupContainerLocalizerForTest(conf); + } + + @Test + public void testMultipleThreadDownload() throws Exception { + FileContext fs = FileContext.getLocalFSFileContext(); + spylfs = spy(fs.getDefaultFileSystem()); + Configuration conf = new Configuration(); + conf.setInt(YarnConfiguration.NM_PRIVATE_LOCALIZER_FETCH_THREAD_COUNT, 4); + ContainerLocalizer localizer = setupContainerLocalizerForTest(conf); + + // verify created cache + List privCacheList = new ArrayList(); + List appCacheList = new ArrayList(); + for (Path p : localDirs) { + Path base = new Path(new Path(p, ContainerLocalizer.USERCACHE), appUser); + Path privcache = new Path(base, ContainerLocalizer.FILECACHE); + privCacheList.add(privcache); + Path appDir = new Path(base, + new Path(ContainerLocalizer.APPCACHE, appId)); + Path appcache = new Path(appDir, ContainerLocalizer.FILECACHE); + appCacheList.add(appcache); + } + + // mock heartbeat responses from NM + ResourceLocalizationSpec rsrcA = getMockRsrc(random, + LocalResourceVisibility.PRIVATE, privCacheList.get(0)); + ResourceLocalizationSpec rsrcB = getMockRsrc(random, + LocalResourceVisibility.PRIVATE, privCacheList.get(0)); + ResourceLocalizationSpec rsrcC = getMockRsrc(random, + LocalResourceVisibility.APPLICATION, appCacheList.get(0)); + ResourceLocalizationSpec rsrcD = getMockRsrc(random, + LocalResourceVisibility.PRIVATE, privCacheList.get(0)); + ResourceLocalizationSpec rsrcE = getMockRsrc(random, + LocalResourceVisibility.APPLICATION, appCacheList.get(0)); + ResourceLocalizationSpec rsrcF = getMockRsrc(random, + LocalResourceVisibility.PRIVATE, privCacheList.get(0)); + + when(nmProxy.heartbeat(isA(LocalizerStatus.class))) + .thenReturn(new MockLocalizerHeartbeatResponse(LocalizerAction.LIVE, + Collections.singletonList(rsrcA))) + .thenReturn(new MockLocalizerHeartbeatResponse(LocalizerAction.LIVE, + Collections.singletonList(rsrcB))) + .thenReturn(new MockLocalizerHeartbeatResponse(LocalizerAction.LIVE, + Collections.singletonList(rsrcC))) + .thenReturn(new MockLocalizerHeartbeatResponse(LocalizerAction.LIVE, + Collections.singletonList(rsrcD))) + .thenReturn(new MockLocalizerHeartbeatResponse(LocalizerAction.LIVE, + Collections.singletonList(rsrcE))) + .thenReturn(new MockLocalizerHeartbeatResponse(LocalizerAction.LIVE, + Collections.singletonList(rsrcF))) + .thenReturn(new MockLocalizerHeartbeatResponse(LocalizerAction.LIVE, + Collections. emptyList())) + .thenReturn( + new MockLocalizerHeartbeatResponse(LocalizerAction.DIE, null)); + + LocalResource tRsrcA = rsrcA.getResource(); + LocalResource tRsrcB = rsrcB.getResource(); + LocalResource tRsrcC = rsrcC.getResource(); + LocalResource tRsrcD = rsrcD.getResource(); + LocalResource tRsrcE = rsrcE.getResource(); + LocalResource tRsrcF = rsrcF.getResource(); + FakeLargeDownload download1 = new FakeLargeDownload( + rsrcA.getResource().getResource().getFile(), true); + FakeLargeDownload download2 = new FakeLargeDownload( + rsrcB.getResource().getResource().getFile(), true); + FakeLargeDownload download3 = new FakeLargeDownload( + rsrcC.getResource().getResource().getFile(), true); + FakeLargeDownload download4 = new FakeLargeDownload( + rsrcD.getResource().getResource().getFile(), true); + FakeLargeDownload download5 = new FakeLargeDownload( + rsrcE.getResource().getResource().getFile(), true); + FakeLargeDownload download6 = new FakeLargeDownload( + rsrcF.getResource().getResource().getFile(), true); + doReturn(download1).when(localizer).download(isA(Path.class), eq(tRsrcA), + isA(UserGroupInformation.class)); + doReturn(download2).when(localizer).download(isA(Path.class), eq(tRsrcB), + isA(UserGroupInformation.class)); + doReturn(download3).when(localizer).download(isA(Path.class), eq(tRsrcC), + isA(UserGroupInformation.class)); + doReturn(download4).when(localizer).download(isA(Path.class), eq(tRsrcD), + isA(UserGroupInformation.class)); + doReturn(download5).when(localizer).download(isA(Path.class), eq(tRsrcE), + isA(UserGroupInformation.class)); + doReturn(download6).when(localizer).download(isA(Path.class), eq(tRsrcF), + isA(UserGroupInformation.class)); + + // run localization + localizer.runLocalization(nmAddr); + for (Path p : localDirs) { + Path base = new Path(new Path(p, ContainerLocalizer.USERCACHE), appUser); + Path privcache = new Path(base, ContainerLocalizer.FILECACHE); + // $x/usercache/$user/filecache + verify(spylfs).mkdir(eq(privcache), eq(CACHE_DIR_PERM), eq(false)); + Path appDir = new Path(base, + new Path(ContainerLocalizer.APPCACHE, appId)); + // $x/usercache/$user/appcache/$appId/filecache + Path appcache = new Path(appDir, ContainerLocalizer.FILECACHE); + verify(spylfs).mkdir(eq(appcache), eq(CACHE_DIR_PERM), eq(false)); + } + // verify tokens read at expected location + verify(spylfs).open(tokenPath); + + // 4 is thread pool size, and as each download will sleep, only 4 download + Assert.assertTrue("rsrcA unable to download", download1.called); + Assert.assertTrue("rsrcB unable to download", download2.called); + Assert.assertTrue("rsrcC unable to download", download3.called); + Assert.assertTrue("rsrcD unable to download", download4.called); + Assert.assertFalse("rsrcE must not be downloaded", download5.called); + Assert.assertFalse("rsrcF must not be downloaded", download6.called); + } + @Test(timeout = 15000) public void testMainFailure() throws Exception { @@ -227,7 +340,8 @@ public void testLocalizerTokenIsGettingRemoved() throws Exception { spylfs = spy(fs.getDefaultFileSystem()); ContainerLocalizer localizer = setupContainerLocalizerForTest(); doNothing().when(localizer).localizeFiles(any(LocalizationProtocol.class), - any(CompletionService.class), any(UserGroupInformation.class)); + any(CompletionService.class), any(UserGroupInformation.class), + any(ThreadPoolExecutor.class)); localizer.runLocalization(nmAddr); verify(spylfs, times(1)).delete(tokenPath, false); } @@ -242,7 +356,8 @@ public void testContainerLocalizerClosesFilesystems() throws Exception { ContainerLocalizer localizer = setupContainerLocalizerForTest(); doNothing().when(localizer).localizeFiles(any(LocalizationProtocol.class), - any(CompletionService.class), any(UserGroupInformation.class)); + any(CompletionService.class), any(UserGroupInformation.class), + any(ThreadPoolExecutor.class)); verify(localizer, never()).closeFileSystems( any(UserGroupInformation.class)); @@ -253,9 +368,10 @@ public void testContainerLocalizerClosesFilesystems() throws Exception { // verify filesystems are closed when localizer fails localizer = setupContainerLocalizerForTest(); - doThrow(new YarnRuntimeException("Forced Failure")).when(localizer).localizeFiles( - any(LocalizationProtocol.class), any(CompletionService.class), - any(UserGroupInformation.class)); + doThrow(new YarnRuntimeException("Forced Failure")).when(localizer) + .localizeFiles(any(LocalizationProtocol.class), + any(CompletionService.class), any(UserGroupInformation.class), + any(ThreadPoolExecutor.class)); verify(localizer, never()).closeFileSystems( any(UserGroupInformation.class)); try { @@ -267,13 +383,12 @@ public void testContainerLocalizerClosesFilesystems() throws Exception { } @SuppressWarnings("unchecked") // mocked generics - private ContainerLocalizer setupContainerLocalizerForTest() + private ContainerLocalizer setupContainerLocalizerForTest(Configuration conf) throws Exception { // don't actually create dirs doNothing().when(spylfs).mkdir( isA(Path.class), isA(FsPermission.class), anyBoolean()); - Configuration conf = new Configuration(); FileContext lfs = FileContext.getFileContext(spylfs, conf); localDirs = new ArrayList(); for (int i = 0; i < 4; ++i) { @@ -298,26 +413,11 @@ private ContainerLocalizer setupContainerLocalizerForTest() ).when(spylfs).open(tokenPath); nmProxy = mock(LocalizationProtocol.class); doReturn(nmProxy).when(localizer).getProxy(nmAddr); + doReturn( + conf.getInt(YarnConfiguration.NM_PRIVATE_LOCALIZER_FETCH_THREAD_COUNT, + YarnConfiguration.DEFAULT_NM_PRIVATE_LOCALIZER_FETCH_THREAD_COUNT)) + .when(localizer).getDownloadThreadCount(); doNothing().when(localizer).sleep(anyInt()); - - - // return result instantly for deterministic test - ExecutorService syncExec = mock(ExecutorService.class); - CompletionService cs = mock(CompletionService.class); - when(cs.submit(isA(Callable.class))) - .thenAnswer(new Answer>() { - @Override - public Future answer(InvocationOnMock invoc) - throws Throwable { - Future done = mock(Future.class); - when(done.isDone()).thenReturn(true); - FakeDownload d = (FakeDownload) invoc.getArguments()[0]; - when(done.get()).thenReturn(d.call()); - return done; - } - }); - doReturn(syncExec).when(localizer).createDownloadThreadPool(); - doReturn(cs).when(localizer).createCompletionService(syncExec); return localizer; } @@ -348,8 +448,8 @@ public boolean matches(Object o) { } static class FakeDownload implements Callable { - private final Path localPath; - private final boolean succeed; + final Path localPath; + final boolean succeed; FakeDownload(String absPath, boolean succeed) { this.localPath = new Path("file:///localcache" + absPath); this.succeed = succeed; @@ -362,6 +462,25 @@ public Path call() throws IOException { return localPath; } } + + static class FakeLargeDownload extends FakeDownload implements Callable { + public boolean called; + FakeLargeDownload(String absPath, boolean succeed) { + super(absPath, succeed); + } + @Override + public Path call() throws IOException { + called = true; + if (!succeed) { + throw new IOException("FAIL " + localPath); + } + try { + Thread.sleep(10000); + } catch (InterruptedException e) { + } + return localPath; + } + } static RecordFactory getMockLocalizerRecordFactory() { RecordFactory mockRF = mock(RecordFactory.class);