diff --git a/perf/src/main/scala/kafka/perf/ProducerPerformance.scala b/perf/src/main/scala/kafka/perf/ProducerPerformance.scala
index ad2ac26..d48c04a 100644
--- a/perf/src/main/scala/kafka/perf/ProducerPerformance.scala
+++ b/perf/src/main/scala/kafka/perf/ProducerPerformance.scala
@@ -25,6 +25,7 @@ import kafka.message.{CompressionCodec, Message}
 import java.text.SimpleDateFormat
 import kafka.serializer._
 import java.util._
+import java.nio.ByteBuffer
 import collection.immutable.List
 import kafka.utils.{VerifiableProperties, Logging}
 import kafka.metrics.KafkaMetricsReporter
@@ -115,9 +116,15 @@ object ProducerPerformance extends Logging {
     val csvMetricsReporterEnabledOpt = parser.accepts("csv-reporter-enabled", "If set, the CSV metrics reporter will be enabled")
     val metricsDirectoryOpt = parser.accepts("metrics-dir", "If csv-reporter-enable is set, and this parameter is" +
             "set, the csv metrics will be outputed here")
-      .withRequiredArg
-      .describedAs("metrics dictory")
-      .ofType(classOf[java.lang.String])
+            .withRequiredArg
+            .describedAs("metrics dictory")
+            .ofType(classOf[java.lang.String])
+    val messageKeyRangeOpt = parser.accepts("message-key-range", "If set, the keys of the sent messages would range from 0 to this specified value; otherwise it will use the messageId as the key")
+            .withOptionalArg
+            .describedAs("message key range")
+            .ofType(classOf[java.lang.Integer])
+            .defaultsTo(4)
+
 
     val options = parser.parse(args : _*)
     for(arg <- List(topicsOpt, brokerListOpt, numMessagesOpt)) {
@@ -142,8 +149,9 @@ object ProducerPerformance extends Logging {
     val compressionCodec = CompressionCodec.getCompressionCodec(options.valueOf(compressionCodecOpt).intValue)
     val seqIdMode = options.has(initialMessageIdOpt)
     var initialMessageId: Int = 0
-    if(seqIdMode)
+    if (seqIdMode)
       initialMessageId = options.valueOf(initialMessageIdOpt).intValue()
+    var messageKeyRange = options.valueOf(messageKeyRangeOpt).intValue()
     val producerRequestTimeoutMs = options.valueOf(producerRequestTimeoutMsOpt).intValue()
     val producerRequestRequiredAcks = options.valueOf(producerRequestRequiredAcksOpt).intValue()
     val producerNumRetries = options.valueOf(producerNumRetriesOpt).intValue()
@@ -189,11 +197,12 @@ object ProducerPerformance extends Logging {
     props.put("message.send.max.retries", config.producerNumRetries.toString)
     props.put("retry.backoff.ms", config.producerRetryBackoffMs.toString)
     props.put("serializer.class", classOf[DefaultEncoder].getName.toString)
-    props.put("key.serializer.class", classOf[NullEncoder[Long]].getName.toString)
+    props.put("key.serializer.class", classOf[DefaultEncoder].getName.toString)
+    props.put("partitioner.class", "kafka.producer.ByteArrayPartitioner")
 
     
     val producerConfig = new ProducerConfig(props)
-    val producer = new Producer[Long, Array[Byte]](producerConfig)
+    val producer = new Producer[Array[Byte], Array[Byte]](producerConfig)
     val seqIdNumDigit = 10   // no. of digits for max int value
 
     val messagesPerThread = config.numMessages / config.numThreads
@@ -204,9 +213,13 @@ object ProducerPerformance extends Logging {
     private val messageIdLabel = "MessageID"
     private val threadIdLabel  = "ThreadID"
     private val topicLabel     = "Topic"
+    private val keyLabel       = "Key"
     private var leftPaddedSeqId : String = ""
 
-    private def generateMessageWithSeqId(topic: String, msgId: Long, msgSize: Int): Array[Byte] = {
+    private var keySelect : Long = 0
+
+
+    private def generateMessageWithSeqId(topic: String, key: Long, msgId: Long, msgSize: Int): Array[Byte] = {
       // Each thread gets a unique range of sequential no. for its ids.
       // Eg. 1000 msg in 10 threads => 100 msg per thread
       // thread 0 IDs :   0 ~  99
@@ -217,6 +230,8 @@ object ProducerPerformance extends Logging {
 
       val msgHeader = topicLabel      + SEP +
               topic           + SEP +
+              keyLabel        + SEP +
+              key             + SEP +
               threadIdLabel   + SEP +
               threadId        + SEP +
               messageIdLabel  + SEP +
@@ -227,16 +242,19 @@ object ProducerPerformance extends Logging {
       return seqMsgString.getBytes()
     }
 
-    private def generateProducerData(topic: String, messageId: Long): (KeyedMessage[Long, Array[Byte]], Int) = {
+    private def generateProducerData(topic: String, messageId: Long): (KeyedMessage[Array[Byte], Array[Byte]], Int) = {
       val msgSize = if(config.isFixSize) config.messageSize else 1 + rand.nextInt(config.messageSize)
+      val key = keySelect
+      keySelect = (keySelect + 1) % config.messageKeyRange
       val message =
         if(config.seqIdMode) {
           val seqId = config.initialMessageId + (messagesPerThread * threadId) + messageId
-          generateMessageWithSeqId(topic, seqId, msgSize)
+          generateMessageWithSeqId(topic, key, seqId, msgSize)
         } else {
           new Array[Byte](msgSize)
         }
-      (new KeyedMessage[Long, Array[Byte]](topic, messageId, message), message.length)
+      val keyBytes = key.toString.getBytes()
+      (new KeyedMessage[Array[Byte], Array[Byte]](topic, keyBytes, message), message.length)
     }
 
     override def run {
