From 0d1c2e73c45c2224a54054124ef4ce4000fb9973 Mon Sep 17 00:00:00 2001 From: Yasuhiro Matsuda Date: Fri, 27 Feb 2015 16:04:46 -0800 Subject: [PATCH] new purgatory implementation --- .../main/scala/kafka/server/DelayedOperation.scala | 174 +++++++++------------ .../main/scala/kafka/server/ReplicaManager.scala | 5 +- .../scala/kafka/utils/SinglyLinkedWeakList.scala | 66 ++++++++ core/src/main/scala/kafka/utils/timer/Timer.scala | 82 ++++++++++ .../main/scala/kafka/utils/timer/TimerTask.scala | 41 +++++ .../scala/kafka/utils/timer/TimerTaskList.scala | 129 +++++++++++++++ .../main/scala/kafka/utils/timer/TimingWheel.scala | 82 ++++++++++ .../unit/kafka/server/DelayedOperationTest.scala | 45 +++--- .../unit/kafka/utils/timer/TimerTaskListTest.scala | 93 +++++++++++ .../scala/unit/kafka/utils/timer/TimerTest.scala | 111 +++++++++++++ 10 files changed, 706 insertions(+), 122 deletions(-) create mode 100644 core/src/main/scala/kafka/utils/SinglyLinkedWeakList.scala create mode 100644 core/src/main/scala/kafka/utils/timer/Timer.scala create mode 100644 core/src/main/scala/kafka/utils/timer/TimerTask.scala create mode 100644 core/src/main/scala/kafka/utils/timer/TimerTaskList.scala create mode 100644 core/src/main/scala/kafka/utils/timer/TimingWheel.scala create mode 100644 core/src/test/scala/unit/kafka/utils/timer/TimerTaskListTest.scala create mode 100644 core/src/test/scala/unit/kafka/utils/timer/TimerTest.scala diff --git a/core/src/main/scala/kafka/server/DelayedOperation.scala b/core/src/main/scala/kafka/server/DelayedOperation.scala index e317676..2f269ec 100644 --- a/core/src/main/scala/kafka/server/DelayedOperation.scala +++ b/core/src/main/scala/kafka/server/DelayedOperation.scala @@ -18,11 +18,13 @@ package kafka.server import kafka.utils._ +import kafka.utils.timer._ import kafka.metrics.KafkaMetricsGroup -import java.util +import java.lang.ref.ReferenceQueue import java.util.concurrent._ import java.util.concurrent.atomic._ + import scala.collection._ import com.yammer.metrics.core.Gauge @@ -41,7 +43,10 @@ import com.yammer.metrics.core.Gauge * * A subclass of DelayedOperation needs to provide an implementation of both onComplete() and tryComplete(). */ -abstract class DelayedOperation(delayMs: Long) extends DelayedItem(delayMs) { +abstract class DelayedOperation(delayMs: Long) extends TimerTask with Logging { + + override val expirationMs = delayMs + System.currentTimeMillis() + private val completed = new AtomicBoolean(false) /* @@ -58,6 +63,8 @@ abstract class DelayedOperation(delayMs: Long) extends DelayedItem(delayMs) { */ def forceComplete(): Boolean = { if (completed.compareAndSet(false, true)) { + // cancel the timeout timer + cancel() onComplete() true } else { @@ -89,19 +96,30 @@ abstract class DelayedOperation(delayMs: Long) extends DelayedItem(delayMs) { * This function needs to be defined in subclasses */ def tryComplete(): Boolean + + /* + * A task that runs on timeout + */ + def run(): Unit = { + forceComplete() + } } /** * A helper purgatory class for bookkeeping delayed operations with a timeout, and expiring timed out operations. */ -class DelayedOperationPurgatory[T <: DelayedOperation](purgatoryName: String, brokerId: Int = 0, purgeInterval: Int = 1000) +class DelayedOperationPurgatory[T <: DelayedOperation](purgatoryName: String, brokerId: Int = 0) extends Logging with KafkaMetricsGroup { + // timeout timer + private[this] val executor = Executors.newSingleThreadExecutor() + private[this] val timeoutTimer = new Timer(executor) + /* a list of operation watching keys */ private val watchersForKey = new Pool[Any, Watchers](Some((key: Any) => new Watchers)) /* background thread expiring operations that have timed out */ - private val expirationReaper = new ExpiredOperationReaper + private val expirationReaper = new ExpiredOperationReaper(timeoutTimer) private val metricsTags = Map("delayedOperation" -> purgatoryName) @@ -166,8 +184,13 @@ class DelayedOperationPurgatory[T <: DelayedOperation](purgatoryName: String, br return true // if it cannot be completed by now and hence is watched, add to the expire queue also - if (! operation.isCompleted()) - expirationReaper.enqueue(operation) + if (! operation.isCompleted()) { + timeoutTimer.add(operation) + if (operation.isCompleted()) { + // cancel the timer task + operation.cancel() + } + } false } @@ -196,7 +219,7 @@ class DelayedOperationPurgatory[T <: DelayedOperation](purgatoryName: String, br /** * Return the number of delayed operations in the expiry queue */ - def delayed() = expirationReaper.delayed + def delayed() = timeoutTimer.size /* * Return the watch list of the given key @@ -208,133 +231,90 @@ class DelayedOperationPurgatory[T <: DelayedOperation](purgatoryName: String, br */ def shutdown() { expirationReaper.shutdown() + executor.shutdown() } /** * A linked list of watched delayed operations based on some key */ private class Watchers { - private val operations = new util.LinkedList[T] - def watched = operations.size() + private[this] val refQueue = new ReferenceQueue[T]() + + private[this] val unreachable = new AtomicInteger(0) + + private[this] val operations = new SinglyLinkedWeakList[T](refQueue) + + def watched(): Int = operations synchronized operations.size // add the element to watch def watch(t: T) { - synchronized { - operations.add(t) - } + operations synchronized operations.add(t) } // traverse the list and try to complete some watched elements def tryCompleteWatched(): Int = { - var completed = 0 - synchronized { + refQueue synchronized { + while (refQueue.poll() != null) {} + } + + operations synchronized { + // the reference queue is drained. we can set the unreachable count to zero + unreachable.set(0) + + var completed = 0 val iter = operations.iterator() - while(iter.hasNext) { - val curr = iter.next - if (curr.isCompleted()) { + while (iter.hasNext()) { + val curr = iter.next() + if (curr == null || curr.isCompleted()) { // another thread has completed this operation, just remove it iter.remove() + } else if (curr synchronized curr.tryComplete()) { + completed += 1 + iter.remove() } else { - if(curr synchronized curr.tryComplete()) { - iter.remove() - completed += 1 - } } } + completed } - completed } - // traverse the list and purge elements that are already completed by others - def purgeCompleted(): Int = { - var purged = 0 - synchronized { - val iter = operations.iterator() - while (iter.hasNext) { - val curr = iter.next - if(curr.isCompleted()) { - iter.remove() - purged += 1 + def purge(threshold: Int): Unit = { + refQueue.synchronized { + while (refQueue.poll() != null) { + unreachable.incrementAndGet() + } + } + + if (unreachable.get > threshold) { + operations synchronized { + // the reference queue is drained. we can set the unreachable count to zero + unreachable.set(0) + + val iter = operations.iterator() + while (iter.hasNext()) { + val curr = iter.next() + if (curr == null || curr.isCompleted()) { + iter.remove() + } } } } - purged } } /** * A background reaper to expire delayed operations that have timed out */ - private class ExpiredOperationReaper extends ShutdownableThread( + private class ExpiredOperationReaper(timeoutTimer: Timer) extends ShutdownableThread( "ExpirationReaper-%d".format(brokerId), false) { - /* The queue storing all delayed operations */ - private val delayedQueue = new DelayQueue[T] - - /* - * Return the number of delayed operations kept by the reaper - */ - def delayed() = delayedQueue.size() - - /* - * Add an operation to be expired - */ - def enqueue(t: T) { - delayedQueue.add(t) - } - - /** - * Try to get the next expired event and force completing it - */ - private def expireNext() { - val curr = delayedQueue.poll(200L, TimeUnit.MILLISECONDS) - if (curr != null.asInstanceOf[T]) { - // if there is an expired operation, try to force complete it - val completedByMe: Boolean = curr synchronized { - curr.onExpiration() - curr.forceComplete() - } - if (completedByMe) - debug("Force complete expired delayed operation %s".format(curr)) - } - } - - /** - * Delete all satisfied events from the delay queue and the watcher lists - */ - private def purgeCompleted(): Int = { - var purged = 0 - - // purge the delayed queue - val iter = delayedQueue.iterator() - while (iter.hasNext) { - val curr = iter.next() - if (curr.isCompleted()) { - iter.remove() - purged += 1 - } - } - - purged - } + private[this] val tryPurge = (w: Watchers) => w.purge(100) override def doWork() { - // try to get the next expired operation and force completing it - expireNext() - // see if we need to purge the watch lists - if (DelayedOperationPurgatory.this.watched() >= purgeInterval) { - debug("Begin purging watch lists") - val purged = watchersForKey.values.map(_.purgeCompleted()).sum - debug("Purged %d elements from watch lists.".format(purged)) - } - // see if we need to purge the delayed operation queue - if (delayed() >= purgeInterval) { - debug("Begin purging delayed queue") - val purged = purgeCompleted() - debug("Purged %d operations from delayed queue.".format(purged)) - } + timeoutTimer.advanceClock(200L) + watchersForKey.values.foreach(tryPurge) } } } diff --git a/core/src/main/scala/kafka/server/ReplicaManager.scala b/core/src/main/scala/kafka/server/ReplicaManager.scala index 586cf4c..c1a2713 100644 --- a/core/src/main/scala/kafka/server/ReplicaManager.scala +++ b/core/src/main/scala/kafka/server/ReplicaManager.scala @@ -83,10 +83,9 @@ class ReplicaManager(val config: KafkaConfig, val stateChangeLogger = KafkaController.stateChangeLogger val delayedProducePurgatory = new DelayedOperationPurgatory[DelayedProduce]( - purgatoryName = "Produce", config.brokerId, config.producerPurgatoryPurgeIntervalRequests) + purgatoryName = "Produce", config.brokerId) val delayedFetchPurgatory = new DelayedOperationPurgatory[DelayedFetch]( - purgatoryName = "Fetch", config.brokerId, config.fetchPurgatoryPurgeIntervalRequests) - + purgatoryName = "Fetch", config.brokerId) newGauge( "LeaderCount", diff --git a/core/src/main/scala/kafka/utils/SinglyLinkedWeakList.scala b/core/src/main/scala/kafka/utils/SinglyLinkedWeakList.scala new file mode 100644 index 0000000..91ae190 --- /dev/null +++ b/core/src/main/scala/kafka/utils/SinglyLinkedWeakList.scala @@ -0,0 +1,66 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.utils + +import java.lang.Iterable +import java.lang.ref.{ReferenceQueue, WeakReference} +import java.util.Iterator + +class SinglyLinkedWeakList[T](queue: ReferenceQueue[T]) extends Iterable[T] { + + private[this] class Entry(element: T, var nextEntry: Entry) extends WeakReference[T](element, queue) + + private[this] var list: Entry = null + private[this] var elementCount: Int = 0 + + def add(element: T): Unit = { + list = new Entry(element, list) + elementCount += 1 + } + + def size: Int = elementCount + + def iterator(): Iterator[T] = { + + new Iterator[T] { + private[this] var currEntry: Entry = null + private[this] var prevEntry: Entry = null + private[this] var nextEntry: Entry = list + + def hasNext(): Boolean = (nextEntry != null) + + def next(): T = { + if (nextEntry == null) throw new NoSuchElementException() + prevEntry = currEntry + currEntry = nextEntry + nextEntry = currEntry.nextEntry + currEntry.get() + } + + def remove(): Unit = { + if (currEntry == null) throw new IllegalStateException() + if (prevEntry == null) + list = nextEntry + else + prevEntry.nextEntry = nextEntry + elementCount -= 1 + } + } + } +} + diff --git a/core/src/main/scala/kafka/utils/timer/Timer.scala b/core/src/main/scala/kafka/utils/timer/Timer.scala new file mode 100644 index 0000000..17244ee --- /dev/null +++ b/core/src/main/scala/kafka/utils/timer/Timer.scala @@ -0,0 +1,82 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package kafka.utils.timer + +import java.util.concurrent.{DelayQueue, ExecutorService, TimeUnit} +import java.util.concurrent.atomic.AtomicInteger +import java.util.concurrent.locks.ReentrantReadWriteLock + +import kafka.utils.threadsafe + +@threadsafe +class Timer(taskExecutor: ExecutorService, tickMs: Long = 1, wheelSize: Int = 20, startMs: Long = System.currentTimeMillis) { + + private[this] val delayQueue = new DelayQueue[TimerTaskList]() + private[this] val taskCounter = new AtomicInteger(0) + private[this] val timingWheel = new TimingWheel( + tickMs = tickMs, + wheelSize = wheelSize, + startMs = startMs, + taskCounter = taskCounter, + delayQueue + ) + + // Locks used to protect data structures while ticking + private[this] val readWriteLock = new ReentrantReadWriteLock() + private[this] val readLock = readWriteLock.readLock() + private[this] val writeLock = readWriteLock.writeLock() + + def add(timerTask: TimerTask): Unit = { + readLock.lock() + try { + addTimerTaskEntry(new TimerTaskEntry(timerTask)) + } finally { + readLock.unlock() + } + } + + private def addTimerTaskEntry(timerTaskEntry: TimerTaskEntry): Unit = { + if (!timingWheel.add(timerTaskEntry)) { + // already expired + taskExecutor.submit(timerTaskEntry.timerTask) + } + } + + private[this] val reinsert = (timerTaskEntry: TimerTaskEntry) => addTimerTaskEntry(timerTaskEntry) + + def advanceClock(timeoutMs: Long): Boolean = { + var bucket = delayQueue.poll(timeoutMs, TimeUnit.MILLISECONDS) + if (bucket != null) { + writeLock.lock() + try { + while (bucket != null) { + timingWheel.advanceClock(bucket.getExpiration()) + bucket.flush(reinsert) + bucket = delayQueue.poll() + } + } finally { + writeLock.unlock() + } + true + } else { + false + } + } + + def size(): Int = taskCounter.get +} + diff --git a/core/src/main/scala/kafka/utils/timer/TimerTask.scala b/core/src/main/scala/kafka/utils/timer/TimerTask.scala new file mode 100644 index 0000000..0c528b4 --- /dev/null +++ b/core/src/main/scala/kafka/utils/timer/TimerTask.scala @@ -0,0 +1,41 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package kafka.utils.timer + +trait TimerTask extends Runnable { + + val expirationMs: Long // timestamp in millisecond + + private[this] var timerTaskEntry: TimerTaskEntry = null + + def cancel(): Unit = { + synchronized { + if (timerTaskEntry != null) timerTaskEntry.remove() + timerTaskEntry = null + } + } + + private[timer] def setTimerTaskEntry(entry: TimerTaskEntry): Unit = { + synchronized { + if (timerTaskEntry != null && timerTaskEntry != entry) { + timerTaskEntry.remove() + } + timerTaskEntry = entry + } + } + +} diff --git a/core/src/main/scala/kafka/utils/timer/TimerTaskList.scala b/core/src/main/scala/kafka/utils/timer/TimerTaskList.scala new file mode 100644 index 0000000..0ecc611 --- /dev/null +++ b/core/src/main/scala/kafka/utils/timer/TimerTaskList.scala @@ -0,0 +1,129 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package kafka.utils.timer + +import java.util.concurrent.{TimeUnit, Delayed} +import java.util.concurrent.atomic.{AtomicLong, AtomicInteger} + +import kafka.utils.{SystemTime, threadsafe} + +import scala.math._ + +@threadsafe +private[timer] class TimerTaskList(taskCounter: AtomicInteger) extends Delayed { + + // TimerTaskList forms a doubly linked cyclic list using a dummy root entry + // root.next points to the head + // root.prev points to the tail + private[this] val root = new TimerTaskEntry(null) + root.next = root + root.prev = root + + private[this] val expiration = new AtomicLong(-1L) + + // Set the bucket's expiration time + // Returns true if the expiration time is changed + def setExpiration(expirationMs: Long): Boolean = { + expiration.getAndSet(expirationMs) != expirationMs + } + + // Get the bucket's expiration time + def getExpiration(): Long = { + expiration.get() + } + + // Apply the supplied function to each of tasks in this list + def foreach(f: (TimerTask)=>Unit): Unit = { + synchronized { + var entry = root.next + while (entry ne root) { + val nextEntry = entry.next + f(entry.timerTask) + entry = nextEntry + } + } + } + + // Add a timer task entry to this list + def add(timerTaskEntry: TimerTaskEntry): Unit = { + synchronized { + // put the timer task entry to the end of the list. (root.prev points to the tail entry) + val tail = root.prev + timerTaskEntry.next = root + timerTaskEntry.prev = tail + timerTaskEntry.list = this + tail.next = timerTaskEntry + root.prev = timerTaskEntry + taskCounter.incrementAndGet() + timerTaskEntry.timerTask.setTimerTaskEntry(timerTaskEntry) + } + } + + // Remove the specified timer task entry from this list + def remove(timerTaskEntry: TimerTaskEntry): Unit = { + synchronized { + if (timerTaskEntry.list != null) { + timerTaskEntry.next.prev = timerTaskEntry.prev + timerTaskEntry.prev.next = timerTaskEntry.next + timerTaskEntry.next = null + timerTaskEntry.prev = null + timerTaskEntry.list = null + taskCounter.decrementAndGet() + } + } + } + + // Remove all task entries and apply the supplied function to each of them + def flush(f: (TimerTaskEntry)=>Unit): Unit = { + synchronized { + var head = root.next + while (head ne root) { + remove(head) + f(head) + head = root.next + } + expiration.set(-1L) + } + } + + def getDelay(unit: TimeUnit): Long = { + unit.convert(max(getExpiration - SystemTime.milliseconds, 0), TimeUnit.MILLISECONDS) + } + + def compareTo(d: Delayed): Int = { + + val other = d.asInstanceOf[TimerTaskList] + + if(getExpiration < other.getExpiration) -1 + else if(getExpiration > other.getExpiration) 1 + else 0 + } + +} + +private[timer] class TimerTaskEntry(val timerTask: TimerTask) { + + var list: TimerTaskList = null + var next: TimerTaskEntry = null + var prev: TimerTaskEntry = null + + def remove(): Unit = { + if (list != null) list.remove(this) + } + +} + diff --git a/core/src/main/scala/kafka/utils/timer/TimingWheel.scala b/core/src/main/scala/kafka/utils/timer/TimingWheel.scala new file mode 100644 index 0000000..94579ac --- /dev/null +++ b/core/src/main/scala/kafka/utils/timer/TimingWheel.scala @@ -0,0 +1,82 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package kafka.utils.timer + +import kafka.utils.nonthreadsafe + +import java.util.concurrent.DelayQueue +import java.util.concurrent.atomic.AtomicInteger + +@nonthreadsafe +private[timer] class TimingWheel(tickMs: Long, wheelSize: Int, startMs: Long, taskCounter: AtomicInteger, queue: DelayQueue[TimerTaskList]) { + + private[this] val interval = tickMs * wheelSize + private[this] val buckets = Array.tabulate[TimerTaskList](wheelSize) { _ => new TimerTaskList(taskCounter) } + + @volatile + private[this] var currentTime = startMs - (startMs % tickMs) // rounding down to multiple of tickSizeMs + private[this] var overflowWheel: TimingWheel = null + + private[this] def addOverflowWheel(): Unit = { + synchronized { + if (overflowWheel == null) { + overflowWheel = new TimingWheel( + tickMs = interval, + wheelSize = wheelSize, + startMs = currentTime, + taskCounter = taskCounter, + queue + ) + } + } + } + + def add(timerTaskEntry: TimerTaskEntry): Boolean = { + val expiration = timerTaskEntry.timerTask.expirationMs + + if (expiration < currentTime + tickMs) { + // Already expired + false + } else if (expiration < currentTime + interval) { + // Put it in an own bucket + val virtualId = expiration / tickMs + val bucket = buckets((virtualId % wheelSize.toLong).toInt) + bucket.add(timerTaskEntry) + + // Set the bucket expiration time + if (bucket.setExpiration(virtualId * tickMs)) { + // The bucket needs to be enqueued because it was an expired bucket + queue.offer(bucket) + } + true + } else { + // Out of the interval. Put it into the parent timer + if (overflowWheel == null) addOverflowWheel() + overflowWheel.add(timerTaskEntry) + } + } + + // Try to advance the clock + def advanceClock(timeMs: Long): Unit = { + if (timeMs >= currentTime + tickMs) { + currentTime = timeMs - (timeMs % tickMs) + + // Try to advance the clock of the overflow wheel if present + if (overflowWheel != null) overflowWheel.advanceClock(currentTime) + } + } +} diff --git a/core/src/test/scala/unit/kafka/server/DelayedOperationTest.scala b/core/src/test/scala/unit/kafka/server/DelayedOperationTest.scala index 7a37617..77116dc 100644 --- a/core/src/test/scala/unit/kafka/server/DelayedOperationTest.scala +++ b/core/src/test/scala/unit/kafka/server/DelayedOperationTest.scala @@ -20,17 +20,16 @@ package kafka.server import org.junit.Test import org.scalatest.junit.JUnit3Suite import junit.framework.Assert._ -import kafka.utils.TestUtils class DelayedOperationTest extends JUnit3Suite { var purgatory: DelayedOperationPurgatory[MockDelayedOperation] = null - + override def setUp() { super.setUp() - purgatory = new DelayedOperationPurgatory[MockDelayedOperation](purgatoryName = "mock", 0, 5) + purgatory = new DelayedOperationPurgatory[MockDelayedOperation](purgatoryName = "mock", 0) } - + override def tearDown() { purgatory.shutdown() super.tearDown() @@ -72,32 +71,34 @@ class DelayedOperationTest extends JUnit3Suite { def testRequestPurge() { val r1 = new MockDelayedOperation(100000L) val r2 = new MockDelayedOperation(100000L) + val r3 = new MockDelayedOperation(100000L) purgatory.tryCompleteElseWatch(r1, Array("test1")) purgatory.tryCompleteElseWatch(r2, Array("test1", "test2")) - purgatory.tryCompleteElseWatch(r1, Array("test2", "test3")) + purgatory.tryCompleteElseWatch(r3, Array("test1", "test2", "test3")) - assertEquals("Purgatory should have 5 watched elements", 5, purgatory.watched()) assertEquals("Purgatory should have 3 total delayed operations", 3, purgatory.delayed()) + assertEquals("Purgatory should have 5 watched elements", 6, purgatory.watched()) - // complete one of the operations, it should - // eventually be purged from the watch list with purge interval 5 + // complete the operations, it should immediately be purged from the delayed operation r2.completable = true r2.tryComplete() - TestUtils.waitUntilTrue(() => purgatory.watched() == 3, - "Purgatory should have 3 watched elements instead of " + purgatory.watched(), 1000L) - TestUtils.waitUntilTrue(() => purgatory.delayed() == 3, - "Purgatory should still have 3 total delayed operations instead of " + purgatory.delayed(), 1000L) + assertEquals("Purgatory should have 3 total delayed operations instead of ", 2, purgatory.delayed()) - // add two more requests, then the satisfied request should be purged from the delayed queue with purge interval 5 - purgatory.tryCompleteElseWatch(r1, Array("test1")) - purgatory.tryCompleteElseWatch(r1, Array("test1")) + r3.completable = true + r3.tryComplete() + assertEquals("Purgatory should have 2 total delayed operations instead of ", 1, purgatory.delayed()) - TestUtils.waitUntilTrue(() => purgatory.watched() == 5, - "Purgatory should have 5 watched elements instead of " + purgatory.watched(), 1000L) - TestUtils.waitUntilTrue(() => purgatory.delayed() == 4, - "Purgatory should have 4 total delayed operations instead of " + purgatory.delayed(), 1000L) + // checking a watch should purge the watch list + purgatory.checkAndComplete("test1") + assertEquals("Purgatory should have 4 watched elements instead of ", 4, purgatory.watched()) + + purgatory.checkAndComplete("test2") + assertEquals("Purgatory should have 2 watched elements instead of ", 2, purgatory.watched()) + + purgatory.checkAndComplete("test3") + assertEquals("Purgatory should have 1 watched elements instead of ", 1, purgatory.watched()) } - + class MockDelayedOperation(delayMs: Long) extends DelayedOperation(delayMs) { var completable = false @@ -124,5 +125,5 @@ class DelayedOperationTest extends JUnit3Suite { } } } - -} \ No newline at end of file + +} diff --git a/core/src/test/scala/unit/kafka/utils/timer/TimerTaskListTest.scala b/core/src/test/scala/unit/kafka/utils/timer/TimerTaskListTest.scala new file mode 100644 index 0000000..05a0165 --- /dev/null +++ b/core/src/test/scala/unit/kafka/utils/timer/TimerTaskListTest.scala @@ -0,0 +1,93 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package kafka.utils.timer + +import junit.framework.Assert._ +import java.util.concurrent.atomic._ +import org.junit.{Test, After, Before} + +class TimerTaskListTest { + + private class TestTask(val expirationMs: Long) extends TimerTask { + def run(): Unit = { } + } + + private def size(list: TimerTaskList): Int = { + var count = 0 + list.foreach(_ => count += 1) + count + } + + @Test + def testAll() { + val sharedCounter = new AtomicInteger(0) + val runCounter = new AtomicInteger(0) + val execCounter = new AtomicInteger(0) + val list1 = new TimerTaskList(sharedCounter) + val list2 = new TimerTaskList(sharedCounter) + val list3 = new TimerTaskList(sharedCounter) + + val tasks = (1 to 10).map { i => + val task = new TestTask(10L) + val prevCount = sharedCounter.get + list1.add(new TimerTaskEntry(task)) + assertEquals(prevCount + 1, sharedCounter.get) + assertEquals(i, sharedCounter.get) + task + }.toSeq + + assertEquals(tasks.size, sharedCounter.get) + + tasks.take(4).foreach { task => + val prevCount = sharedCounter.get + list2.add(new TimerTaskEntry(task)) + assertEquals(prevCount, sharedCounter.get) + } + assertEquals(10 - 4, size(list1)) + assertEquals(4, size(list2)) + + assertEquals(tasks.size, sharedCounter.get) + + tasks.drop(4).foreach { task => + val prevCount = sharedCounter.get + list3.add(new TimerTaskEntry(task)) + assertEquals(prevCount, sharedCounter.get) + } + assertEquals(0, size(list1)) + assertEquals(4, size(list2)) + assertEquals(6, size(list3)) + + assertEquals(tasks.size, sharedCounter.get) + + // cancel tasks in lists + list1.foreach { _.cancel() } + assertEquals(0, size(list1)) + assertEquals(4, size(list2)) + assertEquals(6, size(list3)) + + list2.foreach { _.cancel() } + assertEquals(0, size(list1)) + assertEquals(0, size(list2)) + assertEquals(6, size(list3)) + + list3.foreach { _.cancel() } + assertEquals(0, size(list1)) + assertEquals(0, size(list2)) + assertEquals(0, size(list3)) + } + +} diff --git a/core/src/test/scala/unit/kafka/utils/timer/TimerTest.scala b/core/src/test/scala/unit/kafka/utils/timer/TimerTest.scala new file mode 100644 index 0000000..5524cf6 --- /dev/null +++ b/core/src/test/scala/unit/kafka/utils/timer/TimerTest.scala @@ -0,0 +1,111 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package kafka.utils.timer + +import java.util.concurrent.{CountDownLatch, ExecutorService, Executors, TimeUnit} + +import junit.framework.Assert._ +import java.util.concurrent.atomic._ +import org.junit.{Test, After, Before} + +import scala.collection.mutable.ArrayBuffer +import scala.util.Random + +class TimerTest { + + private class TestTask(val expirationMs: Long, id: Int, latch: CountDownLatch, output: ArrayBuffer[Int]) extends TimerTask { + private[this] val completed = new AtomicBoolean(false) + def run(): Unit = { + if (completed.compareAndSet(false, true)) { + output.synchronized { output += id } + latch.countDown() + } + } + } + + private[this] var executor: ExecutorService = null + + @Before + def setup() { + executor = Executors.newSingleThreadExecutor() + } + + @After + def teardown(): Unit = { + executor.shutdown() + executor = null + } + + @Test + def testAlreadyExpiredTask(): Unit = { + val startTime = System.currentTimeMillis() + val timer = new Timer(taskExecutor = executor, tickMs = 1, wheelSize = 3, startMs = startTime) + val output = new ArrayBuffer[Int]() + + + val latches = (-5 until 0).map { i => + val latch = new CountDownLatch(1) + timer.add(new TestTask(startTime + i, i, latch, output)) + latch + } + + latches.take(5).foreach { latch => + assertEquals("already expired tasks should run immediately", true, latch.await(3, TimeUnit.SECONDS)) + } + + assertEquals("output of already expired tasks", Set(-5, -4, -3, -2, -1), output.toSet) + } + + @Test + def testTaskExpiration(): Unit = { + val startTime = System.currentTimeMillis() + val timer = new Timer(taskExecutor = executor, tickMs = 1, wheelSize = 3, startMs = startTime) + val output = new ArrayBuffer[Int]() + + val tasks = new ArrayBuffer[TestTask]() + val ids = new ArrayBuffer[Int]() + + val latches = + (0 until 5).map { i => + val latch = new CountDownLatch(1) + tasks += new TestTask(startTime + i, i, latch, output) + ids += i + latch + } ++ (10 until 100).map { i => + val latch = new CountDownLatch(2) + tasks += new TestTask(startTime + i, i, latch, output) + tasks += new TestTask(startTime + i, i, latch, output) + ids += i + ids += i + latch + } ++ (100 until 500).map { i => + val latch = new CountDownLatch(1) + tasks += new TestTask(startTime + i, i, latch, output) + ids += i + latch + } + + // randomly submit requests + Random.shuffle(tasks.toSeq).map { task => timer.add(task) } + + while (timer.advanceClock(1000)) {} + + latches.foreach { latch => latch.await() } + + assertEquals("output should match", ids.sorted, output.toSeq) + } +} -- 2.3.0