Index: core/src/main/scala/kafka/server/KafkaServerStartable.scala =================================================================== --- core/src/main/scala/kafka/server/KafkaServerStartable.scala (revision 1179439) +++ core/src/main/scala/kafka/server/KafkaServerStartable.scala (working copy) @@ -22,6 +22,7 @@ import kafka.producer.{ProducerData, ProducerConfig, Producer} import kafka.message.Message import org.apache.log4j.Logger +import java.util.concurrent.CountDownLatch import scala.collection.Map @@ -79,6 +80,7 @@ private val producer = new Producer[Null, Message](producerConfig) + var threadList = List[MirroringThread]() private def isTopicAllowed(topic: String) = { if (consumerConfig.mirrorTopicsWhitelist.nonEmpty) @@ -121,50 +123,89 @@ if (topicMap.nonEmpty) { if (consumerConnector != null) consumerConnector.shutdown() + + /** + * Before starting new consumer threads for the updated set of topics, + * shutdown the existing mirroring threads. Since the consumer connector + * is already shutdown, the mirroring threads should finish their task almost + * instantaneously. If they don't, this points to an error that needs to be looked + * into, and further mirroring should stop + */ + threadList.foreach(_.shutdown) + consumerConnector = Consumer.create(consumerConfig) val topicMessageStreams = consumerConnector.createMessageStreams(topicMap) - var threadList = List[Thread]() for ((topic, streamList) <- topicMessageStreams) for (i <- 0 until streamList.length) - threadList ::= Utils.newThread("kafka-embedded-consumer-%s-%d".format(topic, i), new Runnable() { - def run() { - logger.info("Starting consumer thread %d for topic %s".format(i, topic)) + threadList ::= new MirroringThread(streamList(i), topic, i) - try { - for (message <- streamList(i)) { - val pd = new ProducerData[Null, Message](topic, message) - producer.send(pd) - } - } - catch { - case e => - logger.fatal(e + Utils.stackTrace(e)) - logger.fatal(topic + " stream " + i + " unexpectedly exited") - } - } - }, false) - - for (thread <- threadList) - thread.start() + threadList.foreach(_.start) } else - logger.info("Not starting consumer threads (mirror topic list is empty)") + logger.info("Not starting mirroring threads (mirror topic list is empty)") } def startup() { topicEventWatcher = new ZookeeperTopicEventWatcher(consumerConfig, this) /* - * consumer threads are (re-)started upon topic events (which includes an - * initial startup event which lists the current topics) - */ - } + * consumer threads are (re-)started upon topic events (which includes an + * initial startup event which lists the current topics) + */ + } def shutdown() { - producer.close() + // first shutdown the topic watcher to prevent creating new consumer streams + if (topicEventWatcher != null) + topicEventWatcher.shutdown() + logger.info("Stopped the ZK watcher for new topics, now stopping the Kafka consumers") + // stop pulling more data for mirroring if (consumerConnector != null) consumerConnector.shutdown() - if (topicEventWatcher != null) - topicEventWatcher.shutdown() + logger.info("Stopped the kafka consumer threads for existing topics, now stopping the existing mirroring threads") + // wait for all mirroring threads to stop + threadList.foreach(_.shutdown) + logger.info("Stopped all existing mirroring threads, now stopping the producer") + // only then, shutdown the producer + producer.close() + logger.info("Successfully shutdown this Kafka mirror") } + + class MirroringThread(val stream: KafkaMessageStream[Message], val topic: String, val threadId: Int) extends Thread { + val shutdownComplete = new CountDownLatch(1) + val name = "kafka-embedded-consumer-%s-%d".format(topic, threadId) + this.setDaemon(false) + this.setName(name) + + private val logger = Logger.getLogger(name) + + override def run = { + logger.info("Starting mirroring thread %s for topic %s and stream %d".format(name, topic, threadId)) + + try { + for (message <- stream) { + val pd = new ProducerData[Null, Message](topic, message) + producer.send(pd) + } + + shutdownComplete.countDown + logger.info("Stopped mirroring thread %s for topic %s and stream %d".format(name, topic, threadId)) + } + catch { + case e => + logger.fatal(e + Utils.stackTrace(e)) + logger.fatal(topic + " stream " + threadId + " unexpectedly exited") + } + } + + def shutdown = { + try { + shutdownComplete.await + }catch { + case e: InterruptedException => logger.fatal("Shutdown of thread " + name + " interrupted. " + + "Mirroring thread might leak data!") + } + } + } } +