diff --git hadoop-yarn-project/hadoop-yarn/hadoop-yarn-api/src/main/java/org/apache/hadoop/yarn/conf/YarnConfiguration.java hadoop-yarn-project/hadoop-yarn/hadoop-yarn-api/src/main/java/org/apache/hadoop/yarn/conf/YarnConfiguration.java
index 8899ccd..f15707d 100644
--- hadoop-yarn-project/hadoop-yarn/hadoop-yarn-api/src/main/java/org/apache/hadoop/yarn/conf/YarnConfiguration.java
+++ hadoop-yarn-project/hadoop-yarn/hadoop-yarn-api/src/main/java/org/apache/hadoop/yarn/conf/YarnConfiguration.java
@@ -2743,6 +2743,167 @@ public static boolean areNodeLabelsEnabled(
public static final String TIMELINE_XFS_OPTIONS =
TIMELINE_XFS_PREFIX + "xframe-options";
+ // Prevent DoS attack configuration for ApplicationMasterProtocol
+
+ public static final String DOS_PREFIX = "yarn.dos.";
+
+ public static final String DOS_SLIDING_WINDOW_PREFIX =
+ DOS_PREFIX + "sliding-windows.";
+
+ public static final String DOS_SLIDING_WINDOW_SIZE =
+ DOS_SLIDING_WINDOW_PREFIX + "size";
+ public static final int DEFAULT_DOS_SLIDING_WINDOW_SIZE = 10;
+
+ public static final String DOS_SLIDING_WINDOW_ADVANCE_TIME_SEC =
+ DOS_SLIDING_WINDOW_PREFIX + "advance-time";
+ public static final int DEFAULT_DOS_SLIDING_WINDOW_ADVANCE_TIME_SEC = 60;
+
+ public static final String DOS_TRUNCATE_STRING_ENABLED =
+ DOS_PREFIX + "truncate-string.enable";
+ public static final boolean DEFAULT_DOS_TRUNCATE_STRING_ENABLED = true;
+
+ public static final String DOS_ACTION = DOS_PREFIX + "action";
+ public static final String DEFAULT_DOS_ACTION = "IGNORE";
+
+ // Single request limit
+
+ public static final String DOS_REQUEST_LIMIT = DOS_PREFIX + "request-limit.";
+
+ public static final String DOS_REQUEST_LIMIT_LENGTH_HOST =
+ DOS_REQUEST_LIMIT + "length-host";
+ public static final int DEFAULT_DOS_REQUEST_LIMIT_LENGTH_HOST = 150;
+
+ public static final String DOS_REQUEST_LIMIT_RPC_PORT =
+ DOS_REQUEST_LIMIT + "rpc-port";
+ public static final int DEFAULT_DOS_REQUEST_LIMIT_RPC_PORT = 10000;
+
+ public static final String DOS_REQUEST_LIMIT_LENGTH_DIAGNOSTIC =
+ DOS_REQUEST_LIMIT + "length-diagnostic";
+ public static final int DEFAULT_DOS_REQUEST_LIMIT_LENGTH_DIAGNOSTIC = 100;
+
+ public static final String DOS_REQUEST_LIMIT_MIN_PRIORITY =
+ DOS_REQUEST_LIMIT + "min-priority";
+ public static final int DEFAULT_DOS_REQUEST_LIMIT_MIN_PRIORITY = 0;
+
+ public static final String DOS_REQUEST_LIMIT_MAX_PRIORITY =
+ DOS_REQUEST_LIMIT + "max-priority";
+ public static final int DEFAULT_DOS_REQUEST_LIMIT_MAX_PRIORITY = 100;
+
+ public static final String DOS_REQUEST_LIMIT_LENGTH_RESOURCE =
+ DOS_REQUEST_LIMIT + "length-resource";
+ public static final int DEFAULT_DOS_REQUEST_LIMIT_LENGTH_RESOURCE = 150;
+
+ public static final String DOS_REQUEST_LIMIT_LENGTH_NODELABEL =
+ DOS_REQUEST_LIMIT + "length-nodelabel";
+ public static final int DEFAULT_DOS_REQUEST_LIMIT_LENGTH_NODELABEL = 200;
+
+ public static final String DOS_REQUEST_LIMIT_MAX_MEMORY =
+ DOS_REQUEST_LIMIT + "max-memory";
+ public static final int DEFAULT_DOS_REQUEST_LIMIT_MAX_MEMORY = 100 * 1024;
+
+ public static final String DOS_REQUEST_LIMIT_MAX_VCORES =
+ DOS_REQUEST_LIMIT + "max-vcores";
+ public static final int DEFAULT_DOS_REQUEST_LIMIT_MAX_VCORES = 24;
+
+ public static final String DOS_REQUEST_LIMIT_ASK_LIST =
+ DOS_REQUEST_LIMIT + "ask-list";
+ public static final int DEFAULT_DOS_REQUEST_LIMIT_ASK_LIST = 100;
+
+ public static final String DOS_REQUEST_LIMIT_RELEASE_LIST =
+ DOS_REQUEST_LIMIT + "release-list";
+ public static final int DEFAULT_DOS_REQUEST_LIMIT_RELEASE_LIST = 100;
+
+ public static final String DOS_REQUEST_LIMIT_BLACKLIST =
+ DOS_REQUEST_LIMIT + "blacklist";
+ public static final int DEFAULT_DOS_REQUEST_LIMIT_BLACKLIST = 100;
+
+ public static final String DOS_REQUEST_LIMIT_CHANGELIST =
+ DOS_REQUEST_LIMIT + "change-list";
+ public static final int DEFAULT_DOS_REQUEST_LIMIT_CHANGELIST = 100;
+
+ public static final String DOS_REQUEST_LIMIT_CONTAINERS =
+ DOS_REQUEST_LIMIT + "containers";
+ public static final int DEFAULT_DOS_REQUEST_LIMIT_CONTAINERS = 200;
+
+ // Sliding Window Configuration
+
+ public static final String DOS_SLIDING_WINDOW_THRESHOLD =
+ DOS_SLIDING_WINDOW_PREFIX + "threshold.";
+
+ public static final String DOS_SLIDING_WINDOW_THRESHOLD_REGISTER_AM =
+ DOS_SLIDING_WINDOW_THRESHOLD + "register-am";
+ public static final int DEFAULT_DOS_SLIDING_WINDOW_THRESHOLD_REGISTER_AM = 10;
+
+ public static final String DOS_SLIDING_WINDOW_THRESHOLD_ALLOCATE_AM =
+ DOS_SLIDING_WINDOW_THRESHOLD + "allocate-am";
+ public static final int DEFAULT_DOS_SLIDING_WINDOW_THRESHOLD_ALLOCATE_AM =
+ 1000;
+
+ public static final String DOS_SLIDING_WINDOW_THRESHOLD_FINISH_AM =
+ DOS_SLIDING_WINDOW_THRESHOLD + "finish-am";
+ public static final int DEFAULT_DOS_SLIDING_WINDOW_THRESHOLD_FINISH_AM = 10;
+
+ public static final String DOS_SLIDING_WINDOW_THRESHOLD_ASK_LIST =
+ DOS_SLIDING_WINDOW_THRESHOLD + "ask-list";
+ public static final int DEFAULT_DOS_SLIDING_WINDOW_THRESHOLD_ASK_LIST =
+ 1000;
+
+ public static final String DOS_SLIDING_WINDOW_THRESHOLD_RELEASE_LIST =
+ DOS_SLIDING_WINDOW_THRESHOLD + "release-list";
+ public static final int DEFAULT_DOS_SLIDING_WINDOW_THRESHOLD_RELEASE_LIST =
+ 1000;
+
+ public static final String DOS_SLIDING_WINDOW_THRESHOLD_BLACKLIST =
+ DOS_SLIDING_WINDOW_THRESHOLD + "blacklist";
+ public static final int DEFAULT_DOS_SLIDING_WINDOW_THRESHOLD_BLACKLIST =
+ 100;
+
+ public static final String DOS_SLIDING_WINDOW_THRESHOLD_CHANGELIST =
+ DOS_SLIDING_WINDOW_THRESHOLD + "changelist";
+ public static final int DEFAULT_DOS_SLIDING_WINDOW_THRESHOLD_CHANGELIST =
+ 100;
+
+ public static final String DOS_SLIDING_WINDOW_THRESHOLD_NUM_CONTAINERS =
+ DOS_SLIDING_WINDOW_THRESHOLD + "num-containers";
+ public static final int DEFAULT_DOS_SLIDING_WINDOW_THRESHOLD_NUM_CONTAINERS =
+ 1000;
+
+ // Multiple requests limit
+
+ public static final String DOS_MAX_REQUESTS = DOS_PREFIX + "max-requests.";
+
+ public static final String DOS_MAX_REQUESTS_REGISTER_AM =
+ DOS_MAX_REQUESTS + "register-am";
+ public static final int DEFAULT_DOS_MAX_REQUESTS_REGISTER_AM = 50;
+
+ public static final String DOS_MAX_REQUESTS_ALLOCATE_AM =
+ DOS_MAX_REQUESTS + "allocate-am";
+ public static final int DEFAULT_DOS_MAX_REQUESTS_ALLOCATE_AM = 1000;
+
+ public static final String DOS_MAX_REQUESTS_FINISH_AM =
+ DOS_MAX_REQUESTS + "finish-am";
+ public static final int DEFAULT_DOS_MAX_REQUESTS_FINISH_AM = 50;
+
+ public static final String DOS_MAX_REQUESTS_ASK_LIST =
+ DOS_MAX_REQUESTS + "ask-list";
+ public static final int DEFAULT_DOS_MAX_REQUESTS_ASK_LIST = 1000;
+
+ public static final String DOS_MAX_REQUESTS_RELEASE_LIST =
+ DOS_MAX_REQUESTS + "release-list";
+ public static final int DEFAULT_DOS_MAX_REQUESTS_RELEASE_LIST = 1000;
+
+ public static final String DOS_MAX_REQUESTS_BLACKLIST =
+ DOS_MAX_REQUESTS + "blacklist";
+ public static final int DEFAULT_DOS_MAX_REQUESTS_BLACKLIST = 100;
+
+ public static final String DOS_MAX_REQUESTS_CHANGELIST =
+ DOS_MAX_REQUESTS + "changelist";
+ public static final int DEFAULT_DOS_MAX_REQUESTS_CHANGELIST = 100;
+
+ public static final String DOS_MAX_REQUESTS_NUM_CONTAINERS =
+ DOS_MAX_REQUESTS + "num-containers";
+ public static final int DEFAULT_DOS_MAX_REQUESTS_NUM_CONTAINERS = 100000;
+
public YarnConfiguration() {
super();
}
diff --git hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-common/src/main/java/org/apache/hadoop/yarn/server/preventdos/PreventDoSAction.java hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-common/src/main/java/org/apache/hadoop/yarn/server/preventdos/PreventDoSAction.java
new file mode 100644
index 0000000..036d0ff
--- /dev/null
+++ hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-common/src/main/java/org/apache/hadoop/yarn/server/preventdos/PreventDoSAction.java
@@ -0,0 +1,63 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.hadoop.yarn.server.preventdos;
+
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * In case of potential DoS attack, this class takes action against the client.
+ *
+ */
+public class PreventDoSAction {
+
+ private static final Logger LOG =
+ LoggerFactory.getLogger(PreventDoSAction.class);
+
+ private PreventDoSPolicy policy;
+
+ public PreventDoSAction(PreventDoSPolicy policy) {
+ this.policy = policy;
+ }
+
+ /**
+ * In case of potential DoS attack, this method takes action against the
+ * client.
+ *
+ * Ignore: It logs the cause, this option is used in case of testing.
+ *
+ * Reject: It rejects the request by throwing an exception.
+ *
+ */
+ public void prevent(PreventDoSCause cause) throws PreventDoSAttackException {
+
+ switch (policy) {
+
+ case IGNORE: {
+ LOG.warn("Possible DoS attack: " + cause);
+ break;
+ }
+
+ case REJECT: {
+ throw new PreventDoSAttackException("Possible DoS attack: " + cause);
+ }
+
+ }
+ }
+}
\ No newline at end of file
diff --git hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-common/src/main/java/org/apache/hadoop/yarn/server/preventdos/PreventDoSAttackException.java hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-common/src/main/java/org/apache/hadoop/yarn/server/preventdos/PreventDoSAttackException.java
new file mode 100644
index 0000000..02682f4
--- /dev/null
+++ hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-common/src/main/java/org/apache/hadoop/yarn/server/preventdos/PreventDoSAttackException.java
@@ -0,0 +1,52 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.hadoop.yarn.server.preventdos;
+
+import org.apache.hadoop.classification.InterfaceAudience.Public;
+import org.apache.hadoop.classification.InterfaceStability.Stable;
+import org.apache.hadoop.yarn.exceptions.YarnException;
+
+/**
+ * Exception thrown by the {@code AMPDoSRequestInterceptor} if the application
+ * code is trying to DoS attack the ResourceManager.
+ *
+ */
+@Public
+@Stable
+public class PreventDoSAttackException extends YarnException {
+
+ private static final long serialVersionUID = 1L;
+
+ public PreventDoSAttackException() {
+ super();
+ }
+
+ public PreventDoSAttackException(String message) {
+ super(message);
+ }
+
+ public PreventDoSAttackException(Throwable cause) {
+ super(cause);
+ }
+
+ public PreventDoSAttackException(String message, Throwable cause) {
+ super(message, cause);
+ }
+
+}
\ No newline at end of file
diff --git hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-common/src/main/java/org/apache/hadoop/yarn/server/preventdos/PreventDoSCause.java hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-common/src/main/java/org/apache/hadoop/yarn/server/preventdos/PreventDoSCause.java
new file mode 100644
index 0000000..f93c86e
--- /dev/null
+++ hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-common/src/main/java/org/apache/hadoop/yarn/server/preventdos/PreventDoSCause.java
@@ -0,0 +1,288 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.hadoop.yarn.server.preventdos;
+
+/**
+ * This class groups the information about the Cause of Potential DoS attack.
+ *
+ *
+ * It contains:
+ *
+ *
{@link TypeCheck}
+ *
The method called by the user
+ *
The parameter that exceed the check
+ *
The reason of the failed check
+ *
+ *
+ */
+public class PreventDoSCause {
+
+ private TypeCheck typeCheck;
+ private String method;
+ private String param;
+ private String reason;
+
+ /**
+ * Type of Check.
+ *
+ */
+ public enum TypeCheck {
+ /* Check on a single request */
+ REQUEST_LIMIT,
+ /* Check on a temporal sliding window */
+ SLIDING_WINDOW,
+ /* Check on the entire execution of the application */
+ REQUESTS_LIMIT
+ }
+
+ // Application Master Protocol possible cause
+
+ // Over a Single Request
+
+ public final static PreventDoSCause REQUEST_GET_HOST_CAUSE =
+ new PreventDoSCause(PreventDoSCause.TypeCheck.REQUEST_LIMIT,
+ "Register Application Master", "getHost", "Exceed Max Length");
+
+ public final static PreventDoSCause REQUEST_GET_RPC_PORT_CAUSE =
+ new PreventDoSCause(PreventDoSCause.TypeCheck.REQUEST_LIMIT,
+ "Register Application Master", "getRpcPort", "Invalid Port");
+
+ public final static PreventDoSCause REQUEST_GET_TRACKING_URL_CAUSE =
+ new PreventDoSCause(PreventDoSCause.TypeCheck.REQUEST_LIMIT,
+ "Register Application Master", "getTrackingUrl", "Invalid URL");
+
+ public final static PreventDoSCause REQUEST_GET_DIAGNOSTIC_CAUSE =
+ new PreventDoSCause(PreventDoSCause.TypeCheck.REQUEST_LIMIT,
+ "Finish Application Master", "getTrackingUrl", "Exceed Max Length");
+
+ public final static PreventDoSCause REQUEST_GET_TRACKING_URL_FINISH_CAUSE =
+ new PreventDoSCause(PreventDoSCause.TypeCheck.REQUEST_LIMIT,
+ "Finish Application Master", "getTrackingUrl", "Exceed Max Length");
+
+ public final static PreventDoSCause REQUEST_GET_RESPONSE_ID_CAUSE =
+ new PreventDoSCause(PreventDoSCause.TypeCheck.REQUEST_LIMIT, "Allocate",
+ "getResponseId", "Decreased number response Id");
+
+ public final static PreventDoSCause REQUEST_GET_PROGRESS_CAUSE =
+ new PreventDoSCause(PreventDoSCause.TypeCheck.REQUEST_LIMIT, "Allocate",
+ "getProgress", "Out of Range");
+
+ public final static PreventDoSCause REQUEST_GET_ASK_LIST_CAUSE =
+ new PreventDoSCause(PreventDoSCause.TypeCheck.REQUEST_LIMIT, "Allocate",
+ "getAskList", "Exceed Max Size Length");
+
+ public final static PreventDoSCause REQUEST_GET_PRIORITY_CAUSE =
+ new PreventDoSCause(PreventDoSCause.TypeCheck.REQUEST_LIMIT, "Allocate",
+ "getPriority", "Out of Range");
+
+ public final static PreventDoSCause REQUEST_GET_RESOURCE_NAME_CAUSE =
+ new PreventDoSCause(PreventDoSCause.TypeCheck.REQUEST_LIMIT, "Allocate",
+ "getResourceName", "Exceed Max Length");
+
+ public final static PreventDoSCause REQUEST_GET_CAPABILITY_CAUSE =
+ new PreventDoSCause(PreventDoSCause.TypeCheck.REQUEST_LIMIT, "Allocate",
+ "getCapability", "Out of Range");
+
+ public final static PreventDoSCause REQUEST_GET_NUM_CONTAINERS_CAUSE =
+ new PreventDoSCause(PreventDoSCause.TypeCheck.REQUEST_LIMIT, "Allocate",
+ "getNumContainers", "Out of Range");
+
+ public final static PreventDoSCause REQUEST_GET_NODE_LABEL_CAUSE =
+ new PreventDoSCause(PreventDoSCause.TypeCheck.REQUEST_LIMIT, "Allocate",
+ "getNodeLabelExpression", "Exceed Max Length");
+
+ public final static PreventDoSCause REQUEST_GET_RELEASE_LIST_CAUSE =
+ new PreventDoSCause(PreventDoSCause.TypeCheck.REQUEST_LIMIT, "Allocate",
+ "getReleaseList", "Exceed Max Size Length");
+
+ public final static PreventDoSCause REQUEST_GET_BLACKLIST_ADD_LIST_CAUSE =
+ new PreventDoSCause(PreventDoSCause.TypeCheck.REQUEST_LIMIT, "Allocate",
+ "getBlacklistAdditions", "Exceed Max Size Length");
+
+ public final static PreventDoSCause REQUEST_GET_BLACKLIST_REM_LIST_CAUSE =
+ new PreventDoSCause(PreventDoSCause.TypeCheck.REQUEST_LIMIT, "Allocate",
+ "getBlacklistRemovals", "Exceed Max Size Length");
+
+ public final static PreventDoSCause REQUEST_GET_INCREASE_LIST_CAUSE =
+ new PreventDoSCause(PreventDoSCause.TypeCheck.REQUEST_LIMIT, "Allocate",
+ "getIncreaseRequests", "Exceed Max Size Length");
+
+ public final static PreventDoSCause REQUEST_GET_DECREASE_LIST_CAUSE =
+ new PreventDoSCause(PreventDoSCause.TypeCheck.REQUEST_LIMIT, "Allocate",
+ "getDecreaseRequests", "Exceed Max Size Length");
+
+ // Over Sliding Window
+
+ public final static PreventDoSCause SLIDINGWINDOW_REGISTER_AM_CAUSE =
+ new PreventDoSCause(PreventDoSCause.TypeCheck.SLIDING_WINDOW,
+ "Register Application Master", "registerApplicationMaster",
+ "Exceed Sliding Windows Threshold");
+
+ public final static PreventDoSCause SLIDINGWINDOW_FINISH_AM_CAUSE =
+ new PreventDoSCause(PreventDoSCause.TypeCheck.SLIDING_WINDOW,
+ "Finish Application Master", "finishApplicationMaster",
+ "Exceed Sliding Windows Threshold");
+
+ public final static PreventDoSCause SLIDINGWINDOW_ALLOCATE_AM_CAUSE =
+ new PreventDoSCause(PreventDoSCause.TypeCheck.SLIDING_WINDOW, "Allocate",
+ "allocate", "Exceed Sliding Windows Threshold");
+
+ public final static PreventDoSCause SLIDINGWINDOW_GET_ASK_LIST_CAUSE =
+ new PreventDoSCause(PreventDoSCause.TypeCheck.SLIDING_WINDOW, "Allocate",
+ "getAskList", "Exceed Sliding Windows Threshold");
+
+ public final static PreventDoSCause SLIDINGWINDOW_GET_NUM_CONTAINERS_CAUSE =
+ new PreventDoSCause(PreventDoSCause.TypeCheck.SLIDING_WINDOW, "Allocate",
+ "getNumContainers", "Exceed Sliding Windows Threshold");
+
+ public final static PreventDoSCause SLIDINGWINDOW_GET_RELEASE_LIST_CAUSE =
+ new PreventDoSCause(PreventDoSCause.TypeCheck.SLIDING_WINDOW, "Allocate",
+ "getReleaseList", "Exceed Sliding Windows Threshold");
+
+ public final static PreventDoSCause SLIDINGWINDOW_GET_BLACKLIST_ADD_LIST_CAUSE =
+ new PreventDoSCause(PreventDoSCause.TypeCheck.SLIDING_WINDOW, "Allocate",
+ "getBlacklistAdditions", "Exceed Sliding Windows Threshold");
+
+ public final static PreventDoSCause SLIDINGWINDOW_GET_BLACKLIST_REM_LIST_CAUSE =
+ new PreventDoSCause(PreventDoSCause.TypeCheck.SLIDING_WINDOW, "Allocate",
+ "getBlacklistRemovals", "Exceed Sliding Windows Threshold");
+
+ public final static PreventDoSCause SLIDINGWINDOW_GET_INCREASE_LIST_CAUSE =
+ new PreventDoSCause(PreventDoSCause.TypeCheck.SLIDING_WINDOW, "Allocate",
+ "getIncreaseRequests", "Exceed Sliding Windows Threshold");
+
+ public final static PreventDoSCause SLIDINGWINDOW_GET_DECREASE_LIST_CAUSE =
+ new PreventDoSCause(PreventDoSCause.TypeCheck.SLIDING_WINDOW, "Allocate",
+ "getDecreaseRequests", "Exceed Sliding Windows Threshold");
+
+ // Over the entire execution
+
+ public final static PreventDoSCause REQUESTS_REGISTER_AM_CAUSE =
+ new PreventDoSCause(PreventDoSCause.TypeCheck.REQUESTS_LIMIT,
+ "Register Application Master", "registerApplicationMaster",
+ "Exceed Max Total Count");
+
+ public final static PreventDoSCause REQUESTS_FINISH_AM_CAUSE =
+ new PreventDoSCause(PreventDoSCause.TypeCheck.REQUESTS_LIMIT,
+ "Finish Application Master", "finishApplicationMaster",
+ "Exceed Max Total Count");
+
+ public final static PreventDoSCause REQUESTS_ALLOCATE_AM_CAUSE =
+ new PreventDoSCause(PreventDoSCause.TypeCheck.REQUESTS_LIMIT, "Allocate",
+ "allocate", "Exceed Max Total Count");
+
+ public final static PreventDoSCause REQUESTS_GET_ASK_LIST_CAUSE =
+ new PreventDoSCause(PreventDoSCause.TypeCheck.REQUESTS_LIMIT, "Allocate",
+ "getAskList", "Exceed Max Total Count");
+
+ public final static PreventDoSCause REQUESTS_GET_NUM_CONTAINERS_CAUSE =
+ new PreventDoSCause(PreventDoSCause.TypeCheck.REQUESTS_LIMIT, "Allocate",
+ "getNumContainers", "Exceed Max Total Count");
+
+ public final static PreventDoSCause REQUESTS_GET_RELEASE_LIST_CAUSE =
+ new PreventDoSCause(PreventDoSCause.TypeCheck.REQUESTS_LIMIT, "Allocate",
+ "getReleaseList", "Exceed Max Total Count");
+
+ public final static PreventDoSCause REQUESTS_GET_BLACKLIST_ADD_LIST_CAUSE =
+ new PreventDoSCause(PreventDoSCause.TypeCheck.REQUESTS_LIMIT, "Allocate",
+ "getBlacklistAdditions", "Exceed Max Total Count");
+
+ public final static PreventDoSCause REQUESTS_GET_BLACKLIST_REM_LIST_CAUSE =
+ new PreventDoSCause(PreventDoSCause.TypeCheck.REQUESTS_LIMIT, "Allocate",
+ "getBlacklistRemovals", "Exceed Max Total Count");
+
+ public final static PreventDoSCause REQUESTS_GET_INCREASE_LIST_CAUSE =
+ new PreventDoSCause(PreventDoSCause.TypeCheck.REQUESTS_LIMIT, "Allocate",
+ "getIncreaseRequests", "Exceed Max Total Count");
+
+ public final static PreventDoSCause REQUESTS_GET_DECREASE_LIST_CAUSE =
+ new PreventDoSCause(PreventDoSCause.TypeCheck.REQUESTS_LIMIT, "Allocate",
+ "getDecreaseRequests", "Exceed Max Total Count");
+
+ private PreventDoSCause(TypeCheck typeCheck, String method, String param,
+ String reason) {
+ this.typeCheck = typeCheck;
+ this.method = method;
+ this.param = param;
+ this.reason = reason;
+ }
+
+ /**
+ * @return the method
+ */
+ public String getMethod() {
+ return method;
+ }
+
+ /**
+ * @param method the method to set
+ */
+ public void setMethod(String method) {
+ this.method = method;
+ }
+
+ /**
+ * @return the param
+ */
+ public String getParam() {
+ return param;
+ }
+
+ /**
+ * @param param the param to set
+ */
+ public void setParam(String param) {
+ this.param = param;
+ }
+
+ /**
+ * @return the reason
+ */
+ public String getReason() {
+ return reason;
+ }
+
+ /**
+ * @param reason the reason to set
+ */
+ public void setReason(String reason) {
+ this.reason = reason;
+ }
+
+ /**
+ * @return the typeCheck
+ */
+ public TypeCheck getTypeCheck() {
+ return typeCheck;
+ }
+
+ /**
+ * @param typeCheck the typeCheck to set
+ */
+ public void setTypeCheck(TypeCheck typeCheck) {
+ this.typeCheck = typeCheck;
+ }
+
+ @Override
+ public String toString() {
+ return "PreventDoSCause [typeCheck=" + typeCheck + ", method=" + method
+ + ", param=" + param + ", reason=" + reason + "]";
+ }
+
+}
\ No newline at end of file
diff --git hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-common/src/main/java/org/apache/hadoop/yarn/server/preventdos/PreventDoSContext.java hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-common/src/main/java/org/apache/hadoop/yarn/server/preventdos/PreventDoSContext.java
new file mode 100644
index 0000000..983e678
--- /dev/null
+++ hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-common/src/main/java/org/apache/hadoop/yarn/server/preventdos/PreventDoSContext.java
@@ -0,0 +1,45 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.hadoop.yarn.server.preventdos;
+
+/**
+ * Context of the Prevent DoS attack
+ */
+public interface PreventDoSContext {
+
+ /**
+ * Set the configuration limits for the single request checks.
+ */
+ void setConfigurationLimitRequest();
+
+ /**
+ * Set the configuration limits for the temporal sliding window checks.
+ */
+ void setConfigurationThresholdSlidingWindow();
+
+ /**
+ * Set the configuration limits for the entire execution checks.
+ */
+ void setConfigurationLimitRequests();
+
+ /**
+ * Shut down the timer of the sliding window.
+ */
+ void shutdownTimer();
+}
\ No newline at end of file
diff --git hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-common/src/main/java/org/apache/hadoop/yarn/server/preventdos/PreventDoSContextAMPImpl.java hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-common/src/main/java/org/apache/hadoop/yarn/server/preventdos/PreventDoSContextAMPImpl.java
new file mode 100644
index 0000000..c8d566c
--- /dev/null
+++ hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-common/src/main/java/org/apache/hadoop/yarn/server/preventdos/PreventDoSContextAMPImpl.java
@@ -0,0 +1,622 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.hadoop.yarn.server.preventdos;
+
+import java.net.URI;
+import java.util.List;
+import java.util.Timer;
+import java.util.TimerTask;
+
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.yarn.api.protocolrecords.AllocateRequest;
+import org.apache.hadoop.yarn.api.protocolrecords.FinishApplicationMasterRequest;
+import org.apache.hadoop.yarn.api.protocolrecords.RegisterApplicationMasterRequest;
+import org.apache.hadoop.yarn.api.records.ContainerId;
+import org.apache.hadoop.yarn.api.records.ContainerResourceChangeRequest;
+import org.apache.hadoop.yarn.api.records.Priority;
+import org.apache.hadoop.yarn.api.records.Resource;
+import org.apache.hadoop.yarn.api.records.ResourceBlacklistRequest;
+import org.apache.hadoop.yarn.api.records.ResourceRequest;
+import org.apache.hadoop.yarn.conf.YarnConfiguration;
+
+/**
+ * This class helps by keeping tracks of all the counters to prevent a DoS
+ * attack for a single application.
+ */
+public class PreventDoSContextAMPImpl implements PreventDoSContext {
+
+ public static final String NUM_CONTAINERS = "NumContainers";
+ public Configuration conf;
+ Timer timer = null;
+
+ // Entire job execution counters
+ private int counterRegisterRequests = 0;
+ private int counterAllocateRequests = 0;
+ private int counterFinishRequests = 0;
+ private int counterAskListRequests = 0;
+ private int counterReleaseListRequests = 0;
+ private int counterBlackListRequests = 0;
+ private int counterChangeListRequests = 0;
+ private int counterContainersRequests = 0;
+
+ // Limit for the entire job execution
+ private int limitRegisterRequests;
+ private int limitAllocateRequests;
+ private int limitFinishRequests;
+ private int limitAskListRequests;
+ private int limitReleaseListRequests;
+ private int limitBlackListRequests;
+ private int limitChangeListRequests;
+ private int limitContainersRequests;
+
+ // Limit for singular request
+ private int maxLengthRegisterHost;
+ private int maxLengthFinishDiagnostic;
+ private int maxLengthAllocateResourceName;
+ private int maxLengthAllocateNodeLabel;
+ private int lastResponseId = 0;
+ private int validRPCPort;
+ private int maxMemory;
+ private int maxVCores;
+ private int maxSizeAskList;
+ private int maxSizeReleaseList;
+ private int maxSizeBlackList;
+ private int maxSizeChangeList;
+ private int maxSizeContainers;
+ private Priority minPriority;
+ private Priority maxPriority;
+
+ private SlidingWindowCounterDoS slidingWindow = null;
+
+ private boolean truncateStringEnable;
+
+ public PreventDoSContextAMPImpl(Configuration conf) {
+ this.conf = conf;
+
+ this.truncateStringEnable =
+ conf.getBoolean(YarnConfiguration.DOS_TRUNCATE_STRING_ENABLED,
+ YarnConfiguration.DEFAULT_DOS_TRUNCATE_STRING_ENABLED);
+ setConfigurationLimitRequest();
+ setConfigurationThresholdSlidingWindow();
+ setConfigurationLimitRequests();
+ }
+
+ @Override
+ public void setConfigurationLimitRequest() {
+
+ // Single request limit for Register AM
+
+ this.maxLengthRegisterHost =
+ conf.getInt(YarnConfiguration.DOS_REQUEST_LIMIT_LENGTH_HOST,
+ YarnConfiguration.DEFAULT_DOS_REQUEST_LIMIT_LENGTH_HOST);
+ this.validRPCPort =
+ conf.getInt(YarnConfiguration.DOS_REQUEST_LIMIT_RPC_PORT,
+ YarnConfiguration.DEFAULT_DOS_REQUEST_LIMIT_RPC_PORT);
+
+ // Single request limit for Finish AM
+
+ this.maxLengthFinishDiagnostic =
+ conf.getInt(YarnConfiguration.DOS_REQUEST_LIMIT_LENGTH_DIAGNOSTIC,
+ YarnConfiguration.DEFAULT_DOS_REQUEST_LIMIT_LENGTH_DIAGNOSTIC);
+
+ // Single request limit for Allocate AM
+
+ this.maxLengthAllocateResourceName =
+ conf.getInt(YarnConfiguration.DOS_REQUEST_LIMIT_LENGTH_RESOURCE,
+ YarnConfiguration.DEFAULT_DOS_REQUEST_LIMIT_LENGTH_RESOURCE);
+ this.maxLengthAllocateNodeLabel =
+ conf.getInt(YarnConfiguration.DOS_REQUEST_LIMIT_LENGTH_NODELABEL,
+ YarnConfiguration.DEFAULT_DOS_REQUEST_LIMIT_LENGTH_NODELABEL);
+
+ this.maxVCores = conf.getInt(YarnConfiguration.DOS_REQUEST_LIMIT_MAX_VCORES,
+ YarnConfiguration.DEFAULT_DOS_REQUEST_LIMIT_MAX_VCORES);
+ this.maxMemory = conf.getInt(YarnConfiguration.DOS_REQUEST_LIMIT_MAX_MEMORY,
+ YarnConfiguration.DEFAULT_DOS_REQUEST_LIMIT_MAX_MEMORY);
+ this.maxSizeAskList =
+ conf.getInt(YarnConfiguration.DOS_REQUEST_LIMIT_ASK_LIST,
+ YarnConfiguration.DEFAULT_DOS_REQUEST_LIMIT_ASK_LIST);
+ this.maxSizeReleaseList =
+ conf.getInt(YarnConfiguration.DOS_REQUEST_LIMIT_RELEASE_LIST,
+ YarnConfiguration.DEFAULT_DOS_REQUEST_LIMIT_RELEASE_LIST);
+ this.maxSizeBlackList =
+ conf.getInt(YarnConfiguration.DOS_REQUEST_LIMIT_BLACKLIST,
+ YarnConfiguration.DEFAULT_DOS_REQUEST_LIMIT_BLACKLIST);
+ this.maxSizeChangeList =
+ conf.getInt(YarnConfiguration.DOS_REQUEST_LIMIT_CHANGELIST,
+ YarnConfiguration.DEFAULT_DOS_REQUEST_LIMIT_CHANGELIST);
+ this.maxSizeContainers =
+ conf.getInt(YarnConfiguration.DOS_REQUEST_LIMIT_CONTAINERS,
+ YarnConfiguration.DEFAULT_DOS_REQUEST_LIMIT_CONTAINERS);
+ this.minPriority = Priority.newInstance(
+ conf.getInt(YarnConfiguration.DOS_REQUEST_LIMIT_MIN_PRIORITY,
+ YarnConfiguration.DEFAULT_DOS_REQUEST_LIMIT_MIN_PRIORITY));
+ this.maxPriority = Priority.newInstance(
+ conf.getInt(YarnConfiguration.DOS_REQUEST_LIMIT_MAX_PRIORITY,
+ YarnConfiguration.DEFAULT_DOS_REQUEST_LIMIT_MAX_PRIORITY));
+ }
+
+ private String truncate(String stringToTruncate, int maxLength) {
+ if (stringToTruncate != null) {
+ return stringToTruncate.substring(0, maxLength);
+ }
+ return stringToTruncate;
+ }
+
+ /**
+ * Check if the host inside the RegisterApplicationMasterRequest is a valid
+ * input.
+ *
+ * @param host the host to validate.
+ * @return true if it is valid, false otherwise.
+ */
+ public boolean isLengthRegisterHostValid(String host) {
+ return host != null && host.length() <= this.maxLengthRegisterHost;
+ }
+
+ /**
+ * Truncate the host inside the RegisterApplicationMasterRequest to be
+ * compliant with a valid input.
+ *
+ * @param host the Host to truncate.
+ * @return the truncated host.
+ */
+ public String lengthRegisterHostTruncate(String host) {
+ return (truncateStringEnable) ? truncate(host, maxLengthRegisterHost)
+ : host;
+ }
+
+ /**
+ * Check if the RPC port inside the RegisterApplicationMasterRequest is a
+ * valid input.
+ *
+ * @param rpcPort the rpc port to validate.
+ * @return true if it is valid, false otherwise.
+ */
+ public boolean isRPCPortValid(int rpcPort) {
+ return rpcPort <= this.validRPCPort && rpcPort > 0;
+ }
+
+ /**
+ * Check if the url inside the RegisterApplicationMasterRequest is a valid
+ * address input.
+ *
+ * @param url the url port to validate.
+ * @return true if it is valid, false otherwise.
+ */
+ public boolean isURLValid(String url) {
+ // Ensure url is not null
+ if (url == null || url.isEmpty()) {
+ return false;
+ }
+ // Validate url is well formed
+ boolean hasScheme = url.contains("://");
+ URI uri = null;
+ try {
+ uri = hasScheme ? URI.create(url) : URI.create("dummyscheme://" + url);
+ } catch (IllegalArgumentException e) {
+ return false;
+ }
+ String host = uri.getHost();
+ int port = uri.getPort();
+ String path = uri.getPath();
+ if ((host == null) || (port < 0)
+ || (!hasScheme && path != null && !path.isEmpty())) {
+ return false;
+ }
+ return true;
+ }
+
+ /**
+ * Check if the diagnostic inside the FinishApplicationMasterRequest is a
+ * valid input.
+ *
+ * @param diagnostic the diagnostic to validate.
+ * @return true if it is valid, false otherwise.
+ */
+ public boolean isLengthFinishDiagnosticValid(String diagnostic) {
+ return (diagnostic != null
+ && diagnostic.length() <= this.maxLengthFinishDiagnostic);
+ }
+
+ /**
+ * Truncate the Diagnostic inside the FinishApplicationMasterRequest to be
+ * compliant with a valid input.
+ *
+ * @param diagnostic the Diagnostic to truncate.
+ * @return the truncated Diagnostic.
+ */
+ public String lengthFinishDiagnosticTruncate(String diagnostic) {
+ return (truncateStringEnable)
+ ? truncate(diagnostic, maxLengthFinishDiagnostic) : diagnostic;
+ }
+
+ /**
+ * Check if the responseId inside the AllocateRequest is a valid input.
+ *
+ * @param responseId the responseId to validate.
+ * @return true if it is valid, false otherwise.
+ */
+ public boolean isLastResponseIdValid(int responseId) {
+ if (responseId > this.lastResponseId) {
+ this.lastResponseId = responseId;
+ return true;
+ }
+ return false;
+ }
+
+ /**
+ * Check if the progress inside the AllocateRequest is a valid input.
+ *
+ * @param progress the progress to validate.
+ * @return true if it is valid, false otherwise.
+ */
+ public boolean isProgressValid(float progress) {
+ return progress >= 0 && progress <= 100;
+ }
+
+ /**
+ * Check if the size of AskList inside the AllocateRequest is a valid input.
+ *
+ * @param list the list to validate.
+ * @return true if it is valid, false otherwise.
+ */
+ public boolean isSizeAskListValid(List list) {
+ return list != null && list.size() <= this.maxSizeAskList;
+ }
+
+ /**
+ * Check if the Priority inside the AllocateRequest#ResourceRequest is a valid
+ * input.
+ *
+ * @param priority the priority to validate.
+ * @return true if it is valid, false otherwise.
+ */
+ public boolean isPriorityValid(Priority priority) {
+ return (priority != null && (priority.compareTo(minPriority) <= 0
+ && priority.compareTo(maxPriority) >= 0));
+ }
+
+ /**
+ * Check if the Resource Name inside the AllocateRequest#ResourceRequest is a
+ * valid input.
+ *
+ * @param resourceName the Resource Name to validate.
+ * @return true if it is valid, false otherwise.
+ */
+ public boolean isLengthAllocateResourceNameValid(String resourceName) {
+ return resourceName != null
+ && resourceName.length() <= this.maxLengthAllocateResourceName;
+ }
+
+ /**
+ * Truncate the Resource Name inside the AllocateRequest#ResourceRequest to be
+ * compliant with a valid input.
+ *
+ * @param resourceName the Resource Name to truncate.
+ * @return the truncated Resource Name.
+ */
+ public String lengthAllocateResourceNameTruncate(String resourceName) {
+ return (truncateStringEnable)
+ ? truncate(resourceName, maxLengthAllocateResourceName) : resourceName;
+ }
+
+ /**
+ * Check if the Resource inside the AllocateRequest#ResourceRequest is a valid
+ * input.
+ *
+ * @param resource the resource to validate.
+ * @return true if it is valid, false otherwise.
+ */
+ public boolean isCapabilityValid(Resource resource) {
+ return (resource != null
+ && (resource.getMemorySize() >= 0 && resource.getVirtualCores() >= 0
+ && resource.getMemorySize() <= this.maxMemory
+ && resource.getVirtualCores() <= this.maxVCores));
+ }
+
+ /**
+ * Check if the number of containers inside the
+ * AllocateRequest#ResourceRequest is a valid input.
+ *
+ * @param numContainers the number of containers to validate.
+ * @return true if it is valid, false otherwise.
+ */
+ public boolean isSizeContainersValid(int numContainers) {
+ return numContainers >= 0 && numContainers <= maxSizeContainers;
+ }
+
+ /**
+ * Check if the NodeLabel inside the AllocateRequest#ResourceRequest is a
+ * valid input.
+ *
+ * @param nodeLabel the NodeLabel to validate.
+ * @return true if it is valid, false otherwise.
+ */
+ public boolean isLengthAllocateNodeLabelValid(String nodeLabel) {
+ return nodeLabel != null
+ && nodeLabel.length() <= this.maxLengthAllocateNodeLabel;
+ }
+
+ /**
+ * Truncate the Node Label inside the AllocateRequest#ResourceRequest to be
+ * compliant with a valid input.
+ *
+ * @param nodeLabel the Node Label to truncate.
+ * @return the truncated Node Label.
+ */
+ public String lengthAllocateNodeLabelTruncate(String nodeLabel) {
+ return (truncateStringEnable)
+ ? truncate(nodeLabel, maxLengthAllocateNodeLabel) : nodeLabel;
+ }
+
+ /**
+ * Check if the size of ReleaseList inside the AllocateRequest is a valid
+ * input.
+ *
+ * @param list the list to validate.
+ * @return true if it is valid, false otherwise.
+ */
+ public boolean isSizeReleaseListValid(List list) {
+ return list != null && list.size() <= this.maxSizeReleaseList;
+ }
+
+ /**
+ * Check if the size of BlackList inside the AllocateRequest is a valid input.
+ *
+ * @param list the list to validate.
+ * @return true if it is valid, false otherwise.
+ */
+ public boolean isSizeBlackListValid(List list) {
+ return list != null && list.size() <= this.maxSizeBlackList;
+ }
+
+ /**
+ * Check if the size of ChangeList inside the AllocateRequest is a valid
+ * input.
+ *
+ * @param list the list to validate.
+ * @return true if it is valid, false otherwise.
+ */
+ public boolean isSizeChangeListValid(
+ List list) {
+ return list != null && list.size() <= this.maxSizeChangeList;
+ }
+
+ @Override
+ public void setConfigurationThresholdSlidingWindow() {
+
+ int slidingWindowSize =
+ conf.getInt(YarnConfiguration.DOS_SLIDING_WINDOW_SIZE,
+ YarnConfiguration.DEFAULT_DOS_SLIDING_WINDOW_SIZE);
+ int slidingWindowInterval =
+ conf.getInt(YarnConfiguration.DOS_SLIDING_WINDOW_ADVANCE_TIME_SEC,
+ YarnConfiguration.DEFAULT_DOS_SLIDING_WINDOW_ADVANCE_TIME_SEC);
+
+ slidingWindow = new SlidingWindowCounterDoS(slidingWindowSize);
+
+ // Sliding window configuration for Register AM
+
+ slidingWindow.setThreshold(RegisterApplicationMasterRequest.class.getName(),
+ conf.getInt(YarnConfiguration.DOS_SLIDING_WINDOW_THRESHOLD_REGISTER_AM,
+ YarnConfiguration.DEFAULT_DOS_SLIDING_WINDOW_THRESHOLD_REGISTER_AM));
+
+ // Sliding window configuration for Finish AM
+
+ slidingWindow.setThreshold(FinishApplicationMasterRequest.class.getName(),
+ conf.getInt(YarnConfiguration.DOS_SLIDING_WINDOW_THRESHOLD_FINISH_AM,
+ YarnConfiguration.DEFAULT_DOS_SLIDING_WINDOW_THRESHOLD_FINISH_AM));
+
+ // Sliding window configuration for Allocate AM
+
+ slidingWindow.setThreshold(AllocateRequest.class.getName(), conf.getInt(
+ YarnConfiguration.DOS_SLIDING_WINDOW_THRESHOLD_ALLOCATE_AM,
+ YarnConfiguration.DEFAULT_DOS_SLIDING_WINDOW_THRESHOLD_ALLOCATE_AM));
+
+ slidingWindow.setThreshold(ResourceRequest.class.getName(),
+ conf.getInt(YarnConfiguration.DOS_SLIDING_WINDOW_THRESHOLD_ASK_LIST,
+ YarnConfiguration.DEFAULT_DOS_SLIDING_WINDOW_THRESHOLD_ASK_LIST));
+ slidingWindow.setThreshold(ContainerId.class.getName(), conf.getInt(
+ YarnConfiguration.DOS_SLIDING_WINDOW_THRESHOLD_RELEASE_LIST,
+ YarnConfiguration.DEFAULT_DOS_SLIDING_WINDOW_THRESHOLD_RELEASE_LIST));
+ slidingWindow.setThreshold(ResourceBlacklistRequest.class.getName(),
+ conf.getInt(YarnConfiguration.DOS_SLIDING_WINDOW_THRESHOLD_BLACKLIST,
+ YarnConfiguration.DEFAULT_DOS_SLIDING_WINDOW_THRESHOLD_BLACKLIST));
+ slidingWindow.setThreshold(ContainerResourceChangeRequest.class.getName(),
+ conf.getInt(YarnConfiguration.DOS_SLIDING_WINDOW_THRESHOLD_CHANGELIST,
+ YarnConfiguration.DEFAULT_DOS_SLIDING_WINDOW_THRESHOLD_CHANGELIST));
+ slidingWindow.setThreshold(ResourceRequest.class.getName() + NUM_CONTAINERS,
+ conf.getInt(
+ YarnConfiguration.DOS_SLIDING_WINDOW_THRESHOLD_NUM_CONTAINERS,
+ YarnConfiguration.DEFAULT_DOS_SLIDING_WINDOW_THRESHOLD_NUM_CONTAINERS));
+
+ // Start sliding window
+
+ timer = new Timer();
+ timer.scheduleAtFixedRate(new SlidingWindowsAdvanceTask(),
+ slidingWindowInterval * 1000, slidingWindowInterval * 1000);
+ }
+
+ /**
+ * Check if we can accept the increment by 1 of the request class object.
+ *
+ * @param className the Class of the request to validate.
+ * @return true if it is valid, false otherwise.
+ */
+ public boolean isSlidingWindowIncrementCountValid(String className) {
+ try {
+ slidingWindow.increaseCount(className);
+ } catch (PreventDoSAttackException e) {
+ return false;
+ }
+ return true;
+ }
+
+ /**
+ * Check if we can accept the increment of the request class object.
+ *
+ * @param className the Class of the request to validate.
+ * @return true if it is valid, false otherwise.
+ */
+ public boolean isSlidingWindowIncrementCountValid(String className,
+ int size) {
+ try {
+ slidingWindow.increaseCount(className, size);
+ } catch (PreventDoSAttackException e) {
+ return false;
+ }
+ return true;
+ }
+
+ @Override
+ public void setConfigurationLimitRequests() {
+
+ // Requests limit for Register AM
+
+ this.limitRegisterRequests =
+ conf.getInt(YarnConfiguration.DOS_MAX_REQUESTS_REGISTER_AM,
+ YarnConfiguration.DEFAULT_DOS_MAX_REQUESTS_REGISTER_AM);
+
+ // Requests limit for Finish AM
+
+ this.limitFinishRequests =
+ conf.getInt(YarnConfiguration.DOS_MAX_REQUESTS_FINISH_AM,
+ YarnConfiguration.DEFAULT_DOS_MAX_REQUESTS_FINISH_AM);
+
+ // Requests limit for Allocate AM
+
+ this.limitAllocateRequests =
+ conf.getInt(YarnConfiguration.DOS_MAX_REQUESTS_ALLOCATE_AM,
+ YarnConfiguration.DEFAULT_DOS_MAX_REQUESTS_ALLOCATE_AM);
+
+ this.limitAskListRequests =
+ conf.getInt(YarnConfiguration.DOS_MAX_REQUESTS_ASK_LIST,
+ YarnConfiguration.DEFAULT_DOS_MAX_REQUESTS_ASK_LIST);
+ this.limitReleaseListRequests =
+ conf.getInt(YarnConfiguration.DOS_MAX_REQUESTS_RELEASE_LIST,
+ YarnConfiguration.DEFAULT_DOS_MAX_REQUESTS_RELEASE_LIST);
+ this.limitBlackListRequests =
+ conf.getInt(YarnConfiguration.DOS_MAX_REQUESTS_BLACKLIST,
+ YarnConfiguration.DEFAULT_DOS_MAX_REQUESTS_BLACKLIST);
+ this.limitChangeListRequests =
+ conf.getInt(YarnConfiguration.DOS_MAX_REQUESTS_CHANGELIST,
+ YarnConfiguration.DEFAULT_DOS_MAX_REQUESTS_CHANGELIST);
+ this.limitContainersRequests =
+ conf.getInt(YarnConfiguration.DOS_MAX_REQUESTS_NUM_CONTAINERS,
+ YarnConfiguration.DEFAULT_DOS_MAX_REQUESTS_NUM_CONTAINERS);
+ }
+
+ /**
+ * Check if we can accept the increment of the counter for
+ * RegisterApplicationMasterRequest.
+ *
+ * @return true if it is valid, false otherwise.
+ */
+ public boolean isRegisterRequestsIncrementValid() {
+ return ++counterRegisterRequests <= limitRegisterRequests;
+ }
+
+ /**
+ * Check if we can accept the increment of the counter for AllocateRequest.
+ *
+ * @return true if it is valid, false otherwise.
+ */
+ public boolean isAllocateRequestsIncrementValid() {
+ return ++counterAllocateRequests <= limitAllocateRequests;
+ }
+
+ /**
+ * Check if we can accept the increment of the counter for
+ * FinishApplicationMasterRequest.
+ *
+ * @return true if it is valid, false otherwise.
+ */
+ public boolean isFinishRequestsIncrementValid() {
+ return ++counterFinishRequests <= limitFinishRequests;
+ }
+
+ /**
+ * Check if we can accept the increment of the counter for
+ * AllocateRequest#AskList.
+ *
+ * @return true if it is valid, false otherwise.
+ */
+ public boolean isAskListRequestsIncrementValid(int increment) {
+ counterAskListRequests += increment;
+ return counterAskListRequests <= limitAskListRequests;
+ }
+
+ /**
+ * Check if we can accept the increment of the counter for
+ * AllocateRequest#ReleaseList.
+ *
+ * @return true if it is valid, false otherwise.
+ */
+ public boolean isReleaseListRequestsIncrementValid(int increment) {
+ counterReleaseListRequests += increment;
+ return counterReleaseListRequests <= limitReleaseListRequests;
+ }
+
+ /**
+ * Check if we can accept the increment of the counter for
+ * AllocateRequest#BlackList.
+ *
+ * @return true if it is valid, false otherwise.
+ */
+ public boolean isBlackListRequestsIncrementValid(int increment) {
+ counterBlackListRequests += increment;
+ return counterBlackListRequests <= limitBlackListRequests;
+ }
+
+ /**
+ * Check if we can accept the increment of the counter for
+ * AllocateRequest#ChangeList.
+ *
+ * @return true if it is valid, false otherwise.
+ */
+ public boolean isChangeListRequestsIncrementValid(int increment) {
+ counterChangeListRequests += increment;
+ return counterChangeListRequests <= limitChangeListRequests;
+ }
+
+ /**
+ * Check if we can accept the increment of the counter for
+ * AllocateRequest#ResourceRequest#NumContainers.
+ *
+ * @return true if it is valid, false otherwise.
+ */
+ public boolean isContainersRequestsIncrementValid(int increment) {
+ counterContainersRequests += increment;
+ return counterContainersRequests <= limitContainersRequests;
+ }
+
+ private class SlidingWindowsAdvanceTask extends TimerTask {
+
+ public void run() {
+ slidingWindow.getCountsThenAdvanceWindow();
+ }
+ }
+
+ @Override
+ public void shutdownTimer() {
+ if (timer != null) {
+ timer.cancel();
+ }
+ }
+
+}
\ No newline at end of file
diff --git hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-common/src/main/java/org/apache/hadoop/yarn/server/preventdos/PreventDoSPolicy.java hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-common/src/main/java/org/apache/hadoop/yarn/server/preventdos/PreventDoSPolicy.java
new file mode 100644
index 0000000..4dde2be
--- /dev/null
+++ hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-common/src/main/java/org/apache/hadoop/yarn/server/preventdos/PreventDoSPolicy.java
@@ -0,0 +1,35 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.hadoop.yarn.server.preventdos;
+
+/**
+ * Admin policies of action in case of Potential DoS attack.
+ */
+public enum PreventDoSPolicy {
+
+ /*
+ * Ignore: It logs the cause, this option is used in case of testing.
+ */
+ IGNORE,
+
+ /*
+ * Reject: It rejects the request by throwing an exception.
+ */
+ REJECT,
+}
\ No newline at end of file
diff --git hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-common/src/main/java/org/apache/hadoop/yarn/server/preventdos/SlidingWindowCounter.java hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-common/src/main/java/org/apache/hadoop/yarn/server/preventdos/SlidingWindowCounter.java
new file mode 100644
index 0000000..80669fc
--- /dev/null
+++ hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-common/src/main/java/org/apache/hadoop/yarn/server/preventdos/SlidingWindowCounter.java
@@ -0,0 +1,141 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.hadoop.yarn.server.preventdos;
+
+import java.io.Serializable;
+import java.util.Map;
+
+/**
+ * This class counts objects in a sliding window fashion.
+ *
+ * It is designed 1) to give multiple "producer" threads write access to the
+ * counter, i.e. being able to increment counts of objects, and 2) to give a
+ * single "consumer" thread (e.g. {@link PeriodicSlidingWindowCounter}) read
+ * access to the counter. Whenever the consumer thread performs a read
+ * operation, this class will advance the head slot of the sliding window
+ * counter. This means that the consumer thread indirectly controls where writes
+ * of the producer threads will go to. Also, by itself this class will not
+ * advance the head slot.
+ *
+ * A note for analyzing data based on a sliding window count: During the initial
+ * windowLengthInSlots iterations, this sliding window counter will
+ * always return object counts that are equal or greater than in the previous
+ * iteration. This is the effect of the counter "loading up" at the very start
+ * of its existence. Conceptually, this is the desired behavior.
+ *
+ * To give an example, using a counter with 5 slots which for the sake of this
+ * example represent 1 minute of time each:
+ *
+ *
+ *
+ * {@code
+ * Sliding window counts of an object X over time
+ *
+ * Minute (timeline):
+ * 1 2 3 4 5 6 7 8
+ *
+ * Observed counts per minute:
+ * 1 1 1 1 0 0 0 0
+ *
+ * Counts returned by counter:
+ * 1 2 3 4 4 3 2 1
+ * }
+ *
+ *
+ * As you can see in this example, for the first
+ * windowLengthInSlots (here: the first five minutes) the counter
+ * will always return counts equal or greater than in the previous iteration (1,
+ * 2, 3, 4, 4). This initial load effect needs to be accounted for whenever you
+ * want to perform analyses such as trending topics; otherwise your analysis
+ * algorithm might falsely identify the object to be trending as the counter
+ * seems to observe continuously increasing counts. Also, note that during the
+ * initial load phase every object will exhibit increasing counts.
+ *
+ * On a high-level, the counter exhibits the following behavior: If you asked
+ * the example counter after two minutes, "how often did you count the object
+ * during the past five minutes?", then it should reply "I have counted it 2
+ * times in the past five minutes", implying that it can only account for the
+ * last two of those five minutes because the counter was not running before
+ * that time.
+ *
+ * @param The type of those objects we want to count.
+ */
+public class SlidingWindowCounter implements Serializable {
+
+ private static final long serialVersionUID = -2645063988768785810L;
+
+ private SlotBasedAccumulator objCounter;
+
+ protected int headSlot;
+ protected int tailSlot;
+ protected int windowLengthInSlots;
+
+ public SlidingWindowCounter(int windowLengthInSlots) {
+ if (windowLengthInSlots < 2) {
+ throw new IllegalArgumentException(
+ "Window length in slots must be at least two (you requested "
+ + windowLengthInSlots + ")");
+ }
+ this.windowLengthInSlots = windowLengthInSlots;
+ this.objCounter = new SlotBasedAccumulator(this.windowLengthInSlots);
+
+ this.headSlot = 0;
+ this.tailSlot = slotAfter(headSlot);
+ }
+
+ public void incrementCount(T obj, long amount) {
+ objCounter.incrementCount(obj, headSlot, amount);
+ }
+
+ public long computeTotalCount(T obj) {
+ Long count = objCounter.getCounts().get(obj);
+ if (count == null) {
+ return 0;
+ }
+ return count;
+ }
+
+ /**
+ * Return the current (total) counts of all tracked objects, then advance the
+ * window.
+ *
+ * Whenever this method is called, we consider the counts of the current
+ * sliding window to be available to and successfully processed "upstream"
+ * (i.e. by the caller). Knowing this we will start counting any subsequent
+ * objects within the next "chunk" of the sliding window.
+ *
+ * @return The current (total) counts of all tracked objects.
+ */
+ public Map getCountsThenAdvanceWindow() {
+ Map counts = objCounter.getCounts();
+ objCounter.wipeZeros();
+ objCounter.wipeSlot(tailSlot);
+ advanceHead();
+ return counts;
+ }
+
+ private void advanceHead() {
+ headSlot = tailSlot;
+ tailSlot = slotAfter(tailSlot);
+ }
+
+ private int slotAfter(int slot) {
+ return (slot + 1) % windowLengthInSlots;
+ }
+
+}
diff --git hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-common/src/main/java/org/apache/hadoop/yarn/server/preventdos/SlidingWindowCounterDoS.java hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-common/src/main/java/org/apache/hadoop/yarn/server/preventdos/SlidingWindowCounterDoS.java
new file mode 100644
index 0000000..1c5264c
--- /dev/null
+++ hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-common/src/main/java/org/apache/hadoop/yarn/server/preventdos/SlidingWindowCounterDoS.java
@@ -0,0 +1,93 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.hadoop.yarn.server.preventdos;
+
+import java.util.HashMap;
+import java.util.Map;
+
+/**
+ *
+ *
+ */
+public final class SlidingWindowCounterDoS {
+
+ private SlidingWindowCounter slidingWindow;
+ private Map dosThreshold = new HashMap();
+
+ public SlidingWindowCounterDoS(int windowLengthInSlots) {
+ slidingWindow = new SlidingWindowCounter<>(windowLengthInSlots);
+ }
+
+ /**
+ * Set the threshold for the Sliding window.
+ *
+ * @param obj the value we want to create the counter.
+ * @param threshold the threshold for the obj.
+ */
+ public void setThreshold(T obj, long threshold) {
+ dosThreshold.put(obj, threshold);
+ }
+
+ /**
+ * Increase the counter by 1 for the value in input.
+ *
+ * @param obj the value to increase the counter.
+ * @throws PreventDoSAttackException in case the increment exceed the
+ * threshold.
+ */
+ public void increaseCount(T obj) throws PreventDoSAttackException {
+ increaseCount(obj, 1);
+ }
+
+ /**
+ * Increase the counter for the value in input.
+ *
+ * @param obj the value to increase the counter.
+ * @throws PreventDoSAttackException in case the increment exceed the
+ * threshold.
+ */
+ public void increaseCount(T obj, long amount)
+ throws PreventDoSAttackException {
+ checkThreshold(obj, slidingWindow.computeTotalCount(obj) + amount);
+ slidingWindow.incrementCount(obj, amount);
+ }
+
+ /**
+ * Check the respective counter for the value in input and it compares with
+ * threshold of the same.
+ *
+ * @param obj the value to check.
+ * @param count the current counter of the object.
+ * @throws PreventDoSAttackException in case the increment exceed the
+ * threshold.
+ */
+ private void checkThreshold(T obj, long count)
+ throws PreventDoSAttackException {
+ if (dosThreshold.get(obj) < count) {
+ throw new PreventDoSAttackException();
+ }
+ }
+
+ /**
+ * Advance the sliding window to the next slot.
+ */
+ public void getCountsThenAdvanceWindow() {
+ slidingWindow.getCountsThenAdvanceWindow();
+ }
+
+}
diff --git hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-common/src/main/java/org/apache/hadoop/yarn/server/preventdos/SlotBasedAccumulator.java hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-common/src/main/java/org/apache/hadoop/yarn/server/preventdos/SlotBasedAccumulator.java
new file mode 100644
index 0000000..76ab863
--- /dev/null
+++ hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-common/src/main/java/org/apache/hadoop/yarn/server/preventdos/SlotBasedAccumulator.java
@@ -0,0 +1,144 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.hadoop.yarn.server.preventdos;
+
+import java.io.Serializable;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.Map;
+import java.util.Set;
+
+/**
+ * This class provides per-slot counts of the occurrences of objects.
+ *
+ * It can be used, for instance, as a building block for implementing sliding
+ * window counting of objects.
+ *
+ * This class has an additional incrementCount(obj, slot, amount) that allows
+ * users to increase a slot's count by any amount specified
+ *
+ * This class also has an internal totals map that is consistently updated for
+ * better performance
+ *
+ * @param The type of those objects we want to count.
+ */
+public class SlotBasedAccumulator implements Serializable {
+
+ private static final long serialVersionUID = 4858185737378394432L;
+
+ protected final Map objToCounts = new HashMap();
+ protected final Map objToTotals = new HashMap();
+ protected final int numSlots;
+
+ public SlotBasedAccumulator(int numSlots) {
+ if (numSlots <= 0) {
+ throw new IllegalArgumentException(
+ "Number of slots must be greater than zero (you requested " + numSlots
+ + ")");
+ }
+ this.numSlots = numSlots;
+ }
+
+ public void incrementCount(T obj, int slot) {
+ incrementCount(obj, slot, 1);
+ }
+
+ public void incrementCount(T obj, int slot, long amount) {
+ long[] counts = objToCounts.get(obj);
+ if (counts == null) {
+ counts = new long[this.numSlots];
+ objToCounts.put(obj, counts);
+ }
+ counts[slot] += amount;
+ updateTotal(obj, amount);
+ }
+
+ public long getCount(T obj, int slot) {
+ long[] counts = objToCounts.get(obj);
+ if (counts == null) {
+ return 0;
+ } else {
+ return counts[slot];
+ }
+ }
+
+ public Map getCounts() {
+ // Create a deep copy of the internal total map
+ Map result = new HashMap();
+ for (T obj : objToCounts.keySet()) {
+ result.put(obj, objToTotals.get(obj).longValue());
+ }
+ return result;
+ }
+
+ /**
+ * Reset the slot count of any tracked objects to zero for the given slot.
+ *
+ * @param slot
+ */
+ public void wipeSlot(int slot) {
+ for (T obj : objToCounts.keySet()) {
+ resetSlotCountToZero(obj, slot);
+ }
+ }
+
+ /**
+ * Resets the slot for obj to zero
+ *
+ * @param obj
+ * @param slot
+ */
+ private void resetSlotCountToZero(T obj, int slot) {
+ long[] counts = objToCounts.get(obj);
+ updateTotal(obj, -counts[slot]);
+ counts[slot] = 0;
+ }
+
+ /**
+ * Helper function to update the totals cache map
+ *
+ * @param obj
+ * @param diff
+ */
+ private void updateTotal(T obj, long diff) {
+ Long total = objToTotals.get(obj);
+ if (total == null) {
+ total = new Long(0);
+ }
+ long result = total + diff;
+ objToTotals.put(obj, result);
+ }
+
+ /**
+ * Remove any object from the counter whose total count is zero (to free up
+ * memory).
+ */
+ public void wipeZeros() {
+ Set objToBeRemoved = new HashSet();
+ for (T obj : objToCounts.keySet()) {
+ if (objToTotals.get(obj) == 0) {
+ objToBeRemoved.add(obj);
+ }
+ }
+ for (T obj : objToBeRemoved) {
+ objToCounts.remove(obj);
+ objToTotals.remove(obj);
+ }
+ }
+}
diff --git hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-common/src/test/java/org/apache/hadoop/yarn/server/preventdos/TestPreventDoSAction.java hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-common/src/test/java/org/apache/hadoop/yarn/server/preventdos/TestPreventDoSAction.java
new file mode 100644
index 0000000..77f1243
--- /dev/null
+++ hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-common/src/test/java/org/apache/hadoop/yarn/server/preventdos/TestPreventDoSAction.java
@@ -0,0 +1,59 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.hadoop.yarn.server.preventdos;
+
+import org.junit.Assert;
+import org.junit.Test;
+
+/**
+ * Unit tests for PreventDoSAction.
+ */
+public class TestPreventDoSAction {
+
+ /**
+ * This test verifies if we set the PreventDoSPolicy equals to IGNORE, the
+ * PreventDoSAction does not throw an exception.
+ */
+ @Test
+ public void testIgnoreAction() {
+ PreventDoSAction action = new PreventDoSAction(PreventDoSPolicy.IGNORE);
+ try {
+ action.prevent(PreventDoSCause.REQUEST_GET_ASK_LIST_CAUSE);
+ } catch (PreventDoSAttackException e) {
+ Assert.fail();
+ }
+ }
+
+ /**
+ * This test verifies if we set the PreventDoSPolicy equals to REJECT, the
+ * PreventDoSAction throw an exception.
+ */
+ @Test
+ public void testRejectAction() {
+ PreventDoSAction action = new PreventDoSAction(PreventDoSPolicy.REJECT);
+ try {
+ action.prevent(PreventDoSCause.REQUEST_GET_ASK_LIST_CAUSE);
+ Assert.fail();
+ } catch (PreventDoSAttackException e) {
+ Assert.assertTrue(e.getMessage()
+ .contains(PreventDoSCause.REQUEST_GET_ASK_LIST_CAUSE.toString()));
+ }
+ }
+
+}
diff --git hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-common/src/test/java/org/apache/hadoop/yarn/server/preventdos/TestPreventDoSContextAMPImpl.java hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-common/src/test/java/org/apache/hadoop/yarn/server/preventdos/TestPreventDoSContextAMPImpl.java
new file mode 100644
index 0000000..8d3eef6
--- /dev/null
+++ hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-common/src/test/java/org/apache/hadoop/yarn/server/preventdos/TestPreventDoSContextAMPImpl.java
@@ -0,0 +1,676 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.hadoop.yarn.server.preventdos;
+
+import java.util.ArrayList;
+import java.util.List;
+
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.yarn.api.protocolrecords.AllocateRequest;
+import org.apache.hadoop.yarn.api.protocolrecords.FinishApplicationMasterRequest;
+import org.apache.hadoop.yarn.api.protocolrecords.RegisterApplicationMasterRequest;
+import org.apache.hadoop.yarn.api.records.ApplicationAttemptId;
+import org.apache.hadoop.yarn.api.records.ApplicationId;
+import org.apache.hadoop.yarn.api.records.ContainerId;
+import org.apache.hadoop.yarn.api.records.ContainerResourceChangeRequest;
+import org.apache.hadoop.yarn.api.records.Priority;
+import org.apache.hadoop.yarn.api.records.Resource;
+import org.apache.hadoop.yarn.api.records.ResourceBlacklistRequest;
+import org.apache.hadoop.yarn.api.records.ResourceRequest;
+import org.apache.hadoop.yarn.conf.YarnConfiguration;
+import org.junit.Assert;
+import org.junit.Test;
+
+/**
+ * Unit tests for PreventDoSContextAMPImpl.
+ */
+public class TestPreventDoSContextAMPImpl {
+
+ private static final int maxStringLength = 10;
+ private static final int maxRPCPort = 1000;
+ private static final int maxVCores = 6;
+ private static final int maxMemory = 100 * 1024;
+ private static final int maxContainers = 100;
+ private static final int maxListSize = 10;
+ private static final int minPriority = 1;
+ private static final int maxPriority = 10;
+
+ private static final int slidingWindowSize = 3;
+ private static final int slidingWindowAdvanceTime = 2;
+ private static final int thresholdNumberRequests = 3;
+ private static final int thresholdListSize = 10;
+ private static final int thresholdContainers = 50;
+
+ private static final int maxNumberRequests = 3;
+ private static final int maxListSizeRequests = 10;
+ private static final int maxContainersRequests = 50;
+
+ private void setUpRequestLimit(Configuration conf) {
+
+ // Single request limit for Register AM
+
+ conf.setInt(YarnConfiguration.DOS_REQUEST_LIMIT_LENGTH_HOST,
+ maxStringLength);
+ conf.setInt(YarnConfiguration.DOS_REQUEST_LIMIT_RPC_PORT, maxRPCPort);
+
+ // Single request limit for Finish AM
+
+ conf.setInt(YarnConfiguration.DOS_REQUEST_LIMIT_LENGTH_DIAGNOSTIC,
+ maxStringLength);
+
+ // Single request limit for Allocate AM
+
+ conf.setInt(YarnConfiguration.DOS_REQUEST_LIMIT_LENGTH_RESOURCE,
+ maxStringLength);
+ conf.setInt(YarnConfiguration.DOS_REQUEST_LIMIT_LENGTH_NODELABEL,
+ maxStringLength);
+
+ conf.setInt(YarnConfiguration.DOS_REQUEST_LIMIT_MAX_VCORES, maxVCores);
+ conf.setInt(YarnConfiguration.DOS_REQUEST_LIMIT_MAX_MEMORY, maxMemory);
+ conf.setInt(YarnConfiguration.DOS_REQUEST_LIMIT_ASK_LIST, maxListSize);
+ conf.setInt(YarnConfiguration.DOS_REQUEST_LIMIT_RELEASE_LIST, maxListSize);
+ conf.setInt(YarnConfiguration.DOS_REQUEST_LIMIT_BLACKLIST, maxListSize);
+ conf.setInt(YarnConfiguration.DOS_REQUEST_LIMIT_CHANGELIST, maxListSize);
+ conf.setInt(YarnConfiguration.DOS_REQUEST_LIMIT_CONTAINERS, maxContainers);
+ conf.setInt(YarnConfiguration.DOS_REQUEST_LIMIT_MIN_PRIORITY, minPriority);
+ conf.setInt(YarnConfiguration.DOS_REQUEST_LIMIT_MAX_PRIORITY, maxPriority);
+ }
+
+ private void setUpRequestsLimit(Configuration conf) {
+ // Requests limit for Register AM
+
+ conf.setInt(YarnConfiguration.DOS_MAX_REQUESTS_REGISTER_AM,
+ maxNumberRequests);
+
+ // Requests limit for Finish AM
+
+ conf.setInt(YarnConfiguration.DOS_MAX_REQUESTS_FINISH_AM,
+ maxNumberRequests);
+
+ // Requests limit for Allocate AM
+ conf.setInt(YarnConfiguration.DOS_MAX_REQUESTS_ALLOCATE_AM,
+ maxNumberRequests);
+
+ conf.setInt(YarnConfiguration.DOS_MAX_REQUESTS_ASK_LIST,
+ maxListSizeRequests);
+ conf.setInt(YarnConfiguration.DOS_MAX_REQUESTS_RELEASE_LIST,
+ maxListSizeRequests);
+ conf.setInt(YarnConfiguration.DOS_MAX_REQUESTS_BLACKLIST,
+ maxListSizeRequests);
+ conf.setInt(YarnConfiguration.DOS_MAX_REQUESTS_CHANGELIST,
+ maxListSizeRequests);
+ conf.setInt(YarnConfiguration.DOS_MAX_REQUESTS_NUM_CONTAINERS,
+ maxContainersRequests);
+ }
+
+ private void setUpSlidingWindowThreshold(Configuration conf) {
+
+ conf.setInt(YarnConfiguration.DOS_SLIDING_WINDOW_SIZE, slidingWindowSize);
+ conf.setInt(YarnConfiguration.DOS_SLIDING_WINDOW_ADVANCE_TIME_SEC,
+ slidingWindowAdvanceTime);
+
+ // Sliding window configuration for Register AM
+
+ conf.setInt(YarnConfiguration.DOS_SLIDING_WINDOW_THRESHOLD_REGISTER_AM,
+ thresholdNumberRequests);
+
+ // Sliding window configuration for Finish AM
+
+ conf.setInt(YarnConfiguration.DOS_SLIDING_WINDOW_THRESHOLD_FINISH_AM,
+ thresholdNumberRequests);
+
+ // Sliding window configuration for Allocate AM
+
+ conf.setInt(YarnConfiguration.DOS_SLIDING_WINDOW_THRESHOLD_ALLOCATE_AM,
+ thresholdNumberRequests);
+
+ conf.setInt(YarnConfiguration.DOS_SLIDING_WINDOW_THRESHOLD_ASK_LIST,
+ thresholdListSize);
+ conf.setInt(YarnConfiguration.DOS_SLIDING_WINDOW_THRESHOLD_RELEASE_LIST,
+ thresholdListSize);
+ conf.setInt(YarnConfiguration.DOS_SLIDING_WINDOW_THRESHOLD_BLACKLIST,
+ thresholdListSize);
+ conf.setInt(YarnConfiguration.DOS_SLIDING_WINDOW_THRESHOLD_CHANGELIST,
+ thresholdListSize);
+
+ conf.setInt(YarnConfiguration.DOS_SLIDING_WINDOW_THRESHOLD_NUM_CONTAINERS,
+ thresholdContainers);
+ }
+
+ /**
+ * This tests validates the correctness of the logic of
+ * PreventDoSContextAMPImpl for Single Request checks.
+ */
+ @Test
+ public void testSingleRequestLimit() {
+
+ Configuration conf = new YarnConfiguration();
+ setUpRequestLimit(conf);
+ PreventDoSContextAMPImpl context = new PreventDoSContextAMPImpl(conf);
+
+ // Check RegisterApplicationMasterRequest#Hostname
+
+ Assert.assertTrue(context.isLengthRegisterHostValid("HostNameOk"));
+ Assert.assertFalse(
+ context.isLengthRegisterHostValid("HostNameLongerThanLimit"));
+
+ // Check RegisterApplicationMasterRequest#RPCPort
+
+ Assert.assertTrue(context.isRPCPortValid(50));
+ Assert.assertFalse(context.isRPCPortValid(-10));
+ Assert.assertFalse(context.isRPCPortValid(1001));
+
+ // Check RegisterApplicationMasterRequest#TrackingURL
+
+ Assert.assertTrue(context.isURLValid("Address:10"));
+ Assert.assertTrue(context.isURLValid("https://Address:10"));
+ Assert.assertFalse(context.isURLValid("//..."));
+ Assert.assertFalse(context.isURLValid("BadAddress:10:10"));
+
+ // Check FinishApplicationMasterRequest#Diagnostic
+
+ Assert.assertTrue(context.isLengthFinishDiagnosticValid("Diagnostic"));
+ Assert.assertFalse(
+ context.isLengthFinishDiagnosticValid("DiagnosticLongerThanLimit"));
+
+ // Check AllocateRequest#AskList
+
+ List rrs = new ArrayList();
+ for (int i = 0; i < maxListSize - 2; i++) {
+ rrs.add(ResourceRequest.newInstance(Priority.newInstance(1), "host",
+ Resource.newInstance(1, 1), 1));
+ }
+ Assert.assertTrue(context.isSizeAskListValid(rrs));
+
+ Assert.assertFalse(context.isSizeAskListValid(null));
+
+ rrs = new ArrayList();
+ for (int i = 0; i < maxListSize + 2; i++) {
+ rrs.add(ResourceRequest.newInstance(Priority.newInstance(1), "host",
+ Resource.newInstance(1, 1), 1));
+ }
+ Assert.assertFalse(context.isSizeAskListValid(rrs));
+
+ // Check AllocateRequest#AskList#ResourceRequest#Priority
+
+ Assert.assertTrue(
+ context.isPriorityValid(Priority.newInstance(maxPriority - 1)));
+ Assert.assertFalse(
+ context.isPriorityValid(Priority.newInstance(maxPriority + 1)));
+ Assert.assertFalse(
+ context.isPriorityValid(Priority.newInstance(minPriority - 1)));
+
+ // Check AllocateRequest#AskList#ResourceRequest#ResourceName
+
+ Assert.assertTrue(context.isLengthAllocateResourceNameValid("NodeOk"));
+ Assert.assertFalse(
+ context.isLengthAllocateResourceNameValid("NodeLongerThanLimit"));
+
+ // Check AllocateRequest#AskList#ResourceRequest#Resource
+
+ Assert.assertTrue(context
+ .isCapabilityValid(Resource.newInstance(maxMemory - 2, maxVCores - 2)));
+ Assert.assertFalse(
+ context.isCapabilityValid(Resource.newInstance(-2, maxVCores - 2)));
+ Assert.assertFalse(
+ context.isCapabilityValid(Resource.newInstance(maxMemory - 2, -2)));
+ Assert.assertFalse(context
+ .isCapabilityValid(Resource.newInstance(maxMemory + 2, maxVCores - 2)));
+ Assert.assertFalse(context
+ .isCapabilityValid(Resource.newInstance(maxMemory - 2, maxVCores + 2)));
+ Assert.assertFalse(context.isCapabilityValid(Resource.newInstance(-2, -2)));
+ Assert.assertFalse(context
+ .isCapabilityValid(Resource.newInstance(maxMemory + 2, maxVCores + 2)));
+
+ // Check AllocateRequest#AskList#ResourceRequest#NumContainers
+
+ Assert.assertTrue(context.isSizeContainersValid(maxContainers - 2));
+ Assert.assertFalse(context.isSizeContainersValid(maxContainers + 2));
+ Assert.assertFalse(context.isSizeContainersValid(-2));
+
+ // Check AllocateRequest#AskList#ResourceRequest#NodeLabels
+
+ Assert.assertTrue(context.isLengthAllocateNodeLabelValid("NodeLabel"));
+ Assert.assertFalse(
+ context.isLengthAllocateNodeLabelValid("NodeLabelLongerThanLimit"));
+
+ // Check AllocateRequest#Progress
+
+ Assert.assertTrue(context.isProgressValid(20));
+ Assert.assertFalse(context.isProgressValid(-10));
+ Assert.assertFalse(context.isProgressValid(110));
+
+ // Check AllocateRequest#LastResponseId
+
+ Assert.assertTrue(context.isLastResponseIdValid(1));
+ Assert.assertTrue(context.isLastResponseIdValid(3));
+ Assert.assertFalse(context.isLastResponseIdValid(2));
+
+ // Check AllocateRequest#ReleaseList
+
+ ApplicationId appId = ApplicationId.newInstance(1, 1);
+ ApplicationAttemptId attemptId = ApplicationAttemptId.newInstance(appId, 1);
+ List containers = new ArrayList();
+ for (int i = 0; i < maxListSize - 2; i++) {
+ containers.add(ContainerId.newContainerId(attemptId, i));
+ }
+ Assert.assertTrue(context.isSizeReleaseListValid(containers));
+
+ containers = new ArrayList();
+ for (int i = 0; i < maxListSize + 2; i++) {
+ containers.add(ContainerId.newContainerId(attemptId, i));
+ }
+ Assert.assertFalse(context.isSizeReleaseListValid(containers));
+
+ // Check AllocateRequest#BlackList
+
+ List blacklists = new ArrayList();
+ for (int i = 0; i < maxListSize - 2; i++) {
+ blacklists.add("Node");
+ }
+ Assert.assertTrue(context.isSizeBlackListValid(blacklists));
+
+ blacklists = new ArrayList();
+ for (int i = 0; i < maxListSize + 2; i++) {
+ blacklists.add("Node");
+ }
+ Assert.assertFalse(context.isSizeBlackListValid(blacklists));
+
+ // Check AllocateRequest#ResourceChangeList
+
+ ContainerResourceChangeRequest crcr = ContainerResourceChangeRequest
+ .newInstance(ContainerId.newContainerId(attemptId, 1),
+ Resource.newInstance(maxMemory, maxVCores));
+ List changeResource =
+ new ArrayList();
+ for (int i = 0; i < maxListSize - 2; i++) {
+ changeResource.add(crcr);
+ }
+ Assert.assertTrue(context.isSizeChangeListValid(changeResource));
+
+ changeResource = new ArrayList();
+ for (int i = 0; i < maxListSize + 2; i++) {
+ changeResource.add(crcr);
+ }
+ Assert.assertFalse(context.isSizeChangeListValid(changeResource));
+
+ }
+
+ /**
+ * This tests validates the correctness of the logic of
+ * PreventDoSContextAMPImpl for Single Request truncation.
+ */
+ @Test
+ public void testSingleRequestLimitTruncate() {
+
+ Configuration conf = new YarnConfiguration();
+ setUpRequestLimit(conf);
+ conf.setBoolean(YarnConfiguration.DOS_TRUNCATE_STRING_ENABLED, true);
+
+ PreventDoSContextAMPImpl context = new PreventDoSContextAMPImpl(conf);
+
+ String host = "HostNameLongerThanLimit";
+ String diagnostic = "DiagnosticLongerThanLimit";
+ String nodeLabel = "NodeLableLongerThanLimit";
+ String node = "NodeLongerThanLimit";
+
+ // Truncate RegisterApplicationMasterRequest#Hostname
+
+ String newHost = context.lengthRegisterHostTruncate(host);
+ Assert.assertEquals(maxStringLength, newHost.length());
+ Assert.assertTrue(host.startsWith(newHost));
+
+ // Truncate FinishApplicationMasterRequest#Diagnostic
+
+ String newDiagnostic = context.lengthFinishDiagnosticTruncate(diagnostic);
+ Assert.assertEquals(maxStringLength, newDiagnostic.length());
+ Assert.assertTrue(diagnostic.startsWith(newDiagnostic));
+
+ // Truncate AllocateRequest#AskList#ResourceRequest#NodeLabels
+
+ String newNodeLabel = context.lengthAllocateNodeLabelTruncate(nodeLabel);
+ Assert.assertEquals(maxStringLength, newNodeLabel.length());
+ Assert.assertTrue(nodeLabel.startsWith(newNodeLabel));
+
+ // Truncate AllocateRequest#AskList#ResourceRequest#ResourceName
+
+ String newNode = context.lengthAllocateResourceNameTruncate(node);
+ Assert.assertEquals(maxStringLength, newNode.length());
+ Assert.assertTrue(node.startsWith(newNode));
+
+ }
+
+ /**
+ * This tests validates the correctness of the logic of
+ * PreventDoSContextAMPImpl for temporal sliding window checks.
+ */
+ @Test
+ public void testSlidindWindowThreshold() throws InterruptedException {
+ Configuration conf = new YarnConfiguration();
+ setUpSlidingWindowThreshold(conf);
+ PreventDoSContextAMPImpl context = new PreventDoSContextAMPImpl(conf);
+
+ /*
+ * Current Status of Sliding Window: First Slot
+ *
+ * RegisterApplicationMasterRequest | 0 | 0 | 0 |
+ *
+ * FinishApplicationMasterRequest | 0 | 0 | 0 |
+ *
+ * AllocateRequest | 0 | 0 | 0 |
+ *
+ * ContainerResourceChangeRequest | 0 | 0 | 0 |
+ *
+ * ResourceRequest | 0 | 0 | 0 |
+ *
+ * ContainerId | 0 | 0 | 0 |
+ *
+ * ResourceBlacklistRequest | 0 | 0 | 0 |
+ *
+ * NUM_CONTAINERS | 0 | 0 | 0 |
+ *
+ */
+
+ Thread.sleep(1000);
+
+ Assert.assertTrue(context.isSlidingWindowIncrementCountValid(
+ RegisterApplicationMasterRequest.class.getName()));
+ Assert.assertTrue(context.isSlidingWindowIncrementCountValid(
+ FinishApplicationMasterRequest.class.getName()));
+ Assert.assertTrue(context
+ .isSlidingWindowIncrementCountValid(AllocateRequest.class.getName()));
+ Assert.assertTrue(context.isSlidingWindowIncrementCountValid(
+ ResourceRequest.class.getName(), 3));
+ Assert.assertTrue(context.isSlidingWindowIncrementCountValid(
+ ContainerResourceChangeRequest.class.getName(), 3));
+ Assert.assertTrue(context
+ .isSlidingWindowIncrementCountValid(ContainerId.class.getName(), 3));
+ Assert.assertTrue(context.isSlidingWindowIncrementCountValid(
+ ResourceBlacklistRequest.class.getName(), 3));
+ Assert.assertTrue(context
+ .isSlidingWindowIncrementCountValid(ResourceRequest.class.getName()
+ + PreventDoSContextAMPImpl.NUM_CONTAINERS, 15));
+
+ Thread.sleep(slidingWindowAdvanceTime * 1000);
+
+ /*
+ * Current Status of Sliding Window: Second Slot
+ *
+ * RegisterApplicationMasterRequest | 1 | 0 | 0 |
+ *
+ * FinishApplicationMasterRequest | 1 | 0 | 0 |
+ *
+ * AllocateRequest | 1 | 0 | 0 |
+ *
+ * ContainerResourceChangeRequest | 3 | 0 | 0 |
+ *
+ * ResourceRequest | 3 | 0 | 0 |
+ *
+ * ContainerId | 3 | 0 | 0 |
+ *
+ * ResourceBlacklistRequest | 3 | 0 | 0 |
+ *
+ * NUM_CONTAINERS | 15 | 0 | 0 |
+ *
+ */
+
+ Assert.assertTrue(context.isSlidingWindowIncrementCountValid(
+ RegisterApplicationMasterRequest.class.getName()));
+ Assert.assertTrue(context.isSlidingWindowIncrementCountValid(
+ FinishApplicationMasterRequest.class.getName()));
+ Assert.assertTrue(context
+ .isSlidingWindowIncrementCountValid(AllocateRequest.class.getName()));
+ Assert.assertTrue(context.isSlidingWindowIncrementCountValid(
+ ResourceRequest.class.getName(), 3));
+ Assert.assertTrue(context.isSlidingWindowIncrementCountValid(
+ ContainerResourceChangeRequest.class.getName(), 3));
+ Assert.assertTrue(context
+ .isSlidingWindowIncrementCountValid(ContainerId.class.getName(), 3));
+ Assert.assertTrue(context.isSlidingWindowIncrementCountValid(
+ ResourceBlacklistRequest.class.getName(), 3));
+ Assert.assertTrue(context
+ .isSlidingWindowIncrementCountValid(ResourceRequest.class.getName()
+ + PreventDoSContextAMPImpl.NUM_CONTAINERS, 15));
+
+ Thread.sleep(slidingWindowAdvanceTime * 1000);
+
+ /*
+ * Current Status of Sliding Window: Third Slot
+ *
+ * RegisterApplicationMasterRequest | 1 | 1 | 0 |
+ *
+ * FinishApplicationMasterRequest | 1 | 1 | 0 |
+ *
+ * AllocateRequest | 1 | 1 | 0 |
+ *
+ * ContainerResourceChangeRequest | 3 | 3 | 0 |
+ *
+ * ResourceRequest | 3 | 3 | 0 |
+ *
+ * ContainerId | 3 | 3 | 0 |
+ *
+ * ResourceBlacklistRequest | 3 | 3 | 0 |
+ *
+ * NUM_CONTAINERS | 15 | 15 | 0 |
+ *
+ */
+
+ Assert.assertTrue(context.isSlidingWindowIncrementCountValid(
+ RegisterApplicationMasterRequest.class.getName()));
+ Assert.assertTrue(context.isSlidingWindowIncrementCountValid(
+ FinishApplicationMasterRequest.class.getName()));
+ Assert.assertTrue(context
+ .isSlidingWindowIncrementCountValid(AllocateRequest.class.getName()));
+ Assert.assertTrue(context.isSlidingWindowIncrementCountValid(
+ ResourceRequest.class.getName(), 3));
+ Assert.assertTrue(context.isSlidingWindowIncrementCountValid(
+ ContainerResourceChangeRequest.class.getName(), 3));
+ Assert.assertTrue(context
+ .isSlidingWindowIncrementCountValid(ContainerId.class.getName(), 3));
+ Assert.assertTrue(context.isSlidingWindowIncrementCountValid(
+ ResourceBlacklistRequest.class.getName(), 3));
+ Assert.assertTrue(context
+ .isSlidingWindowIncrementCountValid(ResourceRequest.class.getName()
+ + PreventDoSContextAMPImpl.NUM_CONTAINERS, 15));
+
+ /*
+ * Current Status of Sliding Window: Third Slot
+ *
+ * RegisterApplicationMasterRequest | 1 | 1 | 1 |
+ *
+ * FinishApplicationMasterRequest | 1 | 1 | 1 |
+ *
+ * AllocateRequest | 1 | 1 | 1 |
+ *
+ * ContainerResourceChangeRequest | 3 | 3 | 3 |
+ *
+ * ResourceRequest | 3 | 3 | 3 |
+ *
+ * ContainerId | 3 | 3 | 3 |
+ *
+ * ResourceBlacklistRequest | 3 | 3 | 3 |
+ *
+ * NUM_CONTAINERS | 15 | 15 | 15 |
+ *
+ */
+
+ Assert.assertFalse(context.isSlidingWindowIncrementCountValid(
+ RegisterApplicationMasterRequest.class.getName()));
+ Assert.assertFalse(context.isSlidingWindowIncrementCountValid(
+ FinishApplicationMasterRequest.class.getName()));
+ Assert.assertFalse(context
+ .isSlidingWindowIncrementCountValid(AllocateRequest.class.getName()));
+ Assert.assertFalse(context.isSlidingWindowIncrementCountValid(
+ ResourceRequest.class.getName(), 3));
+ Assert.assertFalse(context.isSlidingWindowIncrementCountValid(
+ ContainerResourceChangeRequest.class.getName(), 3));
+ Assert.assertFalse(context
+ .isSlidingWindowIncrementCountValid(ContainerId.class.getName(), 3));
+ Assert.assertFalse(context.isSlidingWindowIncrementCountValid(
+ ResourceBlacklistRequest.class.getName(), 3));
+ Assert.assertFalse(context
+ .isSlidingWindowIncrementCountValid(ResourceRequest.class.getName()
+ + PreventDoSContextAMPImpl.NUM_CONTAINERS, 15));
+
+ // All the previous operations fail because we exceed the limits
+
+ Thread.sleep(slidingWindowAdvanceTime * 1000);
+
+ /*
+ * Current Status of Sliding Window: First Slot
+ *
+ * RegisterApplicationMasterRequest | 0 | 1 | 1 |
+ *
+ * FinishApplicationMasterRequest | 0 | 1 | 1 |
+ *
+ * AllocateRequest | 0 | 1 | 1 |
+ *
+ * ContainerResourceChangeRequest | 0 | 3 | 3 |
+ *
+ * ResourceRequest | 0 | 3 | 3 |
+ *
+ * ContainerId | 0 | 3 | 3 |
+ *
+ * ResourceBlacklistRequest | 0 | 3 | 3 |
+ *
+ * NUM_CONTAINERS | 0 | 15 | 15 |
+ *
+ */
+
+ Assert.assertTrue(context.isSlidingWindowIncrementCountValid(
+ RegisterApplicationMasterRequest.class.getName()));
+ Assert.assertTrue(context.isSlidingWindowIncrementCountValid(
+ FinishApplicationMasterRequest.class.getName()));
+ Assert.assertTrue(context
+ .isSlidingWindowIncrementCountValid(AllocateRequest.class.getName()));
+ Assert.assertTrue(context.isSlidingWindowIncrementCountValid(
+ ResourceRequest.class.getName(), 3));
+ Assert.assertTrue(context.isSlidingWindowIncrementCountValid(
+ ContainerResourceChangeRequest.class.getName(), 3));
+ Assert.assertTrue(context
+ .isSlidingWindowIncrementCountValid(ContainerId.class.getName(), 3));
+ Assert.assertTrue(context.isSlidingWindowIncrementCountValid(
+ ResourceBlacklistRequest.class.getName(), 3));
+ Assert.assertTrue(context
+ .isSlidingWindowIncrementCountValid(ResourceRequest.class.getName()
+ + PreventDoSContextAMPImpl.NUM_CONTAINERS, 15));
+
+ /*
+ * Current Status of Sliding Window: First Slot
+ *
+ * RegisterApplicationMasterRequest | 1 | 1 | 1 |
+ *
+ * FinishApplicationMasterRequest | 1 | 1 | 1 |
+ *
+ * AllocateRequest | 1 | 1 | 1 |
+ *
+ * ContainerResourceChangeRequest | 3 | 3 | 3 |
+ *
+ * ResourceRequest | 3 | 3 | 3 |
+ *
+ * ContainerId | 3 | 3 | 3 |
+ *
+ * ResourceBlacklistRequest | 3 | 3 | 3 |
+ *
+ * NUM_CONTAINERS | 15 | 15 | 15 |
+ *
+ */
+
+ }
+
+ /**
+ * This tests validates the correctness of the logic of
+ * PreventDoSContextAMPImpl for the entire application's Requests checks.
+ */
+ @Test
+ public void testRequestsLimit() {
+ Configuration conf = new YarnConfiguration();
+ setUpRequestsLimit(conf);
+ PreventDoSContextAMPImpl context = new PreventDoSContextAMPImpl(conf);
+
+ // Check RegisterApplicationMasterRequest entire execution of the
+ // application. Limit is 3.
+ Assert.assertTrue(context.isRegisterRequestsIncrementValid());
+ Assert.assertTrue(context.isRegisterRequestsIncrementValid());
+ Assert.assertTrue(context.isRegisterRequestsIncrementValid());
+
+ // Exceed the limit
+ Assert.assertFalse(context.isRegisterRequestsIncrementValid());
+
+ // Check FinishApplicationMasterRequest entire execution of the
+ // application. Limit is 3.
+ Assert.assertTrue(context.isFinishRequestsIncrementValid());
+ Assert.assertTrue(context.isFinishRequestsIncrementValid());
+ Assert.assertTrue(context.isFinishRequestsIncrementValid());
+
+ // Exceed the limit
+ Assert.assertFalse(context.isFinishRequestsIncrementValid());
+
+ // Check AllocateRequest entire execution of the application. Limit is 3.
+ Assert.assertTrue(context.isAllocateRequestsIncrementValid());
+ Assert.assertTrue(context.isAllocateRequestsIncrementValid());
+ Assert.assertTrue(context.isAllocateRequestsIncrementValid());
+
+ // Exceed the limit
+ Assert.assertFalse(context.isAllocateRequestsIncrementValid());
+
+ // Check ResourceRequest entire execution of the application. Limit is 10.
+ Assert.assertTrue(context.isAskListRequestsIncrementValid(3));
+ Assert.assertTrue(context.isAskListRequestsIncrementValid(3));
+ Assert.assertTrue(context.isAskListRequestsIncrementValid(3));
+
+ // Exceed the limit
+ Assert.assertFalse(context.isAskListRequestsIncrementValid(3));
+
+ // Check ContainerResourceChangeRequest entire execution of the
+ // application. Limit is 10.
+ Assert.assertTrue(context.isChangeListRequestsIncrementValid(3));
+ Assert.assertTrue(context.isChangeListRequestsIncrementValid(3));
+ Assert.assertTrue(context.isChangeListRequestsIncrementValid(3));
+
+ // Exceed the limit
+ Assert.assertFalse(context.isChangeListRequestsIncrementValid(3));
+
+ // Check ResourceBlacklistRequest entire execution of the
+ // application. Limit is 10.
+ Assert.assertTrue(context.isBlackListRequestsIncrementValid(3));
+ Assert.assertTrue(context.isBlackListRequestsIncrementValid(3));
+ Assert.assertTrue(context.isBlackListRequestsIncrementValid(3));
+
+ // Exceed the limit
+ Assert.assertFalse(context.isBlackListRequestsIncrementValid(3));
+
+ // Check ContainerId entire execution of the application. Limit is 10.
+ Assert.assertTrue(context.isReleaseListRequestsIncrementValid(3));
+ Assert.assertTrue(context.isReleaseListRequestsIncrementValid(3));
+ Assert.assertTrue(context.isReleaseListRequestsIncrementValid(3));
+
+ // Exceed the limit
+ Assert.assertFalse(context.isReleaseListRequestsIncrementValid(3));
+
+ // Check NUM_CONTAINERS entire execution of the application. Limit is 50.
+ Assert.assertTrue(context.isContainersRequestsIncrementValid(15));
+ Assert.assertTrue(context.isContainersRequestsIncrementValid(15));
+ Assert.assertTrue(context.isContainersRequestsIncrementValid(15));
+
+ // Exceed the limit
+ Assert.assertFalse(context.isContainersRequestsIncrementValid(15));
+ }
+
+}
diff --git hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-common/src/test/java/org/apache/hadoop/yarn/server/preventdos/TestSlidingWindowCounter.java hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-common/src/test/java/org/apache/hadoop/yarn/server/preventdos/TestSlidingWindowCounter.java
new file mode 100644
index 0000000..2849cde
--- /dev/null
+++ hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-common/src/test/java/org/apache/hadoop/yarn/server/preventdos/TestSlidingWindowCounter.java
@@ -0,0 +1,125 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.hadoop.yarn.server.preventdos;
+
+import java.util.Map;
+import org.junit.Test;
+import org.junit.Assert;
+
+/**
+ * Unit tests for SlidingWindowCounter.
+ */
+public class TestSlidingWindowCounter {
+
+ private static final int ANY_WINDOW_LENGTH_IN_SLOTS = 2;
+ private static final Object ANY_OBJECT = "ANY_OBJECT";
+
+ private SlidingWindowCounter getSlidingWindowCounter(
+ int windowLength) {
+ return new SlidingWindowCounter(windowLength);
+ }
+
+ @Test
+ public void lessThanTwoSlotsShouldThrowIAE() {
+ int[] illegalWindowLengths = new int[] { -10, -3, -2, -1, 0, 1 };
+ for (int windowLengthInSlots : illegalWindowLengths) {
+ try {
+ getSlidingWindowCounter(windowLengthInSlots);
+ Assert.fail();
+ } catch (IllegalArgumentException e) {
+
+ }
+
+ }
+ }
+
+ @Test
+ public void twoOrMoreSlotsShouldBeValid() {
+ int[] legalWindowLengths = new int[] { 2, 3, 20 };
+ for (int windowLengthInSlots : legalWindowLengths) {
+ getSlidingWindowCounter(windowLengthInSlots);
+ }
+ }
+
+ @Test
+ public void newInstanceShouldHaveEmptyCounts() {
+ // given
+ SlidingWindowCounter