diff --git hadoop-yarn-project/hadoop-yarn/hadoop-yarn-api/src/main/java/org/apache/hadoop/yarn/conf/YarnConfiguration.java hadoop-yarn-project/hadoop-yarn/hadoop-yarn-api/src/main/java/org/apache/hadoop/yarn/conf/YarnConfiguration.java index 8899ccd..f15707d 100644 --- hadoop-yarn-project/hadoop-yarn/hadoop-yarn-api/src/main/java/org/apache/hadoop/yarn/conf/YarnConfiguration.java +++ hadoop-yarn-project/hadoop-yarn/hadoop-yarn-api/src/main/java/org/apache/hadoop/yarn/conf/YarnConfiguration.java @@ -2743,6 +2743,167 @@ public static boolean areNodeLabelsEnabled( public static final String TIMELINE_XFS_OPTIONS = TIMELINE_XFS_PREFIX + "xframe-options"; + // Prevent DoS attack configuration for ApplicationMasterProtocol + + public static final String DOS_PREFIX = "yarn.dos."; + + public static final String DOS_SLIDING_WINDOW_PREFIX = + DOS_PREFIX + "sliding-windows."; + + public static final String DOS_SLIDING_WINDOW_SIZE = + DOS_SLIDING_WINDOW_PREFIX + "size"; + public static final int DEFAULT_DOS_SLIDING_WINDOW_SIZE = 10; + + public static final String DOS_SLIDING_WINDOW_ADVANCE_TIME_SEC = + DOS_SLIDING_WINDOW_PREFIX + "advance-time"; + public static final int DEFAULT_DOS_SLIDING_WINDOW_ADVANCE_TIME_SEC = 60; + + public static final String DOS_TRUNCATE_STRING_ENABLED = + DOS_PREFIX + "truncate-string.enable"; + public static final boolean DEFAULT_DOS_TRUNCATE_STRING_ENABLED = true; + + public static final String DOS_ACTION = DOS_PREFIX + "action"; + public static final String DEFAULT_DOS_ACTION = "IGNORE"; + + // Single request limit + + public static final String DOS_REQUEST_LIMIT = DOS_PREFIX + "request-limit."; + + public static final String DOS_REQUEST_LIMIT_LENGTH_HOST = + DOS_REQUEST_LIMIT + "length-host"; + public static final int DEFAULT_DOS_REQUEST_LIMIT_LENGTH_HOST = 150; + + public static final String DOS_REQUEST_LIMIT_RPC_PORT = + DOS_REQUEST_LIMIT + "rpc-port"; + public static final int DEFAULT_DOS_REQUEST_LIMIT_RPC_PORT = 10000; + + public static final String DOS_REQUEST_LIMIT_LENGTH_DIAGNOSTIC = + DOS_REQUEST_LIMIT + "length-diagnostic"; + public static final int DEFAULT_DOS_REQUEST_LIMIT_LENGTH_DIAGNOSTIC = 100; + + public static final String DOS_REQUEST_LIMIT_MIN_PRIORITY = + DOS_REQUEST_LIMIT + "min-priority"; + public static final int DEFAULT_DOS_REQUEST_LIMIT_MIN_PRIORITY = 0; + + public static final String DOS_REQUEST_LIMIT_MAX_PRIORITY = + DOS_REQUEST_LIMIT + "max-priority"; + public static final int DEFAULT_DOS_REQUEST_LIMIT_MAX_PRIORITY = 100; + + public static final String DOS_REQUEST_LIMIT_LENGTH_RESOURCE = + DOS_REQUEST_LIMIT + "length-resource"; + public static final int DEFAULT_DOS_REQUEST_LIMIT_LENGTH_RESOURCE = 150; + + public static final String DOS_REQUEST_LIMIT_LENGTH_NODELABEL = + DOS_REQUEST_LIMIT + "length-nodelabel"; + public static final int DEFAULT_DOS_REQUEST_LIMIT_LENGTH_NODELABEL = 200; + + public static final String DOS_REQUEST_LIMIT_MAX_MEMORY = + DOS_REQUEST_LIMIT + "max-memory"; + public static final int DEFAULT_DOS_REQUEST_LIMIT_MAX_MEMORY = 100 * 1024; + + public static final String DOS_REQUEST_LIMIT_MAX_VCORES = + DOS_REQUEST_LIMIT + "max-vcores"; + public static final int DEFAULT_DOS_REQUEST_LIMIT_MAX_VCORES = 24; + + public static final String DOS_REQUEST_LIMIT_ASK_LIST = + DOS_REQUEST_LIMIT + "ask-list"; + public static final int DEFAULT_DOS_REQUEST_LIMIT_ASK_LIST = 100; + + public static final String DOS_REQUEST_LIMIT_RELEASE_LIST = + DOS_REQUEST_LIMIT + "release-list"; + public static final int DEFAULT_DOS_REQUEST_LIMIT_RELEASE_LIST = 100; + + public static final String DOS_REQUEST_LIMIT_BLACKLIST = + DOS_REQUEST_LIMIT + "blacklist"; + public static final int DEFAULT_DOS_REQUEST_LIMIT_BLACKLIST = 100; + + public static final String DOS_REQUEST_LIMIT_CHANGELIST = + DOS_REQUEST_LIMIT + "change-list"; + public static final int DEFAULT_DOS_REQUEST_LIMIT_CHANGELIST = 100; + + public static final String DOS_REQUEST_LIMIT_CONTAINERS = + DOS_REQUEST_LIMIT + "containers"; + public static final int DEFAULT_DOS_REQUEST_LIMIT_CONTAINERS = 200; + + // Sliding Window Configuration + + public static final String DOS_SLIDING_WINDOW_THRESHOLD = + DOS_SLIDING_WINDOW_PREFIX + "threshold."; + + public static final String DOS_SLIDING_WINDOW_THRESHOLD_REGISTER_AM = + DOS_SLIDING_WINDOW_THRESHOLD + "register-am"; + public static final int DEFAULT_DOS_SLIDING_WINDOW_THRESHOLD_REGISTER_AM = 10; + + public static final String DOS_SLIDING_WINDOW_THRESHOLD_ALLOCATE_AM = + DOS_SLIDING_WINDOW_THRESHOLD + "allocate-am"; + public static final int DEFAULT_DOS_SLIDING_WINDOW_THRESHOLD_ALLOCATE_AM = + 1000; + + public static final String DOS_SLIDING_WINDOW_THRESHOLD_FINISH_AM = + DOS_SLIDING_WINDOW_THRESHOLD + "finish-am"; + public static final int DEFAULT_DOS_SLIDING_WINDOW_THRESHOLD_FINISH_AM = 10; + + public static final String DOS_SLIDING_WINDOW_THRESHOLD_ASK_LIST = + DOS_SLIDING_WINDOW_THRESHOLD + "ask-list"; + public static final int DEFAULT_DOS_SLIDING_WINDOW_THRESHOLD_ASK_LIST = + 1000; + + public static final String DOS_SLIDING_WINDOW_THRESHOLD_RELEASE_LIST = + DOS_SLIDING_WINDOW_THRESHOLD + "release-list"; + public static final int DEFAULT_DOS_SLIDING_WINDOW_THRESHOLD_RELEASE_LIST = + 1000; + + public static final String DOS_SLIDING_WINDOW_THRESHOLD_BLACKLIST = + DOS_SLIDING_WINDOW_THRESHOLD + "blacklist"; + public static final int DEFAULT_DOS_SLIDING_WINDOW_THRESHOLD_BLACKLIST = + 100; + + public static final String DOS_SLIDING_WINDOW_THRESHOLD_CHANGELIST = + DOS_SLIDING_WINDOW_THRESHOLD + "changelist"; + public static final int DEFAULT_DOS_SLIDING_WINDOW_THRESHOLD_CHANGELIST = + 100; + + public static final String DOS_SLIDING_WINDOW_THRESHOLD_NUM_CONTAINERS = + DOS_SLIDING_WINDOW_THRESHOLD + "num-containers"; + public static final int DEFAULT_DOS_SLIDING_WINDOW_THRESHOLD_NUM_CONTAINERS = + 1000; + + // Multiple requests limit + + public static final String DOS_MAX_REQUESTS = DOS_PREFIX + "max-requests."; + + public static final String DOS_MAX_REQUESTS_REGISTER_AM = + DOS_MAX_REQUESTS + "register-am"; + public static final int DEFAULT_DOS_MAX_REQUESTS_REGISTER_AM = 50; + + public static final String DOS_MAX_REQUESTS_ALLOCATE_AM = + DOS_MAX_REQUESTS + "allocate-am"; + public static final int DEFAULT_DOS_MAX_REQUESTS_ALLOCATE_AM = 1000; + + public static final String DOS_MAX_REQUESTS_FINISH_AM = + DOS_MAX_REQUESTS + "finish-am"; + public static final int DEFAULT_DOS_MAX_REQUESTS_FINISH_AM = 50; + + public static final String DOS_MAX_REQUESTS_ASK_LIST = + DOS_MAX_REQUESTS + "ask-list"; + public static final int DEFAULT_DOS_MAX_REQUESTS_ASK_LIST = 1000; + + public static final String DOS_MAX_REQUESTS_RELEASE_LIST = + DOS_MAX_REQUESTS + "release-list"; + public static final int DEFAULT_DOS_MAX_REQUESTS_RELEASE_LIST = 1000; + + public static final String DOS_MAX_REQUESTS_BLACKLIST = + DOS_MAX_REQUESTS + "blacklist"; + public static final int DEFAULT_DOS_MAX_REQUESTS_BLACKLIST = 100; + + public static final String DOS_MAX_REQUESTS_CHANGELIST = + DOS_MAX_REQUESTS + "changelist"; + public static final int DEFAULT_DOS_MAX_REQUESTS_CHANGELIST = 100; + + public static final String DOS_MAX_REQUESTS_NUM_CONTAINERS = + DOS_MAX_REQUESTS + "num-containers"; + public static final int DEFAULT_DOS_MAX_REQUESTS_NUM_CONTAINERS = 100000; + public YarnConfiguration() { super(); } diff --git hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-common/src/main/java/org/apache/hadoop/yarn/server/preventdos/PreventDoSAction.java hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-common/src/main/java/org/apache/hadoop/yarn/server/preventdos/PreventDoSAction.java new file mode 100644 index 0000000..036d0ff --- /dev/null +++ hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-common/src/main/java/org/apache/hadoop/yarn/server/preventdos/PreventDoSAction.java @@ -0,0 +1,63 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.hadoop.yarn.server.preventdos; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * In case of potential DoS attack, this class takes action against the client. + * + */ +public class PreventDoSAction { + + private static final Logger LOG = + LoggerFactory.getLogger(PreventDoSAction.class); + + private PreventDoSPolicy policy; + + public PreventDoSAction(PreventDoSPolicy policy) { + this.policy = policy; + } + + /** + * In case of potential DoS attack, this method takes action against the + * client. + * + * Ignore: It logs the cause, this option is used in case of testing. + * + * Reject: It rejects the request by throwing an exception. + * + */ + public void prevent(PreventDoSCause cause) throws PreventDoSAttackException { + + switch (policy) { + + case IGNORE: { + LOG.warn("Possible DoS attack: " + cause); + break; + } + + case REJECT: { + throw new PreventDoSAttackException("Possible DoS attack: " + cause); + } + + } + } +} \ No newline at end of file diff --git hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-common/src/main/java/org/apache/hadoop/yarn/server/preventdos/PreventDoSAttackException.java hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-common/src/main/java/org/apache/hadoop/yarn/server/preventdos/PreventDoSAttackException.java new file mode 100644 index 0000000..02682f4 --- /dev/null +++ hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-common/src/main/java/org/apache/hadoop/yarn/server/preventdos/PreventDoSAttackException.java @@ -0,0 +1,52 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.hadoop.yarn.server.preventdos; + +import org.apache.hadoop.classification.InterfaceAudience.Public; +import org.apache.hadoop.classification.InterfaceStability.Stable; +import org.apache.hadoop.yarn.exceptions.YarnException; + +/** + * Exception thrown by the {@code AMPDoSRequestInterceptor} if the application + * code is trying to DoS attack the ResourceManager. + * + */ +@Public +@Stable +public class PreventDoSAttackException extends YarnException { + + private static final long serialVersionUID = 1L; + + public PreventDoSAttackException() { + super(); + } + + public PreventDoSAttackException(String message) { + super(message); + } + + public PreventDoSAttackException(Throwable cause) { + super(cause); + } + + public PreventDoSAttackException(String message, Throwable cause) { + super(message, cause); + } + +} \ No newline at end of file diff --git hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-common/src/main/java/org/apache/hadoop/yarn/server/preventdos/PreventDoSCause.java hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-common/src/main/java/org/apache/hadoop/yarn/server/preventdos/PreventDoSCause.java new file mode 100644 index 0000000..f93c86e --- /dev/null +++ hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-common/src/main/java/org/apache/hadoop/yarn/server/preventdos/PreventDoSCause.java @@ -0,0 +1,288 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.hadoop.yarn.server.preventdos; + +/** + * This class groups the information about the Cause of Potential DoS attack. + * + *

+ * It contains: + *

+ * + */ +public class PreventDoSCause { + + private TypeCheck typeCheck; + private String method; + private String param; + private String reason; + + /** + * Type of Check. + * + */ + public enum TypeCheck { + /* Check on a single request */ + REQUEST_LIMIT, + /* Check on a temporal sliding window */ + SLIDING_WINDOW, + /* Check on the entire execution of the application */ + REQUESTS_LIMIT + } + + // Application Master Protocol possible cause + + // Over a Single Request + + public final static PreventDoSCause REQUEST_GET_HOST_CAUSE = + new PreventDoSCause(PreventDoSCause.TypeCheck.REQUEST_LIMIT, + "Register Application Master", "getHost", "Exceed Max Length"); + + public final static PreventDoSCause REQUEST_GET_RPC_PORT_CAUSE = + new PreventDoSCause(PreventDoSCause.TypeCheck.REQUEST_LIMIT, + "Register Application Master", "getRpcPort", "Invalid Port"); + + public final static PreventDoSCause REQUEST_GET_TRACKING_URL_CAUSE = + new PreventDoSCause(PreventDoSCause.TypeCheck.REQUEST_LIMIT, + "Register Application Master", "getTrackingUrl", "Invalid URL"); + + public final static PreventDoSCause REQUEST_GET_DIAGNOSTIC_CAUSE = + new PreventDoSCause(PreventDoSCause.TypeCheck.REQUEST_LIMIT, + "Finish Application Master", "getTrackingUrl", "Exceed Max Length"); + + public final static PreventDoSCause REQUEST_GET_TRACKING_URL_FINISH_CAUSE = + new PreventDoSCause(PreventDoSCause.TypeCheck.REQUEST_LIMIT, + "Finish Application Master", "getTrackingUrl", "Exceed Max Length"); + + public final static PreventDoSCause REQUEST_GET_RESPONSE_ID_CAUSE = + new PreventDoSCause(PreventDoSCause.TypeCheck.REQUEST_LIMIT, "Allocate", + "getResponseId", "Decreased number response Id"); + + public final static PreventDoSCause REQUEST_GET_PROGRESS_CAUSE = + new PreventDoSCause(PreventDoSCause.TypeCheck.REQUEST_LIMIT, "Allocate", + "getProgress", "Out of Range"); + + public final static PreventDoSCause REQUEST_GET_ASK_LIST_CAUSE = + new PreventDoSCause(PreventDoSCause.TypeCheck.REQUEST_LIMIT, "Allocate", + "getAskList", "Exceed Max Size Length"); + + public final static PreventDoSCause REQUEST_GET_PRIORITY_CAUSE = + new PreventDoSCause(PreventDoSCause.TypeCheck.REQUEST_LIMIT, "Allocate", + "getPriority", "Out of Range"); + + public final static PreventDoSCause REQUEST_GET_RESOURCE_NAME_CAUSE = + new PreventDoSCause(PreventDoSCause.TypeCheck.REQUEST_LIMIT, "Allocate", + "getResourceName", "Exceed Max Length"); + + public final static PreventDoSCause REQUEST_GET_CAPABILITY_CAUSE = + new PreventDoSCause(PreventDoSCause.TypeCheck.REQUEST_LIMIT, "Allocate", + "getCapability", "Out of Range"); + + public final static PreventDoSCause REQUEST_GET_NUM_CONTAINERS_CAUSE = + new PreventDoSCause(PreventDoSCause.TypeCheck.REQUEST_LIMIT, "Allocate", + "getNumContainers", "Out of Range"); + + public final static PreventDoSCause REQUEST_GET_NODE_LABEL_CAUSE = + new PreventDoSCause(PreventDoSCause.TypeCheck.REQUEST_LIMIT, "Allocate", + "getNodeLabelExpression", "Exceed Max Length"); + + public final static PreventDoSCause REQUEST_GET_RELEASE_LIST_CAUSE = + new PreventDoSCause(PreventDoSCause.TypeCheck.REQUEST_LIMIT, "Allocate", + "getReleaseList", "Exceed Max Size Length"); + + public final static PreventDoSCause REQUEST_GET_BLACKLIST_ADD_LIST_CAUSE = + new PreventDoSCause(PreventDoSCause.TypeCheck.REQUEST_LIMIT, "Allocate", + "getBlacklistAdditions", "Exceed Max Size Length"); + + public final static PreventDoSCause REQUEST_GET_BLACKLIST_REM_LIST_CAUSE = + new PreventDoSCause(PreventDoSCause.TypeCheck.REQUEST_LIMIT, "Allocate", + "getBlacklistRemovals", "Exceed Max Size Length"); + + public final static PreventDoSCause REQUEST_GET_INCREASE_LIST_CAUSE = + new PreventDoSCause(PreventDoSCause.TypeCheck.REQUEST_LIMIT, "Allocate", + "getIncreaseRequests", "Exceed Max Size Length"); + + public final static PreventDoSCause REQUEST_GET_DECREASE_LIST_CAUSE = + new PreventDoSCause(PreventDoSCause.TypeCheck.REQUEST_LIMIT, "Allocate", + "getDecreaseRequests", "Exceed Max Size Length"); + + // Over Sliding Window + + public final static PreventDoSCause SLIDINGWINDOW_REGISTER_AM_CAUSE = + new PreventDoSCause(PreventDoSCause.TypeCheck.SLIDING_WINDOW, + "Register Application Master", "registerApplicationMaster", + "Exceed Sliding Windows Threshold"); + + public final static PreventDoSCause SLIDINGWINDOW_FINISH_AM_CAUSE = + new PreventDoSCause(PreventDoSCause.TypeCheck.SLIDING_WINDOW, + "Finish Application Master", "finishApplicationMaster", + "Exceed Sliding Windows Threshold"); + + public final static PreventDoSCause SLIDINGWINDOW_ALLOCATE_AM_CAUSE = + new PreventDoSCause(PreventDoSCause.TypeCheck.SLIDING_WINDOW, "Allocate", + "allocate", "Exceed Sliding Windows Threshold"); + + public final static PreventDoSCause SLIDINGWINDOW_GET_ASK_LIST_CAUSE = + new PreventDoSCause(PreventDoSCause.TypeCheck.SLIDING_WINDOW, "Allocate", + "getAskList", "Exceed Sliding Windows Threshold"); + + public final static PreventDoSCause SLIDINGWINDOW_GET_NUM_CONTAINERS_CAUSE = + new PreventDoSCause(PreventDoSCause.TypeCheck.SLIDING_WINDOW, "Allocate", + "getNumContainers", "Exceed Sliding Windows Threshold"); + + public final static PreventDoSCause SLIDINGWINDOW_GET_RELEASE_LIST_CAUSE = + new PreventDoSCause(PreventDoSCause.TypeCheck.SLIDING_WINDOW, "Allocate", + "getReleaseList", "Exceed Sliding Windows Threshold"); + + public final static PreventDoSCause SLIDINGWINDOW_GET_BLACKLIST_ADD_LIST_CAUSE = + new PreventDoSCause(PreventDoSCause.TypeCheck.SLIDING_WINDOW, "Allocate", + "getBlacklistAdditions", "Exceed Sliding Windows Threshold"); + + public final static PreventDoSCause SLIDINGWINDOW_GET_BLACKLIST_REM_LIST_CAUSE = + new PreventDoSCause(PreventDoSCause.TypeCheck.SLIDING_WINDOW, "Allocate", + "getBlacklistRemovals", "Exceed Sliding Windows Threshold"); + + public final static PreventDoSCause SLIDINGWINDOW_GET_INCREASE_LIST_CAUSE = + new PreventDoSCause(PreventDoSCause.TypeCheck.SLIDING_WINDOW, "Allocate", + "getIncreaseRequests", "Exceed Sliding Windows Threshold"); + + public final static PreventDoSCause SLIDINGWINDOW_GET_DECREASE_LIST_CAUSE = + new PreventDoSCause(PreventDoSCause.TypeCheck.SLIDING_WINDOW, "Allocate", + "getDecreaseRequests", "Exceed Sliding Windows Threshold"); + + // Over the entire execution + + public final static PreventDoSCause REQUESTS_REGISTER_AM_CAUSE = + new PreventDoSCause(PreventDoSCause.TypeCheck.REQUESTS_LIMIT, + "Register Application Master", "registerApplicationMaster", + "Exceed Max Total Count"); + + public final static PreventDoSCause REQUESTS_FINISH_AM_CAUSE = + new PreventDoSCause(PreventDoSCause.TypeCheck.REQUESTS_LIMIT, + "Finish Application Master", "finishApplicationMaster", + "Exceed Max Total Count"); + + public final static PreventDoSCause REQUESTS_ALLOCATE_AM_CAUSE = + new PreventDoSCause(PreventDoSCause.TypeCheck.REQUESTS_LIMIT, "Allocate", + "allocate", "Exceed Max Total Count"); + + public final static PreventDoSCause REQUESTS_GET_ASK_LIST_CAUSE = + new PreventDoSCause(PreventDoSCause.TypeCheck.REQUESTS_LIMIT, "Allocate", + "getAskList", "Exceed Max Total Count"); + + public final static PreventDoSCause REQUESTS_GET_NUM_CONTAINERS_CAUSE = + new PreventDoSCause(PreventDoSCause.TypeCheck.REQUESTS_LIMIT, "Allocate", + "getNumContainers", "Exceed Max Total Count"); + + public final static PreventDoSCause REQUESTS_GET_RELEASE_LIST_CAUSE = + new PreventDoSCause(PreventDoSCause.TypeCheck.REQUESTS_LIMIT, "Allocate", + "getReleaseList", "Exceed Max Total Count"); + + public final static PreventDoSCause REQUESTS_GET_BLACKLIST_ADD_LIST_CAUSE = + new PreventDoSCause(PreventDoSCause.TypeCheck.REQUESTS_LIMIT, "Allocate", + "getBlacklistAdditions", "Exceed Max Total Count"); + + public final static PreventDoSCause REQUESTS_GET_BLACKLIST_REM_LIST_CAUSE = + new PreventDoSCause(PreventDoSCause.TypeCheck.REQUESTS_LIMIT, "Allocate", + "getBlacklistRemovals", "Exceed Max Total Count"); + + public final static PreventDoSCause REQUESTS_GET_INCREASE_LIST_CAUSE = + new PreventDoSCause(PreventDoSCause.TypeCheck.REQUESTS_LIMIT, "Allocate", + "getIncreaseRequests", "Exceed Max Total Count"); + + public final static PreventDoSCause REQUESTS_GET_DECREASE_LIST_CAUSE = + new PreventDoSCause(PreventDoSCause.TypeCheck.REQUESTS_LIMIT, "Allocate", + "getDecreaseRequests", "Exceed Max Total Count"); + + private PreventDoSCause(TypeCheck typeCheck, String method, String param, + String reason) { + this.typeCheck = typeCheck; + this.method = method; + this.param = param; + this.reason = reason; + } + + /** + * @return the method + */ + public String getMethod() { + return method; + } + + /** + * @param method the method to set + */ + public void setMethod(String method) { + this.method = method; + } + + /** + * @return the param + */ + public String getParam() { + return param; + } + + /** + * @param param the param to set + */ + public void setParam(String param) { + this.param = param; + } + + /** + * @return the reason + */ + public String getReason() { + return reason; + } + + /** + * @param reason the reason to set + */ + public void setReason(String reason) { + this.reason = reason; + } + + /** + * @return the typeCheck + */ + public TypeCheck getTypeCheck() { + return typeCheck; + } + + /** + * @param typeCheck the typeCheck to set + */ + public void setTypeCheck(TypeCheck typeCheck) { + this.typeCheck = typeCheck; + } + + @Override + public String toString() { + return "PreventDoSCause [typeCheck=" + typeCheck + ", method=" + method + + ", param=" + param + ", reason=" + reason + "]"; + } + +} \ No newline at end of file diff --git hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-common/src/main/java/org/apache/hadoop/yarn/server/preventdos/PreventDoSContext.java hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-common/src/main/java/org/apache/hadoop/yarn/server/preventdos/PreventDoSContext.java new file mode 100644 index 0000000..983e678 --- /dev/null +++ hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-common/src/main/java/org/apache/hadoop/yarn/server/preventdos/PreventDoSContext.java @@ -0,0 +1,45 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.hadoop.yarn.server.preventdos; + +/** + * Context of the Prevent DoS attack + */ +public interface PreventDoSContext { + + /** + * Set the configuration limits for the single request checks. + */ + void setConfigurationLimitRequest(); + + /** + * Set the configuration limits for the temporal sliding window checks. + */ + void setConfigurationThresholdSlidingWindow(); + + /** + * Set the configuration limits for the entire execution checks. + */ + void setConfigurationLimitRequests(); + + /** + * Shut down the timer of the sliding window. + */ + void shutdownTimer(); +} \ No newline at end of file diff --git hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-common/src/main/java/org/apache/hadoop/yarn/server/preventdos/PreventDoSContextAMPImpl.java hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-common/src/main/java/org/apache/hadoop/yarn/server/preventdos/PreventDoSContextAMPImpl.java new file mode 100644 index 0000000..c8d566c --- /dev/null +++ hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-common/src/main/java/org/apache/hadoop/yarn/server/preventdos/PreventDoSContextAMPImpl.java @@ -0,0 +1,622 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.hadoop.yarn.server.preventdos; + +import java.net.URI; +import java.util.List; +import java.util.Timer; +import java.util.TimerTask; + +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.yarn.api.protocolrecords.AllocateRequest; +import org.apache.hadoop.yarn.api.protocolrecords.FinishApplicationMasterRequest; +import org.apache.hadoop.yarn.api.protocolrecords.RegisterApplicationMasterRequest; +import org.apache.hadoop.yarn.api.records.ContainerId; +import org.apache.hadoop.yarn.api.records.ContainerResourceChangeRequest; +import org.apache.hadoop.yarn.api.records.Priority; +import org.apache.hadoop.yarn.api.records.Resource; +import org.apache.hadoop.yarn.api.records.ResourceBlacklistRequest; +import org.apache.hadoop.yarn.api.records.ResourceRequest; +import org.apache.hadoop.yarn.conf.YarnConfiguration; + +/** + * This class helps by keeping tracks of all the counters to prevent a DoS + * attack for a single application. + */ +public class PreventDoSContextAMPImpl implements PreventDoSContext { + + public static final String NUM_CONTAINERS = "NumContainers"; + public Configuration conf; + Timer timer = null; + + // Entire job execution counters + private int counterRegisterRequests = 0; + private int counterAllocateRequests = 0; + private int counterFinishRequests = 0; + private int counterAskListRequests = 0; + private int counterReleaseListRequests = 0; + private int counterBlackListRequests = 0; + private int counterChangeListRequests = 0; + private int counterContainersRequests = 0; + + // Limit for the entire job execution + private int limitRegisterRequests; + private int limitAllocateRequests; + private int limitFinishRequests; + private int limitAskListRequests; + private int limitReleaseListRequests; + private int limitBlackListRequests; + private int limitChangeListRequests; + private int limitContainersRequests; + + // Limit for singular request + private int maxLengthRegisterHost; + private int maxLengthFinishDiagnostic; + private int maxLengthAllocateResourceName; + private int maxLengthAllocateNodeLabel; + private int lastResponseId = 0; + private int validRPCPort; + private int maxMemory; + private int maxVCores; + private int maxSizeAskList; + private int maxSizeReleaseList; + private int maxSizeBlackList; + private int maxSizeChangeList; + private int maxSizeContainers; + private Priority minPriority; + private Priority maxPriority; + + private SlidingWindowCounterDoS slidingWindow = null; + + private boolean truncateStringEnable; + + public PreventDoSContextAMPImpl(Configuration conf) { + this.conf = conf; + + this.truncateStringEnable = + conf.getBoolean(YarnConfiguration.DOS_TRUNCATE_STRING_ENABLED, + YarnConfiguration.DEFAULT_DOS_TRUNCATE_STRING_ENABLED); + setConfigurationLimitRequest(); + setConfigurationThresholdSlidingWindow(); + setConfigurationLimitRequests(); + } + + @Override + public void setConfigurationLimitRequest() { + + // Single request limit for Register AM + + this.maxLengthRegisterHost = + conf.getInt(YarnConfiguration.DOS_REQUEST_LIMIT_LENGTH_HOST, + YarnConfiguration.DEFAULT_DOS_REQUEST_LIMIT_LENGTH_HOST); + this.validRPCPort = + conf.getInt(YarnConfiguration.DOS_REQUEST_LIMIT_RPC_PORT, + YarnConfiguration.DEFAULT_DOS_REQUEST_LIMIT_RPC_PORT); + + // Single request limit for Finish AM + + this.maxLengthFinishDiagnostic = + conf.getInt(YarnConfiguration.DOS_REQUEST_LIMIT_LENGTH_DIAGNOSTIC, + YarnConfiguration.DEFAULT_DOS_REQUEST_LIMIT_LENGTH_DIAGNOSTIC); + + // Single request limit for Allocate AM + + this.maxLengthAllocateResourceName = + conf.getInt(YarnConfiguration.DOS_REQUEST_LIMIT_LENGTH_RESOURCE, + YarnConfiguration.DEFAULT_DOS_REQUEST_LIMIT_LENGTH_RESOURCE); + this.maxLengthAllocateNodeLabel = + conf.getInt(YarnConfiguration.DOS_REQUEST_LIMIT_LENGTH_NODELABEL, + YarnConfiguration.DEFAULT_DOS_REQUEST_LIMIT_LENGTH_NODELABEL); + + this.maxVCores = conf.getInt(YarnConfiguration.DOS_REQUEST_LIMIT_MAX_VCORES, + YarnConfiguration.DEFAULT_DOS_REQUEST_LIMIT_MAX_VCORES); + this.maxMemory = conf.getInt(YarnConfiguration.DOS_REQUEST_LIMIT_MAX_MEMORY, + YarnConfiguration.DEFAULT_DOS_REQUEST_LIMIT_MAX_MEMORY); + this.maxSizeAskList = + conf.getInt(YarnConfiguration.DOS_REQUEST_LIMIT_ASK_LIST, + YarnConfiguration.DEFAULT_DOS_REQUEST_LIMIT_ASK_LIST); + this.maxSizeReleaseList = + conf.getInt(YarnConfiguration.DOS_REQUEST_LIMIT_RELEASE_LIST, + YarnConfiguration.DEFAULT_DOS_REQUEST_LIMIT_RELEASE_LIST); + this.maxSizeBlackList = + conf.getInt(YarnConfiguration.DOS_REQUEST_LIMIT_BLACKLIST, + YarnConfiguration.DEFAULT_DOS_REQUEST_LIMIT_BLACKLIST); + this.maxSizeChangeList = + conf.getInt(YarnConfiguration.DOS_REQUEST_LIMIT_CHANGELIST, + YarnConfiguration.DEFAULT_DOS_REQUEST_LIMIT_CHANGELIST); + this.maxSizeContainers = + conf.getInt(YarnConfiguration.DOS_REQUEST_LIMIT_CONTAINERS, + YarnConfiguration.DEFAULT_DOS_REQUEST_LIMIT_CONTAINERS); + this.minPriority = Priority.newInstance( + conf.getInt(YarnConfiguration.DOS_REQUEST_LIMIT_MIN_PRIORITY, + YarnConfiguration.DEFAULT_DOS_REQUEST_LIMIT_MIN_PRIORITY)); + this.maxPriority = Priority.newInstance( + conf.getInt(YarnConfiguration.DOS_REQUEST_LIMIT_MAX_PRIORITY, + YarnConfiguration.DEFAULT_DOS_REQUEST_LIMIT_MAX_PRIORITY)); + } + + private String truncate(String stringToTruncate, int maxLength) { + if (stringToTruncate != null) { + return stringToTruncate.substring(0, maxLength); + } + return stringToTruncate; + } + + /** + * Check if the host inside the RegisterApplicationMasterRequest is a valid + * input. + * + * @param host the host to validate. + * @return true if it is valid, false otherwise. + */ + public boolean isLengthRegisterHostValid(String host) { + return host != null && host.length() <= this.maxLengthRegisterHost; + } + + /** + * Truncate the host inside the RegisterApplicationMasterRequest to be + * compliant with a valid input. + * + * @param host the Host to truncate. + * @return the truncated host. + */ + public String lengthRegisterHostTruncate(String host) { + return (truncateStringEnable) ? truncate(host, maxLengthRegisterHost) + : host; + } + + /** + * Check if the RPC port inside the RegisterApplicationMasterRequest is a + * valid input. + * + * @param rpcPort the rpc port to validate. + * @return true if it is valid, false otherwise. + */ + public boolean isRPCPortValid(int rpcPort) { + return rpcPort <= this.validRPCPort && rpcPort > 0; + } + + /** + * Check if the url inside the RegisterApplicationMasterRequest is a valid + * address input. + * + * @param url the url port to validate. + * @return true if it is valid, false otherwise. + */ + public boolean isURLValid(String url) { + // Ensure url is not null + if (url == null || url.isEmpty()) { + return false; + } + // Validate url is well formed + boolean hasScheme = url.contains("://"); + URI uri = null; + try { + uri = hasScheme ? URI.create(url) : URI.create("dummyscheme://" + url); + } catch (IllegalArgumentException e) { + return false; + } + String host = uri.getHost(); + int port = uri.getPort(); + String path = uri.getPath(); + if ((host == null) || (port < 0) + || (!hasScheme && path != null && !path.isEmpty())) { + return false; + } + return true; + } + + /** + * Check if the diagnostic inside the FinishApplicationMasterRequest is a + * valid input. + * + * @param diagnostic the diagnostic to validate. + * @return true if it is valid, false otherwise. + */ + public boolean isLengthFinishDiagnosticValid(String diagnostic) { + return (diagnostic != null + && diagnostic.length() <= this.maxLengthFinishDiagnostic); + } + + /** + * Truncate the Diagnostic inside the FinishApplicationMasterRequest to be + * compliant with a valid input. + * + * @param diagnostic the Diagnostic to truncate. + * @return the truncated Diagnostic. + */ + public String lengthFinishDiagnosticTruncate(String diagnostic) { + return (truncateStringEnable) + ? truncate(diagnostic, maxLengthFinishDiagnostic) : diagnostic; + } + + /** + * Check if the responseId inside the AllocateRequest is a valid input. + * + * @param responseId the responseId to validate. + * @return true if it is valid, false otherwise. + */ + public boolean isLastResponseIdValid(int responseId) { + if (responseId > this.lastResponseId) { + this.lastResponseId = responseId; + return true; + } + return false; + } + + /** + * Check if the progress inside the AllocateRequest is a valid input. + * + * @param progress the progress to validate. + * @return true if it is valid, false otherwise. + */ + public boolean isProgressValid(float progress) { + return progress >= 0 && progress <= 100; + } + + /** + * Check if the size of AskList inside the AllocateRequest is a valid input. + * + * @param list the list to validate. + * @return true if it is valid, false otherwise. + */ + public boolean isSizeAskListValid(List list) { + return list != null && list.size() <= this.maxSizeAskList; + } + + /** + * Check if the Priority inside the AllocateRequest#ResourceRequest is a valid + * input. + * + * @param priority the priority to validate. + * @return true if it is valid, false otherwise. + */ + public boolean isPriorityValid(Priority priority) { + return (priority != null && (priority.compareTo(minPriority) <= 0 + && priority.compareTo(maxPriority) >= 0)); + } + + /** + * Check if the Resource Name inside the AllocateRequest#ResourceRequest is a + * valid input. + * + * @param resourceName the Resource Name to validate. + * @return true if it is valid, false otherwise. + */ + public boolean isLengthAllocateResourceNameValid(String resourceName) { + return resourceName != null + && resourceName.length() <= this.maxLengthAllocateResourceName; + } + + /** + * Truncate the Resource Name inside the AllocateRequest#ResourceRequest to be + * compliant with a valid input. + * + * @param resourceName the Resource Name to truncate. + * @return the truncated Resource Name. + */ + public String lengthAllocateResourceNameTruncate(String resourceName) { + return (truncateStringEnable) + ? truncate(resourceName, maxLengthAllocateResourceName) : resourceName; + } + + /** + * Check if the Resource inside the AllocateRequest#ResourceRequest is a valid + * input. + * + * @param resource the resource to validate. + * @return true if it is valid, false otherwise. + */ + public boolean isCapabilityValid(Resource resource) { + return (resource != null + && (resource.getMemorySize() >= 0 && resource.getVirtualCores() >= 0 + && resource.getMemorySize() <= this.maxMemory + && resource.getVirtualCores() <= this.maxVCores)); + } + + /** + * Check if the number of containers inside the + * AllocateRequest#ResourceRequest is a valid input. + * + * @param numContainers the number of containers to validate. + * @return true if it is valid, false otherwise. + */ + public boolean isSizeContainersValid(int numContainers) { + return numContainers >= 0 && numContainers <= maxSizeContainers; + } + + /** + * Check if the NodeLabel inside the AllocateRequest#ResourceRequest is a + * valid input. + * + * @param nodeLabel the NodeLabel to validate. + * @return true if it is valid, false otherwise. + */ + public boolean isLengthAllocateNodeLabelValid(String nodeLabel) { + return nodeLabel != null + && nodeLabel.length() <= this.maxLengthAllocateNodeLabel; + } + + /** + * Truncate the Node Label inside the AllocateRequest#ResourceRequest to be + * compliant with a valid input. + * + * @param nodeLabel the Node Label to truncate. + * @return the truncated Node Label. + */ + public String lengthAllocateNodeLabelTruncate(String nodeLabel) { + return (truncateStringEnable) + ? truncate(nodeLabel, maxLengthAllocateNodeLabel) : nodeLabel; + } + + /** + * Check if the size of ReleaseList inside the AllocateRequest is a valid + * input. + * + * @param list the list to validate. + * @return true if it is valid, false otherwise. + */ + public boolean isSizeReleaseListValid(List list) { + return list != null && list.size() <= this.maxSizeReleaseList; + } + + /** + * Check if the size of BlackList inside the AllocateRequest is a valid input. + * + * @param list the list to validate. + * @return true if it is valid, false otherwise. + */ + public boolean isSizeBlackListValid(List list) { + return list != null && list.size() <= this.maxSizeBlackList; + } + + /** + * Check if the size of ChangeList inside the AllocateRequest is a valid + * input. + * + * @param list the list to validate. + * @return true if it is valid, false otherwise. + */ + public boolean isSizeChangeListValid( + List list) { + return list != null && list.size() <= this.maxSizeChangeList; + } + + @Override + public void setConfigurationThresholdSlidingWindow() { + + int slidingWindowSize = + conf.getInt(YarnConfiguration.DOS_SLIDING_WINDOW_SIZE, + YarnConfiguration.DEFAULT_DOS_SLIDING_WINDOW_SIZE); + int slidingWindowInterval = + conf.getInt(YarnConfiguration.DOS_SLIDING_WINDOW_ADVANCE_TIME_SEC, + YarnConfiguration.DEFAULT_DOS_SLIDING_WINDOW_ADVANCE_TIME_SEC); + + slidingWindow = new SlidingWindowCounterDoS(slidingWindowSize); + + // Sliding window configuration for Register AM + + slidingWindow.setThreshold(RegisterApplicationMasterRequest.class.getName(), + conf.getInt(YarnConfiguration.DOS_SLIDING_WINDOW_THRESHOLD_REGISTER_AM, + YarnConfiguration.DEFAULT_DOS_SLIDING_WINDOW_THRESHOLD_REGISTER_AM)); + + // Sliding window configuration for Finish AM + + slidingWindow.setThreshold(FinishApplicationMasterRequest.class.getName(), + conf.getInt(YarnConfiguration.DOS_SLIDING_WINDOW_THRESHOLD_FINISH_AM, + YarnConfiguration.DEFAULT_DOS_SLIDING_WINDOW_THRESHOLD_FINISH_AM)); + + // Sliding window configuration for Allocate AM + + slidingWindow.setThreshold(AllocateRequest.class.getName(), conf.getInt( + YarnConfiguration.DOS_SLIDING_WINDOW_THRESHOLD_ALLOCATE_AM, + YarnConfiguration.DEFAULT_DOS_SLIDING_WINDOW_THRESHOLD_ALLOCATE_AM)); + + slidingWindow.setThreshold(ResourceRequest.class.getName(), + conf.getInt(YarnConfiguration.DOS_SLIDING_WINDOW_THRESHOLD_ASK_LIST, + YarnConfiguration.DEFAULT_DOS_SLIDING_WINDOW_THRESHOLD_ASK_LIST)); + slidingWindow.setThreshold(ContainerId.class.getName(), conf.getInt( + YarnConfiguration.DOS_SLIDING_WINDOW_THRESHOLD_RELEASE_LIST, + YarnConfiguration.DEFAULT_DOS_SLIDING_WINDOW_THRESHOLD_RELEASE_LIST)); + slidingWindow.setThreshold(ResourceBlacklistRequest.class.getName(), + conf.getInt(YarnConfiguration.DOS_SLIDING_WINDOW_THRESHOLD_BLACKLIST, + YarnConfiguration.DEFAULT_DOS_SLIDING_WINDOW_THRESHOLD_BLACKLIST)); + slidingWindow.setThreshold(ContainerResourceChangeRequest.class.getName(), + conf.getInt(YarnConfiguration.DOS_SLIDING_WINDOW_THRESHOLD_CHANGELIST, + YarnConfiguration.DEFAULT_DOS_SLIDING_WINDOW_THRESHOLD_CHANGELIST)); + slidingWindow.setThreshold(ResourceRequest.class.getName() + NUM_CONTAINERS, + conf.getInt( + YarnConfiguration.DOS_SLIDING_WINDOW_THRESHOLD_NUM_CONTAINERS, + YarnConfiguration.DEFAULT_DOS_SLIDING_WINDOW_THRESHOLD_NUM_CONTAINERS)); + + // Start sliding window + + timer = new Timer(); + timer.scheduleAtFixedRate(new SlidingWindowsAdvanceTask(), + slidingWindowInterval * 1000, slidingWindowInterval * 1000); + } + + /** + * Check if we can accept the increment by 1 of the request class object. + * + * @param className the Class of the request to validate. + * @return true if it is valid, false otherwise. + */ + public boolean isSlidingWindowIncrementCountValid(String className) { + try { + slidingWindow.increaseCount(className); + } catch (PreventDoSAttackException e) { + return false; + } + return true; + } + + /** + * Check if we can accept the increment of the request class object. + * + * @param className the Class of the request to validate. + * @return true if it is valid, false otherwise. + */ + public boolean isSlidingWindowIncrementCountValid(String className, + int size) { + try { + slidingWindow.increaseCount(className, size); + } catch (PreventDoSAttackException e) { + return false; + } + return true; + } + + @Override + public void setConfigurationLimitRequests() { + + // Requests limit for Register AM + + this.limitRegisterRequests = + conf.getInt(YarnConfiguration.DOS_MAX_REQUESTS_REGISTER_AM, + YarnConfiguration.DEFAULT_DOS_MAX_REQUESTS_REGISTER_AM); + + // Requests limit for Finish AM + + this.limitFinishRequests = + conf.getInt(YarnConfiguration.DOS_MAX_REQUESTS_FINISH_AM, + YarnConfiguration.DEFAULT_DOS_MAX_REQUESTS_FINISH_AM); + + // Requests limit for Allocate AM + + this.limitAllocateRequests = + conf.getInt(YarnConfiguration.DOS_MAX_REQUESTS_ALLOCATE_AM, + YarnConfiguration.DEFAULT_DOS_MAX_REQUESTS_ALLOCATE_AM); + + this.limitAskListRequests = + conf.getInt(YarnConfiguration.DOS_MAX_REQUESTS_ASK_LIST, + YarnConfiguration.DEFAULT_DOS_MAX_REQUESTS_ASK_LIST); + this.limitReleaseListRequests = + conf.getInt(YarnConfiguration.DOS_MAX_REQUESTS_RELEASE_LIST, + YarnConfiguration.DEFAULT_DOS_MAX_REQUESTS_RELEASE_LIST); + this.limitBlackListRequests = + conf.getInt(YarnConfiguration.DOS_MAX_REQUESTS_BLACKLIST, + YarnConfiguration.DEFAULT_DOS_MAX_REQUESTS_BLACKLIST); + this.limitChangeListRequests = + conf.getInt(YarnConfiguration.DOS_MAX_REQUESTS_CHANGELIST, + YarnConfiguration.DEFAULT_DOS_MAX_REQUESTS_CHANGELIST); + this.limitContainersRequests = + conf.getInt(YarnConfiguration.DOS_MAX_REQUESTS_NUM_CONTAINERS, + YarnConfiguration.DEFAULT_DOS_MAX_REQUESTS_NUM_CONTAINERS); + } + + /** + * Check if we can accept the increment of the counter for + * RegisterApplicationMasterRequest. + * + * @return true if it is valid, false otherwise. + */ + public boolean isRegisterRequestsIncrementValid() { + return ++counterRegisterRequests <= limitRegisterRequests; + } + + /** + * Check if we can accept the increment of the counter for AllocateRequest. + * + * @return true if it is valid, false otherwise. + */ + public boolean isAllocateRequestsIncrementValid() { + return ++counterAllocateRequests <= limitAllocateRequests; + } + + /** + * Check if we can accept the increment of the counter for + * FinishApplicationMasterRequest. + * + * @return true if it is valid, false otherwise. + */ + public boolean isFinishRequestsIncrementValid() { + return ++counterFinishRequests <= limitFinishRequests; + } + + /** + * Check if we can accept the increment of the counter for + * AllocateRequest#AskList. + * + * @return true if it is valid, false otherwise. + */ + public boolean isAskListRequestsIncrementValid(int increment) { + counterAskListRequests += increment; + return counterAskListRequests <= limitAskListRequests; + } + + /** + * Check if we can accept the increment of the counter for + * AllocateRequest#ReleaseList. + * + * @return true if it is valid, false otherwise. + */ + public boolean isReleaseListRequestsIncrementValid(int increment) { + counterReleaseListRequests += increment; + return counterReleaseListRequests <= limitReleaseListRequests; + } + + /** + * Check if we can accept the increment of the counter for + * AllocateRequest#BlackList. + * + * @return true if it is valid, false otherwise. + */ + public boolean isBlackListRequestsIncrementValid(int increment) { + counterBlackListRequests += increment; + return counterBlackListRequests <= limitBlackListRequests; + } + + /** + * Check if we can accept the increment of the counter for + * AllocateRequest#ChangeList. + * + * @return true if it is valid, false otherwise. + */ + public boolean isChangeListRequestsIncrementValid(int increment) { + counterChangeListRequests += increment; + return counterChangeListRequests <= limitChangeListRequests; + } + + /** + * Check if we can accept the increment of the counter for + * AllocateRequest#ResourceRequest#NumContainers. + * + * @return true if it is valid, false otherwise. + */ + public boolean isContainersRequestsIncrementValid(int increment) { + counterContainersRequests += increment; + return counterContainersRequests <= limitContainersRequests; + } + + private class SlidingWindowsAdvanceTask extends TimerTask { + + public void run() { + slidingWindow.getCountsThenAdvanceWindow(); + } + } + + @Override + public void shutdownTimer() { + if (timer != null) { + timer.cancel(); + } + } + +} \ No newline at end of file diff --git hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-common/src/main/java/org/apache/hadoop/yarn/server/preventdos/PreventDoSPolicy.java hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-common/src/main/java/org/apache/hadoop/yarn/server/preventdos/PreventDoSPolicy.java new file mode 100644 index 0000000..4dde2be --- /dev/null +++ hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-common/src/main/java/org/apache/hadoop/yarn/server/preventdos/PreventDoSPolicy.java @@ -0,0 +1,35 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.hadoop.yarn.server.preventdos; + +/** + * Admin policies of action in case of Potential DoS attack. + */ +public enum PreventDoSPolicy { + + /* + * Ignore: It logs the cause, this option is used in case of testing. + */ + IGNORE, + + /* + * Reject: It rejects the request by throwing an exception. + */ + REJECT, +} \ No newline at end of file diff --git hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-common/src/main/java/org/apache/hadoop/yarn/server/preventdos/SlidingWindowCounter.java hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-common/src/main/java/org/apache/hadoop/yarn/server/preventdos/SlidingWindowCounter.java new file mode 100644 index 0000000..80669fc --- /dev/null +++ hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-common/src/main/java/org/apache/hadoop/yarn/server/preventdos/SlidingWindowCounter.java @@ -0,0 +1,141 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.hadoop.yarn.server.preventdos; + +import java.io.Serializable; +import java.util.Map; + +/** + * This class counts objects in a sliding window fashion. + *

+ * It is designed 1) to give multiple "producer" threads write access to the + * counter, i.e. being able to increment counts of objects, and 2) to give a + * single "consumer" thread (e.g. {@link PeriodicSlidingWindowCounter}) read + * access to the counter. Whenever the consumer thread performs a read + * operation, this class will advance the head slot of the sliding window + * counter. This means that the consumer thread indirectly controls where writes + * of the producer threads will go to. Also, by itself this class will not + * advance the head slot. + *

