diff --git a/core/src/main/scala/kafka/server/AbstractFetcherThread.scala b/core/src/main/scala/kafka/server/AbstractFetcherThread.scala index 30f51257b344a85826a80c284f7f38e87db728a6..37fe8eb5fd4c7d4f8826bba9f6627f284f208f99 100755 --- a/core/src/main/scala/kafka/server/AbstractFetcherThread.scala +++ b/core/src/main/scala/kafka/server/AbstractFetcherThread.scala @@ -145,10 +145,9 @@ abstract class AbstractFetcherThread(name: String, case Errors.NONE => try { val messages = partitionData.toByteBufferMessageSet + val validBytes = messages.validBytes val newOffset = messages.shallowIterator.toSeq.lastOption match { case Some(m) => - partitionStates.updateAndMoveToEnd(topicPartition, new PartitionFetchState(m.nextOffset)) - fetcherStats.byteRate.mark(messages.validBytes) m.nextOffset case None => currentPartitionFetchState.offset @@ -157,6 +156,12 @@ abstract class AbstractFetcherThread(name: String, fetcherLagStats.getAndMaybePut(topic, partitionId).lag = Math.max(0L, partitionData.highWatermark - newOffset) // Once we hand off the partition data to the subclass, we can't mess with it any more in this thread processPartitionData(topicPartition, currentPartitionFetchState.offset, partitionData) + + if (validBytes > 0) { + // Update partitionStates only if there is no exception during processPartitionData + partitionStates.updateAndMoveToEnd(topicPartition, new PartitionFetchState(newOffset)) + fetcherStats.byteRate.mark(validBytes) + } } catch { case ime: CorruptRecordException => // we log the error and continue. This ensures two things diff --git a/core/src/test/scala/unit/kafka/server/AbstractFetcherThreadTest.scala b/core/src/test/scala/unit/kafka/server/AbstractFetcherThreadTest.scala index 1cd2496166d0d383c0e1fca0e1ea17568a6e983a..65c6fdafdfabde84e92975ad08ab2ef2e4b1cd1d 100644 --- a/core/src/test/scala/unit/kafka/server/AbstractFetcherThreadTest.scala +++ b/core/src/test/scala/unit/kafka/server/AbstractFetcherThreadTest.scala @@ -19,11 +19,12 @@ package kafka.server import com.yammer.metrics.Metrics import kafka.cluster.BrokerEndPoint -import kafka.message.ByteBufferMessageSet +import kafka.message.{ByteBufferMessageSet, Message, NoCompressionCodec} import kafka.server.AbstractFetcherThread.{FetchRequest, PartitionData} import kafka.utils.TestUtils import org.apache.kafka.common.TopicPartition import org.apache.kafka.common.protocol.Errors +import org.apache.kafka.common.utils.Utils import org.junit.Assert.{assertFalse, assertTrue} import org.junit.{Before, Test} @@ -123,4 +124,77 @@ class AbstractFetcherThreadTest { new DummyFetchRequest(partitionMap.map { case (k, v) => (k, v.offset) }.toMap) } + + @Test + def testFetchRequestCorruptedMessageException() { + val partition = new TopicPartition("topic", 0) + val fetcherThread = new TestFetcherThread("test", "client", new BrokerEndPoint(0, "localhost", 9092)) + + fetcherThread.start() + + // add one partition for fetching + fetcherThread.addPartitions(Map(partition -> 0L)) + + // wait until fetcherThread finishes the work + TestUtils.waitUntilTrue(() => (fetcherThread.fetchCount > 3), + "Failed waiting for fetcherThread tp finish the work") + + fetcherThread.shutdown() + + assertTrue(fetcherThread.logEndOffset == 2); + } + + class TestFetcherThread(name: String, + clientId: String, + sourceBroker: BrokerEndPoint) + extends DummyFetcherThread(name, clientId, sourceBroker) { + + var logEndOffset = 0L + var fetchCount = 0 + val normalPartitionDataSet = List(new NormalPartitionData(Seq(0)), new NormalPartitionData(Seq(1))) + + override def processPartitionData(topicAndPartition: TopicPartition, + fetchOffset: Long, + partitionData: PartitionData): Unit = { + if (fetchOffset != logEndOffset) + throw new RuntimeException( + "Offset mismatch for partition %s: fetched offset = %d, log end offset = %d." + .format(topicAndPartition, fetchOffset, logEndOffset)) + + val messages = partitionData.toByteBufferMessageSet + for (messageAndOffset <- messages.shallowIterator) { + val m = messageAndOffset.message + m.ensureValid() + logEndOffset = messageAndOffset.nextOffset + } + } + + override protected def fetch(fetchRequest: DummyFetchRequest): Seq[(TopicPartition, DummyPartitionData)] = { + if (fetchCount == 0) { + fetchCount += 1 + fetchRequest.offsets.mapValues(_ => new CorruptedPartitionData).toSeq + } else { + fetchCount += 1 + fetchRequest.offsets.map { + case (k, v) => (k, normalPartitionDataSet(v.toInt)) + }.toSeq + } + } + } + + class CorruptedPartitionData extends DummyPartitionData { + override def toByteBufferMessageSet: ByteBufferMessageSet = { + val corruptedMessage = new Message("hello".getBytes) + val badChecksum: Int = (corruptedMessage.checksum + 1 % Int.MaxValue).toInt + Utils.writeUnsignedInt(corruptedMessage.buffer, Message.CrcOffset, badChecksum) + new ByteBufferMessageSet(NoCompressionCodec, corruptedMessage) + } + } + + class NormalPartitionData(offsetSeq: Seq[Long]) extends DummyPartitionData { + override def toByteBufferMessageSet: ByteBufferMessageSet = { + new ByteBufferMessageSet(NoCompressionCodec, offsetSeq, new Message("hello".getBytes)) + } + } + }