diff --git a/core/src/main/scala/kafka/api/FetchRequest.scala b/core/src/main/scala/kafka/api/FetchRequest.scala index 21b9f0e..242a288 100644 --- a/core/src/main/scala/kafka/api/FetchRequest.scala +++ b/core/src/main/scala/kafka/api/FetchRequest.scala @@ -17,9 +17,10 @@ package kafka.api -import java.nio._ -import kafka.network._ -import kafka.utils._ +import java.nio.ByteBuffer +import kafka.common.FetchRequestFormatException +import kafka.network.Request +import kafka.utils.Utils import scala.collection.mutable.{HashMap, Buffer, ListBuffer} object OffsetDetail { @@ -101,7 +102,27 @@ case class FetchRequest( versionId: Short, minBytes: Int, offsetInfo: Seq[OffsetDetail] ) extends Request(RequestKeys.Fetch) { + // ensure that a topic "X" appears in at most one OffsetDetail + def validate() { + if(offsetInfo == null) + throw new FetchRequestFormatException("FetchRequest has null offsetInfo") + + // We don't want to get fancy with groupBy's and filter's since we just want the first occurrence + var topics = Set[String]() + val iter = offsetInfo.iterator + while(iter.hasNext) { + val topic = iter.next.topic + if(topics.contains(topic)) + throw new FetchRequestFormatException("FetchRequest has multiple OffsetDetails for topic: " + topic) + else + topics += topic + } + } + def writeTo(buffer: ByteBuffer) { + // validate first + validate() + buffer.putShort(versionId) buffer.putInt(correlationId) Utils.writeShortString(buffer, clientId, "UTF-8") diff --git a/core/src/main/scala/kafka/api/FetchResponse.scala b/core/src/main/scala/kafka/api/FetchResponse.scala index b800dbe..51e788e 100644 --- a/core/src/main/scala/kafka/api/FetchResponse.scala +++ b/core/src/main/scala/kafka/api/FetchResponse.scala @@ -76,6 +76,8 @@ case class TopicData(topic: String, partitionData: Array[PartitionData]) { } object FetchResponse { + val CurrentVersion = 1.shortValue() + def readFrom(buffer: ByteBuffer): FetchResponse = { val versionId = buffer.getShort val correlationId = buffer.getInt diff --git a/core/src/main/scala/kafka/common/ErrorMapping.scala b/core/src/main/scala/kafka/common/ErrorMapping.scala index 3161458..491b0d7 100644 --- a/core/src/main/scala/kafka/common/ErrorMapping.scala +++ b/core/src/main/scala/kafka/common/ErrorMapping.scala @@ -33,13 +33,15 @@ object ErrorMapping { val InvalidMessageCode = 2 val WrongPartitionCode = 3 val InvalidFetchSizeCode = 4 + val InvalidFetchRequestFormatCode = 5 private val exceptionToCode = Map[Class[Throwable], Int]( classOf[OffsetOutOfRangeException].asInstanceOf[Class[Throwable]] -> OffsetOutOfRangeCode, classOf[InvalidMessageException].asInstanceOf[Class[Throwable]] -> InvalidMessageCode, classOf[InvalidPartitionException].asInstanceOf[Class[Throwable]] -> WrongPartitionCode, - classOf[InvalidMessageSizeException].asInstanceOf[Class[Throwable]] -> InvalidFetchSizeCode + classOf[InvalidMessageSizeException].asInstanceOf[Class[Throwable]] -> InvalidFetchSizeCode, + classOf[FetchRequestFormatException].asInstanceOf[Class[Throwable]] -> InvalidFetchRequestFormatCode ).withDefaultValue(UnknownCode) /* invert the mapping */ diff --git a/core/src/main/scala/kafka/common/FetchRequestFormatException.scala b/core/src/main/scala/kafka/common/FetchRequestFormatException.scala new file mode 100644 index 0000000..0bc7d4e --- /dev/null +++ b/core/src/main/scala/kafka/common/FetchRequestFormatException.scala @@ -0,0 +1,21 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package kafka.common + +class FetchRequestFormatException(val message: String) extends RuntimeException(message) { + def this() = this(null) +} diff --git a/core/src/main/scala/kafka/server/KafkaApis.scala b/core/src/main/scala/kafka/server/KafkaApis.scala index 233c62a..6deca6e 100644 --- a/core/src/main/scala/kafka/server/KafkaApis.scala +++ b/core/src/main/scala/kafka/server/KafkaApis.scala @@ -21,13 +21,13 @@ import java.io.IOException import java.lang.IllegalStateException import kafka.admin.{CreateTopicCommand, AdminUtils} import kafka.api._ -import kafka.common.ErrorMapping import kafka.log._ import kafka.message._ import kafka.network._ import kafka.utils.{SystemTime, Logging} import org.apache.log4j.Logger import scala.collection.mutable.ListBuffer +import kafka.common.{FetchRequestFormatException, ErrorMapping} /** * Logic to handle the various Kafka requests @@ -72,8 +72,7 @@ class KafkaApis(val logManager: LogManager) extends Logging { try { logManager.getOrCreateLog(request.topic, partition).append(request.messages) trace(request.messages.sizeInBytes + " bytes written to logs.") - } - catch { + } catch { case e => error("Error processing " + requestHandlerName + " on " + request.topic + ":" + partition, e) e match { @@ -92,8 +91,16 @@ class KafkaApis(val logManager: LogManager) extends Logging { if(requestLogger.isTraceEnabled) requestLogger.trace("Fetch request " + fetchRequest.toString) + // validate the request + try { + fetchRequest.validate() + } catch { + case e:FetchRequestFormatException => + val response = new FetchResponse(FetchResponse.CurrentVersion, fetchRequest.correlationId, Array.empty) + return Some(new FetchResponseSend(response, ErrorMapping.InvalidFetchRequestFormatCode)) + } + val fetchedData = new ListBuffer[TopicData]() - var error: Int = ErrorMapping.NoError for(offsetDetail <- fetchRequest.offsetInfo) { val info = new ListBuffer[PartitionData]() @@ -101,7 +108,7 @@ class KafkaApis(val logManager: LogManager) extends Logging { val (partitions, offsets, fetchSizes) = (offsetDetail.partitions, offsetDetail.offsets, offsetDetail.fetchSizes) for( (partition, offset, fetchSize) <- (partitions, offsets, fetchSizes).zipped.map((_,_,_)) ) { val partitionInfo = readMessageSet(topic, partition, offset, fetchSize) match { - case Left(err) => error = err; new PartitionData(partition, err, offset, MessageSet.Empty) + case Left(err) => new PartitionData(partition, err, offset, MessageSet.Empty) case Right(messages) => new PartitionData(partition, ErrorMapping.NoError, offset, messages) } info.append(partitionInfo) @@ -109,7 +116,7 @@ class KafkaApis(val logManager: LogManager) extends Logging { fetchedData.append(new TopicData(topic, info.toArray)) } val response = new FetchResponse(FetchRequest.CurrentVersion, fetchRequest.correlationId, fetchedData.toArray ) - Some(new FetchResponseSend(response, error)) + Some(new FetchResponseSend(response, ErrorMapping.NoError)) } private def readMessageSet(topic: String, partition: Int, offset: Long, maxSize: Int): Either[Int, MessageSet] = { diff --git a/core/src/test/scala/unit/kafka/integration/PrimitiveApiTest.scala b/core/src/test/scala/unit/kafka/integration/PrimitiveApiTest.scala index 0dbf9ff..64a6293 100644 --- a/core/src/test/scala/unit/kafka/integration/PrimitiveApiTest.scala +++ b/core/src/test/scala/unit/kafka/integration/PrimitiveApiTest.scala @@ -17,11 +17,12 @@ package kafka.integration -import scala.collection._ import java.io.File +import java.nio.ByteBuffer import java.util.Properties import junit.framework.Assert._ -import kafka.common.{ErrorMapping, OffsetOutOfRangeException, InvalidPartitionException} +import kafka.api.{OffsetDetail, FetchRequest, FetchRequestBuilder, ProducerRequest} +import kafka.common.{FetchRequestFormatException, OffsetOutOfRangeException, InvalidPartitionException} import kafka.message.{DefaultCompressionCodec, NoCompressionCodec, Message, ByteBufferMessageSet} import kafka.producer.{ProducerData, Producer, ProducerConfig} import kafka.serializer.StringDecoder @@ -29,8 +30,7 @@ import kafka.server.{KafkaRequestHandler, KafkaConfig} import kafka.utils.TestUtils import org.apache.log4j.{Level, Logger} import org.scalatest.junit.JUnit3Suite -import java.nio.ByteBuffer -import kafka.api.{FetchRequest, FetchRequestBuilder, ProducerRequest} +import scala.collection._ /** * End to end tests of the primitive apis against a local server @@ -61,6 +61,23 @@ class PrimitiveApiTest extends JUnit3Suite with ProducerConsumerTestHarness with val deserializedRequest = FetchRequest.readFrom(serializedBuffer) assertEquals(request, deserializedRequest) } + + def testFetchRequestEnforcesUniqueTopicsForOffsetDetails() { + val offsets = Array( + new OffsetDetail("topic1", Array(0, 1, 2), Array(0L, 0L, 0L), Array(1000, 1000, 1000)), + new OffsetDetail("topic2", Array(0, 1, 2), Array(0L, 0L, 0L), Array(1000, 1000, 1000)), + new OffsetDetail("topic1", Array(3, 4, 5), Array(0L, 0L, 0L), Array(1000, 1000, 1000)), + new OffsetDetail("topic2", Array(3, 4, 5), Array(0L, 0L, 0L), Array(1000, 1000, 1000)) + ) + val request = new FetchRequest( versionId = FetchRequest.CurrentVersion, correlationId = 0, clientId = "", + replicaId = -1, maxWait = -1, minBytes = -1, offsetInfo = offsets) + try { + consumer.fetch(request) + fail("FetchRequest should throw FetchRequestFormatException due to duplicate topics") + } catch { + case e: FetchRequestFormatException => "success" + } + } def testDefaultEncoderProducerAndFetch() { val topic = "test-topic"