diff --git hadoop-yarn-project/hadoop-yarn/hadoop-yarn-api/src/main/java/org/apache/hadoop/yarn/conf/YarnConfiguration.java hadoop-yarn-project/hadoop-yarn/hadoop-yarn-api/src/main/java/org/apache/hadoop/yarn/conf/YarnConfiguration.java index 0b150c2..1d2ae32 100644 --- hadoop-yarn-project/hadoop-yarn/hadoop-yarn-api/src/main/java/org/apache/hadoop/yarn/conf/YarnConfiguration.java +++ hadoop-yarn-project/hadoop-yarn/hadoop-yarn-api/src/main/java/org/apache/hadoop/yarn/conf/YarnConfiguration.java @@ -2565,6 +2565,152 @@ public static boolean areNodeLabelsEnabled( public static final String TIMELINE_CSRF_METHODS_TO_IGNORE = TIMELINE_CSRF_PREFIX + "methods-to-ignore"; + // Prevent DoS attack configuration + + public static final String DOS_PREFIX = "dos."; + + public static final String DOS_SLIDING_WINDOW = + DOS_PREFIX + "sliding-windows."; + + public static final String DOS_SLIDING_WINDOW_SIZE = + DOS_SLIDING_WINDOW + "size"; + public static final int DEFAULT_DOS_SLIDING_WINDOW_SIZE = 10; + + public static final String DOS_SLIDING_WINDOW_ADVANCE_TIME_SEC = + DOS_SLIDING_WINDOW + "advance-time"; + public static final int DEFAULT_DOS_SLIDING_WINDOW_ADVANCE_TIME_SEC = 60; + + public static final String DOS_THRESHOLD = DOS_PREFIX + "threshold."; + + public static final String DOS_THREASHOLD_REGISTER_AM = + DOS_THRESHOLD + "register-am"; + public static final int DEFAULT_DOS_THREASHOLD_REGISTER_AM = 10; + + public static final String DOS_THREASHOLD_ALLOCATE_AM = + DOS_THRESHOLD + "allocate-am"; + public static final int DEFAULT_DOS_THREASHOLD_ALLOCATE_AM = 1000; + + public static final String DOS_THREASHOLD_FINISH_AM = + DOS_THRESHOLD + "finish-am"; + public static final int DEFAULT_DOS_THREASHOLD_FINISH_AM = 10; + + public static final String DOS_THREASHOLD_ASKLIST_AM = + DOS_THRESHOLD + "ask-am"; + public static final int DEFAULT_DOS_THREASHOLD_ASKLIST_AM = 1000; + + public static final String DOS_THREASHOLD_RELEASE_AM = + DOS_THRESHOLD + "release-am"; + public static final int DEFAULT_DOS_THREASHOLD_RELEASE_AM = 1000; + + public static final String DOS_THREASHOLD_BLACKLIST_INCREASE_AM = + DOS_THRESHOLD + "blacklist-increase-am"; + public static final int DEFAULT_DOS_THREASHOLD_BLACKLIST_INCREASE_AM = 100; + + public static final String DOS_THREASHOLD_BLACKLIST_DECREASE_AM = + DOS_THRESHOLD + "blacklist-decrease-am"; + public static final int DEFAULT_DOS_THREASHOLD_BLACKLIST_DECREASE_AM = 100; + + public static final String DOS_THREASHOLD_RESOURCE_INCREASE_AM = + DOS_THRESHOLD + "resource-increase-am"; + public static final int DEFAULT_DOS_THREASHOLD_RESOURCE_INCREASE_AM = 100; + + public static final String DOS_THREASHOLD_RESOURCE_DECREASE_AM = + DOS_THRESHOLD + "resource-decrease-am"; + public static final int DEFAULT_DOS_THREASHOLD_RESOURCE_DECREASE_AM = 100; + + public static final String DOS_THREASHOLD_RESOURCE_NUM_CONTAINERS = + DOS_THRESHOLD + "resource-decrease-am"; + public static final int DEFAULT_DOS_THREASHOLD_RESOURCE_NUM_CONTAINERS = 1000; + + public static final String DOS_LIMIT = DOS_PREFIX + "limit."; + + public static final String DOS_LIMIT_LENGTH_HOST = + DOS_LIMIT + "length-host"; + public static final int DEFAULT_DOS_LIMIT_LENGTH_HOST = 150; + + public static final String DOS_LIMIT_RPC_PORT = DOS_LIMIT + "rpc-port"; + public static final int DEFAULT_DOS_LIMIT_RPC_PORT = 10000; + + public static final String DOS_LIMIT_LENGTH_DIAGNOSTIC = + DOS_LIMIT + "length-diagnostic"; + public static final int DEFAULT_DOS_LIMIT_LENGTH_DIAGNOSTIC = 100; + + public static final String DOS_LIMIT_MIN_PRIORITY = + DOS_LIMIT + "min-priority"; + public static final int DEFAULT_DOS_LIMIT_MIN_PRIORITY = 0; + + public static final String DOS_LIMIT_MAX_PRIORITY = + DOS_LIMIT + "max-priority"; + public static final int DEFAULT_DOS_LIMIT_MAX_PRIORITY = 100; + + public static final String DOS_LIMIT_LENGTH_RESOURCENAME = + DOS_LIMIT + "length-host"; + public static final int DEFAULT_DOS_LIMIT_LENGTH_RESOURCENAME = 150; + + public static final String DOS_LIMIT_LENGTH_NODELABEL = + DOS_LIMIT + "length-nodelabel"; + public static final int DEFAULT_DOS_LIMIT_LENGTH_NODELABEL = 200; + + public static final String DOS_LIMIT_MAX_MEMORY = DOS_LIMIT + "max-memory"; + public static final int DEFAULT_DOS_LIMIT_MAX_MEMORY = 100 * 1024; + + public static final String DOS_LIMIT_MAX_VCORES = DOS_LIMIT + "max-vcores"; + public static final int DEFAULT_DOS_LIMIT_MAX_VCORES = 24; + + public static final String DOS_MAX_SIZE = DOS_PREFIX + "max-size."; + + public static final String DOS_MAX_SIZE_ASKLIST = DOS_MAX_SIZE + "ask-list"; + public static final int DEFAULT_DOS_MAX_SIZE_ASKLIST = 100; + + public static final String DOS_MAX_SIZE_RELEASELIST = + DOS_MAX_SIZE + "release-list"; + public static final int DEFAULT_DOS_MAX_SIZE_RELEASELIST = 100; + + public static final String DOS_MAX_SIZE_BLACKLIST = + DOS_MAX_SIZE + "black-list"; + public static final int DEFAULT_DOS_MAX_SIZE_BLACKLIST = 100; + + public static final String DOS_MAX_SIZE_CHANGELIST = + DOS_MAX_SIZE + "change-list"; + public static final int DEFAULT_DOS_MAX_SIZE_CHANGELIST = 100; + + public static final String DOS_MAX_SIZE_CONTAINERS = + DOS_MAX_SIZE + "containers"; + public static final int DEFAULT_DOS_MAX_SIZE_CONTAINERS = 200; + + public static final String DOS_MAX_REQUESTS = DOS_PREFIX + "max-requests."; + + public static final String DOS_MAX_REQUESTS_REGISTER = + DOS_MAX_REQUESTS + "register"; + public static final int DEFAULT_DOS_MAX_REQUESTS_REGISTER = 50; + + public static final String DOS_MAX_REQUESTS_ALLOCATE = + DOS_MAX_REQUESTS + "allocate"; + public static final int DEFAULT_DOS_MAX_REQUESTS_ALLOCATE = 1000; + + public static final String DOS_MAX_REQUESTS_FINISH = + DOS_MAX_REQUESTS + "finish"; + public static final int DEFAULT_DOS_MAX_REQUESTS_FINISH = 50; + + public static final String DOS_MAX_REQUESTS_ASKLIST = + DOS_MAX_REQUESTS + "finish";; + public static final int DEFAULT_DOS_MAX_REQUESTS_ASKLIST = 1000; + + public static final String DOS_MAX_REQUESTS_RELEASELIST = + DOS_MAX_REQUESTS + "releaselist"; + public static final int DEFAULT_DOS_MAX_REQUESTS_RELEASELIST = 1000; + + public static final String DOS_MAX_REQUESTS_BLACKLIST = + DOS_MAX_REQUESTS + "blacklist"; + public static final int DEFAULT_DOS_MAX_REQUESTS_BLACKLIST = 100; + + public static final String DOS_MAX_REQUESTS_CHANGELIST = + DOS_MAX_REQUESTS + "changelist"; + public static final int DEFAULT_DOS_MAX_REQUESTS_CHANGELIST = 100; + + public static final String DOS_MAX_REQUESTS_CONTAINERS = + DOS_MAX_REQUESTS + "containers"; + public static final int DEFAULT_DOS_MAX_REQUESTS_CONTAINERS = 100000; public YarnConfiguration() { super(); diff --git hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-common/src/main/java/org/apache/hadoop/yarn/server/dos/PreventDoSAttackException.java hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-common/src/main/java/org/apache/hadoop/yarn/server/dos/PreventDoSAttackException.java new file mode 100644 index 0000000..5f1b39e --- /dev/null +++ hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-common/src/main/java/org/apache/hadoop/yarn/server/dos/PreventDoSAttackException.java @@ -0,0 +1,46 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.hadoop.yarn.server.dos; + +import org.apache.hadoop.classification.InterfaceAudience.Public; +import org.apache.hadoop.classification.InterfaceStability.Stable; + +@Public +@Stable +public class PreventDoSAttackException extends Exception { + + private static final long serialVersionUID = 1L; + + public PreventDoSAttackException() { + super(); + } + + public PreventDoSAttackException(String message) { + super(message); + } + + public PreventDoSAttackException(Throwable cause) { + super(cause); + } + + public PreventDoSAttackException(String message, Throwable cause) { + super(message, cause); + } + +} \ No newline at end of file diff --git hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-common/src/main/java/org/apache/hadoop/yarn/server/dos/SlidingWindowCounter.java hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-common/src/main/java/org/apache/hadoop/yarn/server/dos/SlidingWindowCounter.java new file mode 100644 index 0000000..601e987 --- /dev/null +++ hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-common/src/main/java/org/apache/hadoop/yarn/server/dos/SlidingWindowCounter.java @@ -0,0 +1,131 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.hadoop.yarn.server.dos; + +import java.io.Serializable; +import java.util.Map; + +/** + * This class counts objects in a sliding window fashion. + *

+ * It is designed 1) to give multiple "producer" threads write access to the + * counter, i.e. being able to increment counts of objects, and 2) to give a + * single "consumer" thread (e.g. {@link PeriodicSlidingWindowCounter}) read + * access to the counter. Whenever the consumer thread performs a read + * operation, this class will advance the head slot of the sliding window + * counter. This means that the consumer thread indirectly controls where writes + * of the producer threads will go to. Also, by itself this class will not + * advance the head slot. + *

