diff --git a/core/src/main/scala/kafka/producer/ProducerConfig.scala b/core/src/main/scala/kafka/producer/ProducerConfig.scala
index e559187..235b228 100644
--- a/core/src/main/scala/kafka/producer/ProducerConfig.scala
+++ b/core/src/main/scala/kafka/producer/ProducerConfig.scala
@@ -113,5 +113,15 @@ class ProducerConfig private (val props: VerifiableProperties)
 
   val producerRetryBackoffMs = props.getInt("producer.retry.backoff.ms", 100)
 
+  /**
+   * The producer generally refreshes the topic metadata from brokers when there is a failure
+   * (partition missing, leader not available...). It will also poll regularly (default: every 10min
+   * so 600000ms). If you set this to a negative value, metadata will only get refreshed on failure.
+   * If you set this to zero, the metadata will get refreshed after each message sent (not recommended)
+   * Important note: the refresh happen only AFTER the message is sent, so if the producer never sends
+   * a message the metadata is never refreshed
+   */
+  val topicMetadataRefreshIntervalMs = props.getInt("producer.metadata.refresh.interval.ms", 600000)
+
   validate(this)
 }
diff --git a/core/src/main/scala/kafka/producer/async/DefaultEventHandler.scala b/core/src/main/scala/kafka/producer/async/DefaultEventHandler.scala
index 4f04862..5452bf3 100644
--- a/core/src/main/scala/kafka/producer/async/DefaultEventHandler.scala
+++ b/core/src/main/scala/kafka/producer/async/DefaultEventHandler.scala
@@ -21,13 +21,12 @@ import kafka.common._
 import kafka.message.{NoCompressionCodec, Message, ByteBufferMessageSet}
 import kafka.producer._
 import kafka.serializer.Encoder
-import kafka.utils.{Utils, Logging}
+import kafka.utils.{Utils, Logging, SystemTime}
 import scala.collection.{Seq, Map}
-import scala.collection.mutable.{ArrayBuffer, HashMap}
+import scala.collection.mutable.{ArrayBuffer, HashMap, Set}
 import java.util.concurrent.atomic._
 import kafka.api.{TopicMetadata, ProducerRequest}
 
-
 class DefaultEventHandler[K,V](config: ProducerConfig,
                                private val partitioner: Partitioner[K],
                                private val encoder: Encoder[V],
@@ -43,6 +42,10 @@ class DefaultEventHandler[K,V](config: ProducerConfig,
 
   private val lock = new Object()
 
+  private val topicMetadataRefreshInterval = config.topicMetadataRefreshIntervalMs
+  private var lastTopicMetadataRefresh = 0L
+  private val topicMetadataToRefresh = Set.empty[String]
+
   private val producerStats = ProducerStatsRegistry.getProducerStats(config.clientId)
   private val producerTopicStats = ProducerTopicStatsRegistry.getProducerTopicStats(config.clientId)
 
@@ -58,12 +61,21 @@ class DefaultEventHandler[K,V](config: ProducerConfig,
       var outstandingProduceRequests = serializedData
       var remainingRetries = config.producerRetries + 1
       while (remainingRetries > 0 && outstandingProduceRequests.size > 0) {
+        topicMetadataToRefresh ++= outstandingProduceRequests.map(_.topic)
         outstandingProduceRequests = dispatchSerializedData(outstandingProduceRequests)
+        if (topicMetadataRefreshInterval >= 0 &&
+            SystemTime.milliseconds - lastTopicMetadataRefresh > topicMetadataRefreshInterval) {
+          Utils.swallowError(brokerPartitionInfo.updateInfo(topicMetadataToRefresh.toSet))
+          topicMetadataToRefresh.clear
+          lastTopicMetadataRefresh = SystemTime.milliseconds
+        }
         if (outstandingProduceRequests.size > 0)  {
           // back off and update the topic metadata cache before attempting another send operation
           Thread.sleep(config.producerRetryBackoffMs)
           // get topics of the outstanding produce requests and refresh metadata for those
-          Utils.swallowError(brokerPartitionInfo.updateInfo(outstandingProduceRequests.map(_.topic).toSet))
+          // unless we just refreshed
+          if (topicMetadataToRefresh.nonEmpty)
+              Utils.swallowError(brokerPartitionInfo.updateInfo(outstandingProduceRequests.map(_.topic).toSet))
           remainingRetries -= 1
           producerStats.resendRate.mark()
         }
@@ -133,9 +145,7 @@ class DefaultEventHandler[K,V](config: ProducerConfig,
     try {
       for (message <- messages) {
         val topicPartitionsList = getPartitionListForTopic(message)
-        val totalNumPartitions = topicPartitionsList.length
-
-        val partitionIndex = getPartition(message.key, totalNumPartitions)
+        val partitionIndex = getPartition(message.key, topicPartitionsList)
         val brokerPartition = topicPartitionsList(partitionIndex)
 
         // postpone the failure until the send operation, so that requests for other brokers are handled correctly
@@ -184,17 +194,24 @@ class DefaultEventHandler[K,V](config: ProducerConfig,
    * Retrieves the partition id and throws an UnknownTopicOrPartitionException if
    * the value of partition is not between 0 and numPartitions-1
    * @param key the partition key
-   * @param numPartitions the total number of available partitions
+   * @param topicPartitionList the list of available partitions
    * @return the partition id
    */
-  private def getPartition(key: K, numPartitions: Int): Int = {
+  private def getPartition(key: K, topicPartitionList: Seq[PartitionAndLeader]): Int = {
+    val numPartitions = topicPartitionList.size
     if(numPartitions <= 0)
       throw new UnknownTopicOrPartitionException("Invalid number of partitions: " + numPartitions +
         "\n Valid values are > 0")
     val partition =
-      if(key == null)
-        Utils.abs(partitionCounter.getAndIncrement()) % numPartitions
-      else
+      if(key == null) {
+        // If the key is null, we don't really need a partitioner so we just send to the next
+        // available partition
+        val availablePartitions = topicPartitionList.filter(_.leaderBrokerIdOpt.isDefined)
+        if (availablePartitions.isEmpty)
+          throw new UnknownTopicOrPartitionException("No leader for any partition")
+        val index = Utils.abs(partitionCounter.getAndIncrement()) % availablePartitions.size
+        availablePartitions(index).partitionId
+      } else
         partitioner.partition(key, numPartitions)
     if(partition < 0 || partition >= numPartitions)
       throw new UnknownTopicOrPartitionException("Invalid partition id : " + partition +
