diff --git a/hadoop-yarn-project/hadoop-yarn/conf/container-executor.cfg b/hadoop-yarn-project/hadoop-yarn/conf/container-executor.cfg
index 4df53df6892..e0767a40265 100644
--- a/hadoop-yarn-project/hadoop-yarn/conf/container-executor.cfg
+++ b/hadoop-yarn-project/hadoop-yarn/conf/container-executor.cfg
@@ -22,4 +22,9 @@ feature.tc.enabled=false
#[fpga]
# module.enabled=## Enable/Disable the FPGA resource handler module. set to "true" to enable, disabled by default
# fpga.major-device-number=## Major device number of FPGA, by default is 246. Strongly recommend setting this
-# fpga.allowed-device-minor-numbers=## Comma separated allowed minor device numbers, empty means all FPGA devices managed by YARN.
\ No newline at end of file
+# fpga.allowed-device-minor-numbers=## Comma separated allowed minor device numbers, empty means all FPGA devices managed by YARN.
+
+# The configs below deal with settings for resource handled by pluggable device plugin framework
+#[devices]
+# module.enabled=## Enable/Disable the device resource handler module for isolation. Disabled by default.
+# devices.denied-numbers=## Blacklisted devices not permitted to use. The format is comma separated "majorNumber:minorNumber". For instance, "195:1,195:2". Leave it empty means default devices reported by device plugin are all allowed.
\ No newline at end of file
diff --git a/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/CMakeLists.txt b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/CMakeLists.txt
index 300bb65c322..f0f005d53b5 100644
--- a/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/CMakeLists.txt
+++ b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/CMakeLists.txt
@@ -135,6 +135,7 @@ add_library(container
main/native/container-executor/impl/modules/common/module-configs.c
main/native/container-executor/impl/modules/gpu/gpu-module.c
main/native/container-executor/impl/modules/fpga/fpga-module.c
+ main/native/container-executor/impl/modules/devices/devices-module.c
main/native/container-executor/impl/utils/docker-util.c
)
@@ -169,6 +170,7 @@ add_executable(cetest
main/native/container-executor/test/modules/cgroups/test-cgroups-module.cc
main/native/container-executor/test/modules/gpu/test-gpu-module.cc
main/native/container-executor/test/modules/fpga/test-fpga-module.cc
+ main/native/container-executor/test/modules/devices/test-devices-module.cc
main/native/container-executor/test/test_util.cc
main/native/container-executor/test/utils/test_docker_util.cc)
target_link_libraries(cetest gtest container)
diff --git a/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/api/deviceplugin/Device.java b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/api/deviceplugin/Device.java
index c3a25157218..4ad247f681c 100644
--- a/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/api/deviceplugin/Device.java
+++ b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/api/deviceplugin/Device.java
@@ -181,8 +181,8 @@ public String toString() {
// default -1 representing the value is not set
private int id = -1;
private String devPath = "";
- private int majorNumber;
- private int minorNumber;
+ private int majorNumber = -1;
+ private int minorNumber = -1;
private String busID = "";
private boolean isHealthy;
private String status = "";
diff --git a/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/linux/privileged/PrivilegedOperation.java b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/linux/privileged/PrivilegedOperation.java
index a1fdb91f9a3..a17daede2a8 100644
--- a/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/linux/privileged/PrivilegedOperation.java
+++ b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/linux/privileged/PrivilegedOperation.java
@@ -54,6 +54,7 @@
RUN_DOCKER_CMD("--run-docker"),
GPU("--module-gpu"),
FPGA("--module-fpga"),
+ DEVICE("--module-devices"),
LIST_AS_USER(""), // no CLI switch supported yet.
ADD_NUMA_PARAMS(""), // no CLI switch supported yet.
REMOVE_DOCKER_CONTAINER("--remove-docker-container"),
diff --git a/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/com/nvidia/NvidiaGPUPluginForRuntimeV2.java b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/com/nvidia/NvidiaGPUPluginForRuntimeV2.java
new file mode 100644
index 00000000000..609ec677cad
--- /dev/null
+++ b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/com/nvidia/NvidiaGPUPluginForRuntimeV2.java
@@ -0,0 +1,240 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.hadoop.yarn.server.nodemanager.containermanager.resourceplugin.com.nvidia;
+
+import com.google.common.annotations.VisibleForTesting;
+import com.google.common.collect.ImmutableSet;
+import org.apache.hadoop.util.Shell;
+import org.apache.hadoop.yarn.exceptions.YarnException;
+import org.apache.hadoop.yarn.server.nodemanager.api.deviceplugin.Device;
+import org.apache.hadoop.yarn.server.nodemanager.api.deviceplugin.DevicePlugin;
+import org.apache.hadoop.yarn.server.nodemanager.api.deviceplugin.DeviceRegisterRequest;
+import org.apache.hadoop.yarn.server.nodemanager.api.deviceplugin.DeviceRuntimeSpec;
+import org.apache.hadoop.yarn.server.nodemanager.api.deviceplugin.YarnRuntimeType;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.io.File;
+import java.io.IOException;
+import java.util.HashMap;
+import java.util.Map;
+import java.util.Set;
+import java.util.TreeSet;
+
+/**
+ * Nvidia GPU plugin supporting both Nvidia container runtime v2 for Docker and
+ * non-Docker container.
+ * */
+public class NvidiaGPUPluginForRuntimeV2 implements DevicePlugin {
+ public static final Logger LOG = LoggerFactory.getLogger(
+ NvidiaGPUPluginForRuntimeV2.class);
+
+ public static final String NV_RESOURCE_NAME = "nvidia.com/gpu";
+
+ private NvidiaCommandExecutor shellExecutor = new NvidiaCommandExecutor();
+
+ private Map environment = new HashMap<>();
+
+ // If this environment is set, use it directly
+ private static final String ENV_BINARY_PATH = "NVIDIA_SMI_PATH";
+
+ private static final String DEFAULT_BINARY_NAME = "nvidia-smi";
+
+ private static final String DEV_NAME_PREFIX = "nvidia";
+
+ private String pathOfGpuBinary = null;
+
+ // command should not run more than 10 sec.
+ private static final int MAX_EXEC_TIMEOUT_MS = 10 * 1000;
+
+ // When executable path not set, try to search default dirs
+ // By default search /usr/bin, /bin, and /usr/local/nvidia/bin (when
+ // launched by nvidia-docker.
+ private static final Set DEFAULT_BINARY_SEARCH_DIRS = ImmutableSet.of(
+ "/usr/bin", "/bin", "/usr/local/nvidia/bin");
+
+ @Override
+ public DeviceRegisterRequest getRegisterRequestInfo() throws Exception {
+ return DeviceRegisterRequest.Builder.newInstance()
+ .setResourceName(NV_RESOURCE_NAME).build();
+ }
+
+ @Override
+ public Set getDevices() throws Exception {
+ shellExecutor.searchBinary();
+ TreeSet r = new TreeSet<>();
+ String output;
+ try {
+ output = shellExecutor.getDeviceInfo();
+ String[] lines = output.trim().split("\n");
+ int id = 0;
+ for (String oneLine : lines) {
+ String[] tokensEachLine = oneLine.split(",");
+ if (tokensEachLine.length != 2) {
+ throw new Exception("Cannot parse the output to get device info. "
+ + "Unexpected format in it:" + oneLine);
+ }
+ String minorNumber = tokensEachLine[0].trim();
+ String busId = tokensEachLine[1].trim();
+ String majorNumber = getMajorNumber(DEV_NAME_PREFIX
+ + minorNumber);
+ if (majorNumber != null) {
+ r.add(Device.Builder.newInstance()
+ .setId(id)
+ .setMajorNumber(Integer.parseInt(majorNumber))
+ .setMinorNumber(Integer.parseInt(minorNumber))
+ .setBusID(busId)
+ .setDevPath("/dev/" + DEV_NAME_PREFIX + minorNumber)
+ .setHealthy(true)
+ .build());
+ id++;
+ }
+ }
+ return r;
+ } catch (IOException e) {
+ if (LOG.isDebugEnabled()) {
+ LOG.debug("Failed to get output from " + pathOfGpuBinary);
+ }
+ throw new YarnException(e);
+ }
+ }
+
+ @Override
+ public DeviceRuntimeSpec onDevicesAllocated(Set allocatedDevices,
+ YarnRuntimeType yarnRuntime) throws Exception {
+ if (LOG.isDebugEnabled()) {
+ LOG.debug("Generating runtime spec for allocated devices: "
+ + allocatedDevices + ", " + yarnRuntime.getName());
+ }
+ if (yarnRuntime == YarnRuntimeType.RUNTIME_DOCKER) {
+ String nvidiaRuntime = "nvidia";
+ String nvidiaVisibleDevices = "NVIDIA_VISIBLE_DEVICES";
+ StringBuffer gpuMinorNumbersSB = new StringBuffer();
+ for (Device device : allocatedDevices) {
+ gpuMinorNumbersSB.append(device.getMinorNumber() + ",");
+ }
+ String minorNumbers = gpuMinorNumbersSB.toString();
+ LOG.info("Nvidia Docker v2 assigned GPU: " + minorNumbers);
+ return DeviceRuntimeSpec.Builder.newInstance()
+ .addEnv(nvidiaVisibleDevices,
+ minorNumbers.substring(0, minorNumbers.length() - 1))
+ .setContainerRuntime(nvidiaRuntime)
+ .build();
+ }
+ return null;
+ }
+
+ @Override
+ public void onDevicesReleased(Set releasedDevices) throws Exception {
+ // do nothing
+ }
+
+ // Get major number from device name.
+ private String getMajorNumber(String devName) {
+ String output = null;
+ // output "major:minor" in hex
+ try {
+ if (LOG.isDebugEnabled()) {
+ LOG.debug("Get major numbers from /dev/" + devName);
+ }
+ output = shellExecutor.getMajorMinorInfo(devName);
+ String[] strs = output.trim().split(":");
+ if (LOG.isDebugEnabled()) {
+ LOG.debug("stat output:" + output);
+ }
+ output = Integer.toString(Integer.parseInt(strs[0], 16));
+ } catch (IOException e) {
+ String msg =
+ "Failed to get major number from reading /dev/" + devName;
+ LOG.warn(msg);
+ } catch (NumberFormatException e) {
+ LOG.error("Failed to parse device major number from stat output");
+ output = null;
+ }
+ return output;
+ }
+
+ /**
+ * A shell wrapper class easy for test.
+ * */
+ public class NvidiaCommandExecutor {
+
+ public String getDeviceInfo() throws IOException {
+ return Shell.execCommand(environment,
+ new String[]{pathOfGpuBinary, "--query-gpu=index,pci.bus_id",
+ "--format=csv,noheader"}, MAX_EXEC_TIMEOUT_MS);
+ }
+
+ public String getMajorMinorInfo(String devName) throws IOException {
+ // output "major:minor" in hex
+ Shell.ShellCommandExecutor shexec = new Shell.ShellCommandExecutor(
+ new String[]{"stat", "-c", "%t:%T", "/dev/" + devName});
+ shexec.execute();
+ return shexec.getOutput();
+ }
+
+ public void searchBinary() throws Exception {
+ if (pathOfGpuBinary != null) {
+ LOG.info("Skip searching, the nvidia gpu binary is already set: "
+ + pathOfGpuBinary);
+ return;
+ }
+ // search env for the binary
+ String envBinaryPath = System.getenv(ENV_BINARY_PATH);
+ if (null != envBinaryPath) {
+ if (new File(envBinaryPath).exists()) {
+ pathOfGpuBinary = envBinaryPath;
+ LOG.info("Use nvidia gpu binary: " + pathOfGpuBinary);
+ return;
+ }
+ }
+ LOG.info("Search binary..");
+ // search if binary exists in default folders
+ File binaryFile;
+ boolean found = false;
+ for (String dir : DEFAULT_BINARY_SEARCH_DIRS) {
+ binaryFile = new File(dir, DEFAULT_BINARY_NAME);
+ if (binaryFile.exists()) {
+ found = true;
+ pathOfGpuBinary = binaryFile.getAbsolutePath();
+ LOG.info("Found binary:" + pathOfGpuBinary);
+ break;
+ }
+ }
+ if (!found) {
+ LOG.error("No binary found from env variable: "
+ + ENV_BINARY_PATH + " or path "
+ + DEFAULT_BINARY_SEARCH_DIRS.toString());
+ throw new Exception("No binary found for "
+ + NvidiaGPUPluginForRuntimeV2.class);
+ }
+ }
+ }
+
+ @VisibleForTesting
+ public void setPathOfGpuBinary(String pathOfGpuBinary) {
+ this.pathOfGpuBinary = pathOfGpuBinary;
+ }
+
+ @VisibleForTesting
+ public void setShellExecutor(
+ NvidiaCommandExecutor shellExecutor) {
+ this.shellExecutor = shellExecutor;
+ }
+}
diff --git a/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/com/nvidia/package-info.java b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/com/nvidia/package-info.java
new file mode 100644
index 00000000000..8eb2687331f
--- /dev/null
+++ b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/com/nvidia/package-info.java
@@ -0,0 +1,19 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.hadoop.yarn.server.nodemanager.containermanager.resourceplugin.com.nvidia;
\ No newline at end of file
diff --git a/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/deviceframework/DeviceMappingManager.java b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/deviceframework/DeviceMappingManager.java
index fefccb9753c..9fcbf93cd36 100644
--- a/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/deviceframework/DeviceMappingManager.java
+++ b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/deviceframework/DeviceMappingManager.java
@@ -95,6 +95,20 @@ public DeviceMappingManager(Context context) {
return devicePluginSchedulers;
}
+ @VisibleForTesting
+ public Set getAllocatedDevices(String resourceName,
+ ContainerId cId) {
+ Set assigned = new TreeSet<>();
+ Map assignedMap =
+ this.getAllUsedDevices().get(resourceName);
+ for (Map.Entry entry : assignedMap.entrySet()) {
+ if (entry.getValue().equals(cId)) {
+ assigned.add(entry.getKey());
+ }
+ }
+ return assigned;
+ }
+
public synchronized void addDeviceSet(String resourceName,
Set deviceSet) {
LOG.info("Adding new resource: " + "type:"
@@ -148,8 +162,10 @@ private synchronized DeviceAllocation internalAssignDevices(
ContainerId containerId = container.getContainerId();
int requestedDeviceCount = getRequestedDeviceCount(resourceName,
requestedResource);
- LOG.debug("Try allocating " + requestedDeviceCount
- + " " + resourceName);
+ if (LOG.isDebugEnabled()) {
+ LOG.debug("Try allocating " + requestedDeviceCount
+ + " " + resourceName);
+ }
// Assign devices to container if requested some.
if (requestedDeviceCount > 0) {
if (requestedDeviceCount > getAvailableDevices(resourceName)) {
@@ -245,18 +261,24 @@ public synchronized void cleanupAssignedDevices(String resourceName,
ContainerId containerId) {
Iterator> iter =
allUsedDevices.get(resourceName).entrySet().iterator();
+ Map.Entry entry;
while (iter.hasNext()) {
- if (iter.next().getValue().equals(containerId)) {
+ entry = iter.next();
+ if (entry.getValue().equals(containerId)) {
+ if (LOG.isDebugEnabled()) {
+ LOG.debug("Recycle devices: " + entry.getKey()
+ + ", type: " + resourceName + " from " + containerId);
+ }
iter.remove();
}
}
}
- public static int getRequestedDeviceCount(String resourceName,
+ public static int getRequestedDeviceCount(String resName,
Resource requestedResource) {
try {
return Long.valueOf(requestedResource.getResourceValue(
- resourceName)).intValue();
+ resName)).intValue();
} catch (ResourceNotFoundException e) {
return 0;
}
@@ -270,10 +292,7 @@ public int getAvailableDevices(String resourceName) {
private long getReleasingDevices(String resourceName) {
long releasingDevices = 0;
Map used = allUsedDevices.get(resourceName);
- Iterator> iter = used.entrySet()
- .iterator();
- while (iter.hasNext()) {
- ContainerId containerId = iter.next().getValue();
+ for (ContainerId containerId : ImmutableSet.copyOf(used.values())) {
Container container = nmContext.getContainers().get(containerId);
if (container != null) {
if (container.isContainerInFinalStates()) {
@@ -295,16 +314,20 @@ private void pickAndDoSchedule(Set allowed,
DevicePluginScheduler dps) throws ResourceHandlerException {
if (null == dps) {
- LOG.debug("Customized device plugin scheduler is preferred "
- + "but not implemented, use default logic");
+ if (LOG.isDebugEnabled()) {
+ LOG.debug("Customized device plugin scheduler is preferred "
+ + "but not implemented, use default logic");
+ }
defaultScheduleAction(allowed, used,
assigned, containerId, count);
} else {
- LOG.debug("Customized device plugin implemented,"
- + "use customized logic");
- // Use customized device scheduler
- LOG.debug("Try to schedule " + count
- + "(" + resourceName + ") using " + dps.getClass());
+ if (LOG.isDebugEnabled()) {
+ LOG.debug("Customized device plugin implemented,"
+ + "use customized logic");
+ // Use customized device scheduler
+ LOG.debug("Try to schedule " + count
+ + "(" + resourceName + ") using " + dps.getClass());
+ }
// Pass in unmodifiable set
Set dpsAllocated = dps.allocateDevices(
Sets.difference(allowed, used.keySet()),
@@ -345,6 +368,7 @@ private void defaultScheduleAction(Set allowed,
private String resourceName;
private Set allowed = Collections.emptySet();
+
private Set denied = Collections.emptySet();
DeviceAllocation(String resName, Set a,
@@ -362,6 +386,10 @@ private void defaultScheduleAction(Set allowed,
return allowed;
}
+ public Set getDenied() {
+ return denied;
+ }
+
@Override
public String toString() {
return "ResourceType: " + resourceName
diff --git a/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/deviceframework/DevicePluginAdapter.java b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/deviceframework/DevicePluginAdapter.java
index ed7d0dad284..462e45a52a9 100644
--- a/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/deviceframework/DevicePluginAdapter.java
+++ b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/deviceframework/DevicePluginAdapter.java
@@ -18,6 +18,7 @@
package org.apache.hadoop.yarn.server.nodemanager.containermanager.resourceplugin.deviceframework;
+import com.google.common.annotations.VisibleForTesting;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.yarn.api.records.ContainerId;
@@ -49,10 +50,20 @@
private final static Log LOG = LogFactory.getLog(DevicePluginAdapter.class);
private final String resourceName;
+
private final DevicePlugin devicePlugin;
private DeviceMappingManager deviceMappingManager;
+
private DeviceResourceHandlerImpl deviceResourceHandler;
private DeviceResourceUpdaterImpl deviceResourceUpdater;
+ private DeviceResourceDockerRuntimePluginImpl deviceDockerCommandPlugin;
+
+
+ @VisibleForTesting
+ public void setDeviceResourceHandler(
+ DeviceResourceHandlerImpl deviceResourceHandler) {
+ this.deviceResourceHandler = deviceResourceHandler;
+ }
public DevicePluginAdapter(String name, DevicePlugin dp,
DeviceMappingManager dmm) {
@@ -65,8 +76,16 @@ public DeviceMappingManager getDeviceMappingManager() {
return deviceMappingManager;
}
+
+ public DevicePlugin getDevicePlugin() {
+ return devicePlugin;
+ }
+
@Override
public void initialize(Context context) throws YarnException {
+ deviceDockerCommandPlugin = new DeviceResourceDockerRuntimePluginImpl(
+ resourceName,
+ devicePlugin, this);
deviceResourceUpdater = new DeviceResourceUpdaterImpl(
resourceName, devicePlugin);
LOG.info(resourceName + " plugin adapter initialized");
@@ -78,8 +97,8 @@ public ResourceHandler createResourceHandler(Context nmContext,
CGroupsHandler cGroupsHandler,
PrivilegedOperationExecutor privilegedOperationExecutor) {
this.deviceResourceHandler = new DeviceResourceHandlerImpl(resourceName,
- devicePlugin, this, deviceMappingManager,
- cGroupsHandler, privilegedOperationExecutor);
+ this, deviceMappingManager,
+ cGroupsHandler, privilegedOperationExecutor, nmContext);
return deviceResourceHandler;
}
@@ -95,7 +114,7 @@ public void cleanup() {
@Override
public DockerCommandPlugin getDockerCommandPluginInstance() {
- return null;
+ return deviceDockerCommandPlugin;
}
@Override
diff --git a/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/deviceframework/DeviceResourceDockerRuntimePluginImpl.java b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/deviceframework/DeviceResourceDockerRuntimePluginImpl.java
new file mode 100644
index 00000000000..aaa11bd8585
--- /dev/null
+++ b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/deviceframework/DeviceResourceDockerRuntimePluginImpl.java
@@ -0,0 +1,233 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.hadoop.yarn.server.nodemanager.containermanager.resourceplugin.deviceframework;
+
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+import org.apache.hadoop.yarn.api.records.ContainerId;
+import org.apache.hadoop.yarn.server.nodemanager.api.deviceplugin.Device;
+import org.apache.hadoop.yarn.server.nodemanager.api.deviceplugin.DevicePlugin;
+import org.apache.hadoop.yarn.server.nodemanager.api.deviceplugin.DeviceRuntimeSpec;
+import org.apache.hadoop.yarn.server.nodemanager.api.deviceplugin.MountDeviceSpec;
+import org.apache.hadoop.yarn.server.nodemanager.api.deviceplugin.MountVolumeSpec;
+import org.apache.hadoop.yarn.server.nodemanager.api.deviceplugin.VolumeSpec;
+import org.apache.hadoop.yarn.server.nodemanager.api.deviceplugin.YarnRuntimeType;
+import org.apache.hadoop.yarn.server.nodemanager.containermanager.container.Container;
+import org.apache.hadoop.yarn.server.nodemanager.containermanager.linux.runtime.docker.DockerRunCommand;
+import org.apache.hadoop.yarn.server.nodemanager.containermanager.linux.runtime.docker.DockerVolumeCommand;
+import org.apache.hadoop.yarn.server.nodemanager.containermanager.resourceplugin.DockerCommandPlugin;
+import org.apache.hadoop.yarn.server.nodemanager.containermanager.runtime.ContainerExecutionException;
+import org.apache.hadoop.yarn.util.LRUCacheHashMap;
+
+import java.util.Collections;
+import java.util.Map;
+import java.util.Set;
+
+/**
+ * Bridge DevicePlugin and the hooks related to lunch Docker container.
+ * When launching Docker container, DockerLinuxContainerRuntime will invoke
+ * this class's methods which get needed info back from DevicePlugin.
+ * */
+public class DeviceResourceDockerRuntimePluginImpl
+ implements DockerCommandPlugin {
+
+ final static Log LOG = LogFactory.getLog(
+ DeviceResourceDockerRuntimePluginImpl.class);
+
+ private String resourceName;
+ private DevicePlugin devicePlugin;
+ private DevicePluginAdapter devicePluginAdapter;
+
+ private int maxCacheSize = 100;
+ // LRU to avoid memory leak if getCleanupDockerVolumesCommand not invoked.
+ private Map> cachedAllocation =
+ Collections.synchronizedMap(new LRUCacheHashMap(maxCacheSize, true));
+
+ private Map cachedSpec =
+ Collections.synchronizedMap(new LRUCacheHashMap<>(maxCacheSize, true));
+
+ public DeviceResourceDockerRuntimePluginImpl(String resourceName,
+ DevicePlugin devicePlugin, DevicePluginAdapter devicePluginAdapter) {
+ this.resourceName = resourceName;
+ this.devicePlugin = devicePlugin;
+ this.devicePluginAdapter = devicePluginAdapter;
+ }
+
+ @Override
+ public void updateDockerRunCommand(DockerRunCommand dockerRunCommand,
+ Container container) throws ContainerExecutionException {
+ String containerId = container.getContainerId().toString();
+ if (LOG.isDebugEnabled()) {
+ LOG.debug("Try to update docker run command for: " + containerId);
+ }
+ if(!requestedDevice(resourceName, container)) {
+ return;
+ }
+ DeviceRuntimeSpec deviceRuntimeSpec = getRuntimeSpec(container);
+ if (deviceRuntimeSpec == null) {
+ LOG.warn("The device plugin: "
+ + devicePlugin.getClass().getCanonicalName()
+ + " returns null device runtime spec value for container: "
+ + containerId);
+ return;
+ }
+ // handle runtime
+ dockerRunCommand.addRuntime(deviceRuntimeSpec.getContainerRuntime());
+ if (LOG.isDebugEnabled()) {
+ LOG.debug("Handle docker container runtime type: "
+ + deviceRuntimeSpec.getContainerRuntime() + " for container: "
+ + containerId);
+ }
+ // handle device mounts
+ Set deviceMounts = deviceRuntimeSpec.getDeviceMounts();
+ if (LOG.isDebugEnabled()) {
+ LOG.debug("Handle device mounts: " + deviceMounts + " for container: "
+ + containerId);
+ }
+ for (MountDeviceSpec mountDeviceSpec : deviceMounts) {
+ dockerRunCommand.addDevice(
+ mountDeviceSpec.getDevicePathInHost(),
+ mountDeviceSpec.getDevicePathInContainer());
+ }
+ // handle volume mounts
+ Set mountVolumeSpecs = deviceRuntimeSpec.getVolumeMounts();
+ if (LOG.isDebugEnabled()) {
+ LOG.debug("Handle volume mounts: " + mountVolumeSpecs + " for container: "
+ + containerId);
+ }
+ for (MountVolumeSpec mountVolumeSpec : mountVolumeSpecs) {
+ if (mountVolumeSpec.getReadOnly()) {
+ dockerRunCommand.addReadOnlyMountLocation(
+ mountVolumeSpec.getHostPath(),
+ mountVolumeSpec.getMountPath());
+ } else {
+ dockerRunCommand.addReadWriteMountLocation(
+ mountVolumeSpec.getHostPath(),
+ mountVolumeSpec.getMountPath());
+ }
+ }
+ // handle envs
+ dockerRunCommand.addEnv(deviceRuntimeSpec.getEnvs());
+ if (LOG.isDebugEnabled()) {
+ LOG.debug("Handle envs: " + deviceRuntimeSpec.getEnvs()
+ + " for container: " + containerId);
+ }
+ }
+
+ @Override
+ public DockerVolumeCommand getCreateDockerVolumeCommand(Container container)
+ throws ContainerExecutionException {
+ if(!requestedDevice(resourceName, container)) {
+ return null;
+ }
+ DeviceRuntimeSpec deviceRuntimeSpec = getRuntimeSpec(container);
+ if (deviceRuntimeSpec == null) {
+ return null;
+ }
+ Set volumeClaims = deviceRuntimeSpec.getVolumeSpecs();
+ for (VolumeSpec volumeSec: volumeClaims) {
+ if (volumeSec.getVolumeOperation().equals(VolumeSpec.CREATE)) {
+ DockerVolumeCommand command = new DockerVolumeCommand(
+ DockerVolumeCommand.VOLUME_CREATE_SUB_COMMAND);
+ command.setDriverName(volumeSec.getVolumeDriver());
+ command.setVolumeName(volumeSec.getVolumeName());
+ if (LOG.isDebugEnabled()) {
+ LOG.debug("Get volume create request from plugin:" + volumeClaims
+ + " for container: " + container.getContainerId().toString());
+ }
+ return command;
+ }
+ }
+ return null;
+ }
+
+ @Override
+ public DockerVolumeCommand getCleanupDockerVolumesCommand(Container container)
+ throws ContainerExecutionException {
+
+ if(!requestedDevice(resourceName, container)) {
+ return null;
+ }
+ Set allocated = getAllocatedDevices(container);
+ try {
+ devicePlugin.onDevicesReleased(allocated);
+ } catch (Exception e) {
+ LOG.warn("Exception thrown in onDeviceReleased of "
+ + devicePlugin.getClass() + "for container: "
+ + container.getContainerId().toString(), e);
+ }
+ // remove cache
+ ContainerId containerId = container.getContainerId();
+ cachedAllocation.remove(containerId);
+ cachedSpec.remove(containerId);
+ return null;
+ }
+
+ protected boolean requestedDevice(String resName, Container container) {
+ return DeviceMappingManager.
+ getRequestedDeviceCount(resName, container.getResource()) > 0;
+ }
+
+ private Set getAllocatedDevices(Container container) {
+ // get allocated devices
+ Set allocated;
+ ContainerId containerId = container.getContainerId();
+ allocated = cachedAllocation.get(containerId);
+ if (allocated != null) {
+ return allocated;
+ }
+ allocated = devicePluginAdapter
+ .getDeviceMappingManager()
+ .getAllocatedDevices(resourceName, containerId);
+ if (LOG.isDebugEnabled()) {
+ LOG.debug("Get allocation from deviceMappingManager: "
+ + allocated + ", " + resourceName + " for container: " + containerId);
+ }
+ cachedAllocation.put(containerId, allocated);
+ return allocated;
+ }
+
+ public synchronized DeviceRuntimeSpec getRuntimeSpec(Container container) {
+ ContainerId containerId = container.getContainerId();
+ DeviceRuntimeSpec deviceRuntimeSpec = cachedSpec.get(containerId);
+ if (deviceRuntimeSpec == null) {
+ Set allocated = getAllocatedDevices(container);
+ if (allocated == null || allocated.size() == 0) {
+ LOG.error("Cannot get allocation for container:" + containerId);
+ return null;
+ }
+ try {
+ deviceRuntimeSpec = devicePlugin.onDevicesAllocated(allocated,
+ YarnRuntimeType.RUNTIME_DOCKER);
+ } catch (Exception e) {
+ LOG.error("Exception thrown in onDeviceAllocated of "
+ + devicePlugin.getClass() + " for container: " + containerId, e);
+ }
+ if (deviceRuntimeSpec == null) {
+ LOG.error("Null DeviceRuntimeSpec value got from "
+ + devicePlugin.getClass() + " for container: "
+ + containerId + ", please check plugin logic");
+ return null;
+ }
+ cachedSpec.put(containerId, deviceRuntimeSpec);
+ }
+ return deviceRuntimeSpec;
+ }
+
+}
diff --git a/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/deviceframework/DeviceResourceHandlerImpl.java b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/deviceframework/DeviceResourceHandlerImpl.java
index 5124ab3ad93..e4afd234bf8 100644
--- a/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/deviceframework/DeviceResourceHandlerImpl.java
+++ b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/deviceframework/DeviceResourceHandlerImpl.java
@@ -18,20 +18,29 @@
package org.apache.hadoop.yarn.server.nodemanager.containermanager.resourceplugin.deviceframework;
+import com.google.common.annotations.VisibleForTesting;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.util.StringUtils;
import org.apache.hadoop.yarn.api.records.ContainerId;
+import org.apache.hadoop.yarn.server.nodemanager.Context;
import org.apache.hadoop.yarn.server.nodemanager.api.deviceplugin.Device;
import org.apache.hadoop.yarn.server.nodemanager.api.deviceplugin.DevicePlugin;
+import org.apache.hadoop.yarn.server.nodemanager.api.deviceplugin.DeviceRuntimeSpec;
import org.apache.hadoop.yarn.server.nodemanager.api.deviceplugin.YarnRuntimeType;
import org.apache.hadoop.yarn.server.nodemanager.containermanager.container.Container;
import org.apache.hadoop.yarn.server.nodemanager.containermanager.linux.privileged.PrivilegedOperation;
+import org.apache.hadoop.yarn.server.nodemanager.containermanager.linux.privileged.PrivilegedOperationException;
import org.apache.hadoop.yarn.server.nodemanager.containermanager.linux.privileged.PrivilegedOperationExecutor;
import org.apache.hadoop.yarn.server.nodemanager.containermanager.linux.resources.CGroupsHandler;
import org.apache.hadoop.yarn.server.nodemanager.containermanager.linux.resources.ResourceHandler;
import org.apache.hadoop.yarn.server.nodemanager.containermanager.linux.resources.ResourceHandlerException;
+import org.apache.hadoop.yarn.server.nodemanager.containermanager.linux.runtime.DockerLinuxContainerRuntime;
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Arrays;
import java.util.List;
import java.util.Set;
@@ -52,19 +61,45 @@
private final CGroupsHandler cGroupsHandler;
private final PrivilegedOperationExecutor privilegedOperationExecutor;
private final DevicePluginAdapter devicePluginAdapter;
+ private final Context nmContext;
+ private ShellWrapper shellWrapper;
- public DeviceResourceHandlerImpl(String reseName,
- DevicePlugin devPlugin,
+ // This will be used by container-executor to add necessary clis
+ public static final String EXCLUDED_DEVICES_CLI_OPTION = "--excluded_devices";
+ public static final String ALLOWED_DEVICES_CLI_OPTION = "--allowed_devices";
+ public static final String CONTAINER_ID_CLI_OPTION = "--container_id";
+
+ public DeviceResourceHandlerImpl(String resName,
+ DevicePluginAdapter devPluginAdapter,
+ DeviceMappingManager devMappingManager,
+ CGroupsHandler cgHandler,
+ PrivilegedOperationExecutor operation,
+ Context ctx) {
+ this.devicePluginAdapter = devPluginAdapter;
+ this.resourceName = resName;
+ this.devicePlugin = devPluginAdapter.getDevicePlugin();
+ this.cGroupsHandler = cgHandler;
+ this.privilegedOperationExecutor = operation;
+ this.deviceMappingManager = devMappingManager;
+ this.nmContext = ctx;
+ this.shellWrapper = new ShellWrapper();
+ }
+
+ @VisibleForTesting
+ public DeviceResourceHandlerImpl(String resName,
DevicePluginAdapter devPluginAdapter,
DeviceMappingManager devMappingManager,
CGroupsHandler cgHandler,
- PrivilegedOperationExecutor operation) {
+ PrivilegedOperationExecutor operation,
+ Context ctx, ShellWrapper shell) {
this.devicePluginAdapter = devPluginAdapter;
- this.resourceName = reseName;
- this.devicePlugin = devPlugin;
+ this.resourceName = resName;
+ this.devicePlugin = devPluginAdapter.getDevicePlugin();
this.cGroupsHandler = cgHandler;
this.privilegedOperationExecutor = operation;
this.deviceMappingManager = devMappingManager;
+ this.nmContext = ctx;
+ this.shellWrapper = shell;
}
@Override
@@ -98,11 +133,13 @@ public DeviceResourceHandlerImpl(String reseName,
String containerIdStr = container.getContainerId().toString();
DeviceMappingManager.DeviceAllocation allocation =
deviceMappingManager.assignDevices(resourceName, container);
- LOG.debug("Allocated to "
- + containerIdStr + ": " + allocation);
-
+ if (LOG.isDebugEnabled()) {
+ LOG.debug("Allocated to "
+ + containerIdStr + ": " + allocation);
+ }
+ DeviceRuntimeSpec spec;
try {
- devicePlugin.onDevicesAllocated(
+ spec = devicePlugin.onDevicesAllocated(
allocation.getAllowed(), YarnRuntimeType.RUNTIME_DEFAULT);
} catch (Exception e) {
throw new ResourceHandlerException("Exception thrown from"
@@ -110,13 +147,95 @@ public DeviceResourceHandlerImpl(String reseName,
}
// cgroups operation based on allocation
- /**
- * TODO: implement a general container-executor device module
- * */
+ if (spec != null) {
+ LOG.warn("Runtime spec in non-Docker container is not supported yet!");
+ }
+ // Create device cgroups for the container
+ cGroupsHandler.createCGroup(CGroupsHandler.CGroupController.DEVICES,
+ containerIdStr);
+ // non-Docker, use cgroups to do isolation
+ if (!DockerLinuxContainerRuntime.isDockerContainerRequested(
+ nmContext.getConf(),
+ container.getLaunchContext().getEnvironment())) {
+ tryIsolateDevices(allocation, containerIdStr);
+ List ret = new ArrayList<>();
+ ret.add(new PrivilegedOperation(
+ PrivilegedOperation.OperationType.ADD_PID_TO_CGROUP,
+ PrivilegedOperation.CGROUP_ARG_PREFIX + cGroupsHandler
+ .getPathForCGroupTasks(CGroupsHandler.CGroupController.DEVICES,
+ containerIdStr)));
+ return ret;
+ }
return null;
}
+ /**
+ * Try set cgroup devices params for the container using container-executor.
+ * If it has real device major number, minor number or dev path,
+ * we'll do the enforcement. Otherwise, won't do it.
+ *
+ * */
+ private void tryIsolateDevices(
+ DeviceMappingManager.DeviceAllocation allocation,
+ String containerIdStr) throws ResourceHandlerException {
+ try {
+ // Execute c-e to setup device isolation before launch the container
+ PrivilegedOperation privilegedOperation = new PrivilegedOperation(
+ PrivilegedOperation.OperationType.DEVICE,
+ Arrays.asList(CONTAINER_ID_CLI_OPTION, containerIdStr));
+ boolean needNativeDeviceOperation = false;
+ int majorNumber;
+ int minorNumber;
+ List devNumbers = new ArrayList<>();
+ if (!allocation.getDenied().isEmpty()) {
+ DeviceType devType;
+ for (Device deniedDevice : allocation.getDenied()) {
+ majorNumber = deniedDevice.getMajorNumber();
+ minorNumber = deniedDevice.getMinorNumber();
+ // Add device type
+ devType = getDeviceType(deniedDevice);
+ if (devType != null) {
+ devNumbers.add(devType.getName() + "-" + majorNumber + ":"
+ + minorNumber + "-rwm");
+ }
+ }
+ if (devNumbers.size() != 0) {
+ privilegedOperation.appendArgs(
+ Arrays.asList(EXCLUDED_DEVICES_CLI_OPTION,
+ StringUtils.join(",", devNumbers)));
+ needNativeDeviceOperation = true;
+ }
+ }
+
+ if (!allocation.getAllowed().isEmpty()) {
+ devNumbers.clear();
+ for (Device allowedDevice : allocation.getAllowed()) {
+ majorNumber = allowedDevice.getMajorNumber();
+ minorNumber = allowedDevice.getMinorNumber();
+ if (majorNumber != -1 && minorNumber != -1) {
+ devNumbers.add(majorNumber + ":" + minorNumber);
+ }
+ }
+ if (devNumbers.size() > 0) {
+ privilegedOperation.appendArgs(
+ Arrays.asList(ALLOWED_DEVICES_CLI_OPTION,
+ StringUtils.join(",", devNumbers)));
+ needNativeDeviceOperation = true;
+ }
+ }
+ if (needNativeDeviceOperation) {
+ privilegedOperationExecutor.executePrivilegedOperation(
+ privilegedOperation, true);
+ }
+ } catch (PrivilegedOperationException e) {
+ cGroupsHandler.deleteCGroup(CGroupsHandler.CGroupController.DEVICES,
+ containerIdStr);
+ LOG.warn("Could not update cgroup for container", e);
+ throw new ResourceHandlerException(e);
+ }
+ }
+
@Override
public synchronized List reacquireContainer(
ContainerId containerId) throws ResourceHandlerException {
@@ -134,6 +253,8 @@ public DeviceResourceHandlerImpl(String reseName,
public synchronized List postComplete(
ContainerId containerId) throws ResourceHandlerException {
deviceMappingManager.cleanupAssignedDevices(resourceName, containerId);
+ cGroupsHandler.deleteCGroup(CGroupsHandler.CGroupController.DEVICES,
+ containerId.toString());
return null;
}
@@ -151,4 +272,73 @@ public String toString() {
", devicePluginAdapter=" + devicePluginAdapter +
'}';
}
+
+ public DeviceType getDeviceType(Device device) {
+ String devName = device.getDevPath();
+ if (devName.isEmpty()) {
+ LOG.warn("Empty device path provided, try to get device type from " +
+ "major:minor device number");
+ int major = device.getMajorNumber();
+ int minor = device.getMinorNumber();
+ if (major == -1 && minor == -1) {
+ LOG.warn("Non device number provided, cannot decide the device type");
+ return null;
+ }
+ // Get type from the device numbers
+ return getDeviceTypeFromDeviceNumber(device.getMajorNumber(),
+ device.getMinorNumber());
+ }
+ DeviceType deviceType;
+ try {
+ if (LOG.isDebugEnabled()) {
+ LOG.debug("Try to get device type from device path: " + devName);
+ }
+ String output = shellWrapper.getDeviceFileType(devName);
+ if (LOG.isDebugEnabled()) {
+ LOG.debug("stat output:" + output);
+ }
+ deviceType = output.startsWith("c") ? DeviceType.CHAR : DeviceType.BLOCK;
+ } catch (IOException e) {
+ String msg =
+ "Failed to get device type from stat " + devName;
+ LOG.warn(msg);
+ return null;
+ }
+ return deviceType;
+ }
+
+ /**
+ * Get the device type used for cgroups value set.
+ * If sys file "/sys/dev/block/major:minor" exists, it's block device.
+ * Otherwise, it's char device. An exception is that Nvidia GPU doesn't
+ * create this sys file. so assume character device by default.
+ */
+ public DeviceType getDeviceTypeFromDeviceNumber(int major, int minor) {
+ if (shellWrapper.existFile("/sys/dev/block/"
+ + major + ":" + minor)) {
+ return DeviceType.BLOCK;
+ }
+ return DeviceType.CHAR;
+ }
+
+ /**
+ * Enum for Linux device type. Used when updating device cgroups params.
+ * "b" represents block device
+ * "c" represents character device
+ * */
+ private enum DeviceType {
+ BLOCK("b"),
+ CHAR("c");
+
+ private final String name;
+
+ DeviceType(String n) {
+ this.name = n;
+ }
+
+ public String getName() {
+ return name;
+ }
+ }
+
}
diff --git a/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/deviceframework/ShellWrapper.java b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/deviceframework/ShellWrapper.java
new file mode 100644
index 00000000000..69cbcdf37fb
--- /dev/null
+++ b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/deviceframework/ShellWrapper.java
@@ -0,0 +1,46 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.hadoop.yarn.server.nodemanager.containermanager.resourceplugin.deviceframework;
+
+import org.apache.hadoop.util.Shell;
+
+import java.io.File;
+import java.io.IOException;
+
+/**
+ * A shell Wrapper to ease testing.
+ * */
+public class ShellWrapper {
+
+ public String getDeviceFileType(String devName) throws IOException {
+ Shell.ShellCommandExecutor shexec = new Shell.ShellCommandExecutor(
+ new String[]{"stat", "-c", "%F", devName});
+ shexec.execute();
+ return shexec.getOutput();
+ }
+
+ public boolean existFile(String path) {
+ File searchFile =
+ new File(path);
+ if (searchFile.exists()) {
+ return true;
+ }
+ return false;
+ }
+}
diff --git a/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/native/container-executor/impl/main.c b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/native/container-executor/impl/main.c
index af540fd58eb..8507ff89dee 100644
--- a/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/native/container-executor/impl/main.c
+++ b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/native/container-executor/impl/main.c
@@ -24,6 +24,7 @@
#include "modules/gpu/gpu-module.h"
#include "modules/fpga/fpga-module.h"
#include "modules/cgroups/cgroups-operations.h"
+#include "modules/devices/devices-module.h"
#include "utils/string-utils.h"
#include
@@ -289,6 +290,11 @@ static int validate_arguments(int argc, char **argv , int *operation) {
&argv[1]);
}
+ if (strcmp("--module-devices", argv[1]) == 0) {
+ return handle_devices_request(&update_cgroups_parameters, "devices", argc - 1,
+ &argv[1]);
+ }
+
if (strcmp("--checksetup", argv[1]) == 0) {
*operation = CHECK_SETUP;
return 0;
diff --git a/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/native/container-executor/impl/modules/cgroups/cgroups-operations.c b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/native/container-executor/impl/modules/cgroups/cgroups-operations.c
index ea1d36d5532..ab1eab53248 100644
--- a/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/native/container-executor/impl/modules/cgroups/cgroups-operations.c
+++ b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/native/container-executor/impl/modules/cgroups/cgroups-operations.c
@@ -132,7 +132,7 @@ int update_cgroups_parameters(
goto cleanup;
}
- fprintf(ERRORFILE, "CGroups: Updating cgroups, path=%s, value=%s",
+ fprintf(ERRORFILE, "CGroups: Updating cgroups, path=%s, value=%s\n",
full_path, value);
// Write values to file
diff --git a/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/native/container-executor/impl/modules/devices/devices-module.c b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/native/container-executor/impl/modules/devices/devices-module.c
new file mode 100644
index 00000000000..9df6662cc77
--- /dev/null
+++ b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/native/container-executor/impl/modules/devices/devices-module.c
@@ -0,0 +1,281 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "configuration.h"
+#include "container-executor.h"
+#include "utils/string-utils.h"
+#include "modules/devices/devices-module.h"
+#include "modules/cgroups/cgroups-operations.h"
+#include "modules/common/module-configs.h"
+#include "modules/common/constants.h"
+#include "util.h"
+
+#include
+#include
+#include
+#include
+#include
+#include
+
+#define EXCLUDED_DEVICES_OPTION "excluded_devices"
+#define ALLOWED_DEVICES_OPTION "allowed_devices"
+#define CONTAINER_ID_OPTION "container_id"
+#define MAX_CONTAINER_ID_LEN 128
+
+static const struct section* cfg_section;
+
+// Search a string in a string list, return 1 when found
+static int search_in_list(char** list, char* token) {
+ int i = 0;
+ char** iterator = list;
+ // search token in list
+ while (iterator[i] != NULL) {
+ if (strstr(token, iterator[i]) != NULL ||
+ strstr(iterator[i], token) != NULL) {
+ // Found deny device in allowed list
+ return 1;
+ }
+ i++;
+ }
+ return 0;
+}
+
+static int is_block_device(const char* value) {
+ int is_block = 0;
+ int max_path_size = 512;
+ char* block_path = malloc(max_path_size);
+ if (block_path == NULL) {
+ fprintf(ERRORFILE, "Failed to allocate memory for sys device path string.\n");
+ fflush(ERRORFILE);
+ goto cleanup;
+ }
+ if (snprintf(block_path, max_path_size, "/sys/dev/block/%s",
+ value) < 0) {
+ fprintf(ERRORFILE, "Failed to construct system block device path.\n");
+ goto cleanup;
+ }
+ struct stat sb;
+ // file exists, is block device
+ if (stat(block_path, &sb) == 0) {
+ is_block = 1;
+ }
+cleanup:
+ if (block_path) {
+ free(block_path);
+ }
+ return is_block;
+}
+
+static int internal_handle_devices_request(
+ update_cgroups_parameters_function update_cgroups_parameters_func_p,
+ char** deny_devices_number_tokens,
+ char** allow_devices_number_tokens,
+ const char* container_id) {
+ int return_code = 0;
+
+ char** ce_denied_numbers = NULL;
+ char* ce_denied_str = get_section_value(DEVICES_DENIED_NUMBERS,
+ cfg_section);
+ // Get denied "major:minor" device numbers from cfg, if not set, means all
+ // devices can be used by YARN.
+ if (ce_denied_str != NULL) {
+ ce_denied_numbers = split_delimiter(ce_denied_str, ",");
+ if (NULL == ce_denied_numbers) {
+ fprintf(ERRORFILE,
+ "Invalid value set for %s, value=%s\n",
+ DEVICES_DENIED_NUMBERS,
+ ce_denied_str);
+ return_code = -1;
+ goto cleanup;
+ }
+ // Check allowed devices passed in
+ char** allow_iterator = allow_devices_number_tokens;
+ int allow_count = 0;
+ while (allow_iterator[allow_count] != NULL) {
+ if (search_in_list(ce_denied_numbers, allow_iterator[allow_count])) {
+ fprintf(ERRORFILE,
+ "Trying to allow device with device number=%s which is not permitted in container-executor.cfg. %s\n",
+ allow_iterator[allow_count],
+ "It could be caused by a mismatch of devices reported by device plugin");
+ return_code = -1;
+ goto cleanup;
+ }
+ allow_count++;
+ }
+
+ // Deny devices configured in c-e.cfg
+ char** ce_iterator = ce_denied_numbers;
+ int ce_count = 0;
+ while (ce_iterator[ce_count] != NULL) {
+ // skip if duplicate with denied numbers passed in
+ if (search_in_list(deny_devices_number_tokens, ce_iterator[ce_count])) {
+ ce_count++;
+ continue;
+ }
+ char param_value[128];
+ char type = 'c';
+ memset(param_value, 0, sizeof(param_value));
+ if (is_block_device(ce_iterator[ce_count])) {
+ type = 'b';
+ }
+ snprintf(param_value, sizeof(param_value), "%c %s rwm",
+ type,
+ ce_iterator[ce_count]);
+ // Update device cgroups value
+ int rc = update_cgroups_parameters_func_p("devices", "deny",
+ container_id, param_value);
+
+ if (0 != rc) {
+ fprintf(ERRORFILE, "CGroups: Failed to update cgroups. %s\n", param_value);
+ return_code = -1;
+ goto cleanup;
+ }
+ ce_count++;
+ }
+ }
+
+ // Deny devices passed from java side
+ char** iterator = deny_devices_number_tokens;
+ int count = 0;
+ char* value = NULL;
+ int index = 0;
+ while (iterator[count] != NULL) {
+ // Replace like "c-242:0-rwm" to "c 242:0 rwm"
+ value = iterator[count];
+ index = 0;
+ while (value[index] != '\0') {
+ if (value[index] == '-') {
+ value[index] = ' ';
+ }
+ index++;
+ }
+ // Update device cgroups value
+ int rc = update_cgroups_parameters_func_p("devices", "deny",
+ container_id, iterator[count]);
+
+ if (0 != rc) {
+ fprintf(ERRORFILE, "CGroups: Failed to update cgroups\n");
+ return_code = -1;
+ goto cleanup;
+ }
+ count++;
+ }
+
+cleanup:
+ if (ce_denied_numbers != NULL) {
+ free_values(ce_denied_numbers);
+ }
+ return return_code;
+}
+
+void reload_devices_configuration() {
+ cfg_section = get_configuration_section(DEVICES_MODULE_SECTION_NAME, get_cfg());
+}
+
+/*
+ * Format of devices request commandline:
+ * The excluded_devices is comma separated device cgroups values with device type.
+ * The "-" will be replaced with " " to match the cgroups parameter
+ * c-e --module-devices \
+ * --excluded_devices b-8:16-rwm,c-244:0-rwm,c-244:1-rwm \
+ * --allowed_devices 8:32,8:48,243:2 \
+ * --container_id container_x_y
+ */
+int handle_devices_request(update_cgroups_parameters_function func,
+ const char* module_name, int module_argc, char** module_argv) {
+ if (!cfg_section) {
+ reload_devices_configuration();
+ }
+
+ if (!module_enabled(cfg_section, DEVICES_MODULE_SECTION_NAME)) {
+ fprintf(ERRORFILE,
+ "Please make sure devices module is enabled before using it.\n");
+ return -1;
+ }
+
+ static struct option long_options[] = {
+ {EXCLUDED_DEVICES_OPTION, required_argument, 0, 'e' },
+ {ALLOWED_DEVICES_OPTION, required_argument, 0, 'a' },
+ {CONTAINER_ID_OPTION, required_argument, 0, 'c' },
+ {0, 0, 0, 0}
+ };
+
+ int c = 0;
+ int option_index = 0;
+
+ char** deny_device_value_tokens = NULL;
+ char** allow_device_value_tokens = NULL;
+ char container_id[MAX_CONTAINER_ID_LEN];
+ memset(container_id, 0, sizeof(container_id));
+ int failed = 0;
+
+ optind = 1;
+ while((c = getopt_long(module_argc, module_argv, "e:a:c:",
+ long_options, &option_index)) != -1) {
+ switch(c) {
+ case 'e':
+ deny_device_value_tokens = split_delimiter(optarg, ",");
+ break;
+ case 'a':
+ allow_device_value_tokens = split_delimiter(optarg, ",");
+ break;
+ case 'c':
+ if (!validate_container_id(optarg)) {
+ fprintf(ERRORFILE,
+ "Specified container_id=%s is invalid\n", optarg);
+ failed = 1;
+ goto cleanup;
+ }
+ strncpy(container_id, optarg, MAX_CONTAINER_ID_LEN);
+ break;
+ default:
+ fprintf(ERRORFILE,
+ "Unknown option in devices command character %d %c, optionindex = %d\n",
+ c, c, optind);
+ failed = 1;
+ goto cleanup;
+ }
+ }
+
+ if (0 == container_id[0]) {
+ fprintf(ERRORFILE,
+ "[%s] --container_id must be specified.\n", __func__);
+ failed = 1;
+ goto cleanup;
+ }
+
+ if (NULL == deny_device_value_tokens) {
+ // Devices number is null, skip following call.
+ fprintf(ERRORFILE, "--excluded_devices is not specified, skip cgroups call.\n");
+ goto cleanup;
+ }
+
+ failed = internal_handle_devices_request(func,
+ deny_device_value_tokens,
+ allow_device_value_tokens,
+ container_id);
+
+cleanup:
+ if (deny_device_value_tokens) {
+ free_values(deny_device_value_tokens);
+ }
+ if (allow_device_value_tokens) {
+ free_values(allow_device_value_tokens);
+ }
+ return failed;
+}
diff --git a/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/native/container-executor/impl/modules/devices/devices-module.h b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/native/container-executor/impl/modules/devices/devices-module.h
new file mode 100644
index 00000000000..c5d67851b83
--- /dev/null
+++ b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/native/container-executor/impl/modules/devices/devices-module.h
@@ -0,0 +1,45 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifdef __FreeBSD__
+#define _WITH_GETLINE
+#endif
+
+#ifndef _MODULES_DEVICES_MUDULE_H_
+#define _MODULES_DEVICES_MUDULE_H_
+
+// Denied device list. value format is "major1:minor1,major2:minor2"
+#define DEVICES_DENIED_NUMBERS "devices.denied-numbers"
+#define DEVICES_MODULE_SECTION_NAME "devices"
+
+// For unit test stubbing
+typedef int (*update_cgroups_parameters_function)(const char*, const char*,
+ const char*, const char*);
+
+/**
+ * Handle devices requests
+ */
+int handle_devices_request(update_cgroups_parameters_function func,
+ const char* module_name, int module_argc, char** module_argv);
+
+/**
+ * Reload config from filesystem, visible for testing.
+ */
+void reload_devices_configuration();
+
+#endif
\ No newline at end of file
diff --git a/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/native/container-executor/impl/util.c b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/native/container-executor/impl/util.c
index eea3e108ea6..1753954cb2e 100644
--- a/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/native/container-executor/impl/util.c
+++ b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/native/container-executor/impl/util.c
@@ -44,6 +44,9 @@ char** split_delimiter(char *value, const char *delim) {
memset(return_values, 0, sizeof(char *) * return_values_size);
temp_tok = strtok_r(value, delim, &tempstr);
+ if (NULL == temp_tok) {
+ return_values[size++] = strdup(value);
+ }
while (temp_tok != NULL) {
temp_tok = strdup(temp_tok);
if (NULL == temp_tok) {
diff --git a/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/native/container-executor/test/modules/devices/test-devices-module.cc b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/native/container-executor/test/modules/devices/test-devices-module.cc
new file mode 100644
index 00000000000..9c537d6f516
--- /dev/null
+++ b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/native/container-executor/test/modules/devices/test-devices-module.cc
@@ -0,0 +1,298 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include
+
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+
+#include
+#include
+
+extern "C" {
+#include "configuration.h"
+#include "container-executor.h"
+#include "modules/cgroups/cgroups-operations.h"
+#include "modules/devices/devices-module.h"
+#include "test/test-container-executor-common.h"
+#include "util.h"
+}
+
+namespace ContainerExecutor {
+
+class TestDevicesModule : public ::testing::Test {
+protected:
+ virtual void SetUp() {
+ if (mkdirs(TEST_ROOT, 0755) != 0) {
+ fprintf(ERRORFILE, "Failed to mkdir TEST_ROOT: %s\n", TEST_ROOT);
+ exit(1);
+ }
+ LOGFILE = stdout;
+ ERRORFILE = stderr;
+ }
+
+ virtual void TearDown() {
+
+ }
+};
+
+static std::vector cgroups_parameters_invoked;
+
+static int mock_update_cgroups_parameters(
+ const char* controller_name,
+ const char* param_name,
+ const char* group_id,
+ const char* value) {
+ char* buf = (char*) malloc(128);
+ strcpy(buf, controller_name);
+ cgroups_parameters_invoked.push_back(buf);
+
+ buf = (char*) malloc(128);
+ strcpy(buf, param_name);
+ cgroups_parameters_invoked.push_back(buf);
+
+ buf = (char*) malloc(128);
+ strcpy(buf, group_id);
+ cgroups_parameters_invoked.push_back(buf);
+
+ buf = (char*) malloc(128);
+ strcpy(buf, value);
+ cgroups_parameters_invoked.push_back(buf);
+ return 0;
+}
+
+static void clear_cgroups_parameters_invoked() {
+ for (std::vector::size_type i = 0; i < cgroups_parameters_invoked.size(); i++) {
+ free((void *) cgroups_parameters_invoked[i]);
+ }
+ cgroups_parameters_invoked.clear();
+}
+
+static void verify_param_updated_to_cgroups(
+ int argc, const char** argv) {
+ ASSERT_EQ(argc, cgroups_parameters_invoked.size());
+
+ int offset = 0;
+ while (offset < argc) {
+ ASSERT_STREQ(argv[offset], cgroups_parameters_invoked[offset]);
+ offset++;
+ }
+}
+
+static void write_and_load_devices_module_to_cfg(const char* cfg_filepath, int enabled) {
+ FILE *file = fopen(cfg_filepath, "w");
+ if (file == NULL) {
+ printf("FAIL: Could not open configuration file: %s\n", cfg_filepath);
+ exit(1);
+ }
+ fprintf(file, "[devices]\n");
+ if (enabled) {
+ fprintf(file, "module.enabled=true\n");
+ } else {
+ fprintf(file, "module.enabled=false\n");
+ }
+ fclose(file);
+
+ // Read config file
+ read_executor_config(cfg_filepath);
+ reload_devices_configuration();
+}
+
+static void append_config(const char* cfg_filepath, char values[]) {
+ FILE *file = fopen(cfg_filepath, "a");
+ if (file == NULL) {
+ printf("FAIL: Could not open configuration file: %s\n", cfg_filepath);
+ exit(1);
+ }
+ fprintf(file, "%s", values);
+ fclose(file);
+
+ // Read config file
+ read_executor_config(cfg_filepath);
+ reload_devices_configuration();
+}
+
+static void test_devices_module_enabled_disabled(int enabled) {
+ // Write config file.
+ const char *filename = TEST_ROOT "/test_cgroups_module_enabled_disabled.cfg";
+ write_and_load_devices_module_to_cfg(filename, enabled);
+ char excluded_devices[] = "c-243:0-rwm,c-243:1-rwm";
+ char allowed_devices[] = "243:2";
+ char* argv[] = { (char*) "--module-devices", (char*) "--excluded_devices",
+ excluded_devices,
+ (char*) "--allowed_devices",
+ allowed_devices,
+ (char*) "--container_id",
+ (char*) "container_1498064906505_0001_01_000001" };
+
+ int rc = handle_devices_request(&mock_update_cgroups_parameters,
+ "devices", 7, argv);
+
+ int EXPECTED_RC;
+ if (enabled) {
+ EXPECTED_RC = 0;
+ } else {
+ EXPECTED_RC = -1;
+ }
+ ASSERT_EQ(EXPECTED_RC, rc);
+
+ clear_cgroups_parameters_invoked();
+ free_executor_configurations();
+}
+
+TEST_F(TestDevicesModule, test_verify_device_module_calls_cgroup_parameter) {
+ // Write config file.
+ const char *filename = TEST_ROOT "/test_verify_devices_module_calls_cgroup_parameter.cfg";
+ write_and_load_devices_module_to_cfg(filename, 1);
+
+ char* container_id = (char*) "container_1498064906505_0001_01_000001";
+ char excluded_devices[] = "c-243:0-rwm,c-243:1-rwm";
+ char allowed_devices[] = "243:2";
+ char* argv[] = { (char*) "--module-devices", (char*) "--excluded_devices",
+ excluded_devices,
+ (char*) "--allowed_devices",
+ allowed_devices,
+ (char*) "--container_id",
+ container_id };
+ /* Test case 1: block 2 devices */
+ clear_cgroups_parameters_invoked();
+ int rc = handle_devices_request(&mock_update_cgroups_parameters,
+ "devices", 7, argv);
+ ASSERT_EQ(0, rc) << "Should success.\n";
+ // Verify cgroups parameters
+ const char* expected_cgroups_argv[] = { "devices", "deny", container_id, "c 243:0 rwm",
+ "devices", "deny", container_id, "c 243:1 rwm"};
+ verify_param_updated_to_cgroups(8, expected_cgroups_argv);
+
+ /* Test case 2: block 0 devices */
+ clear_cgroups_parameters_invoked();
+ char* argv_1[] = { (char*) "--module-devices", (char*) "--container_id", container_id };
+ rc = handle_devices_request(&mock_update_cgroups_parameters,
+ "devices", 3, argv_1);
+ ASSERT_EQ(0, rc) << "Should success.\n";
+
+ // Verify cgroups parameters
+ verify_param_updated_to_cgroups(0, NULL);
+
+ clear_cgroups_parameters_invoked();
+ free_executor_configurations();
+}
+
+TEST_F(TestDevicesModule, test_update_cgroup_parameter_with_config) {
+ // Write config file.
+ const char *filename = TEST_ROOT "/test_update_cgroup_parameter_with_config.cfg";
+ write_and_load_devices_module_to_cfg(filename, 1);
+ // Add denied numbers
+ char tokens[] = "devices.denied-numbers=243:1\n";
+ append_config(filename, tokens);
+
+ char* container_id = (char*) "container_1498064906505_0001_01_000001";
+ char excluded_devices[] = "c-243:0-rwm,c-243:1-rwm";
+ char allowed_devices[] = "243:2";
+ char* argv[] = { (char*) "--module-devices", (char*) "--excluded_devices",
+ excluded_devices,
+ (char*) "--allowed_devices",
+ allowed_devices,
+ (char*) "--container_id",
+ container_id };
+ /* Test case 1: block 2 devices */
+ clear_cgroups_parameters_invoked();
+ int rc = handle_devices_request(&mock_update_cgroups_parameters,
+ "devices", 7, argv);
+ ASSERT_EQ(0, rc) << "Should success.\n";
+ // Verify cgroups parameters
+ const char* expected_cgroups_argv[] = { "devices", "deny", container_id, "c 243:0 rwm",
+ "devices", "deny", container_id, "c 243:1 rwm"};
+ verify_param_updated_to_cgroups(8, expected_cgroups_argv);
+
+ /* Test case 2: block 2 devices but try allow devices not permitted by config*/
+ clear_cgroups_parameters_invoked();
+ // device plugin reported 0,1,2,3 totally. Allocated 1,2
+ // But c-e.cfg has device 1 denied.
+ char excluded_devices2[] = "c-243:0-rwm,c-243:3-rwm";
+ char allowed_devices2[] = "243:1,243:2";
+ char* argv1[] = { (char*) "--module-devices", (char*) "--excluded_devices",
+ excluded_devices2,
+ (char*) "--allowed_devices",
+ allowed_devices2,
+ (char*) "--container_id",
+ container_id };
+ rc = handle_devices_request(&mock_update_cgroups_parameters,
+ "devices", 7, argv1);
+ ASSERT_NE(0, rc) << "Should fail.\n";
+
+ clear_cgroups_parameters_invoked();
+ free_executor_configurations();
+}
+
+TEST_F(TestDevicesModule, test_illegal_cli_parameters) {
+ // Write config file.
+ const char *filename = TEST_ROOT "/test_illegal_cli_parameters.cfg";
+ write_and_load_devices_module_to_cfg(filename, 1);
+ char excluded_devices[] = "c-243:0-rwm,c-243:1-rwm";
+ char allowed_devices[] = "243:2";
+ // Illegal container id - 1
+ char* argv[] = { (char*) "--module-devices", (char*) "--excluded_devices",
+ excluded_devices,
+ (char*) "--allowed_devices",
+ allowed_devices,
+ (char*) "--container_id", (char*) "xxxx" };
+ int rc = handle_devices_request(&mock_update_cgroups_parameters,
+ "devices", 7, argv);
+ ASSERT_NE(0, rc) << "Should fail.\n";
+
+ // Illegal container id - 2
+ clear_cgroups_parameters_invoked();
+ char* argv_1[] = { (char*) "--module-devices", (char*) "--excluded_devices",
+ excluded_devices,
+ (char*) "--allowed_devices",
+ allowed_devices,
+ (char*) "--container_id", (char*) "container_1" };
+ rc = handle_devices_request(&mock_update_cgroups_parameters,
+ "devices", 7, argv_1);
+ ASSERT_NE(0, rc) << "Should fail.\n";
+
+ // Illegal container id - 3
+ clear_cgroups_parameters_invoked();
+ char* argv_2[] = { (char*) "--module-devices",
+ (char*) "--excluded_devices",
+ excluded_devices };
+ rc = handle_devices_request(&mock_update_cgroups_parameters,
+ "devices", 3, argv_2);
+ ASSERT_NE(0, rc) << "Should fail.\n";
+
+ clear_cgroups_parameters_invoked();
+ free_executor_configurations();
+}
+
+TEST_F(TestDevicesModule, test_devices_module_disabled) {
+ test_devices_module_enabled_disabled(0);
+}
+
+TEST_F(TestDevicesModule, test_devices_module_enabled) {
+ test_devices_module_enabled_disabled(1);
+}
+} // namespace ContainerExecutor
\ No newline at end of file
diff --git a/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/test/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/deviceframework/TestDeviceMappingManager.java b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/test/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/deviceframework/TestDeviceMappingManager.java
index d69ab420e9f..508e7f784f6 100644
--- a/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/test/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/deviceframework/TestDeviceMappingManager.java
+++ b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/test/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/deviceframework/TestDeviceMappingManager.java
@@ -25,6 +25,7 @@
import org.apache.hadoop.yarn.api.records.ContainerLaunchContext;
import org.apache.hadoop.yarn.api.records.Resource;
import org.apache.hadoop.yarn.conf.YarnConfiguration;
+import org.apache.hadoop.yarn.server.nodemanager.Context;
import org.apache.hadoop.yarn.server.nodemanager.NodeManager;
import org.apache.hadoop.yarn.server.nodemanager.api.deviceplugin.Device;
import org.apache.hadoop.yarn.server.nodemanager.api.deviceplugin.DevicePlugin;
@@ -33,6 +34,8 @@
import org.apache.hadoop.yarn.server.nodemanager.api.deviceplugin.YarnRuntimeType;
import org.apache.hadoop.yarn.server.nodemanager.containermanager.container.Container;
import org.apache.hadoop.yarn.server.nodemanager.containermanager.container.ResourceMappings;
+import org.apache.hadoop.yarn.server.nodemanager.containermanager.linux.privileged.PrivilegedOperationExecutor;
+import org.apache.hadoop.yarn.server.nodemanager.containermanager.linux.resources.CGroupsHandler;
import org.apache.hadoop.yarn.server.nodemanager.containermanager.linux.resources.ResourceHandlerException;
import org.apache.hadoop.yarn.server.nodemanager.containermanager.runtime.ContainerRuntimeConstants;
import org.apache.hadoop.yarn.server.nodemanager.recovery.NMStateStoreService;
@@ -74,6 +77,10 @@
private ExecutorService containerLauncher;
private Configuration conf;
+ private CGroupsHandler mockCGroupsHandler;
+ private PrivilegedOperationExecutor mockPrivilegedExecutor;
+ private Context mockCtx;
+
@Before
public void setup() throws Exception {
// setup resource-types.xml
@@ -89,7 +96,7 @@ public void setup() throws Exception {
isA(String.class),
isA(ArrayList.class));
dmm = new DeviceMappingManager(context);
- int deviceCount = 600;
+ int deviceCount = 100;
TreeSet r = new TreeSet<>();
for (int i = 0; i < deviceCount; i++) {
r.add(Device.Builder.newInstance()
@@ -117,6 +124,10 @@ public void setup() throws Exception {
containerLauncher =
Executors.newFixedThreadPool(10);
+ mockCGroupsHandler = mock(CGroupsHandler.class);
+ mockPrivilegedExecutor = mock(PrivilegedOperationExecutor.class);
+ mockCtx = mock(NodeManager.NMContext.class);
+ when(mockCtx.getConf()).thenReturn(conf);
}
@After
@@ -134,7 +145,7 @@ public void tearDown() throws IOException {
@Test
public void testAllocation()
throws InterruptedException, ResourceHandlerException {
- int totalContainerCount = 100;
+ int totalContainerCount = 10;
String resourceName1 = "cmpA.com/hdwA";
String resourceName2 = "cmp.com/cmp";
DeviceMappingManager dmmSpy = spy(dmm);
@@ -158,11 +169,12 @@ public void testAllocation()
resourceName,
num, false);
containerSet.get(resourceName).put(c, num);
-
+ DevicePlugin myPlugin = new MyTestPlugin();
+ DevicePluginAdapter dpa = new DevicePluginAdapter(resourceName,
+ myPlugin, dmm);
DeviceResourceHandlerImpl dri = new DeviceResourceHandlerImpl(
- resourceName,
- new MyTestPlugin(), null,
- dmmSpy, null, null);
+ resourceName, dpa,
+ dmmSpy, mockCGroupsHandler, mockPrivilegedExecutor, mockCtx);
Future f = containerLauncher.submit(new MyContainerLaunch(
dri, c, i, false));
}
@@ -173,12 +185,11 @@ public void testAllocation()
}
Long endTime = System.currentTimeMillis();
- LOG.info("Each container allocation spends roughly: {} ms",
+ LOG.info("Each container preStart spends roughly: {} ms",
(endTime - startTime)/totalContainerCount);
// Ensure invocation times
verify(dmmSpy, times(totalContainerCount)).assignDevices(
anyString(), any(Container.class));
-
// Ensure used devices' count for each type is correct
int totalAllocatedCount = 0;
Map used1 =
@@ -198,23 +209,15 @@ public void testAllocation()
for (Map.Entry entry :
containerSet.get(resourceName1).entrySet()) {
int containerWanted = entry.getValue();
- int actualAllocated = 0;
- for (ContainerId cid : used1.values()) {
- if (cid.equals(entry.getKey().getContainerId())) {
- actualAllocated++;
- }
- }
+ int actualAllocated = dmm.getAllocatedDevices(resourceName1,
+ entry.getKey().getContainerId()).size();
Assert.assertEquals(containerWanted, actualAllocated);
}
for (Map.Entry entry :
containerSet.get(resourceName2).entrySet()) {
int containerWanted = entry.getValue();
- int actualAllocated = 0;
- for (ContainerId cid : used2.values()) {
- if (cid.equals(entry.getKey().getContainerId())) {
- actualAllocated++;
- }
- }
+ int actualAllocated = dmm.getAllocatedDevices(resourceName2,
+ entry.getKey().getContainerId()).size();
Assert.assertEquals(containerWanted, actualAllocated);
}
}
@@ -248,11 +251,12 @@ public void testAllocationAndCleanup()
resourceName,
num, false);
containerSet.get(resourceName).put(c, num);
-
+ DevicePlugin myPlugin = new MyTestPlugin();
+ DevicePluginAdapter dpa = new DevicePluginAdapter(resourceName,
+ myPlugin, dmm);
DeviceResourceHandlerImpl dri = new DeviceResourceHandlerImpl(
- resourceName,
- new MyTestPlugin(), null,
- dmmSpy, null, null);
+ resourceName, dpa,
+ dmmSpy, mockCGroupsHandler, mockPrivilegedExecutor, mockCtx);
Future f = containerLauncher.submit(new MyContainerLaunch(
dri, c, i, true));
}
@@ -262,7 +266,6 @@ public void testAllocationAndCleanup()
LOG.info("Wait for the threads to finish");
}
-
// Ensure invocation times
verify(dmmSpy, times(totalContainerCount)).assignDevices(
anyString(), any(Container.class));
diff --git a/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/test/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/deviceframework/TestDevicePluginAdapter.java b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/test/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/deviceframework/TestDevicePluginAdapter.java
index b9a0763572e..75668f2a5d7 100644
--- a/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/test/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/deviceframework/TestDevicePluginAdapter.java
+++ b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/test/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/deviceframework/TestDevicePluginAdapter.java
@@ -18,7 +18,6 @@
package org.apache.hadoop.yarn.server.nodemanager.containermanager.resourceplugin.deviceframework;
-import org.apache.hadoop.service.ServiceOperations;
import org.apache.hadoop.yarn.api.records.ApplicationAttemptId;
import org.apache.hadoop.yarn.api.records.ApplicationId;
@@ -34,12 +33,20 @@
import org.apache.hadoop.yarn.server.nodemanager.api.deviceplugin.DevicePluginScheduler;
import org.apache.hadoop.yarn.server.nodemanager.api.deviceplugin.DeviceRegisterRequest;
import org.apache.hadoop.yarn.server.nodemanager.api.deviceplugin.DeviceRuntimeSpec;
+import org.apache.hadoop.yarn.server.nodemanager.api.deviceplugin.MountDeviceSpec;
+import org.apache.hadoop.yarn.server.nodemanager.api.deviceplugin.MountVolumeSpec;
+import org.apache.hadoop.yarn.server.nodemanager.api.deviceplugin.VolumeSpec;
import org.apache.hadoop.yarn.server.nodemanager.api.deviceplugin.YarnRuntimeType;
import org.apache.hadoop.yarn.server.nodemanager.containermanager.container.Container;
import org.apache.hadoop.yarn.server.nodemanager.containermanager.container.ResourceMappings;
+import org.apache.hadoop.yarn.server.nodemanager.containermanager.linux.privileged.PrivilegedOperation;
+import org.apache.hadoop.yarn.server.nodemanager.containermanager.linux.privileged.PrivilegedOperationException;
import org.apache.hadoop.yarn.server.nodemanager.containermanager.linux.privileged.PrivilegedOperationExecutor;
import org.apache.hadoop.yarn.server.nodemanager.containermanager.linux.resources.CGroupsHandler;
import org.apache.hadoop.yarn.server.nodemanager.containermanager.linux.resources.ResourceHandlerException;
+import org.apache.hadoop.yarn.server.nodemanager.containermanager.linux.runtime.docker.DockerRunCommand;
+import org.apache.hadoop.yarn.server.nodemanager.containermanager.linux.runtime.docker.DockerVolumeCommand;
+import org.apache.hadoop.yarn.server.nodemanager.containermanager.resourceplugin.DockerCommandPlugin;
import org.apache.hadoop.yarn.server.nodemanager.containermanager.resourceplugin.ResourcePluginManager;
import org.apache.hadoop.yarn.server.nodemanager.containermanager.runtime.ContainerRuntimeConstants;
import org.apache.hadoop.yarn.server.nodemanager.recovery.NMMemoryStateStoreService;
@@ -51,6 +58,7 @@
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
+import org.mockito.ArgumentCaptor;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -60,15 +68,21 @@
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
+import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.TreeSet;
import java.util.concurrent.ConcurrentHashMap;
+
import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.ArgumentMatchers.anyBoolean;
import static org.mockito.ArgumentMatchers.anyInt;
+import static org.mockito.ArgumentMatchers.anyString;
+import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.ArgumentMatchers.isA;
import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.reset;
import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.times;
@@ -89,7 +103,6 @@
private String tempResourceTypesFile;
private CGroupsHandler mockCGroupsHandler;
private PrivilegedOperationExecutor mockPrivilegedExecutor;
- private NodeManager nm;
@Before
public void setup() throws Exception {
@@ -110,13 +123,6 @@ public void tearDown() throws IOException {
if (dest.exists()) {
dest.delete();
}
- if (nm != null) {
- try {
- ServiceOperations.stop(nm);
- } catch (Throwable t) {
- // ignore
- }
- }
}
@@ -130,16 +136,14 @@ public void testBasicWorkflow()
NodeManager.NMContext context = mock(NodeManager.NMContext.class);
NMStateStoreService storeService = mock(NMStateStoreService.class);
when(context.getNMStateStore()).thenReturn(storeService);
+ when(context.getConf()).thenReturn(this.conf);
doNothing().when(storeService).storeAssignedResources(isA(Container.class),
isA(String.class),
isA(ArrayList.class));
-
// Init scheduler manager
DeviceMappingManager dmm = new DeviceMappingManager(context);
-
ResourcePluginManager rpm = mock(ResourcePluginManager.class);
when(rpm.getDeviceMappingManager()).thenReturn(dmm);
-
// Init an plugin
MyPlugin plugin = new MyPlugin();
MyPlugin spyPlugin = spy(plugin);
@@ -150,14 +154,19 @@ public void testBasicWorkflow()
spyPlugin, dmm);
// Bootstrap, adding device
adapter.initialize(context);
- adapter.createResourceHandler(context,
- mockCGroupsHandler, mockPrivilegedExecutor);
+ // Use mock shell when create resourceHandler
+ ShellWrapper mockShellWrapper = mock(ShellWrapper.class);
+ when(mockShellWrapper.existFile(anyString())).thenReturn(true);
+ when(mockShellWrapper.getDeviceFileType(anyString())).thenReturn("c");
+ DeviceResourceHandlerImpl drhl = new DeviceResourceHandlerImpl(resourceName,
+ adapter, dmm, mockCGroupsHandler, mockPrivilegedExecutor, context,
+ mockShellWrapper);
+ adapter.setDeviceResourceHandler(drhl);
adapter.getDeviceResourceHandler().bootstrap(conf);
int size = dmm.getAvailableDevices(resourceName);
Assert.assertEquals(3, size);
-
- // A container c1 requests 1 device
- Container c1 = mockContainerWithDeviceRequest(0,
+ // Case 1. A container c1 requests 1 device
+ Container c1 = mockContainerWithDeviceRequest(1,
resourceName,
1, false);
// preStart
@@ -169,19 +178,33 @@ public void testBasicWorkflow()
dmm.getAllUsedDevices().get(resourceName).size());
Assert.assertEquals(3,
dmm.getAllAllowedDevices().get(resourceName).size());
+ Assert.assertEquals(1,
+ dmm.getAllocatedDevices(resourceName, c1.getContainerId()).size());
+ verify(mockShellWrapper, times(2)).getDeviceFileType(anyString());
+ // check device cgroup create operation
+ checkCgroupOperation(c1.getContainerId().toString(), 1,
+ "c-256:1-rwm,c-256:2-rwm", "256:0");
// postComplete
- adapter.getDeviceResourceHandler().postComplete(getContainerId(0));
+ adapter.getDeviceResourceHandler().postComplete(getContainerId(1));
Assert.assertEquals(3,
dmm.getAvailableDevices(resourceName));
Assert.assertEquals(0,
dmm.getAllUsedDevices().get(resourceName).size());
Assert.assertEquals(3,
dmm.getAllAllowedDevices().get(resourceName).size());
-
- // A container c2 requests 3 device
- Container c2 = mockContainerWithDeviceRequest(1,
+ // check cgroup delete operation
+ verify(mockCGroupsHandler).deleteCGroup(
+ CGroupsHandler.CGroupController.DEVICES,
+ c1.getContainerId().toString());
+ // Case 2. A container c2 requests 3 device
+ Container c2 = mockContainerWithDeviceRequest(2,
resourceName,
3, false);
+ reset(mockShellWrapper);
+ reset(mockCGroupsHandler);
+ reset(mockPrivilegedExecutor);
+ when(mockShellWrapper.existFile(anyString())).thenReturn(true);
+ when(mockShellWrapper.getDeviceFileType(anyString())).thenReturn("c");
// preStart
adapter.getDeviceResourceHandler().preStart(c2);
// check book keeping
@@ -191,19 +214,37 @@ public void testBasicWorkflow()
dmm.getAllUsedDevices().get(resourceName).size());
Assert.assertEquals(3,
dmm.getAllAllowedDevices().get(resourceName).size());
+ Assert.assertEquals(3,
+ dmm.getAllocatedDevices(resourceName, c2.getContainerId()).size());
+ verify(mockShellWrapper, times(0)).getDeviceFileType(anyString());
+ // check device cgroup create operation
+ verify(mockCGroupsHandler).createCGroup(
+ CGroupsHandler.CGroupController.DEVICES,
+ c2.getContainerId().toString());
+ // check device cgroup update operation
+ checkCgroupOperation(c2.getContainerId().toString(), 1,
+ null, "256:0,256:1,256:2");
// postComplete
- adapter.getDeviceResourceHandler().postComplete(getContainerId(1));
+ adapter.getDeviceResourceHandler().postComplete(getContainerId(2));
Assert.assertEquals(3,
dmm.getAvailableDevices(resourceName));
Assert.assertEquals(0,
dmm.getAllUsedDevices().get(resourceName).size());
Assert.assertEquals(3,
dmm.getAllAllowedDevices().get(resourceName).size());
-
- // A container c3 request 0 device
- Container c3 = mockContainerWithDeviceRequest(1,
+ // check cgroup delete operation
+ verify(mockCGroupsHandler).deleteCGroup(
+ CGroupsHandler.CGroupController.DEVICES,
+ c2.getContainerId().toString());
+ // Case 3. A container c3 request 0 device
+ Container c3 = mockContainerWithDeviceRequest(3,
resourceName,
0, false);
+ reset(mockShellWrapper);
+ reset(mockCGroupsHandler);
+ reset(mockPrivilegedExecutor);
+ when(mockShellWrapper.existFile(anyString())).thenReturn(true);
+ when(mockShellWrapper.getDeviceFileType(anyString())).thenReturn("c");
// preStart
adapter.getDeviceResourceHandler().preStart(c3);
// check book keeping
@@ -213,14 +254,57 @@ public void testBasicWorkflow()
dmm.getAllUsedDevices().get(resourceName).size());
Assert.assertEquals(3,
dmm.getAllAllowedDevices().get(resourceName).size());
+ verify(mockShellWrapper, times(3)).getDeviceFileType(anyString());
+ // check device cgroup create operation
+ verify(mockCGroupsHandler).createCGroup(
+ CGroupsHandler.CGroupController.DEVICES,
+ c3.getContainerId().toString());
+ // check device cgroup update operation
+ checkCgroupOperation(c3.getContainerId().toString(), 1,
+ "c-256:0-rwm,c-256:1-rwm,c-256:2-rwm", null);
// postComplete
- adapter.getDeviceResourceHandler().postComplete(getContainerId(1));
+ adapter.getDeviceResourceHandler().postComplete(getContainerId(3));
Assert.assertEquals(3,
dmm.getAvailableDevices(resourceName));
Assert.assertEquals(0,
dmm.getAllUsedDevices().get(resourceName).size());
Assert.assertEquals(3,
dmm.getAllAllowedDevices().get(resourceName).size());
+ Assert.assertEquals(0,
+ dmm.getAllocatedDevices(resourceName, c3.getContainerId()).size());
+ // check cgroup delete operation
+ verify(mockCGroupsHandler).deleteCGroup(
+ CGroupsHandler.CGroupController.DEVICES,
+ c3.getContainerId().toString());
+ }
+
+ private void checkCgroupOperation(String cId,
+ int invokeTimesOfPrivilegedExecutor,
+ String excludedParam, String allowedParam)
+ throws PrivilegedOperationException, ResourceHandlerException {
+ verify(mockCGroupsHandler).createCGroup(
+ CGroupsHandler.CGroupController.DEVICES,
+ cId);
+ // check device cgroup update operation
+ ArgumentCaptor args =
+ ArgumentCaptor.forClass(PrivilegedOperation.class);
+ verify(mockPrivilegedExecutor, times(invokeTimesOfPrivilegedExecutor))
+ .executePrivilegedOperation(args.capture(), eq(true));
+ Assert.assertEquals(PrivilegedOperation.OperationType.DEVICE,
+ args.getValue().getOperationType());
+ List expectedArgs = new ArrayList<>();
+ expectedArgs.add(DeviceResourceHandlerImpl.CONTAINER_ID_CLI_OPTION);
+ expectedArgs.add(cId);
+ if (excludedParam != null && !excludedParam.isEmpty()) {
+ expectedArgs.add(DeviceResourceHandlerImpl.EXCLUDED_DEVICES_CLI_OPTION);
+ expectedArgs.add(excludedParam);
+ }
+ if (allowedParam != null && !allowedParam.isEmpty()) {
+ expectedArgs.add(DeviceResourceHandlerImpl.ALLOWED_DEVICES_CLI_OPTION);
+ expectedArgs.add(allowedParam);
+ }
+ Assert.assertArrayEquals(expectedArgs.toArray(),
+ args.getValue().getArguments().toArray());
}
@Test
@@ -251,6 +335,7 @@ public void testStoreDeviceSchedulerManagerState()
NMStateStoreService realStoreService = new NMMemoryStateStoreService();
NMStateStoreService storeService = spy(realStoreService);
when(context.getNMStateStore()).thenReturn(storeService);
+ when(context.getConf()).thenReturn(this.conf);
doNothing().when(storeService).storeAssignedResources(isA(Container.class),
isA(String.class),
isA(ArrayList.class));
@@ -395,6 +480,7 @@ public void testAssignedDeviceCleanupWhenStoreOpFails()
NodeManager.NMContext context = mock(NodeManager.NMContext.class);
NMStateStoreService realStoreService = new NMMemoryStateStoreService();
NMStateStoreService storeService = spy(realStoreService);
+ when(context.getConf()).thenReturn(this.conf);
when(context.getNMStateStore()).thenReturn(storeService);
doThrow(new IOException("Exception ...")).when(storeService)
.storeAssignedResources(isA(Container.class),
@@ -448,6 +534,7 @@ public void testPreferPluginScheduler() throws IOException, YarnException {
NodeManager.NMContext context = mock(NodeManager.NMContext.class);
NMStateStoreService storeService = mock(NMStateStoreService.class);
when(context.getNMStateStore()).thenReturn(storeService);
+ when(context.getConf()).thenReturn(this.conf);
doNothing().when(storeService).storeAssignedResources(isA(Container.class),
isA(String.class),
isA(ArrayList.class));
@@ -526,6 +613,7 @@ public void testNMResourceInfoRESTAPI() throws IOException, YarnException {
NodeManager.NMContext context = mock(NodeManager.NMContext.class);
NMStateStoreService storeService = mock(NMStateStoreService.class);
when(context.getNMStateStore()).thenReturn(storeService);
+ when(context.getConf()).thenReturn(this.conf);
doNothing().when(storeService).storeAssignedResources(isA(Container.class),
isA(String.class),
isA(ArrayList.class));
@@ -584,6 +672,206 @@ public void testNMResourceInfoRESTAPI() throws IOException, YarnException {
Assert.assertEquals(3, response.getTotalDevices().size());
}
+ /**
+ * Test a container run command update when using Docker runtime.
+ * And the device plugin it uses is like Nvidia Docker v1.
+ * */
+ @Test
+ public void testDeviceResourceDockerRuntimePlugin1() throws Exception {
+ NodeManager.NMContext context = mock(NodeManager.NMContext.class);
+ NMStateStoreService storeService = mock(NMStateStoreService.class);
+ when(context.getNMStateStore()).thenReturn(storeService);
+ when(context.getConf()).thenReturn(this.conf);
+ doNothing().when(storeService).storeAssignedResources(isA(Container.class),
+ isA(String.class),
+ isA(ArrayList.class));
+ // Init scheduler manager
+ DeviceMappingManager dmm = new DeviceMappingManager(context);
+ DeviceMappingManager spyDmm = spy(dmm);
+ ResourcePluginManager rpm = mock(ResourcePluginManager.class);
+ when(rpm.getDeviceMappingManager()).thenReturn(spyDmm);
+ // Init a plugin
+ MyPlugin plugin = new MyPlugin();
+ MyPlugin spyPlugin = spy(plugin);
+ String resourceName = MyPlugin.RESOURCE_NAME;
+ // Init an adapter for the plugin
+ DevicePluginAdapter adapter = new DevicePluginAdapter(
+ resourceName,
+ spyPlugin, spyDmm);
+ adapter.initialize(context);
+ // Bootstrap, adding device
+ adapter.initialize(context);
+ adapter.createResourceHandler(context,
+ mockCGroupsHandler, mockPrivilegedExecutor);
+ adapter.getDeviceResourceHandler().bootstrap(conf);
+ // Case 1. A container request Docker runtime and 1 device
+ Container c1 = mockContainerWithDeviceRequest(1, resourceName, 1, true);
+ // generate spec based on v1
+ spyPlugin.setDevicePluginVersion("v1");
+ // preStart will do allocation
+ adapter.getDeviceResourceHandler().preStart(c1);
+ Set allocatedDevice = spyDmm.getAllocatedDevices(resourceName,
+ c1.getContainerId());
+ reset(spyDmm);
+ // c1 is requesting docker runtime.
+ // it will create parent cgroup but no cgroups update operation needed.
+ // check device cgroup create operation
+ verify(mockCGroupsHandler).createCGroup(
+ CGroupsHandler.CGroupController.DEVICES,
+ c1.getContainerId().toString());
+ // ensure no cgroups update operation
+ verify(mockPrivilegedExecutor, times(0))
+ .executePrivilegedOperation(
+ any(PrivilegedOperation.class), anyBoolean());
+ DockerCommandPlugin dcp = adapter.getDockerCommandPluginInstance();
+ // When DockerLinuxContainerRuntime invoke the DockerCommandPluginInstance
+ // First to create volume
+ DockerVolumeCommand dvc = dcp.getCreateDockerVolumeCommand(c1);
+ // ensure that allocation is get once from device mapping manager
+ verify(spyDmm).getAllocatedDevices(resourceName, c1.getContainerId());
+ // ensure that plugin's onDeviceAllocated is invoked
+ verify(spyPlugin).onDevicesAllocated(
+ allocatedDevice,
+ YarnRuntimeType.RUNTIME_DEFAULT);
+ verify(spyPlugin).onDevicesAllocated(
+ allocatedDevice,
+ YarnRuntimeType.RUNTIME_DOCKER);
+ Assert.assertEquals("nvidia-docker", dvc.getDriverName());
+ Assert.assertEquals("create", dvc.getSubCommand());
+ Assert.assertEquals("nvidia_driver_352.68", dvc.getVolumeName());
+
+ // then the DockerLinuxContainerRuntime will update docker run command
+ DockerRunCommand drc =
+ new DockerRunCommand(c1.getContainerId().toString(), "user",
+ "image/tensorflow");
+ // reset to avoid count times in above invocation
+ reset(spyPlugin);
+ reset(spyDmm);
+ // Second, update the run command.
+ dcp.updateDockerRunCommand(drc, c1);
+ // The spec is already generated in getCreateDockerVolumeCommand
+ // and there should be a cache hit for DeviceRuntime spec.
+ verify(spyPlugin, times(0)).onDevicesAllocated(
+ allocatedDevice,
+ YarnRuntimeType.RUNTIME_DOCKER);
+ // ensure that allocation is get from cache instead of device mapping
+ // manager
+ verify(spyDmm, times(0)).getAllocatedDevices(resourceName,
+ c1.getContainerId());
+ String runStr = drc.toString();
+ Assert.assertTrue(
+ runStr.contains("nvidia_driver_352.68:/usr/local/nvidia:ro"));
+ Assert.assertTrue(runStr.contains("/dev/hdwA0:/dev/hdwA0"));
+ // Third, cleanup in getCleanupDockerVolumesCommand
+ dcp.getCleanupDockerVolumesCommand(c1);
+ // Ensure device plugin's onDeviceReleased is invoked
+ verify(spyPlugin).onDevicesReleased(allocatedDevice);
+ // If we run the c1 again. No cache will be used for allocation and spec
+ dcp.getCreateDockerVolumeCommand(c1);
+ verify(spyDmm).getAllocatedDevices(resourceName, c1.getContainerId());
+ verify(spyPlugin).onDevicesAllocated(
+ allocatedDevice,
+ YarnRuntimeType.RUNTIME_DOCKER);
+ }
+
+ /**
+ * Test a container run command update when using Docker runtime.
+ * And the device plugin it uses is like Nvidia Docker v2.
+ * */
+ @Test
+ public void testDeviceResourceDockerRuntimePlugin2() throws Exception {
+ NodeManager.NMContext context = mock(NodeManager.NMContext.class);
+ NMStateStoreService storeService = mock(NMStateStoreService.class);
+ when(context.getNMStateStore()).thenReturn(storeService);
+ when(context.getConf()).thenReturn(this.conf);
+ doNothing().when(storeService).storeAssignedResources(isA(Container.class),
+ isA(String.class),
+ isA(ArrayList.class));
+ // Init scheduler manager
+ DeviceMappingManager dmm = new DeviceMappingManager(context);
+ DeviceMappingManager spyDmm = spy(dmm);
+ ResourcePluginManager rpm = mock(ResourcePluginManager.class);
+ when(rpm.getDeviceMappingManager()).thenReturn(spyDmm);
+ // Init a plugin
+ MyPlugin plugin = new MyPlugin();
+ MyPlugin spyPlugin = spy(plugin);
+ String resourceName = MyPlugin.RESOURCE_NAME;
+ // Init an adapter for the plugin
+ DevicePluginAdapter adapter = new DevicePluginAdapter(
+ resourceName,
+ spyPlugin, spyDmm);
+ adapter.initialize(context);
+ // Bootstrap, adding device
+ adapter.initialize(context);
+ adapter.createResourceHandler(context,
+ mockCGroupsHandler, mockPrivilegedExecutor);
+ adapter.getDeviceResourceHandler().bootstrap(conf);
+ // Case 1. A container request Docker runtime and 1 device
+ Container c1 = mockContainerWithDeviceRequest(1, resourceName, 2, true);
+ // generate spec based on v2
+ spyPlugin.setDevicePluginVersion("v2");
+ // preStart will do allocation
+ adapter.getDeviceResourceHandler().preStart(c1);
+ Set allocatedDevice = spyDmm.getAllocatedDevices(resourceName,
+ c1.getContainerId());
+ reset(spyDmm);
+ // c1 is requesting docker runtime.
+ // it will create parent cgroup but no cgroups update operation needed.
+ // check device cgroup create operation
+ verify(mockCGroupsHandler).createCGroup(
+ CGroupsHandler.CGroupController.DEVICES,
+ c1.getContainerId().toString());
+ // ensure no cgroups update operation
+ verify(mockPrivilegedExecutor, times(0))
+ .executePrivilegedOperation(
+ any(PrivilegedOperation.class), anyBoolean());
+ DockerCommandPlugin dcp = adapter.getDockerCommandPluginInstance();
+ // When DockerLinuxContainerRuntime invoke the DockerCommandPluginInstance
+ // First to create volume
+ DockerVolumeCommand dvc = dcp.getCreateDockerVolumeCommand(c1);
+ // ensure that allocation is get once from device mapping manager
+ verify(spyDmm).getAllocatedDevices(resourceName, c1.getContainerId());
+ // ensure that plugin's onDeviceAllocated is invoked
+ verify(spyPlugin).onDevicesAllocated(
+ allocatedDevice,
+ YarnRuntimeType.RUNTIME_DEFAULT);
+ verify(spyPlugin).onDevicesAllocated(
+ allocatedDevice,
+ YarnRuntimeType.RUNTIME_DOCKER);
+ // No volume creation request
+ Assert.assertNull(dvc);
+
+ // then the DockerLinuxContainerRuntime will update docker run command
+ DockerRunCommand drc =
+ new DockerRunCommand(c1.getContainerId().toString(), "user",
+ "image/tensorflow");
+ // reset to avoid count times in above invocation
+ reset(spyPlugin);
+ reset(spyDmm);
+ // Second, update the run command.
+ dcp.updateDockerRunCommand(drc, c1);
+ // The spec is already generated in getCreateDockerVolumeCommand
+ // and there should be a cache hit for DeviceRuntime spec.
+ verify(spyPlugin, times(0)).onDevicesAllocated(
+ allocatedDevice,
+ YarnRuntimeType.RUNTIME_DOCKER);
+ // ensure that allocation is get once from device mapping manager
+ verify(spyDmm, times(0)).getAllocatedDevices(resourceName,
+ c1.getContainerId());
+ Assert.assertEquals("0,1", drc.getEnv().get("NVIDIA_VISIBLE_DEVICES"));
+ Assert.assertTrue(drc.toString().contains("runtime=nvidia"));
+ // Third, cleanup in getCleanupDockerVolumesCommand
+ dcp.getCleanupDockerVolumesCommand(c1);
+ // Ensure device plugin's onDeviceReleased is invoked
+ verify(spyPlugin).onDevicesReleased(allocatedDevice);
+ // If we run the c1 again. No cache will be used for allocation and spec
+ dcp.getCreateDockerVolumeCommand(c1);
+ verify(spyDmm).getAllocatedDevices(resourceName, c1.getContainerId());
+ verify(spyPlugin).onDevicesAllocated(
+ allocatedDevice,
+ YarnRuntimeType.RUNTIME_DOCKER);
+ }
+
private static ContainerId getContainerId(int id) {
return ContainerId.newContainerId(ApplicationAttemptId
.newInstance(ApplicationId.newInstance(1234L, 1), 1), id);
@@ -591,6 +879,15 @@ private static ContainerId getContainerId(int id) {
private class MyPlugin implements DevicePlugin, DevicePluginScheduler {
private final static String RESOURCE_NAME = "cmpA.com/hdwA";
+
+ // v1 means the vendor uses the similar way of Nvidia Docker v1
+ // v2 means the vendor user the similar way of Nvidia Docker v2
+ private String devicePluginVersion = "v2";
+
+ public void setDevicePluginVersion(String version) {
+ devicePluginVersion = version;
+ }
+
@Override
public DeviceRegisterRequest getRegisterRequestInfo() {
return DeviceRegisterRequest.Builder.newInstance()
@@ -613,7 +910,7 @@ public DeviceRegisterRequest getRegisterRequestInfo() {
.setId(1)
.setDevPath("/dev/hdwA1")
.setMajorNumber(256)
- .setMinorNumber(0)
+ .setMinorNumber(1)
.setBusID("0000:80:01.0")
.setHealthy(true)
.build());
@@ -621,7 +918,7 @@ public DeviceRegisterRequest getRegisterRequestInfo() {
.setId(2)
.setDevPath("/dev/hdwA2")
.setMajorNumber(256)
- .setMinorNumber(0)
+ .setMinorNumber(2)
.setBusID("0000:80:02.0")
.setHealthy(true)
.build());
@@ -631,12 +928,69 @@ public DeviceRegisterRequest getRegisterRequestInfo() {
@Override
public DeviceRuntimeSpec onDevicesAllocated(Set allocatedDevices,
YarnRuntimeType yarnRuntime) throws Exception {
+ if (yarnRuntime == YarnRuntimeType.RUNTIME_DEFAULT) {
+ return null;
+ }
+ if (yarnRuntime == YarnRuntimeType.RUNTIME_DOCKER) {
+ return generateSpec(devicePluginVersion, allocatedDevices);
+ }
return null;
}
+ private DeviceRuntimeSpec generateSpec(String version,
+ Set allocatedDevices) {
+ DeviceRuntimeSpec.Builder builder =
+ DeviceRuntimeSpec.Builder.newInstance();
+ if (version.equals("v1")) {
+ // Nvidia v1 examples like below. These info is get from Nvidia v1
+ // RESTful.
+ // --device=/dev/nvidiactl --device=/dev/nvidia-uvm
+ // --device=/dev/nvidia0
+ // --volume-driver=nvidia-docker
+ // --volume=nvidia_driver_352.68:/usr/local/nvidia:ro
+ String volumeDriverName = "nvidia-docker";
+ String volumeToBeCreated = "nvidia_driver_352.68";
+ String volumePathInContainer = "/usr/local/nvidia";
+ // describe volumes to be created and mounted
+ builder.addVolumeSpec(
+ VolumeSpec.Builder.newInstance()
+ .setVolumeDriver(volumeDriverName)
+ .setVolumeName(volumeToBeCreated)
+ .setVolumeOperation(VolumeSpec.CREATE).build())
+ .addMountVolumeSpec(
+ MountVolumeSpec.Builder.newInstance()
+ .setHostPath(volumeToBeCreated)
+ .setMountPath(volumePathInContainer)
+ .setReadOnly(true).build());
+ // describe devices to be mounted
+ for (Device device : allocatedDevices) {
+ builder.addMountDeviceSpec(
+ MountDeviceSpec.Builder.newInstance()
+ .setDevicePathInHost(device.getDevPath())
+ .setDevicePathInContainer(device.getDevPath())
+ .setDevicePermission(MountDeviceSpec.RW).build());
+ }
+ }
+ if (version.equals("v2")) {
+ String nvidiaRuntime = "nvidia";
+ String nvidiaVisibleDevices = "NVIDIA_VISIBLE_DEVICES";
+ StringBuffer gpuMinorNumbersSB = new StringBuffer();
+ for (Device device : allocatedDevices) {
+ gpuMinorNumbersSB.append(device.getMinorNumber() + ",");
+ }
+ String minorNumbers = gpuMinorNumbersSB.toString();
+ // set runtime and environment variable is enough for
+ // plugin like Nvidia Docker v2
+ builder.addEnv(nvidiaVisibleDevices,
+ minorNumbers.substring(0, minorNumbers.length() - 1))
+ .setContainerRuntime(nvidiaRuntime);
+ }
+ return builder.build();
+ }
+
@Override
public void onDevicesReleased(Set releasedDevices) {
-
+ // nothing to do
}
@Override
diff --git a/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/test/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/nvidia/com/TestNvidiaGpuPlugin.java b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/test/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/nvidia/com/TestNvidiaGpuPlugin.java
new file mode 100644
index 00000000000..33154d84eaf
--- /dev/null
+++ b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/test/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/nvidia/com/TestNvidiaGpuPlugin.java
@@ -0,0 +1,108 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.hadoop.yarn.server.nodemanager.containermanager.resourceplugin.nvidia.com;
+
+import org.apache.hadoop.yarn.server.nodemanager.api.deviceplugin.Device;
+import org.apache.hadoop.yarn.server.nodemanager.api.deviceplugin.DeviceRuntimeSpec;
+import org.apache.hadoop.yarn.server.nodemanager.api.deviceplugin.YarnRuntimeType;
+import org.apache.hadoop.yarn.server.nodemanager.containermanager.resourceplugin.com.nvidia.NvidiaGPUPluginForRuntimeV2;
+import org.junit.Assert;
+import org.junit.Test;
+
+import java.util.Set;
+import java.util.TreeSet;
+
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.when;
+
+/**
+ * Test case for Nvidia GPU device plugin.
+ * */
+public class TestNvidiaGpuPlugin {
+
+ @Test
+ public void testGetNvidiaDevices() throws Exception {
+ NvidiaGPUPluginForRuntimeV2.NvidiaCommandExecutor mockShell =
+ mock(NvidiaGPUPluginForRuntimeV2.NvidiaCommandExecutor.class);
+ String deviceInfoShellOutput =
+ "0, 00000000:04:00.0\n" +
+ "1, 00000000:82:00.0";
+ String majorMinorNumber0 = "c3:0";
+ String majorMinorNumber1 = "c3:1";
+ when(mockShell.getDeviceInfo()).thenReturn(deviceInfoShellOutput);
+ when(mockShell.getMajorMinorInfo("nvidia0"))
+ .thenReturn(majorMinorNumber0);
+ when(mockShell.getMajorMinorInfo("nvidia1"))
+ .thenReturn(majorMinorNumber1);
+ NvidiaGPUPluginForRuntimeV2 plugin = new NvidiaGPUPluginForRuntimeV2();
+ plugin.setShellExecutor(mockShell);
+ plugin.setPathOfGpuBinary("/fake/nvidia-smi");
+
+ Set expectedDevices = new TreeSet<>();
+ expectedDevices.add(Device.Builder.newInstance()
+ .setId(0).setHealthy(true)
+ .setBusID("00000000:04:00.0")
+ .setDevPath("/dev/nvidia0")
+ .setMajorNumber(195)
+ .setMinorNumber(0).build());
+ expectedDevices.add(Device.Builder.newInstance()
+ .setId(1).setHealthy(true)
+ .setBusID("00000000:82:00.0")
+ .setDevPath("/dev/nvidia1")
+ .setMajorNumber(195)
+ .setMinorNumber(1).build());
+ Set devices = plugin.getDevices();
+ Assert.assertEquals(expectedDevices, devices);
+ }
+
+ @Test
+ public void testOnDeviceAllocated() throws Exception {
+ NvidiaGPUPluginForRuntimeV2 plugin = new NvidiaGPUPluginForRuntimeV2();
+ Set allocatedDevices = new TreeSet<>();
+
+ DeviceRuntimeSpec spec = plugin.onDevicesAllocated(allocatedDevices,
+ YarnRuntimeType.RUNTIME_DEFAULT);
+ Assert.assertNull(spec);
+
+ // allocate one device
+ allocatedDevices.add(Device.Builder.newInstance()
+ .setId(0).setHealthy(true)
+ .setBusID("00000000:04:00.0")
+ .setDevPath("/dev/nvidia0")
+ .setMajorNumber(195)
+ .setMinorNumber(0).build());
+ spec = plugin.onDevicesAllocated(allocatedDevices,
+ YarnRuntimeType.RUNTIME_DOCKER);
+ Assert.assertEquals("nvidia", spec.getContainerRuntime());
+ Assert.assertEquals("0", spec.getEnvs().get("NVIDIA_VISIBLE_DEVICES"));
+
+ // two device allowed
+ allocatedDevices.add(Device.Builder.newInstance()
+ .setId(0).setHealthy(true)
+ .setBusID("00000000:82:00.0")
+ .setDevPath("/dev/nvidia1")
+ .setMajorNumber(195)
+ .setMinorNumber(1).build());
+ spec = plugin.onDevicesAllocated(allocatedDevices,
+ YarnRuntimeType.RUNTIME_DOCKER);
+ Assert.assertEquals("nvidia", spec.getContainerRuntime());
+ Assert.assertEquals("0,1", spec.getEnvs().get("NVIDIA_VISIBLE_DEVICES"));
+
+ }
+}