Index: ml/src/main/java/org/apache/hama/ml/distance/CosineDistance.java =================================================================== --- ml/src/main/java/org/apache/hama/ml/distance/CosineDistance.java (revision 1498967) +++ ml/src/main/java/org/apache/hama/ml/distance/CosineDistance.java (working copy) @@ -50,7 +50,7 @@ double lengthSquaredv1 = vec1.pow(2).sum(); double lengthSquaredv2 = vec2.pow(2).sum(); - double dotProduct = vec2.dot(vec1); + double dotProduct = vec2.dotUnsafe(vec1); double denominator = Math.sqrt(lengthSquaredv1) * Math.sqrt(lengthSquaredv2); Index: ml/src/main/java/org/apache/hama/ml/distance/EuclidianDistance.java =================================================================== --- ml/src/main/java/org/apache/hama/ml/distance/EuclidianDistance.java (revision 1498967) +++ ml/src/main/java/org/apache/hama/ml/distance/EuclidianDistance.java (working copy) @@ -36,7 +36,7 @@ @Override public double measureDistance(DoubleVector vec1, DoubleVector vec2) { - return Math.sqrt(vec2.subtract(vec1).pow(2).sum()); + return Math.sqrt(vec2.subtractUnsafe(vec1).pow(2).sum()); } } Index: ml/src/main/java/org/apache/hama/ml/kmeans/KMeansBSP.java =================================================================== --- ml/src/main/java/org/apache/hama/ml/kmeans/KMeansBSP.java (revision 1498967) +++ ml/src/main/java/org/apache/hama/ml/kmeans/KMeansBSP.java (working copy) @@ -162,7 +162,7 @@ if (oldCenter == null) { msgCenters[msg.getCenterIndex()] = newCenter; } else { - msgCenters[msg.getCenterIndex()] = oldCenter.add(newCenter); + msgCenters[msg.getCenterIndex()] = oldCenter.addUnsafe(newCenter); } } // divide by how often we globally summed vectors @@ -177,7 +177,7 @@ for (int i = 0; i < msgCenters.length; i++) { final DoubleVector oldCenter = centers[i]; if (msgCenters[i] != null) { - double calculateError = oldCenter.subtract(msgCenters[i]).abs().sum(); + double calculateError = oldCenter.subtractUnsafe(msgCenters[i]).abs().sum(); if (calculateError > 0.0d) { centers[i] = msgCenters[i]; convergedCounter++; @@ -241,7 +241,7 @@ } else { // add the vector to the center newCenterArray[lowestDistantCenter] = newCenterArray[lowestDistantCenter] - .add(key); + .addUnsafe(key); summationCount[lowestDistantCenter]++; } } Index: ml/src/main/java/org/apache/hama/ml/math/DenseDoubleMatrix.java =================================================================== --- ml/src/main/java/org/apache/hama/ml/math/DenseDoubleMatrix.java (revision 1498967) +++ ml/src/main/java/org/apache/hama/ml/math/DenseDoubleMatrix.java (working copy) @@ -21,6 +21,8 @@ import java.util.HashSet; import java.util.Random; +import com.google.common.base.Preconditions; + /** * Dense double matrix implementation, internally uses two dimensional double * arrays. @@ -384,7 +386,7 @@ * @see de.jungblut.math.DoubleMatrix#multiply(de.jungblut.math.DoubleMatrix) */ @Override - public final DoubleMatrix multiply(DoubleMatrix other) { + public final DoubleMatrix multiplyUnsafe(DoubleMatrix other) { DenseDoubleMatrix matrix = new DenseDoubleMatrix(this.getRowCount(), other.getColumnCount()); @@ -412,7 +414,7 @@ * ) */ @Override - public final DoubleMatrix multiplyElementWise(DoubleMatrix other) { + public final DoubleMatrix multiplyElementWiseUnsafe(DoubleMatrix other) { DenseDoubleMatrix matrix = new DenseDoubleMatrix(this.numRows, this.numColumns); @@ -431,7 +433,7 @@ * de.jungblut.math.DoubleMatrix#multiplyVector(de.jungblut.math.DoubleVector) */ @Override - public final DoubleVector multiplyVector(DoubleVector v) { + public final DoubleVector multiplyVectorUnsafe(DoubleVector v) { DoubleVector vector = new DenseDoubleVector(this.getRowCount()); for (int row = 0; row < numRows; row++) { double sum = 0.0d; @@ -494,7 +496,7 @@ * @see de.jungblut.math.DoubleMatrix#subtract(de.jungblut.math.DoubleMatrix) */ @Override - public DoubleMatrix subtract(DoubleMatrix other) { + public DoubleMatrix subtractUnsafe(DoubleMatrix other) { DoubleMatrix m = new DenseDoubleMatrix(this.numRows, this.numColumns); for (int i = 0; i < numRows; i++) { for (int j = 0; j < numColumns; j++) { @@ -509,7 +511,7 @@ * @see de.jungblut.math.DoubleMatrix#subtract(de.jungblut.math.DoubleVector) */ @Override - public DenseDoubleMatrix subtract(DoubleVector vec) { + public DenseDoubleMatrix subtractUnsafe(DoubleVector vec) { DenseDoubleMatrix cop = new DenseDoubleMatrix(this.getRowCount(), this.getColumnCount()); for (int i = 0; i < this.getColumnCount(); i++) { @@ -775,7 +777,7 @@ * Just a absolute error function. */ public static double error(DenseDoubleMatrix a, DenseDoubleMatrix b) { - return a.subtract(b).sum(); + return a.subtractUnsafe(b).sum(); } @Override @@ -795,20 +797,91 @@ /** * {@inheritDoc} */ - public DoubleMatrix applyToElements(DoubleMatrix other, DoubleDoubleFunction fun) { - if (this.numRows != other.getRowCount() - || this.numColumns != other.getColumnCount()) { - throw new IllegalArgumentException( - "Cannot apply double double function to matrices with different sizes."); - } - + public DoubleMatrix applyToElements(DoubleMatrix other, + DoubleDoubleFunction fun) { + Preconditions + .checkArgument(this.numRows == other.getRowCount() + && this.numColumns == other.getColumnCount(), + "Cannot apply double double function to matrices with different sizes."); + for (int r = 0; r < this.numRows; ++r) { for (int c = 0; c < this.numColumns; ++c) { this.set(r, c, fun.apply(this.get(r, c), other.get(r, c))); } } - + return this; } + /* + * (non-Javadoc) + * @see + * org.apache.hama.ml.math.DoubleMatrix#safeMultiply(org.apache.hama.ml.math + * .DoubleMatrix) + */ + @Override + public DoubleMatrix multiply(DoubleMatrix other) { + Preconditions + .checkArgument( + this.numColumns == other.getRowCount(), + String + .format( + "Matrix with size [%d, %d] cannot multiple matrix with size [%d, %d]", + this.numRows, this.numColumns, other.getRowCount(), + other.getColumnCount())); + + return this.multiplyUnsafe(other); + } + + /* + * (non-Javadoc) + * @see + * org.apache.hama.ml.math.DoubleMatrix#safeMultiplyElementWise(org.apache + * .hama.ml.math.DoubleMatrix) + */ + @Override + public DoubleMatrix multiplyElementWise(DoubleMatrix other) { + Preconditions.checkArgument(this.numRows != other.getRowCount() + || this.numColumns != other.getColumnCount(), + "Matrices with different dimensions cannot be multiplied elementwise."); + return this.multiplyElementWiseUnsafe(other); + } + + /* + * (non-Javadoc) + * @see + * org.apache.hama.ml.math.DoubleMatrix#safeMultiplyVector(org.apache.hama + * .ml.math.DoubleVector) + */ + @Override + public DoubleVector multiplyVector(DoubleVector v) { + Preconditions.checkArgument(this.numColumns == v.getDimension(), + "Dimension mismatch."); + return this.multiplyVectorUnsafe(v); + } + + /* + * (non-Javadoc) + * @see org.apache.hama.ml.math.DoubleMatrix#subtract(org.apache.hama.ml.math. + * DoubleMatrix) + */ + @Override + public DoubleMatrix subtract(DoubleMatrix other) { + Preconditions.checkArgument(this.numRows == other.getRowCount() + && this.numColumns == other.getColumnCount(), "Dimension mismatch."); + return subtractUnsafe(other); + } + + /* + * (non-Javadoc) + * @see org.apache.hama.ml.math.DoubleMatrix#subtract(org.apache.hama.ml.math. + * DoubleVector) + */ + @Override + public DoubleMatrix subtract(DoubleVector vec) { + Preconditions.checkArgument(this.numColumns == vec.getDimension(), + "Dimension mismatch."); + return null; + } + } Index: ml/src/main/java/org/apache/hama/ml/math/DenseDoubleVector.java =================================================================== --- ml/src/main/java/org/apache/hama/ml/math/DenseDoubleVector.java (revision 1498967) +++ ml/src/main/java/org/apache/hama/ml/math/DenseDoubleVector.java (working copy) @@ -24,6 +24,7 @@ import java.util.Iterator; import java.util.List; +import com.google.common.base.Preconditions; import com.google.common.collect.AbstractIterator; /** @@ -112,16 +113,17 @@ } /** - * {@inheritDoc}} + * {@inheritDoc} */ @Override - public DoubleVector applyToElements(DoubleVector other, DoubleDoubleFunction func) { + public DoubleVector applyToElements(DoubleVector other, + DoubleDoubleFunction func) { for (int i = 0; i < vector.length; i++) { this.vector[i] = func.apply(vector[i], other.get(i)); } return this; } - + /* * (non-Javadoc) * @see de.jungblut.math.DoubleVector#apply(de.jungblut.math.function. @@ -157,7 +159,7 @@ * @see de.jungblut.math.DoubleVector#add(de.jungblut.math.DoubleVector) */ @Override - public final DoubleVector add(DoubleVector v) { + public final DoubleVector addUnsafe(DoubleVector v) { DenseDoubleVector newv = new DenseDoubleVector(v.getLength()); for (int i = 0; i < v.getLength(); i++) { newv.set(i, this.get(i) + v.get(i)); @@ -183,7 +185,7 @@ * @see de.jungblut.math.DoubleVector#subtract(de.jungblut.math.DoubleVector) */ @Override - public final DoubleVector subtract(DoubleVector v) { + public final DoubleVector subtractUnsafe(DoubleVector v) { DoubleVector newv = new DenseDoubleVector(v.getLength()); for (int i = 0; i < v.getLength(); i++) { newv.set(i, this.get(i) - v.get(i)); @@ -235,7 +237,7 @@ * @see de.jungblut.math.DoubleVector#multiply(de.jungblut.math.DoubleVector) */ @Override - public DoubleVector multiply(DoubleVector vector) { + public DoubleVector multiplyUnsafe(DoubleVector vector) { DoubleVector v = new DenseDoubleVector(this.getLength()); for (int i = 0; i < v.getLength(); i++) { v.set(i, this.get(i) * vector.get(i)); @@ -338,10 +340,10 @@ * @see de.jungblut.math.DoubleVector#dot(de.jungblut.math.DoubleVector) */ @Override - public double dot(DoubleVector s) { + public double dotUnsafe(DoubleVector vector) { double dotProduct = 0.0d; for (int i = 0; i < getLength(); i++) { - dotProduct += this.get(i) * s.get(i); + dotProduct += this.get(i) * vector.get(i); } return dotProduct; } @@ -652,4 +654,54 @@ return null; } + /* + * (non-Javadoc) + * @see org.apache.hama.ml.math.DoubleVector#safeAdd(org.apache.hama.ml.math. + * DoubleVector) + */ + @Override + public DoubleVector add(DoubleVector vector) { + Preconditions.checkArgument(this.vector.length == vector.getDimension(), + "Dimensions of two vectors do not equal."); + return this.addUnsafe(vector); + } + + /* + * (non-Javadoc) + * @see + * org.apache.hama.ml.math.DoubleVector#safeSubtract(org.apache.hama.ml.math + * .DoubleVector) + */ + @Override + public DoubleVector subtract(DoubleVector vector) { + Preconditions.checkArgument(this.vector.length == vector.getDimension(), + "Dimensions of two vectors do not equal."); + return this.subtractUnsafe(vector); + } + + /* + * (non-Javadoc) + * @see + * org.apache.hama.ml.math.DoubleVector#safeMultiplay(org.apache.hama.ml.math + * .DoubleVector) + */ + @Override + public DoubleVector multiplay(DoubleVector vector) { + Preconditions.checkArgument(this.vector.length == vector.getDimension(), + "Dimensions of two vectors do not equal."); + return this.multiplyUnsafe(vector); + } + + /* + * (non-Javadoc) + * @see org.apache.hama.ml.math.DoubleVector#safeDot(org.apache.hama.ml.math. + * DoubleVector) + */ + @Override + public double dot(DoubleVector vector) { + Preconditions.checkArgument(this.vector.length == vector.getDimension(), + "Dimensions of two vectors do not equal."); + return this.dotUnsafe(vector); + } + } Index: ml/src/main/java/org/apache/hama/ml/math/DoubleMatrix.java =================================================================== --- ml/src/main/java/org/apache/hama/ml/math/DoubleMatrix.java (revision 1498967) +++ ml/src/main/java/org/apache/hama/ml/math/DoubleMatrix.java (working copy) @@ -80,18 +80,47 @@ /** * Multiplies this matrix with the given other matrix. + * + * @param other the other matrix. + * @return */ + public DoubleMatrix multiplyUnsafe(DoubleMatrix other); + + /** + * Validates the input and multiplies this matrix with the given other matrix. + * + * @param other the other matrix. + * @return + */ public DoubleMatrix multiply(DoubleMatrix other); /** * Multiplies this matrix per element with a given matrix. */ + public DoubleMatrix multiplyElementWiseUnsafe(DoubleMatrix other); + + /** + * Validates the input and multiplies this matrix per element with a given + * matrix. + * + * @param other the other matrix + * @return + */ public DoubleMatrix multiplyElementWise(DoubleMatrix other); /** * Multiplies this matrix with a given vector v. The returning vector contains * the sum of the rows. */ + public DoubleVector multiplyVectorUnsafe(DoubleVector v); + + /** + * Multiplies this matrix with a given vector v. The returning vector contains + * the sum of the rows. + * + * @param v the vector + * @return + */ public DoubleVector multiplyVector(DoubleVector v); /** @@ -114,12 +143,29 @@ /** * Subtracts this matrix by the given other matrix. */ + public DoubleMatrix subtractUnsafe(DoubleMatrix other); + + /** + * Validates the input and subtracts this matrix by the given other matrix. + * + * @param other + * @return + */ public DoubleMatrix subtract(DoubleMatrix other); /** * Subtracts each element in a column by the related element in the given * vector. */ + public DoubleMatrix subtractUnsafe(DoubleVector vec); + + /** + * Validates and subtracts each element in a column by the related element in + * the given vector. + * + * @param vec + * @return + */ public DoubleMatrix subtract(DoubleVector vec); /** @@ -203,6 +249,7 @@ * @param fun The function that takes two arguments. * @return The matrix itself, supply for chain operation. */ - public DoubleMatrix applyToElements(DoubleMatrix other, DoubleDoubleFunction fun); + public DoubleMatrix applyToElements(DoubleMatrix other, + DoubleDoubleFunction fun); } Index: ml/src/main/java/org/apache/hama/ml/math/DoubleVector.java =================================================================== --- ml/src/main/java/org/apache/hama/ml/math/DoubleVector.java (revision 1498967) +++ ml/src/main/java/org/apache/hama/ml/math/DoubleVector.java (working copy) @@ -58,7 +58,7 @@ * @param value the value at the index of the vector to set. */ public void set(int index, double value); - + /** * Apply a given {@link DoubleVectorFunction} to this vector and return a new * one. @@ -68,7 +68,7 @@ */ @Deprecated public DoubleVector apply(DoubleVectorFunction func); - + /** * Apply a given {@link DoubleDoubleVectorFunction} to this vector and the * other given vector. @@ -97,17 +97,26 @@ * @param func the function to apply on this and the other vector. * @return a new vector with the result of the function of the two vectors. */ - public DoubleVector applyToElements(DoubleVector other, DoubleDoubleFunction func); + public DoubleVector applyToElements(DoubleVector other, + DoubleDoubleFunction func); /** * Adds the given {@link DoubleVector} to this vector. * - * @param v the other vector. + * @param vector the other vector. * @return a new vector with the sum of both vectors at each element index. */ - public DoubleVector add(DoubleVector v); + public DoubleVector addUnsafe(DoubleVector vector); /** + * Validates the input and adds the given {@link DoubleVector} to this vector. + * + * @param vector the other vector. + * @return a new vector with the sum of both vectors at each element index. + */ + public DoubleVector add(DoubleVector vector); + + /** * Adds the given scalar to this vector. * * @param scalar the scalar. @@ -118,12 +127,21 @@ /** * Subtracts this vector by the given {@link DoubleVector}. * - * @param v the other vector. + * @param vector the other vector. * @return a new vector with the difference of both vectors. */ - public DoubleVector subtract(DoubleVector v); + public DoubleVector subtractUnsafe(DoubleVector vector); /** + * Validates the input and subtracts this vector by the given + * {@link DoubleVector}. + * + * @param vector the other vector. + * @return a new vector with the difference of both vectors. + */ + public DoubleVector subtract(DoubleVector vector); + + /** * Subtracts the given scalar to this vector. (vector - scalar). * * @param scalar the scalar. @@ -153,9 +171,18 @@ * @param vector the other vector. * @return a new vector with the result of the operation. */ - public DoubleVector multiply(DoubleVector vector); + public DoubleVector multiplyUnsafe(DoubleVector vector); /** + * Validates the input and multiplies the given {@link DoubleVector} with this + * vector. + * + * @param vector the other vector. + * @return a new vector with the result of the operation. + */ + public DoubleVector multiplay(DoubleVector vector); + + /** * Divides this vector by the given scalar. (= vector/scalar). * * @param scalar the given scalar. @@ -201,12 +228,21 @@ /** * Calculates the dot product between this vector and the given vector. * - * @param s the given vector s. + * @param vector the given vector. * @return the dot product as a double. */ - public double dot(DoubleVector s); + public double dotUnsafe(DoubleVector vector); /** + * Validates the input and calculates the dot product between this vector and + * the given vector. + * + * @param vector the given vector. + * @return the dot product as a double. + */ + public double dot(DoubleVector vector); + + /** * Slices this vector from index 0 to the given length. * * @param length must be > 0 and smaller than the dimension of the vector. Index: ml/src/main/java/org/apache/hama/ml/regression/LinearRegressionModel.java =================================================================== --- ml/src/main/java/org/apache/hama/ml/regression/LinearRegressionModel.java (revision 1498967) +++ ml/src/main/java/org/apache/hama/ml/regression/LinearRegressionModel.java (working copy) @@ -38,7 +38,7 @@ @Override public double applyHypothesis(DoubleVector theta, DoubleVector x) { - return theta.dot(x); + return theta.dotUnsafe(x); } @Override Index: ml/src/main/java/org/apache/hama/ml/regression/LogisticRegressionModel.java =================================================================== --- ml/src/main/java/org/apache/hama/ml/regression/LogisticRegressionModel.java (revision 1498967) +++ ml/src/main/java/org/apache/hama/ml/regression/LogisticRegressionModel.java (working copy) @@ -53,7 +53,7 @@ DoubleVector x) { return BigDecimal.valueOf(1).divide( BigDecimal.valueOf(1d).add( - BigDecimal.valueOf(Math.exp(-1d * theta.dot(x)))), + BigDecimal.valueOf(Math.exp(-1d * theta.dotUnsafe(x)))), MathContext.DECIMAL128); } Index: ml/src/main/java/org/apache/hama/ml/writable/VectorWritable.java =================================================================== --- ml/src/main/java/org/apache/hama/ml/writable/VectorWritable.java (revision 1498967) +++ ml/src/main/java/org/apache/hama/ml/writable/VectorWritable.java (working copy) @@ -119,7 +119,7 @@ } public static int compareVector(DoubleVector a, DoubleVector o) { - DoubleVector subtract = a.subtract(o); + DoubleVector subtract = a.subtractUnsafe(o); return (int) subtract.sum(); }