From 1cbe59dc06ef0ad66672e9c48ce418e6173c78e5 Mon Sep 17 00:00:00 2001 From: David Arthur Date: Tue, 2 Apr 2013 10:12:27 -0400 Subject: [PATCH] KAFKA-316 Disallow MessageSets within MessageSets Add a depth variable to keep track of how nested the Messages are. If depth is greater than one, throw an InvalidMessageException. --- .../scala/kafka/message/ByteBufferMessageSet.scala | 11 +++++++---- .../kafka/message/MessageCompressionTest.scala | 4 ++-- 2 files changed, 9 insertions(+), 6 deletions(-) diff --git a/core/src/main/scala/kafka/message/ByteBufferMessageSet.scala b/core/src/main/scala/kafka/message/ByteBufferMessageSet.scala index 03590ad..4c7c797 100644 --- a/core/src/main/scala/kafka/message/ByteBufferMessageSet.scala +++ b/core/src/main/scala/kafka/message/ByteBufferMessageSet.scala @@ -59,7 +59,10 @@ object ByteBufferMessageSet { } } - def decompress(message: Message): ByteBufferMessageSet = { + def decompress(message: Message, depth: Int = 0): ByteBufferMessageSet = { + if(depth > 0) { + throw new InvalidMessageException("Compressed MessageSets cannot include nested compressed MessageSets") + } val outputStream: ByteArrayOutputStream = new ByteArrayOutputStream val inputStream: InputStream = new ByteBufferBackedInputStream(message.payload) val intermediateBuffer = new Array[Byte](1024) @@ -74,7 +77,7 @@ object ByteBufferMessageSet { val outputBuffer = ByteBuffer.allocate(outputStream.size) outputBuffer.put(outputStream.toByteArray) outputBuffer.rewind - new ByteBufferMessageSet(outputBuffer) + new ByteBufferMessageSet(outputBuffer, depth+1) } private def writeMessage(buffer: ByteBuffer, message: Message, offset: Long) { @@ -95,7 +98,7 @@ object ByteBufferMessageSet { * Option 2: Give it a list of messages along with instructions relating to serialization format. Producers will use this method. * */ -class ByteBufferMessageSet(@BeanProperty val buffer: ByteBuffer) extends MessageSet with Logging { +class ByteBufferMessageSet(@BeanProperty val buffer: ByteBuffer, val depth: Int = 0) extends MessageSet with Logging { private var shallowValidByteCount = -1 def this(compressionCodec: CompressionCodec, messages: Message*) { @@ -175,7 +178,7 @@ class ByteBufferMessageSet(@BeanProperty val buffer: ByteBuffer) extends Message innerIter = null new MessageAndOffset(newMessage, offset) case _ => - innerIter = ByteBufferMessageSet.decompress(newMessage).internalIterator() + innerIter = ByteBufferMessageSet.decompress(newMessage, depth).internalIterator() if(!innerIter.hasNext) innerIter = null makeNext() diff --git a/core/src/test/scala/unit/kafka/message/MessageCompressionTest.scala b/core/src/test/scala/unit/kafka/message/MessageCompressionTest.scala index ed22931..53710d3 100644 --- a/core/src/test/scala/unit/kafka/message/MessageCompressionTest.scala +++ b/core/src/test/scala/unit/kafka/message/MessageCompressionTest.scala @@ -42,7 +42,7 @@ class MessageCompressionTest extends JUnitSuite { assertEquals(messages, decompressed) } - @Test + @Test(expected = classOf[ InvalidMessageException ]) def testComplexCompressDecompress() { val messages = List(new Message("hi there".getBytes), new Message("I am fine".getBytes), new Message("I am not so well today".getBytes)) val message = new ByteBufferMessageSet(compressionCodec = DefaultCompressionCodec, messages = messages.slice(0, 2):_*) @@ -51,7 +51,7 @@ class MessageCompressionTest extends JUnitSuite { val decompressedMessages = complexMessage.iterator.map(_.message).toList assertEquals(messages, decompressedMessages) } - + def isSnappyAvailable(): Boolean = { try { val snappy = new org.xerial.snappy.SnappyOutputStream(new ByteArrayOutputStream()) -- 1.7.5.4