diff --git a/core/src/main/scala/kafka/network/SocketServer.scala b/core/src/main/scala/kafka/network/SocketServer.scala index 216245d..dcfca3f 100644 --- a/core/src/main/scala/kafka/network/SocketServer.scala +++ b/core/src/main/scala/kafka/network/SocketServer.scala @@ -52,14 +52,14 @@ class SocketServer(val brokerId: Int, def startup() { for(i <- 0 until numProcessorThreads) { processors(i) = new Processor(i, time, maxRequestSize, requestChannel) - Utils.newThread("kafka-processor-%d-%d".format(port, i), processors(i), false).start() + Utils.newThread("kafka-network-thread-%d-%d".format(port, i), processors(i), false).start() } // register the processor threads for notification of responses requestChannel.addResponseListener((id:Int) => processors(id).wakeup()) // start accepting connections this.acceptor = new Acceptor(host, port, processors, sendBufferSize, recvBufferSize) - Utils.newThread("kafka-acceptor", acceptor, false).start() + Utils.newThread("kafka-socket-acceptor", acceptor, false).start() acceptor.awaitStartup info("Started") } @@ -265,6 +265,7 @@ private[kafka] class Processor(val id: Int, } } debug("Closing selector.") + closeAll() swallowError(selector.close()) shutdownComplete() } @@ -314,6 +315,17 @@ private[kafka] class Processor(val id: Int, key.attach(null) swallowError(key.cancel()) } + + /* + * Close all open connections + */ + private def closeAll() { + val iter = this.selector.keys().iterator() + while (iter.hasNext) { + val key = iter.next() + close(key) + } + } /** * Queue up a new connection for reading diff --git a/core/src/test/scala/unit/kafka/network/SocketServerTest.scala b/core/src/test/scala/unit/kafka/network/SocketServerTest.scala index 4ff6f55..c3b1ac4 100644 --- a/core/src/test/scala/unit/kafka/network/SocketServerTest.scala +++ b/core/src/test/scala/unit/kafka/network/SocketServerTest.scala @@ -5,7 +5,7 @@ * The ASF licenses this file to You under the Apache License, Version 2.0 * (the "License"); you may not use this file except in compliance with * the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software @@ -31,7 +31,6 @@ import kafka.message.ByteBufferMessageSet import java.nio.channels.SelectionKey import kafka.utils.TestUtils - class SocketServerTest extends JUnitSuite { val server: SocketServer = new SocketServer(0, @@ -52,7 +51,7 @@ class SocketServerTest extends JUnitSuite { outgoing.flush() } - def receiveResponse(socket: Socket): Array[Byte] = { + def receiveResponse(socket: Socket): Array[Byte] = { val incoming = new DataInputStream(socket.getInputStream) val len = incoming.readInt() val response = new Array[Byte](len) @@ -98,7 +97,7 @@ class SocketServerTest extends JUnitSuite { assertEquals(serializedBytes.toSeq, receiveResponse(socket).toSeq) } - @Test(expected=classOf[IOException]) + @Test(expected = classOf[IOException]) def tooBigRequestIsRejected() { val tooManyBytes = new Array[Byte](server.maxRequestSize + 1) new Random().nextBytes(tooManyBytes) @@ -108,22 +107,10 @@ class SocketServerTest extends JUnitSuite { } @Test - def testSocketSelectionKeyState() { + def testNullResponse() { val socket = connect() - val correlationId = -1 - val clientId = SyncProducerConfig.DefaultClientId - val ackTimeoutMs = SyncProducerConfig.DefaultAckTimeoutMs - val ack: Short = 0 - val emptyRequest = - new ProducerRequest(correlationId, clientId, ack, ackTimeoutMs, collection.mutable.Map[TopicAndPartition, ByteBufferMessageSet]()) - - val byteBuffer = ByteBuffer.allocate(emptyRequest.sizeInBytes) - emptyRequest.writeTo(byteBuffer) - byteBuffer.rewind() - val serializedBytes = new Array[Byte](byteBuffer.remaining) - byteBuffer.get(serializedBytes) - - sendRequest(socket, 0, serializedBytes) + val bytes = new Array[Byte](40) + sendRequest(socket, 0, bytes) val request = server.requestChannel.receiveRequest // Since the response is not sent yet, the selection key should not be readable. @@ -135,7 +122,16 @@ class SocketServerTest extends JUnitSuite { Assert.assertTrue( TestUtils.waitUntilTrue( () => { (request.requestKey.asInstanceOf[SelectionKey].interestOps & SelectionKey.OP_READ) == SelectionKey.OP_READ }, - 5000) - ) + 5000)) + } + + @Test(expected = classOf[SocketException]) + def testSocketsCloseOnShutdown() { + // open a connection and then shutdown the server + val socket = connect() + server.shutdown() + // doing a subsequent send should throw an exception as the connection should be closed. + val bytes = new Array[Byte](10) + sendRequest(socket, 0, bytes) } }