From 9b51760b25784657df62ea6c66044cb08f9e2d61 Mon Sep 17 00:00:00 2001
From: Dave Crighton <davicrig@li-4af0294c-2cb5-11b2-a85c-9b674f3accb4.ibm.com>
Date: Tue, 25 Jul 2023 11:59:49 +0100
Subject: [PATCH] defensive code for rogue packets during server restart

---
 .../authenticator/SaslClientAuthenticator.java       | 12 ++++++++++--
 1 file changed, 10 insertions(+), 2 deletions(-)

diff --git a/clients/src/main/java/org/apache/kafka/common/security/authenticator/SaslClientAuthenticator.java b/clients/src/main/java/org/apache/kafka/common/security/authenticator/SaslClientAuthenticator.java
index e502f80d5c..140a88dcf7 100644
--- a/clients/src/main/java/org/apache/kafka/common/security/authenticator/SaslClientAuthenticator.java
+++ b/clients/src/main/java/org/apache/kafka/common/security/authenticator/SaslClientAuthenticator.java
@@ -135,6 +135,8 @@ public class SaslClientAuthenticator implements Authenticator {
     public static boolean isReserved(int correlationId) {
         return correlationId >= MIN_RESERVED_CORRELATION_ID;
     }
+    
+    private static final int MAX_SIZE = 10*1024*1024;
 
     private final Subject subject;
     private final String servicePrincipal;
@@ -236,6 +238,7 @@ public class SaslClientAuthenticator implements Authenticator {
      */
     @SuppressWarnings("fallthrough")
     public void authenticate() throws IOException {
+    	try {
         if (netOutBuffer != null && !flushNetOutBufferAndUpdateInterestOps())
             return;
 
@@ -321,6 +324,11 @@ public class SaslClientAuthenticator implements Authenticator {
                 // Should never get here since exception would have been propagated earlier
                 throw new IllegalStateException("SASL handshake has already failed");
         }
+    	} catch(OutOfMemoryError e) {
+    		//This is almost certainly caused by unexpected packets inflating the buffer size
+    		//so protect against this by failing.
+    		 throw new IllegalSaslStateException("Memory allocation failure during SASL handshake");
+    	}
     }
 
     private void sendHandshakeRequest(short version) throws IOException {
@@ -471,7 +479,7 @@ public class SaslClientAuthenticator implements Authenticator {
     }
 
     private byte[] receiveResponseOrToken() throws IOException {
-        if (netInBuffer == null) netInBuffer = new NetworkReceive(node);
+        if (netInBuffer == null) netInBuffer = new NetworkReceive(MAX_SIZE, node);
         netInBuffer.readFrom(transportLayer);
         byte[] serverPacket = null;
         if (netInBuffer.complete()) {
@@ -566,7 +574,7 @@ public class SaslClientAuthenticator implements Authenticator {
 
     private AbstractResponse receiveKafkaResponse() throws IOException {
         if (netInBuffer == null)
-            netInBuffer = new NetworkReceive(node);
+            netInBuffer = new NetworkReceive(MAX_SIZE, node);
         NetworkReceive receive = netInBuffer;
         try {
             byte[] responseBytes = receiveResponseOrToken();
-- 
2.41.0

