diff --git a/core/src/main/scala/kafka/server/RequestPurgatory.scala b/core/src/main/scala/kafka/server/RequestPurgatory.scala index 8fb9865..7e486be 100644 --- a/core/src/main/scala/kafka/server/RequestPurgatory.scala +++ b/core/src/main/scala/kafka/server/RequestPurgatory.scala @@ -207,6 +207,7 @@ abstract class RequestPurgatory[T <: DelayedRequest, R](brokerId: Int = 0, purge private val delayed = new DelayQueue[T] private val running = new AtomicBoolean(true) private val shutdownLatch = new CountDownLatch(1) + private val needsFullPurge = new AtomicBoolean(false) /* The count of elements in the delay queue that are unsatisfied */ private [kafka] val unsatisfied = new AtomicInteger(0) @@ -218,16 +219,20 @@ abstract class RequestPurgatory[T <: DelayedRequest, R](brokerId: Int = 0, purge while(running.get) { try { val curr = pollExpired() - curr synchronized { - expire(curr) + if (curr != null) { + curr synchronized { + expire(curr) + } } - } catch { - case ie: InterruptedException => + if (needsFullPurge.get) { val purged = purgeSatisfied() debug("Purged %d requests from delay queue.".format(purged)) val numPurgedFromWatchers = watchersForKey.values.map(_.purgeSatisfied()).sum debug("Purged %d (watcher) requests.".format(numPurgedFromWatchers)) - case e: Exception => + needsFullPurge.set(false) + } + } catch { + case e: Throwable => error("Error in long poll expiry thread: ", e) } } @@ -241,7 +246,7 @@ abstract class RequestPurgatory[T <: DelayedRequest, R](brokerId: Int = 0, purge } def forcePurge() { - expirationThread.interrupt() + needsFullPurge.set(true) } /** Shutdown the expiry thread*/ @@ -261,7 +266,9 @@ abstract class RequestPurgatory[T <: DelayedRequest, R](brokerId: Int = 0, purge */ private def pollExpired(): T = { while(true) { - val curr = delayed.take() + val curr = delayed.poll(500L, TimeUnit.MILLISECONDS) + if (curr == null) + return null.asInstanceOf[T] val updated = curr.satisfied.compareAndSet(false, true) if(updated) { unsatisfied.getAndDecrement()