Index: core/src/test/scala/unit/kafka/message/ByteBufferMessageSetTest.scala =================================================================== --- core/src/test/scala/unit/kafka/message/ByteBufferMessageSetTest.scala (revision 1160947) +++ core/src/test/scala/unit/kafka/message/ByteBufferMessageSetTest.scala (working copy) @@ -29,12 +29,20 @@ @Test def testValidBytes() { - val messages = new ByteBufferMessageSet(NoCompressionCodec, new Message("hello".getBytes()), new Message("there".getBytes())) - val buffer = ByteBuffer.allocate(messages.sizeInBytes.toInt + 2) - buffer.put(messages.serialized) - buffer.putShort(4) - val messagesPlus = new ByteBufferMessageSet(buffer) - assertEquals("Adding invalid bytes shouldn't change byte count", messages.validBytes, messagesPlus.validBytes) + { + val messages = new ByteBufferMessageSet(NoCompressionCodec, new Message("hello".getBytes()), new Message("there".getBytes())) + val buffer = ByteBuffer.allocate(messages.sizeInBytes.toInt + 2) + buffer.put(messages.serialized) + buffer.putShort(4) + val messagesPlus = new ByteBufferMessageSet(buffer) + assertEquals("Adding invalid bytes shouldn't change byte count", messages.validBytes, messagesPlus.validBytes) + } + + // test valid bytes on empty ByteBufferMessageSet + { + assertEquals("Valid bytes on an empty ByteBufferMessageSet should return 0", 0, + MessageSet.Empty.asInstanceOf[ByteBufferMessageSet].validBytes) + } } @Test Index: core/src/main/scala/kafka/message/ByteBufferMessageSet.scala =================================================================== --- core/src/main/scala/kafka/message/ByteBufferMessageSet.scala (revision 1160947) +++ core/src/main/scala/kafka/message/ByteBufferMessageSet.scala (working copy) @@ -40,7 +40,6 @@ private val logger = Logger.getLogger(getClass()) private var validByteCount = -1L private var shallowValidByteCount = -1L - private var deepValidByteCount = -1L def this(compressionCodec: CompressionCodec, messages: Message*) { this(MessageSet.createByteBuffer(compressionCodec, messages:_*), 0L, ErrorMapping.NoError) @@ -58,9 +57,9 @@ def serialized(): ByteBuffer = buffer - def validBytes: Long = deepValidBytes - - def shallowValidBytes: Long = { + def validBytes: Long = shallowValidBytes + + private def shallowValidBytes: Long = { if(shallowValidByteCount < 0) { val iter = deepIterator while(iter.hasNext) { @@ -68,18 +67,10 @@ shallowValidByteCount = messageAndOffset.offset } } - shallowValidByteCount - initialOffset + if(shallowValidByteCount < initialOffset) 0 + else (shallowValidByteCount - initialOffset) } - def deepValidBytes: Long = { - if (deepValidByteCount < 0) { - val iter = deepIterator - while (iter.hasNext) - iter.next - } - deepValidByteCount - } - /** Write the messages in this set to the given channel */ def writeTo(channel: WritableByteChannel, offset: Long, size: Long): Long = channel.write(buffer.duplicate) @@ -98,7 +89,6 @@ def makeNextOuter: MessageAndOffset = { if (topIter.remaining < 4) { - deepValidByteCount = currValidBytes return allDone() } val size = topIter.getInt() @@ -109,7 +99,6 @@ logger.trace("size of data = " + size) } if(size < 0 || topIter.remaining < size) { - deepValidByteCount = currValidBytes if (currValidBytes == 0 || size < 0) throw new InvalidMessageSizeException("invalid message size: " + size + " only received bytes: " + topIter.remaining + " at " + currValidBytes + "( possible causes (1) a single message larger than " + Index: core/src/main/scala/kafka/consumer/PartitionTopicInfo.scala =================================================================== --- core/src/main/scala/kafka/consumer/PartitionTopicInfo.scala (revision 1160947) +++ core/src/main/scala/kafka/consumer/PartitionTopicInfo.scala (working copy) @@ -59,7 +59,7 @@ * @return the number of valid bytes */ def enqueue(messages: ByteBufferMessageSet, fetchOffset: Long): Long = { - val size = messages.shallowValidBytes + val size = messages.validBytes if(size > 0) { // update fetched offset to the compressed data chunk size, not the decompressed message set size if(logger.isTraceEnabled)