From acd4c80066c12de2631b8e462ac87ffab2b765eb Mon Sep 17 00:00:00 2001 From: Onur Karaman Date: Sat, 25 Apr 2015 15:19:39 -0700 Subject: [PATCH] add heartbeat to coordinator --- .../clients/consumer/internals/Coordinator.java | 3 +- .../org/apache/kafka/common/protocol/Errors.java | 10 +- .../main/scala/kafka/coordinator/Consumer.scala | 45 ++ .../kafka/coordinator/ConsumerCoordinator.scala | 623 +++++++++++++-------- .../scala/kafka/coordinator/ConsumerRegistry.scala | 52 -- .../scala/kafka/coordinator/DelayedHeartbeat.scala | 33 +- .../scala/kafka/coordinator/DelayedJoinGroup.scala | 31 +- .../scala/kafka/coordinator/DelayedRebalance.scala | 37 +- core/src/main/scala/kafka/coordinator/Group.scala | 131 +++++ .../scala/kafka/coordinator/GroupRegistry.scala | 79 --- .../scala/kafka/coordinator/HeartbeatBucket.scala | 36 -- .../kafka/coordinator/PartitionAssignor.scala | 121 ++++ .../scala/kafka/server/DelayedOperationKey.scala | 6 - core/src/main/scala/kafka/server/KafkaApis.scala | 18 +- core/src/main/scala/kafka/server/KafkaServer.scala | 2 +- .../main/scala/kafka/server/OffsetManager.scala | 2 +- .../scala/unit/kafka/coordinator/GroupTest.scala | 155 +++++ .../kafka/coordinator/PartitionAssignorTest.scala | 251 +++++++++ 18 files changed, 1145 insertions(+), 490 deletions(-) create mode 100644 core/src/main/scala/kafka/coordinator/Consumer.scala delete mode 100644 core/src/main/scala/kafka/coordinator/ConsumerRegistry.scala create mode 100644 core/src/main/scala/kafka/coordinator/Group.scala delete mode 100644 core/src/main/scala/kafka/coordinator/GroupRegistry.scala delete mode 100644 core/src/main/scala/kafka/coordinator/HeartbeatBucket.scala create mode 100644 core/src/main/scala/kafka/coordinator/PartitionAssignor.scala create mode 100644 core/src/test/scala/unit/kafka/coordinator/GroupTest.scala create mode 100644 core/src/test/scala/unit/kafka/coordinator/PartitionAssignorTest.scala diff --git a/clients/src/main/java/org/apache/kafka/clients/consumer/internals/Coordinator.java b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/Coordinator.java index e55ab11..b2764df 100644 --- a/clients/src/main/java/org/apache/kafka/clients/consumer/internals/Coordinator.java +++ b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/Coordinator.java @@ -96,7 +96,7 @@ public final class Coordinator { this.time = time; this.client = client; this.generation = -1; - this.consumerId = ""; + this.consumerId = JoinGroupRequest.UNKNOWN_CONSUMER_ID; this.groupId = groupId; this.metadata = metadata; this.consumerCoordinator = null; @@ -132,6 +132,7 @@ public final class Coordinator { // TODO: needs to handle disconnects and errors, should not just throw exceptions Errors.forCode(response.errorCode()).maybeThrow(); this.consumerId = response.consumerId(); + this.generation = response.generationId(); // set the flag to refresh last committed offsets this.subscriptions.needRefreshCommits(); diff --git a/clients/src/main/java/org/apache/kafka/common/protocol/Errors.java b/clients/src/main/java/org/apache/kafka/common/protocol/Errors.java index 36aa412..5b898c8 100644 --- a/clients/src/main/java/org/apache/kafka/common/protocol/Errors.java +++ b/clients/src/main/java/org/apache/kafka/common/protocol/Errors.java @@ -69,7 +69,15 @@ public enum Errors { INVALID_REQUIRED_ACKS(21, new InvalidRequiredAcksException("Produce request specified an invalid value for required acks.")), ILLEGAL_GENERATION(22, - new ApiException("Specified consumer generation id is not valid.")); + new ApiException("Specified consumer generation id is not valid.")), + INCONSISTENT_PARTITION_ASSIGNMENT_STRATEGY(23, + new ApiException("The request partition assignment strategy does not match that of the group.")), + UNKNOWN_PARTITION_ASSIGNMENT_STRATEGY(24, + new ApiException("The request partition assignment strategy is unknown to the broker.")), + UNKNOWN_CONSUMER_ID(25, + new ApiException("The coordinator is not aware of this consumer.")), + INVALID_SESSION_TIMEOUT(26, + new ApiException("The session timeout is not within an acceptable range.")); private static Map, Errors> classToError = new HashMap, Errors>(); private static Map codeToError = new HashMap(); diff --git a/core/src/main/scala/kafka/coordinator/Consumer.scala b/core/src/main/scala/kafka/coordinator/Consumer.scala new file mode 100644 index 0000000..956f11c --- /dev/null +++ b/core/src/main/scala/kafka/coordinator/Consumer.scala @@ -0,0 +1,45 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.coordinator + +import kafka.common.TopicAndPartition +import kafka.utils.nonthreadsafe + +/** + * A consumer contains the following metadata: + * + * Heartbeat metadata: + * 1. negotiated heartbeat session timeout. + * 2. recorded number of timed-out heartbeats. + * 3. timestamp of the latest heartbeat + * + * Subscription metadata: + * 1. subscribed topics + * 2. assigned partitions for the subscribed topics. + */ +@nonthreadsafe +private[coordinator] class Consumer(val consumerId: String, + val groupId: String, + var topics: Set[String], + val sessionTimeoutMs: Int) { + + var numExpiredHeartbeat = 0 + var awaitingRebalance = false + var assignedTopicPartitions = Set.empty[TopicAndPartition] + var latestHeartbeat: Long = -1 +} diff --git a/core/src/main/scala/kafka/coordinator/ConsumerCoordinator.scala b/core/src/main/scala/kafka/coordinator/ConsumerCoordinator.scala index 456b602..a7ebe6f 100644 --- a/core/src/main/scala/kafka/coordinator/ConsumerCoordinator.scala +++ b/core/src/main/scala/kafka/coordinator/ConsumerCoordinator.scala @@ -16,79 +16,82 @@ */ package kafka.coordinator -import org.apache.kafka.common.protocol.Errors - import kafka.common.TopicAndPartition import kafka.server._ import kafka.utils._ +import kafka.utils.CoreUtils.{inReadLock, inWriteLock} +import org.apache.kafka.common.protocol.Errors +import org.apache.kafka.common.requests.JoinGroupRequest -import scala.collection.mutable.HashMap +import java.util.concurrent.atomic.AtomicBoolean +import java.util.concurrent.locks.ReentrantReadWriteLock -import org.I0Itec.zkclient.{IZkChildListener, ZkClient} -import org.apache.kafka.common.requests.JoinGroupRequest +import collection.mutable + +import org.I0Itec.zkclient.{IZkDataListener, ZkClient} /** - * Kafka coordinator handles consumer group and consumer offset management. + * ConsumerCoordinator handles consumer group and consumer offset management. * - * Each Kafka server instantiates a coordinator, which is responsible for a set of + * Each Kafka server instantiates a coordinator which is responsible for a set of * consumer groups; the consumer groups are assigned to coordinators based on their * group names. */ class ConsumerCoordinator(val config: KafkaConfig, - val zkClient: ZkClient) extends Logging { - - this.logIdent = "[Kafka Coordinator " + config.brokerId + "]: " + val zkClient: ZkClient, + val offsetManager: OffsetManager) extends Logging { - /* zookeeper listener for topic-partition changes */ - private val topicPartitionChangeListeners = new HashMap[String, TopicPartitionChangeListener] + this.logIdent = "[ConsumerCoordinator " + config.brokerId + "]: " - /* the consumer group registry cache */ - // TODO: access to this map needs to be synchronized - private val consumerGroupRegistries = new HashMap[String, GroupRegistry] + private val MinSessionTimeoutMs = 6000 + private val MaxSessionTimeoutMs = 30000 - /* the list of subscribed groups per topic */ - // TODO: access to this map needs to be synchronized - private val consumerGroupsPerTopic = new HashMap[String, List[String]] - - /* the delayed operation purgatory for heartbeat-based failure detection */ private var heartbeatPurgatory: DelayedOperationPurgatory[DelayedHeartbeat] = null - - /* the delayed operation purgatory for handling join-group requests */ private var joinGroupPurgatory: DelayedOperationPurgatory[DelayedJoinGroup] = null - - /* the delayed operation purgatory for preparing rebalance process */ private var rebalancePurgatory: DelayedOperationPurgatory[DelayedRebalance] = null - /* latest consumer heartbeat bucket's end timestamp in milliseconds */ - private var latestHeartbeatBucketEndMs: Long = SystemTime.milliseconds + private val isLoadingGroupMetadata = new AtomicBoolean(true) /** - * Start-up logic executed at the same time when the server starts up. + * NOTE: If a group lock and coordinatorLock are simultaneously needed, + * be sure to acquire the group lock before coordinatorLock to prevent deadlock */ - def startup() { + private val coordinatorLock = new ReentrantReadWriteLock() - // Initialize consumer group registries and heartbeat bucket metadata - latestHeartbeatBucketEndMs = SystemTime.milliseconds + /** + * These should be guarded by coordinatorLock + */ + private val groups = new mutable.HashMap[String, Group] + private val groupsPerTopic = new mutable.HashMap[String, Set[String]] + private val topicPartitionCounts = new mutable.HashMap[String, Int] + private val topicPartitionChangeListeners = new mutable.HashMap[String, TopicPartitionChangeListener] - // Initialize purgatories for delayed heartbeat, join-group and rebalance operations - heartbeatPurgatory = new DelayedOperationPurgatory[DelayedHeartbeat](purgatoryName = "Heartbeat", brokerId = config.brokerId) - joinGroupPurgatory = new DelayedOperationPurgatory[DelayedJoinGroup](purgatoryName = "JoinGroup", brokerId = config.brokerId) - rebalancePurgatory = new DelayedOperationPurgatory[DelayedRebalance](purgatoryName = "Rebalance", brokerId = config.brokerId) + /** + * Startup logic executed at the same time when the server starts up. + */ + def startup() { + info("Starting up.") + heartbeatPurgatory = new DelayedOperationPurgatory[DelayedHeartbeat]("Heartbeat", config.brokerId) + joinGroupPurgatory = new DelayedOperationPurgatory[DelayedJoinGroup]("JoinGroup", config.brokerId) + rebalancePurgatory = new DelayedOperationPurgatory[DelayedRebalance]("Rebalance", config.brokerId) + loadGroupMetadata() + } + private def loadGroupMetadata() { + info("Loading group metadata.") + isLoadingGroupMetadata.set(false) } /** - * Shut-down logic executed at the same time when server shuts down, - * ordering of actions should be reversed from the start-up process + * Shutdown logic executed at the same time when server shuts down. + * Ordering of actions should be reversed from the start-up process. * */ def shutdown() { - + info("Shutting down.") // De-register all Zookeeper listeners for topic-partition changes - for (topic <- topicPartitionChangeListeners.keys) { - deregisterTopicChangeListener(topic) - } + topicPartitionChangeListeners.keys.foreach(deregisterTopicPartitionChangeListener) topicPartitionChangeListeners.clear() // Shutdown purgatories for delayed heartbeat, join-group and rebalance operations @@ -96,252 +99,408 @@ class ConsumerCoordinator(val config: KafkaConfig, joinGroupPurgatory.shutdown() rebalancePurgatory.shutdown() - // Clean up consumer group registries metadata - consumerGroupRegistries.clear() - consumerGroupsPerTopic.clear() + // Clean up consumer group metadata + groups.clear() + groupsPerTopic.clear() + topicPartitionCounts.clear() } - /** - * Process a join-group request from a consumer to join as a new group member - */ - def consumerJoinGroup(groupId: String, - consumerId: String, - topics: List[String], - sessionTimeoutMs: Int, - partitionAssignmentStrategy: String, - responseCallback:(List[TopicAndPartition], Int, Short) => Unit ) { - - // if the group does not exist yet, create one - if (!consumerGroupRegistries.contains(groupId)) - createNewGroup(groupId, partitionAssignmentStrategy) - - val groupRegistry = consumerGroupRegistries(groupId) - - // if the consumer id is unknown or it does exists in - // the group yet, register this consumer to the group - if (consumerId.equals(JoinGroupRequest.UNKNOWN_CONSUMER_ID)) { - createNewConsumer(groupId, groupRegistry.generateNextConsumerId, topics, sessionTimeoutMs) - } else if (!groupRegistry.memberRegistries.contains(consumerId)) { - createNewConsumer(groupId, consumerId, topics, sessionTimeoutMs) + def joinGroup(groupId: String, + consumerId: String, + topics: Set[String], + sessionTimeoutMs: Int, + partitionAssignmentStrategy: String, + responseCallback:(Set[TopicAndPartition], String, Int, Short) => Unit) { + if (isLoadingGroupMetadata.get) { + responseCallback(Set.empty, consumerId, 0, Errors.CONSUMER_COORDINATOR_NOT_AVAILABLE.code) + } else if (!isCoordinatorForGroup(groupId)) { + responseCallback(Set.empty, consumerId, 0, Errors.NOT_COORDINATOR_FOR_CONSUMER.code) + } else if (!PartitionAssignor.strategies.contains(partitionAssignmentStrategy)) { + responseCallback(Set.empty, consumerId, 0, Errors.UNKNOWN_PARTITION_ASSIGNMENT_STRATEGY.code) + } else if (sessionTimeoutMs < MinSessionTimeoutMs || sessionTimeoutMs > MaxSessionTimeoutMs) { + responseCallback(Set.empty, consumerId, 0, Errors.INVALID_SESSION_TIMEOUT.code) + } else { + val group = inReadLock(coordinatorLock) { + groups.get(groupId).orNull + } + if (group == null) { + if (consumerId != JoinGroupRequest.UNKNOWN_CONSUMER_ID) { + responseCallback(Set.empty, consumerId, 0, Errors.UNKNOWN_CONSUMER_ID.code) + } else { + val group = addGroup(groupId, partitionAssignmentStrategy) + doJoinGroup(group, consumerId, topics, sessionTimeoutMs, partitionAssignmentStrategy, responseCallback) + } + } else { + doJoinGroup(group, consumerId, topics, sessionTimeoutMs, partitionAssignmentStrategy, responseCallback) + } } + } - // add a delayed join-group operation to the purgatory - // TODO + private def doJoinGroup(group: Group, + consumerId: String, + topics: Set[String], + sessionTimeoutMs: Int, + partitionAssignmentStrategy: String, + responseCallback:(Set[TopicAndPartition], String, Int, Short) => Unit) { + group synchronized { + if (group.is(Dead)) { + responseCallback(Set.empty, consumerId, 0, Errors.UNKNOWN_CONSUMER_ID.code) + } else if (partitionAssignmentStrategy != group.partitionAssignmentStrategy) { + responseCallback(Set.empty, consumerId, 0, Errors.INCONSISTENT_PARTITION_ASSIGNMENT_STRATEGY.code) + } else if (consumerId != JoinGroupRequest.UNKNOWN_CONSUMER_ID && !group.has(consumerId)) { + responseCallback(Set.empty, consumerId, 0, Errors.UNKNOWN_CONSUMER_ID.code) + } else if (group.has(consumerId) && group.is(Stable) && topics == group.get(consumerId).topics) { + // TODO: consumer sent a JoinGroupRequest for no reason. How should we handle unexpected consumer requests? + val consumer = group.get(consumerId) + scheduleHeartbeatExpiration(group, consumer) + responseCallback(consumer.assignedTopicPartitions, consumerId, group.generationId, Errors.NONE.code) + } else { + // if the consumer id is unknown, register this consumer to the group + val consumer = if (consumerId == JoinGroupRequest.UNKNOWN_CONSUMER_ID) { + val generatedConsumerId = group.generateNextConsumerId + val consumer = addConsumer(generatedConsumerId, topics, sessionTimeoutMs, group) + maybePrepareRebalance(group) + consumer + } else { + val consumer = group.get(consumerId) + // existing consumer changed its subscribed topics + if (topics != consumer.topics) { + updateConsumer(group, consumer, topics) + maybePrepareRebalance(group) + consumer + } else consumer // existing consumer rejoining a group due to rebalance + } + + consumer.awaitingRebalance = true + + val delayedJoinGroup = new DelayedJoinGroup(this, group, consumer, 2 * MaxSessionTimeoutMs, responseCallback) + val consumerGroupKey = ConsumerGroupKey(group.groupId) + joinGroupPurgatory.tryCompleteElseWatch(delayedJoinGroup, Seq(consumerGroupKey)) + + if (group.is(PreparingRebalance)) + rebalancePurgatory.checkAndComplete(consumerGroupKey) + } + } + } - // if the current group is under rebalance process, - // check if the delayed rebalance operation can be finished - // TODO + def heartbeat(groupId: String, + consumerId: String, + generationId: Int, + responseCallback: Short => Unit) { + if (isLoadingGroupMetadata.get) { + responseCallback(Errors.CONSUMER_COORDINATOR_NOT_AVAILABLE.code) + } else if (!isCoordinatorForGroup(groupId)) { + responseCallback(Errors.NOT_COORDINATOR_FOR_CONSUMER.code) + } else { + val group = inReadLock(coordinatorLock) { + groups.get(groupId).orNull + } + if (group == null) { + responseCallback(Errors.UNKNOWN_CONSUMER_ID.code) + } else { + group synchronized { + if (group.is(Dead)) { + responseCallback(Errors.UNKNOWN_CONSUMER_ID.code) + } else if (!group.has(consumerId)) { + responseCallback(Errors.UNKNOWN_CONSUMER_ID.code) + } else if (generationId != group.generationId) { + responseCallback(Errors.ILLEGAL_GENERATION.code) + } else { + val consumer = group.get(consumerId) + scheduleHeartbeatExpiration(group, consumer) + responseCallback(Errors.NONE.code) + } + } + } + } + } - // TODO -------------------------------------------------------------- - // TODO: this is just a stub for new consumer testing, - // TODO: needs to be replaced with the logic above - // TODO -------------------------------------------------------------- - // just return all the partitions of the subscribed topics - val partitionIdsPerTopic = ZkUtils.getPartitionsForTopics(zkClient, topics) - val partitions = partitionIdsPerTopic.flatMap{ case (topic, partitionIds) => - partitionIds.map(partition => { - TopicAndPartition(topic, partition) - }) - }.toList + private def scheduleHeartbeatExpiration(group: Group, consumer: Consumer) { + consumer.latestHeartbeat = SystemTime.milliseconds + val consumerKey = ConsumerKey(consumer.groupId, consumer.consumerId) + // TODO: can we fix DelayedOperationPurgatory to remove keys in watchersForKey with empty watchers list? + heartbeatPurgatory.checkAndComplete(consumerKey) + val heartbeatDeadline = consumer.latestHeartbeat + consumer.sessionTimeoutMs + val delayedHeartbeat = new DelayedHeartbeat(this, group, consumer, heartbeatDeadline, consumer.sessionTimeoutMs) + heartbeatPurgatory.tryCompleteElseWatch(delayedHeartbeat, Seq(consumerKey)) + } - responseCallback(partitions, 1 /* generation id */, Errors.NONE.code) + private def addGroup(groupId: String, partitionAssignmentStrategy: String) = { + inWriteLock(coordinatorLock) { + groups.getOrElseUpdate(groupId, new Group(groupId, partitionAssignmentStrategy)) + } + } - info("Handled join-group from consumer " + consumerId + " to group " + groupId) + private def removeGroup(group: Group) { + group.transitionTo(Dead) + info("Group %s generation %s is dead".format(group.groupId, group.generationId)) + inWriteLock(coordinatorLock) { + groups.remove(group.groupId) + } } - /** - * Process a heartbeat request from a consumer - */ - def consumerHeartbeat(groupId: String, - consumerId: String, - generationId: Int, - responseCallback: Short => Unit) { - - // check that the group already exists - // TODO - - // check that the consumer has already registered for the group - // TODO - - // check if the consumer generation id is correct - // TODO - - // remove the consumer from its current heartbeat bucket, and add it back to the corresponding bucket - // TODO - - // create the heartbeat response, if partition rebalance is triggered set the corresponding error code - // TODO - - info("Handled heartbeat of consumer " + consumerId + " from group " + groupId) - - // TODO -------------------------------------------------------------- - // TODO: this is just a stub for new consumer testing, - // TODO: needs to be replaced with the logic above - // TODO -------------------------------------------------------------- - // check if the consumer already exist, if yes return OK, - // otherwise return illegal generation error - if (consumerGroupRegistries.contains(groupId) - && consumerGroupRegistries(groupId).memberRegistries.contains(consumerId)) - responseCallback(Errors.NONE.code) - else - responseCallback(Errors.ILLEGAL_GENERATION.code) + private def addConsumer(consumerId: String, + topics: Set[String], + sessionTimeoutMs: Int, + group: Group) = { + val consumer = new Consumer(consumerId, group.groupId, topics, sessionTimeoutMs) + inWriteLock(coordinatorLock) { + topics.foreach(topic => associateGroupWithTopic(group, topic)) + } + group.add(consumerId, consumer) + consumer + } + + private def removeConsumer(group: Group, consumer: Consumer) { + trace("Consumer %s in group %s has failed".format(consumer.consumerId, group.groupId)) + group.remove(consumer.consumerId) + val remainingTopicsForGroup = group.topics + inWriteLock(coordinatorLock) { + consumer.topics.foreach { topic => + // nobody else in the group was interested in the topic, so dissociate the group from the topic + if (!remainingTopicsForGroup.contains(topic)) + dissociateGroupFromTopic(group, topic) + } + } } /** - * Create a new consumer + * Update a consumer when their subscribed topics change. */ - private def createNewConsumer(groupId: String, - consumerId: String, - topics: List[String], - sessionTimeoutMs: Int) { - debug("Registering consumer " + consumerId + " for group " + groupId) - - // create the new consumer registry entry - val consumerRegistry = new ConsumerRegistry(groupId, consumerId, topics, sessionTimeoutMs) - - consumerGroupRegistries(groupId).memberRegistries.put(consumerId, consumerRegistry) - - // check if the partition assignment strategy is consistent with the group - // TODO + private def updateConsumer(group: Group, consumer: Consumer, topics: Set[String]) { + val unsubscribedTopics = consumer.topics -- topics + val newlySubscribedTopics = topics -- consumer.topics + consumer.topics = topics + val remainingTopicsForGroup = group.topics + inWriteLock(coordinatorLock) { + unsubscribedTopics.foreach { topic => + // nobody else in the group was interested in the topic, so dissociate the group from the topic + if (!remainingTopicsForGroup.contains(topic)) + dissociateGroupFromTopic(group, topic) + } + newlySubscribedTopics.foreach(topic => associateGroupWithTopic(group, topic)) + } + } - // add the group to the subscribed topics - // TODO + private def associateGroupWithTopic(group: Group, topic: String) { + val currentGroupsForTopic = groupsPerTopic.getOrElse(topic, Set.empty) + groupsPerTopic.put(topic, currentGroupsForTopic + group.groupId) + topicPartitionCounts.getOrElseUpdate(topic, getTopicPartitionCountFromZK(topic)) + registerTopicPartitionChangeListener(topic) + } - // schedule heartbeat tasks for the consumer - // TODO + private def dissociateGroupFromTopic(group: Group, topic: String) { + val remainingGroupsForTopic = groupsPerTopic(topic) - group.groupId + // no other group cares about the topic, so erase all state associated with the topic + if (remainingGroupsForTopic.isEmpty) { + groupsPerTopic.remove(topic) + topicPartitionCounts.remove(topic) + deregisterTopicPartitionChangeListener(topic) + } else { + groupsPerTopic.put(topic, remainingGroupsForTopic) + } + } - // add the member registry entry to the group - // TODO + private def maybePrepareRebalance(group: Group) { + if (group.canRebalance) + prepareRebalance(group) + } - // start preparing group partition rebalance - // TODO + private def prepareRebalance(group: Group) { + group.transitionTo(PreparingRebalance) + group.generationId += 1 + info("Preparing to rebalance group %s generation %s".format(group.groupId, group.generationId)) - info("Registered consumer " + consumerId + " for group " + groupId) + val rebalanceTimeout = group.rebalanceTimeout + val delayedRebalance = new DelayedRebalance(this, group, rebalanceTimeout) + val consumerGroupKey = ConsumerGroupKey(group.groupId) + rebalancePurgatory.tryCompleteElseWatch(delayedRebalance, Seq(consumerGroupKey)) } - /** - * Create a new consumer group in the registry - */ - private def createNewGroup(groupId: String, partitionAssignmentStrategy: String) { - debug("Creating new group " + groupId) - - val groupRegistry = new GroupRegistry(groupId, partitionAssignmentStrategy) + private def rebalance(group: Group) { + group.transitionTo(Rebalancing) + info("Rebalancing group %s generation %s".format(group.groupId, group.generationId)) - consumerGroupRegistries.put(groupId, groupRegistry) + val assignedPartitionsPerConsumer = reassignPartitions(group) + trace("Rebalance for group %s generation %s has assigned partitions: %s" + .format(group.groupId, group.generationId, assignedPartitionsPerConsumer)) - info("Created new group registry " + groupId) + group.transitionTo(Stable) + info("Stabilized group %s generation %s".format(group.groupId, group.generationId)) + val consumerGroupKey = ConsumerGroupKey(group.groupId) + joinGroupPurgatory.checkAndComplete(consumerGroupKey) } - /** - * Callback invoked when a consumer's heartbeat has expired - */ - private def onConsumerHeartbeatExpired(groupId: String, consumerId: String) { - - // if the consumer does not exist in group registry anymore, do nothing - // TODO + private def onRebalanceFailure(group: Group, failedConsumers: List[Consumer]) { + failedConsumers.foreach { failedConsumer => + removeConsumer(group, failedConsumer) + // TODO: cut the socket connection to the consumer + } - // record heartbeat failure - // TODO + if (group.isEmpty) + removeGroup(group) + } + private def onConsumerHeartbeatExpired(group: Group, consumer: Consumer) { + consumer.numExpiredHeartbeat += 1 // if the maximum failures has been reached, mark consumer as failed - // TODO + // TODO: figure out a value for maximum failures + if (consumer.numExpiredHeartbeat > 0) + onConsumerFailure(group, consumer) } - /** - * Callback invoked when a consumer is marked as failed - */ - private def onConsumerFailure(groupId: String, consumerId: String) { - - // remove the consumer from its group registry metadata - // TODO - - // cut the socket connection to the consumer - // TODO: howto ?? - - // if the group has no consumer members any more, remove the group - // otherwise start preparing group partition rebalance - // TODO - + private def onConsumerFailure(group: Group, consumer: Consumer) { + removeConsumer(group, consumer) + maybePrepareRebalance(group) } - /** - * Prepare partition rebalance for the group - */ - private def prepareRebalance(groupId: String) { - - // try to change the group state to PrepareRebalance + private def isCoordinatorForGroup(groupId: String) = offsetManager.leaderIsLocal(offsetManager.partitionFor(groupId)) - // add a task to the delayed rebalance purgatory + private def reassignPartitions(group: Group) = { + val assignor = PartitionAssignor.createInstance(group.partitionAssignmentStrategy) + val topicsPerConsumer = group.topicsPerConsumer + val partitionsPerTopic = inReadLock(coordinatorLock) { + topicPartitionCounts.toMap + } + val assignedPartitionsPerConsumer = assignor.assign(topicsPerConsumer, partitionsPerTopic) + assignedPartitionsPerConsumer.foreach { case (consumerId, partitions) => + group.get(consumerId).assignedTopicPartitions = partitions + } + assignedPartitionsPerConsumer + } - // TODO + private def registerTopicPartitionChangeListener(topic: String) { + inWriteLock(coordinatorLock) { + if (!topicPartitionChangeListeners.contains(topic)) { + val listener = new TopicPartitionChangeListener(config) + topicPartitionChangeListeners.put(topic, listener) + zkClient.subscribeDataChanges(ZkUtils.getTopicPath(topic), listener) + } + } } - /** - * Start partition rebalance for the group - */ - private def startRebalance(groupId: String) { + private def deregisterTopicPartitionChangeListener(topic: String) { + inWriteLock(coordinatorLock) { + val listener = topicPartitionChangeListeners(topic) + zkClient.unsubscribeDataChanges(ZkUtils.getTopicPath(topic), listener) + topicPartitionChangeListeners.remove(topic) + } + } - // try to change the group state to UnderRebalance + private def getTopicPartitionCountFromZK(topic: String) = { + val topicData = ZkUtils.getPartitionAssignmentForTopics(zkClient, Seq(topic)) + topicData(topic).size + } - // compute new assignment based on the strategy + def tryCompleteJoinGroup(group: Group, forceComplete: () => Boolean) = { + group synchronized { + if (group.is(Stable)) + forceComplete() + else false + } + } - // send back the join-group response + def onExpirationJoinGroup() {} + def onCompleteJoinGroup(group: Group, + consumer: Consumer, + responseCallback:(Set[TopicAndPartition], String, Int, Short) => Unit) { + group synchronized { + consumer.awaitingRebalance = false + scheduleHeartbeatExpiration(group, consumer) + responseCallback(consumer.assignedTopicPartitions, consumer.consumerId, group.generationId, Errors.NONE.code) + } + } - // TODO + def tryCompleteRebalance(group: Group, forceComplete: () => Boolean) = { + group synchronized { + if (group.allConsumersRejoined) + forceComplete() + else false + } } - /** - * Fail current partition rebalance for the group - */ + def onExpirationRebalance() {} + def onCompleteRebalance(group: Group) { + group synchronized { + val failedConsumers = group.nonRejoinedConsumers + if (group.isEmpty || !failedConsumers.isEmpty) + onRebalanceFailure(group, failedConsumers) + if (!group.is(Dead)) + rebalance(group) + } + } - /** - * Register ZK listeners for topic-partition changes - */ - private def registerTopicChangeListener(topic: String) = { - if (!topicPartitionChangeListeners.contains(topic)) { - val listener = new TopicPartitionChangeListener(config) - topicPartitionChangeListeners.put(topic, listener) - ZkUtils.makeSurePersistentPathExists(zkClient, ZkUtils.getTopicPath(topic)) - zkClient.subscribeChildChanges(ZkUtils.getTopicPath(topic), listener) + def tryCompleteHeartbeat(group: Group, consumer: Consumer, heartbeatDeadline: Long, forceComplete: () => Boolean) = { + group synchronized { + if (shouldKeepConsumerAlive(consumer, heartbeatDeadline)) + forceComplete() + else false } } - /** - * De-register ZK listeners for topic-partition changes - */ - private def deregisterTopicChangeListener(topic: String) = { - val listener = topicPartitionChangeListeners.get(topic).get - zkClient.unsubscribeChildChanges(ZkUtils.getTopicPath(topic), listener) - topicPartitionChangeListeners.remove(topic) + def onExpirationHeartbeat(group: Group, consumer: Consumer, heartbeatDeadline: Long) { + group synchronized { + if (!shouldKeepConsumerAlive(consumer, heartbeatDeadline)) + onConsumerHeartbeatExpired(group, consumer) + } } + def onCompleteHeartbeat() {} + + private def shouldKeepConsumerAlive(consumer: Consumer, heartbeatDeadline: Long) = + consumer.awaitingRebalance || consumer.latestHeartbeat > heartbeatDeadline - consumer.sessionTimeoutMs + /** * Zookeeper listener that catch topic-partition changes */ - class TopicPartitionChangeListener(val config: KafkaConfig) extends IZkChildListener with Logging { - - this.logIdent = "[TopicChangeListener on coordinator " + config.brokerId + "]: " - - /** - * Try to trigger a rebalance for each group subscribed in the changed topic - * - * @throws Exception - * On any error. - */ - def handleChildChange(parentPath: String , curChilds: java.util.List[String]) { - debug("Fired for path %s with children %s".format(parentPath, curChilds)) - - // get the topic - val topic = parentPath.split("/").last + class TopicPartitionChangeListener(val config: KafkaConfig) extends IZkDataListener with Logging { + this.logIdent = "[TopicPartitionChangeListener on Coordinator " + config.brokerId + "]: " + + override def handleDataChange(dataPath: String, data: Object) { + info("Handling data change for path: %s data: %s".format(dataPath, data)) + val topic = topicFromDataPath(dataPath) + val numPartitions = getTopicPartitionCountFromZK(topic) + + val groupsToRebalance = inWriteLock(coordinatorLock) { + // This condition exists because a consumer can leave between the data change and the coordinatorLock above + if (topicPartitionCounts.contains(topic)) { + topicPartitionCounts.put(topic, numPartitions) + groupsPerTopic(topic).map(groupId => groups(groupId)) + } + else Set.empty[Group] + } + prepareRebalanceAll(groupsToRebalance) + } - // get groups that subscribed to this topic - val groups = consumerGroupsPerTopic.get(topic).get + override def handleDataDeleted(dataPath: String) { + info("Handling data delete for path: %s".format(dataPath)) + val topic = topicFromDataPath(dataPath) + val groupsToRebalance = inWriteLock(coordinatorLock) { + // This condition exists because a consumer can leave between the data delete and the coordinatorLock above + if (topicPartitionCounts.contains(topic)) { + topicPartitionCounts.put(topic, 0) + groupsPerTopic(topic).map(groupId => groups(groupId)) + } + else Set.empty[Group] + } + prepareRebalanceAll(groupsToRebalance) + } - for (groupId <- groups) { - prepareRebalance(groupId) + private def prepareRebalanceAll(groupsToRebalance: Set[Group]) { + groupsToRebalance.foreach { group => + group synchronized { + /** + * This condition exists because a consumer can leave between the coordinatorLock above and the prepareRebalance below. + * We can't wrap this loop in the coordinatorLock because we must preserve the ordering of nested locks to prevent deadlock. + */ + maybePrepareRebalance(group) + } } } + + private def topicFromDataPath(dataPath: String) = { + val nodes = dataPath.split("/") + nodes.last + } } } - - diff --git a/core/src/main/scala/kafka/coordinator/ConsumerRegistry.scala b/core/src/main/scala/kafka/coordinator/ConsumerRegistry.scala deleted file mode 100644 index 2f57970..0000000 --- a/core/src/main/scala/kafka/coordinator/ConsumerRegistry.scala +++ /dev/null @@ -1,52 +0,0 @@ -/** - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package kafka.coordinator - -import java.util.concurrent.atomic.{AtomicBoolean, AtomicInteger} -import java.util.HashMap - -/** - * Consumer registry metadata contains the following metadata: - * - * Heartbeat metadata: - * 1. negotiated heartbeat session timeout. - * 2. recorded number of timed-out heartbeats. - * 3. associated heartbeat bucket in the purgatory. - * - * Subscription metadata: - * 1. subscribed topic list - * 2. assigned partitions for the subscribed topics. - */ -class ConsumerRegistry(val groupId: String, - val consumerId: String, - val topics: List[String], - val sessionTimeoutMs: Int) { - - /* number of expired heartbeat recorded */ - val numExpiredHeartbeat = new AtomicInteger(0) - - /* flag indicating if join group request is received */ - val joinGroupReceived = new AtomicBoolean(false) - - /* assigned partitions per subscribed topic */ - val assignedPartitions = new HashMap[String, List[Int]] - - /* associated heartbeat bucket */ - var currentHeartbeatBucket = null - -} diff --git a/core/src/main/scala/kafka/coordinator/DelayedHeartbeat.scala b/core/src/main/scala/kafka/coordinator/DelayedHeartbeat.scala index 6a6bc7b..b3360cc 100644 --- a/core/src/main/scala/kafka/coordinator/DelayedHeartbeat.scala +++ b/core/src/main/scala/kafka/coordinator/DelayedHeartbeat.scala @@ -20,29 +20,16 @@ package kafka.coordinator import kafka.server.DelayedOperation /** - * Delayed heartbeat operations that are added to the purgatory for session-timeout checking - * - * These operations will always be expired. Once it has expired, all its - * currently contained consumers are marked as heartbeat timed out. + * Delayed heartbeat operations that are added to the purgatory for session timeout checking. + * Heartbeats are paused during rebalance. */ -class DelayedHeartbeat(sessionTimeout: Long, - bucket: HeartbeatBucket, - expireCallback: (String, String) => Unit) +private[coordinator] class DelayedHeartbeat(consumerCoordinator: ConsumerCoordinator, + group: Group, + consumer: Consumer, + heartbeatDeadline: Long, + sessionTimeout: Long) extends DelayedOperation(sessionTimeout) { - - /* this function should never be called */ - override def tryComplete(): Boolean = { - - throw new IllegalStateException("Delayed heartbeat purgatory should never try to complete any bucket") - } - - override def onExpiration() { - // TODO - } - - /* mark all consumers within the heartbeat as heartbeat timed out */ - override def onComplete() { - for (registry <- bucket.consumerRegistryList) - expireCallback(registry.groupId, registry.consumerId) - } + override def tryComplete(): Boolean = consumerCoordinator.tryCompleteHeartbeat(group, consumer, heartbeatDeadline, forceComplete) + override def onExpiration() = consumerCoordinator.onExpirationHeartbeat(group, consumer, heartbeatDeadline) + override def onComplete() = consumerCoordinator.onCompleteHeartbeat() } diff --git a/core/src/main/scala/kafka/coordinator/DelayedJoinGroup.scala b/core/src/main/scala/kafka/coordinator/DelayedJoinGroup.scala index df60cbc..8f57d38 100644 --- a/core/src/main/scala/kafka/coordinator/DelayedJoinGroup.scala +++ b/core/src/main/scala/kafka/coordinator/DelayedJoinGroup.scala @@ -17,6 +17,7 @@ package kafka.coordinator +import kafka.common.TopicAndPartition import kafka.server.DelayedOperation /** @@ -26,23 +27,13 @@ import kafka.server.DelayedOperation * join-group operations will be completed by sending back the response with the * calculated partition assignment. */ -class DelayedJoinGroup(sessionTimeout: Long, - consumerRegistry: ConsumerRegistry, - responseCallback: () => Unit) extends DelayedOperation(sessionTimeout) { - - /* always successfully complete the operation once called */ - override def tryComplete(): Boolean = { - forceComplete() - } - - override def onExpiration() { - // TODO - } - - /* always assume the partition is already assigned as this delayed operation should never time-out */ - override def onComplete() { - - // TODO - responseCallback - } -} \ No newline at end of file +private[coordinator] class DelayedJoinGroup(consumerCoordinator: ConsumerCoordinator, + group: Group, + consumer: Consumer, + sessionTimeout: Long, + responseCallback:(Set[TopicAndPartition], String, Int, Short) => Unit) + extends DelayedOperation(sessionTimeout) { + override def tryComplete(): Boolean = consumerCoordinator.tryCompleteJoinGroup(group, forceComplete) + override def onExpiration() = consumerCoordinator.onExpirationJoinGroup() + override def onComplete() = consumerCoordinator.onCompleteJoinGroup(group, consumer, responseCallback) +} diff --git a/core/src/main/scala/kafka/coordinator/DelayedRebalance.scala b/core/src/main/scala/kafka/coordinator/DelayedRebalance.scala index 8defa2e..12e77ec 100644 --- a/core/src/main/scala/kafka/coordinator/DelayedRebalance.scala +++ b/core/src/main/scala/kafka/coordinator/DelayedRebalance.scala @@ -18,7 +18,6 @@ package kafka.coordinator import kafka.server.DelayedOperation -import java.util.concurrent.atomic.AtomicBoolean /** @@ -31,36 +30,12 @@ import java.util.concurrent.atomic.AtomicBoolean * the group are marked as failed, and complete this operation to proceed rebalance with * the rest of the group. */ -class DelayedRebalance(sessionTimeout: Long, - groupRegistry: GroupRegistry, - rebalanceCallback: String => Unit, - failureCallback: (String, String) => Unit) +private[coordinator] class DelayedRebalance(consumerCoordinator: ConsumerCoordinator, + group: Group, + sessionTimeout: Long) extends DelayedOperation(sessionTimeout) { - val allConsumersJoinedGroup = new AtomicBoolean(false) - - /* check if all known consumers have requested to re-join group */ - override def tryComplete(): Boolean = { - allConsumersJoinedGroup.set(groupRegistry.memberRegistries.values.foldLeft - (true) ((agg, cur) => agg && cur.joinGroupReceived.get())) - - if (allConsumersJoinedGroup.get()) - forceComplete() - else - false - } - - override def onExpiration() { - // TODO - } - - /* mark consumers that have not re-joined group as failed and proceed to rebalance the rest of the group */ - override def onComplete() { - groupRegistry.memberRegistries.values.foreach(consumerRegistry => - if (!consumerRegistry.joinGroupReceived.get()) - failureCallback(groupRegistry.groupId, consumerRegistry.consumerId) - ) - - rebalanceCallback(groupRegistry.groupId) - } + override def tryComplete(): Boolean = consumerCoordinator.tryCompleteRebalance(group, forceComplete) + override def onExpiration() = consumerCoordinator.onExpirationRebalance() + override def onComplete() = consumerCoordinator.onCompleteRebalance(group) } diff --git a/core/src/main/scala/kafka/coordinator/Group.scala b/core/src/main/scala/kafka/coordinator/Group.scala new file mode 100644 index 0000000..63817c4 --- /dev/null +++ b/core/src/main/scala/kafka/coordinator/Group.scala @@ -0,0 +1,131 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.coordinator + +import kafka.utils.nonthreadsafe + +import java.util.UUID + +import collection.mutable + +private[coordinator] sealed trait GroupState { def state: Byte } + +/** + * Consumer group is preparing to rebalance + * + * action: respond to heartbeats with an ILLEGAL GENERATION error code + * transition: some consumers have joined by the timeout => Rebalancing + * all consumers have left the group => Dead + */ +private[coordinator] case object PreparingRebalance extends GroupState { val state: Byte = 1 } + +/** + * Consumer group is rebalancing + * + * action: compute the group's partition assignment + * send the join-group response with new partition assignment when rebalance is complete + * transition: partition assignment has been computed => Stable + */ +private[coordinator] case object Rebalancing extends GroupState { val state: Byte = 2 } + +/** + * Consumer group is stable + * + * action: respond to consumer heartbeats normally + * transition: consumer failure detected via heartbeat => PreparingRebalance + * consumer join-group received => PreparingRebalance + * zookeeper topic watcher fired => PreparingRebalance + */ +private[coordinator] case object Stable extends GroupState { val state: Byte = 3 } + +/** + * Consumer group has no more members + * + * action: none + * transition: none + */ +private[coordinator] case object Dead extends GroupState { val state: Byte = 4 } + + +/** + * A group contains the following metadata: + * + * Membership metadata: + * 1. Consumers registered in this group + * 2. Partition assignment strategy for this group + * + * State metadata: + * 1. group state + * 2. generation id + */ +@nonthreadsafe +private[coordinator] class Group(val groupId: String, + val partitionAssignmentStrategy: String) { + + private val validPreviousStates: Map[GroupState, Set[GroupState]] = + Map(Dead -> Set(PreparingRebalance), + Stable -> Set(Rebalancing), + PreparingRebalance -> Set(Stable), + Rebalancing -> Set(PreparingRebalance)) + + private val consumers = new mutable.HashMap[String, Consumer] + private var state: GroupState = Stable + var generationId = 0 + + def is(groupState: GroupState) = state == groupState + def has(consumerId: String) = consumers.contains(consumerId) + def get(consumerId: String) = consumers(consumerId) + + def add(consumerId: String, consumer: Consumer) { + consumers.put(consumerId, consumer) + } + + def remove(consumerId: String) { + consumers.remove(consumerId) + } + + def isEmpty = consumers.isEmpty + + def topicsPerConsumer = consumers.mapValues(_.topics).toMap + + def topics = consumers.values.flatMap(_.topics).toSet + + def allConsumersRejoined = consumers.values.forall(_.awaitingRebalance) + + def nonRejoinedConsumers = consumers.values.filter(!_.awaitingRebalance).toList + + def rebalanceTimeout = consumers.values.foldLeft(0) {(timeout, consumer) => + timeout.max(consumer.sessionTimeoutMs) + } + + // TODO: decide if ids should be predictable or random + def generateNextConsumerId = UUID.randomUUID().toString + + def canRebalance = state == Stable + + def transitionTo(groupState: GroupState) { + assertValidTransition(groupState) + state = groupState + } + + private def assertValidTransition(targetState: GroupState) { + if (!validPreviousStates(targetState).contains(state)) + throw new IllegalStateException("Group %s should be in the %s states before moving to %s state. Instead it is in %s state" + .format(groupId, validPreviousStates(targetState).mkString(","), targetState, state)) + } +} \ No newline at end of file diff --git a/core/src/main/scala/kafka/coordinator/GroupRegistry.scala b/core/src/main/scala/kafka/coordinator/GroupRegistry.scala deleted file mode 100644 index 94ef582..0000000 --- a/core/src/main/scala/kafka/coordinator/GroupRegistry.scala +++ /dev/null @@ -1,79 +0,0 @@ -/** - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package kafka.coordinator - -import scala.collection.mutable -import java.util.concurrent.atomic.AtomicInteger - -sealed trait GroupStates { def state: Byte } - -/** - * Consumer group is preparing start rebalance - * - * action: respond consumer heartbeat with error code, - * transition: all known consumers has re-joined group => UnderRebalance - */ -case object PrepareRebalance extends GroupStates { val state: Byte = 1 } - -/** - * Consumer group is under rebalance - * - * action: send the join-group response with new assignment - * transition: all consumers has heartbeat with the new generation id => Fetching - * new consumer join-group received => PrepareRebalance - */ -case object UnderRebalance extends GroupStates { val state: Byte = 2 } - -/** - * Consumer group is fetching data - * - * action: respond consumer heartbeat normally - * transition: consumer failure detected via heartbeat => PrepareRebalance - * consumer join-group received => PrepareRebalance - * zookeeper watcher fired => PrepareRebalance - */ -case object Fetching extends GroupStates { val state: Byte = 3 } - -case class GroupState() { - @volatile var currentState: Byte = PrepareRebalance.state -} - -/* Group registry contains the following metadata of a registered group in the coordinator: - * - * Membership metadata: - * 1. List of consumers registered in this group - * 2. Partition assignment strategy for this group - * - * State metadata: - * 1. Current group state - * 2. Current group generation id - */ -class GroupRegistry(val groupId: String, - val partitionAssignmentStrategy: String) { - - val memberRegistries = new mutable.HashMap[String, ConsumerRegistry]() - - val state: GroupState = new GroupState() - - val generationId = new AtomicInteger(1) - - val nextConsumerId = new AtomicInteger(1) - - def generateNextConsumerId = groupId + "-" + nextConsumerId.getAndIncrement -} - diff --git a/core/src/main/scala/kafka/coordinator/HeartbeatBucket.scala b/core/src/main/scala/kafka/coordinator/HeartbeatBucket.scala deleted file mode 100644 index 821e26e..0000000 --- a/core/src/main/scala/kafka/coordinator/HeartbeatBucket.scala +++ /dev/null @@ -1,36 +0,0 @@ -/** - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package kafka.coordinator - -import scala.collection.mutable - -/** - * A bucket of consumers that are scheduled for heartbeat expiration. - * - * The motivation behind this is to avoid expensive fine-grained per-consumer - * heartbeat expiration but use coarsen-grained methods that group consumers - * with similar deadline together. This will result in some consumers not - * being expired for heartbeats in time but is tolerable. - */ -class HeartbeatBucket(val startMs: Long, endMs: Long) { - - /* The list of consumers that are contained in this bucket */ - val consumerRegistryList = new mutable.HashSet[ConsumerRegistry] - - // TODO -} diff --git a/core/src/main/scala/kafka/coordinator/PartitionAssignor.scala b/core/src/main/scala/kafka/coordinator/PartitionAssignor.scala new file mode 100644 index 0000000..2f2adc5 --- /dev/null +++ b/core/src/main/scala/kafka/coordinator/PartitionAssignor.scala @@ -0,0 +1,121 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.coordinator + +import kafka.common.TopicAndPartition +import kafka.utils.CoreUtils + +private[coordinator] trait PartitionAssignor { + /** + * Assigns partitions to consumers in a group. + * @return A mapping from consumer to assigned partitions. + */ + def assign(topicsPerConsumer: Map[String, Set[String]], + partitionsPerTopic: Map[String, Int]): Map[String, Set[TopicAndPartition]] + + protected def fill[K, V](vsPerK: Map[K, Set[V]], expectedKs: Set[K]): Map[K, Set[V]] = { + val unfilledKs = expectedKs -- vsPerK.keySet + vsPerK ++ unfilledKs.map(k => k -> Set.empty[V]) + } + + protected def aggregate[K, V](pairs: Seq[(K, V)]): Map[K, Set[V]] = { + pairs + .groupBy { case (k, v) => k } + .map { case (k, kvPairs) => (k, kvPairs.map(_._2).toSet) } + } + + protected def invert[K, V](vsPerK: Map[K, Set[V]]): Map[V, Set[K]] = { + val vkPairs = vsPerK.toSeq.flatMap { case (k, vs) => vs.map(v => (v, k)) } + aggregate(vkPairs) + } +} + +private[coordinator] object PartitionAssignor { + val strategies = Set("range", "roundrobin") + + def createInstance(strategy: String) = strategy match { + case "roundrobin" => new RoundRobinAssignor() + case _ => new RangeAssignor() + } +} + +/** + * The roundrobin assignor lays out all the available partitions and all the available consumers. It + * then proceeds to do a roundrobin assignment from partition to consumer. If the subscriptions of all consumer + * instances are identical, then the partitions will be uniformly distributed. (i.e., the partition ownership counts + * will be within a delta of exactly one across all consumers.) + * + * roundrobin assignment is allowed only if the set of subscribed topics is identical for every consumer within the group. + */ +private[coordinator] class RoundRobinAssignor extends PartitionAssignor { + override def assign(topicsPerConsumer: Map[String, Set[String]], + partitionsPerTopic: Map[String, Int]): Map[String, Set[TopicAndPartition]] = { + val consumersHaveIdenticalTopics = topicsPerConsumer.values.toSet.size == 1 + require(consumersHaveIdenticalTopics, + "Round-robin assignment is allowed only if all consumers in the group subscribe to the same topics") + val consumers = topicsPerConsumer.keys.toSeq.sorted + val topics = topicsPerConsumer.head._2 + val consumerAssignor = CoreUtils.circularIterator(consumers) + + val allTopicPartitions = topics.toSeq.flatMap { topic => + val numPartitionsForTopic = partitionsPerTopic(topic) + (0 until numPartitionsForTopic).map(partition => TopicAndPartition(topic, partition)) + } + + val consumerPartitionPairs = allTopicPartitions.map { topicAndPartition => + val consumer = consumerAssignor.next() + (consumer, topicAndPartition) + } + fill(aggregate(consumerPartitionPairs), topicsPerConsumer.keySet) + } +} + +/** + * The range assignor works on a per-topic basis. For each topic, we lay out the available partitions in numeric order + * and the consumers in lexicographic order. We then divide the number of partitions by the total number of + * consumers to determine the number of partitions to assign to each consumer. If it does not evenly + * divide, then the first few consumers will have one extra partition. For example, suppose there are two consumers C1 + * and C2, and there are five available partitions (p0, p1, p2, p3, p4). Each consumer + * will get at least two partitions and the first consumer will get one extra partition. So the assignment will be: + * p0 -> C1, p1 -> C1, p2 -> C1, p3 -> C2, p4 -> C2 + */ +private[coordinator] class RangeAssignor extends PartitionAssignor { + override def assign(topicsPerConsumer: Map[String, Set[String]], + partitionsPerTopic: Map[String, Int]): Map[String, Set[TopicAndPartition]] = { + val consumersPerTopic = invert(topicsPerConsumer) + val consumerPartitionPairs = consumersPerTopic.toSeq.flatMap { case (topic, consumersForTopic) => + val numPartitionsForTopic = partitionsPerTopic(topic) + + val numPartitionsPerConsumer = numPartitionsForTopic / consumersForTopic.size + val consumersWithExtraPartition = numPartitionsForTopic % consumersForTopic.size + + consumersForTopic.toSeq.sorted.zipWithIndex.flatMap { case (consumerForTopic, consumerIndex) => + val startPartition = numPartitionsPerConsumer * consumerIndex + consumerIndex.min(consumersWithExtraPartition) + val numPartitions = numPartitionsPerConsumer + (if (consumerIndex + 1 > consumersWithExtraPartition) 0 else 1) + + /* + * Range-partition the sorted partitions to consumers for better locality. + * The first few consumers pick up an extra partition, if any. + */ + (startPartition until startPartition + numPartitions) + .map(partition => (consumerForTopic, TopicAndPartition(topic, partition))) + } + } + fill(aggregate(consumerPartitionPairs), topicsPerConsumer.keySet) + } +} diff --git a/core/src/main/scala/kafka/server/DelayedOperationKey.scala b/core/src/main/scala/kafka/server/DelayedOperationKey.scala index b673e43..c122bde 100644 --- a/core/src/main/scala/kafka/server/DelayedOperationKey.scala +++ b/core/src/main/scala/kafka/server/DelayedOperationKey.scala @@ -38,12 +38,6 @@ case class TopicPartitionOperationKey(topic: String, partition: Int) extends Del override def keyLabel = "%s-%d".format(topic, partition) } -/* used by bucketized delayed-heartbeat operations */ -case class TTimeMsKey(time: Long) extends DelayedOperationKey { - - override def keyLabel = "%d".format(time) -} - /* used by delayed-join-group operations */ case class ConsumerKey(groupId: String, consumerId: String) extends DelayedOperationKey { diff --git a/core/src/main/scala/kafka/server/KafkaApis.scala b/core/src/main/scala/kafka/server/KafkaApis.scala index b4004aa..9abca4e 100644 --- a/core/src/main/scala/kafka/server/KafkaApis.scala +++ b/core/src/main/scala/kafka/server/KafkaApis.scala @@ -499,17 +499,19 @@ class KafkaApis(val requestChannel: RequestChannel, val respHeader = new ResponseHeader(request.header.correlationId) // the callback for sending a join-group response - def sendResponseCallback(partitions: List[TopicAndPartition], generationId: Int, errorCode: Short) { + def sendResponseCallback(partitions: Set[TopicAndPartition], consumerId: String, generationId: Int, errorCode: Short) { val partitionList = partitions.map(tp => new TopicPartition(tp.topic, tp.partition)).toBuffer - val responseBody = new JoinGroupResponse(errorCode, generationId, joinGroupRequest.consumerId, partitionList) + val responseBody = new JoinGroupResponse(errorCode, generationId, consumerId, partitionList) + trace("Sending join group response %s for correlation id %d to client %s." + .format(responseBody, request.header.correlationId, request.header.clientId)) requestChannel.sendResponse(new RequestChannel.Response(request, new BoundedByteBufferSend(respHeader, responseBody))) } // let the coordinator to handle join-group - coordinator.consumerJoinGroup( + coordinator.joinGroup( joinGroupRequest.groupId(), joinGroupRequest.consumerId(), - joinGroupRequest.topics().toList, + joinGroupRequest.topics().toSet, joinGroupRequest.sessionTimeout(), joinGroupRequest.strategy(), sendResponseCallback) @@ -521,12 +523,14 @@ class KafkaApis(val requestChannel: RequestChannel, // the callback for sending a heartbeat response def sendResponseCallback(errorCode: Short) { - val response = new HeartbeatResponse(errorCode) - requestChannel.sendResponse(new RequestChannel.Response(request, new BoundedByteBufferSend(respHeader, response))) + val responseBody = new HeartbeatResponse(errorCode) + trace("Sending join group response %s for correlation id %d to client %s." + .format(responseBody, request.header.correlationId, request.header.clientId)) + requestChannel.sendResponse(new RequestChannel.Response(request, new BoundedByteBufferSend(respHeader, responseBody))) } // let the coordinator to handle heartbeat - coordinator.consumerHeartbeat( + coordinator.heartbeat( heartbeatRequest.groupId(), heartbeatRequest.consumerId(), heartbeatRequest.groupGenerationId(), diff --git a/core/src/main/scala/kafka/server/KafkaServer.scala b/core/src/main/scala/kafka/server/KafkaServer.scala index c63f4ba..1d44107 100755 --- a/core/src/main/scala/kafka/server/KafkaServer.scala +++ b/core/src/main/scala/kafka/server/KafkaServer.scala @@ -141,7 +141,7 @@ class KafkaServer(val config: KafkaConfig, time: Time = SystemTime) extends Logg kafkaController.startup() /* start kafka coordinator */ - consumerCoordinator = new ConsumerCoordinator(config, zkClient) + consumerCoordinator = new ConsumerCoordinator(config, zkClient, offsetManager) consumerCoordinator.startup() /* start processing requests */ diff --git a/core/src/main/scala/kafka/server/OffsetManager.scala b/core/src/main/scala/kafka/server/OffsetManager.scala index 18680ce..0ae44e7 100755 --- a/core/src/main/scala/kafka/server/OffsetManager.scala +++ b/core/src/main/scala/kafka/server/OffsetManager.scala @@ -427,7 +427,7 @@ class OffsetManager(val config: OffsetManagerConfig, hw } - private def leaderIsLocal(partition: Int) = { getHighWatermark(partition) != -1L } + def leaderIsLocal(partition: Int) = { getHighWatermark(partition) != -1L } /** * When this broker becomes a follower for an offsets topic partition clear out the cache for groups that belong to diff --git a/core/src/test/scala/unit/kafka/coordinator/GroupTest.scala b/core/src/test/scala/unit/kafka/coordinator/GroupTest.scala new file mode 100644 index 0000000..4fdac12 --- /dev/null +++ b/core/src/test/scala/unit/kafka/coordinator/GroupTest.scala @@ -0,0 +1,155 @@ +package kafka.coordinator + +import junit.framework.Assert._ +import org.junit.{Before, Test} +import org.scalatest.junit.JUnitSuite + +/** + * Test group state transitions + */ +class GroupTest extends JUnitSuite { + var group: Group = null + + @Before + def setUp() { + group = new Group("test", "range") + } + + @Test + def testCanRebalanceWhenStable() { + assertTrue(group.canRebalance) + } + + @Test + def testCannotRebalanceWhenPreparingRebalance() { + group.transitionTo(PreparingRebalance) + assertFalse(group.canRebalance) + } + + @Test + def testCannotRebalanceWhenRebalancing() { + group.transitionTo(PreparingRebalance) + group.transitionTo(Rebalancing) + assertFalse(group.canRebalance) + } + + @Test + def testCannotRebalanceWhenDead() { + group.transitionTo(PreparingRebalance) + group.transitionTo(Dead) + assertFalse(group.canRebalance) + } + + @Test + def testStableToPreparingRebalanceTransition() { + group.transitionTo(PreparingRebalance) + assertState(group, PreparingRebalance) + } + + @Test + def testPreparingRebalanceToRebalancingTransition() { + group.transitionTo(PreparingRebalance) + group.transitionTo(Rebalancing) + assertState(group, Rebalancing) + } + + @Test + def testPreparingRebalanceToDeadTransition() { + group.transitionTo(PreparingRebalance) + group.transitionTo(Dead) + assertState(group, Dead) + } + + @Test + def testRebalancingToStableTransition() { + group.transitionTo(PreparingRebalance) + group.transitionTo(Rebalancing) + group.transitionTo(Stable) + assertState(group, Stable) + } + + @Test(expected = classOf[IllegalStateException]) + def testStableToStableIllegalTransition() { + group.transitionTo(Stable) + } + + @Test(expected = classOf[IllegalStateException]) + def testStableToRebalancingIllegalTransition() { + group.transitionTo(Rebalancing) + } + + @Test(expected = classOf[IllegalStateException]) + def testStableToDeadIllegalTransition() { + group.transitionTo(Dead) + } + + @Test(expected = classOf[IllegalStateException]) + def testPreparingRebalanceToPreparingRebalanceIllegalTransition() { + group.transitionTo(PreparingRebalance) + group.transitionTo(PreparingRebalance) + } + + @Test(expected = classOf[IllegalStateException]) + def testPreparingRebalanceToStableIllegalTransition() { + group.transitionTo(PreparingRebalance) + group.transitionTo(Stable) + } + + @Test(expected = classOf[IllegalStateException]) + def testRebalancingToRebalancingIllegalTransition() { + group.transitionTo(PreparingRebalance) + group.transitionTo(Rebalancing) + group.transitionTo(Rebalancing) + } + + @Test(expected = classOf[IllegalStateException]) + def testRebalancingToPreparingRebalanceTransition() { + group.transitionTo(PreparingRebalance) + group.transitionTo(Rebalancing) + group.transitionTo(PreparingRebalance) + } + + @Test(expected = classOf[IllegalStateException]) + def testRebalancingToDeadIllegalTransition() { + group.transitionTo(PreparingRebalance) + group.transitionTo(Rebalancing) + group.transitionTo(Dead) + } + + @Test(expected = classOf[IllegalStateException]) + def testDeadToDeadIllegalTransition() { + group.transitionTo(PreparingRebalance) + group.transitionTo(Dead) + group.transitionTo(Dead) + } + + @Test(expected = classOf[IllegalStateException]) + def testDeadToStableIllegalTransition() { + group.transitionTo(PreparingRebalance) + group.transitionTo(Dead) + group.transitionTo(Stable) + } + + @Test(expected = classOf[IllegalStateException]) + def testDeadToPreparingRebalanceIllegalTransition() { + group.transitionTo(PreparingRebalance) + group.transitionTo(Dead) + group.transitionTo(PreparingRebalance) + } + + @Test(expected = classOf[IllegalStateException]) + def testDeadToRebalancingIllegalTransition() { + group.transitionTo(PreparingRebalance) + group.transitionTo(Dead) + group.transitionTo(Rebalancing) + } + + private def assertState(group: Group, targetState: GroupState) { + val states: Set[GroupState] = Set(Stable, PreparingRebalance, Rebalancing, Dead) + val otherStates = states - targetState + otherStates.foreach { otherState => + assertFalse(group.is(otherState)) + } + assertTrue(group.is(targetState)) + } +} diff --git a/core/src/test/scala/unit/kafka/coordinator/PartitionAssignorTest.scala b/core/src/test/scala/unit/kafka/coordinator/PartitionAssignorTest.scala new file mode 100644 index 0000000..2e7efa2 --- /dev/null +++ b/core/src/test/scala/unit/kafka/coordinator/PartitionAssignorTest.scala @@ -0,0 +1,251 @@ +package kafka.coordinator + +import kafka.common.TopicAndPartition +import junit.framework.Assert._ +import org.junit.Test +import org.scalatest.junit.JUnitSuite + +class PartitionAssignorTest extends JUnitSuite { + + @Test + def testRangeAssignorOneConsumerNoTopic() { + val consumer = "consumer" + val assignor = new RangeAssignor() + val topicsPerConsumer = Map(consumer -> Set.empty[String]) + val partitionsPerTopic = Map.empty[String, Int] + val actual = assignor.assign(topicsPerConsumer, partitionsPerTopic) + val expected = Map(consumer -> Set.empty[TopicAndPartition]) + assertEquals(expected, actual) + } + + @Test + def testRangeAssignorOneConsumerNonexistentTopic() { + val topic = "topic" + val consumer = "consumer" + val assignor = new RangeAssignor() + val topicsPerConsumer = Map(consumer -> Set(topic)) + val partitionsPerTopic = Map(topic -> 0) + val actual = assignor.assign(topicsPerConsumer, partitionsPerTopic) + val expected = Map(consumer -> Set.empty[TopicAndPartition]) + assertEquals(expected, actual) + } + + @Test + def testRangeAssignorOneConsumerOneTopic() { + val topic = "topic" + val consumer = "consumer" + val numPartitions = 3 + val assignor = new RangeAssignor() + val topicsPerConsumer = Map(consumer -> Set(topic)) + val partitionsPerTopic = Map(topic -> numPartitions) + val actual = assignor.assign(topicsPerConsumer, partitionsPerTopic) + val expected = Map(consumer -> topicAndPartitions(topic, numPartitions)) + assertEquals(expected, actual) + } + + @Test + def testRangeAssignorOnlyAssignsPartitionsFromSubscribedTopics() { + val subscribedTopic = "topic" + val otherTopic = "other" + val consumer = "consumer" + val subscribedTopicNumPartitions = 3 + val otherTopicNumPartitions = 3 + val assignor = new RangeAssignor() + val topicsPerConsumer = Map(consumer -> Set(subscribedTopic)) + val partitionsPerTopic = Map(subscribedTopic -> subscribedTopicNumPartitions, otherTopic -> otherTopicNumPartitions) + val actual = assignor.assign(topicsPerConsumer, partitionsPerTopic) + val expected = Map(consumer -> topicAndPartitions(subscribedTopic, subscribedTopicNumPartitions)) + assertEquals(expected, actual) + } + + @Test + def testRangeAssignorOneConsumerMultipleTopics() { + val topic1 = "topic1" + val topic2 = "topic2" + val consumer = "consumer" + val numTopic1Partitions = 1 + val numTopic2Partitions = 2 + val assignor = new RangeAssignor() + val topicsPerConsumer = Map(consumer -> Set(topic1, topic2)) + val partitionsPerTopic = Map(topic1 -> numTopic1Partitions, topic2 -> numTopic2Partitions) + val actual = assignor.assign(topicsPerConsumer, partitionsPerTopic) + val expected = Map(consumer -> (topicAndPartitions(topic1, numTopic1Partitions) ++ topicAndPartitions(topic2, numTopic2Partitions))) + assertEquals(expected, actual) + } + + @Test + def testRangeAssignorTwoConsumersOneTopicOnePartition() { + val topic = "topic" + val consumer1 = "consumer1" + val consumer2 = "consumer2" + val numPartitions = 1 + val assignor = new RangeAssignor() + val topicsPerConsumer = Map(consumer1 -> Set(topic), consumer2 -> Set(topic)) + val partitionsPerTopic = Map(topic -> numPartitions) + val actual = assignor.assign(topicsPerConsumer, partitionsPerTopic) + val expected = Map(consumer1 -> topicAndPartitions(topic, numPartitions), consumer2 -> Set.empty) + assertEquals(expected, actual) + } + + @Test + def testRangeAssignorTwoConsumersOneTopicTwoPartitions() { + val topic = "topic" + val consumer1 = "consumer1" + val consumer2 = "consumer2" + val numPartitions = 2 + val assignor = new RangeAssignor() + val topicsPerConsumer = Map(consumer1 -> Set(topic), consumer2 -> Set(topic)) + val partitionsPerTopic = Map(topic -> numPartitions) + val actual = assignor.assign(topicsPerConsumer, partitionsPerTopic) + val expected = Map(consumer1 -> topicAndPartitions(topic, 1), consumer2 -> topicAndPartitions(topic, 1, start = 1)) + assertEquals(expected, actual) + } + + @Test + def testRangeAssignorMultipleConsumersMixedTopics() { + val topic1 = "topic1" + val topic2 = "topic2" + val consumer1 = "consumer1" + val consumer2 = "consumer2" + val consumer3 = "consumer3" + val numTopic1Partitions = 3 + val numTopic2Partitions = 2 + val assignor = new RangeAssignor() + val topicsPerConsumer = Map(consumer1 -> Set(topic1), consumer2 -> Set(topic1, topic2), consumer3 -> Set(topic1)) + val partitionsPerTopic = Map(topic1 -> numTopic1Partitions, topic2 -> numTopic2Partitions) + val actual = assignor.assign(topicsPerConsumer, partitionsPerTopic) + val expected = Map(consumer1 -> topicAndPartitions(topic1, 1), + consumer2 -> (topicAndPartitions(topic1, 1, start = 1) ++ topicAndPartitions(topic2, numTopic2Partitions)), + consumer3 -> topicAndPartitions(topic1, 1, start = 2)) + assertEquals(expected, actual) + } + + @Test + def testRoundRobinAssignorOneConsumerNoTopic() { + val consumer = "consumer" + val assignor = new RoundRobinAssignor() + val topicsPerConsumer = Map(consumer -> Set.empty[String]) + val partitionsPerTopic = Map.empty[String, Int] + val actual = assignor.assign(topicsPerConsumer, partitionsPerTopic) + val expected = Map(consumer -> Set.empty[TopicAndPartition]) + assertEquals(expected, actual) + } + + @Test + def testRoundRobinAssignorOneConsumerNonexistentTopic() { + val topic = "topic" + val consumer = "consumer" + val assignor = new RoundRobinAssignor() + val topicsPerConsumer = Map(consumer -> Set(topic)) + val partitionsPerTopic = Map(topic -> 0) + val actual = assignor.assign(topicsPerConsumer, partitionsPerTopic) + val expected = Map(consumer -> Set.empty[TopicAndPartition]) + assertEquals(expected, actual) + } + + @Test + def testRoundRobinAssignorOneConsumerOneTopic() { + val topic = "topic" + val consumer = "consumer" + val numPartitions = 3 + val assignor = new RoundRobinAssignor() + val topicsPerConsumer = Map(consumer -> Set(topic)) + val partitionsPerTopic = Map(topic -> numPartitions) + val actual = assignor.assign(topicsPerConsumer, partitionsPerTopic) + val expected = Map(consumer -> topicAndPartitions(topic, numPartitions)) + assertEquals(expected, actual) + } + + @Test + def testRoundRobinAssignorOnlyAssignsPartitionsFromSubscribedTopics() { + val subscribedTopic = "topic" + val otherTopic = "other" + val consumer = "consumer" + val subscribedTopicNumPartitions = 3 + val otherTopicNumPartitions = 3 + val assignor = new RoundRobinAssignor() + val topicsPerConsumer = Map(consumer -> Set(subscribedTopic)) + val partitionsPerTopic = Map(subscribedTopic -> subscribedTopicNumPartitions, otherTopic -> otherTopicNumPartitions) + val actual = assignor.assign(topicsPerConsumer, partitionsPerTopic) + val expected = Map(consumer -> topicAndPartitions(subscribedTopic, subscribedTopicNumPartitions)) + assertEquals(expected, actual) + } + + @Test + def testRoundRobinAssignorOneConsumerMultipleTopics() { + val topic1 = "topic1" + val topic2 = "topic2" + val consumer = "consumer" + val numTopic1Partitions = 1 + val numTopic2Partitions = 2 + val assignor = new RoundRobinAssignor() + val topicsPerConsumer = Map(consumer -> Set(topic1, topic2)) + val partitionsPerTopic = Map(topic1 -> numTopic1Partitions, topic2 -> numTopic2Partitions) + val actual = assignor.assign(topicsPerConsumer, partitionsPerTopic) + val expected = Map(consumer -> (topicAndPartitions(topic1, numTopic1Partitions) ++ topicAndPartitions(topic2, numTopic2Partitions))) + assertEquals(expected, actual) + } + + @Test + def testRoundRobinAssignorTwoConsumersOneTopicOnePartition() { + val topic = "topic" + val consumer1 = "consumer1" + val consumer2 = "consumer2" + val numPartitions = 1 + val assignor = new RoundRobinAssignor() + val topicsPerConsumer = Map(consumer1 -> Set(topic), consumer2 -> Set(topic)) + val partitionsPerTopic = Map(topic -> numPartitions) + val actual = assignor.assign(topicsPerConsumer, partitionsPerTopic) + val expected = Map(consumer1 -> topicAndPartitions(topic, numPartitions), consumer2 -> Set.empty) + assertEquals(expected, actual) + } + + @Test + def testRoundRobinAssignorTwoConsumersOneTopicTwoPartitions() { + val topic = "topic" + val consumer1 = "consumer1" + val consumer2 = "consumer2" + val numPartitions = 2 + val assignor = new RoundRobinAssignor() + val topicsPerConsumer = Map(consumer1 -> Set(topic), consumer2 -> Set(topic)) + val partitionsPerTopic = Map(topic -> numPartitions) + val actual = assignor.assign(topicsPerConsumer, partitionsPerTopic) + val expected = Map(consumer1 -> topicAndPartitions(topic, 1), consumer2 -> topicAndPartitions(topic, 1, start = 1)) + assertEquals(expected, actual) + } + + @Test(expected = classOf[IllegalArgumentException]) + def testRoundRobinAssignorCannotAssignWithMixedTopics() { + val topic1 = "topic1" + val topic2 = "topic2" + val consumer1 = "consumer1" + val consumer2 = "consumer2" + val consumer3 = "consumer3" + val numTopic1Partitions = 3 + val numTopic2Partitions = 2 + val assignor = new RoundRobinAssignor() + val topicsPerConsumer = Map(consumer1 -> Set(topic1), consumer2 -> Set(topic1, topic2), consumer3 -> Set(topic1)) + val partitionsPerTopic = Map(topic1 -> numTopic1Partitions, topic2 -> numTopic2Partitions) + assignor.assign(topicsPerConsumer, partitionsPerTopic) + } + + @Test + def testRoundRobinAssignorTwoConsumersTwoTopicsSixPartitions() { + val topic1 = "topic1" + val topic2 = "topic2" + val consumer1 = "consumer1" + val consumer2 = "consumer2" + val numTopic1Partitions = 3 + val numTopic2Partitions = 3 + val assignor = new RoundRobinAssignor() + val topicsPerConsumer = Map(consumer1 -> Set(topic1, topic2), consumer2 -> Set(topic1, topic2)) + val partitionsPerTopic = Map(topic1 -> numTopic1Partitions, topic2 -> numTopic2Partitions) + val actual = assignor.assign(topicsPerConsumer, partitionsPerTopic) + val expected = Map(consumer1 -> (topicAndPartitions(topic1, 1) ++ topicAndPartitions(topic1, 1, start = 2) ++ topicAndPartitions(topic2, 1, start = 1)), + consumer2 -> (topicAndPartitions(topic1, 1, start = 1) ++ topicAndPartitions(topic2, 1) ++ topicAndPartitions(topic2, 1, start = 2))) + assertEquals(expected, actual) + } + + private def topicAndPartitions(topic: String, partitions: Int, start: Int = 0) = + (start until start + partitions).map(partition => TopicAndPartition(topic, partition)).toSet +} -- 1.7.12.4