diff --git a/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/api/deviceplugin/DevicePluginScheduler.java b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/api/deviceplugin/DevicePluginScheduler.java index 80ffdf73829..479a19db603 100644 --- a/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/api/deviceplugin/DevicePluginScheduler.java +++ b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/api/deviceplugin/DevicePluginScheduler.java @@ -18,6 +18,7 @@ package org.apache.hadoop.yarn.server.nodemanager.api.deviceplugin; +import java.util.Map; import java.util.Set; /** @@ -29,10 +30,15 @@ * Called when allocating devices. The framework will do all device book * keeping and fail recovery. So this hook could be stateless and only do * scheduling based on available devices passed in. It could be - * invoked multiple times by the framework. + * invoked multiple times by the framework. The hint in environment variables + * passed in could be potentially used in making better scheduling decision. + * For instance, GPU scheduling might support different kind of policy. The + * container can set it through environment variables. * @param availableDevices Devices allowed to be chosen from. * @param count Number of device to be allocated. + * @param env Environment variables of the container. * @return A set of {@link Device} allocated * */ - Set allocateDevices(Set availableDevices, int count); + Set allocateDevices(Set availableDevices, int count, + Map env); } diff --git a/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/com/nvidia/NvidiaGPUPlugin.java b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/com/nvidia/NvidiaGPUPlugin.java index cac04958b50..81b84a1eed7 100644 --- a/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/com/nvidia/NvidiaGPUPlugin.java +++ b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/com/nvidia/NvidiaGPUPlugin.java @@ -24,6 +24,7 @@ import org.apache.hadoop.yarn.exceptions.YarnException; import org.apache.hadoop.yarn.server.nodemanager.api.deviceplugin.Device; import org.apache.hadoop.yarn.server.nodemanager.api.deviceplugin.DevicePlugin; +import org.apache.hadoop.yarn.server.nodemanager.api.deviceplugin.DevicePluginScheduler; import org.apache.hadoop.yarn.server.nodemanager.api.deviceplugin.DeviceRegisterRequest; import org.apache.hadoop.yarn.server.nodemanager.api.deviceplugin.DeviceRuntimeSpec; import org.apache.hadoop.yarn.server.nodemanager.api.deviceplugin.YarnRuntimeType; @@ -32,15 +33,20 @@ import java.io.File; import java.io.IOException; +import java.util.Arrays; +import java.util.Collections; import java.util.HashMap; +import java.util.LinkedList; +import java.util.List; import java.util.Map; import java.util.Set; import java.util.TreeSet; /** * Nvidia GPU plugin supporting both Nvidia Docker v2 and non-Docker container. + * It has topology aware as well as simple scheduling ability. * */ -public class NvidiaGPUPlugin implements DevicePlugin { +public class NvidiaGPUPlugin implements DevicePlugin, DevicePluginScheduler { public static final Logger LOG = LoggerFactory.getLogger( NvidiaGPUPlugin.class); @@ -66,6 +72,45 @@ private static final Set DEFAULT_BINARY_SEARCH_DIRS = ImmutableSet.of( "/usr/bin", "/bin", "/usr/local/nvidia/bin"); + private boolean topoInitialized = false; + + private Set lastTimeFoundDevices; + + /** + * It caches the combination of different devices and the communication cost. + * The key is device count + * The value is a map whose key is device combination, value is cost. + * For instance: + * { 2=> {[device1,device2]=>0, [device1,device3]=>10} + * 3 => {[device1,device2,device3]=>10, [device2,device3,device5]=>20}, + * } + * */ + private Map, Integer>> costTable = new HashMap<>(); + + /** + * The key is a pair of minors. For instance, "0-1" indicates 0 to 1 + * The value is weight between the two devices. + * */ + private Map devicePairToWeight = new HashMap<>(); + + /** + * The container set this environment variable to tell the scheduler what's + * the policy to use when do scheduling + * */ + public static final String TOPOLOGY_POLICY_ENV_KEY = "NVIDIA_TOPO_POLICY"; + + /** + * Schedule policy that prefer the faster GPU-GPU communication. + * Suitable for heavy GPU computation workload generally. + * */ + public static final String TOPOLOGY_POLICY_PACK = "PACK"; + + /** + * Schedule policy that prefer the faster CPU-GPU communication. + * Suitable for heavy CPU-GPU IO operations generally. + * */ + public static final String TOPOLOGY_POLICY_SPREAD = "SPREAD"; + @Override public DeviceRegisterRequest getRegisterRequestInfo() throws Exception { return DeviceRegisterRequest.Builder.newInstance() @@ -96,8 +141,11 @@ public DeviceRegisterRequest getRegisterRequestInfo() throws Exception { .setDevPath("/dev/" + DEV_NAME_PREFIX + minorNumber) .setHealthy(true) .build()); + id++; } } + // cache it which help to topology scheduling + lastTimeFoundDevices = r; return r; } catch (IOException e) { LOG.debug("Failed to get output from " + pathOfGpuBinary); @@ -154,6 +202,347 @@ private String getMajorNumber(String devName) { return output; } + @Override + public Set allocateDevices(Set availableDevices, int count, + Map envs) { + Set allocation = new TreeSet<>(); + // corner cases + if (count == 1 || count == availableDevices.size()) { + basicSchedule(allocation, count, availableDevices); + return allocation; + } + try { + if (!topoInitialized) { + initCostTable(); + } + // topology aware scheduling + topologyAwareSchedule(allocation, count, + envs, availableDevices, this.costTable); + if (allocation.size() != count) { + LOG.error("Failed to do topology scheduling. Skip to use basic " + + "scheduling"); + } + return allocation; + } catch (IOException e) { + LOG.error("Error in getting GPU topology info. " + + "Skip topology aware scheduling"); + } + // basic scheduling + basicSchedule(allocation, count, availableDevices); + return allocation; + } + + @VisibleForTesting + public void initCostTable() throws IOException { + // get topology + String topo = shellExecutor.getTopologyInfo(); + // build the graph + parseTopo(topo, devicePairToWeight); + // build the cost table of different device combinations + if (lastTimeFoundDevices == null) { + try { + getDevices(); + } catch (Exception e) { + LOG.error("Failed to get devices!"); + return; + } + } + buildCostTable(costTable, lastTimeFoundDevices); + this.topoInitialized = true; + } + + /** + * Generate combination of devices and its cost. + * costTable + * */ + private void buildCostTable( + Map, Integer>> costTable, + Set lastTimeFoundDevices) { + Device[] deviceList = new Device[lastTimeFoundDevices.size()]; + lastTimeFoundDevices.toArray(deviceList); + generateAllDeviceCombination(costTable, deviceList, deviceList.length); + } + + /** + * For every possible combination of i elements. + * We generate a map whose key is the combination, value is cost. + */ + private void generateAllDeviceCombination( + Map, Integer>> costTable, + Device[] allDevices, int n) { + // allocated devices count range from 1 to n-1 + for (int i = 2; i < n; i++) { + Map, Integer> combinationToCost = + new HashMap<>(); + buildCombination(combinationToCost, allDevices, n, i); + costTable.put(i, combinationToCost); + } + } + + private void buildCombination(Map, Integer> combinationToCost, + Device[] allDevices, int n, int r) { + // A temporary list to store all combination one by one + Device[] subDeviceList = new Device[r]; + combinationRecursive(combinationToCost, allDevices, subDeviceList, + 0, n - 1, 0, r); + } + + /** + * Populate combination to cost map recursively. + * + * @param cTc combinationToCost map. The key is device set, the value is cost + * @param allDevices all devices used to assign value to subDevicelist + * @param subDeviceList store a subset of devices temporary + * @param start start index in the allDevices + * @param end last index in the allDevices + * @param index dynamic index in the subDeviceList need to be assigned + * @param r the length of the subDeviceList + */ + void combinationRecursive(Map, Integer> cTc, + Device[] allDevices, Device[] subDeviceList, + int start, int end, int index, int r) { + // sub device list's length is ready to compute the cost + if (index == r) { + Set oneSet = new TreeSet<>(Arrays.asList(subDeviceList)); + int cost = computeCostOfDevices(subDeviceList); + cTc.put(oneSet, cost); + return; + } + for (int i = start; i <= end; i++) { + subDeviceList[index] = allDevices[i]; + combinationRecursive(cTc, allDevices, subDeviceList, + i + 1, end, index + 1, r); + } + } + + /** + * The cost function used to calculate costs of a sub set of devices. + * It calculate link weight of each pair in non-duplicated combination of + * devices. + */ + @VisibleForTesting + public int computeCostOfDevices(Device[] devices) { + int cost = 0; + String gpuIndex0; + String gpuIndex1; + for (int i = 0; i < devices.length; i++) { + gpuIndex0 = String.valueOf(devices[i].getMinorNumber()); + for (int j = i + 1; j < devices.length; j++) { + gpuIndex1 = String.valueOf(devices[j].getMinorNumber()); + cost += this.devicePairToWeight.get(gpuIndex0 + "-" + gpuIndex1); + } + } + return cost; + } + + /** + * Topology Aware schedule algorithm. + * It doesn't consider CPU affinity or NUMA or bus bandwidths. + * It support two plicy: "spread" and "pack" which can be set by container's + * environment variable. Use pack by default which means prefer the faster + * GPU-GPU. "Spread" means prefer the faster CPU-GPU. + * It can potentially be extend to take GPU attribute like GPU chip memory + * into consideration. + * */ + @VisibleForTesting + public void topologyAwareSchedule(Set allocation, int count, + Map envs, + Set availableDevices, + Map, Integer>> costTable) { + int num = 0; + String policy = envs.get(TOPOLOGY_POLICY_ENV_KEY); + if (policy == null) { + policy = TOPOLOGY_POLICY_PACK; + } + + /** + * Get combinations from costTable given the count of device want to + * allocate. + * */ + if (costTable == null) { + LOG.error("No cost table initialized!"); + return; + } + Map, Integer> combinationsToCost = costTable.get(count); + List, Integer>> listSortedByCost = + new LinkedList<>(combinationsToCost.entrySet()); + + // the container needs PACK policy + if (policy.equalsIgnoreCase(TOPOLOGY_POLICY_PACK)) { + Collections.sort(listSortedByCost, + (o1, o2) -> (o1.getValue()).compareTo(o2.getValue())); + // search from low cost to high cost for combinations of count devices + for (Map.Entry, Integer> entry : listSortedByCost) { + if (availableDevices.containsAll(entry.getKey())) { + allocation.addAll(entry.getKey()); + LOG.info("Topology scheduler allocated: " + allocation); + return; + } + } + LOG.error("Unknown error happened in topology scheduler"); + } + // the container needs spread policy + if (policy.equalsIgnoreCase(TOPOLOGY_POLICY_SPREAD)) { + Collections.sort(listSortedByCost, + (o1, o2) -> (o2.getValue()).compareTo(o1.getValue())); + // search from high cost to low cost + for (Map.Entry, Integer> entry : listSortedByCost) { + if (availableDevices.containsAll(entry.getKey())) { + allocation.addAll(entry.getKey()); + LOG.info("Topology scheduler allocated: " + allocation); + return; + } + } + LOG.error("Unknown error happened in topology scheduler"); + } + } + + @VisibleForTesting + public void basicSchedule(Set allocation, int count, + Set availableDevices) { + // Basic scheduling + // allocate all available + if (count == availableDevices.size()) { + allocation.addAll(availableDevices); + return; + } + int number = 0; + for (Device d : availableDevices) { + allocation.add(d); + number++; + if (number == count) { + break; + } + } + } + + /** + * A typical sample topo output: + * + * GPU0 GPU1 GPU2 GPU3 CPU Affinity + * GPU0 X PHB SOC SOC 0-31 + * GPU1 PHB X SOC SOC 0-31 + * GPU2 SOC SOC X PHB 0-31 + * GPU3 SOC SOC PHB X 0-31 + * + * + * Legend: + * + * X = Self + * SOC = Connection traversing PCIe as well as the SMP link between + * CPU sockets(e.g. QPI) + * PHB = Connection traversing PCIe as well as a PCIe Host Bridge + * (typically the CPU) + * PXB = Connection traversing multiple PCIe switches + * (without traversing the PCIe Host Bridge) + * PIX = Connection traversing a single PCIe switch + * NV# = Connection traversing a bonded set of # NVLinks」 + * */ + public void parseTopo(String topo, + Map deviceLinkToWeight) { + String[] lines = topo.split("\n"); + int rowMinor; + int colMinor; + String legend; + String tempType; + for (String oneLine : lines) { + oneLine = oneLine.trim(); + if (oneLine.isEmpty()) { + continue; + } + // To the end. No more metrics info + if (oneLine.startsWith("Legend")) { + break; + } + // Skip header + if (oneLine.contains("Affinity")) { + continue; + } + String[] tokens = oneLine.split(("\\s+")); + String name = tokens[0]; + rowMinor = Integer.parseInt(name.substring(name.lastIndexOf("U") + 1)); + for (int i = 1; i < tokens.length; i++) { + tempType = tokens[i]; + colMinor = i - 1; + // self, skip + if (tempType.equals("X")) { + continue; + } + if (tempType.equals("SOC") || tempType.equals("SYS")) { + populateGraphEdgeWeight(DeviceLinkType.P2PLinkCrossCPUSocket, + rowMinor, colMinor, deviceLinkToWeight); + } + if (tempType.equals("PHB") || tempType.equals("NODE")) { + populateGraphEdgeWeight(DeviceLinkType.P2PLinkSameCPUSocket, + rowMinor, colMinor, deviceLinkToWeight); + } + if (tempType.equals("PXB")) { + populateGraphEdgeWeight(DeviceLinkType.P2PLinkMultiSwitch, + rowMinor, colMinor, deviceLinkToWeight); + } + if (tempType.equals("PIX")) { + populateGraphEdgeWeight(DeviceLinkType.P2PLinkSingleSwitch, + rowMinor, colMinor, deviceLinkToWeight); + } + if (tempType.startsWith("NV")) { + populateGraphEdgeWeight(DeviceLinkType.P2PLinkNVLink, + rowMinor, colMinor, deviceLinkToWeight); + } + } // end one line handling + } + } + + private void populateGraphEdgeWeight( + DeviceLinkType linkType, + int leftVertex, + int rightVertex, + Map deviceLinkToWeight) { + deviceLinkToWeight.putIfAbsent(leftVertex + "-" + rightVertex, + linkType.getWeight()); + } + + /** + * Different type of link + * */ + public enum DeviceLinkType { + /** + * For Nvdia GPU NVLink + * */ + P2PLinkNVLink(1), + + /** + * Connected to same CPU (Same NUMA node) + * */ + P2PLinkSameCPUSocket(2), + + /** + * Cross CPU through socket-level link (e.g. QPI). + * Usually cross NUMA node + * */ + P2PLinkCrossCPUSocket(4), + + /** + * Just need to traverse one PCIe switch to talk + * */ + P2PLinkSingleSwitch(16), + + /** + * Need to traverse multiple PCIe switch to talk + * */ + P2PLinkMultiSwitch(32); + + // A higher link level means slower communication + private int weight; + + public int getWeight() { + return weight; + } + + DeviceLinkType(int w) { + this.weight = w; + } + } + /** * A shell wrapper class easy for test. * */ @@ -174,6 +563,13 @@ public String getMajorMinorInfo(String devName) throws IOException { return shexec.getOutput(); } + // Get the topology metrics info from nvdia-smi + public String getTopologyInfo() throws IOException { + return Shell.execCommand(environment, + new String[]{pathOfGpuBinary, "topo", + "-m"}, MAX_EXEC_TIMEOUT_MS); + } + public void searchBinary() throws Exception { if (pathOfGpuBinary != null) { return; @@ -218,4 +614,20 @@ public void setShellExecutor( MyShellExecutor shellExecutor) { this.shellExecutor = shellExecutor; } + + @VisibleForTesting + public boolean isTopoInitialized() { + return topoInitialized; + } + + @VisibleForTesting + public Map, Integer>> getCostTable() { + return costTable; + } + + @VisibleForTesting + public Map getDevicePairToWeight() { + return devicePairToWeight; + } + } diff --git a/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/deviceframework/DeviceMappingManager.java b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/deviceframework/DeviceMappingManager.java index c4003ca6c22..0e737747f6e 100644 --- a/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/deviceframework/DeviceMappingManager.java +++ b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/deviceframework/DeviceMappingManager.java @@ -19,6 +19,7 @@ package org.apache.hadoop.yarn.server.nodemanager.containermanager.resourceplugin.deviceframework; import com.google.common.annotations.VisibleForTesting; +import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; import com.google.common.collect.Sets; import org.apache.commons.logging.Log; @@ -190,7 +191,8 @@ private synchronized DeviceAllocation internalAssignDevices( DevicePluginScheduler dps = devicePluginSchedulers.get(resourceName); // Prefer DevicePluginScheduler logic pickAndDoSchedule(allowedDevices, usedDevices, assignedDevices, - containerId, requestedDeviceCount, resourceName, dps); + containerId, requestedDeviceCount, resourceName, dps, + container.getLaunchContext().getEnvironment()); // Record in state store if we allocated anything if (!assignedDevices.isEmpty()) { @@ -310,7 +312,8 @@ private long getReleasingDevices(String resourceName) { private void pickAndDoSchedule(Set allowed, Map used, Set assigned, ContainerId containerId, int count, String resourceName, - DevicePluginScheduler dps) throws ResourceHandlerException { + DevicePluginScheduler dps, Map env) + throws ResourceHandlerException { if (null == dps) { LOG.debug("Customized device plugin scheduler is preferred " @@ -326,7 +329,8 @@ private void pickAndDoSchedule(Set allowed, // Pass in unmodifiable set Set dpsAllocated = dps.allocateDevices( Sets.difference(allowed, used.keySet()), - count); + count, + ImmutableMap.copyOf(env)); if (dpsAllocated.size() != count) { throw new ResourceHandlerException(dps.getClass() + " should allocate " + count diff --git a/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/test/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/deviceframework/FakeTestDevicePlugin1.java b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/test/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/deviceframework/FakeTestDevicePlugin1.java index 12f106411b4..69736fd9696 100644 --- a/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/test/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/deviceframework/FakeTestDevicePlugin1.java +++ b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/test/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/deviceframework/FakeTestDevicePlugin1.java @@ -20,6 +20,7 @@ import org.apache.hadoop.yarn.server.nodemanager.api.deviceplugin.*; +import java.util.Map; import java.util.Set; import java.util.TreeSet; @@ -62,7 +63,7 @@ public void onDevicesReleased(Set allocatedDevices) { @Override public Set allocateDevices(Set availableDevices, - int count) { + int count, Map env) { Set allocated = new TreeSet(); int number = 0; for (Device d : availableDevices) { diff --git a/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/test/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/deviceframework/TestDevicePluginAdapter.java b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/test/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/deviceframework/TestDevicePluginAdapter.java index 457a5f2fcca..da10b559413 100644 --- a/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/test/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/deviceframework/TestDevicePluginAdapter.java +++ b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/test/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/deviceframework/TestDevicePluginAdapter.java @@ -77,6 +77,7 @@ import static org.mockito.Matchers.any; import static org.mockito.Matchers.anyBoolean; import static org.mockito.Matchers.anyInt; +import static org.mockito.Matchers.anyMap; import static org.mockito.Matchers.anyString; import static org.mockito.Matchers.eq; import static org.mockito.Matchers.isA; @@ -570,7 +571,7 @@ public void testPreferPluginScheduler() throws IOException, YarnException { adapter.getDeviceResourceHandler().preStart(c1); // Use customized scheduler verify(spyPlugin, times(1)).allocateDevices( - any(TreeSet.class), anyInt()); + any(TreeSet.class), anyInt(), anyMap()); Assert.assertEquals(2, dmm.getAvailableDevices(resourceName)); Assert.assertEquals(1, @@ -994,7 +995,7 @@ public void onDevicesReleased(Set releasedDevices) { @Override public Set allocateDevices(Set availableDevices, - int count) { + int count, Map env) { Set allocated = new TreeSet<>(); int number = 0; for (Device d : availableDevices) { diff --git a/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/test/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/nvidia/com/TestNvidiaGpuPlugin.java b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/test/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/nvidia/com/TestNvidiaGpuPlugin.java index 4b93cae8669..9a7aaf9d721 100644 --- a/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/test/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/nvidia/com/TestNvidiaGpuPlugin.java +++ b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/test/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/nvidia/com/TestNvidiaGpuPlugin.java @@ -24,11 +24,24 @@ import org.apache.hadoop.yarn.server.nodemanager.containermanager.resourceplugin.com.nvidia.NvidiaGPUPlugin; import org.junit.Assert; import org.junit.Test; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import java.io.IOException; +import java.util.HashMap; +import java.util.Iterator; +import java.util.Map; import java.util.Set; import java.util.TreeSet; +import static org.mockito.Matchers.anyInt; +import static org.mockito.Matchers.anyMap; +import static org.mockito.Matchers.anySet; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.reset; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; /** @@ -36,6 +49,9 @@ * */ public class TestNvidiaGpuPlugin { + private static final Logger LOG = + LoggerFactory.getLogger(TestNvidiaGpuPlugin.class); + @Test public void testGetNvidiaDevices() throws Exception { NvidiaGPUPlugin.MyShellExecutor mockShell = @@ -62,7 +78,7 @@ public void testGetNvidiaDevices() throws Exception { .setMajorNumber(195) .setMinorNumber(0).build()); expectedDevices.add(Device.Builder.newInstance() - .setId(0).setHealthy(true) + .setId(1).setHealthy(true) .setBusID("00000000:82:00.0") .setDevPath("/dev/nvidia1") .setMajorNumber(195) @@ -103,6 +119,265 @@ public void testOnDeviceAllocated() throws Exception { YarnRuntimeType.RUNTIME_DOCKER); Assert.assertEquals("nvidia", spec.getContainerRuntime()); Assert.assertEquals("0,1", spec.getEnvs().get("NVIDIA_VISIBLE_DEVICES")); + } + + private NvidiaGPUPlugin mockFourGPUPlugin() throws IOException { + String topoInfo = "\tGPU0\tGPU1\tGPU2\tGPU3\tCPU Affinity\n" + + "GPU0\t X \tPHB\tSOC\tSOC\t0-31\n" + + "GPU1\tPHB\t X \tSOC\tSOC\t0-31\n" + + "GPU2\tSOC\tSOC\t X \tPHB\t0-31\n" + + "GPU3\tSOC\tSOC\tPHB\t X \t0-31\n" + + "\n" + + "\n" + + " Legend:\n" + + "\n" + + " X = Self\n" + + " SOC = Connection traversing PCIe as well as the SMP link between\n" + + " CPU sockets(e.g. QPI)\n" + + " PHB = Connection traversing PCIe as well as a PCIe Host Bridge\n" + + " (typically the CPU)\n" + + " PXB = Connection traversing multiple PCIe switches\n" + + " (without traversing the PCIe Host Bridge)\n" + + " PIX = Connection traversing a single PCIe switch\n" + + " NV# = Connection traversing a bonded set of # NVLinks"; + + String deviceInfoShellOutput = "0, 00000000:04:00.0\n" + + "1, 00000000:82:00.0\n" + + "2, 00000000:83:00.0\n" + + "3, 00000000:84:00.0"; + String majorMinorNumber0 = "c3:0"; + String majorMinorNumber1 = "c3:1"; + String majorMinorNumber2 = "c3:2"; + String majorMinorNumber3 = "c3:3"; + NvidiaGPUPlugin.MyShellExecutor mockShell = + mock(NvidiaGPUPlugin.MyShellExecutor.class); + when(mockShell.getDeviceInfo()).thenReturn(deviceInfoShellOutput); + when(mockShell.getMajorMinorInfo("nvidia0")) + .thenReturn(majorMinorNumber0); + when(mockShell.getMajorMinorInfo("nvidia1")) + .thenReturn(majorMinorNumber1); + when(mockShell.getMajorMinorInfo("nvidia2")) + .thenReturn(majorMinorNumber2); + when(mockShell.getMajorMinorInfo("nvidia3")) + .thenReturn(majorMinorNumber3); + when(mockShell.getTopologyInfo()).thenReturn(topoInfo); + when(mockShell.getDeviceInfo()).thenReturn(deviceInfoShellOutput); + + NvidiaGPUPlugin plugin = new NvidiaGPUPlugin(); + plugin.setShellExecutor(mockShell); + plugin.setPathOfGpuBinary("/fake/nvidia-smi"); + return plugin; + } + + @Test + public void testTopologySchedulingWithPackPolicy() throws Exception { + NvidiaGPUPlugin plugin = mockFourGPUPlugin(); + NvidiaGPUPlugin spyPlugin = spy(plugin); + // cache the total devices + Set allDevices = spyPlugin.getDevices(); + // environment variable to use PACK policy + Map env = new HashMap<>(); + env.put(NvidiaGPUPlugin.TOPOLOGY_POLICY_ENV_KEY, + NvidiaGPUPlugin.TOPOLOGY_POLICY_PACK); + // Case 1. allocate 1 device + Set allocation = spyPlugin.allocateDevices(allDevices, 1, env); + // ensure no topology scheduling needed + Assert.assertEquals(allocation.size(), 1); + verify(spyPlugin).basicSchedule(anySet(), anyInt(), anySet()); + reset(spyPlugin); + // Case 2. allocate all available + allocation = spyPlugin.allocateDevices(allDevices, allDevices.size(), env); + Assert.assertEquals(allocation.size(), allDevices.size()); + verify(spyPlugin).basicSchedule(anySet(), anyInt(), anySet()); + // Case 3. allocate 2 devices + reset(spyPlugin); + int count = 2; + Map pairToWeight = spyPlugin.getDevicePairToWeight(); + allocation = spyPlugin.allocateDevices(allDevices, count, env); + Assert.assertEquals(allocation.size(), count); + // the costTable should be init and used topology scheduling + verify(spyPlugin).initCostTable(); + Assert.assertTrue(spyPlugin.isTopoInitialized()); + verify(spyPlugin).topologyAwareSchedule(anySet(), anyInt(), anyMap(), + anySet(), anyMap()); + Assert.assertEquals(allocation.size(), count); + Device[] allocatedDevices = + allocation.toArray(new Device[count]); + // Check weights + Assert.assertEquals(2, spyPlugin.computeCostOfDevices(allocatedDevices)); + // Case 4. allocate 3 devices + reset(spyPlugin); + count = 3; + allocation = spyPlugin.allocateDevices(allDevices, count, env); + Assert.assertEquals(allocation.size(), count); + // the costTable should be init and used topology scheduling + verify(spyPlugin, times(0)).initCostTable(); + Assert.assertTrue(spyPlugin.isTopoInitialized()); + verify(spyPlugin).topologyAwareSchedule(anySet(), anyInt(), anyMap(), + anySet(), anyMap()); + Assert.assertEquals(allocation.size(), count); + allocatedDevices = + allocation.toArray(new Device[count]); + // check weights + Assert.assertEquals(2 + 4 + 4, + spyPlugin.computeCostOfDevices(allocatedDevices)); + // Case 5. allocate 2 GPUs from three available devices + reset(spyPlugin); + Iterator iterator = allDevices.iterator(); + iterator.next(); + // remove GPU0 + iterator.remove(); + count = 2; + allocation = spyPlugin.allocateDevices(allDevices, count, env); + Assert.assertEquals(allocation.size(), count); + // the costTable should be init and used topology scheduling + verify(spyPlugin, times(0)).initCostTable(); + Assert.assertTrue(spyPlugin.isTopoInitialized()); + verify(spyPlugin).topologyAwareSchedule(anySet(), anyInt(), anyMap(), + anySet(), anyMap()); + Assert.assertEquals(allocation.size(), count); + allocatedDevices = + allocation.toArray(new Device[count]); + // check weights + Assert.assertEquals(2, + spyPlugin.computeCostOfDevices(allocatedDevices)); + // it should allocate GPU 2 and 3 + for (Device device : allocation) { + if (device.getMinorNumber() == 2) { + Assert.assertTrue(true); + } else if (device.getMinorNumber() == 3) { + Assert.assertTrue(true); + } else { + Assert.assertTrue("Should allocate GPU 2 and 3",false); + } + } + } + + @Test + public void testTopologySchedulingWithSpreadPolicy() throws Exception { + NvidiaGPUPlugin plugin = mockFourGPUPlugin(); + NvidiaGPUPlugin spyPlugin = spy(plugin); + // cache the total devices + Set allDevices = spyPlugin.getDevices(); + // environment variable to use PACK policy + Map env = new HashMap<>(); + env.put(NvidiaGPUPlugin.TOPOLOGY_POLICY_ENV_KEY, + NvidiaGPUPlugin.TOPOLOGY_POLICY_SPREAD); + // Case 1. allocate 1 device + Set allocation = spyPlugin.allocateDevices(allDevices, 1, env); + // ensure no topology scheduling needed + Assert.assertEquals(allocation.size(), 1); + verify(spyPlugin).basicSchedule(anySet(), anyInt(), anySet()); + reset(spyPlugin); + // Case 2. allocate all available + allocation = spyPlugin.allocateDevices(allDevices, allDevices.size(), env); + Assert.assertEquals(allocation.size(), allDevices.size()); + verify(spyPlugin).basicSchedule(anySet(), anyInt(), anySet()); + // Case 3. allocate 2 devices + reset(spyPlugin); + int count = 2; + Map pairToWeight = spyPlugin.getDevicePairToWeight(); + allocation = spyPlugin.allocateDevices(allDevices, count, env); + Assert.assertEquals(allocation.size(), count); + // the costTable should be init and used topology scheduling + verify(spyPlugin).initCostTable(); + Assert.assertTrue(spyPlugin.isTopoInitialized()); + verify(spyPlugin).topologyAwareSchedule(anySet(), anyInt(), anyMap(), + anySet(), anyMap()); + Assert.assertEquals(allocation.size(), count); + Device[] allocatedDevices = + allocation.toArray(new Device[count]); + // Check weights + Assert.assertEquals(4, spyPlugin.computeCostOfDevices(allocatedDevices)); + // Case 4. allocate 3 devices + reset(spyPlugin); + count = 3; + allocation = spyPlugin.allocateDevices(allDevices, count, env); + Assert.assertEquals(allocation.size(), count); + // the costTable should be init and used topology scheduling + verify(spyPlugin, times(0)).initCostTable(); + Assert.assertTrue(spyPlugin.isTopoInitialized()); + verify(spyPlugin).topologyAwareSchedule(anySet(), anyInt(), anyMap(), + anySet(), anyMap()); + Assert.assertEquals(allocation.size(), count); + allocatedDevices = + allocation.toArray(new Device[count]); + // check weights + Assert.assertEquals(2 + 4 + 4, + spyPlugin.computeCostOfDevices(allocatedDevices)); + // Case 5. allocate 2 GPUs from three available devices + reset(spyPlugin); + Iterator iterator = allDevices.iterator(); + iterator.next(); + // remove GPU0 + iterator.remove(); + count = 2; + allocation = spyPlugin.allocateDevices(allDevices, count, env); + Assert.assertEquals(allocation.size(), count); + // the costTable should be init and used topology scheduling + verify(spyPlugin, times(0)).initCostTable(); + Assert.assertTrue(spyPlugin.isTopoInitialized()); + verify(spyPlugin).topologyAwareSchedule(anySet(), anyInt(), anyMap(), + anySet(), anyMap()); + Assert.assertEquals(allocation.size(), count); + allocatedDevices = + allocation.toArray(new Device[count]); + // check weights + Assert.assertEquals(4, + spyPlugin.computeCostOfDevices(allocatedDevices)); + // it should allocate GPU 1 and 2 + for (Device device : allocation) { + if (device.getMinorNumber() == 2) { + Assert.assertTrue(true); + } else if (device.getMinorNumber() == 1) { + Assert.assertTrue(true); + } else { + Assert.assertTrue("Should allocate GPU 1 and 2",false); + } + } + } + + /** + * Test the key cost table used for topology scheduling + * */ + @Test + public void testCostTable() throws IOException { + NvidiaGPUPlugin plugin = mockFourGPUPlugin(); + NvidiaGPUPlugin spyPlugin = spy(plugin); + // verify the device pair to weight map + spyPlugin.initCostTable(); + Map devicePairToWeight = spyPlugin.getDevicePairToWeight(); + // 12 combinations when choose 2 GPUs from 4 respect the order + Assert.assertEquals(12, devicePairToWeight.size()); + int sameCPUWeight = + NvidiaGPUPlugin.DeviceLinkType.P2PLinkSameCPUSocket.getWeight(); + int crossCPUWeight = + NvidiaGPUPlugin.DeviceLinkType.P2PLinkCrossCPUSocket.getWeight(); + // GPU 0 to 1, weight is 2 + Assert.assertEquals(sameCPUWeight, (int)devicePairToWeight.get("0-1")); + Assert.assertEquals(sameCPUWeight, (int)devicePairToWeight.get("1-0")); + // GPU 0 to 2, weight is 4 + Assert.assertEquals(crossCPUWeight, (int)devicePairToWeight.get("0-2")); + Assert.assertEquals(crossCPUWeight, (int)devicePairToWeight.get("2-0")); + // GPU 0 to 3, weight is 4 + Assert.assertEquals(crossCPUWeight, (int)devicePairToWeight.get("0-3")); + Assert.assertEquals(crossCPUWeight, (int)devicePairToWeight.get("3-0")); + // GPU 1 to 2, weight is 4 + Assert.assertEquals(crossCPUWeight, (int)devicePairToWeight.get("1-2")); + Assert.assertEquals(crossCPUWeight, (int)devicePairToWeight.get("2-1")); + // GPU 1 to 3, weight is 4 + Assert.assertEquals(crossCPUWeight, (int)devicePairToWeight.get("1-3")); + Assert.assertEquals(crossCPUWeight, (int)devicePairToWeight.get("3-1")); + // GPU 2 to 3, weight is 2 + Assert.assertEquals(sameCPUWeight, (int)devicePairToWeight.get("2-3")); + Assert.assertEquals(sameCPUWeight, (int)devicePairToWeight.get("3-2")); + // verify cost Table + Map, Integer>> costTable = + spyPlugin.getCostTable(); + Assert.assertNull(costTable.get(1)); + Assert.assertEquals(6, costTable.get(2).size()); + Assert.assertEquals(4, costTable.get(3).size()); + Assert.assertNull(costTable.get(4)); } }