diff --git a/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-common/src/main/java/org/apache/hadoop/yarn/nodelabels/CommonNodeLabelsManager.java b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-common/src/main/java/org/apache/hadoop/yarn/nodelabels/CommonNodeLabelsManager.java index 4fd4bd6bd72..dfa58f19a9e 100644 --- a/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-common/src/main/java/org/apache/hadoop/yarn/nodelabels/CommonNodeLabelsManager.java +++ b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-common/src/main/java/org/apache/hadoop/yarn/nodelabels/CommonNodeLabelsManager.java @@ -559,6 +559,49 @@ private void replaceNodeForLabels(NodeId node, Set oldLabels, addNodeToLabels(node, newLabels); } + private void addLabelsToNodeInHost(NodeId node, Set labels) + throws IOException { + Host host = nodeCollections.get(node.getHost()); + if (null == host) { + throw new IOException("Should create host before add labels."); + } + Node nm = host.nms.get(node); + if (nm != null) { + Node new_nm = nm.copy(); + if (new_nm.labels == null) { + new_nm.labels = Collections.newSetFromMap(new ConcurrentHashMap()); + } + new_nm.labels.addAll(labels); + + host.nms.put(node, new_nm); + } + } + + protected void removeLabelsFromNodeInHost(NodeId node, Set labels) + throws IOException { + Host host = nodeCollections.get(node.getHost()); + if (null == host) { + throw new IOException("Should create host before remove labels."); + } + Node nm = host.nms.get(node); + if (nm != null) { + if (nm.labels == null) { + nm.labels = new HashSet(); + } else { + nm.labels.removeAll(labels); + } + } + } + + private void replaceLabelsForNode(NodeId node, Set oldLabels, + Set newLabels) throws IOException { + if(oldLabels != null) { + removeLabelsFromNodeInHost(node, oldLabels); + } + addLabelsToNodeInHost(node, newLabels); + } + + @SuppressWarnings("unchecked") protected void internalUpdateLabelsOnNodes( Map> nodeToLabels, NodeLabelUpdateOperation op) @@ -597,10 +640,12 @@ protected void internalUpdateLabelsOnNodes( break; case REPLACE: replaceNodeForLabels(nodeId, host.labels, labels); + replaceLabelsForNode(nodeId, host.labels, labels); host.labels.clear(); host.labels.addAll(labels); for (Node node : host.nms.values()) { replaceNodeForLabels(node.nodeId, node.labels, labels); + replaceLabelsForNode(node.nodeId, node.labels, labels); node.labels = null; } break; @@ -625,6 +670,7 @@ protected void internalUpdateLabelsOnNodes( case REPLACE: oldLabels = getLabelsByNode(nodeId); replaceNodeForLabels(nodeId, oldLabels, labels); + replaceLabelsForNode(nodeId, oldLabels, labels); if (nm.labels == null) { nm.labels = new HashSet(); }