+ * A note for analyzing data based on a sliding window count: During the initial + * windowLengthInSlots iterations, this sliding window counter will + * always return object counts that are equal or greater than in the previous + * iteration. This is the effect of the counter "loading up" at the very start + * of its existence. Conceptually, this is the desired behavior. + *

+ * To give an example, using a counter with 5 slots which for the sake of this + * example represent 1 minute of time each: + *

+ * + *

+ * {@code
+ * Sliding window counts of an object X over time
+ *
+ * Minute (timeline):
+ * 1    2   3   4   5   6   7   8
+ *
+ * Observed counts per minute:
+ * 1    1   1   1   0   0   0   0
+ *
+ * Counts returned by counter:
+ * 1    2   3   4   4   3   2   1
+ * }
+ * 
+ *

+ * As you can see in this example, for the first + * windowLengthInSlots (here: the first five minutes) the counter + * will always return counts equal or greater than in the previous iteration (1, + * 2, 3, 4, 4). This initial load effect needs to be accounted for whenever you + * want to perform analyses such as trending topics; otherwise your analysis + * algorithm might falsely identify the object to be trending as the counter + * seems to observe continuously increasing counts. Also, note that during the + * initial load phase every object will exhibit increasing counts. + *

+ * On a high-level, the counter exhibits the following behavior: If you asked + * the example counter after two minutes, + * "how often did you count the object during the past five minutes?", then it + * should reply "I have counted it 2 times in the past five minutes", implying + * that it can only account for the last two of those five minutes because the + * counter was not running before that time. + * + * @param The type of those objects we want to count. + */ +public class SlidingWindowCounter implements Serializable { + + private static final long serialVersionUID = -2645063988768785810L; + + private SlotBasedCounter objCounter; + protected int headSlot; + protected int tailSlot; + protected int windowLengthInSlots; + + public SlidingWindowCounter(int windowLengthInSlots) { + if (windowLengthInSlots < 2) { + throw new IllegalArgumentException( + "Window length in slots must be at least two (you requested " + + windowLengthInSlots + ")"); + } + this.windowLengthInSlots = windowLengthInSlots; + this.objCounter = new SlotBasedCounter(this.windowLengthInSlots); + + this.headSlot = 0; + this.tailSlot = slotAfter(headSlot); + } + + public void incrementCount(T obj) { + objCounter.incrementCount(obj, headSlot); + } + + /** + * Return the current (total) counts of all tracked objects, then advance the + * window. + *

+ * Whenever this method is called, we consider the counts of the current + * sliding window to be available to and successfully processed "upstream" + * (i.e. by the caller). Knowing this we will start counting any subsequent + * objects within the next "chunk" of the sliding window. + * + * @return The current (total) counts of all tracked objects. + */ + public Map getCountsThenAdvanceWindow() { + Map counts = objCounter.getCounts(); + objCounter.wipeZeros(); + objCounter.wipeSlot(tailSlot); + advanceHead(); + return counts; + } + + private void advanceHead() { + headSlot = tailSlot; + tailSlot = slotAfter(tailSlot); + } + + private int slotAfter(int slot) { + return (slot + 1) % windowLengthInSlots; + } +} diff --git hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-common/src/main/java/org/apache/hadoop/yarn/server/dos/SlidingWindowCounterDos.java hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-common/src/main/java/org/apache/hadoop/yarn/server/dos/SlidingWindowCounterDos.java new file mode 100644 index 0000000..fb4a2cd --- /dev/null +++ hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-common/src/main/java/org/apache/hadoop/yarn/server/dos/SlidingWindowCounterDos.java @@ -0,0 +1,49 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.hadoop.yarn.server.dos; + +public final class SlidingWindowCounterDos extends SlidingWindowCounter { + + protected SlotBasedCounterDoS objCounter; + + private static final long serialVersionUID = -2645063988768785810L; + + public SlidingWindowCounterDos(int windowLengthInSlots) { + super(windowLengthInSlots); + objCounter = new SlotBasedCounterDoS(this.windowLengthInSlots); + } + + public void setThreshold(T obj, long threshold) { + objCounter.setThreshold(obj, threshold); + } + + public void increaseCount(T obj, long amount) + throws PreventDoSAttackException { + objCounter.increaseCount(obj, headSlot, amount); + } + + public void incrementCountRequest(T obj) throws PreventDoSAttackException { + super.incrementCount(obj); + checkThreshold(obj); + } + + private void checkThreshold(T obj) throws PreventDoSAttackException { + objCounter.checkThreshold(obj, headSlot); + } + +} diff --git hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-common/src/main/java/org/apache/hadoop/yarn/server/dos/SlotBasedCounter.java hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-common/src/main/java/org/apache/hadoop/yarn/server/dos/SlotBasedCounter.java new file mode 100644 index 0000000..4606607 --- /dev/null +++ hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-common/src/main/java/org/apache/hadoop/yarn/server/dos/SlotBasedCounter.java @@ -0,0 +1,120 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.hadoop.yarn.server.dos; + +import java.io.Serializable; +import java.util.HashMap; +import java.util.HashSet; +import java.util.Map; +import java.util.Set; + +/** + * This class provides per-slot counts of the occurrences of objects. + *

+ * It can be used, for instance, as a building block for implementing sliding + * window counting of objects. + * + * @param The type of those objects we want to count. + */ +public class SlotBasedCounter implements Serializable { + + private static final long serialVersionUID = 4858185737378394432L; + + protected final Map objToCounts = new HashMap(); + protected final int numSlots; + + public SlotBasedCounter(int numSlots) { + if (numSlots <= 0) { + throw new IllegalArgumentException( + "Number of slots must be greater than zero (you requested " + numSlots + + ")"); + } + this.numSlots = numSlots; + } + + public void incrementCount(T obj, int slot) { + long[] counts = objToCounts.get(obj); + if (counts == null) { + counts = new long[this.numSlots]; + objToCounts.put(obj, counts); + } + counts[slot]++; + } + + public long getCount(T obj, int slot) { + long[] counts = objToCounts.get(obj); + if (counts == null) { + return 0; + } else { + return counts[slot]; + } + } + + public Map getCounts() { + Map result = new HashMap(); + for (T obj : objToCounts.keySet()) { + result.put(obj, computeTotalCount(obj)); + } + return result; + } + + private long computeTotalCount(T obj) { + long[] curr = objToCounts.get(obj); + long total = 0; + for (long l : curr) { + total += l; + } + return total; + } + + /** + * Reset the slot count of any tracked objects to zero for the given slot. + * + * @param slot + */ + public void wipeSlot(int slot) { + for (T obj : objToCounts.keySet()) { + resetSlotCountToZero(obj, slot); + } + } + + private void resetSlotCountToZero(T obj, int slot) { + long[] counts = objToCounts.get(obj); + counts[slot] = 0; + } + + private boolean shouldBeRemovedFromCounter(T obj) { + return computeTotalCount(obj) == 0; + } + + /** + * Remove any object from the counter whose total count is zero (to free up + * memory). + */ + public void wipeZeros() { + Set objToBeRemoved = new HashSet(); + for (T obj : objToCounts.keySet()) { + if (shouldBeRemovedFromCounter(obj)) { + objToBeRemoved.add(obj); + } + } + for (T obj : objToBeRemoved) { + objToCounts.remove(obj); + } + } +} diff --git hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-common/src/main/java/org/apache/hadoop/yarn/server/dos/SlotBasedCounterDoS.java hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-common/src/main/java/org/apache/hadoop/yarn/server/dos/SlotBasedCounterDoS.java new file mode 100644 index 0000000..d251a45 --- /dev/null +++ hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-common/src/main/java/org/apache/hadoop/yarn/server/dos/SlotBasedCounterDoS.java @@ -0,0 +1,61 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.hadoop.yarn.server.dos; + +import java.util.HashMap; +import java.util.Map; + +public class SlotBasedCounterDoS extends SlotBasedCounter { + + private static final long serialVersionUID = -3578001566939303847L; + private Map dosThreshold = new HashMap(); + + public SlotBasedCounterDoS(int numSlots) { + super(numSlots); + } + + public void increaseCount(T obj, int slot, long amount) + throws PreventDoSAttackException { + long[] counts = objToCounts.get(obj); + if (counts == null) { + counts = new long[this.numSlots]; + objToCounts.put(obj, counts); + } + counts[slot] += amount; + checkThreshold(obj, counts[slot]); + } + + public void setThreshold(T obj, long threshold) { + dosThreshold.put(obj, threshold); + } + + public void checkThreshold(T obj, int slot) throws PreventDoSAttackException { + long[] counts = objToCounts.get(obj); + if (counts == null) { + return; + } + checkThreshold(obj, counts[slot]); + } + + private void checkThreshold(T obj, long count) + throws PreventDoSAttackException { + if (dosThreshold.get(obj) < count) { + throw new PreventDoSAttackException(); + } + } +} diff --git hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/amrmproxy/AMPDoSRequestInterceptor.java hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/amrmproxy/AMPDoSRequestInterceptor.java new file mode 100644 index 0000000..fe833b6 --- /dev/null +++ hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/amrmproxy/AMPDoSRequestInterceptor.java @@ -0,0 +1,641 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.hadoop.yarn.server.nodemanager.amrmproxy; + +import java.io.IOException; +import java.util.Timer; +import java.util.TimerTask; + +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.security.UserGroupInformation; +import org.apache.hadoop.yarn.api.protocolrecords.AllocateRequest; +import org.apache.hadoop.yarn.api.protocolrecords.AllocateResponse; +import org.apache.hadoop.yarn.api.protocolrecords.FinishApplicationMasterRequest; +import org.apache.hadoop.yarn.api.protocolrecords.FinishApplicationMasterResponse; +import org.apache.hadoop.yarn.api.protocolrecords.RegisterApplicationMasterRequest; +import org.apache.hadoop.yarn.api.protocolrecords.RegisterApplicationMasterResponse; +import org.apache.hadoop.yarn.api.records.ContainerId; +import org.apache.hadoop.yarn.api.records.ContainerResourceChangeRequest; +import org.apache.hadoop.yarn.api.records.Priority; +import org.apache.hadoop.yarn.api.records.ResourceBlacklistRequest; +import org.apache.hadoop.yarn.api.records.ResourceRequest; +import org.apache.hadoop.yarn.conf.YarnConfiguration; +import org.apache.hadoop.yarn.exceptions.YarnException; +import org.apache.hadoop.yarn.exceptions.YarnRuntimeException; +import org.apache.hadoop.yarn.server.dos.PreventDoSAttackException; +import org.apache.hadoop.yarn.server.dos.SlidingWindowCounterDos; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +public final class AMPDoSRequestInterceptor extends AbstractRequestInterceptor { + + private static final Logger LOG = + LoggerFactory.getLogger(AMPDoSRequestInterceptor.class); + private UserGroupInformation user = null; + Timer timer = null; + + private SlidingWindowCounterDos slidingWindow = null; + + // Threshold and limit for singular request + + private int registerHostMaxLength; + private int validRPCPort; + private int allocateDiagnosticMaxLength; + private int lastResponseId = 0; + private Priority minPriority; + private Priority maxPriority; + private int allocateResourceNameMaxLength; + private int allocateNodeLabelMaxLength; + private int maxMemory; + private int maxVCores; + private int askListMaxSize; + private int releaseListMaxSize; + private int blackListMaxSize; + private int changeListMaxSize; + private int containersMaxSize; + + // Lifetime counters + + private int lifetimeCounterRegisterRequest = 0; + private int lifetimeCounterAllocateRequest = 0; + private int lifetimeCounterFinishRequest = 0; + private int lifetimeCounterAskList = 0; + private int lifetimeCounterReleaseList = 0; + private int lifetimeCounterBlackList = 0; + private int lifetimeCounterChangeList = 0; + private int lifetimeCounterContainers = 0; + + // Limit for lifetime counters + + private int maxRegisterRequests; + private int maxAllocateRequests; + private int maxFinishRequests; + private int maxAskList; + private int maxReleaseList; + private int maxBlackList; + private int maxChangeList; + private int maxContainers; + + @Override + public void init(AMRMProxyApplicationContext appContext) { + super.init(appContext); + final Configuration conf = this.getConf(); + try { + user = UserGroupInformation.createProxyUser( + appContext.getApplicationAttemptId().toString(), + UserGroupInformation.getCurrentUser()); + user.addToken(appContext.getAMRMToken()); + + } catch (IOException e) { + String message = + "Error while creating of RM app master service proxy for attemptId:" + + appContext.getApplicationAttemptId().toString(); + if (user != null) { + message += ", user: " + user; + } + + LOG.info(message); + throw new YarnRuntimeException(message, e); + } catch (Exception e) { + throw new YarnRuntimeException(e); + } + + createSlidingWindow(conf); + setLimits(conf); + setLifetimeLimits(conf); + } + + private void createSlidingWindow(Configuration conf) { + + // Sliding window configurations + + int slidingWindowSize = + conf.getInt(YarnConfiguration.DOS_SLIDING_WINDOW_SIZE, + YarnConfiguration.DEFAULT_DOS_SLIDING_WINDOW_SIZE); + int slidingWindowInterval = + conf.getInt(YarnConfiguration.DOS_SLIDING_WINDOW_ADVANCE_TIME_SEC, + YarnConfiguration.DEFAULT_DOS_SLIDING_WINDOW_ADVANCE_TIME_SEC); + + slidingWindow = new SlidingWindowCounterDos(slidingWindowSize); + + // Sliding window configuration for Register AM + + slidingWindow.setThreshold(RegisterApplicationMasterRequest.class.getName(), + conf.getInt(YarnConfiguration.DOS_THREASHOLD_REGISTER_AM, + YarnConfiguration.DEFAULT_DOS_THREASHOLD_REGISTER_AM)); + + // Sliding window configuration for Finish AM + + slidingWindow.setThreshold(RegisterApplicationMasterRequest.class.getName(), + conf.getInt(YarnConfiguration.DOS_THREASHOLD_FINISH_AM, + YarnConfiguration.DEFAULT_DOS_THREASHOLD_FINISH_AM)); + + // Sliding window configuration for Allocate AM + + slidingWindow.setThreshold(RegisterApplicationMasterRequest.class.getName(), + conf.getInt(YarnConfiguration.DOS_THREASHOLD_ALLOCATE_AM, + YarnConfiguration.DEFAULT_DOS_THREASHOLD_ALLOCATE_AM)); + + slidingWindow.setThreshold(ResourceRequest.class.getName(), + conf.getInt(YarnConfiguration.DOS_THREASHOLD_ASKLIST_AM, + YarnConfiguration.DEFAULT_DOS_THREASHOLD_ASKLIST_AM)); + slidingWindow.setThreshold(ContainerId.class.getName(), + conf.getInt(YarnConfiguration.DOS_THREASHOLD_RELEASE_AM, + YarnConfiguration.DEFAULT_DOS_THREASHOLD_RELEASE_AM)); + slidingWindow.setThreshold( + ResourceBlacklistRequest.class.getName() + "Increase", + conf.getInt(YarnConfiguration.DOS_THREASHOLD_BLACKLIST_INCREASE_AM, + YarnConfiguration.DEFAULT_DOS_THREASHOLD_BLACKLIST_INCREASE_AM)); + slidingWindow.setThreshold( + ResourceBlacklistRequest.class.getName() + "Decrease", + conf.getInt(YarnConfiguration.DOS_THREASHOLD_BLACKLIST_DECREASE_AM, + YarnConfiguration.DEFAULT_DOS_THREASHOLD_BLACKLIST_DECREASE_AM)); + slidingWindow.setThreshold( + ContainerResourceChangeRequest.class.getName() + "Increase", + conf.getInt(YarnConfiguration.DOS_THREASHOLD_RESOURCE_INCREASE_AM, + YarnConfiguration.DEFAULT_DOS_THREASHOLD_RESOURCE_INCREASE_AM)); + slidingWindow.setThreshold( + ContainerResourceChangeRequest.class.getName() + "Decrease", + conf.getInt(YarnConfiguration.DOS_THREASHOLD_RESOURCE_DECREASE_AM, + YarnConfiguration.DEFAULT_DOS_THREASHOLD_RESOURCE_DECREASE_AM)); + slidingWindow.setThreshold( + ResourceRequest.class.getName() + "NumContainers", + conf.getInt(YarnConfiguration.DOS_THREASHOLD_RESOURCE_NUM_CONTAINERS, + YarnConfiguration.DEFAULT_DOS_THREASHOLD_RESOURCE_NUM_CONTAINERS)); + + timer = new Timer(); + timer.scheduleAtFixedRate(new SlidingWindowsAdvanceTask(), + slidingWindowInterval * 1000, slidingWindowInterval * 1000); + } + + private void setLimits(Configuration conf) { + this.registerHostMaxLength = + conf.getInt(YarnConfiguration.DOS_LIMIT_LENGTH_HOST, + YarnConfiguration.DEFAULT_DOS_LIMIT_LENGTH_HOST); + this.validRPCPort = conf.getInt(YarnConfiguration.DOS_LIMIT_RPC_PORT, + YarnConfiguration.DEFAULT_DOS_LIMIT_RPC_PORT); + this.allocateDiagnosticMaxLength = + conf.getInt(YarnConfiguration.DOS_LIMIT_LENGTH_DIAGNOSTIC, + YarnConfiguration.DEFAULT_DOS_LIMIT_LENGTH_DIAGNOSTIC); + this.minPriority = Priority + .newInstance(conf.getInt(YarnConfiguration.DOS_LIMIT_MIN_PRIORITY, + YarnConfiguration.DEFAULT_DOS_LIMIT_MIN_PRIORITY)); + this.maxPriority = Priority + .newInstance(conf.getInt(YarnConfiguration.DOS_LIMIT_MAX_PRIORITY, + YarnConfiguration.DEFAULT_DOS_LIMIT_MAX_PRIORITY)); + this.allocateResourceNameMaxLength = + conf.getInt(YarnConfiguration.DOS_LIMIT_LENGTH_RESOURCENAME, + YarnConfiguration.DEFAULT_DOS_LIMIT_LENGTH_RESOURCENAME); + this.allocateNodeLabelMaxLength = + conf.getInt(YarnConfiguration.DOS_LIMIT_LENGTH_NODELABEL, + YarnConfiguration.DEFAULT_DOS_LIMIT_LENGTH_NODELABEL); + this.maxVCores = conf.getInt(YarnConfiguration.DOS_LIMIT_MAX_VCORES, + YarnConfiguration.DEFAULT_DOS_LIMIT_MAX_VCORES); + this.maxMemory = conf.getInt(YarnConfiguration.DOS_LIMIT_MAX_MEMORY, + YarnConfiguration.DEFAULT_DOS_LIMIT_MAX_MEMORY); + this.askListMaxSize = conf.getInt(YarnConfiguration.DOS_MAX_SIZE_ASKLIST, + YarnConfiguration.DEFAULT_DOS_MAX_SIZE_ASKLIST); + this.releaseListMaxSize = + conf.getInt(YarnConfiguration.DOS_MAX_SIZE_RELEASELIST, + YarnConfiguration.DEFAULT_DOS_MAX_SIZE_RELEASELIST); + this.blackListMaxSize = + conf.getInt(YarnConfiguration.DOS_MAX_SIZE_BLACKLIST, + YarnConfiguration.DEFAULT_DOS_MAX_SIZE_BLACKLIST); + this.changeListMaxSize = + conf.getInt(YarnConfiguration.DOS_MAX_SIZE_CHANGELIST, + YarnConfiguration.DEFAULT_DOS_MAX_SIZE_CHANGELIST); + this.containersMaxSize = + conf.getInt(YarnConfiguration.DOS_MAX_SIZE_CONTAINERS, + YarnConfiguration.DEFAULT_DOS_MAX_SIZE_CONTAINERS); + } + + private void setLifetimeLimits(Configuration conf) { + this.maxRegisterRequests = + conf.getInt(YarnConfiguration.DOS_MAX_REQUESTS_REGISTER, + YarnConfiguration.DEFAULT_DOS_MAX_REQUESTS_REGISTER); + this.maxAllocateRequests = + conf.getInt(YarnConfiguration.DOS_MAX_REQUESTS_ALLOCATE, + YarnConfiguration.DEFAULT_DOS_MAX_REQUESTS_ALLOCATE); + this.maxFinishRequests = + conf.getInt(YarnConfiguration.DOS_MAX_REQUESTS_FINISH, + YarnConfiguration.DEFAULT_DOS_MAX_REQUESTS_FINISH); + this.maxAskList = conf.getInt(YarnConfiguration.DOS_MAX_REQUESTS_ASKLIST, + YarnConfiguration.DEFAULT_DOS_MAX_REQUESTS_ASKLIST); + this.maxReleaseList = + conf.getInt(YarnConfiguration.DOS_MAX_REQUESTS_RELEASELIST, + YarnConfiguration.DEFAULT_DOS_MAX_REQUESTS_RELEASELIST); + this.maxBlackList = + conf.getInt(YarnConfiguration.DOS_MAX_REQUESTS_BLACKLIST, + YarnConfiguration.DEFAULT_DOS_MAX_REQUESTS_BLACKLIST); + this.maxChangeList = + conf.getInt(YarnConfiguration.DOS_MAX_REQUESTS_CHANGELIST, + YarnConfiguration.DEFAULT_DOS_MAX_REQUESTS_CHANGELIST); + this.maxContainers = + conf.getInt(YarnConfiguration.DOS_MAX_REQUESTS_CONTAINERS, + YarnConfiguration.DEFAULT_DOS_MAX_REQUESTS_CONTAINERS); + } + + private boolean validURL(String url) { + // TODO we can use UrlValidator from Apache Commons + return true; + } + + @Override + public RegisterApplicationMasterResponse registerApplicationMaster( + final RegisterApplicationMasterRequest request) + throws YarnException, IOException { + + // Increase lifetime counter + + lifetimeCounterRegisterRequest++; + + if (lifetimeCounterRegisterRequest >= maxRegisterRequests) { + LOG.error("Possible DoS attack Register"); + // TODO Reject? + } + + // Increase the sliding window counter + + try { + slidingWindow.incrementCountRequest( + RegisterApplicationMasterRequest.class.getName()); + + } catch (PreventDoSAttackException e) { + LOG.error("Possible DoS attack Register"); + // TODO Reject? + } + + // Check String and numerical values + + // Host -> Size of String + if (request.getHost().length() > this.registerHostMaxLength) { + LOG.error("Possible DoS attack Register"); + // TODO String truncated + } + + // RPC Port -> Valid port number + if (request.getRpcPort() > this.validRPCPort) { + LOG.error("Possible DoS attack Register"); + // TODO Reject? + } + + // Tracking URL -> Valid URL expression + if (!validURL(request.getTrackingUrl())) { + LOG.error("Possible DoS attack Register"); + // TODO Reject? + } + + LOG.info("Forwarding registration request to the next interceptor."); + return getNextInterceptor().registerApplicationMaster(request); + } + + @Override + public FinishApplicationMasterResponse finishApplicationMaster( + final FinishApplicationMasterRequest request) + throws YarnException, IOException { + + // Increase lifetime counter + + lifetimeCounterFinishRequest++; + + if (lifetimeCounterFinishRequest >= maxFinishRequests) { + LOG.error("Possible DoS attack Register"); + // TODO Reject? + } + + // Increase the sliding window counter + try { + slidingWindow.incrementCountRequest( + FinishApplicationMasterRequest.class.getName()); + + } catch (PreventDoSAttackException e) { + LOG.info("Possible DoS attack Finish"); + // TODO Reject + } + + // Check String and numerical values + + // Diagnostic -> Size of String + if (request.getDiagnostics().length() > this.allocateDiagnosticMaxLength) { + LOG.error("Possible DoS attack Finish"); + // TODO String truncated + } + + // Tracking URL -> Valid URL expression + if (!validURL(request.getTrackingUrl())) { + LOG.error("Possible DoS attack Finish"); + // TODO Reject? + } + + LOG.info("Forwarding finish application request to the next interceptor"); + return getNextInterceptor().finishApplicationMaster(request); + } + + @Override + public AllocateResponse allocate(final AllocateRequest request) + throws YarnException, IOException { + + // Increase lifetime counter + + lifetimeCounterAllocateRequest++; + + if (lifetimeCounterAllocateRequest >= maxAllocateRequests) { + LOG.error("Possible DoS attack Register"); + // TODO Reject? + } + + // Increase the sliding window counters + + try { + + // Increase the sliding window counter of AllocateRequest + + slidingWindow.incrementCountRequest(AllocateRequest.class.getName()); + + if (request.getAskList() != null) { + + // Check size list of singular request + + if (request.getAskList().size() >= this.askListMaxSize) { + LOG.error("Possible DoS attack Allocate"); + // TODO Reject? + } + + // Increase lifetime counters + + lifetimeCounterAskList += request.getAskList().size(); + + if (lifetimeCounterAskList >= maxAskList) { + LOG.error("Possible DoS attack Register"); + // TODO Reject? + } + + // Increase the sliding window counter of askList + + slidingWindow.increaseCount(ResourceRequest.class.getName(), + request.getAskList().size()); + + for (ResourceRequest rr : request.getAskList()) { + + // Check String and numerical values for ResourceRequest + + // Priority -> Valid priority value + + if (rr.getPriority().compareTo(minPriority) < 0 + || rr.getPriority().compareTo(maxPriority) > 0) { + LOG.error("Possible DoS attack Allocate"); + // TODO Reject? + } + + // ResourceName -> Size of String + + if (rr.getResourceName() + .length() > this.allocateResourceNameMaxLength) { + LOG.error("Possible DoS attack Allocate"); + // TODO String truncated + } + + // Capacity -> Within a range + + int memory = rr.getCapability().getMemory(); + int vcores = rr.getCapability().getVirtualCores(); + + if (memory < 0 || vcores < 0 || memory > this.maxMemory + || vcores > this.maxVCores) { + LOG.error("Possible DoS attack Allocate"); + // TODO Reject? + } + + // Containers -> Within a range + + if (rr.getNumContainers() < 0 + || rr.getNumContainers() > containersMaxSize) { + LOG.error("Possible DoS attack Allocate"); + // TODO Reject? + } + + // Increase lifetime counters + + lifetimeCounterContainers += rr.getNumContainers(); + + if (lifetimeCounterContainers >= maxContainers) { + LOG.error("Possible DoS attack Register"); + // TODO Reject? + } + + // Increase the sliding window counter of NumContainers + + slidingWindow.increaseCount( + ResourceRequest.class.getName() + "NumContainers", + rr.getNumContainers()); + + // Label Expression -> Size of String + + if (rr.getNodeLabelExpression() + .length() > this.allocateNodeLabelMaxLength) { + LOG.error("Possible DoS attack Allocate"); + // TODO String truncated + } + } + } + + if (request.getReleaseList() != null) { + + // Check size list of singular request + + if (request.getReleaseList().size() >= this.releaseListMaxSize) { + LOG.error("Possible DoS attack Allocate"); + // TODO Reject? + } + + // Increase the sliding window counter of ReleaseList + + slidingWindow.increaseCount(ContainerId.class.getName(), + request.getReleaseList().size()); + + // Increase lifetime counters + + lifetimeCounterReleaseList += request.getReleaseList().size(); + + if (lifetimeCounterReleaseList >= maxReleaseList) { + LOG.error("Possible DoS attack Register"); + // TODO Reject? + } + + } + + if (request.getResourceBlacklistRequest() != null) { + if (request.getResourceBlacklistRequest() + .getBlacklistAdditions() != null) { + + // Check size list of singular request + + if (request.getResourceBlacklistRequest().getBlacklistAdditions() + .size() >= this.blackListMaxSize) { + LOG.error("Possible DoS attack Allocate"); + // TODO Reject? + } + + // Increase the sliding window counter of BlackList Request + + slidingWindow.increaseCount( + ResourceBlacklistRequest.class.getName() + "Increase", + request.getResourceBlacklistRequest().getBlacklistAdditions() + .size()); + + // Increase lifetime counters + + lifetimeCounterBlackList += request.getResourceBlacklistRequest() + .getBlacklistAdditions().size(); + + if (lifetimeCounterBlackList >= maxBlackList) { + LOG.error("Possible DoS attack Register"); + // TODO Reject? + } + } + if (request.getResourceBlacklistRequest() + .getBlacklistRemovals() != null) { + + // Check size list of singular request + + if (request.getResourceBlacklistRequest().getBlacklistRemovals() + .size() >= this.blackListMaxSize) { + LOG.error("Possible DoS attack Allocate"); + // TODO Reject? + } + + // Increase the sliding window counter of BlackList Request + + slidingWindow.increaseCount( + ResourceBlacklistRequest.class.getName() + "Decrease", request + .getResourceBlacklistRequest().getBlacklistRemovals().size()); + + // Increase lifetime counters + + lifetimeCounterBlackList += request.getResourceBlacklistRequest() + .getBlacklistRemovals().size(); + + if (lifetimeCounterBlackList >= maxBlackList) { + LOG.error("Possible DoS attack Register"); + // TODO Reject? + } + + } + } + + if (request.getIncreaseRequests() != null) { + + // Check size list of singular request + + if (request.getIncreaseRequests().size() >= this.changeListMaxSize) { + LOG.error("Possible DoS attack Allocate"); + // TODO Reject? + } + + // Increase the sliding window counter of IncreaseRequest + + slidingWindow.increaseCount( + ContainerResourceChangeRequest.class.getName() + "Increase", + request.getIncreaseRequests().size()); + + // Increase lifetime counter + + lifetimeCounterChangeList += request.getIncreaseRequests().size(); + + if (lifetimeCounterChangeList >= maxChangeList) { + LOG.error("Possible DoS attack Register"); + // TODO Reject? + } + + } + + if (request.getDecreaseRequests() != null) { + + // Check size list of singular request + + if (request.getDecreaseRequests().size() >= this.changeListMaxSize) { + LOG.error("Possible DoS attack Allocate"); + // TODO Reject? + } + + // Increase the sliding window counter of DecreaseRequest + + slidingWindow.increaseCount( + ContainerResourceChangeRequest.class.getName() + "Decrease", + request.getDecreaseRequests().size()); + + // Increase lifetime counter + + lifetimeCounterChangeList += request.getDecreaseRequests().size(); + + if (lifetimeCounterChangeList >= maxChangeList) { + LOG.error("Possible DoS attack Register"); + // TODO Reject? + } + } + + } catch (PreventDoSAttackException e) { + LOG.info("Possible DoS attack Allocate"); + // TODO Reject? + } + + // Check String and numerical values + + // ResponseId -> Greater than previous one + + if (request.getResponseId() > lastResponseId) { + lastResponseId = request.getResponseId(); + } else { + LOG.error("Possible DoS attack Allocate"); + // TODO Reject? + } + + // App Progress -> Between 0 and 100 + + if (request.getProgress() < 0 || request.getProgress() > 100) { + LOG.error("Possible DoS attack Allocate"); + // TODO Reject? + } + + LOG.info("Forwarding allocate request to the next interceptor."); + return getNextInterceptor().allocate(request); + } + + @Override + public void setNextInterceptor(RequestInterceptor next) { + throw new YarnRuntimeException( + "setNextInterceptor is being called on DefaultRequestInterceptor," + + "which should be the last one in the chain " + + "Check if the interceptor pipeline configuration is correct"); + } + + private class SlidingWindowsAdvanceTask extends TimerTask { + + public void run() { + slidingWindow.getCountsThenAdvanceWindow(); + } + } + + @Override + public void shutdown() { + if (timer != null) { + timer.cancel(); + } + } +}