Index: core/src/test/scala/unit/kafka/message/CompressionUtilsTest.scala =================================================================== --- core/src/test/scala/unit/kafka/message/CompressionUtilsTest.scala (revision 1198602) +++ core/src/test/scala/unit/kafka/message/CompressionUtilsTest.scala (working copy) @@ -20,6 +20,7 @@ import kafka.utils.TestUtils import org.scalatest.junit.JUnitSuite import org.junit.Test +import junit.framework.Assert._ class CompressionUtilTest extends JUnitSuite { @@ -55,4 +56,20 @@ TestUtils.checkEquals(messages.iterator, TestUtils.getMessageIterator(decompressedMessages.iterator)) } + + @Test + def testSnappyCompressDecompressExplicit() { + + val messages = List[Message](new Message("hi there".getBytes), new Message("I am fine".getBytes), new Message("I am not so well today".getBytes)) + + val message = CompressionUtils.compress(messages,SnappyCompressionCodec) + + assertEquals(message.compressionCodec,SnappyCompressionCodec) + + val decompressedMessages = CompressionUtils.decompress(message) + + TestUtils.checkLength(decompressedMessages.iterator,3) + + TestUtils.checkEquals(messages.iterator, TestUtils.getMessageIterator(decompressedMessages.iterator)) + } } Index: core/src/main/scala/kafka/message/CompressionCodec.scala =================================================================== --- core/src/main/scala/kafka/message/CompressionCodec.scala (revision 1198602) +++ core/src/main/scala/kafka/message/CompressionCodec.scala (working copy) @@ -22,6 +22,7 @@ codec match { case 0 => NoCompressionCodec case 1 => GZIPCompressionCodec + case 2 => SnappyCompressionCodec case _ => throw new kafka.common.UnknownCodecException("%d is an unknown compression codec".format(codec)) } } @@ -33,4 +34,6 @@ case object GZIPCompressionCodec extends CompressionCodec { val codec = 1 } +case object SnappyCompressionCodec extends CompressionCodec { val codec = 2 } + case object NoCompressionCodec extends CompressionCodec { val codec = 0 } Index: core/src/main/scala/kafka/message/CompressionUtils.scala =================================================================== --- core/src/main/scala/kafka/message/CompressionUtils.scala (revision 1198602) +++ core/src/main/scala/kafka/message/CompressionUtils.scala (working copy) @@ -20,125 +20,143 @@ import java.io.ByteArrayOutputStream import java.io.IOException import java.io.InputStream -import java.util.zip.GZIPInputStream -import java.util.zip.GZIPOutputStream import java.nio.ByteBuffer import org.apache.log4j.Logger -object CompressionUtils { - private val logger = Logger.getLogger(getClass) +abstract sealed class CompressionFacade(inputStream: InputStream, outputStream: ByteArrayOutputStream) { + def close() = { + if (inputStream != null) inputStream.close() + if (outputStream != null) outputStream.close() + } + def read(a: Array[Byte]): Int + def write(a: Array[Byte]) +} - def compress(messages: Iterable[Message]): Message = compress(messages, DefaultCompressionCodec) +class GZIPCompression(inputStream: InputStream, outputStream: ByteArrayOutputStream) extends CompressionFacade(inputStream,outputStream) { + import java.util.zip.GZIPInputStream + import java.util.zip.GZIPOutputStream + val gzipIn:GZIPInputStream = if (inputStream == null) null else new GZIPInputStream(inputStream) + val gzipOut:GZIPOutputStream = if (outputStream == null) null else new GZIPOutputStream(outputStream) - def compress(messages: Iterable[Message], compressionCodec: CompressionCodec):Message = compressionCodec match { - case DefaultCompressionCodec => - val outputStream:ByteArrayOutputStream = new ByteArrayOutputStream() - val gzipOutput:GZIPOutputStream = new GZIPOutputStream(outputStream) - if(logger.isDebugEnabled) - logger.debug("Allocating message byte buffer of size = " + MessageSet.messageSetSize(messages)) + override def close() { + if (gzipIn != null) gzipIn.close() + if (gzipOut != null) gzipOut.close() + super.close() + } - val messageByteBuffer = ByteBuffer.allocate(MessageSet.messageSetSize(messages)) - messages.foreach(m => m.serializeTo(messageByteBuffer)) - messageByteBuffer.rewind + override def write(a: Array[Byte]) = { + gzipOut.write(a) + } - try { - gzipOutput.write(messageByteBuffer.array) - } catch { - case e: IOException => logger.error("Error while writing to the GZIP output stream", e) - if(gzipOutput != null) gzipOutput.close(); - if(outputStream != null) outputStream.close() - throw e - } finally { - if(gzipOutput != null) gzipOutput.close() - if(outputStream != null) outputStream.close() - } + override def read(a: Array[Byte]): Int = { + gzipIn.read(a) + } +} - val oneCompressedMessage:Message = new Message(outputStream.toByteArray, compressionCodec) - oneCompressedMessage - case GZIPCompressionCodec => - val outputStream:ByteArrayOutputStream = new ByteArrayOutputStream() - val gzipOutput:GZIPOutputStream = new GZIPOutputStream(outputStream) - if(logger.isDebugEnabled) - logger.debug("Allocating message byte buffer of size = " + MessageSet.messageSetSize(messages)) +class SnappyCompression(inputStream: InputStream,outputStream: ByteArrayOutputStream) extends CompressionFacade(inputStream,outputStream) { + import org.xerial.snappy.{SnappyInputStream} + import org.xerial.snappy.{SnappyOutputStream} + + val snappyIn:SnappyInputStream = if (inputStream == null) null else new SnappyInputStream(inputStream) + val snappyOut:SnappyOutputStream = if (outputStream == null) null else new SnappyOutputStream(outputStream) - val messageByteBuffer = ByteBuffer.allocate(MessageSet.messageSetSize(messages)) - messages.foreach(m => m.serializeTo(messageByteBuffer)) - messageByteBuffer.rewind + override def close() = { + if (snappyIn != null) snappyIn.close() + if (snappyOut != null) snappyOut.close() + super.close() + } - try { - gzipOutput.write(messageByteBuffer.array) - } catch { - case e: IOException => logger.error("Error while writing to the GZIP output stream", e) - if(gzipOutput != null) - gzipOutput.close() - if(outputStream != null) - outputStream.close() - throw e - } finally { - if(gzipOutput != null) - gzipOutput.close() - if(outputStream != null) - outputStream.close() - } + override def write(a: Array[Byte]) = { + snappyOut.write(a) + } - val oneCompressedMessage:Message = new Message(outputStream.toByteArray, compressionCodec) - oneCompressedMessage + override def read(a: Array[Byte]): Int = { + snappyIn.read(a) + } + +} + +object CompressionFactory { + def apply(compressionCodec: CompressionCodec, stream: ByteArrayOutputStream): CompressionFacade = compressionCodec match { + case GZIPCompressionCodec => new GZIPCompression(null,stream) + case SnappyCompressionCodec => new SnappyCompression(null,stream) case _ => throw new kafka.common.UnknownCodecException("Unknown Codec: " + compressionCodec) } + def apply(compressionCodec: CompressionCodec, stream: InputStream): CompressionFacade = compressionCodec match { + case GZIPCompressionCodec => new GZIPCompression(stream,null) + case SnappyCompressionCodec => new SnappyCompression(stream,null) + case _ => + throw new kafka.common.UnknownCodecException("Unknown Codec: " + compressionCodec) + } +} - def decompress(message: Message): ByteBufferMessageSet = message.compressionCodec match { - case DefaultCompressionCodec => - val outputStream:ByteArrayOutputStream = new ByteArrayOutputStream - val inputStream:InputStream = new ByteBufferBackedInputStream(message.payload) - val gzipIn:GZIPInputStream = new GZIPInputStream(inputStream) - val intermediateBuffer = new Array[Byte](1024) +object CompressionUtils { + private val logger = Logger.getLogger(getClass) - try { - Stream.continually(gzipIn.read(intermediateBuffer)).takeWhile(_ > 0).foreach { dataRead => - outputStream.write(intermediateBuffer, 0, dataRead) - } - }catch { - case e: IOException => logger.error("Error while reading from the GZIP input stream", e) - if(gzipIn != null) gzipIn.close - if(outputStream != null) outputStream.close - throw e - } finally { - if(gzipIn != null) gzipIn.close - if(outputStream != null) outputStream.close - } + //specify the codec which is the default when DefaultCompressionCodec is used + private var defaultCodec: CompressionCodec = GZIPCompressionCodec - val outputBuffer = ByteBuffer.allocate(outputStream.size) - outputBuffer.put(outputStream.toByteArray) - outputBuffer.rewind - val outputByteArray = outputStream.toByteArray - new ByteBufferMessageSet(outputBuffer) - case GZIPCompressionCodec => - val outputStream:ByteArrayOutputStream = new ByteArrayOutputStream - val inputStream:InputStream = new ByteBufferBackedInputStream(message.payload) - val gzipIn:GZIPInputStream = new GZIPInputStream(inputStream) - val intermediateBuffer = new Array[Byte](1024) + def compress(messages: Iterable[Message], compressionCodec: CompressionCodec = DefaultCompressionCodec):Message = { + val outputStream:ByteArrayOutputStream = new ByteArrayOutputStream() + + if(logger.isDebugEnabled) + logger.debug("Allocating message byte buffer of size = " + MessageSet.messageSetSize(messages)) - try { - Stream.continually(gzipIn.read(intermediateBuffer)).takeWhile(_ > 0).foreach { dataRead => - outputStream.write(intermediateBuffer, 0, dataRead) - } - }catch { - case e: IOException => logger.error("Error while reading from the GZIP input stream", e) - if(gzipIn != null) gzipIn.close - if(outputStream != null) outputStream.close - throw e - } finally { - if(gzipIn != null) gzipIn.close - if(outputStream != null) outputStream.close + var cf: CompressionFacade = null + + if (compressionCodec == DefaultCompressionCodec) + cf = CompressionFactory(defaultCodec,outputStream) + else + cf = CompressionFactory(compressionCodec,outputStream) + + val messageByteBuffer = ByteBuffer.allocate(MessageSet.messageSetSize(messages)) + messages.foreach(m => m.serializeTo(messageByteBuffer)) + messageByteBuffer.rewind + + try { + cf.write(messageByteBuffer.array) + } catch { + case e: IOException => logger.error("Error while writing to the GZIP output stream", e) + cf.close() + throw e + } finally { + cf.close() + } + + val oneCompressedMessage:Message = new Message(outputStream.toByteArray, compressionCodec) + oneCompressedMessage + } + + def decompress(message: Message): ByteBufferMessageSet = { + val outputStream:ByteArrayOutputStream = new ByteArrayOutputStream + val inputStream:InputStream = new ByteBufferBackedInputStream(message.payload) + + val intermediateBuffer = new Array[Byte](1024) + + var cf: CompressionFacade = null + + if (message.compressionCodec == DefaultCompressionCodec) + cf = CompressionFactory(defaultCodec,inputStream) + else + cf = CompressionFactory(message.compressionCodec,inputStream) + + try { + Stream.continually(cf.read(intermediateBuffer)).takeWhile(_ > 0).foreach { dataRead => + outputStream.write(intermediateBuffer, 0, dataRead) } + }catch { + case e: IOException => logger.error("Error while reading from the GZIP input stream", e) + cf.close() + throw e + } finally { + cf.close() + } - val outputBuffer = ByteBuffer.allocate(outputStream.size) - outputBuffer.put(outputStream.toByteArray) - outputBuffer.rewind - val outputByteArray = outputStream.toByteArray - new ByteBufferMessageSet(outputBuffer) - case _ => - throw new kafka.common.UnknownCodecException("Unknown Codec: " + message.compressionCodec) + val outputBuffer = ByteBuffer.allocate(outputStream.size) + outputBuffer.put(outputStream.toByteArray) + outputBuffer.rewind + val outputByteArray = outputStream.toByteArray + new ByteBufferMessageSet(outputBuffer) } } Index: project/build/KafkaProject.scala =================================================================== --- project/build/KafkaProject.scala (revision 1198602) +++ project/build/KafkaProject.scala (working copy) @@ -42,7 +42,7 @@ class CoreKafkaProject(info: ProjectInfo) extends DefaultProject(info) - with IdeaProject with CoreDependencies with TestDependencies { + with IdeaProject with CoreDependencies with TestDependencies with CompressionDependencies { val corePackageAction = packageAllAction //The issue is going from log4j 1.2.14 to 1.2.15, the developers added some features which required @@ -225,5 +225,9 @@ val log4j = "log4j" % "log4j" % "1.2.15" val jopt = "net.sf.jopt-simple" % "jopt-simple" % "3.2" } + + trait CompressionDependencies { + val snappy = "org.xerial.snappy" % "snappy-java" % "1.0.4.1" + } }