diff --git a/core/src/main/scala/kafka/controller/KafkaController.scala b/core/src/main/scala/kafka/controller/KafkaController.scala
index 103f6cf..2a18c28 100755
--- a/core/src/main/scala/kafka/controller/KafkaController.scala
+++ b/core/src/main/scala/kafka/controller/KafkaController.scala
@@ -329,8 +329,8 @@ class KafkaController(val config : KafkaConfig, zkUtils: ZkUtils, val brokerStat
       partitionStateMachine.registerListeners()
       replicaStateMachine.registerListeners()
       initializeControllerContext()
-      replicaStateMachine.startup()
       partitionStateMachine.startup()
+      replicaStateMachine.startup()
       // register the partition change listeners for all existing topics on failover
       controllerContext.allTopics.foreach(topic => partitionStateMachine.registerPartitionChangeListener(topic))
       info("Broker %d is ready to serve as the new controller with epoch %d".format(config.brokerId, epoch))
diff --git a/core/src/main/scala/kafka/controller/PartitionStateMachine.scala b/core/src/main/scala/kafka/controller/PartitionStateMachine.scala
index ec03b84..1f6e739 100755
--- a/core/src/main/scala/kafka/controller/PartitionStateMachine.scala
+++ b/core/src/main/scala/kafka/controller/PartitionStateMachine.scala
@@ -67,7 +67,9 @@ class PartitionStateMachine(controller: KafkaController) extends Logging {
     // set started flag
     hasStarted.set(true)
     // try to move partitions to online state
-    triggerOnlinePartitionStateChange()
+    //inLock(controllerContext.controllerLock) {
+      triggerOnlinePartitionStateChange()
+    //}
 
     info("Started partition state machine with initial state -> " + partitionState.toString())
   }
@@ -105,6 +107,16 @@ class PartitionStateMachine(controller: KafkaController) extends Logging {
     info("Stopped partition state machine")
   }
 
+  def isBatchClear(): Boolean = {
+    try {
+      error(s"### This controller: $controllerId")
+      brokerRequestBatch.newBatch()
+      return true
+    } catch {
+      case e: IllegalStateException => return false
+    }
+  }
+
   /**
    * This API invokes the OnlinePartition state change on all partitions in either the NewPartition or OfflinePartition
    * state. This is called on a successful controller election and on broker changes
@@ -120,6 +132,7 @@ class PartitionStateMachine(controller: KafkaController) extends Logging {
           handleStateChange(topicAndPartition.topic, topicAndPartition.partition, OnlinePartition, controller.offlinePartitionSelector,
                             (new CallbackBuilder).build)
       }
+      Thread.sleep(200)
       brokerRequestBatch.sendRequestsToBrokers(controller.epoch)
     } catch {
       case e: Throwable => error("Error while moving some partitions to the online state", e)
@@ -141,13 +154,18 @@ class PartitionStateMachine(controller: KafkaController) extends Logging {
                          callbacks: Callbacks = (new CallbackBuilder).build) {
     info("Invoking state change to %s for partitions %s".format(targetState, partitions.mkString(",")))
     try {
+      error("### Begin")
+      error(new Exception)
       brokerRequestBatch.newBatch()
       partitions.foreach { topicAndPartition =>
         handleStateChange(topicAndPartition.topic, topicAndPartition.partition, targetState, leaderSelector, callbacks)
       }
+      Thread.sleep(200)
+      error("### End")
+      error(new Exception)
       brokerRequestBatch.sendRequestsToBrokers(controller.epoch)
     }catch {
-      case e: Throwable => error("Error while moving some partitions to %s state".format(targetState), e)
+      case e: Throwable => error("Error while moving some partitions to %s state, controllerId %d".format(targetState, controllerId), e)
       // TODO: It is not enough to bail out and log an error, it is important to trigger state changes for those partitions
     }
   }
diff --git a/core/src/main/scala/kafka/controller/ReplicaStateMachine.scala b/core/src/main/scala/kafka/controller/ReplicaStateMachine.scala
index 2fd8b95..749e54f 100755
--- a/core/src/main/scala/kafka/controller/ReplicaStateMachine.scala
+++ b/core/src/main/scala/kafka/controller/ReplicaStateMachine.scala
@@ -68,7 +68,9 @@ class ReplicaStateMachine(controller: KafkaController) extends Logging {
     // set started flag
     hasStarted.set(true)
     // move all Online replicas to Online
-    handleStateChanges(controllerContext.allLiveReplicas(), OnlineReplica)
+    inLock(controllerContext.controllerLock) {
+      handleStateChanges(controllerContext.allLiveReplicas(), OnlineReplica)
+    }
 
     info("Started replica state machine with initial state -> " + replicaState.toString())
   }
diff --git a/core/src/test/scala/unit/kafka/controller/ControllerFailoverTest.scala b/core/src/test/scala/unit/kafka/controller/ControllerFailoverTest.scala
index 91ac1f6..523977d 100644
--- a/core/src/test/scala/unit/kafka/controller/ControllerFailoverTest.scala
+++ b/core/src/test/scala/unit/kafka/controller/ControllerFailoverTest.scala
@@ -20,15 +20,14 @@ package kafka.controller
 import java.util.Properties
 import java.util.concurrent.LinkedBlockingQueue
 
-import kafka.api.RequestOrResponse
 import kafka.common.TopicAndPartition
 import kafka.integration.KafkaServerTestHarness
 import kafka.server.{KafkaConfig, KafkaServer}
 import kafka.utils._
 import org.apache.kafka.common.metrics.Metrics
-import org.apache.kafka.common.requests.{AbstractRequestResponse, AbstractRequest}
 import org.apache.kafka.common.utils.SystemTime
 import org.apache.log4j.{Level, Logger}
+import org.junit.Assert._
 import org.junit.{After, Before, Test}
 
 import scala.collection.mutable
@@ -58,11 +57,75 @@ class ControllerFailoverTest extends KafkaServerTestHarness with Logging {
     this.metrics.close()
   }
 
+
+  @Test
+  def testStateMachineRace() {
+    log.setLevel(Level.INFO)
+    var controller: KafkaServer = this.servers.head;
+    // Find the current controller
+    val epochMap: mutable.Map[Int, Int] = mutable.Map.empty
+    for (server <- this.servers) {
+      epochMap += (server.config.brokerId -> server.kafkaController.epoch)
+      if (server.kafkaController.isActive()) {
+        controller = server
+      }
+    }
+    // Create topic with one partition
+    kafka.admin.AdminUtils.createTopic(controller.zkUtils, topic, 1, 1)
+    val topicPartition = TopicAndPartition("topic1", 0)
+    var partitions = controller.kafkaController.partitionStateMachine.partitionsInState(OnlinePartition)
+    while (!partitions.contains(topicPartition)) {
+      partitions = controller.kafkaController.partitionStateMachine.partitionsInState(OnlinePartition)
+      Thread.sleep(100)
+    }
+    controller.kafkaController.partitionStateMachine.shutdown()
+    Thread.sleep(100)
+    for (server <- this.servers) {
+      if (!server.kafkaController.isActive()) {
+        server.shutdown()
+      }
+    }
+    info("Wait until broker has shutdown and removed from controller list")
+    while(controller.kafkaController.controllerContext.liveBrokerIds.size > 1) {
+      Thread.sleep(100)
+    }
+    Thread.sleep(1000)
+    assertTrue(controller.kafkaController.partitionStateMachine.isBatchClear())
+    info("Asserted")
+  }
+
+  @Test
+  def testStartupRace() {
+    log.setLevel(Level.INFO)
+    var controller: KafkaServer = this.servers.head;
+    // Find the current controller
+    val epochMap: mutable.Map[Int, Int] = mutable.Map.empty
+    for (server <- this.servers) {
+      epochMap += (server.config.brokerId -> server.kafkaController.epoch)
+      if (server.kafkaController.isActive()) {
+        controller = server
+      }
+    }
+    // Create topic with one partition
+    //controller.kafkaController.partitionStateMachine.shutdown()
+    kafka.admin.AdminUtils.createTopic(controller.zkUtils, topic, 1, 1)
+    //kafka.admin.AdminUtils.createTopic(controller.zkUtils, "test-topic", 1, 1)
+    val topicPartition = TopicAndPartition("topic1", 0)
+    var partitions = controller.kafkaController.partitionStateMachine.partitionsInState(OnlinePartition)
+    while (!partitions.contains(topicPartition)) {
+      partitions = controller.kafkaController.partitionStateMachine.partitionsInState(OnlinePartition)
+      Thread.sleep(100)
+    }
+    controller.kafkaController.partitionStateMachine.startup()
+
+    Thread.sleep(1000)
+  }
+
   /**
    * See @link{https://issues.apache.org/jira/browse/KAFKA-2300}
    * for the background of this test case
    */
-  @Test
+
   def testMetadataUpdate() {
     log.setLevel(Level.INFO)
     var controller: KafkaServer = this.servers.head;
@@ -151,6 +214,7 @@ class ControllerFailoverTest extends KafkaServerTestHarness with Logging {
   }
 }
 
+
 class MockChannelManager(private val controllerContext: ControllerContext, config: KafkaConfig, metrics: Metrics)
   extends ControllerChannelManager(controllerContext, config, new SystemTime, metrics) {
 
