/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.commons.math.linear;

import java.io.Serializable;

/**
 * Sparse matrix implementation based on an open addressed map.
 * 
 * @version $Revision$ $Date$
 */
public class SparseRealMatrix extends AbstractRealMatrix {

    static class OpenIntToDoubleHashMap implements Serializable {

        private static final long serialVersionUID = 1L;
        private static final float LOAD_FACTOR = 0.50f;
        private static final int DEFAULT_EXPECTED_SIZE = 16;
        private static final int RESIZE_MULTIPLIER = 2;
        private static final int PERTURB_SHIFT = 5;
        private int size;
        private int mask;
        private int[] keys;
        private double[] values;
        protected byte[] states;
        protected static final byte FREE = 0;
        protected static final byte FULL = 1;
        protected static final byte REMOVED = 2;

        OpenIntToDoubleHashMap() {
            this(DEFAULT_EXPECTED_SIZE);
        }

        OpenIntToDoubleHashMap(int expectedSize) {
            int capacity = computeCapacity(expectedSize);
            keys = new int[capacity];
            values = new double[capacity];
            states = new byte[capacity];
            mask = capacity - 1;
        }

        OpenIntToDoubleHashMap(OpenIntToDoubleHashMap source) {
            int length = source.keys.length;
            keys = new int[length];
            System.arraycopy(source.keys, 0, keys, 0, length);
            values = new double[length];
            System.arraycopy(source.values, 0, values, 0, length);
            states = new byte[length];
            System.arraycopy(source.states, 0, states, 0, length);
            mask = source.mask;
            size = source.size;
        }

        private static int computeCapacity(int expectedSize) {
            if (expectedSize == 0)
                return 1;
            int capacity = (int) Math.ceil(expectedSize / LOAD_FACTOR);
            int powerOfTwo = Integer.highestOneBit(capacity);
            if (powerOfTwo == capacity)
                return capacity;
            return nextPowerOfTwo(capacity);
        }

        private static int nextPowerOfTwo(int i) {
            return Integer.highestOneBit(i) << 1;
        }

        public double get(int key) {
            int hash = hashOf(key);
            int index = hash & mask;
            if (containsKey(key, index))
                return values[index];
            if (states[index] == FREE)
                return 0.0;
            for (int perturb = perturb(hash), j = index; states[index] != FREE; perturb >>= PERTURB_SHIFT) {
                j = probe(perturb, j);
                index = j & mask;
                if (containsKey(key, index))
                    return values[index];
            }
            return 0.0;
        }

        private static int perturb(int hash) {
            return hash & 0x7fffffff;
        }

        private int findInsertionIndex(int key) {
            return findInsertionIndex(keys, states, key, mask);
        }

        private static int findInsertionIndex(int[] keys, byte[] states,
            int key, int mask) {
            int hash = hashOf(key);
            int index = hash & mask;
            if (states[index] == FREE)
                return index;
            else if (states[index] == FULL && keys[index] == key)
                return changeIndexSign(index);

            if (states[index] == FULL) {
                for (int perturb = perturb(hash), j = index;; perturb >>= PERTURB_SHIFT) {
                    j = probe(perturb, j);
                    index = j & mask;
                    if (states[index] != FULL || keys[index] == key)
                        break;
                }
            }
            if (states[index] == FREE)
                return index;
            /*
             * Due to the loop exit condition, if (states[index] == FULL) then
             * keys[index] == key
             */
            else if (states[index] == FULL)
                return changeIndexSign(index);

            int firstRemoved = index;
            for (int perturb = perturb(hash), j = index;; perturb >>= PERTURB_SHIFT) {
                j = probe(perturb, j);
                index = j & mask;
                if (states[index] == FREE)
                    return firstRemoved;
                else if (states[index] == FULL && keys[index] == key)
                    return changeIndexSign(index);
            }
        }

        private static int probe(int perturb, int j) {
            return (j << 2) + j + perturb + 1;
        }

        private static int changeIndexSign(int index) {
            return -index - 1;
        }

        public int size() {
            return size;
        }

        public double remove(int key) {
            int hash = hashOf(key);
            int index = hash & mask;
            if (containsKey(key, index)) {
                double previous = values[index];
                doRemove(index);
                return previous;
            }
            if (states[index] == FREE)
                return 0.0;

            for (int perturb = perturb(hash), j = index; states[index] != FREE; perturb >>= PERTURB_SHIFT) {
                j = probe(perturb, j);
                index = j & mask;
                if (containsKey(key, index)) {
                    double previous = values[index];
                    doRemove(index);
                    return previous;
                }
            }
            return 0.0;
        }

