From 87c18de2f8ac3efa033cb5d0a01befde13d498a3 Mon Sep 17 00:00:00 2001 From: Ewen Cheslack-Postava Date: Tue, 14 Oct 2014 23:04:55 -0700 Subject: [PATCH] KAFKA-1196 WIP Ensure FetchResponses don't exceed 2GB limit. --- core/src/main/scala/kafka/api/FetchResponse.scala | 53 ++++++++++---- core/src/main/scala/kafka/server/KafkaApis.scala | 24 ++++--- .../unit/kafka/integration/PrimitiveApiTest.scala | 84 +++++++++++++++++++++- 3 files changed, 137 insertions(+), 24 deletions(-) diff --git a/core/src/main/scala/kafka/api/FetchResponse.scala b/core/src/main/scala/kafka/api/FetchResponse.scala index 8d085a1..c12cce4 100644 --- a/core/src/main/scala/kafka/api/FetchResponse.scala +++ b/core/src/main/scala/kafka/api/FetchResponse.scala @@ -20,11 +20,13 @@ package kafka.api import java.nio.ByteBuffer import java.nio.channels.GatheringByteChannel -import kafka.common.{TopicAndPartition, ErrorMapping} +import kafka.common.{KafkaException, TopicAndPartition, ErrorMapping} import kafka.message.{MessageSet, ByteBufferMessageSet} import kafka.network.{MultiSend, Send} import kafka.api.ApiUtils._ +import scala.collection.mutable + object FetchResponsePartitionData { def readFrom(buffer: ByteBuffer): FetchResponsePartitionData = { val error = buffer.getShort @@ -152,22 +154,45 @@ object FetchResponse { case class FetchResponse(correlationId: Int, - data: Map[TopicAndPartition, FetchResponsePartitionData]) + availableData: Map[TopicAndPartition, FetchResponsePartitionData]) extends RequestOrResponse() { - /** - * Partitions the data into a map of maps (one for each topic). - */ - lazy val dataGroupedByTopic = data.groupBy(_._1.topic) + private def groupByTopic(data: collection.Map[TopicAndPartition, FetchResponsePartitionData]) = data.groupBy(_._1.topic) - val sizeInBytes = - FetchResponse.headerSize + - dataGroupedByTopic.foldLeft(0) ((folded, curr) => { - val topicData = TopicData(curr._1, curr._2.map { - case (topicAndPartition, partitionData) => (topicAndPartition.partition, partitionData) + private def partialDataSizeInBytes(dataGroupedByTopic: collection.Map[String,collection.Map[TopicAndPartition, FetchResponsePartitionData]]): Long = + FetchResponse.headerSize.toLong + + dataGroupedByTopic.foldLeft(0L) ((folded, curr) => { + val topicData = TopicData(curr._1, curr._2.map { + case (topicAndPartition, partitionData) => (topicAndPartition.partition, partitionData) + }.toMap) + folded + topicData.sizeInBytes.toLong }) - folded + topicData.sizeInBytes - }) + + /* The data we can actually send. This filters the total available data down to a subset that we can + * fit (serialized) under the 2GB limit since Ints are used all over the place to specify sizes. Without + * this filtering, having too much data available across multiple topics/partitions and a consumer that + * doesn't have low enough limits for each TopicPartition, the aggregate data size can exceed the limit + * and make impossible to make forward progress. Simple randomization is used to avoid starvation. + */ + val (data, dataGroupedByTopic, sizeInBytes, exceededMessageSizeLimit) = { + val maximumLength = Int.MaxValue.toLong - 4 /* for FetchResponseSend size */ + var dataGrouped = groupByTopic(availableData) + var computedSize = partialDataSizeInBytes(dataGrouped) + if (computedSize <= maximumLength) { + (availableData, dataGrouped, computedSize.toInt, false) + } else { + val result = mutable.Map[TopicAndPartition, FetchResponsePartitionData](availableData.toSeq:_*) + val removeOrder = collection.mutable.Queue.apply(scala.util.Random.shuffle(availableData.keys).toSeq:_*) + do { + val removeKey = removeOrder.dequeue() + // Maintain a result, just remove the set of messages that would have been returned + result.put(removeKey, result(removeKey).copy(messages=MessageSet.Empty)) + dataGrouped = groupByTopic(result) + computedSize = partialDataSizeInBytes(dataGrouped) + } while(computedSize > maximumLength) + (result.toMap, dataGrouped, computedSize.toInt, true) + } + } /* * FetchResponse uses [sendfile](http://man7.org/linux/man-pages/man2/sendfile.2.html) @@ -217,7 +242,7 @@ class FetchResponseSend(val fetchResponse: FetchResponse) extends Send { val sends = new MultiSend(fetchResponse.dataGroupedByTopic.toList.map { case(topic, data) => new TopicDataSend(TopicData(topic, - data.map{case(topicAndPartition, message) => (topicAndPartition.partition, message)})) + data.toMap.map{case(topicAndPartition, message) => (topicAndPartition.partition, message)})) }) { val expectedBytesToWrite = fetchResponse.sizeInBytes - FetchResponse.headerSize } diff --git a/core/src/main/scala/kafka/server/KafkaApis.scala b/core/src/main/scala/kafka/server/KafkaApis.scala index 85498b4..2802f97 100644 --- a/core/src/main/scala/kafka/server/KafkaApis.scala +++ b/core/src/main/scala/kafka/server/KafkaApis.scala @@ -303,16 +303,24 @@ class KafkaApis(val requestChannel: RequestChannel, */ def handleFetchRequest(request: RequestChannel.Request) { val fetchRequest = request.requestObj.asInstanceOf[FetchRequest] - val dataRead = replicaManager.readMessageSets(fetchRequest) - // if the fetch request comes from the follower, - // update its corresponding log end offset + // Look up requested data + val dataRead = replicaManager.readMessageSets(fetchRequest) + val response = new FetchResponse(fetchRequest.correlationId, dataRead.mapValues(_.data)) + val consumedData = + if (response.exceededMessageSizeLimit) + dataRead.filterKeys(topicAndPartition => response.data.contains(topicAndPartition)) + else + dataRead + + // If the fetch request comes from the follower, update its corresponding log end offset. + // Can safely use all output from readMessageSets since these if(fetchRequest.isFromFollower) recordFollowerLogEndOffsets(fetchRequest.replicaId, dataRead.mapValues(_.offset)) // check if this fetch request can be satisfied right away - val bytesReadable = dataRead.values.map(_.data.messages.sizeInBytes).sum - val errorReadingData = dataRead.values.foldLeft(false)((errorIncurred, dataAndOffset) => + val bytesReadable = consumedData.values.map(_.data.messages.sizeInBytes).sum + val errorReadingData = consumedData.values.foldLeft(false)((errorIncurred, dataAndOffset) => errorIncurred || (dataAndOffset.data.error != ErrorMapping.NoError)) // send the data immediately if 1) fetch request does not want to wait // 2) fetch request does not require any data @@ -321,10 +329,10 @@ class KafkaApis(val requestChannel: RequestChannel, if(fetchRequest.maxWait <= 0 || fetchRequest.numPartitions <= 0 || bytesReadable >= fetchRequest.minBytes || - errorReadingData) { + errorReadingData || + response.exceededMessageSizeLimit) { debug("Returning fetch response %s for fetch request with correlation id %d to client %s" - .format(dataRead.values.map(_.data.error).mkString(","), fetchRequest.correlationId, fetchRequest.clientId)) - val response = new FetchResponse(fetchRequest.correlationId, dataRead.mapValues(_.data)) + .format(consumedData.values.map(_.data.error).mkString(","), fetchRequest.correlationId, fetchRequest.clientId)) requestChannel.sendResponse(new RequestChannel.Response(request, new FetchResponseSend(response))) } else { debug("Putting fetch request with correlation id %d from client %s into purgatory".format(fetchRequest.correlationId, diff --git a/core/src/test/scala/unit/kafka/integration/PrimitiveApiTest.scala b/core/src/test/scala/unit/kafka/integration/PrimitiveApiTest.scala index a5386a0..a314199 100644 --- a/core/src/test/scala/unit/kafka/integration/PrimitiveApiTest.scala +++ b/core/src/test/scala/unit/kafka/integration/PrimitiveApiTest.scala @@ -29,14 +29,15 @@ import org.scalatest.junit.JUnit3Suite import scala.collection._ import kafka.admin.AdminUtils import kafka.common.{TopicAndPartition, ErrorMapping, UnknownTopicOrPartitionException, OffsetOutOfRangeException} -import kafka.utils.{StaticPartitioner, TestUtils, Utils} +import kafka.utils.{Logging, StaticPartitioner, TestUtils, Utils} import kafka.serializer.StringEncoder import java.util.Properties /** * End to end tests of the primitive apis against a local server */ -class PrimitiveApiTest extends JUnit3Suite with ProducerConsumerTestHarness with ZooKeeperTestHarness { + +class PrimitiveApiTest extends JUnit3Suite with ProducerConsumerTestHarness with ZooKeeperTestHarness with Logging { val requestHandlerLogger = Logger.getLogger(classOf[KafkaRequestHandler]) val port = TestUtils.choosePort() @@ -267,6 +268,85 @@ class PrimitiveApiTest extends JUnit3Suite with ProducerConsumerTestHarness with } /** + * Tests a consumer that requests enough data such that each partition data read could fit in a response, + * but combined the data would overflow the 2GB limit for the response. Instead, the server should return + * a subset of the data that could be returned. + */ + def overflowReadMessageSet(forcePurgatory: Boolean) { + // Topics * (message size + messages per topic) must be > 2GB, here 600 MB * 4 topics + val messageSize = 32 * 1024 + val messagesPerTopic = 600 * 32 + val bytesPerTopic = messagesPerTopic * messageSize + + val topics = List("test1", "test2", "test3", "test4") + assertTrue("Total data should be > 2GB", bytesPerTopic.toLong * topics.size > Int.MaxValue) + createSimpleTopicsAndAwaitLeader(zkClient, topics) + + val msgBytes = new Array[Byte](messageSize) + val msg = new String(msgBytes) + for(topic <- topics) { + for(i <- 1 to messagesPerTopic) + producer.send(new KeyedMessage[String, String](topic, msg)) + } + + val purgatoryTimeout = 5000 + + // First request should give partial results -- 3 topics/partitions fit + val initialPartitionMessageCounts = { + val builder = new FetchRequestBuilder() + if (forcePurgatory) + builder.maxWait(purgatoryTimeout).minBytes(Int.MaxValue) + for (topic <- topics) + builder.addFetch(topic, 0, 0, bytesPerTopic + (1024 * 1024)) + val request = builder.build() + val started = System.currentTimeMillis() + val response = consumer.fetch(request) + val finished = System.currentTimeMillis() + + val messageCounts = topics.map(topic => topic -> response.messageSet(topic, 0).size).toMap + System.out.println("Initial # of messages: %s".format(messageCounts.toString)) + System.out.println("Retrieval took %d ms".format(finished-started)) + val initialMessages = messageCounts.values.sum + assertEquals("First fetch request should return 3 partitions with all messages", messagesPerTopic * 3, initialMessages) + messageCounts + } + + // Second request should complete the rest of the messages (all remaining in one topic/partition) + val remainingPartitionMessageCounts = { + val builder = new FetchRequestBuilder() + if (forcePurgatory) + builder.maxWait(purgatoryTimeout).minBytes(Int.MaxValue) + for (topic <- topics) + builder.addFetch(topic, 0, initialPartitionMessageCounts(topic), bytesPerTopic + (1024 * 1024)) + val request = builder.build() + val started = System.currentTimeMillis() + val response = consumer.fetch(request) + val finished = System.currentTimeMillis() + + val messageCounts = topics.map(topic => topic -> response.messageSet(topic, 0).size).toMap + System.out.println("Remaining # of messages: %s".format(messageCounts.toString)) + System.out.println("Retrieval took %d ms".format(finished-started)) + val remainingMessages = messageCounts.values.sum + assertEquals("Second fetch request should return last remaining partition of messages", messagesPerTopic, remainingMessages) + if (forcePurgatory) + assertTrue("Second fetch should take at least as long as the purgatory timeout", finished-started > purgatoryTimeout) + messageCounts + } + + for(topic <- topics) + assertEquals("Each topic should have received exactly its total number of messages", messagesPerTopic, initialPartitionMessageCounts(topic) + remainingPartitionMessageCounts(topic)) + } + + // Ideally these could be run in the same test, but that's causing OOM errors currently. + def testOverflowReadMessageSet() { + overflowReadMessageSet(false) + } + + def testOverflowReadMessageSetPurgatory() { + overflowReadMessageSet(true) + } + + /** * For testing purposes, just create these topics each with one partition and one replica for * which the provided broker should the leader for. Create and wait for broker to lead. Simple. */ -- 2.1.2