import org.apache.lucene.util.PriorityQueue;

import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicLong;
import java.util.Map;
import java.util.Set;
import java.util.ArrayList;

public class ConcurrentLRUCache extends ConcurrentHashMap {

  final AtomicLong counter = new AtomicLong(0);
  final AtomicInteger size = new AtomicInteger(0);
  final int lowWaterMark;
  final int highWaterMark;
  final int acceptableWaterMark;
  final AtomicBoolean latch = new AtomicBoolean(false);
  boolean cleaning = false;
  volatile long oldestEntry;

  public ConcurrentLRUCache(int initialCapacity, float loadFactor, int concurrencyLevel,
                            int lowWaterMark, int highWaterMark, int acceptableWaterMark)
  {
    super(initialCapacity, loadFactor, concurrencyLevel);
    this.lowWaterMark = lowWaterMark;
    this.highWaterMark = highWaterMark;
    this.acceptableWaterMark = acceptableWaterMark;
  }


  public static class CacheValue {
    private volatile long lastAccess;
    private long scratch;
    private final Object val;
    public CacheValue(Object val, long lastAccess) {
      this.val = val;
      this.lastAccess = lastAccess;
    }
  }

  public Object cacheGet(Object key) {
    Object o = super.get(key);
    if (o == null) {
      return null;
    }
    CacheValue cv = (CacheValue)o;
    cv.lastAccess = counter.getAndIncrement();
    return cv.val;
  }

  public Object cachePut(Object key, Object val) {
    CacheValue cv = new CacheValue(val, counter.getAndIncrement());
    Object o = super.put(key, cv);
    if (o != null) {
      return ((CacheValue)o).val;
    }

    int sz = size.incrementAndGet();

    // non-volatile piggybacked read for "cleaning"
    if (sz > highWaterMark  && !cleaning) {
      // make sure that only one thread tries to clean
      if (latch.compareAndSet(false, true)) {
        doClean();
      }
    }

    return null;
  }


