Index: ml/src/main/java/org/apache/hama/ml/distance/CosineDistance.java =================================================================== --- ml/src/main/java/org/apache/hama/ml/distance/CosineDistance.java (revision 1500180) +++ ml/src/main/java/org/apache/hama/ml/distance/CosineDistance.java (working copy) @@ -50,7 +50,7 @@ double lengthSquaredv1 = vec1.pow(2).sum(); double lengthSquaredv2 = vec2.pow(2).sum(); - double dotProduct = vec2.dot(vec1); + double dotProduct = vec2.dotUnsafe(vec1); double denominator = Math.sqrt(lengthSquaredv1) * Math.sqrt(lengthSquaredv2); Index: ml/src/main/java/org/apache/hama/ml/distance/EuclidianDistance.java =================================================================== --- ml/src/main/java/org/apache/hama/ml/distance/EuclidianDistance.java (revision 1500180) +++ ml/src/main/java/org/apache/hama/ml/distance/EuclidianDistance.java (working copy) @@ -36,7 +36,7 @@ @Override public double measureDistance(DoubleVector vec1, DoubleVector vec2) { - return Math.sqrt(vec2.subtract(vec1).pow(2).sum()); + return Math.sqrt(vec2.subtractUnsafe(vec1).pow(2).sum()); } } Index: ml/src/main/java/org/apache/hama/ml/kmeans/KMeansBSP.java =================================================================== --- ml/src/main/java/org/apache/hama/ml/kmeans/KMeansBSP.java (revision 1500180) +++ ml/src/main/java/org/apache/hama/ml/kmeans/KMeansBSP.java (working copy) @@ -162,7 +162,7 @@ if (oldCenter == null) { msgCenters[msg.getCenterIndex()] = newCenter; } else { - msgCenters[msg.getCenterIndex()] = oldCenter.add(newCenter); + msgCenters[msg.getCenterIndex()] = oldCenter.addUnsafe(newCenter); } } // divide by how often we globally summed vectors @@ -177,7 +177,7 @@ for (int i = 0; i < msgCenters.length; i++) { final DoubleVector oldCenter = centers[i]; if (msgCenters[i] != null) { - double calculateError = oldCenter.subtract(msgCenters[i]).abs().sum(); + double calculateError = oldCenter.subtractUnsafe(msgCenters[i]).abs().sum(); if (calculateError > 0.0d) { centers[i] = msgCenters[i]; convergedCounter++; @@ -241,7 +241,7 @@ } else { // add the vector to the center newCenterArray[lowestDistantCenter] = newCenterArray[lowestDistantCenter] - .add(key); + .addUnsafe(key); summationCount[lowestDistantCenter]++; } } Index: ml/src/main/java/org/apache/hama/ml/math/DenseDoubleMatrix.java =================================================================== --- ml/src/main/java/org/apache/hama/ml/math/DenseDoubleMatrix.java (revision 1500180) +++ ml/src/main/java/org/apache/hama/ml/math/DenseDoubleMatrix.java (working copy) @@ -21,6 +21,8 @@ import java.util.HashSet; import java.util.Random; +import com.google.common.base.Preconditions; + /** * Dense double matrix implementation, internally uses two dimensional double * arrays. @@ -384,7 +386,7 @@ * @see de.jungblut.math.DoubleMatrix#multiply(de.jungblut.math.DoubleMatrix) */ @Override - public final DoubleMatrix multiply(DoubleMatrix other) { + public final DoubleMatrix multiplyUnsafe(DoubleMatrix other) { DenseDoubleMatrix matrix = new DenseDoubleMatrix(this.getRowCount(), other.getColumnCount()); @@ -412,7 +414,7 @@ * ) */ @Override - public final DoubleMatrix multiplyElementWise(DoubleMatrix other) { + public final DoubleMatrix multiplyElementWiseUnsafe(DoubleMatrix other) { DenseDoubleMatrix matrix = new DenseDoubleMatrix(this.numRows, this.numColumns); @@ -431,7 +433,7 @@ * de.jungblut.math.DoubleMatrix#multiplyVector(de.jungblut.math.DoubleVector) */ @Override - public final DoubleVector multiplyVector(DoubleVector v) { + public final DoubleVector multiplyVectorUnsafe(DoubleVector v) { DoubleVector vector = new DenseDoubleVector(this.getRowCount()); for (int row = 0; row < numRows; row++) { double sum = 0.0d; @@ -494,7 +496,7 @@ * @see de.jungblut.math.DoubleMatrix#subtract(de.jungblut.math.DoubleMatrix) */ @Override - public DoubleMatrix subtract(DoubleMatrix other) { + public DoubleMatrix subtractUnsafe(DoubleMatrix other) { DoubleMatrix m = new DenseDoubleMatrix(this.numRows, this.numColumns); for (int i = 0; i < numRows; i++) { for (int j = 0; j < numColumns; j++) { @@ -509,7 +511,7 @@ * @see de.jungblut.math.DoubleMatrix#subtract(de.jungblut.math.DoubleVector) */ @Override - public DenseDoubleMatrix subtract(DoubleVector vec) { + public DenseDoubleMatrix subtractUnsafe(DoubleVector vec) { DenseDoubleMatrix cop = new DenseDoubleMatrix(this.getRowCount(), this.getColumnCount()); for (int i = 0; i < this.getColumnCount(); i++) { @@ -523,7 +525,7 @@ * @see de.jungblut.math.DoubleMatrix#divide(de.jungblut.math.DoubleVector) */ @Override - public DoubleMatrix divide(DoubleVector vec) { + public DoubleMatrix divideUnsafe(DoubleVector vec) { DoubleMatrix cop = new DenseDoubleMatrix(this.getRowCount(), this.getColumnCount()); for (int i = 0; i < this.getColumnCount(); i++) { @@ -532,12 +534,22 @@ return cop; } + /** + * {@inheritDoc} + */ + @Override + public DoubleMatrix divide(DoubleVector vec) { + Preconditions.checkArgument(this.getColumnCount() == vec.getDimension(), + "Dimension mismatch."); + return this.divideUnsafe(vec); + } + /* * (non-Javadoc) * @see de.jungblut.math.DoubleMatrix#divide(de.jungblut.math.DoubleMatrix) */ @Override - public DoubleMatrix divide(DoubleMatrix other) { + public DoubleMatrix divideUnsafe(DoubleMatrix other) { DoubleMatrix m = new DenseDoubleMatrix(this.numRows, this.numColumns); for (int i = 0; i < numRows; i++) { for (int j = 0; j < numColumns; j++) { @@ -547,6 +559,13 @@ return m; } + @Override + public DoubleMatrix divide(DoubleMatrix other) { + Preconditions.checkArgument(this.getRowCount() == other.getRowCount() + && this.getColumnCount() == other.getColumnCount()); + return divideUnsafe(other); + } + /* * (non-Javadoc) * @see de.jungblut.math.DoubleMatrix#divide(double) @@ -775,7 +794,7 @@ * Just a absolute error function. */ public static double error(DenseDoubleMatrix a, DenseDoubleMatrix b) { - return a.subtract(b).sum(); + return a.subtractUnsafe(b).sum(); } @Override @@ -795,20 +814,91 @@ /** * {@inheritDoc} */ - public DoubleMatrix applyToElements(DoubleMatrix other, DoubleDoubleFunction fun) { - if (this.numRows != other.getRowCount() - || this.numColumns != other.getColumnCount()) { - throw new IllegalArgumentException( - "Cannot apply double double function to matrices with different sizes."); - } - + public DoubleMatrix applyToElements(DoubleMatrix other, + DoubleDoubleFunction fun) { + Preconditions + .checkArgument(this.numRows == other.getRowCount() + && this.numColumns == other.getColumnCount(), + "Cannot apply double double function to matrices with different sizes."); + for (int r = 0; r < this.numRows; ++r) { for (int c = 0; c < this.numColumns; ++c) { this.set(r, c, fun.apply(this.get(r, c), other.get(r, c))); } } - + return this; } + /* + * (non-Javadoc) + * @see + * org.apache.hama.ml.math.DoubleMatrix#safeMultiply(org.apache.hama.ml.math + * .DoubleMatrix) + */ + @Override + public DoubleMatrix multiply(DoubleMatrix other) { + Preconditions + .checkArgument( + this.numColumns == other.getRowCount(), + String + .format( + "Matrix with size [%d, %d] cannot multiple matrix with size [%d, %d]", + this.numRows, this.numColumns, other.getRowCount(), + other.getColumnCount())); + + return this.multiplyUnsafe(other); + } + + /* + * (non-Javadoc) + * @see + * org.apache.hama.ml.math.DoubleMatrix#safeMultiplyElementWise(org.apache + * .hama.ml.math.DoubleMatrix) + */ + @Override + public DoubleMatrix multiplyElementWise(DoubleMatrix other) { + Preconditions.checkArgument(this.numRows == other.getRowCount() + && this.numColumns == other.getColumnCount(), + "Matrices with different dimensions cannot be multiplied elementwise."); + return this.multiplyElementWiseUnsafe(other); + } + + /* + * (non-Javadoc) + * @see + * org.apache.hama.ml.math.DoubleMatrix#safeMultiplyVector(org.apache.hama + * .ml.math.DoubleVector) + */ + @Override + public DoubleVector multiplyVector(DoubleVector v) { + Preconditions.checkArgument(this.numColumns == v.getDimension(), + "Dimension mismatch."); + return this.multiplyVectorUnsafe(v); + } + + /* + * (non-Javadoc) + * @see org.apache.hama.ml.math.DoubleMatrix#subtract(org.apache.hama.ml.math. + * DoubleMatrix) + */ + @Override + public DoubleMatrix subtract(DoubleMatrix other) { + Preconditions.checkArgument(this.numRows == other.getRowCount() + && this.numColumns == other.getColumnCount(), "Dimension mismatch."); + return subtractUnsafe(other); + } + + /* + * (non-Javadoc) + * @see org.apache.hama.ml.math.DoubleMatrix#subtract(org.apache.hama.ml.math. + * DoubleVector) + */ + @Override + public DoubleMatrix subtract(DoubleVector vec) { + Preconditions.checkArgument(this.numColumns == vec.getDimension(), + "Dimension mismatch."); + return null; + } + } Index: ml/src/main/java/org/apache/hama/ml/math/DenseDoubleVector.java =================================================================== --- ml/src/main/java/org/apache/hama/ml/math/DenseDoubleVector.java (revision 1500180) +++ ml/src/main/java/org/apache/hama/ml/math/DenseDoubleVector.java (working copy) @@ -24,6 +24,7 @@ import java.util.Iterator; import java.util.List; +import com.google.common.base.Preconditions; import com.google.common.collect.AbstractIterator; /** @@ -112,16 +113,17 @@ } /** - * {@inheritDoc}} + * {@inheritDoc} */ @Override - public DoubleVector applyToElements(DoubleVector other, DoubleDoubleFunction func) { + public DoubleVector applyToElements(DoubleVector other, + DoubleDoubleFunction func) { for (int i = 0; i < vector.length; i++) { this.vector[i] = func.apply(vector[i], other.get(i)); } return this; } - + /* * (non-Javadoc) * @see de.jungblut.math.DoubleVector#apply(de.jungblut.math.function. @@ -157,7 +159,7 @@ * @see de.jungblut.math.DoubleVector#add(de.jungblut.math.DoubleVector) */ @Override - public final DoubleVector add(DoubleVector v) { + public final DoubleVector addUnsafe(DoubleVector v) { DenseDoubleVector newv = new DenseDoubleVector(v.getLength()); for (int i = 0; i < v.getLength(); i++) { newv.set(i, this.get(i) + v.get(i)); @@ -183,7 +185,7 @@ * @see de.jungblut.math.DoubleVector#subtract(de.jungblut.math.DoubleVector) */ @Override - public final DoubleVector subtract(DoubleVector v) { + public final DoubleVector subtractUnsafe(DoubleVector v) { DoubleVector newv = new DenseDoubleVector(v.getLength()); for (int i = 0; i < v.getLength(); i++) { newv.set(i, this.get(i) - v.get(i)); @@ -235,7 +237,7 @@ * @see de.jungblut.math.DoubleVector#multiply(de.jungblut.math.DoubleVector) */ @Override - public DoubleVector multiply(DoubleVector vector) { + public DoubleVector multiplyUnsafe(DoubleVector vector) { DoubleVector v = new DenseDoubleVector(this.getLength()); for (int i = 0; i < v.getLength(); i++) { v.set(i, this.get(i) * vector.get(i)); @@ -338,10 +340,10 @@ * @see de.jungblut.math.DoubleVector#dot(de.jungblut.math.DoubleVector) */ @Override - public double dot(DoubleVector s) { + public double dotUnsafe(DoubleVector vector) { double dotProduct = 0.0d; for (int i = 0; i < getLength(); i++) { - dotProduct += this.get(i) * s.get(i); + dotProduct += this.get(i) * vector.get(i); } return dotProduct; } @@ -652,4 +654,54 @@ return null; } + /* + * (non-Javadoc) + * @see org.apache.hama.ml.math.DoubleVector#safeAdd(org.apache.hama.ml.math. + * DoubleVector) + */ + @Override + public DoubleVector add(DoubleVector vector) { + Preconditions.checkArgument(this.vector.length == vector.getDimension(), + "Dimensions of two vectors do not equal."); + return this.addUnsafe(vector); + } + + /* + * (non-Javadoc) + * @see + * org.apache.hama.ml.math.DoubleVector#safeSubtract(org.apache.hama.ml.math + * .DoubleVector) + */ + @Override + public DoubleVector subtract(DoubleVector vector) { + Preconditions.checkArgument(this.vector.length == vector.getDimension(), + "Dimensions of two vectors do not equal."); + return this.subtractUnsafe(vector); + } + + /* + * (non-Javadoc) + * @see + * org.apache.hama.ml.math.DoubleVector#safeMultiplay(org.apache.hama.ml.math + * .DoubleVector) + */ + @Override + public DoubleVector multiply(DoubleVector vector) { + Preconditions.checkArgument(this.vector.length == vector.getDimension(), + "Dimensions of two vectors do not equal."); + return this.multiplyUnsafe(vector); + } + + /* + * (non-Javadoc) + * @see org.apache.hama.ml.math.DoubleVector#safeDot(org.apache.hama.ml.math. + * DoubleVector) + */ + @Override + public double dot(DoubleVector vector) { + Preconditions.checkArgument(this.vector.length == vector.getDimension(), + "Dimensions of two vectors do not equal."); + return this.dotUnsafe(vector); + } + } Index: ml/src/main/java/org/apache/hama/ml/math/DoubleMatrix.java =================================================================== --- ml/src/main/java/org/apache/hama/ml/math/DoubleMatrix.java (revision 1500180) +++ ml/src/main/java/org/apache/hama/ml/math/DoubleMatrix.java (working copy) @@ -80,18 +80,47 @@ /** * Multiplies this matrix with the given other matrix. + * + * @param other the other matrix. + * @return */ + public DoubleMatrix multiplyUnsafe(DoubleMatrix other); + + /** + * Validates the input and multiplies this matrix with the given other matrix. + * + * @param other the other matrix. + * @return + */ public DoubleMatrix multiply(DoubleMatrix other); /** * Multiplies this matrix per element with a given matrix. */ + public DoubleMatrix multiplyElementWiseUnsafe(DoubleMatrix other); + + /** + * Validates the input and multiplies this matrix per element with a given + * matrix. + * + * @param other the other matrix + * @return + */ public DoubleMatrix multiplyElementWise(DoubleMatrix other); /** * Multiplies this matrix with a given vector v. The returning vector contains * the sum of the rows. */ + public DoubleVector multiplyVectorUnsafe(DoubleVector v); + + /** + * Multiplies this matrix with a given vector v. The returning vector contains + * the sum of the rows. + * + * @param v the vector + * @return + */ public DoubleVector multiplyVector(DoubleVector v); /** @@ -114,23 +143,58 @@ /** * Subtracts this matrix by the given other matrix. */ + public DoubleMatrix subtractUnsafe(DoubleMatrix other); + + /** + * Validates the input and subtracts this matrix by the given other matrix. + * + * @param other + * @return + */ public DoubleMatrix subtract(DoubleMatrix other); /** * Subtracts each element in a column by the related element in the given * vector. */ + public DoubleMatrix subtractUnsafe(DoubleVector vec); + + /** + * Validates and subtracts each element in a column by the related element in + * the given vector. + * + * @param vec + * @return + */ public DoubleMatrix subtract(DoubleVector vec); /** * Divides each element in a column by the related element in the given * vector. */ + public DoubleMatrix divideUnsafe(DoubleVector vec); + + /** + * Validates and divides each element in a column by the related element in + * the given vector. + * + * @param vec + * @return + */ public DoubleMatrix divide(DoubleVector vec); /** * Divides this matrix by the given other matrix. (Per element division). */ + public DoubleMatrix divideUnsafe(DoubleMatrix other); + + /** + * Validates and divides this matrix by the given other matrix. (Per element + * division). + * + * @param other + * @return + */ public DoubleMatrix divide(DoubleMatrix other); /** @@ -203,6 +267,7 @@ * @param fun The function that takes two arguments. * @return The matrix itself, supply for chain operation. */ - public DoubleMatrix applyToElements(DoubleMatrix other, DoubleDoubleFunction fun); + public DoubleMatrix applyToElements(DoubleMatrix other, + DoubleDoubleFunction fun); } Index: ml/src/main/java/org/apache/hama/ml/math/DoubleVector.java =================================================================== --- ml/src/main/java/org/apache/hama/ml/math/DoubleVector.java (revision 1500180) +++ ml/src/main/java/org/apache/hama/ml/math/DoubleVector.java (working copy) @@ -58,7 +58,7 @@ * @param value the value at the index of the vector to set. */ public void set(int index, double value); - + /** * Apply a given {@link DoubleVectorFunction} to this vector and return a new * one. @@ -68,7 +68,7 @@ */ @Deprecated public DoubleVector apply(DoubleVectorFunction func); - + /** * Apply a given {@link DoubleDoubleVectorFunction} to this vector and the * other given vector. @@ -97,17 +97,26 @@ * @param func the function to apply on this and the other vector. * @return a new vector with the result of the function of the two vectors. */ - public DoubleVector applyToElements(DoubleVector other, DoubleDoubleFunction func); + public DoubleVector applyToElements(DoubleVector other, + DoubleDoubleFunction func); /** * Adds the given {@link DoubleVector} to this vector. * - * @param v the other vector. + * @param vector the other vector. * @return a new vector with the sum of both vectors at each element index. */ - public DoubleVector add(DoubleVector v); + public DoubleVector addUnsafe(DoubleVector vector); /** + * Validates the input and adds the given {@link DoubleVector} to this vector. + * + * @param vector the other vector. + * @return a new vector with the sum of both vectors at each element index. + */ + public DoubleVector add(DoubleVector vector); + + /** * Adds the given scalar to this vector. * * @param scalar the scalar. @@ -118,12 +127,21 @@ /** * Subtracts this vector by the given {@link DoubleVector}. * - * @param v the other vector. + * @param vector the other vector. * @return a new vector with the difference of both vectors. */ - public DoubleVector subtract(DoubleVector v); + public DoubleVector subtractUnsafe(DoubleVector vector); /** + * Validates the input and subtracts this vector by the given + * {@link DoubleVector}. + * + * @param vector the other vector. + * @return a new vector with the difference of both vectors. + */ + public DoubleVector subtract(DoubleVector vector); + + /** * Subtracts the given scalar to this vector. (vector - scalar). * * @param scalar the scalar. @@ -153,6 +171,15 @@ * @param vector the other vector. * @return a new vector with the result of the operation. */ + public DoubleVector multiplyUnsafe(DoubleVector vector); + + /** + * Validates the input and multiplies the given {@link DoubleVector} with this + * vector. + * + * @param vector the other vector. + * @return a new vector with the result of the operation. + */ public DoubleVector multiply(DoubleVector vector); /** @@ -201,12 +228,21 @@ /** * Calculates the dot product between this vector and the given vector. * - * @param s the given vector s. + * @param vector the given vector. * @return the dot product as a double. */ - public double dot(DoubleVector s); + public double dotUnsafe(DoubleVector vector); /** + * Validates the input and calculates the dot product between this vector and + * the given vector. + * + * @param vector the given vector. + * @return the dot product as a double. + */ + public double dot(DoubleVector vector); + + /** * Slices this vector from index 0 to the given length. * * @param length must be > 0 and smaller than the dimension of the vector. Index: ml/src/main/java/org/apache/hama/ml/regression/LinearRegressionModel.java =================================================================== --- ml/src/main/java/org/apache/hama/ml/regression/LinearRegressionModel.java (revision 1500180) +++ ml/src/main/java/org/apache/hama/ml/regression/LinearRegressionModel.java (working copy) @@ -38,7 +38,7 @@ @Override public double applyHypothesis(DoubleVector theta, DoubleVector x) { - return theta.dot(x); + return theta.dotUnsafe(x); } @Override Index: ml/src/main/java/org/apache/hama/ml/regression/LogisticRegressionModel.java =================================================================== --- ml/src/main/java/org/apache/hama/ml/regression/LogisticRegressionModel.java (revision 1500180) +++ ml/src/main/java/org/apache/hama/ml/regression/LogisticRegressionModel.java (working copy) @@ -53,7 +53,7 @@ DoubleVector x) { return BigDecimal.valueOf(1).divide( BigDecimal.valueOf(1d).add( - BigDecimal.valueOf(Math.exp(-1d * theta.dot(x)))), + BigDecimal.valueOf(Math.exp(-1d * theta.dotUnsafe(x)))), MathContext.DECIMAL128); } Index: ml/src/main/java/org/apache/hama/ml/writable/VectorWritable.java =================================================================== --- ml/src/main/java/org/apache/hama/ml/writable/VectorWritable.java (revision 1500180) +++ ml/src/main/java/org/apache/hama/ml/writable/VectorWritable.java (working copy) @@ -119,7 +119,7 @@ } public static int compareVector(DoubleVector a, DoubleVector o) { - DoubleVector subtract = a.subtract(o); + DoubleVector subtract = a.subtractUnsafe(o); return (int) subtract.sum(); } Index: ml/src/test/java/org/apache/hama/ml/math/TestDenseDoubleMatrix.java =================================================================== --- ml/src/test/java/org/apache/hama/ml/math/TestDenseDoubleMatrix.java (revision 1500180) +++ ml/src/test/java/org/apache/hama/ml/math/TestDenseDoubleMatrix.java (working copy) @@ -19,6 +19,8 @@ import static org.junit.Assert.assertArrayEquals; +import java.util.Arrays; + import org.junit.Test; /** @@ -57,12 +59,14 @@ @Test public void testDoubleDoubleFunction() { double[][] values1 = new double[][] { { 1, 2, 3 }, { 4, 5, 6 }, { 7, 8, 9 } }; - double[][] values2 = new double[][] { { 2, 3, 4 }, { 5, 6, 7 }, { 8, 9, 10 } }; - double[][] result = new double[][] { {3, 5, 7}, {9, 11, 13}, {15, 17, 19}}; + double[][] values2 = new double[][] { { 2, 3, 4 }, { 5, 6, 7 }, + { 8, 9, 10 } }; + double[][] result = new double[][] { { 3, 5, 7 }, { 9, 11, 13 }, + { 15, 17, 19 } }; DenseDoubleMatrix mat1 = new DenseDoubleMatrix(values1); DenseDoubleMatrix mat2 = new DenseDoubleMatrix(values2); - + mat1.applyToElements(mat2, new DoubleDoubleFunction() { @Override @@ -83,4 +87,153 @@ } } + @Test + public void testMultiplyNormal() { + double[][] mat1 = new double[][] { { 1, 2, 3 }, { 4, 5, 6 } }; + double[][] mat2 = new double[][] { { 6, 5 }, { 4, 3 }, { 2, 1 } }; + double[][] expMat = new double[][] { { 20, 14 }, { 56, 41 } }; + DoubleMatrix matrix1 = new DenseDoubleMatrix(mat1); + DoubleMatrix matrix2 = new DenseDoubleMatrix(mat2); + DoubleMatrix actMatrix = matrix1.multiply(matrix2); + for (int r = 0; r < actMatrix.getRowCount(); ++r) { + assertArrayEquals(expMat[r], actMatrix.getRowVector(r).toArray(), + 0.000001); + } + } + + @Test(expected = IllegalArgumentException.class) + public void testMultiplyAbnormal() { + double[][] mat1 = new double[][] { { 1, 2, 3 }, { 4, 5, 6 } }; + double[][] mat2 = new double[][] { { 6, 5 }, { 4, 3 } }; + DoubleMatrix matrix1 = new DenseDoubleMatrix(mat1); + DoubleMatrix matrix2 = new DenseDoubleMatrix(mat2); + matrix1.multiply(matrix2); + } + + @Test + public void testMultiplyElementWiseNormal() { + double[][] mat1 = new double[][] { { 1, 2, 3 }, { 4, 5, 6 } }; + double[][] mat2 = new double[][] { { 6, 5, 4 }, { 3, 2, 1 } }; + double[][] expMat = new double[][] { { 6, 10, 12 }, { 12, 10, 6 } }; + DoubleMatrix matrix1 = new DenseDoubleMatrix(mat1); + DoubleMatrix matrix2 = new DenseDoubleMatrix(mat2); + DoubleMatrix actMatrix = matrix1.multiplyElementWise(matrix2); + for (int r = 0; r < actMatrix.getRowCount(); ++r) { + assertArrayEquals(expMat[r], actMatrix.getRowVector(r).toArray(), + 0.000001); + } + } + + @Test(expected = IllegalArgumentException.class) + public void testMultiplyElementWiseAbnormal() { + double[][] mat1 = new double[][] { { 1, 2, 3 }, { 4, 5, 6 } }; + double[][] mat2 = new double[][] { { 6, 5 }, { 4, 3 } }; + DoubleMatrix matrix1 = new DenseDoubleMatrix(mat1); + DoubleMatrix matrix2 = new DenseDoubleMatrix(mat2); + matrix1.multiplyElementWise(matrix2); + } + + @Test + public void testMultiplyVectorNormal() { + double[][] mat1 = new double[][] { { 1, 2, 3 }, { 4, 5, 6 } }; + double[] mat2 = new double[] { 6, 5, 4 }; + double[] expVec = new double[] { 28, 73 }; + DoubleMatrix matrix1 = new DenseDoubleMatrix(mat1); + DoubleVector vector2 = new DenseDoubleVector(mat2); + DoubleVector actVec = matrix1.multiplyVector(vector2); + assertArrayEquals(expVec, actVec.toArray(), 0.000001); + } + + @Test(expected = IllegalArgumentException.class) + public void testMultiplyVectorAbnormal() { + double[][] mat1 = new double[][] { { 1, 2, 3 }, { 4, 5, 6 } }; + double[] vec2 = new double[] { 6, 5 }; + DoubleMatrix matrix1 = new DenseDoubleMatrix(mat1); + DoubleVector vector2 = new DenseDoubleVector(vec2); + matrix1.multiplyVector(vector2); + } + + @Test + public void testSubtractNormal() { + double[][] mat1 = new double[][] { + {1, 2, 3}, + {4, 5, 6} + }; + double[][] mat2 = new double[][] { + {6, 5, 4}, + {3, 2, 1} + }; + double[][] expMat = new double[][] { + {-5, -3, -1}, + {1, 3, 5} + }; + DoubleMatrix matrix1 = new DenseDoubleMatrix(mat1); + DoubleMatrix matrix2 = new DenseDoubleMatrix(mat2); + DoubleMatrix actMatrix = matrix1.subtract(matrix2); + for (int r = 0; r < actMatrix.getRowCount(); ++r) { + assertArrayEquals(expMat[r], actMatrix.getRowVector(r).toArray(), 0.000001); + } + } + + @Test(expected = IllegalArgumentException.class) + public void testSubtractAbnormal() { + double[][] mat1 = new double[][] { + {1, 2, 3}, + {4, 5, 6} + }; + double[][] mat2 = new double[][] { + {6, 5}, + {4, 3} + }; + DoubleMatrix matrix1 = new DenseDoubleMatrix(mat1); + DoubleMatrix matrix2 = new DenseDoubleMatrix(mat2); + matrix1.subtract(matrix2); + } + + @Test + public void testDivideVectorNormal() { + double[][] mat1 = new double[][] { { 1, 2, 3 }, { 4, 5, 6 } }; + double[] mat2 = new double[] { 6, 5, 4 }; + double[][] expVec = new double[][] { {1.0 / 6, 2.0 / 5, 3.0 / 4}, {4.0 / 6, 5.0 / 5, 6.0 / 4} }; + DoubleMatrix matrix1 = new DenseDoubleMatrix(mat1); + DoubleVector vector2 = new DenseDoubleVector(mat2); + DoubleMatrix expMat = new DenseDoubleMatrix(expVec); + DoubleMatrix actMat = matrix1.divide(vector2); + for (int r = 0; r < actMat.getRowCount(); ++r) { + assertArrayEquals(expMat.getRowVector(r).toArray(), actMat.getRowVector(r).toArray(), 0.000001); + } + } + + @Test(expected = IllegalArgumentException.class) + public void testDivideVectorAbnormal() { + double[][] mat1 = new double[][] { { 1, 2, 3 }, { 4, 5, 6 } }; + double[] vec2 = new double[] { 6, 5 }; + DoubleMatrix matrix1 = new DenseDoubleMatrix(mat1); + DoubleVector vector2 = new DenseDoubleVector(vec2); + matrix1.divide(vector2); + } + + @Test + public void testDivideNormal() { + double[][] mat1 = new double[][] { { 1, 2, 3 }, { 4, 5, 6 } }; + double[][] mat2 = new double[][] { { 6, 5, 4 }, { 3, 2, 1 } }; + double[][] expMat = new double[][] { { 1.0 / 6, 2.0 / 5, 3.0 / 4 }, { 4.0 / 3, 5.0 / 2, 6.0 / 1 } }; + DoubleMatrix matrix1 = new DenseDoubleMatrix(mat1); + DoubleMatrix matrix2 = new DenseDoubleMatrix(mat2); + DoubleMatrix actMatrix = matrix1.divide(matrix2); + for (int r = 0; r < actMatrix.getRowCount(); ++r) { + assertArrayEquals(expMat[r], actMatrix.getRowVector(r).toArray(), + 0.000001); + } + } + + @Test(expected = IllegalArgumentException.class) + public void testDivideAbnormal() { + double[][] mat1 = new double[][] { { 1, 2, 3 }, { 4, 5, 6 } }; + double[][] mat2 = new double[][] { { 6, 5 }, { 4, 3 } }; + DoubleMatrix matrix1 = new DenseDoubleMatrix(mat1); + DoubleMatrix matrix2 = new DenseDoubleMatrix(mat2); + matrix1.divide(matrix2); + } + } Index: ml/src/test/java/org/apache/hama/ml/math/TestDenseDoubleVector.java =================================================================== --- ml/src/test/java/org/apache/hama/ml/math/TestDenseDoubleVector.java (revision 1500180) +++ ml/src/test/java/org/apache/hama/ml/math/TestDenseDoubleVector.java (working copy) @@ -18,8 +18,11 @@ package org.apache.hama.ml.math; import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertEquals; +import org.junit.Rule; import org.junit.Test; +import org.junit.rules.ExpectedException; /** * Testcase for {@link DenseDoubleVector} @@ -77,4 +80,79 @@ assertArrayEquals(result, vec1.toArray(), 0.0001); } + + @Test + public void testAddNormal() { + double[] arr1 = new double[] {1, 2, 3}; + double[] arr2 = new double[] {4, 5, 6}; + DoubleVector vec1 = new DenseDoubleVector(arr1); + DoubleVector vec2 = new DenseDoubleVector(arr2); + double[] arrExp = new double[] {5, 7, 9}; + assertArrayEquals(arrExp, vec1.add(vec2).toArray(), 0.000001); + } + + @Test(expected = IllegalArgumentException.class) + public void testAddAbnormal() { + double[] arr1 = new double[] {1, 2, 3}; + double[] arr2 = new double[] {4, 5}; + DoubleVector vec1 = new DenseDoubleVector(arr1); + DoubleVector vec2 = new DenseDoubleVector(arr2); + vec1.add(vec2); + } + + @Test + public void testSubtractNormal() { + double[] arr1 = new double[] {1, 2, 3}; + double[] arr2 = new double[] {4, 5, 6}; + DoubleVector vec1 = new DenseDoubleVector(arr1); + DoubleVector vec2 = new DenseDoubleVector(arr2); + double[] arrExp = new double[] {-3, -3, -3}; + assertArrayEquals(arrExp, vec1.subtract(vec2).toArray(), 0.000001); + } + + @Test(expected = IllegalArgumentException.class) + public void testSubtractAbnormal() { + double[] arr1 = new double[] {1, 2, 3}; + double[] arr2 = new double[] {4, 5}; + DoubleVector vec1 = new DenseDoubleVector(arr1); + DoubleVector vec2 = new DenseDoubleVector(arr2); + vec1.subtract(vec2); + } + + @Test + public void testMultiplyNormal() { + double[] arr1 = new double[] {1, 2, 3}; + double[] arr2 = new double[] {4, 5, 6}; + DoubleVector vec1 = new DenseDoubleVector(arr1); + DoubleVector vec2 = new DenseDoubleVector(arr2); + double[] arrExp = new double[] {4, 10, 18}; + assertArrayEquals(arrExp, vec1.multiply(vec2).toArray(), 0.000001); + } + + @Test(expected = IllegalArgumentException.class) + public void testMultiplyAbnormal() { + double[] arr1 = new double[] {1, 2, 3}; + double[] arr2 = new double[] {4, 5}; + DoubleVector vec1 = new DenseDoubleVector(arr1); + DoubleVector vec2 = new DenseDoubleVector(arr2); + vec1.multiply(vec2); + } + + @Test + public void testDotNormal() { + double[] arr1 = new double[] {1, 2, 3}; + double[] arr2 = new double[] {4, 5, 6}; + DoubleVector vec1 = new DenseDoubleVector(arr1); + DoubleVector vec2 = new DenseDoubleVector(arr2); + assertEquals(32.0, vec1.dot(vec2), 0.000001); + } + + @Test(expected = IllegalArgumentException.class) + public void testDotAbnormal() { + double[] arr1 = new double[] {1, 2, 3}; + double[] arr2 = new double[] {4, 5}; + DoubleVector vec1 = new DenseDoubleVector(arr1); + DoubleVector vec2 = new DenseDoubleVector(arr2); + vec1.add(vec2); + } }