+ * A note for analyzing data based on a sliding window count: During the initial + * windowLengthInSlots iterations, this sliding window counter will + * always return object counts that are equal or greater than in the previous + * iteration. This is the effect of the counter "loading up" at the very start + * of its existence. Conceptually, this is the desired behavior. + *

+ * To give an example, using a counter with 5 slots which for the sake of this + * example represent 1 minute of time each: + *

+ * + *

+ * {@code
+ * Sliding window counts of an object X over time
+ *
+ * Minute (timeline):
+ * 1    2   3   4   5   6   7   8
+ *
+ * Observed counts per minute:
+ * 1    1   1   1   0   0   0   0
+ *
+ * Counts returned by counter:
+ * 1    2   3   4   4   3   2   1
+ * }
+ * 
+ *

+ * As you can see in this example, for the first + * windowLengthInSlots (here: the first five minutes) the counter + * will always return counts equal or greater than in the previous iteration (1, + * 2, 3, 4, 4). This initial load effect needs to be accounted for whenever you + * want to perform analyses such as trending topics; otherwise your analysis + * algorithm might falsely identify the object to be trending as the counter + * seems to observe continuously increasing counts. Also, note that during the + * initial load phase every object will exhibit increasing counts. + *

+ * On a high-level, the counter exhibits the following behavior: If you asked + * the example counter after two minutes, "how often did you count the object + * during the past five minutes?", then it should reply "I have counted it 2 + * times in the past five minutes", implying that it can only account for the + * last two of those five minutes because the counter was not running before + * that time. + * + * @param The type of those objects we want to count. + */ +public class SlidingWindowCounter implements Serializable { + + private static final long serialVersionUID = -2645063988768785810L; + + private SlotBasedAccumulator objCounter; + + protected int headSlot; + protected int tailSlot; + protected int windowLengthInSlots; + + public SlidingWindowCounter(int windowLengthInSlots) { + if (windowLengthInSlots < 2) { + throw new IllegalArgumentException( + "Window length in slots must be at least two (you requested " + + windowLengthInSlots + ")"); + } + this.windowLengthInSlots = windowLengthInSlots; + this.objCounter = new SlotBasedAccumulator(this.windowLengthInSlots); + + this.headSlot = 0; + this.tailSlot = slotAfter(headSlot); + } + + public void incrementCount(T obj, long amount) { + objCounter.incrementCount(obj, headSlot, amount); + } + + public long computeTotalCount(T obj) { + Long count = objCounter.getCounts().get(obj); + if (count == null) { + return 0; + } + return count; + } + + /** + * Return the current (total) counts of all tracked objects, then advance the + * window. + *

+ * Whenever this method is called, we consider the counts of the current + * sliding window to be available to and successfully processed "upstream" + * (i.e. by the caller). Knowing this we will start counting any subsequent + * objects within the next "chunk" of the sliding window. + * + * @return The current (total) counts of all tracked objects. + */ + public Map getCountsThenAdvanceWindow() { + Map counts = objCounter.getCounts(); + objCounter.wipeZeros(); + objCounter.wipeSlot(tailSlot); + advanceHead(); + return counts; + } + + private void advanceHead() { + headSlot = tailSlot; + tailSlot = slotAfter(tailSlot); + } + + private int slotAfter(int slot) { + return (slot + 1) % windowLengthInSlots; + } + +} diff --git hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-common/src/main/java/org/apache/hadoop/yarn/server/preventdos/SlidingWindowCounterDoS.java hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-common/src/main/java/org/apache/hadoop/yarn/server/preventdos/SlidingWindowCounterDoS.java new file mode 100644 index 0000000..1c5264c --- /dev/null +++ hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-common/src/main/java/org/apache/hadoop/yarn/server/preventdos/SlidingWindowCounterDoS.java @@ -0,0 +1,93 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.hadoop.yarn.server.preventdos; + +import java.util.HashMap; +import java.util.Map; + +/** + * + * + */ +public final class SlidingWindowCounterDoS { + + private SlidingWindowCounter slidingWindow; + private Map dosThreshold = new HashMap(); + + public SlidingWindowCounterDoS(int windowLengthInSlots) { + slidingWindow = new SlidingWindowCounter<>(windowLengthInSlots); + } + + /** + * Set the threshold for the Sliding window. + * + * @param obj the value we want to create the counter. + * @param threshold the threshold for the obj. + */ + public void setThreshold(T obj, long threshold) { + dosThreshold.put(obj, threshold); + } + + /** + * Increase the counter by 1 for the value in input. + * + * @param obj the value to increase the counter. + * @throws PreventDoSAttackException in case the increment exceed the + * threshold. + */ + public void increaseCount(T obj) throws PreventDoSAttackException { + increaseCount(obj, 1); + } + + /** + * Increase the counter for the value in input. + * + * @param obj the value to increase the counter. + * @throws PreventDoSAttackException in case the increment exceed the + * threshold. + */ + public void increaseCount(T obj, long amount) + throws PreventDoSAttackException { + checkThreshold(obj, slidingWindow.computeTotalCount(obj) + amount); + slidingWindow.incrementCount(obj, amount); + } + + /** + * Check the respective counter for the value in input and it compares with + * threshold of the same. + * + * @param obj the value to check. + * @param count the current counter of the object. + * @throws PreventDoSAttackException in case the increment exceed the + * threshold. + */ + private void checkThreshold(T obj, long count) + throws PreventDoSAttackException { + if (dosThreshold.get(obj) < count) { + throw new PreventDoSAttackException(); + } + } + + /** + * Advance the sliding window to the next slot. + */ + public void getCountsThenAdvanceWindow() { + slidingWindow.getCountsThenAdvanceWindow(); + } + +} diff --git hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-common/src/main/java/org/apache/hadoop/yarn/server/preventdos/SlotBasedAccumulator.java hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-common/src/main/java/org/apache/hadoop/yarn/server/preventdos/SlotBasedAccumulator.java new file mode 100644 index 0000000..76ab863 --- /dev/null +++ hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-common/src/main/java/org/apache/hadoop/yarn/server/preventdos/SlotBasedAccumulator.java @@ -0,0 +1,144 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.hadoop.yarn.server.preventdos; + +import java.io.Serializable; +import java.util.HashMap; +import java.util.HashSet; +import java.util.Map; +import java.util.Set; + +/** + * This class provides per-slot counts of the occurrences of objects. + *

+ * It can be used, for instance, as a building block for implementing sliding + * window counting of objects. + * + * This class has an additional incrementCount(obj, slot, amount) that allows + * users to increase a slot's count by any amount specified + * + * This class also has an internal totals map that is consistently updated for + * better performance + * + * @param The type of those objects we want to count. + */ +public class SlotBasedAccumulator implements Serializable { + + private static final long serialVersionUID = 4858185737378394432L; + + protected final Map objToCounts = new HashMap(); + protected final Map objToTotals = new HashMap(); + protected final int numSlots; + + public SlotBasedAccumulator(int numSlots) { + if (numSlots <= 0) { + throw new IllegalArgumentException( + "Number of slots must be greater than zero (you requested " + numSlots + + ")"); + } + this.numSlots = numSlots; + } + + public void incrementCount(T obj, int slot) { + incrementCount(obj, slot, 1); + } + + public void incrementCount(T obj, int slot, long amount) { + long[] counts = objToCounts.get(obj); + if (counts == null) { + counts = new long[this.numSlots]; + objToCounts.put(obj, counts); + } + counts[slot] += amount; + updateTotal(obj, amount); + } + + public long getCount(T obj, int slot) { + long[] counts = objToCounts.get(obj); + if (counts == null) { + return 0; + } else { + return counts[slot]; + } + } + + public Map getCounts() { + // Create a deep copy of the internal total map + Map result = new HashMap(); + for (T obj : objToCounts.keySet()) { + result.put(obj, objToTotals.get(obj).longValue()); + } + return result; + } + + /** + * Reset the slot count of any tracked objects to zero for the given slot. + * + * @param slot + */ + public void wipeSlot(int slot) { + for (T obj : objToCounts.keySet()) { + resetSlotCountToZero(obj, slot); + } + } + + /** + * Resets the slot for obj to zero + * + * @param obj + * @param slot + */ + private void resetSlotCountToZero(T obj, int slot) { + long[] counts = objToCounts.get(obj); + updateTotal(obj, -counts[slot]); + counts[slot] = 0; + } + + /** + * Helper function to update the totals cache map + * + * @param obj + * @param diff + */ + private void updateTotal(T obj, long diff) { + Long total = objToTotals.get(obj); + if (total == null) { + total = new Long(0); + } + long result = total + diff; + objToTotals.put(obj, result); + } + + /** + * Remove any object from the counter whose total count is zero (to free up + * memory). + */ + public void wipeZeros() { + Set objToBeRemoved = new HashSet(); + for (T obj : objToCounts.keySet()) { + if (objToTotals.get(obj) == 0) { + objToBeRemoved.add(obj); + } + } + for (T obj : objToBeRemoved) { + objToCounts.remove(obj); + objToTotals.remove(obj); + } + } +} diff --git hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-common/src/test/java/org/apache/hadoop/yarn/server/preventdos/TestPreventDoSAction.java hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-common/src/test/java/org/apache/hadoop/yarn/server/preventdos/TestPreventDoSAction.java new file mode 100644 index 0000000..77f1243 --- /dev/null +++ hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-common/src/test/java/org/apache/hadoop/yarn/server/preventdos/TestPreventDoSAction.java @@ -0,0 +1,59 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.hadoop.yarn.server.preventdos; + +import org.junit.Assert; +import org.junit.Test; + +/** + * Unit tests for PreventDoSAction. + */ +public class TestPreventDoSAction { + + /** + * This test verifies if we set the PreventDoSPolicy equals to IGNORE, the + * PreventDoSAction does not throw an exception. + */ + @Test + public void testIgnoreAction() { + PreventDoSAction action = new PreventDoSAction(PreventDoSPolicy.IGNORE); + try { + action.prevent(PreventDoSCause.REQUEST_GET_ASK_LIST_CAUSE); + } catch (PreventDoSAttackException e) { + Assert.fail(); + } + } + + /** + * This test verifies if we set the PreventDoSPolicy equals to REJECT, the + * PreventDoSAction throw an exception. + */ + @Test + public void testRejectAction() { + PreventDoSAction action = new PreventDoSAction(PreventDoSPolicy.REJECT); + try { + action.prevent(PreventDoSCause.REQUEST_GET_ASK_LIST_CAUSE); + Assert.fail(); + } catch (PreventDoSAttackException e) { + Assert.assertTrue(e.getMessage() + .contains(PreventDoSCause.REQUEST_GET_ASK_LIST_CAUSE.toString())); + } + } + +} diff --git hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-common/src/test/java/org/apache/hadoop/yarn/server/preventdos/TestPreventDoSContextAMPImpl.java hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-common/src/test/java/org/apache/hadoop/yarn/server/preventdos/TestPreventDoSContextAMPImpl.java new file mode 100644 index 0000000..8d3eef6 --- /dev/null +++ hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-common/src/test/java/org/apache/hadoop/yarn/server/preventdos/TestPreventDoSContextAMPImpl.java @@ -0,0 +1,676 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.hadoop.yarn.server.preventdos; + +import java.util.ArrayList; +import java.util.List; + +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.yarn.api.protocolrecords.AllocateRequest; +import org.apache.hadoop.yarn.api.protocolrecords.FinishApplicationMasterRequest; +import org.apache.hadoop.yarn.api.protocolrecords.RegisterApplicationMasterRequest; +import org.apache.hadoop.yarn.api.records.ApplicationAttemptId; +import org.apache.hadoop.yarn.api.records.ApplicationId; +import org.apache.hadoop.yarn.api.records.ContainerId; +import org.apache.hadoop.yarn.api.records.ContainerResourceChangeRequest; +import org.apache.hadoop.yarn.api.records.Priority; +import org.apache.hadoop.yarn.api.records.Resource; +import org.apache.hadoop.yarn.api.records.ResourceBlacklistRequest; +import org.apache.hadoop.yarn.api.records.ResourceRequest; +import org.apache.hadoop.yarn.conf.YarnConfiguration; +import org.junit.Assert; +import org.junit.Test; + +/** + * Unit tests for PreventDoSContextAMPImpl. + */ +public class TestPreventDoSContextAMPImpl { + + private static final int maxStringLength = 10; + private static final int maxRPCPort = 1000; + private static final int maxVCores = 6; + private static final int maxMemory = 100 * 1024; + private static final int maxContainers = 100; + private static final int maxListSize = 10; + private static final int minPriority = 1; + private static final int maxPriority = 10; + + private static final int slidingWindowSize = 3; + private static final int slidingWindowAdvanceTime = 2; + private static final int thresholdNumberRequests = 3; + private static final int thresholdListSize = 10; + private static final int thresholdContainers = 50; + + private static final int maxNumberRequests = 3; + private static final int maxListSizeRequests = 10; + private static final int maxContainersRequests = 50; + + private void setUpRequestLimit(Configuration conf) { + + // Single request limit for Register AM + + conf.setInt(YarnConfiguration.DOS_REQUEST_LIMIT_LENGTH_HOST, + maxStringLength); + conf.setInt(YarnConfiguration.DOS_REQUEST_LIMIT_RPC_PORT, maxRPCPort); + + // Single request limit for Finish AM + + conf.setInt(YarnConfiguration.DOS_REQUEST_LIMIT_LENGTH_DIAGNOSTIC, + maxStringLength); + + // Single request limit for Allocate AM + + conf.setInt(YarnConfiguration.DOS_REQUEST_LIMIT_LENGTH_RESOURCE, + maxStringLength); + conf.setInt(YarnConfiguration.DOS_REQUEST_LIMIT_LENGTH_NODELABEL, + maxStringLength); + + conf.setInt(YarnConfiguration.DOS_REQUEST_LIMIT_MAX_VCORES, maxVCores); + conf.setInt(YarnConfiguration.DOS_REQUEST_LIMIT_MAX_MEMORY, maxMemory); + conf.setInt(YarnConfiguration.DOS_REQUEST_LIMIT_ASK_LIST, maxListSize); + conf.setInt(YarnConfiguration.DOS_REQUEST_LIMIT_RELEASE_LIST, maxListSize); + conf.setInt(YarnConfiguration.DOS_REQUEST_LIMIT_BLACKLIST, maxListSize); + conf.setInt(YarnConfiguration.DOS_REQUEST_LIMIT_CHANGELIST, maxListSize); + conf.setInt(YarnConfiguration.DOS_REQUEST_LIMIT_CONTAINERS, maxContainers); + conf.setInt(YarnConfiguration.DOS_REQUEST_LIMIT_MIN_PRIORITY, minPriority); + conf.setInt(YarnConfiguration.DOS_REQUEST_LIMIT_MAX_PRIORITY, maxPriority); + } + + private void setUpRequestsLimit(Configuration conf) { + // Requests limit for Register AM + + conf.setInt(YarnConfiguration.DOS_MAX_REQUESTS_REGISTER_AM, + maxNumberRequests); + + // Requests limit for Finish AM + + conf.setInt(YarnConfiguration.DOS_MAX_REQUESTS_FINISH_AM, + maxNumberRequests); + + // Requests limit for Allocate AM + conf.setInt(YarnConfiguration.DOS_MAX_REQUESTS_ALLOCATE_AM, + maxNumberRequests); + + conf.setInt(YarnConfiguration.DOS_MAX_REQUESTS_ASK_LIST, + maxListSizeRequests); + conf.setInt(YarnConfiguration.DOS_MAX_REQUESTS_RELEASE_LIST, + maxListSizeRequests); + conf.setInt(YarnConfiguration.DOS_MAX_REQUESTS_BLACKLIST, + maxListSizeRequests); + conf.setInt(YarnConfiguration.DOS_MAX_REQUESTS_CHANGELIST, + maxListSizeRequests); + conf.setInt(YarnConfiguration.DOS_MAX_REQUESTS_NUM_CONTAINERS, + maxContainersRequests); + } + + private void setUpSlidingWindowThreshold(Configuration conf) { + + conf.setInt(YarnConfiguration.DOS_SLIDING_WINDOW_SIZE, slidingWindowSize); + conf.setInt(YarnConfiguration.DOS_SLIDING_WINDOW_ADVANCE_TIME_SEC, + slidingWindowAdvanceTime); + + // Sliding window configuration for Register AM + + conf.setInt(YarnConfiguration.DOS_SLIDING_WINDOW_THRESHOLD_REGISTER_AM, + thresholdNumberRequests); + + // Sliding window configuration for Finish AM + + conf.setInt(YarnConfiguration.DOS_SLIDING_WINDOW_THRESHOLD_FINISH_AM, + thresholdNumberRequests); + + // Sliding window configuration for Allocate AM + + conf.setInt(YarnConfiguration.DOS_SLIDING_WINDOW_THRESHOLD_ALLOCATE_AM, + thresholdNumberRequests); + + conf.setInt(YarnConfiguration.DOS_SLIDING_WINDOW_THRESHOLD_ASK_LIST, + thresholdListSize); + conf.setInt(YarnConfiguration.DOS_SLIDING_WINDOW_THRESHOLD_RELEASE_LIST, + thresholdListSize); + conf.setInt(YarnConfiguration.DOS_SLIDING_WINDOW_THRESHOLD_BLACKLIST, + thresholdListSize); + conf.setInt(YarnConfiguration.DOS_SLIDING_WINDOW_THRESHOLD_CHANGELIST, + thresholdListSize); + + conf.setInt(YarnConfiguration.DOS_SLIDING_WINDOW_THRESHOLD_NUM_CONTAINERS, + thresholdContainers); + } + + /** + * This tests validates the correctness of the logic of + * PreventDoSContextAMPImpl for Single Request checks. + */ + @Test + public void testSingleRequestLimit() { + + Configuration conf = new YarnConfiguration(); + setUpRequestLimit(conf); + PreventDoSContextAMPImpl context = new PreventDoSContextAMPImpl(conf); + + // Check RegisterApplicationMasterRequest#Hostname + + Assert.assertTrue(context.isLengthRegisterHostValid("HostNameOk")); + Assert.assertFalse( + context.isLengthRegisterHostValid("HostNameLongerThanLimit")); + + // Check RegisterApplicationMasterRequest#RPCPort + + Assert.assertTrue(context.isRPCPortValid(50)); + Assert.assertFalse(context.isRPCPortValid(-10)); + Assert.assertFalse(context.isRPCPortValid(1001)); + + // Check RegisterApplicationMasterRequest#TrackingURL + + Assert.assertTrue(context.isURLValid("Address:10")); + Assert.assertTrue(context.isURLValid("https://Address:10")); + Assert.assertFalse(context.isURLValid("//...")); + Assert.assertFalse(context.isURLValid("BadAddress:10:10")); + + // Check FinishApplicationMasterRequest#Diagnostic + + Assert.assertTrue(context.isLengthFinishDiagnosticValid("Diagnostic")); + Assert.assertFalse( + context.isLengthFinishDiagnosticValid("DiagnosticLongerThanLimit")); + + // Check AllocateRequest#AskList + + List rrs = new ArrayList(); + for (int i = 0; i < maxListSize - 2; i++) { + rrs.add(ResourceRequest.newInstance(Priority.newInstance(1), "host", + Resource.newInstance(1, 1), 1)); + } + Assert.assertTrue(context.isSizeAskListValid(rrs)); + + Assert.assertFalse(context.isSizeAskListValid(null)); + + rrs = new ArrayList(); + for (int i = 0; i < maxListSize + 2; i++) { + rrs.add(ResourceRequest.newInstance(Priority.newInstance(1), "host", + Resource.newInstance(1, 1), 1)); + } + Assert.assertFalse(context.isSizeAskListValid(rrs)); + + // Check AllocateRequest#AskList#ResourceRequest#Priority + + Assert.assertTrue( + context.isPriorityValid(Priority.newInstance(maxPriority - 1))); + Assert.assertFalse( + context.isPriorityValid(Priority.newInstance(maxPriority + 1))); + Assert.assertFalse( + context.isPriorityValid(Priority.newInstance(minPriority - 1))); + + // Check AllocateRequest#AskList#ResourceRequest#ResourceName + + Assert.assertTrue(context.isLengthAllocateResourceNameValid("NodeOk")); + Assert.assertFalse( + context.isLengthAllocateResourceNameValid("NodeLongerThanLimit")); + + // Check AllocateRequest#AskList#ResourceRequest#Resource + + Assert.assertTrue(context + .isCapabilityValid(Resource.newInstance(maxMemory - 2, maxVCores - 2))); + Assert.assertFalse( + context.isCapabilityValid(Resource.newInstance(-2, maxVCores - 2))); + Assert.assertFalse( + context.isCapabilityValid(Resource.newInstance(maxMemory - 2, -2))); + Assert.assertFalse(context + .isCapabilityValid(Resource.newInstance(maxMemory + 2, maxVCores - 2))); + Assert.assertFalse(context + .isCapabilityValid(Resource.newInstance(maxMemory - 2, maxVCores + 2))); + Assert.assertFalse(context.isCapabilityValid(Resource.newInstance(-2, -2))); + Assert.assertFalse(context + .isCapabilityValid(Resource.newInstance(maxMemory + 2, maxVCores + 2))); + + // Check AllocateRequest#AskList#ResourceRequest#NumContainers + + Assert.assertTrue(context.isSizeContainersValid(maxContainers - 2)); + Assert.assertFalse(context.isSizeContainersValid(maxContainers + 2)); + Assert.assertFalse(context.isSizeContainersValid(-2)); + + // Check AllocateRequest#AskList#ResourceRequest#NodeLabels + + Assert.assertTrue(context.isLengthAllocateNodeLabelValid("NodeLabel")); + Assert.assertFalse( + context.isLengthAllocateNodeLabelValid("NodeLabelLongerThanLimit")); + + // Check AllocateRequest#Progress + + Assert.assertTrue(context.isProgressValid(20)); + Assert.assertFalse(context.isProgressValid(-10)); + Assert.assertFalse(context.isProgressValid(110)); + + // Check AllocateRequest#LastResponseId + + Assert.assertTrue(context.isLastResponseIdValid(1)); + Assert.assertTrue(context.isLastResponseIdValid(3)); + Assert.assertFalse(context.isLastResponseIdValid(2)); + + // Check AllocateRequest#ReleaseList + + ApplicationId appId = ApplicationId.newInstance(1, 1); + ApplicationAttemptId attemptId = ApplicationAttemptId.newInstance(appId, 1); + List containers = new ArrayList(); + for (int i = 0; i < maxListSize - 2; i++) { + containers.add(ContainerId.newContainerId(attemptId, i)); + } + Assert.assertTrue(context.isSizeReleaseListValid(containers)); + + containers = new ArrayList(); + for (int i = 0; i < maxListSize + 2; i++) { + containers.add(ContainerId.newContainerId(attemptId, i)); + } + Assert.assertFalse(context.isSizeReleaseListValid(containers)); + + // Check AllocateRequest#BlackList + + List blacklists = new ArrayList(); + for (int i = 0; i < maxListSize - 2; i++) { + blacklists.add("Node"); + } + Assert.assertTrue(context.isSizeBlackListValid(blacklists)); + + blacklists = new ArrayList(); + for (int i = 0; i < maxListSize + 2; i++) { + blacklists.add("Node"); + } + Assert.assertFalse(context.isSizeBlackListValid(blacklists)); + + // Check AllocateRequest#ResourceChangeList + + ContainerResourceChangeRequest crcr = ContainerResourceChangeRequest + .newInstance(ContainerId.newContainerId(attemptId, 1), + Resource.newInstance(maxMemory, maxVCores)); + List changeResource = + new ArrayList(); + for (int i = 0; i < maxListSize - 2; i++) { + changeResource.add(crcr); + } + Assert.assertTrue(context.isSizeChangeListValid(changeResource)); + + changeResource = new ArrayList(); + for (int i = 0; i < maxListSize + 2; i++) { + changeResource.add(crcr); + } + Assert.assertFalse(context.isSizeChangeListValid(changeResource)); + + } + + /** + * This tests validates the correctness of the logic of + * PreventDoSContextAMPImpl for Single Request truncation. + */ + @Test + public void testSingleRequestLimitTruncate() { + + Configuration conf = new YarnConfiguration(); + setUpRequestLimit(conf); + conf.setBoolean(YarnConfiguration.DOS_TRUNCATE_STRING_ENABLED, true); + + PreventDoSContextAMPImpl context = new PreventDoSContextAMPImpl(conf); + + String host = "HostNameLongerThanLimit"; + String diagnostic = "DiagnosticLongerThanLimit"; + String nodeLabel = "NodeLableLongerThanLimit"; + String node = "NodeLongerThanLimit"; + + // Truncate RegisterApplicationMasterRequest#Hostname + + String newHost = context.lengthRegisterHostTruncate(host); + Assert.assertEquals(maxStringLength, newHost.length()); + Assert.assertTrue(host.startsWith(newHost)); + + // Truncate FinishApplicationMasterRequest#Diagnostic + + String newDiagnostic = context.lengthFinishDiagnosticTruncate(diagnostic); + Assert.assertEquals(maxStringLength, newDiagnostic.length()); + Assert.assertTrue(diagnostic.startsWith(newDiagnostic)); + + // Truncate AllocateRequest#AskList#ResourceRequest#NodeLabels + + String newNodeLabel = context.lengthAllocateNodeLabelTruncate(nodeLabel); + Assert.assertEquals(maxStringLength, newNodeLabel.length()); + Assert.assertTrue(nodeLabel.startsWith(newNodeLabel)); + + // Truncate AllocateRequest#AskList#ResourceRequest#ResourceName + + String newNode = context.lengthAllocateResourceNameTruncate(node); + Assert.assertEquals(maxStringLength, newNode.length()); + Assert.assertTrue(node.startsWith(newNode)); + + } + + /** + * This tests validates the correctness of the logic of + * PreventDoSContextAMPImpl for temporal sliding window checks. + */ + @Test + public void testSlidindWindowThreshold() throws InterruptedException { + Configuration conf = new YarnConfiguration(); + setUpSlidingWindowThreshold(conf); + PreventDoSContextAMPImpl context = new PreventDoSContextAMPImpl(conf); + + /* + * Current Status of Sliding Window: First Slot + * + * RegisterApplicationMasterRequest | 0 | 0 | 0 | + * + * FinishApplicationMasterRequest | 0 | 0 | 0 | + * + * AllocateRequest | 0 | 0 | 0 | + * + * ContainerResourceChangeRequest | 0 | 0 | 0 | + * + * ResourceRequest | 0 | 0 | 0 | + * + * ContainerId | 0 | 0 | 0 | + * + * ResourceBlacklistRequest | 0 | 0 | 0 | + * + * NUM_CONTAINERS | 0 | 0 | 0 | + * + */ + + Thread.sleep(1000); + + Assert.assertTrue(context.isSlidingWindowIncrementCountValid( + RegisterApplicationMasterRequest.class.getName())); + Assert.assertTrue(context.isSlidingWindowIncrementCountValid( + FinishApplicationMasterRequest.class.getName())); + Assert.assertTrue(context + .isSlidingWindowIncrementCountValid(AllocateRequest.class.getName())); + Assert.assertTrue(context.isSlidingWindowIncrementCountValid( + ResourceRequest.class.getName(), 3)); + Assert.assertTrue(context.isSlidingWindowIncrementCountValid( + ContainerResourceChangeRequest.class.getName(), 3)); + Assert.assertTrue(context + .isSlidingWindowIncrementCountValid(ContainerId.class.getName(), 3)); + Assert.assertTrue(context.isSlidingWindowIncrementCountValid( + ResourceBlacklistRequest.class.getName(), 3)); + Assert.assertTrue(context + .isSlidingWindowIncrementCountValid(ResourceRequest.class.getName() + + PreventDoSContextAMPImpl.NUM_CONTAINERS, 15)); + + Thread.sleep(slidingWindowAdvanceTime * 1000); + + /* + * Current Status of Sliding Window: Second Slot + * + * RegisterApplicationMasterRequest | 1 | 0 | 0 | + * + * FinishApplicationMasterRequest | 1 | 0 | 0 | + * + * AllocateRequest | 1 | 0 | 0 | + * + * ContainerResourceChangeRequest | 3 | 0 | 0 | + * + * ResourceRequest | 3 | 0 | 0 | + * + * ContainerId | 3 | 0 | 0 | + * + * ResourceBlacklistRequest | 3 | 0 | 0 | + * + * NUM_CONTAINERS | 15 | 0 | 0 | + * + */ + + Assert.assertTrue(context.isSlidingWindowIncrementCountValid( + RegisterApplicationMasterRequest.class.getName())); + Assert.assertTrue(context.isSlidingWindowIncrementCountValid( + FinishApplicationMasterRequest.class.getName())); + Assert.assertTrue(context + .isSlidingWindowIncrementCountValid(AllocateRequest.class.getName())); + Assert.assertTrue(context.isSlidingWindowIncrementCountValid( + ResourceRequest.class.getName(), 3)); + Assert.assertTrue(context.isSlidingWindowIncrementCountValid( + ContainerResourceChangeRequest.class.getName(), 3)); + Assert.assertTrue(context + .isSlidingWindowIncrementCountValid(ContainerId.class.getName(), 3)); + Assert.assertTrue(context.isSlidingWindowIncrementCountValid( + ResourceBlacklistRequest.class.getName(), 3)); + Assert.assertTrue(context + .isSlidingWindowIncrementCountValid(ResourceRequest.class.getName() + + PreventDoSContextAMPImpl.NUM_CONTAINERS, 15)); + + Thread.sleep(slidingWindowAdvanceTime * 1000); + + /* + * Current Status of Sliding Window: Third Slot + * + * RegisterApplicationMasterRequest | 1 | 1 | 0 | + * + * FinishApplicationMasterRequest | 1 | 1 | 0 | + * + * AllocateRequest | 1 | 1 | 0 | + * + * ContainerResourceChangeRequest | 3 | 3 | 0 | + * + * ResourceRequest | 3 | 3 | 0 | + * + * ContainerId | 3 | 3 | 0 | + * + * ResourceBlacklistRequest | 3 | 3 | 0 | + * + * NUM_CONTAINERS | 15 | 15 | 0 | + * + */ + + Assert.assertTrue(context.isSlidingWindowIncrementCountValid( + RegisterApplicationMasterRequest.class.getName())); + Assert.assertTrue(context.isSlidingWindowIncrementCountValid( + FinishApplicationMasterRequest.class.getName())); + Assert.assertTrue(context + .isSlidingWindowIncrementCountValid(AllocateRequest.class.getName())); + Assert.assertTrue(context.isSlidingWindowIncrementCountValid( + ResourceRequest.class.getName(), 3)); + Assert.assertTrue(context.isSlidingWindowIncrementCountValid( + ContainerResourceChangeRequest.class.getName(), 3)); + Assert.assertTrue(context + .isSlidingWindowIncrementCountValid(ContainerId.class.getName(), 3)); + Assert.assertTrue(context.isSlidingWindowIncrementCountValid( + ResourceBlacklistRequest.class.getName(), 3)); + Assert.assertTrue(context + .isSlidingWindowIncrementCountValid(ResourceRequest.class.getName() + + PreventDoSContextAMPImpl.NUM_CONTAINERS, 15)); + + /* + * Current Status of Sliding Window: Third Slot + * + * RegisterApplicationMasterRequest | 1 | 1 | 1 | + * + * FinishApplicationMasterRequest | 1 | 1 | 1 | + * + * AllocateRequest | 1 | 1 | 1 | + * + * ContainerResourceChangeRequest | 3 | 3 | 3 | + * + * ResourceRequest | 3 | 3 | 3 | + * + * ContainerId | 3 | 3 | 3 | + * + * ResourceBlacklistRequest | 3 | 3 | 3 | + * + * NUM_CONTAINERS | 15 | 15 | 15 | + * + */ + + Assert.assertFalse(context.isSlidingWindowIncrementCountValid( + RegisterApplicationMasterRequest.class.getName())); + Assert.assertFalse(context.isSlidingWindowIncrementCountValid( + FinishApplicationMasterRequest.class.getName())); + Assert.assertFalse(context + .isSlidingWindowIncrementCountValid(AllocateRequest.class.getName())); + Assert.assertFalse(context.isSlidingWindowIncrementCountValid( + ResourceRequest.class.getName(), 3)); + Assert.assertFalse(context.isSlidingWindowIncrementCountValid( + ContainerResourceChangeRequest.class.getName(), 3)); + Assert.assertFalse(context + .isSlidingWindowIncrementCountValid(ContainerId.class.getName(), 3)); + Assert.assertFalse(context.isSlidingWindowIncrementCountValid( + ResourceBlacklistRequest.class.getName(), 3)); + Assert.assertFalse(context + .isSlidingWindowIncrementCountValid(ResourceRequest.class.getName() + + PreventDoSContextAMPImpl.NUM_CONTAINERS, 15)); + + // All the previous operations fail because we exceed the limits + + Thread.sleep(slidingWindowAdvanceTime * 1000); + + /* + * Current Status of Sliding Window: First Slot + * + * RegisterApplicationMasterRequest | 0 | 1 | 1 | + * + * FinishApplicationMasterRequest | 0 | 1 | 1 | + * + * AllocateRequest | 0 | 1 | 1 | + * + * ContainerResourceChangeRequest | 0 | 3 | 3 | + * + * ResourceRequest | 0 | 3 | 3 | + * + * ContainerId | 0 | 3 | 3 | + * + * ResourceBlacklistRequest | 0 | 3 | 3 | + * + * NUM_CONTAINERS | 0 | 15 | 15 | + * + */ + + Assert.assertTrue(context.isSlidingWindowIncrementCountValid( + RegisterApplicationMasterRequest.class.getName())); + Assert.assertTrue(context.isSlidingWindowIncrementCountValid( + FinishApplicationMasterRequest.class.getName())); + Assert.assertTrue(context + .isSlidingWindowIncrementCountValid(AllocateRequest.class.getName())); + Assert.assertTrue(context.isSlidingWindowIncrementCountValid( + ResourceRequest.class.getName(), 3)); + Assert.assertTrue(context.isSlidingWindowIncrementCountValid( + ContainerResourceChangeRequest.class.getName(), 3)); + Assert.assertTrue(context + .isSlidingWindowIncrementCountValid(ContainerId.class.getName(), 3)); + Assert.assertTrue(context.isSlidingWindowIncrementCountValid( + ResourceBlacklistRequest.class.getName(), 3)); + Assert.assertTrue(context + .isSlidingWindowIncrementCountValid(ResourceRequest.class.getName() + + PreventDoSContextAMPImpl.NUM_CONTAINERS, 15)); + + /* + * Current Status of Sliding Window: First Slot + * + * RegisterApplicationMasterRequest | 1 | 1 | 1 | + * + * FinishApplicationMasterRequest | 1 | 1 | 1 | + * + * AllocateRequest | 1 | 1 | 1 | + * + * ContainerResourceChangeRequest | 3 | 3 | 3 | + * + * ResourceRequest | 3 | 3 | 3 | + * + * ContainerId | 3 | 3 | 3 | + * + * ResourceBlacklistRequest | 3 | 3 | 3 | + * + * NUM_CONTAINERS | 15 | 15 | 15 | + * + */ + + } + + /** + * This tests validates the correctness of the logic of + * PreventDoSContextAMPImpl for the entire application's Requests checks. + */ + @Test + public void testRequestsLimit() { + Configuration conf = new YarnConfiguration(); + setUpRequestsLimit(conf); + PreventDoSContextAMPImpl context = new PreventDoSContextAMPImpl(conf); + + // Check RegisterApplicationMasterRequest entire execution of the + // application. Limit is 3. + Assert.assertTrue(context.isRegisterRequestsIncrementValid()); + Assert.assertTrue(context.isRegisterRequestsIncrementValid()); + Assert.assertTrue(context.isRegisterRequestsIncrementValid()); + + // Exceed the limit + Assert.assertFalse(context.isRegisterRequestsIncrementValid()); + + // Check FinishApplicationMasterRequest entire execution of the + // application. Limit is 3. + Assert.assertTrue(context.isFinishRequestsIncrementValid()); + Assert.assertTrue(context.isFinishRequestsIncrementValid()); + Assert.assertTrue(context.isFinishRequestsIncrementValid()); + + // Exceed the limit + Assert.assertFalse(context.isFinishRequestsIncrementValid()); + + // Check AllocateRequest entire execution of the application. Limit is 3. + Assert.assertTrue(context.isAllocateRequestsIncrementValid()); + Assert.assertTrue(context.isAllocateRequestsIncrementValid()); + Assert.assertTrue(context.isAllocateRequestsIncrementValid()); + + // Exceed the limit + Assert.assertFalse(context.isAllocateRequestsIncrementValid()); + + // Check ResourceRequest entire execution of the application. Limit is 10. + Assert.assertTrue(context.isAskListRequestsIncrementValid(3)); + Assert.assertTrue(context.isAskListRequestsIncrementValid(3)); + Assert.assertTrue(context.isAskListRequestsIncrementValid(3)); + + // Exceed the limit + Assert.assertFalse(context.isAskListRequestsIncrementValid(3)); + + // Check ContainerResourceChangeRequest entire execution of the + // application. Limit is 10. + Assert.assertTrue(context.isChangeListRequestsIncrementValid(3)); + Assert.assertTrue(context.isChangeListRequestsIncrementValid(3)); + Assert.assertTrue(context.isChangeListRequestsIncrementValid(3)); + + // Exceed the limit + Assert.assertFalse(context.isChangeListRequestsIncrementValid(3)); + + // Check ResourceBlacklistRequest entire execution of the + // application. Limit is 10. + Assert.assertTrue(context.isBlackListRequestsIncrementValid(3)); + Assert.assertTrue(context.isBlackListRequestsIncrementValid(3)); + Assert.assertTrue(context.isBlackListRequestsIncrementValid(3)); + + // Exceed the limit + Assert.assertFalse(context.isBlackListRequestsIncrementValid(3)); + + // Check ContainerId entire execution of the application. Limit is 10. + Assert.assertTrue(context.isReleaseListRequestsIncrementValid(3)); + Assert.assertTrue(context.isReleaseListRequestsIncrementValid(3)); + Assert.assertTrue(context.isReleaseListRequestsIncrementValid(3)); + + // Exceed the limit + Assert.assertFalse(context.isReleaseListRequestsIncrementValid(3)); + + // Check NUM_CONTAINERS entire execution of the application. Limit is 50. + Assert.assertTrue(context.isContainersRequestsIncrementValid(15)); + Assert.assertTrue(context.isContainersRequestsIncrementValid(15)); + Assert.assertTrue(context.isContainersRequestsIncrementValid(15)); + + // Exceed the limit + Assert.assertFalse(context.isContainersRequestsIncrementValid(15)); + } + +} diff --git hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-common/src/test/java/org/apache/hadoop/yarn/server/preventdos/TestSlidingWindowCounter.java hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-common/src/test/java/org/apache/hadoop/yarn/server/preventdos/TestSlidingWindowCounter.java new file mode 100644 index 0000000..2849cde --- /dev/null +++ hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-common/src/test/java/org/apache/hadoop/yarn/server/preventdos/TestSlidingWindowCounter.java @@ -0,0 +1,125 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.hadoop.yarn.server.preventdos; + +import java.util.Map; +import org.junit.Test; +import org.junit.Assert; + +/** + * Unit tests for SlidingWindowCounter. + */ +public class TestSlidingWindowCounter { + + private static final int ANY_WINDOW_LENGTH_IN_SLOTS = 2; + private static final Object ANY_OBJECT = "ANY_OBJECT"; + + private SlidingWindowCounter getSlidingWindowCounter( + int windowLength) { + return new SlidingWindowCounter(windowLength); + } + + @Test + public void lessThanTwoSlotsShouldThrowIAE() { + int[] illegalWindowLengths = new int[] { -10, -3, -2, -1, 0, 1 }; + for (int windowLengthInSlots : illegalWindowLengths) { + try { + getSlidingWindowCounter(windowLengthInSlots); + Assert.fail(); + } catch (IllegalArgumentException e) { + + } + + } + } + + @Test + public void twoOrMoreSlotsShouldBeValid() { + int[] legalWindowLengths = new int[] { 2, 3, 20 }; + for (int windowLengthInSlots : legalWindowLengths) { + getSlidingWindowCounter(windowLengthInSlots); + } + } + + @Test + public void newInstanceShouldHaveEmptyCounts() { + // given + SlidingWindowCounter counter = + getSlidingWindowCounter(ANY_WINDOW_LENGTH_IN_SLOTS); + + // when + Map counts = counter.getCountsThenAdvanceWindow(); + + // then + Assert.assertTrue(counts.isEmpty()); + } + + @Test + public void testCounterWithSimulatedRuns() { + Object[][] dataProvider = + new Object[][] { + { 2, new int[] { 3, 2, 0, 0, 1, 0, 0, 0 }, + new long[] { 3, 5, 2, 0, 1, 1, 0, 0 } }, + { 3, new int[] { 3, 2, 0, 0, 1, 0, 0, 0 }, + new long[] { 3, 5, 5, 2, 1, 1, 1, 0 } }, + { 4, new int[] { 3, 2, 0, 0, 1, 0, 0, 0 }, + new long[] { 3, 5, 5, 5, 3, 1, 1, 1 } }, + { 5, new int[] { 3, 2, 0, 0, 1, 0, 0, 0 }, + new long[] { 3, 5, 5, 5, 6, 3, 1, 1 } }, + { 5, new int[] { 3, 11, 5, 13, 7, 17, 0, 3, 50, 600, 7000 }, + new long[] { 3, 14, 19, 32, 39, 53, 42, 40, 77, 670, + 7653 } }, }; + for (Object[] data : dataProvider) { + int windowLengthInSlots = (int) data[0]; + int[] incrementsPerIteration = (int[]) data[1]; + long[] expCountsPerIteration = (long[]) data[2]; + + // given + SlidingWindowCounter counter = + getSlidingWindowCounter(windowLengthInSlots); + ; + int numIterations = incrementsPerIteration.length; + + for (int i = 0; i < numIterations; i++) { + int numIncrements = incrementsPerIteration[i]; + long expCounts = expCountsPerIteration[i]; + // Objects are absent if they were zero both this iteration + // and the last -- if only this one, we need to report zero. + boolean expAbsent = ((expCounts == 0) + && ((i == 0) || (expCountsPerIteration[i - 1] == 0))); + + // given (for this iteration) + for (int j = 0; j < numIncrements; j++) { + counter.incrementCount(ANY_OBJECT, 1); + } + + // when (for this iteration) + Map counts = counter.getCountsThenAdvanceWindow(); + + // then (for this iteration) + if (expAbsent) { + Assert.assertFalse(counts.keySet().contains(ANY_OBJECT)); + } else { + Assert.assertEquals(expCounts, counts.get(ANY_OBJECT).longValue()); + } + } + } + } + +} \ No newline at end of file diff --git hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-common/src/test/java/org/apache/hadoop/yarn/server/preventdos/TestSlotBasedAccumulator.java hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-common/src/test/java/org/apache/hadoop/yarn/server/preventdos/TestSlotBasedAccumulator.java new file mode 100644 index 0000000..21b4289 --- /dev/null +++ hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-common/src/test/java/org/apache/hadoop/yarn/server/preventdos/TestSlotBasedAccumulator.java @@ -0,0 +1,208 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.hadoop.yarn.server.preventdos; + +import java.util.Map; +import org.junit.Test; +import org.junit.Assert; + +/** + * Unit tests for SlotBasedAccumulator. + */ +public class TestSlotBasedAccumulator { + + private static final int ANY_NUM_SLOTS = 1; + private static final int ANY_SLOT = 0; + private static final Object ANY_OBJECT = "ANY_OBJECT"; + + private SlotBasedAccumulator getSlotBasedAccumulator(int numSlots) { + return new SlotBasedAccumulator(numSlots); + } + + @Test + public void negativeOrZeroNumSlotsShouldThrowIAE() { + int[] illegalNumSlotsData = new int[] { -10, -3, -2, -1, 0 }; + for (int numSlots : illegalNumSlotsData) { + try { + getSlotBasedAccumulator(numSlots); + Assert.fail(); + } catch (IllegalArgumentException e) { + + } + } + + } + + @Test + public void positiveNumSlotsShouldBeOk() { + int[] legalNumSlotsData = new int[] { 1, 2, 3, 20 }; + for (int numSlots : legalNumSlotsData) { + getSlotBasedAccumulator(numSlots); + } + } + + @Test + public void newInstanceShouldHaveEmptyCounts() { + // given + SlotBasedAccumulator counter = + getSlotBasedAccumulator(ANY_NUM_SLOTS); + + // when + Map counts = counter.getCounts(); + + // then + Assert.assertTrue(counts.isEmpty()); + } + + @Test + public void shouldReturnNonEmptyCountsWhenAtLeastOneObjectWasCounted() { + // given + SlotBasedAccumulator counter = + getSlotBasedAccumulator(ANY_NUM_SLOTS); + counter.incrementCount(ANY_OBJECT, ANY_SLOT); + + // when + Map counts = counter.getCounts(); + + // then + Assert.assertFalse(counts.isEmpty()); + + // additional tests that go beyond what this test is primarily about + Assert.assertEquals(1, counts.size()); + Assert.assertEquals(1L, counts.get(ANY_OBJECT).longValue()); + } + + public void shouldIncrementCount() { + String[] objects = new String[] { "foo", "bar" }; + int[] expCounts = new int[] { 3, 2 }; + // given + SlotBasedAccumulator counter = + getSlotBasedAccumulator(ANY_NUM_SLOTS); + + // when + for (int i = 0; i < objects.length; i++) { + Object obj = objects[i]; + int numIncrements = expCounts[i]; + for (int j = 0; j < numIncrements; j++) { + counter.incrementCount(obj, ANY_SLOT); + } + } + + // then + for (int i = 0; i < objects.length; i++) { + Assert.assertEquals(expCounts[i], counter.getCount(objects[i], ANY_SLOT)); + } + Assert.assertEquals(0, counter.getCount("nonexistentObject", ANY_SLOT)); + } + + @Test + public void shouldReturnZeroForNonexistentObject() { + // given + SlotBasedAccumulator counter = + getSlotBasedAccumulator(ANY_NUM_SLOTS); + + // when + counter.incrementCount("somethingElse", ANY_SLOT); + + // then + Assert.assertEquals(0, counter.getCount("nonexistentObject", ANY_SLOT)); + } + + @Test + public void shouldIncrementCountOnlyOneSlotAtATime() { + // given + int numSlots = 3; + Object obj = Long.valueOf(10); + SlotBasedAccumulator counter = getSlotBasedAccumulator(numSlots); + + // when (empty) + // then + Assert.assertEquals(0, counter.getCount(obj, 0)); + Assert.assertEquals(0, counter.getCount(obj, 1)); + Assert.assertEquals(0, counter.getCount(obj, 2)); + + // when + counter.incrementCount(obj, 1); + + // then + Assert.assertEquals(0, counter.getCount(obj, 0)); + Assert.assertEquals(1, counter.getCount(obj, 1)); + Assert.assertEquals(0, counter.getCount(obj, 2)); + } + + @Test + public void wipeSlotShouldSetAllCountsInSlotToZero() { + // given + SlotBasedAccumulator counter = + getSlotBasedAccumulator(ANY_NUM_SLOTS); + Object countWasOne = "countWasOne"; + Object countWasThree = "countWasThree"; + counter.incrementCount(countWasOne, ANY_SLOT); + counter.incrementCount(countWasThree, ANY_SLOT); + counter.incrementCount(countWasThree, ANY_SLOT); + counter.incrementCount(countWasThree, ANY_SLOT); + + // when + counter.wipeSlot(ANY_SLOT); + + // then + Assert.assertEquals(0, counter.getCount(countWasOne, ANY_SLOT)); + Assert.assertEquals(0, counter.getCount(countWasThree, ANY_SLOT)); + } + + @Test + public void wipeZerosShouldRemoveAnyObjectsWithZeroTotalCount() { + // given + SlotBasedAccumulator counter = getSlotBasedAccumulator(2); + int wipeSlot = 0; + int otherSlot = 1; + Object willBeRemoved = "willBeRemoved"; + Object willContinueToBeTracked = "willContinueToBeTracked"; + counter.incrementCount(willBeRemoved, wipeSlot); + counter.incrementCount(willContinueToBeTracked, wipeSlot); + counter.incrementCount(willContinueToBeTracked, otherSlot); + + // when + counter.wipeSlot(wipeSlot); + counter.wipeZeros(); + + // then + Assert.assertFalse(counter.getCounts().keySet().contains(willBeRemoved)); + Assert.assertTrue( + counter.getCounts().keySet().contains(willContinueToBeTracked)); + } + + @Test + public void performanceTest() { + int numRounds = 100; + int numSlots = 1000; + int numObjects = 100; + SlotBasedAccumulator counter = getSlotBasedAccumulator(numSlots); + // Wipe and then increment each slot, simulating a sliding window + while (--numRounds > 0) { + for (int slot = 0; slot < numSlots; ++slot) { + counter.wipeSlot(slot); + for (int s = 0; s < numObjects; ++s) { + counter.incrementCount("test" + s, slot); + } + counter.wipeZeros(); + } + } + } +} \ No newline at end of file diff --git hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/amrmproxy/AMPDoSRequestInterceptor.java hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/amrmproxy/AMPDoSRequestInterceptor.java new file mode 100644 index 0000000..8b9c06c --- /dev/null +++ hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/amrmproxy/AMPDoSRequestInterceptor.java @@ -0,0 +1,501 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.hadoop.yarn.server.nodemanager.amrmproxy; + +import java.io.IOException; + +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.yarn.api.protocolrecords.AllocateRequest; +import org.apache.hadoop.yarn.api.protocolrecords.AllocateResponse; +import org.apache.hadoop.yarn.api.protocolrecords.FinishApplicationMasterRequest; +import org.apache.hadoop.yarn.api.protocolrecords.FinishApplicationMasterResponse; +import org.apache.hadoop.yarn.api.protocolrecords.RegisterApplicationMasterRequest; +import org.apache.hadoop.yarn.api.protocolrecords.RegisterApplicationMasterResponse; +import org.apache.hadoop.yarn.api.records.ContainerId; +import org.apache.hadoop.yarn.api.records.ContainerResourceChangeRequest; +import org.apache.hadoop.yarn.api.records.ResourceBlacklistRequest; +import org.apache.hadoop.yarn.api.records.ResourceRequest; +import org.apache.hadoop.yarn.conf.YarnConfiguration; +import org.apache.hadoop.yarn.exceptions.YarnException; +import org.apache.hadoop.yarn.server.preventdos.PreventDoSAction; +import org.apache.hadoop.yarn.server.preventdos.PreventDoSAttackException; +import org.apache.hadoop.yarn.server.preventdos.PreventDoSCause; +import org.apache.hadoop.yarn.server.preventdos.PreventDoSContextAMPImpl; +import org.apache.hadoop.yarn.server.preventdos.PreventDoSPolicy; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import com.google.common.annotations.VisibleForTesting; +import com.google.common.base.Preconditions; + +/** + *

+ * The AMPDoSRequestInterceptor runs on the NodeManager and is modeled as an + * AMRMProxy request interceptor. It is responsible for the + * following: + *

+ *
    + *
  • Check every request's value. One example of Denial of Service attacks is + * to attempt to "crash" it by giving wrong input parameters, hereby preventing + * uses for legitimate users.
  • + *
  • Check every request's value size. One example of Denial of Service + * attacks is to attempt to “fill” the server’s data structures, hereby + * preventing uses for legitimate users’ calls.
  • + *
  • Check and count all the requests. One example of Denial of Service + * attacks is to attempt to “flood” the server, hereby preventing legitimate + * user traffic.
  • + *
+ */ +public final class AMPDoSRequestInterceptor extends AbstractRequestInterceptor { + + private static final Logger LOG = + LoggerFactory.getLogger(AMPDoSRequestInterceptor.class); + + private PreventDoSContextAMPImpl context; + private PreventDoSAction action; + + @Override + public void init(AMRMProxyApplicationContext appContext) { + super.init(appContext); + Configuration conf = this.getConf(); + PreventDoSPolicy policy = + PreventDoSPolicy.valueOf(conf.get(YarnConfiguration.DOS_ACTION, + YarnConfiguration.DEFAULT_DOS_ACTION)); + + LOG.info("The policy for prevent DoS is " + policy); + + action = new PreventDoSAction(policy); + + context = new PreventDoSContextAMPImpl(conf); + + } + + /** + * Check and validate the correctness of the RegisterApplicationMasterRequest. + * + * @param request registration request + * @return Register Response + * @throws YarnException YarnException + * @throws IOException IOException + */ + @Override + public RegisterApplicationMasterResponse registerApplicationMaster( + final RegisterApplicationMasterRequest request) + throws YarnException, IOException { + + Preconditions.checkNotNull(request, + "RegisterApplicationMasterRequest is null."); + + // Check value of a single request + + // Host -> Size of String + if (!context.isLengthRegisterHostValid(request.getHost())) { + request.setHost(context.lengthRegisterHostTruncate(request.getHost())); + action.prevent(PreventDoSCause.REQUEST_GET_HOST_CAUSE); + } + + // RPC Port -> Valid port number + if (!context.isRPCPortValid(request.getRpcPort())) { + action.prevent(PreventDoSCause.REQUEST_GET_RPC_PORT_CAUSE); + } + + // Tracking URL -> Valid URL expression + if (!context.isURLValid(request.getTrackingUrl())) { + action.prevent(PreventDoSCause.REQUEST_GET_TRACKING_URL_CAUSE); + } + + // Increase the sliding window counter + if (!context.isSlidingWindowIncrementCountValid( + RegisterApplicationMasterRequest.class.getName())) { + action.prevent(PreventDoSCause.SLIDINGWINDOW_REGISTER_AM_CAUSE); + } + + // Increase requests counter + if (!context.isRegisterRequestsIncrementValid()) { + action.prevent(PreventDoSCause.REQUESTS_REGISTER_AM_CAUSE); + } + + LOG.debug("Forwarding registration request to the next interceptor."); + return getNextInterceptor().registerApplicationMaster(request); + } + + /** + * Check and validate the correctness of the FinishApplicationMasterRequest. + * + * @param request finishing request + * @return Finish Response + * @throws YarnException YarnException + * @throws IOException IOException + */ + @Override + public FinishApplicationMasterResponse finishApplicationMaster( + final FinishApplicationMasterRequest request) + throws YarnException, IOException { + + Preconditions.checkNotNull(request, + "FinishApplicationMasterResponse is null."); + + // Check value of a single request + + // Diagnostic -> Size of String + if (!context.isLengthFinishDiagnosticValid(request.getDiagnostics())) { + request.setDiagnostics( + context.lengthFinishDiagnosticTruncate(request.getDiagnostics())); + action.prevent(PreventDoSCause.REQUEST_GET_DIAGNOSTIC_CAUSE); + } + + // Tracking URL -> Valid URL expression + if (!context.isURLValid(request.getTrackingUrl())) { + action.prevent(PreventDoSCause.REQUEST_GET_TRACKING_URL_FINISH_CAUSE); + } + + // Increase the sliding window counter + if (!context.isSlidingWindowIncrementCountValid( + FinishApplicationMasterRequest.class.getName())) { + action.prevent(PreventDoSCause.SLIDINGWINDOW_FINISH_AM_CAUSE); + } + + // Increase requests counter + if (!context.isAllocateRequestsIncrementValid()) { + action.prevent(PreventDoSCause.REQUESTS_FINISH_AM_CAUSE); + } + + LOG.debug("Forwarding finish application request to the next interceptor"); + return getNextInterceptor().finishApplicationMaster(request); + } + + /** + * Check and validate the correctness of the AllocateRequest. + * + * @param request allocation request + * @return Allocate Response + * @throws YarnException YarnException + * @throws IOException IOException + */ + @Override + public AllocateResponse allocate(final AllocateRequest request) + throws YarnException, IOException { + Preconditions.checkNotNull(request, "AllocateRequest is null."); + + allocateLimitRequestCheck(request); + + allocateThresholdSlidingWindowCheck(request); + + allocateLimitRequestsCheck(request); + + LOG.debug("Forwarding allocate request to the next interceptor."); + return getNextInterceptor().allocate(request); + } + + /** + * Check and validate the correctness of the single AllocateRequest. + */ + private void allocateLimitRequestCheck(AllocateRequest request) + throws PreventDoSAttackException { + + // Check String and numerical values + + // ResponseId -> Greater than previous one + if (!context.isLastResponseIdValid(request.getResponseId())) { + action.prevent(PreventDoSCause.REQUEST_GET_RESPONSE_ID_CAUSE); + } + + // App Progress -> Between 0 and 100 + if (!context.isProgressValid(request.getProgress())) { + action.prevent(PreventDoSCause.REQUEST_GET_PROGRESS_CAUSE); + } + + if (request.getAskList() != null) { + + // AskList -> Size list + if (!context.isSizeAskListValid(request.getAskList())) { + action.prevent(PreventDoSCause.REQUEST_GET_ASK_LIST_CAUSE); + } + + // Check String and numerical values for ResourceRequest + for (ResourceRequest rr : request.getAskList()) { + + // Priority -> Valid priority value + if (!context.isPriorityValid(rr.getPriority())) { + action.prevent(PreventDoSCause.REQUEST_GET_PRIORITY_CAUSE); + } + + // ResourceName -> Size of String + if (!context.isLengthAllocateResourceNameValid(rr.getResourceName())) { + rr.setResourceName( + context.lengthAllocateResourceNameTruncate(rr.getResourceName())); + action.prevent(PreventDoSCause.REQUEST_GET_RESOURCE_NAME_CAUSE); + } + + // Capacity -> Within a range + if (!context.isCapabilityValid(rr.getCapability())) { + action.prevent(PreventDoSCause.REQUEST_GET_CAPABILITY_CAUSE); + } + + // Containers -> Within a range + if (!context.isSizeContainersValid(rr.getNumContainers())) { + action.prevent(PreventDoSCause.REQUEST_GET_NUM_CONTAINERS_CAUSE); + } + + // Label Expression -> Size of String + if (!context + .isLengthAllocateNodeLabelValid(rr.getNodeLabelExpression())) { + rr.setNodeLabelExpression(context + .lengthAllocateNodeLabelTruncate(rr.getNodeLabelExpression())); + action.prevent(PreventDoSCause.REQUEST_GET_NODE_LABEL_CAUSE); + } + } + } + + // ReleaseList -> Size list + if (!context.isSizeReleaseListValid(request.getReleaseList())) { + action.prevent(PreventDoSCause.REQUEST_GET_RELEASE_LIST_CAUSE); + } + + // Check Resource Blacklist Request + if (request.getResourceBlacklistRequest() != null) { + + if (request.getResourceBlacklistRequest() + .getBlacklistAdditions() != null) { + + // Blacklist Additions -> Size list + if (!context.isSizeBlackListValid( + request.getResourceBlacklistRequest().getBlacklistAdditions())) { + action.prevent(PreventDoSCause.REQUEST_GET_BLACKLIST_ADD_LIST_CAUSE); + } + } + if (request.getResourceBlacklistRequest() + .getBlacklistRemovals() != null) { + + // Blacklist Removals -> Size list + if (!context.isSizeBlackListValid( + request.getResourceBlacklistRequest().getBlacklistRemovals())) { + action.prevent(PreventDoSCause.REQUEST_GET_BLACKLIST_REM_LIST_CAUSE); + } + } + } + + if (request.getIncreaseRequests() != null) { + // Increase Request -> Size list + if (!context.isSizeChangeListValid(request.getIncreaseRequests())) { + action.prevent(PreventDoSCause.REQUEST_GET_INCREASE_LIST_CAUSE); + } + } + + if (request.getDecreaseRequests() != null) { + // Decrease Request -> Size list + if (!context.isSizeChangeListValid(request.getDecreaseRequests())) { + action.prevent(PreventDoSCause.REQUEST_GET_DECREASE_LIST_CAUSE); + } + } + } + + /** + * Check and validate the correctness of the AllocateRequest over short amount + * of time with sliding window. + */ + private void allocateThresholdSlidingWindowCheck(AllocateRequest request) + throws PreventDoSAttackException { + + // Increase the sliding window counter of AllocateRequest + if (!context + .isSlidingWindowIncrementCountValid(AllocateRequest.class.getName())) { + action.prevent(PreventDoSCause.SLIDINGWINDOW_ALLOCATE_AM_CAUSE); + } + + if (request.getAskList() != null) { + + // Increase the sliding window counter of AskList + if (!context.isSlidingWindowIncrementCountValid( + ResourceRequest.class.getName(), request.getAskList().size())) { + action.prevent(PreventDoSCause.SLIDINGWINDOW_GET_ASK_LIST_CAUSE); + } + + for (ResourceRequest rr : request.getAskList()) { + + // Increase the sliding window counter of NumContainers + if (!context.isSlidingWindowIncrementCountValid( + ResourceRequest.class.getName() + + PreventDoSContextAMPImpl.NUM_CONTAINERS, + rr.getNumContainers())) { + action + .prevent(PreventDoSCause.SLIDINGWINDOW_GET_NUM_CONTAINERS_CAUSE); + } + } + } + + if (request.getReleaseList() != null) { + + // Increase the sliding window counter of ReleaseList + if (!context.isSlidingWindowIncrementCountValid( + ContainerId.class.getName(), request.getReleaseList().size())) { + action.prevent(PreventDoSCause.SLIDINGWINDOW_GET_RELEASE_LIST_CAUSE); + } + } + + if (request.getResourceBlacklistRequest() != null) { + if (request.getResourceBlacklistRequest() + .getBlacklistAdditions() != null) { + + // Increase the sliding window counter of BlackList Request Increase + if (!context.isSlidingWindowIncrementCountValid( + ResourceBlacklistRequest.class.getName(), + request.getResourceBlacklistRequest().getBlacklistAdditions() + .size())) { + action.prevent( + PreventDoSCause.SLIDINGWINDOW_GET_BLACKLIST_ADD_LIST_CAUSE); + } + } + + if (request.getResourceBlacklistRequest() + .getBlacklistRemovals() != null) { + + // Increase the sliding window counter of BlackList Request Decrease + if (!context.isSlidingWindowIncrementCountValid( + ResourceBlacklistRequest.class.getName(), request + .getResourceBlacklistRequest().getBlacklistRemovals().size())) { + action.prevent( + PreventDoSCause.SLIDINGWINDOW_GET_BLACKLIST_REM_LIST_CAUSE); + } + } + } + + if (request.getIncreaseRequests() != null) { + + // Increase the sliding window counter of IncreaseRequest + if (!context.isSlidingWindowIncrementCountValid( + ContainerResourceChangeRequest.class.getName(), + request.getIncreaseRequests().size())) { + action.prevent(PreventDoSCause.SLIDINGWINDOW_GET_INCREASE_LIST_CAUSE); + } + } + + if (request.getDecreaseRequests() != null) { + + // Increase the sliding window counter of DecreaseRequest + if (!context.isSlidingWindowIncrementCountValid( + ContainerResourceChangeRequest.class.getName(), + request.getDecreaseRequests().size())) { + action.prevent(PreventDoSCause.SLIDINGWINDOW_GET_DECREASE_LIST_CAUSE); + } + } + } + + /** + * Check and validate the correctness of the AllocateRequest over the entire + * execution of the application. + */ + private void allocateLimitRequestsCheck(AllocateRequest request) + throws PreventDoSAttackException { + + // Increase requests counter + if (!context.isAllocateRequestsIncrementValid()) { + action.prevent(PreventDoSCause.REQUESTS_ALLOCATE_AM_CAUSE); + } + + if (request.getAskList() != null) { + + // Increase requests counters for AskList + if (!context + .isAskListRequestsIncrementValid(request.getAskList().size())) { + action.prevent(PreventDoSCause.REQUESTS_GET_ASK_LIST_CAUSE); + } + + for (ResourceRequest rr : request.getAskList()) { + + // Increase requests counters for NumContainers + if (!context + .isContainersRequestsIncrementValid(rr.getNumContainers())) { + action.prevent(PreventDoSCause.REQUESTS_GET_NUM_CONTAINERS_CAUSE); + } + } + } + + if (request.getReleaseList() != null) { + + // Increase requests counters for ReleaseList + if (!context.isReleaseListRequestsIncrementValid( + request.getReleaseList().size())) { + action.prevent(PreventDoSCause.REQUESTS_GET_RELEASE_LIST_CAUSE); + } + } + + if (request.getResourceBlacklistRequest() != null) { + if (request.getResourceBlacklistRequest() + .getBlacklistAdditions() != null) { + + // Increase requests counters for ResourceBlacklist + if (!context.isBlackListRequestsIncrementValid(request + .getResourceBlacklistRequest().getBlacklistAdditions().size())) { + action.prevent(PreventDoSCause.REQUESTS_GET_BLACKLIST_ADD_LIST_CAUSE); + } + } + if (request.getResourceBlacklistRequest() + .getBlacklistRemovals() != null) { + + // Increase requests counters for ResourceBlacklist + if (!context.isBlackListRequestsIncrementValid(request + .getResourceBlacklistRequest().getBlacklistRemovals().size())) { + action.prevent(PreventDoSCause.REQUESTS_GET_BLACKLIST_REM_LIST_CAUSE); + } + } + } + + if (request.getIncreaseRequests() != null) { + + // Increase requests counter for IncreaseRequests + if (!context.isChangeListRequestsIncrementValid( + request.getIncreaseRequests().size())) { + action.prevent(PreventDoSCause.REQUESTS_GET_INCREASE_LIST_CAUSE); + } + } + + if (request.getDecreaseRequests() != null) { + + // Increase requests counter for DecreaseRequests + if (!context.isChangeListRequestsIncrementValid( + request.getDecreaseRequests().size())) { + action.prevent(PreventDoSCause.REQUESTS_GET_DECREASE_LIST_CAUSE); + } + } + } + + @Override + public void shutdown() { + if (context != null) { + context.shutdownTimer(); + } + } + + @VisibleForTesting + public void setAction(PreventDoSAction action) { + this.action = action; + } + + @VisibleForTesting + public void setContext(PreventDoSContextAMPImpl context) { + this.context = context; + } + + @Override + public void setNextInterceptor(RequestInterceptor next) { + super.setNextInterceptor(next); + } + +} diff --git hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/test/java/org/apache/hadoop/yarn/server/nodemanager/amrmproxy/MockDefaultInterceptor.java hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/test/java/org/apache/hadoop/yarn/server/nodemanager/amrmproxy/MockDefaultInterceptor.java new file mode 100644 index 0000000..827c8a8 --- /dev/null +++ hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/test/java/org/apache/hadoop/yarn/server/nodemanager/amrmproxy/MockDefaultInterceptor.java @@ -0,0 +1,107 @@ +package org.apache.hadoop.yarn.server.nodemanager.amrmproxy; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.Collections; + +import org.apache.commons.lang.NotImplementedException; +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.yarn.api.protocolrecords.AllocateRequest; +import org.apache.hadoop.yarn.api.protocolrecords.AllocateResponse; +import org.apache.hadoop.yarn.api.protocolrecords.FinishApplicationMasterRequest; +import org.apache.hadoop.yarn.api.protocolrecords.FinishApplicationMasterResponse; +import org.apache.hadoop.yarn.api.protocolrecords.RegisterApplicationMasterRequest; +import org.apache.hadoop.yarn.api.protocolrecords.RegisterApplicationMasterResponse; +import org.apache.hadoop.yarn.api.records.ApplicationAccessType; +import org.apache.hadoop.yarn.api.records.Container; +import org.apache.hadoop.yarn.api.records.NMToken; +import org.apache.hadoop.yarn.api.records.NodeReport; +import org.apache.hadoop.yarn.api.records.Resource; +import org.apache.hadoop.yarn.exceptions.YarnException; +import org.apache.hadoop.yarn.server.api.protocolrecords.DistributedSchedulingAllocateRequest; +import org.apache.hadoop.yarn.server.api.protocolrecords.DistributedSchedulingAllocateResponse; +import org.apache.hadoop.yarn.server.api.protocolrecords.RegisterDistributedSchedulingAMResponse; + +/** + * Mock Class For DefaultRequestInterceptor. It implements only the + * ApplicationMasterProtocol APIs. + */ +public class MockDefaultInterceptor implements RequestInterceptor { + + @Override + public RegisterDistributedSchedulingAMResponse registerApplicationMasterForDistributedScheduling( + RegisterApplicationMasterRequest request) + throws YarnException, IOException { + throw new NotImplementedException(); + } + + @Override + public DistributedSchedulingAllocateResponse allocateForDistributedScheduling( + DistributedSchedulingAllocateRequest request) + throws YarnException, IOException { + throw new NotImplementedException(); + } + + @Override + public RegisterApplicationMasterResponse registerApplicationMaster( + RegisterApplicationMasterRequest request) + throws YarnException, IOException { + return RegisterApplicationMasterResponse.newInstance( + Resource.newInstance(512, 1), Resource.newInstance(512000, 1024), + Collections. emptyMap(), + ByteBuffer.wrap("fake_key".getBytes()), + Collections. emptyList(), "default", + Collections. emptyList()); + } + + @Override + public FinishApplicationMasterResponse finishApplicationMaster( + FinishApplicationMasterRequest request) + throws YarnException, IOException { + return FinishApplicationMasterResponse.newInstance(false); + } + + @Override + public AllocateResponse allocate(AllocateRequest request) + throws YarnException, IOException { + return AllocateResponse.newInstance(0, null, null, + new ArrayList(), null, null, 1, null, null); + } + + @Override + public void setConf(Configuration conf) { + throw new NotImplementedException(); + } + + @Override + public Configuration getConf() { + throw new NotImplementedException(); + } + + @Override + public void init(AMRMProxyApplicationContext ctx) { + throw new NotImplementedException(); + } + + @Override + public void shutdown() { + throw new NotImplementedException(); + } + + @Override + public void setNextInterceptor(RequestInterceptor nextInterceptor) { + throw new NotImplementedException(); + } + + @Override + public RequestInterceptor getNextInterceptor() { + throw new NotImplementedException(); + } + + @Override + public AMRMProxyApplicationContext getApplicationContext() { + throw new NotImplementedException(); + } + +} diff --git hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/test/java/org/apache/hadoop/yarn/server/nodemanager/amrmproxy/TestAMPDoSRequestInterceptor.java hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/test/java/org/apache/hadoop/yarn/server/nodemanager/amrmproxy/TestAMPDoSRequestInterceptor.java new file mode 100644 index 0000000..a530e07 --- /dev/null +++ hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/test/java/org/apache/hadoop/yarn/server/nodemanager/amrmproxy/TestAMPDoSRequestInterceptor.java @@ -0,0 +1,1004 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.hadoop.yarn.server.nodemanager.amrmproxy; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; + +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.yarn.api.protocolrecords.AllocateRequest; +import org.apache.hadoop.yarn.api.protocolrecords.FinishApplicationMasterRequest; +import org.apache.hadoop.yarn.api.protocolrecords.RegisterApplicationMasterRequest; +import org.apache.hadoop.yarn.api.records.ApplicationAttemptId; +import org.apache.hadoop.yarn.api.records.ApplicationId; +import org.apache.hadoop.yarn.api.records.ContainerId; +import org.apache.hadoop.yarn.api.records.ContainerResourceChangeRequest; +import org.apache.hadoop.yarn.api.records.FinalApplicationStatus; +import org.apache.hadoop.yarn.api.records.Priority; +import org.apache.hadoop.yarn.api.records.Resource; +import org.apache.hadoop.yarn.api.records.ResourceBlacklistRequest; +import org.apache.hadoop.yarn.api.records.ResourceRequest; +import org.apache.hadoop.yarn.conf.YarnConfiguration; +import org.apache.hadoop.yarn.exceptions.YarnException; +import org.apache.hadoop.yarn.server.preventdos.PreventDoSAction; +import org.apache.hadoop.yarn.server.preventdos.PreventDoSAttackException; +import org.apache.hadoop.yarn.server.preventdos.PreventDoSCause; +import org.apache.hadoop.yarn.server.preventdos.PreventDoSContextAMPImpl; +import org.apache.hadoop.yarn.server.preventdos.PreventDoSPolicy; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +/** + * Unit tests for AMPDoSRequestInterceptor. + */ +public class TestAMPDoSRequestInterceptor { + + private static final int maxStringLength = 10; + private static final int maxRPCPort = 1000; + private static final int maxVCores = 6; + private static final int maxMemory = 100 * 1024; + private static final int maxContainers = 50; + private static final int maxListSize = 20; + private static final int minPriority = 1; + private static final int maxPriority = 10; + + private static final int slidingWindowSize = 3; + private static final int slidingWindowAdvanceTime = 1; + private static final int thresholdNumberRequests = 30; + private static final int thresholdListSize = 1000; + private static final int thresholdContainers = 1000; + + private static final int maxNumberRequests = 10; + private static final int maxListSizeRequests = 100; + private static final int maxContainersRequests = 100; + private static final int rounds = 5; + + private AMPDoSRequestInterceptor ampDoSRequestInterceptor; + + private void setUpRequestLimit(Configuration conf) { + + // Single request configuration for Register AM + + conf.setInt(YarnConfiguration.DOS_REQUEST_LIMIT_LENGTH_HOST, + maxStringLength); + conf.setInt(YarnConfiguration.DOS_REQUEST_LIMIT_RPC_PORT, maxRPCPort); + + // Single request configuration for Finish AM + + conf.setInt(YarnConfiguration.DOS_REQUEST_LIMIT_LENGTH_DIAGNOSTIC, + maxStringLength); + + // Single request configuration for Allocate AM + + conf.setInt(YarnConfiguration.DOS_REQUEST_LIMIT_LENGTH_RESOURCE, + maxStringLength); + conf.setInt(YarnConfiguration.DOS_REQUEST_LIMIT_LENGTH_NODELABEL, + maxStringLength); + + conf.setInt(YarnConfiguration.DOS_REQUEST_LIMIT_MAX_VCORES, maxVCores); + conf.setInt(YarnConfiguration.DOS_REQUEST_LIMIT_MAX_MEMORY, maxMemory); + conf.setInt(YarnConfiguration.DOS_REQUEST_LIMIT_ASK_LIST, maxListSize); + conf.setInt(YarnConfiguration.DOS_REQUEST_LIMIT_RELEASE_LIST, maxListSize); + conf.setInt(YarnConfiguration.DOS_REQUEST_LIMIT_BLACKLIST, maxListSize); + conf.setInt(YarnConfiguration.DOS_REQUEST_LIMIT_CHANGELIST, maxListSize); + conf.setInt(YarnConfiguration.DOS_REQUEST_LIMIT_CONTAINERS, maxContainers); + conf.setInt(YarnConfiguration.DOS_REQUEST_LIMIT_MIN_PRIORITY, minPriority); + conf.setInt(YarnConfiguration.DOS_REQUEST_LIMIT_MAX_PRIORITY, maxPriority); + } + + private void setUpRequestsLimit(Configuration conf) { + // Lifetime configuration for Register AM + + conf.setInt(YarnConfiguration.DOS_MAX_REQUESTS_REGISTER_AM, + maxNumberRequests); + + // Lifetime configuration for Finish AM + + conf.setInt(YarnConfiguration.DOS_MAX_REQUESTS_FINISH_AM, + maxNumberRequests); + + // Lifetime configuration for Allocate AM + conf.setInt(YarnConfiguration.DOS_MAX_REQUESTS_ALLOCATE_AM, + maxNumberRequests); + + conf.setInt(YarnConfiguration.DOS_MAX_REQUESTS_ASK_LIST, + maxListSizeRequests); + conf.setInt(YarnConfiguration.DOS_MAX_REQUESTS_RELEASE_LIST, + maxListSizeRequests); + conf.setInt(YarnConfiguration.DOS_MAX_REQUESTS_BLACKLIST, + maxListSizeRequests); + conf.setInt(YarnConfiguration.DOS_MAX_REQUESTS_CHANGELIST, + maxListSizeRequests); + conf.setInt(YarnConfiguration.DOS_MAX_REQUESTS_NUM_CONTAINERS, + maxContainersRequests); + } + + private void setUpSlidingWindowThreshold(Configuration conf) { + + conf.setInt(YarnConfiguration.DOS_SLIDING_WINDOW_SIZE, slidingWindowSize); + conf.setInt(YarnConfiguration.DOS_SLIDING_WINDOW_ADVANCE_TIME_SEC, + slidingWindowAdvanceTime); + + // Sliding window configuration for Register AM + + conf.setInt(YarnConfiguration.DOS_SLIDING_WINDOW_THRESHOLD_REGISTER_AM, + thresholdNumberRequests); + + // Sliding window configuration for Finish AM + + conf.setInt(YarnConfiguration.DOS_SLIDING_WINDOW_THRESHOLD_FINISH_AM, + thresholdNumberRequests); + + // Sliding window configuration for Allocate AM + + conf.setInt(YarnConfiguration.DOS_SLIDING_WINDOW_THRESHOLD_ALLOCATE_AM, + thresholdNumberRequests); + + conf.setInt(YarnConfiguration.DOS_SLIDING_WINDOW_THRESHOLD_ASK_LIST, + thresholdListSize); + conf.setInt(YarnConfiguration.DOS_SLIDING_WINDOW_THRESHOLD_RELEASE_LIST, + thresholdListSize); + conf.setInt(YarnConfiguration.DOS_SLIDING_WINDOW_THRESHOLD_BLACKLIST, + thresholdListSize); + conf.setInt(YarnConfiguration.DOS_SLIDING_WINDOW_THRESHOLD_CHANGELIST, + thresholdListSize); + + conf.setInt(YarnConfiguration.DOS_SLIDING_WINDOW_THRESHOLD_NUM_CONTAINERS, + thresholdContainers); + } + + @Before + public void setUp() { + Configuration conf = new YarnConfiguration(); + setUpRequestLimit(conf); + setUpSlidingWindowThreshold(conf); + setUpRequestsLimit(conf); + + PreventDoSAction action = new PreventDoSAction(PreventDoSPolicy.REJECT); + PreventDoSContextAMPImpl context = new PreventDoSContextAMPImpl(conf); + + ampDoSRequestInterceptor = new AMPDoSRequestInterceptor(); + ampDoSRequestInterceptor.setContext(context); + ampDoSRequestInterceptor.setAction(action); + RequestInterceptor next = new MockDefaultInterceptor(); + ampDoSRequestInterceptor.setNextInterceptor(next); + } + + /** + * This tests validates the correctness of the logic of + * AMPDoSRequestInterceptor for Single + * {@link RegisterApplicationMasterRequest} checks. + */ + @Test + public void testRequestLimitRegisterApplicationMaster() + throws YarnException, IOException { + + // Execution with valid input + + RegisterApplicationMasterRequest request = RegisterApplicationMasterRequest + .newInstance("hostname", 10, "https://Address:10"); + try { + ampDoSRequestInterceptor.registerApplicationMaster(request); + } catch (Exception e) { + Assert.fail(); + } + + // Execution with invalid hostname + + request = RegisterApplicationMasterRequest.newInstance("hostnameToLong", 10, + "https://Address:10"); + try { + ampDoSRequestInterceptor.registerApplicationMaster(request); + Assert.fail(); + } catch (PreventDoSAttackException e) { + Assert.assertTrue(e.getMessage() + .contains(PreventDoSCause.REQUEST_GET_HOST_CAUSE.toString())); + } + + // Execution with invalid RPC Port + + request = RegisterApplicationMasterRequest.newInstance("hostname", -10, + "https://Address:10"); + try { + ampDoSRequestInterceptor.registerApplicationMaster(request); + Assert.fail(); + } catch (PreventDoSAttackException e) { + Assert.assertTrue(e.getMessage() + .contains(PreventDoSCause.REQUEST_GET_RPC_PORT_CAUSE.toString())); + } + + // Execution with invalid tracking URL + + request = RegisterApplicationMasterRequest.newInstance("hostname", 10, + "BadAddress:10:10"); + try { + ampDoSRequestInterceptor.registerApplicationMaster(request); + Assert.fail(); + } catch (PreventDoSAttackException e) { + Assert.assertTrue(e.getMessage() + .contains(PreventDoSCause.REQUEST_GET_TRACKING_URL_CAUSE.toString())); + } + } + + /** + * This tests validates the correctness of the logic of + * AMPDoSRequestInterceptor for Single {@link FinishApplicationMasterRequest} + * checks. + */ + @Test + public void testRequestLimitFinishApplicationMaster() + throws YarnException, IOException { + + // Execution with valid input + + FinishApplicationMasterRequest request = FinishApplicationMasterRequest + .newInstance(FinalApplicationStatus.SUCCEEDED, "Diagnostic", + "https://Address:10"); + try { + ampDoSRequestInterceptor.finishApplicationMaster(request); + } catch (Exception e) { + Assert.fail(); + } + + // Execution with invalid Diagnostic + + request = FinishApplicationMasterRequest.newInstance( + FinalApplicationStatus.SUCCEEDED, "DiagnosticTooLong", + "https://Address:10"); + try { + ampDoSRequestInterceptor.finishApplicationMaster(request); + Assert.fail(); + } catch (PreventDoSAttackException e) { + Assert.assertTrue(e.getMessage() + .contains(PreventDoSCause.REQUEST_GET_DIAGNOSTIC_CAUSE.toString())); + } + + // Execution with invalid Tracking URL + + request = FinishApplicationMasterRequest.newInstance( + FinalApplicationStatus.SUCCEEDED, "Diagnostic", "BadAddress:10:10"); + try { + ampDoSRequestInterceptor.finishApplicationMaster(request); + Assert.fail(); + } catch (PreventDoSAttackException e) { + Assert.assertTrue(e.getMessage().contains( + PreventDoSCause.REQUEST_GET_TRACKING_URL_FINISH_CAUSE.toString())); + } + } + + /** + * This tests validates the correctness of the logic of + * AMPDoSRequestInterceptor for Single {@link AllocateRequest} checks. + */ + @Test + public void testRequestLimitAllocate() throws YarnException, IOException { + + int responseId = 2; + float progress = 1; + + List resourceAsk = new ArrayList(); + for (int i = 0; i < 2; i++) { + resourceAsk.add(ResourceRequest.newInstance(Priority.newInstance(1), + "host", Resource.newInstance(1, 1), 1, false, "label")); + } + + ApplicationId appId = ApplicationId.newInstance(1, 1); + ApplicationAttemptId attemptId = ApplicationAttemptId.newInstance(appId, 1); + List containersToBeReleased = new ArrayList(); + for (int i = 0; i < 2; i++) { + containersToBeReleased.add(ContainerId.newContainerId(attemptId, i)); + } + + List additions = new ArrayList(); + for (int i = 0; i < 2; i++) { + additions.add("Node"); + } + List removals = new ArrayList(); + for (int i = 0; i < 2; i++) { + removals.add("Node"); + } + ResourceBlacklistRequest resourceBlacklistRequest = + ResourceBlacklistRequest.newInstance(additions, removals); + + ContainerResourceChangeRequest crcr = ContainerResourceChangeRequest + .newInstance(ContainerId.newContainerId(attemptId, 1), + Resource.newInstance(maxMemory, maxVCores)); + List increaseRequests = + new ArrayList(); + for (int i = 0; i < 2; i++) { + increaseRequests.add(crcr); + } + + List decreaseRequests = + new ArrayList(); + for (int i = 0; i < 2; i++) { + decreaseRequests.add(crcr); + } + + // Execution with valid input + + AllocateRequest request = AllocateRequest.newInstance(responseId, progress, + resourceAsk, containersToBeReleased, resourceBlacklistRequest, + increaseRequests, decreaseRequests); + + try { + ampDoSRequestInterceptor.allocate(request); + } catch (Exception e) { + Assert.fail(); + } + + // Execution with invalid ResponseId + + request = AllocateRequest.newInstance(responseId - 1, progress, resourceAsk, + containersToBeReleased, resourceBlacklistRequest, increaseRequests, + decreaseRequests); + try { + ampDoSRequestInterceptor.allocate(request); + Assert.fail(); + } catch (PreventDoSAttackException e) { + Assert.assertTrue(e.getMessage() + .contains(PreventDoSCause.REQUEST_GET_RESPONSE_ID_CAUSE.toString())); + } + + // Execution with invalid Progress + + responseId++; + request = AllocateRequest.newInstance(responseId, -10, resourceAsk, + containersToBeReleased, resourceBlacklistRequest, increaseRequests, + decreaseRequests); + try { + ampDoSRequestInterceptor.allocate(request); + Assert.fail(); + } catch (PreventDoSAttackException e) { + Assert.assertTrue(e.getMessage() + .contains(PreventDoSCause.REQUEST_GET_PROGRESS_CAUSE.toString())); + } + + // Execution with invalid AskList Size + + responseId++; + List resourceAskExceed = new ArrayList(); + for (int i = 0; i < maxListSize + 1; i++) { + resourceAskExceed.add(ResourceRequest.newInstance(Priority.newInstance(1), + "host", Resource.newInstance(1, 1), 1, false, "label")); + } + request = AllocateRequest.newInstance(responseId, progress, + resourceAskExceed, containersToBeReleased, resourceBlacklistRequest, + increaseRequests, decreaseRequests); + try { + ampDoSRequestInterceptor.allocate(request); + Assert.fail(); + } catch (PreventDoSAttackException e) { + Assert.assertTrue(e.getMessage() + .contains(PreventDoSCause.REQUEST_GET_ASK_LIST_CAUSE.toString())); + } + + // Execution with invalid Priority + + responseId++; + resourceAskExceed = new ArrayList(); + for (int i = 0; i < 2; i++) { + resourceAskExceed.add( + ResourceRequest.newInstance(Priority.newInstance(maxPriority + 1), + "host", Resource.newInstance(1, 1), 1, false, "label")); + } + request = AllocateRequest.newInstance(responseId, progress, + resourceAskExceed, containersToBeReleased, resourceBlacklistRequest, + increaseRequests, decreaseRequests); + try { + ampDoSRequestInterceptor.allocate(request); + Assert.fail(); + } catch (PreventDoSAttackException e) { + Assert.assertTrue(e.getMessage() + .contains(PreventDoSCause.REQUEST_GET_PRIORITY_CAUSE.toString())); + } + + // Execution with invalid ResourceName + + responseId++; + resourceAskExceed = new ArrayList(); + for (int i = 0; i < 2; i++) { + resourceAskExceed.add(ResourceRequest.newInstance(Priority.newInstance(1), + "hostNameToLong", Resource.newInstance(1, 1), 1, false, "label")); + } + request = AllocateRequest.newInstance(responseId, progress, + resourceAskExceed, containersToBeReleased, resourceBlacklistRequest, + increaseRequests, decreaseRequests); + try { + ampDoSRequestInterceptor.allocate(request); + Assert.fail(); + } catch (PreventDoSAttackException e) { + Assert.assertTrue(e.getMessage().contains( + PreventDoSCause.REQUEST_GET_RESOURCE_NAME_CAUSE.toString())); + } + + // Execution with invalid Resource + + responseId++; + resourceAskExceed = new ArrayList(); + for (int i = 0; i < 2; i++) { + resourceAskExceed.add(ResourceRequest.newInstance(Priority.newInstance(1), + "host", Resource.newInstance(-1, 1), 1, false, "label")); + } + request = AllocateRequest.newInstance(responseId, progress, + resourceAskExceed, containersToBeReleased, resourceBlacklistRequest, + increaseRequests, decreaseRequests); + try { + ampDoSRequestInterceptor.allocate(request); + Assert.fail(); + } catch (PreventDoSAttackException e) { + Assert.assertTrue(e.getMessage() + .contains(PreventDoSCause.REQUEST_GET_CAPABILITY_CAUSE.toString())); + } + + // Execution with invalid NumContainers + + responseId++; + resourceAskExceed = new ArrayList(); + for (int i = 0; i < 2; i++) { + resourceAskExceed + .add(ResourceRequest.newInstance(Priority.newInstance(1), "host", + Resource.newInstance(1, 1), maxContainers + 1, false, "label")); + } + request = AllocateRequest.newInstance(responseId, progress, + resourceAskExceed, containersToBeReleased, resourceBlacklistRequest, + increaseRequests, decreaseRequests); + try { + ampDoSRequestInterceptor.allocate(request); + Assert.fail(); + } catch (PreventDoSAttackException e) { + Assert.assertTrue(e.getMessage().contains( + PreventDoSCause.REQUEST_GET_NUM_CONTAINERS_CAUSE.toString())); + } + + // Execution with invalid NodeLabel + + responseId++; + resourceAskExceed = new ArrayList(); + for (int i = 0; i < 2; i++) { + resourceAskExceed + .add(ResourceRequest.newInstance(Priority.newInstance(1), "host", + Resource.newInstance(1, 1), 1, false, "labelLongerThanLimit")); + } + request = AllocateRequest.newInstance(responseId, progress, + resourceAskExceed, containersToBeReleased, resourceBlacklistRequest, + increaseRequests, decreaseRequests); + try { + ampDoSRequestInterceptor.allocate(request); + Assert.fail(); + } catch (PreventDoSAttackException e) { + Assert.assertTrue(e.getMessage() + .contains(PreventDoSCause.REQUEST_GET_NODE_LABEL_CAUSE.toString())); + } + + // Execution with invalid ReleaseList size + + responseId++; + + List containersToBeReleasedExceed = + new ArrayList(); + for (int i = 0; i < maxListSize + 1; i++) { + containersToBeReleasedExceed + .add(ContainerId.newContainerId(attemptId, i)); + } + request = AllocateRequest.newInstance(responseId, progress, resourceAsk, + containersToBeReleasedExceed, resourceBlacklistRequest, + increaseRequests, decreaseRequests); + try { + ampDoSRequestInterceptor.allocate(request); + Assert.fail(); + } catch (PreventDoSAttackException e) { + Assert.assertTrue(e.getMessage() + .contains(PreventDoSCause.REQUEST_GET_RELEASE_LIST_CAUSE.toString())); + } + + // Execution with invalid BlackList Additional size + + responseId++; + + List additionsExceed = new ArrayList(); + for (int i = 0; i < maxListSize + 1; i++) { + additionsExceed.add("Node"); + } + + ResourceBlacklistRequest resourceBlacklistRequestExceed = + ResourceBlacklistRequest.newInstance(additionsExceed, removals); + + request = AllocateRequest.newInstance(responseId, progress, resourceAsk, + containersToBeReleased, resourceBlacklistRequestExceed, + increaseRequests, decreaseRequests); + try { + ampDoSRequestInterceptor.allocate(request); + Assert.fail(); + } catch (PreventDoSAttackException e) { + Assert.assertTrue(e.getMessage().contains( + PreventDoSCause.REQUEST_GET_BLACKLIST_ADD_LIST_CAUSE.toString())); + } + + // Execution with invalid BlackList Removal size + + responseId++; + + List removalsExceed = new ArrayList(); + for (int i = 0; i < maxListSize + 1; i++) { + removalsExceed.add("Node"); + } + resourceBlacklistRequestExceed = + ResourceBlacklistRequest.newInstance(additions, removalsExceed); + + request = AllocateRequest.newInstance(responseId, progress, resourceAsk, + containersToBeReleased, resourceBlacklistRequestExceed, + increaseRequests, decreaseRequests); + try { + ampDoSRequestInterceptor.allocate(request); + Assert.fail(); + } catch (PreventDoSAttackException e) { + Assert.assertTrue(e.getMessage().contains( + PreventDoSCause.REQUEST_GET_BLACKLIST_REM_LIST_CAUSE.toString())); + } + + // Execution with invalid Increase ResourceRequest size + + responseId++; + + List increaseRequestsExceed = + new ArrayList(); + for (int i = 0; i < maxListSize + 1; i++) { + increaseRequestsExceed.add(crcr); + } + + request = AllocateRequest.newInstance(responseId, progress, resourceAsk, + containersToBeReleased, resourceBlacklistRequest, + increaseRequestsExceed, decreaseRequests); + try { + ampDoSRequestInterceptor.allocate(request); + Assert.fail(); + } catch (PreventDoSAttackException e) { + Assert.assertTrue(e.getMessage().contains( + PreventDoSCause.REQUEST_GET_INCREASE_LIST_CAUSE.toString())); + } + + // Execution with invalid Decrease ResourceRequest size + + responseId++; + + List decreaseRequestsExceed = + new ArrayList(); + for (int i = 0; i < maxListSize + 1; i++) { + decreaseRequestsExceed.add(crcr); + } + + request = AllocateRequest.newInstance(responseId, progress, resourceAsk, + containersToBeReleased, resourceBlacklistRequest, increaseRequests, + decreaseRequestsExceed); + try { + ampDoSRequestInterceptor.allocate(request); + Assert.fail(); + } catch (PreventDoSAttackException e) { + Assert.assertTrue(e.getMessage().contains( + PreventDoSCause.REQUEST_GET_DECREASE_LIST_CAUSE.toString())); + } + } + + /** + * This tests validates the correctness of the logic of + * AMPDoSRequestInterceptor for {@link RegisterApplicationMasterRequest} + * checks for the entire execution. + */ + @Test + public void testRequestsLimitRegisterApplicationMaster() + throws YarnException, IOException { + + // Check RegisterApplicationMasterRequest entire execution of the + // application. Limit is 10. + RegisterApplicationMasterRequest request = RegisterApplicationMasterRequest + .newInstance("hostname", 10, "https://Address:10"); + for (int i = 0; i < maxNumberRequests; i++) { + try { + ampDoSRequestInterceptor.registerApplicationMaster(request); + } catch (Exception e) { + Assert.fail(); + } + } + + // Exceed the limit + try { + ampDoSRequestInterceptor.registerApplicationMaster(request); + Assert.fail(); + } catch (PreventDoSAttackException e) { + Assert.assertTrue(e.getMessage() + .contains(PreventDoSCause.REQUESTS_REGISTER_AM_CAUSE.toString())); + } + } + + /** + * This tests validates the correctness of the logic of + * AMPDoSRequestInterceptor for {@link FinishApplicationMasterRequest} checks + * for the entire execution. + */ + @Test + public void testRequestsLimitFinishApplicationMaster() + throws YarnException, IOException { + + // Check FinishApplicationMasterRequest entire execution of the + // application. Limit is 10. + FinishApplicationMasterRequest request = FinishApplicationMasterRequest + .newInstance(FinalApplicationStatus.SUCCEEDED, "Diagnostic", + "https://Address:10"); + for (int i = 0; i < maxNumberRequests; i++) { + try { + ampDoSRequestInterceptor.finishApplicationMaster(request); + } catch (Exception e) { + Assert.fail(); + } + } + + // Exceed the limit + try { + ampDoSRequestInterceptor.finishApplicationMaster(request); + Assert.fail(); + } catch (PreventDoSAttackException e) { + Assert.assertTrue(e.getMessage() + .contains(PreventDoSCause.REQUESTS_FINISH_AM_CAUSE.toString())); + } + } + + /** + * This tests validates the correctness of the logic of + * AMPDoSRequestInterceptor for {@link AllocateRequest} checks for the entire + * execution. + */ + @Test + public void testRequestsLimitAllocateApplicationMaster() + throws YarnException, IOException { + int responseId = 2; + float progress = 1; + + AllocateRequest request = AllocateRequest.newInstance(responseId, progress, + null, null, null, null, null); + + // Check AllocateRequest entire execution of the application. Limit is 10. + for (int i = 0; i < maxNumberRequests; i++) { + try { + ampDoSRequestInterceptor.allocate(request); + request.setResponseId(++responseId); + } catch (Exception e) { + Assert.fail(); + } + } + + // Exceed the limit + try { + ampDoSRequestInterceptor.allocate(request); + Assert.fail(); + } catch (PreventDoSAttackException e) { + Assert.assertTrue(e.getMessage() + .contains(PreventDoSCause.REQUESTS_ALLOCATE_AM_CAUSE.toString())); + } + } + + /** + * This tests validates the correctness of the logic of + * AMPDoSRequestInterceptor for {@link AllocateRequest#getAskList()} checks + * for the entire execution. + */ + @Test + public void testRequestsLimitAllocateApplicationMasterAskList() + throws YarnException, IOException { + int responseId = 2; + float progress = 1; + + List resourceAsk = new ArrayList(); + for (int i = 0; i < maxListSizeRequests / rounds; i++) { + resourceAsk.add(ResourceRequest.newInstance(Priority.newInstance(1), + "host", Resource.newInstance(1, 1), 1, false, "label")); + } + + AllocateRequest request = AllocateRequest.newInstance(responseId, progress, + resourceAsk, null, null, null, null); + + // Check AllocateRequest#AskList entire execution of the application. Limit + // is 100. + for (int i = 0; i < rounds; i++) { + try { + ampDoSRequestInterceptor.allocate(request); + request.setResponseId(++responseId); + } catch (Exception e) { + Assert.fail(); + } + } + + // Exceed the limit + try { + ampDoSRequestInterceptor.allocate(request); + Assert.fail(); + } catch (PreventDoSAttackException e) { + Assert.assertTrue(e.getMessage() + .contains(PreventDoSCause.REQUESTS_GET_ASK_LIST_CAUSE.toString())); + } + } + + /** + * This tests validates the correctness of the logic of + * AMPDoSRequestInterceptor for {@link ResourceRequest#getNumContainers()} + * checks for the entire execution. + */ + @Test + public void testRequestsLimitAllocateApplicationMasterNumContainers() + throws YarnException, IOException { + int responseId = 2; + float progress = 1; + + List resourceAsk = new ArrayList(); + resourceAsk.add(ResourceRequest.newInstance(Priority.newInstance(1), "host", + Resource.newInstance(1, 1), maxContainersRequests / rounds, false, + "label")); + + AllocateRequest request = AllocateRequest.newInstance(responseId, progress, + resourceAsk, null, null, null, null); + + // Check AllocateRequest#NumContainers entire execution of the application. + // Limit + // is 100. + for (int i = 0; i < rounds; i++) { + try { + ampDoSRequestInterceptor.allocate(request); + request.setResponseId(++responseId); + } catch (Exception e) { + Assert.fail(); + } + } + + // Exceed the limit + try { + ampDoSRequestInterceptor.allocate(request); + Assert.fail(); + } catch (PreventDoSAttackException e) { + Assert.assertTrue(e.getMessage().contains( + PreventDoSCause.REQUESTS_GET_NUM_CONTAINERS_CAUSE.toString())); + } + } + + /** + * This tests validates the correctness of the logic of + * AMPDoSRequestInterceptor for {@link AllocateRequest#getReleaseList()} + * checks for the entire execution. + */ + @Test + public void testRequestsLimitAllocateApplicationMasterReleaseList() + throws YarnException, IOException { + int responseId = 2; + float progress = 1; + + ApplicationId appId = ApplicationId.newInstance(1, 1); + ApplicationAttemptId attemptId = ApplicationAttemptId.newInstance(appId, 1); + List containersToBeReleased = new ArrayList(); + for (int i = 0; i < maxListSizeRequests / rounds; i++) { + containersToBeReleased.add(ContainerId.newContainerId(attemptId, i)); + } + + AllocateRequest request = AllocateRequest.newInstance(responseId, progress, + null, containersToBeReleased, null, null, null); + + // Check AllocateRequest#ReleaseList entire execution of the application. + // Limit is 100. + for (int i = 0; i < rounds; i++) { + try { + ampDoSRequestInterceptor.allocate(request); + request.setResponseId(++responseId); + } catch (Exception e) { + Assert.fail(); + } + } + + // Exceed the limit + try { + ampDoSRequestInterceptor.allocate(request); + Assert.fail(); + } catch (PreventDoSAttackException e) { + Assert.assertTrue(e.getMessage().contains( + PreventDoSCause.REQUESTS_GET_RELEASE_LIST_CAUSE.toString())); + } + } + + /** + * This tests validates the correctness of the logic of + * AMPDoSRequestInterceptor for + * {@link ResourceBlacklistRequest#getBlacklistAdditions()} checks for the + * entire execution. + */ + @Test + public void testRequestsLimitAllocateApplicationMasterAddBlacklist() + throws YarnException, IOException { + int responseId = 2; + float progress = 1; + + List additions = new ArrayList(); + for (int i = 0; i < maxListSizeRequests / rounds; i++) { + additions.add("Node"); + } + List removals = new ArrayList(); + + ResourceBlacklistRequest resourceBlacklistRequest = + ResourceBlacklistRequest.newInstance(additions, removals); + + AllocateRequest request = AllocateRequest.newInstance(responseId, progress, + null, null, resourceBlacklistRequest, null, null); + + // Check AllocateRequest#BlackListAdd entire execution of the application. + // Limit is 100. + for (int i = 0; i < rounds; i++) { + try { + ampDoSRequestInterceptor.allocate(request); + request.setResponseId(++responseId); + } catch (Exception e) { + Assert.fail(); + } + } + + // Exceed the limit + try { + ampDoSRequestInterceptor.allocate(request); + Assert.fail(); + } catch (PreventDoSAttackException e) { + Assert.assertTrue(e.getMessage().contains( + PreventDoSCause.REQUESTS_GET_BLACKLIST_ADD_LIST_CAUSE.toString())); + } + } + + /** + * This tests validates the correctness of the logic of + * AMPDoSRequestInterceptor for + * {@link ResourceBlacklistRequest#getBlacklistRemovals()} checks for the + * entire execution. + */ + @Test + public void testRequestsLimitAllocateApplicationMasterRemBlacklist() + throws YarnException, IOException { + int responseId = 2; + float progress = 1; + + List removals = new ArrayList(); + for (int i = 0; i < maxListSizeRequests / rounds; i++) { + removals.add("Node"); + } + List additions = new ArrayList(); + + ResourceBlacklistRequest resourceBlacklistRequest = + ResourceBlacklistRequest.newInstance(additions, removals); + + AllocateRequest request = AllocateRequest.newInstance(responseId, progress, + null, null, resourceBlacklistRequest, null, null); + + // Check AllocateRequest#BlackListRem entire execution of the application. + // Limit is 100. + for (int i = 0; i < rounds; i++) { + try { + ampDoSRequestInterceptor.allocate(request); + request.setResponseId(++responseId); + } catch (Exception e) { + Assert.fail(); + } + } + + // Exceed the limit + try { + ampDoSRequestInterceptor.allocate(request); + Assert.fail(); + } catch (PreventDoSAttackException e) { + Assert.assertTrue(e.getMessage().contains( + PreventDoSCause.REQUESTS_GET_BLACKLIST_REM_LIST_CAUSE.toString())); + } + } + + /** + * This tests validates the correctness of the logic of + * AMPDoSRequestInterceptor for {@link AllocateRequest#getIncreaseRequests()} + * checks for the entire execution. + */ + @Test + public void testRequestsLimitAllocateApplicationMasterIncrease() + throws YarnException, IOException { + int responseId = 2; + float progress = 1; + + ApplicationId appId = ApplicationId.newInstance(1, 1); + ApplicationAttemptId attemptId = ApplicationAttemptId.newInstance(appId, 1); + ContainerResourceChangeRequest crcr = ContainerResourceChangeRequest + .newInstance(ContainerId.newContainerId(attemptId, 1), + Resource.newInstance(maxMemory, maxVCores)); + List increaseRequests = + new ArrayList(); + for (int i = 0; i < maxListSizeRequests / rounds; i++) { + increaseRequests.add(crcr); + } + + AllocateRequest request = AllocateRequest.newInstance(responseId, progress, + null, null, null, increaseRequests, null); + + // Check AllocateRequest#ChangeResourceRequestIncrease entire execution of + // the application. Limit is 100. + for (int i = 0; i < rounds; i++) { + try { + ampDoSRequestInterceptor.allocate(request); + request.setResponseId(++responseId); + } catch (Exception e) { + Assert.fail(); + } + } + + // Exceed the limit + try { + ampDoSRequestInterceptor.allocate(request); + Assert.fail(); + } catch (PreventDoSAttackException e) { + Assert.assertTrue(e.getMessage().contains( + PreventDoSCause.REQUESTS_GET_INCREASE_LIST_CAUSE.toString())); + } + } + + /** + * This tests validates the correctness of the logic of + * AMPDoSRequestInterceptor for {@link AllocateRequest#getDecreaseRequests()} + * checks for the entire execution. + */ + @Test + public void testRequestsLimitAllocateApplicationMasterDecrease() + throws YarnException, IOException { + int responseId = 2; + float progress = 1; + + ApplicationId appId = ApplicationId.newInstance(1, 1); + ApplicationAttemptId attemptId = ApplicationAttemptId.newInstance(appId, 1); + ContainerResourceChangeRequest crcr = ContainerResourceChangeRequest + .newInstance(ContainerId.newContainerId(attemptId, 1), + Resource.newInstance(maxMemory, maxVCores)); + List decreaseRequests = + new ArrayList(); + for (int i = 0; i < maxListSizeRequests / rounds; i++) { + decreaseRequests.add(crcr); + } + + AllocateRequest request = AllocateRequest.newInstance(responseId, progress, + null, null, null, null, decreaseRequests); + + // Check AllocateRequest#ChangeResourceRequestDecrease entire execution of + // the application. Limit is 100. + for (int i = 0; i < rounds; i++) { + try { + ampDoSRequestInterceptor.allocate(request); + request.setResponseId(++responseId); + } catch (Exception e) { + Assert.fail(); + } + } + + // Exceed the limit + try { + ampDoSRequestInterceptor.allocate(request); + Assert.fail(); + } catch (PreventDoSAttackException e) { + Assert.assertTrue(e.getMessage().contains( + PreventDoSCause.REQUESTS_GET_DECREASE_LIST_CAUSE.toString())); + } + } + +} diff --git hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/test/java/org/apache/hadoop/yarn/server/nodemanager/amrmproxy/TestAMPDoSRequestInterceptorSlidingWindow.java hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/test/java/org/apache/hadoop/yarn/server/nodemanager/amrmproxy/TestAMPDoSRequestInterceptorSlidingWindow.java new file mode 100644 index 0000000..0bfaf34 --- /dev/null +++ hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/test/java/org/apache/hadoop/yarn/server/nodemanager/amrmproxy/TestAMPDoSRequestInterceptorSlidingWindow.java @@ -0,0 +1,966 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.hadoop.yarn.server.nodemanager.amrmproxy; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; + +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.yarn.api.protocolrecords.AllocateRequest; +import org.apache.hadoop.yarn.api.protocolrecords.FinishApplicationMasterRequest; +import org.apache.hadoop.yarn.api.protocolrecords.RegisterApplicationMasterRequest; +import org.apache.hadoop.yarn.api.records.ApplicationAttemptId; +import org.apache.hadoop.yarn.api.records.ApplicationId; +import org.apache.hadoop.yarn.api.records.ContainerId; +import org.apache.hadoop.yarn.api.records.ContainerResourceChangeRequest; +import org.apache.hadoop.yarn.api.records.FinalApplicationStatus; +import org.apache.hadoop.yarn.api.records.Priority; +import org.apache.hadoop.yarn.api.records.Resource; +import org.apache.hadoop.yarn.api.records.ResourceBlacklistRequest; +import org.apache.hadoop.yarn.api.records.ResourceRequest; +import org.apache.hadoop.yarn.conf.YarnConfiguration; +import org.apache.hadoop.yarn.exceptions.YarnException; +import org.apache.hadoop.yarn.server.preventdos.PreventDoSAction; +import org.apache.hadoop.yarn.server.preventdos.PreventDoSAttackException; +import org.apache.hadoop.yarn.server.preventdos.PreventDoSCause; +import org.apache.hadoop.yarn.server.preventdos.PreventDoSContextAMPImpl; +import org.apache.hadoop.yarn.server.preventdos.PreventDoSPolicy; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +/** + * Unit tests for AMPDoSRequestInterceptor for SlidingWindow. + */ +public class TestAMPDoSRequestInterceptorSlidingWindow { + + private static final int maxStringLength = 10; + private static final int maxRPCPort = 1000; + private static final int maxVCores = 6; + private static final int maxMemory = 100 * 1024; + private static final int maxContainers = 50; + private static final int maxListSize = 50; + private static final int minPriority = 1; + private static final int maxPriority = 10; + + private static final int slidingWindowSize = 3; + private static final int slidingWindowAdvanceTime = 1; + private static final int thresholdNumberRequests = 3; + private static final int thresholdListSize = 100; + private static final int thresholdContainers = 100; + private static final int round = 3; + + private static final int maxNumberRequests = 100; + private static final int maxListSizeRequests = 1000; + private static final int maxContainersRequests = 1000; + + private AMPDoSRequestInterceptor ampDoSRequestInterceptor; + + private void setUpRequestLimit(Configuration conf) { + + // Single request configuration for Register AM + + conf.setInt(YarnConfiguration.DOS_REQUEST_LIMIT_LENGTH_HOST, + maxStringLength); + conf.setInt(YarnConfiguration.DOS_REQUEST_LIMIT_RPC_PORT, maxRPCPort); + + // Single request configuration for Finish AM + + conf.setInt(YarnConfiguration.DOS_REQUEST_LIMIT_LENGTH_DIAGNOSTIC, + maxStringLength); + + // Single request configuration for Allocate AM + + conf.setInt(YarnConfiguration.DOS_REQUEST_LIMIT_LENGTH_RESOURCE, + maxStringLength); + conf.setInt(YarnConfiguration.DOS_REQUEST_LIMIT_LENGTH_NODELABEL, + maxStringLength); + + conf.setInt(YarnConfiguration.DOS_REQUEST_LIMIT_MAX_VCORES, maxVCores); + conf.setInt(YarnConfiguration.DOS_REQUEST_LIMIT_MAX_MEMORY, maxMemory); + conf.setInt(YarnConfiguration.DOS_REQUEST_LIMIT_ASK_LIST, maxListSize); + conf.setInt(YarnConfiguration.DOS_REQUEST_LIMIT_RELEASE_LIST, maxListSize); + conf.setInt(YarnConfiguration.DOS_REQUEST_LIMIT_BLACKLIST, maxListSize); + conf.setInt(YarnConfiguration.DOS_REQUEST_LIMIT_CHANGELIST, maxListSize); + conf.setInt(YarnConfiguration.DOS_REQUEST_LIMIT_CONTAINERS, maxContainers); + conf.setInt(YarnConfiguration.DOS_REQUEST_LIMIT_MIN_PRIORITY, minPriority); + conf.setInt(YarnConfiguration.DOS_REQUEST_LIMIT_MAX_PRIORITY, maxPriority); + } + + private void setUpRequestsLimit(Configuration conf) { + // Lifetime configuration for Register AM + + conf.setInt(YarnConfiguration.DOS_MAX_REQUESTS_REGISTER_AM, + maxNumberRequests); + + // Lifetime configuration for Finish AM + + conf.setInt(YarnConfiguration.DOS_MAX_REQUESTS_FINISH_AM, + maxNumberRequests); + + // Lifetime configuration for Allocate AM + conf.setInt(YarnConfiguration.DOS_MAX_REQUESTS_ALLOCATE_AM, + maxNumberRequests); + + conf.setInt(YarnConfiguration.DOS_MAX_REQUESTS_ASK_LIST, + maxListSizeRequests); + conf.setInt(YarnConfiguration.DOS_MAX_REQUESTS_RELEASE_LIST, + maxListSizeRequests); + conf.setInt(YarnConfiguration.DOS_MAX_REQUESTS_BLACKLIST, + maxListSizeRequests); + conf.setInt(YarnConfiguration.DOS_MAX_REQUESTS_CHANGELIST, + maxListSizeRequests); + conf.setInt(YarnConfiguration.DOS_MAX_REQUESTS_NUM_CONTAINERS, + maxContainersRequests); + } + + private void setUpSlidingWindowThreshold(Configuration conf) { + + conf.setInt(YarnConfiguration.DOS_SLIDING_WINDOW_SIZE, slidingWindowSize); + conf.setInt(YarnConfiguration.DOS_SLIDING_WINDOW_ADVANCE_TIME_SEC, + slidingWindowAdvanceTime); + + // Sliding window configuration for Register AM + + conf.setInt(YarnConfiguration.DOS_SLIDING_WINDOW_THRESHOLD_REGISTER_AM, + thresholdNumberRequests); + + // Sliding window configuration for Finish AM + + conf.setInt(YarnConfiguration.DOS_SLIDING_WINDOW_THRESHOLD_FINISH_AM, + thresholdNumberRequests); + + // Sliding window configuration for Allocate AM + + conf.setInt(YarnConfiguration.DOS_SLIDING_WINDOW_THRESHOLD_ALLOCATE_AM, + thresholdNumberRequests); + + conf.setInt(YarnConfiguration.DOS_SLIDING_WINDOW_THRESHOLD_ASK_LIST, + thresholdListSize); + conf.setInt(YarnConfiguration.DOS_SLIDING_WINDOW_THRESHOLD_RELEASE_LIST, + thresholdListSize); + conf.setInt(YarnConfiguration.DOS_SLIDING_WINDOW_THRESHOLD_BLACKLIST, + thresholdListSize); + conf.setInt(YarnConfiguration.DOS_SLIDING_WINDOW_THRESHOLD_CHANGELIST, + thresholdListSize); + + conf.setInt(YarnConfiguration.DOS_SLIDING_WINDOW_THRESHOLD_NUM_CONTAINERS, + thresholdContainers); + } + + @Before + public void setUp() { + Configuration conf = new YarnConfiguration(); + setUpRequestLimit(conf); + setUpSlidingWindowThreshold(conf); + setUpRequestsLimit(conf); + + PreventDoSAction action = new PreventDoSAction(PreventDoSPolicy.REJECT); + PreventDoSContextAMPImpl context = new PreventDoSContextAMPImpl(conf); + + ampDoSRequestInterceptor = new AMPDoSRequestInterceptor(); + ampDoSRequestInterceptor.setContext(context); + ampDoSRequestInterceptor.setAction(action); + RequestInterceptor next = new MockDefaultInterceptor(); + ampDoSRequestInterceptor.setNextInterceptor(next); + } + + /** + * This tests validates the correctness of the logic of + * AMPDoSRequestInterceptor for temporal sliding window checks of + * {@link RegisterApplicationMasterRequest}. + */ + @Test + public void testSlidingWindowThresholdRegisterApplicationMaster() + throws YarnException, IOException { + + RegisterApplicationMasterRequest request = RegisterApplicationMasterRequest + .newInstance("hostname", 10, "https://Address:10"); + + /* + * Current Status of Sliding Window: First Slot + * + * RegisterApplicationMasterRequest | 0 | 0 | 0 | + * + */ + + for (int i = 0; i < thresholdNumberRequests; i++) { + try { + Thread.sleep(slidingWindowAdvanceTime * 1000); + ampDoSRequestInterceptor.registerApplicationMaster(request); + } catch (Exception e) { + Assert.fail(); + } + } + + /* + * Current Status of Sliding Window: Third Slot + * + * RegisterApplicationMasterRequest | 1 | 1 | 1 | + * + */ + + try { + ampDoSRequestInterceptor.registerApplicationMaster(request); + Assert.fail(); + } catch (PreventDoSAttackException e) { + Assert.assertTrue(e.getMessage().contains( + PreventDoSCause.SLIDINGWINDOW_REGISTER_AM_CAUSE.toString())); + } + + // The previous operation fails because we exceed the limit + + try { + Thread.sleep(slidingWindowAdvanceTime * 1000); + + /* + * Current Status of Sliding Window: First Slot + * + * RegisterApplicationMasterRequest | 0 | 1 | 1 | + * + */ + + ampDoSRequestInterceptor.registerApplicationMaster(request); + + /* + * Current Status of Sliding Window: First Slot + * + * RegisterApplicationMasterRequest | 1 | 1 | 1 | + * + */ + + } catch (Exception e) { + Assert.fail(); + } + } + + /** + * This tests validates the correctness of the logic of + * AMPDoSRequestInterceptor for temporal sliding window checks of + * {@link FinishApplicationMasterRequest}. + */ + @Test + public void testSlidingWindowThresholdFinishApplicationMaster() + throws YarnException, IOException { + + FinishApplicationMasterRequest request = FinishApplicationMasterRequest + .newInstance(FinalApplicationStatus.SUCCEEDED, "Diagnostic", + "https://Address:10"); + + /* + * Current Status of Sliding Window: First Slot + * + * FinishApplicationMasterRequest | 0 | 0 | 0 | + * + */ + + for (int i = 0; i < round + 1; i++) { + try { + Thread.sleep(slidingWindowAdvanceTime * 1000); + ampDoSRequestInterceptor.finishApplicationMaster(request); + } catch (Exception e) { + Assert.fail(); + } + } + + /* + * Current Status of Sliding Window: Third Slot + * + * FinishApplicationMasterRequest | 1 | 1 | 1 | + * + */ + + try { + ampDoSRequestInterceptor.finishApplicationMaster(request); + Assert.fail(); + } catch (PreventDoSAttackException e) { + Assert.assertTrue(e.getMessage() + .contains(PreventDoSCause.SLIDINGWINDOW_FINISH_AM_CAUSE.toString())); + } + + // The previous operation fails because we exceed the limit + + try { + Thread.sleep(slidingWindowAdvanceTime * 1000); + + /* + * Current Status of Sliding Window: First Slot + * + * FinishApplicationMasterRequest | 0 | 1 | 1 | + * + */ + + ampDoSRequestInterceptor.finishApplicationMaster(request); + + /* + * Current Status of Sliding Window: First Slot + * + * FinishApplicationMasterRequest | 1 | 1 | 1 | + * + */ + + } catch (Exception e) { + Assert.fail(); + } + } + + /** + * This tests validates the correctness of the logic of + * AMPDoSRequestInterceptor for temporal sliding window checks of + * {@link AllocateRequest}. + */ + @Test + public void testSlidingWindowThresholdAllocateApplicationMaster() + throws YarnException, IOException { + int responseId = 2; + float progress = 1; + + AllocateRequest request = AllocateRequest.newInstance(responseId, progress, + null, null, null, null, null); + + /* + * Current Status of Sliding Window: First Slot + * + * AllocateRequest | 0 | 0 | 0 | + * + */ + for (int i = 0; i < round + 1; i++) { + try { + Thread.sleep(slidingWindowAdvanceTime * 1000); + ampDoSRequestInterceptor.allocate(request); + request.setResponseId(++responseId); + } catch (Exception e) { + Assert.fail(); + } + } + + /* + * Current Status of Sliding Window: Third Slot + * + * AllocateRequest | 1 | 1 | 1 | + * + */ + + try { + ampDoSRequestInterceptor.allocate(request); + Assert.fail(); + } catch (PreventDoSAttackException e) { + Assert.assertTrue(e.getMessage().contains( + PreventDoSCause.SLIDINGWINDOW_ALLOCATE_AM_CAUSE.toString())); + } + + // The previous operation fails because we exceed the limit + + try { + Thread.sleep(slidingWindowAdvanceTime * 1000); + request.setResponseId(++responseId); + + /* + * Current Status of Sliding Window: First Slot + * + * AllocateRequest | 0 | 1 | 1 | + * + */ + + ampDoSRequestInterceptor.allocate(request); + + /* + * Current Status of Sliding Window: First Slot + * + * AllocateRequest | 1 | 1 | 1 | + * + */ + + } catch (Exception e) { + Assert.fail(); + } + } + + /** + * This tests validates the correctness of the logic of + * AMPDoSRequestInterceptor for temporal sliding window checks of + * {@link AllocateRequest#getAskList()}. + */ + @Test + public void testSlidingWindowThresholdAllocateApplicationMasterAskList() + throws YarnException, IOException, InterruptedException { + int responseId = 2; + float progress = 1; + + List resourceAsk = new ArrayList(); + for (int i = 0; i < thresholdListSize / (round - 1); i++) { + resourceAsk.add(ResourceRequest.newInstance(Priority.newInstance(1), + "host", Resource.newInstance(1, 1), 1, false, "label")); + } + + AllocateRequest request = AllocateRequest.newInstance(responseId, progress, + resourceAsk, null, null, null, null); + + /* + * Current Status of Sliding Window: First Slot + * + * ResourceRequest | 0 | 0 | 0 | + * + */ + for (int i = 0; i < round - 1; i++) { + try { + Thread.sleep(slidingWindowAdvanceTime * 1000); + ampDoSRequestInterceptor.allocate(request); + request.setResponseId(++responseId); + } catch (Exception e) { + Assert.fail(); + } + } + + /* + * Current Status of Sliding Window: Third Slot + * + * ResourceRequest | 0 | 50 | 50 | + * + */ + + try { + Thread.sleep(slidingWindowAdvanceTime * 1000); + ampDoSRequestInterceptor.allocate(request); + Assert.fail(); + } catch (PreventDoSAttackException e) { + Assert.assertTrue(e.getMessage().contains( + PreventDoSCause.SLIDINGWINDOW_GET_ASK_LIST_CAUSE.toString())); + } + + // The previous operation fails because we exceed the limit + + try { + Thread.sleep(slidingWindowAdvanceTime * 1000); + request.setResponseId(++responseId); + + /* + * Current Status of Sliding Window: Second Slot + * + * ResourceRequest | 0 | 0 | 50 | + * + */ + + ampDoSRequestInterceptor.allocate(request); + + /* + * Current Status of Sliding Window: Second Slot + * + * ResourceRequest | 0 | 50 | 50 | + * + */ + + } catch (Exception e) { + Assert.fail(); + } + } + + /** + * This tests validates the correctness of the logic of + * AMPDoSRequestInterceptor for temporal sliding window checks of + * {@link ResourceRequest#getNumContainers()}. + */ + @Test + public void testSlidingWindowThresholdAllocateApplicationMasterNumContainers() + throws YarnException, IOException, InterruptedException { + int responseId = 2; + float progress = 1; + + List resourceAsk = new ArrayList(); + resourceAsk.add(ResourceRequest.newInstance(Priority.newInstance(1), "host", + Resource.newInstance(1, 1), thresholdContainers / (round - 1), false, + "label")); + + AllocateRequest request = AllocateRequest.newInstance(responseId, progress, + resourceAsk, null, null, null, null); + + /* + * Current Status of Sliding Window: First Slot + * + * NUM_CONTAINERS | 0 | 0 | 0 | + * + */ + for (int i = 0; i < (round - 1); i++) { + try { + Thread.sleep(slidingWindowAdvanceTime * 1000); + ampDoSRequestInterceptor.allocate(request); + request.setResponseId(++responseId); + } catch (Exception e) { + Assert.fail(); + } + } + + /* + * Current Status of Sliding Window: Third Slot + * + * NUM_CONTAINERS | 0 | 50 | 50 | + * + */ + + try { + Thread.sleep(slidingWindowAdvanceTime * 1000); + ampDoSRequestInterceptor.allocate(request); + Assert.fail(); + } catch (PreventDoSAttackException e) { + Assert.assertTrue(e.getMessage().contains( + PreventDoSCause.SLIDINGWINDOW_GET_NUM_CONTAINERS_CAUSE.toString())); + } + + // The previous operation fails because we exceed the limit + + try { + Thread.sleep(slidingWindowAdvanceTime * 1000); + request.setResponseId(++responseId); + + /* + * Current Status of Sliding Window: Second Slot + * + * NUM_CONTAINERS | 0 | 0 | 50 | + * + */ + + ampDoSRequestInterceptor.allocate(request); + + /* + * Current Status of Sliding Window: Second Slot + * + * NUM_CONTAINERS | 0 | 50 | 50 | + * + */ + + } catch (Exception e) { + Assert.fail(); + } + } + + /** + * This tests validates the correctness of the logic of + * AMPDoSRequestInterceptor for temporal sliding window checks of + * {@link AllocateRequest#getReleaseList()}. + */ + @Test + public void testSlidingWindowThresholdAllocateApplicationMasterReleaseList() + throws YarnException, IOException, InterruptedException { + int responseId = 2; + float progress = 1; + + ApplicationId appId = ApplicationId.newInstance(1, 1); + ApplicationAttemptId attemptId = ApplicationAttemptId.newInstance(appId, 1); + List containersToBeReleased = new ArrayList(); + for (int i = 0; i < thresholdListSize / (round - 1); i++) { + containersToBeReleased.add(ContainerId.newContainerId(attemptId, i)); + } + + AllocateRequest request = AllocateRequest.newInstance(responseId, progress, + null, containersToBeReleased, null, null, null); + + /* + * Current Status of Sliding Window: First Slot + * + * ContainerId | 0 | 0 | 0 | + * + */ + for (int i = 0; i < round - 1; i++) { + try { + Thread.sleep(slidingWindowAdvanceTime * 1000); + ampDoSRequestInterceptor.allocate(request); + request.setResponseId(++responseId); + } catch (Exception e) { + Assert.fail(); + } + } + + /* + * Current Status of Sliding Window: Third Slot + * + * ContainerId | 0 | 50 | 50 | + * + */ + + try { + Thread.sleep(slidingWindowAdvanceTime * 1000); + ampDoSRequestInterceptor.allocate(request); + Assert.fail(); + } catch (PreventDoSAttackException e) { + Assert.assertTrue(e.getMessage().contains( + PreventDoSCause.SLIDINGWINDOW_GET_RELEASE_LIST_CAUSE.toString())); + } + + // The previous operation fails because we exceed the limit + + try { + Thread.sleep(slidingWindowAdvanceTime * 1000); + request.setResponseId(++responseId); + + /* + * Current Status of Sliding Window: Second Slot + * + * ContainerId | 0 | 0 | 50 | + * + */ + + ampDoSRequestInterceptor.allocate(request); + + /* + * Current Status of Sliding Window: Second Slot + * + * ContainerId | 0 | 50 | 50 | + * + */ + + } catch (Exception e) { + Assert.fail(); + } + } + + /** + * This tests validates the correctness of the logic of + * AMPDoSRequestInterceptor for temporal sliding window checks of + * {@link ResourceBlacklistRequest#getBlacklistAdditions()}. + */ + @Test + public void testSlidingWindowThresholdAllocateApplicationMasterAddBlacklist() + throws YarnException, IOException, InterruptedException { + int responseId = 2; + float progress = 1; + + List additions = new ArrayList(); + for (int i = 0; i < thresholdListSize / (round - 1); i++) { + additions.add("Node"); + } + List removals = new ArrayList(); + + ResourceBlacklistRequest resourceBlacklistRequest = + ResourceBlacklistRequest.newInstance(additions, removals); + + AllocateRequest request = AllocateRequest.newInstance(responseId, progress, + null, null, resourceBlacklistRequest, null, null); + + /* + * Current Status of Sliding Window: First Slot + * + * ResourceBlacklistRequest | 0 | 0 | 0 | + * + */ + for (int i = 0; i < round - 1; i++) { + try { + Thread.sleep(slidingWindowAdvanceTime * 1000); + ampDoSRequestInterceptor.allocate(request); + request.setResponseId(++responseId); + } catch (Exception e) { + Assert.fail(); + } + } + + /* + * Current Status of Sliding Window: Third Slot + * + * ResourceBlacklistRequest | 0 | 50 | 50 | + * + */ + + try { + Thread.sleep(slidingWindowAdvanceTime * 1000); + ampDoSRequestInterceptor.allocate(request); + Assert.fail(); + } catch (PreventDoSAttackException e) { + Assert.assertTrue(e.getMessage() + .contains(PreventDoSCause.SLIDINGWINDOW_GET_BLACKLIST_ADD_LIST_CAUSE + .toString())); + } + + // The previous operation fails because we exceed the limit + + try { + Thread.sleep(slidingWindowAdvanceTime * 1000); + request.setResponseId(++responseId); + + /* + * Current Status of Sliding Window: Second Slot + * + * ResourceBlacklistRequest | 0 | 0 | 50 | + * + */ + + ampDoSRequestInterceptor.allocate(request); + + /* + * Current Status of Sliding Window: Second Slot + * + * ResourceBlacklistRequest | 0 | 50 | 50 | + * + */ + + } catch (Exception e) { + Assert.fail(); + } + } + + /** + * This tests validates the correctness of the logic of + * AMPDoSRequestInterceptor for temporal sliding window checks of + * {@link ResourceBlacklistRequest#getBlacklistRemovals()}. + */ + @Test + public void testSlidingWindowThresholdAllocateApplicationMasterRemBlacklist() + throws YarnException, IOException, InterruptedException { + int responseId = 2; + float progress = 1; + + List removals = new ArrayList(); + for (int i = 0; i < thresholdListSize / (round - 1); i++) { + removals.add("Node "); + } + List additions = new ArrayList(); + + ResourceBlacklistRequest resourceBlacklistRequest = + ResourceBlacklistRequest.newInstance(additions, removals); + + AllocateRequest request = AllocateRequest.newInstance(responseId, progress, + null, null, resourceBlacklistRequest, null, null); + + /* + * Current Status of Sliding Window: First Slot + * + * ResourceBlacklistRequest | 0 | 0 | 0 | + * + */ + for (int i = 0; i < round - 1; i++) { + try { + Thread.sleep(slidingWindowAdvanceTime * 1000); + ampDoSRequestInterceptor.allocate(request); + request.setResponseId(++responseId); + } catch (Exception e) { + Assert.fail(); + } + } + + /* + * Current Status of Sliding Window: Third Slot + * + * ResourceBlacklistRequest | 0 | 50 | 50 | + * + */ + + try { + Thread.sleep(slidingWindowAdvanceTime * 1000); + ampDoSRequestInterceptor.allocate(request); + Assert.fail(); + } catch (PreventDoSAttackException e) { + Assert.assertTrue(e.getMessage() + .contains(PreventDoSCause.SLIDINGWINDOW_GET_BLACKLIST_REM_LIST_CAUSE + .toString())); + } + + // The previous operation fails because we exceed the limit + + try { + Thread.sleep(slidingWindowAdvanceTime * 1000); + request.setResponseId(++responseId); + + /* + * Current Status of Sliding Window: Second Slot + * + * ResourceBlacklistRequest | 0 | 0 | 50 | + * + */ + + ampDoSRequestInterceptor.allocate(request); + + /* + * Current Status of Sliding Window: Second Slot + * + * ResourceBlacklistRequest | 0 | 50 | 50 | + * + */ + + } catch (Exception e) { + Assert.fail(); + } + } + + /** + * This tests validates the correctness of the logic of + * AMPDoSRequestInterceptor for temporal sliding window checks of + * {@link AllocateRequest#getIncreaseRequests()}. + */ + @Test + public void testSlidingWindowThresholdAllocateApplicationMasterIncrease() + throws YarnException, IOException, InterruptedException { + int responseId = 2; + float progress = 1; + + ApplicationId appId = ApplicationId.newInstance(1, 1); + ApplicationAttemptId attemptId = ApplicationAttemptId.newInstance(appId, 1); + ContainerResourceChangeRequest crcr = ContainerResourceChangeRequest + .newInstance(ContainerId.newContainerId(attemptId, 1), + Resource.newInstance(maxMemory, maxVCores)); + List increaseRequests = + new ArrayList(); + for (int i = 0; i < thresholdListSize / (round - 1); i++) { + increaseRequests.add(crcr); + } + + AllocateRequest request = AllocateRequest.newInstance(responseId, progress, + null, null, null, increaseRequests, null); + + /* + * Current Status of Sliding Window: First Slot + * + * ContainerResourceChangeRequest | 0 | 0 | 0 | + * + */ + for (int i = 0; i < round - 1; i++) { + try { + Thread.sleep(slidingWindowAdvanceTime * 1000); + ampDoSRequestInterceptor.allocate(request); + request.setResponseId(++responseId); + } catch (Exception e) { + Assert.fail(); + } + } + + /* + * Current Status of Sliding Window: Third Slot + * + * ContainerResourceChangeRequest | 0 | 50 | 50 | + * + */ + + try { + Thread.sleep(slidingWindowAdvanceTime * 1000); + ampDoSRequestInterceptor.allocate(request); + Assert.fail(); + } catch (PreventDoSAttackException e) { + Assert.assertTrue(e.getMessage().contains( + PreventDoSCause.SLIDINGWINDOW_GET_INCREASE_LIST_CAUSE.toString())); + } + + // The previous operation fails because we exceed the limit + + try { + Thread.sleep(slidingWindowAdvanceTime * 1000); + request.setResponseId(++responseId); + + /* + * Current Status of Sliding Window: Second Slot + * + * ContainerResourceChangeRequest | 0 | 0 | 50 | + * + */ + + ampDoSRequestInterceptor.allocate(request); + + /* + * Current Status of Sliding Window: Second Slot + * + * ContainerResourceChangeRequest | 0 | 50 | 50 | + * + */ + + } catch (Exception e) { + Assert.fail(); + } + } + + /** + * This tests validates the correctness of the logic of + * AMPDoSRequestInterceptor for temporal sliding window checks of + * {@link AllocateRequest#getDecreaseRequests()}. + */ + @Test + public void testSlidingWindowThresholdAllocateApplicationMasterDecrease() + throws YarnException, IOException, InterruptedException { + int responseId = 2; + float progress = 1; + + ApplicationId appId = ApplicationId.newInstance(1, 1); + ApplicationAttemptId attemptId = ApplicationAttemptId.newInstance(appId, 1); + ContainerResourceChangeRequest crcr = ContainerResourceChangeRequest + .newInstance(ContainerId.newContainerId(attemptId, 1), + Resource.newInstance(maxMemory, maxVCores)); + List decreaseRequests = + new ArrayList(); + for (int i = 0; i < thresholdListSize / (round - 1); i++) { + decreaseRequests.add(crcr); + } + + AllocateRequest request = AllocateRequest.newInstance(responseId, progress, + null, null, null, null, decreaseRequests); + + /* + * Current Status of Sliding Window: First Slot + * + * ContainerResourceChangeRequest | 0 | 0 | 0 | + * + */ + for (int i = 0; i < round - 1; i++) { + try { + Thread.sleep(slidingWindowAdvanceTime * 1000); + ampDoSRequestInterceptor.allocate(request); + request.setResponseId(++responseId); + } catch (Exception e) { + Assert.fail(); + } + } + + /* + * Current Status of Sliding Window: Third Slot + * + * ContainerResourceChangeRequest | 0 | 50 | 50 | + * + */ + + try { + Thread.sleep(slidingWindowAdvanceTime * 1000); + ampDoSRequestInterceptor.allocate(request); + Assert.fail(); + } catch (PreventDoSAttackException e) { + Assert.assertTrue(e.getMessage().contains( + PreventDoSCause.SLIDINGWINDOW_GET_DECREASE_LIST_CAUSE.toString())); + } + + // The previous operation fails because we exceed the limit + + try { + Thread.sleep(slidingWindowAdvanceTime * 1000); + request.setResponseId(++responseId); + + /* + * Current Status of Sliding Window: Second Slot + * + * ContainerResourceChangeRequest | 0 | 0 | 50 | + * + */ + + ampDoSRequestInterceptor.allocate(request); + + /* + * Current Status of Sliding Window: Second Slot + * + * ContainerResourceChangeRequest | 0 | 50 | 50 | + * + */ + } catch (Exception e) { + Assert.fail(); + } + } + +}