Index: core/src/test/scala/unit/kafka/network/SocketServerTest.scala IDEA additional info: Subsystem: com.intellij.openapi.diff.impl.patch.CharsetEP <+>UTF-8 =================================================================== --- core/src/test/scala/unit/kafka/network/SocketServerTest.scala (revision ee1267b127f3081db491fa1bf9a287084c324e36) +++ core/src/test/scala/unit/kafka/network/SocketServerTest.scala (revision ) @@ -17,19 +17,21 @@ package kafka.network; -import java.net._ import java.io._ -import org.junit._ -import org.scalatest.junit.JUnitSuite +import java.net.Socket +import java.nio.ByteBuffer +import java.nio.channels.SelectionKey import java.util.Random + import junit.framework.Assert._ -import kafka.producer.SyncProducerConfig import kafka.api.ProducerRequest -import java.nio.ByteBuffer import kafka.common.TopicAndPartition import kafka.message.ByteBufferMessageSet -import java.nio.channels.SelectionKey +import kafka.producer.SyncProducerConfig import kafka.utils.TestUtils +import org.junit._ +import org.scalatest.junit.JUnitSuite + import scala.collection.Map class SocketServerTest extends JUnitSuite { @@ -43,10 +45,12 @@ recvBufferSize = 300000, maxRequestSize = 50, maxConnectionsPerIp = 5, - connectionsMaxIdleMs = 60*1000, + connectionsMaxIdleMs = 10, maxConnectionsPerIpOverrides = Map.empty[String,Int]) server.startup() + def connect(s:SocketServer = server) = new Socket("localhost", s.port) + def sendRequest(socket: Socket, id: Short, request: Array[Byte]) { val outgoing = new DataOutputStream(socket.getOutputStream) outgoing.writeInt(request.length + 2) @@ -73,15 +77,7 @@ channel.sendResponse(new RequestChannel.Response(request.processor, request, send)) } - def connect(s:SocketServer = server) = new Socket("localhost", s.port) - - @After - def cleanup() { - server.shutdown() - } - @Test - def simpleRequest() { - val socket = connect() + def newRequestBytes: Array[Byte] = { val correlationId = -1 val clientId = SyncProducerConfig.DefaultClientId val ackTimeoutMs = SyncProducerConfig.DefaultAckTimeoutMs @@ -94,7 +90,17 @@ byteBuffer.rewind() val serializedBytes = new Array[Byte](byteBuffer.remaining) byteBuffer.get(serializedBytes) + serializedBytes + } + @After + def cleanup() { + server.shutdown() + } + @Test + def simpleRequest() { + val socket = connect() + val serializedBytes = newRequestBytes sendRequest(socket, 0, serializedBytes) processRequest(server.requestChannel) assertEquals(serializedBytes.toSeq, receiveResponse(socket).toSeq) @@ -175,5 +181,27 @@ conn.setSoTimeout(3000) assertEquals(-1, conn.getInputStream.read()) overrideServer.shutdown() + } + + @Test + def testSocketsCloseAfterMaxIdle() { + //do a request response cycle + val socket = connect(server) + val serializedBytes = newRequestBytes + + sendRequest(socket, 0, serializedBytes) + processRequest(server.requestChannel) + receiveResponse(socket) + + // then wait for remaining to max idle time + some more to allow a epoll cycle + Thread.sleep(server.connectionsMaxIdleMs * 15 /10 + 450) + // doing a subsequent send should throw an exception as the connection should be closed + try { + sendRequest(socket, 0, serializedBytes) + sendRequest(socket, 0, serializedBytes) + fail("Expecting I/O error") + } catch { + case (e: IOException) => //good + } } } Index: core/src/main/scala/kafka/network/SocketServer.scala IDEA additional info: Subsystem: com.intellij.openapi.diff.impl.patch.CharsetEP <+>UTF-8 =================================================================== --- core/src/main/scala/kafka/network/SocketServer.scala (revision ee1267b127f3081db491fa1bf9a287084c324e36) +++ core/src/main/scala/kafka/network/SocketServer.scala (revision ) @@ -306,7 +306,7 @@ private val newConnections = new ConcurrentLinkedQueue[SocketChannel]() private val connectionsMaxIdleNanos = connectionsMaxIdleMs * 1000 * 1000 private var currentTimeNanos = SystemTime.nanoseconds - private val lruConnections = new util.LinkedHashMap[SelectionKey, Long] + private val lruConnections = new util.LinkedHashMap[SelectionKey, Long](16, .75F, true) private var nextIdleCloseCheckTime = currentTimeNanos + connectionsMaxIdleNanos override def run() {