        private boolean containsKey(int key, int index) {
            return (key != 0 || states[index] == FULL) && keys[index] == key;
        }

        private void doRemove(int index) {
            keys[index] = 0;
            states[index] = REMOVED;
            values[index] = 0;
            --size;
        }

        public double put(int key, double value) {
            int index = findInsertionIndex(key);
            double previous = 0.0;
            boolean newMapping = true;
            if (index < 0) {
                index = changeIndexSign(index);
                previous = values[index];
                newMapping = false;
            }
            keys[index] = key;
            states[index] = FULL;
            values[index] = value;
            if (newMapping) {
                ++size;
                if (shouldGrowTable())
                    growTable();
            }

            return previous;
        }

        private void growTable() {
            int oldLength = states.length;
            int[] oldKeys = keys;
            double[] oldValues = values;
            byte[] oldStates = states;

            int newLength = RESIZE_MULTIPLIER * oldLength;
            int[] newKeys = new int[newLength];
            double[] newValues = new double[newLength];
            byte[] newStates = new byte[newLength];
            int newMask = newLength - 1;
            for (int i = 0; i < oldLength; ++i) {
                if (oldStates[i] == FULL) {
                    int key = oldKeys[i];
                    int index = findInsertionIndex(newKeys, newStates, key,
                            newMask);
                    newKeys[index] = key;
                    newValues[index] = oldValues[i];
                    newStates[index] = FULL;
                }
            }
            mask = newMask;
            keys = newKeys;
            values = newValues;
            states = newStates;
        }

        private boolean shouldGrowTable() {
            return size > (mask + 1) * LOAD_FACTOR;
        }

        private static int hashOf(int h) {
            h ^= ((h >>> 20) ^ (h >>> 12));
            return h ^ (h >>> 7) ^ (h >>> 4);
        }
    }
  
    private static final long serialVersionUID = -4601548152403366499L;

    private final int rowDimension;
    private final int columnDimension;
    private OpenIntToDoubleHashMap entries;

    public SparseRealMatrix(int rowDimension, int columnDimension) {
        super(rowDimension, columnDimension);
        this.rowDimension = rowDimension;
        this.columnDimension = columnDimension;
        this.entries = new OpenIntToDoubleHashMap();
    }
  
    public SparseRealMatrix(SparseRealMatrix sparseRealMatrix) {
        this.rowDimension = sparseRealMatrix.rowDimension;
        this.columnDimension = sparseRealMatrix.columnDimension;
        this.entries = new OpenIntToDoubleHashMap(sparseRealMatrix.entries);
    }
  
    /** {@inheritDoc} */
    @Override
    public RealMatrix copy() {
        return new SparseRealMatrix(this);
    }

    /** {@inheritDoc} */
    @Override
    public RealMatrix createMatrix(int rowDimension, int columnDimension)
            throws IllegalArgumentException {
        return new SparseRealMatrix(rowDimension, columnDimension);
    }

    /** {@inheritDoc} */
    @Override
    public int getColumnDimension() {
        return this.columnDimension;
    }

    /** {@inheritDoc} */
    @Override
    public double[][] getData() {
        throw new UnsupportedOperationException(
                "Not supported because sparse matrices are usually chosen when " +
                "the matrix to be modeled is too large to fit into memory, trying " +
                "to build the data array from this may result in an OutOfMemoryException.");
    }

    /** {@inheritDoc} */
    @Override
    public double getEntry(int row, int column) throws MatrixIndexException {
        checkRowIndex(row);
        checkColumnIndex(column);
        return entries.get(computeKey(row, column));
    }

    /** {@inheritDoc} */
    @Override
    public int getRowDimension() {
        return this.rowDimension;
    }

    /** {@inheritDoc} */
    @Override
    public void setEntry(int row, int column, double value)
            throws MatrixIndexException {
        checkRowIndex(row);
        checkColumnIndex(column);
        if (value == 0.0D) {
            entries.remove(computeKey(row, column));
        } else {
            entries.put(computeKey(row, column), value);
        }
    }
    
    private int computeKey(int row, int column) {
        return row * columnDimension + column;
    }
}
