diff --git a/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-api/src/main/java/org/apache/hadoop/yarn/conf/YarnConfiguration.java b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-api/src/main/java/org/apache/hadoop/yarn/conf/YarnConfiguration.java
index 7dd5ce3..06aa348 100644
--- a/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-api/src/main/java/org/apache/hadoop/yarn/conf/YarnConfiguration.java
+++ b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-api/src/main/java/org/apache/hadoop/yarn/conf/YarnConfiguration.java
@@ -2757,6 +2757,11 @@ public static boolean areNodeLabelsEnabled(
public static final String TIMELINE_XFS_OPTIONS =
TIMELINE_XFS_PREFIX + "xframe-options";
+ /** Number of threads to use for private localization fetching. */
+ public static final String NM_PRIVATE_LOCALIZER_FETCH_THREAD_COUNT = NM_PREFIX
+ + "localizer.private.fetch.thread-count";
+ public static final int DEFAULT_NM_PRIVATE_LOCALIZER_FETCH_THREAD_COUNT = 1;
+
public YarnConfiguration() {
super();
}
diff --git a/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-common/src/main/resources/yarn-default.xml b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-common/src/main/resources/yarn-default.xml
index 3c30ed3..4566eae 100644
--- a/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-common/src/main/resources/yarn-default.xml
+++ b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-common/src/main/resources/yarn-default.xml
@@ -2999,4 +2999,12 @@
3000
+
+
+ Number of threads to use for private localization fetching
+
+ yarn.nodemanager.localizer.private.fetch.thread-count
+ 1
+
+
diff --git a/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/localizer/ContainerLocalizer.java b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/localizer/ContainerLocalizer.java
index 04be631..1724f3f 100644
--- a/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/localizer/ContainerLocalizer.java
+++ b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/localizer/ContainerLocalizer.java
@@ -35,6 +35,9 @@
import java.util.concurrent.ExecutorCompletionService;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Future;
+import java.util.concurrent.LinkedBlockingQueue;
+import java.util.concurrent.ThreadFactory;
+import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.TimeUnit;
import org.apache.commons.logging.Log;
@@ -53,7 +56,6 @@
import org.apache.hadoop.security.token.TokenIdentifier;
import org.apache.hadoop.util.DiskValidator;
import org.apache.hadoop.util.DiskValidatorFactory;
-import org.apache.hadoop.util.concurrent.HadoopExecutors;
import org.apache.hadoop.yarn.YarnUncaughtExceptionHandler;
import org.apache.hadoop.yarn.api.records.LocalResource;
import org.apache.hadoop.yarn.api.records.SerializedException;
@@ -173,11 +175,11 @@ public LocalizationProtocol run() {
ugi.addToken(token);
}
- ExecutorService exec = null;
+ ThreadPoolExecutor exec = null;
try {
exec = createDownloadThreadPool();
CompletionService ecs = createCompletionService(exec);
- localizeFiles(nodeManager, ecs, ugi);
+ localizeFiles(nodeManager, ecs, ugi, exec);
return;
} catch (Throwable e) {
throw new IOException(e);
@@ -192,10 +194,22 @@ public LocalizationProtocol run() {
}
}
}
+
+ int getDownloadThreadCount() {
+ return conf.getInt(
+ YarnConfiguration.NM_PRIVATE_LOCALIZER_FETCH_THREAD_COUNT,
+ YarnConfiguration.DEFAULT_NM_PRIVATE_LOCALIZER_FETCH_THREAD_COUNT);
+ }
- ExecutorService createDownloadThreadPool() {
- return HadoopExecutors.newSingleThreadExecutor(new ThreadFactoryBuilder()
- .setNameFormat("ContainerLocalizer Downloader").build());
+ ThreadPoolExecutor createDownloadThreadPool() {
+ int nThreads = getDownloadThreadCount();
+ ThreadFactory tf = new ThreadFactoryBuilder()
+ .setNameFormat("ContainerLocalizer Downloader #%d")
+ .build();
+ return new ThreadPoolExecutor(nThreads, nThreads,
+ 0L, TimeUnit.MILLISECONDS,
+ new LinkedBlockingQueue(),
+ tf);
}
CompletionService createCompletionService(ExecutorService exec) {
@@ -235,9 +249,11 @@ protected void closeFileSystems(UserGroupInformation ugi) {
}
protected void localizeFiles(LocalizationProtocol nodemanager,
- CompletionService cs, UserGroupInformation ugi)
+ CompletionService cs, UserGroupInformation ugi, ThreadPoolExecutor exec)
throws IOException, YarnException {
+ int downloadThreadCount = getDownloadThreadCount();
while (true) {
+ boolean newRequestReceived = false;
try {
LocalizerStatus status = createStatus();
LocalizerHeartbeatResponse response = nodemanager.heartbeat(status);
@@ -246,6 +262,7 @@ protected void localizeFiles(LocalizationProtocol nodemanager,
List newRsrcs = response.getResourceSpecs();
for (ResourceLocalizationSpec newRsrc : newRsrcs) {
if (!pendingResources.containsKey(newRsrc.getResource())) {
+ newRequestReceived = true;
pendingResources.put(newRsrc.getResource(), cs.submit(download(
new Path(newRsrc.getDestinationDirectory().getFile()),
newRsrc.getResource(), ugi)));
@@ -269,7 +286,14 @@ protected void localizeFiles(LocalizationProtocol nodemanager,
}
return;
}
- cs.poll(1000, TimeUnit.MILLISECONDS);
+ if (exec.getActiveCount() >= downloadThreadCount
+ || !newRequestReceived) {
+ // Each heartbeat gives us only 1 resource to download. Don't wait
+ // for the first 'threadCount' heartbeats to allow parallel download.
+ // Subsequent downloads are also parallel because cs.poll(...)
+ // returns early when any download finishes before the timeout.
+ cs.poll(1000, TimeUnit.MILLISECONDS);
+ }
} catch (InterruptedException e) {
return;
} catch (YarnException e) {
diff --git a/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/localizer/ResourceLocalizationService.java b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/localizer/ResourceLocalizationService.java
index 37473e3..a51b824 100644
--- a/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/localizer/ResourceLocalizationService.java
+++ b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/localizer/ResourceLocalizationService.java
@@ -1110,10 +1110,9 @@ LocalizerHeartbeatResponse processHeartbeat(
List rsrcs =
new ArrayList();
- /*
- * TODO : It doesn't support multiple downloads per ContainerLocalizer
- * at the same time. We need to think whether we should support this.
- */
+ // Return one resource per heartbeat.
+ // ContainerLocalizer can run multiple heartbeats to get multiple
+ // resources
LocalResource next = findNextResource();
if (next != null) {
try {
diff --git a/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/test/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/localizer/TestContainerLocalizer.java b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/test/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/localizer/TestContainerLocalizer.java
index fac7086..9b89328 100644
--- a/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/test/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/localizer/TestContainerLocalizer.java
+++ b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/test/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/localizer/TestContainerLocalizer.java
@@ -43,8 +43,7 @@
import java.util.Random;
import java.util.concurrent.Callable;
import java.util.concurrent.CompletionService;
-import java.util.concurrent.ExecutorService;
-import java.util.concurrent.Future;
+import java.util.concurrent.ThreadPoolExecutor;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
@@ -64,6 +63,7 @@
import org.apache.hadoop.yarn.api.records.LocalResourceType;
import org.apache.hadoop.yarn.api.records.LocalResourceVisibility;
import org.apache.hadoop.yarn.api.records.URL;
+import org.apache.hadoop.yarn.conf.YarnConfiguration;
import org.apache.hadoop.yarn.exceptions.YarnException;
import org.apache.hadoop.yarn.exceptions.YarnRuntimeException;
import org.apache.hadoop.yarn.factories.RecordFactory;
@@ -72,7 +72,6 @@
import org.apache.hadoop.yarn.server.nodemanager.api.protocolrecords.LocalResourceStatus;
import org.apache.hadoop.yarn.server.nodemanager.api.protocolrecords.LocalizerAction;
import org.apache.hadoop.yarn.server.nodemanager.api.protocolrecords.LocalizerStatus;
-import org.apache.hadoop.yarn.util.ConverterUtils;
import org.junit.Assert;
import org.junit.Test;
import org.mockito.ArgumentMatcher;
@@ -200,6 +199,120 @@ public boolean matches(Object o) {
}));
}
+ private ContainerLocalizer setupContainerLocalizerForTest() throws Exception {
+ Configuration conf = new Configuration();
+ return setupContainerLocalizerForTest(conf);
+ }
+
+ @Test
+ public void testMultipleThreadDownload() throws Exception {
+ FileContext fs = FileContext.getLocalFSFileContext();
+ spylfs = spy(fs.getDefaultFileSystem());
+ Configuration conf = new Configuration();
+ conf.setInt(YarnConfiguration.NM_PRIVATE_LOCALIZER_FETCH_THREAD_COUNT, 4);
+ ContainerLocalizer localizer = setupContainerLocalizerForTest(conf);
+
+ // verify created cache
+ List privCacheList = new ArrayList();
+ List appCacheList = new ArrayList();
+ for (Path p : localDirs) {
+ Path base = new Path(new Path(p, ContainerLocalizer.USERCACHE), appUser);
+ Path privcache = new Path(base, ContainerLocalizer.FILECACHE);
+ privCacheList.add(privcache);
+ Path appDir = new Path(base,
+ new Path(ContainerLocalizer.APPCACHE, appId));
+ Path appcache = new Path(appDir, ContainerLocalizer.FILECACHE);
+ appCacheList.add(appcache);
+ }
+
+ // mock heartbeat responses from NM
+ ResourceLocalizationSpec rsrcA = getMockRsrc(random,
+ LocalResourceVisibility.PRIVATE, privCacheList.get(0));
+ ResourceLocalizationSpec rsrcB = getMockRsrc(random,
+ LocalResourceVisibility.PRIVATE, privCacheList.get(0));
+ ResourceLocalizationSpec rsrcC = getMockRsrc(random,
+ LocalResourceVisibility.APPLICATION, appCacheList.get(0));
+ ResourceLocalizationSpec rsrcD = getMockRsrc(random,
+ LocalResourceVisibility.PRIVATE, privCacheList.get(0));
+ ResourceLocalizationSpec rsrcE = getMockRsrc(random,
+ LocalResourceVisibility.APPLICATION, appCacheList.get(0));
+ ResourceLocalizationSpec rsrcF = getMockRsrc(random,
+ LocalResourceVisibility.PRIVATE, privCacheList.get(0));
+
+ when(nmProxy.heartbeat(isA(LocalizerStatus.class)))
+ .thenReturn(new MockLocalizerHeartbeatResponse(LocalizerAction.LIVE,
+ Collections.singletonList(rsrcA)))
+ .thenReturn(new MockLocalizerHeartbeatResponse(LocalizerAction.LIVE,
+ Collections.singletonList(rsrcB)))
+ .thenReturn(new MockLocalizerHeartbeatResponse(LocalizerAction.LIVE,
+ Collections.singletonList(rsrcC)))
+ .thenReturn(new MockLocalizerHeartbeatResponse(LocalizerAction.LIVE,
+ Collections.singletonList(rsrcD)))
+ .thenReturn(new MockLocalizerHeartbeatResponse(LocalizerAction.LIVE,
+ Collections.singletonList(rsrcE)))
+ .thenReturn(new MockLocalizerHeartbeatResponse(LocalizerAction.LIVE,
+ Collections.singletonList(rsrcF)))
+ .thenReturn(new MockLocalizerHeartbeatResponse(LocalizerAction.LIVE,
+ Collections. emptyList()))
+ .thenReturn(
+ new MockLocalizerHeartbeatResponse(LocalizerAction.DIE, null));
+
+ LocalResource tRsrcA = rsrcA.getResource();
+ LocalResource tRsrcB = rsrcB.getResource();
+ LocalResource tRsrcC = rsrcC.getResource();
+ LocalResource tRsrcD = rsrcD.getResource();
+ LocalResource tRsrcE = rsrcE.getResource();
+ LocalResource tRsrcF = rsrcF.getResource();
+ FakeLargeDownload download1 = new FakeLargeDownload(
+ rsrcA.getResource().getResource().getFile(), true);
+ FakeLargeDownload download2 = new FakeLargeDownload(
+ rsrcB.getResource().getResource().getFile(), true);
+ FakeLargeDownload download3 = new FakeLargeDownload(
+ rsrcC.getResource().getResource().getFile(), true);
+ FakeLargeDownload download4 = new FakeLargeDownload(
+ rsrcD.getResource().getResource().getFile(), true);
+ FakeLargeDownload download5 = new FakeLargeDownload(
+ rsrcE.getResource().getResource().getFile(), true);
+ FakeLargeDownload download6 = new FakeLargeDownload(
+ rsrcF.getResource().getResource().getFile(), true);
+ doReturn(download1).when(localizer).download(isA(Path.class), eq(tRsrcA),
+ isA(UserGroupInformation.class));
+ doReturn(download2).when(localizer).download(isA(Path.class), eq(tRsrcB),
+ isA(UserGroupInformation.class));
+ doReturn(download3).when(localizer).download(isA(Path.class), eq(tRsrcC),
+ isA(UserGroupInformation.class));
+ doReturn(download4).when(localizer).download(isA(Path.class), eq(tRsrcD),
+ isA(UserGroupInformation.class));
+ doReturn(download5).when(localizer).download(isA(Path.class), eq(tRsrcE),
+ isA(UserGroupInformation.class));
+ doReturn(download6).when(localizer).download(isA(Path.class), eq(tRsrcF),
+ isA(UserGroupInformation.class));
+
+ // run localization
+ localizer.runLocalization(nmAddr);
+ for (Path p : localDirs) {
+ Path base = new Path(new Path(p, ContainerLocalizer.USERCACHE), appUser);
+ Path privcache = new Path(base, ContainerLocalizer.FILECACHE);
+ // $x/usercache/$user/filecache
+ verify(spylfs).mkdir(eq(privcache), eq(CACHE_DIR_PERM), eq(false));
+ Path appDir = new Path(base,
+ new Path(ContainerLocalizer.APPCACHE, appId));
+ // $x/usercache/$user/appcache/$appId/filecache
+ Path appcache = new Path(appDir, ContainerLocalizer.FILECACHE);
+ verify(spylfs).mkdir(eq(appcache), eq(CACHE_DIR_PERM), eq(false));
+ }
+ // verify tokens read at expected location
+ verify(spylfs).open(tokenPath);
+
+ // 4 is thread pool size, and as each download will sleep, only 4 download
+ Assert.assertTrue("rsrcA unable to download", download1.called);
+ Assert.assertTrue("rsrcB unable to download", download2.called);
+ Assert.assertTrue("rsrcC unable to download", download3.called);
+ Assert.assertTrue("rsrcD unable to download", download4.called);
+ Assert.assertFalse("rsrcE must not be downloaded", download5.called);
+ Assert.assertFalse("rsrcF must not be downloaded", download6.called);
+ }
+
@Test(timeout = 15000)
public void testMainFailure() throws Exception {
@@ -227,7 +340,8 @@ public void testLocalizerTokenIsGettingRemoved() throws Exception {
spylfs = spy(fs.getDefaultFileSystem());
ContainerLocalizer localizer = setupContainerLocalizerForTest();
doNothing().when(localizer).localizeFiles(any(LocalizationProtocol.class),
- any(CompletionService.class), any(UserGroupInformation.class));
+ any(CompletionService.class), any(UserGroupInformation.class),
+ any(ThreadPoolExecutor.class));
localizer.runLocalization(nmAddr);
verify(spylfs, times(1)).delete(tokenPath, false);
}
@@ -242,7 +356,8 @@ public void testContainerLocalizerClosesFilesystems() throws Exception {
ContainerLocalizer localizer = setupContainerLocalizerForTest();
doNothing().when(localizer).localizeFiles(any(LocalizationProtocol.class),
- any(CompletionService.class), any(UserGroupInformation.class));
+ any(CompletionService.class), any(UserGroupInformation.class),
+ any(ThreadPoolExecutor.class));
verify(localizer, never()).closeFileSystems(
any(UserGroupInformation.class));
@@ -253,9 +368,10 @@ public void testContainerLocalizerClosesFilesystems() throws Exception {
// verify filesystems are closed when localizer fails
localizer = setupContainerLocalizerForTest();
- doThrow(new YarnRuntimeException("Forced Failure")).when(localizer).localizeFiles(
- any(LocalizationProtocol.class), any(CompletionService.class),
- any(UserGroupInformation.class));
+ doThrow(new YarnRuntimeException("Forced Failure")).when(localizer)
+ .localizeFiles(any(LocalizationProtocol.class),
+ any(CompletionService.class), any(UserGroupInformation.class),
+ any(ThreadPoolExecutor.class));
verify(localizer, never()).closeFileSystems(
any(UserGroupInformation.class));
try {
@@ -267,13 +383,12 @@ public void testContainerLocalizerClosesFilesystems() throws Exception {
}
@SuppressWarnings("unchecked") // mocked generics
- private ContainerLocalizer setupContainerLocalizerForTest()
+ private ContainerLocalizer setupContainerLocalizerForTest(Configuration conf)
throws Exception {
// don't actually create dirs
doNothing().when(spylfs).mkdir(
isA(Path.class), isA(FsPermission.class), anyBoolean());
- Configuration conf = new Configuration();
FileContext lfs = FileContext.getFileContext(spylfs, conf);
localDirs = new ArrayList();
for (int i = 0; i < 4; ++i) {
@@ -298,26 +413,11 @@ private ContainerLocalizer setupContainerLocalizerForTest()
).when(spylfs).open(tokenPath);
nmProxy = mock(LocalizationProtocol.class);
doReturn(nmProxy).when(localizer).getProxy(nmAddr);
+ doReturn(
+ conf.getInt(YarnConfiguration.NM_PRIVATE_LOCALIZER_FETCH_THREAD_COUNT,
+ YarnConfiguration.DEFAULT_NM_PRIVATE_LOCALIZER_FETCH_THREAD_COUNT))
+ .when(localizer).getDownloadThreadCount();
doNothing().when(localizer).sleep(anyInt());
-
-
- // return result instantly for deterministic test
- ExecutorService syncExec = mock(ExecutorService.class);
- CompletionService cs = mock(CompletionService.class);
- when(cs.submit(isA(Callable.class)))
- .thenAnswer(new Answer>() {
- @Override
- public Future answer(InvocationOnMock invoc)
- throws Throwable {
- Future done = mock(Future.class);
- when(done.isDone()).thenReturn(true);
- FakeDownload d = (FakeDownload) invoc.getArguments()[0];
- when(done.get()).thenReturn(d.call());
- return done;
- }
- });
- doReturn(syncExec).when(localizer).createDownloadThreadPool();
- doReturn(cs).when(localizer).createCompletionService(syncExec);
return localizer;
}
@@ -348,8 +448,8 @@ public boolean matches(Object o) {
}
static class FakeDownload implements Callable {
- private final Path localPath;
- private final boolean succeed;
+ final Path localPath;
+ final boolean succeed;
FakeDownload(String absPath, boolean succeed) {
this.localPath = new Path("file:///localcache" + absPath);
this.succeed = succeed;
@@ -362,6 +462,25 @@ public Path call() throws IOException {
return localPath;
}
}
+
+ static class FakeLargeDownload extends FakeDownload implements Callable {
+ public boolean called;
+ FakeLargeDownload(String absPath, boolean succeed) {
+ super(absPath, succeed);
+ }
+ @Override
+ public Path call() throws IOException {
+ called = true;
+ if (!succeed) {
+ throw new IOException("FAIL " + localPath);
+ }
+ try {
+ Thread.sleep(10000);
+ } catch (InterruptedException e) {
+ }
+ return localPath;
+ }
+ }
static RecordFactory getMockLocalizerRecordFactory() {
RecordFactory mockRF = mock(RecordFactory.class);