  private void doClean() {
    // if we want to keep at least 1000 entries, then timestamps of
    // current through current-1000 are guaranteed not to be the oldest!
    // Also, if we want to remove 500 entries, then
    // oldestEntry through oldestEntry+500 are guaranteed to be
    // removed.

    long timeCurrent = counter.get();
    int sz = size.get();

    int numRemoved = 0;
    int numKept = 0;
    long newestEntry = timeCurrent;
    long newNewestEntry = -1;
    long oldestEntry = this.oldestEntry;
    long newOldestEntry = Integer.MAX_VALUE;

    int wantToKeep = lowWaterMark;
    int wantToRemove = sz - lowWaterMark;

    cleaning = true;  // piggyback non-volatile write on a following volatile write
    this.oldestEntry = oldestEntry;  // volatile write to ensure visibility of other vars

    Map.Entry[] eset = new Map.Entry[sz];
    int eSize = 0;

    for (Map.Entry e : (Set<Map.Entry>)super.entrySet()) {
      CacheValue cv = (CacheValue)e.getValue();

      // also set scratch to lastAccess to avoid more volatile reads
      long lastAccess = cv.scratch = cv.lastAccess;

      // since wantToKeep is likely to be bigger than wantToRemove, check it first
      if (newestEntry - lastAccess < wantToKeep) {
        // this entry is guaranteed not to be in the bottom
        // group, so do nothing.
        numKept++;
      } else if (lastAccess - oldestEntry <= wantToRemove // entry in bottom group?
              || numKept >= wantToKeep  // if we have enough entries, discard the rest
              )
      {
        // this entry is guaranteed to be in the bottom group
        // so immediately remove it from the map.
        super.remove(e.getKey());
        numRemoved++;
      } else {
        // This entry *could* be in the bottom group.
        // Collect these entries to avoid another full pass... this is wasted
        // effort if enough entries are normally removed in this first pass.
        // An alternate impl could make a full second pass.
        if (eSize < eset.length-1) {
          eset[eSize++] = e;
          newNewestEntry = Math.max(lastAccess, newNewestEntry);
          newOldestEntry = Math.min(lastAccess, newOldestEntry);
        }
      }
    }

    // TODO: allow this to be customized in the constructor?
    int numPasses=1; // maximum number of linear passes over the data
    
    // if we didn't remove enough entries, then make more passes
    // over the values we collected, with updated min and max values.
    while (sz - numRemoved >= acceptableWaterMark && --numPasses>=0) {

      oldestEntry = newOldestEntry == Integer.MAX_VALUE ? oldestEntry : newOldestEntry;
      newOldestEntry = Integer.MAX_VALUE;
      newestEntry = newNewestEntry;
      newNewestEntry = -1;
      wantToKeep = lowWaterMark - numKept;
      wantToRemove = sz - lowWaterMark - numRemoved;

      for (int i=eSize-1; i>=0; i--) {
        Map.Entry e = eset[i];
        CacheValue cv = (CacheValue)e.getValue();
        long lastAccess = cv.scratch;

        if (newestEntry - lastAccess < wantToKeep) {
          // this entry is guaranteed not to be in the bottom
          // group, so do nothing but remove it from the eset.
          numKept++;
          // remove the entry by moving the last element to it's position
          eset[i] = eset[eSize-1];
          eSize--;
        } else if (lastAccess - oldestEntry <= wantToRemove) { // entry in bottom group?

          // this entry is guaranteed to be in the bottom group
          // so immediately remove it from the map.
          super.remove(e.getKey());
          numRemoved++;

          // remove the entry by moving the last element to it's position
          eset[i] = eset[eSize-1];
          eSize--;     
        } else {
          // This entry *could* be in the bottom group, so keep it in the eset,
          // and update the stats.
          newNewestEntry = Math.max(lastAccess, newNewestEntry);
          newOldestEntry = Math.min(lastAccess, newOldestEntry);
        }
      }
    }


    // if we still didn't remove enough entries, then make another pass while
    // inserting into a priority queue
    if (sz - numRemoved >= acceptableWaterMark) {

      oldestEntry = newOldestEntry == Integer.MAX_VALUE ? oldestEntry : newOldestEntry;
      newOldestEntry = Integer.MAX_VALUE;
      newestEntry = newNewestEntry;
      newNewestEntry = -1;
      wantToKeep = lowWaterMark - numKept;
      wantToRemove = sz - lowWaterMark - numRemoved;
      
      PQueue queue = new PQueue(wantToRemove);

      for (int i=eSize-1; i>=0; i--) {
        Map.Entry e = eset[i];
        CacheValue cv = (CacheValue)e.getValue();
        long lastAccess = cv.scratch;

        if (newestEntry - lastAccess < wantToKeep) {
          // this entry is guaranteed not to be in the bottom
          // group, so do nothing but remove it from the eset.
          numKept++;
          // removal not necessary on last pass.
          // eset[i] = eset[eSize-1];
          // eSize--;
        } else if (lastAccess - oldestEntry <= wantToRemove) {  // entry in bottom group?
          // this entry is guaranteed to be in the bottom group
          // so immediately remove it.
          super.remove(e.getKey());
          numRemoved++;

          // removal not necessary on last pass.
          // eset[i] = eset[eSize-1];
          // eSize--;
        } else {
          // This entry *could* be in the bottom group.
          // add it to the priority queue
          Object o = queue.insertWithOverflow(e);

          // everything in the priority queue will be removed, so keep track of
          // the lowest value that comes back out of the queue.
          if (o == null) {
            // if the queue isn't full, keep track of the maximum value
            // we have put into it.
            newOldestEntry = Math.max(lastAccess, newOldestEntry);
          } else {
            // otherwise, keep track of the *minimum* value that comes back out
            long access = lastAccess;
            if (o != e) {
              access = ((CacheValue)(((Map.Entry)o).getValue())).scratch;
            }
            newOldestEntry = Math.min(access, newOldestEntry);
          }
        }
      }

      // Now delete everything in the priority queue.
      // avoid pop() since order doesn't matter.
      for (Object o : queue.getValues()) {
        if (o==null) continue;
        Map.Entry e = (Map.Entry)o;
        super.remove(e.getKey());
        numRemoved++;
      }
    }

    size.addAndGet(-numRemoved);
    this.oldestEntry = oldestEntry;
    cleaning = false;  // non-volatile write piggybacking on following volatile write
    latch.set(false);
  }


  private static class PQueue extends PriorityQueue {
    PQueue(int maxSz) {
      super.initialize(maxSz);
    }

    Object[] getValues() { return heap; }

    protected boolean lessThan(Object a, Object b) {
      Map.Entry<?,CacheValue> a1 = (Map.Entry<?,CacheValue>)a;
      Map.Entry<?,CacheValue> b1 = (Map.Entry<?,CacheValue>)b;
      return a1.getValue().scratch > b1.getValue().scratch;
    }
  }

}
