Index: ml/src/main/java/org/apache/hama/ml/math/DenseDoubleMatrix.java =================================================================== --- ml/src/main/java/org/apache/hama/ml/math/DenseDoubleMatrix.java (revision 1498967) +++ ml/src/main/java/org/apache/hama/ml/math/DenseDoubleMatrix.java (working copy) @@ -795,20 +795,68 @@ /** * {@inheritDoc} */ - public DoubleMatrix applyToElements(DoubleMatrix other, DoubleDoubleFunction fun) { + public DoubleMatrix applyToElements(DoubleMatrix other, + DoubleDoubleFunction fun) { if (this.numRows != other.getRowCount() || this.numColumns != other.getColumnCount()) { throw new IllegalArgumentException( "Cannot apply double double function to matrices with different sizes."); } - + for (int r = 0; r < this.numRows; ++r) { for (int c = 0; c < this.numColumns; ++c) { this.set(r, c, fun.apply(this.get(r, c), other.get(r, c))); } } - + return this; } + /* + * (non-Javadoc) + * @see + * org.apache.hama.ml.math.DoubleMatrix#safeMultiply(org.apache.hama.ml.math + * .DoubleMatrix) + */ + @Override + public DoubleMatrix safeMultiply(DoubleMatrix other) { + if (this.numColumns != other.getRowCount()) { + throw new IllegalArgumentException( + String + .format( + "Matrix with size [%d, %d] cannot multiple matrix with size [%d, %d]", + this.numRows, this.numColumns, other.getRowCount(), + other.getColumnCount())); + } + return this.multiply(other); + } + + /* + * (non-Javadoc) + * @see + * org.apache.hama.ml.math.DoubleMatrix#safeMultiplyElementWise(org.apache + * .hama.ml.math.DoubleMatrix) + */ + @Override + public DoubleMatrix safeMultiplyElementWise(DoubleMatrix other) { + if (this.numRows != other.getRowCount() || this.numColumns != other.getColumnCount()) { + throw new IllegalArgumentException("Matrices with different dimensions cannot be multiplied elementwise."); + } + return this.multiplyElementWise(other); + } + + /* + * (non-Javadoc) + * @see + * org.apache.hama.ml.math.DoubleMatrix#safeMultiplyVector(org.apache.hama + * .ml.math.DoubleVector) + */ + @Override + public DoubleVector safeMultiplyVector(DoubleVector v) { + if (this.numColumns != v.getDimension()) { + throw new IllegalArgumentException("Dimension mismatch."); + } + return this.multiplyVector(v); + } + } Index: ml/src/main/java/org/apache/hama/ml/math/DenseDoubleVector.java =================================================================== --- ml/src/main/java/org/apache/hama/ml/math/DenseDoubleVector.java (revision 1498967) +++ ml/src/main/java/org/apache/hama/ml/math/DenseDoubleVector.java (working copy) @@ -112,16 +112,17 @@ } /** - * {@inheritDoc}} + * {@inheritDoc} */ @Override - public DoubleVector applyToElements(DoubleVector other, DoubleDoubleFunction func) { + public DoubleVector applyToElements(DoubleVector other, + DoubleDoubleFunction func) { for (int i = 0; i < vector.length; i++) { this.vector[i] = func.apply(vector[i], other.get(i)); } return this; } - + /* * (non-Javadoc) * @see de.jungblut.math.DoubleVector#apply(de.jungblut.math.function. @@ -338,10 +339,10 @@ * @see de.jungblut.math.DoubleVector#dot(de.jungblut.math.DoubleVector) */ @Override - public double dot(DoubleVector s) { + public double dot(DoubleVector vector) { double dotProduct = 0.0d; for (int i = 0; i < getLength(); i++) { - dotProduct += this.get(i) * s.get(i); + dotProduct += this.get(i) * vector.get(i); } return dotProduct; } @@ -652,4 +653,62 @@ return null; } + /* + * (non-Javadoc) + * @see org.apache.hama.ml.math.DoubleVector#safeAdd(org.apache.hama.ml.math. + * DoubleVector) + */ + @Override + public DoubleVector safeAdd(DoubleVector vector) { + if (this.vector.length != vector.getDimension()) { + throw new IllegalArgumentException( + "Dimensions of two vectors do not equal."); + } + return this.add(vector); + } + + /* + * (non-Javadoc) + * @see + * org.apache.hama.ml.math.DoubleVector#safeSubtract(org.apache.hama.ml.math + * .DoubleVector) + */ + @Override + public DoubleVector safeSubtract(DoubleVector vector) { + if (this.vector.length != vector.getDimension()) { + throw new IllegalArgumentException( + "Dimension of two vectors do not equal."); + } + return this.subtract(vector); + } + + /* + * (non-Javadoc) + * @see + * org.apache.hama.ml.math.DoubleVector#safeMultiplay(org.apache.hama.ml.math + * .DoubleVector) + */ + @Override + public DoubleVector safeMultiplay(DoubleVector vector) { + if (this.vector.length != vector.getDimension()) { + throw new IllegalArgumentException( + "Dimension of two vectors do not equal."); + } + return this.multiply(vector); + } + + /* + * (non-Javadoc) + * @see org.apache.hama.ml.math.DoubleVector#safeDot(org.apache.hama.ml.math. + * DoubleVector) + */ + @Override + public double safeDot(DoubleVector vector) { + if (this.vector.length != vector.getDimension()) { + throw new IllegalArgumentException( + "Dimension of two vectors do not equal."); + } + return this.dot(vector); + } + } Index: ml/src/main/java/org/apache/hama/ml/math/DoubleMatrix.java =================================================================== --- ml/src/main/java/org/apache/hama/ml/math/DoubleMatrix.java (revision 1498967) +++ ml/src/main/java/org/apache/hama/ml/math/DoubleMatrix.java (working copy) @@ -80,21 +80,50 @@ /** * Multiplies this matrix with the given other matrix. + * + * @param other the other matrix. + * @return */ public DoubleMatrix multiply(DoubleMatrix other); /** + * Validates the input and multiplies this matrix with the given other matrix. + * + * @param other the other matrix. + * @return + */ + public DoubleMatrix safeMultiply(DoubleMatrix other); + + /** * Multiplies this matrix per element with a given matrix. */ public DoubleMatrix multiplyElementWise(DoubleMatrix other); /** + * Validates the input and multiplies this matrix per element with a given + * matrix. + * + * @param other the other matrix + * @return + */ + public DoubleMatrix safeMultiplyElementWise(DoubleMatrix other); + + /** * Multiplies this matrix with a given vector v. The returning vector contains * the sum of the rows. */ public DoubleVector multiplyVector(DoubleVector v); /** + * Multiplies this matrix with a given vector v. The returning vector contains + * the sum of the rows. + * + * @param v the vector + * @return + */ + public DoubleVector safeMultiplyVector(DoubleVector v); + + /** * Transposes this matrix. */ public DoubleMatrix transpose(); @@ -203,6 +232,7 @@ * @param fun The function that takes two arguments. * @return The matrix itself, supply for chain operation. */ - public DoubleMatrix applyToElements(DoubleMatrix other, DoubleDoubleFunction fun); + public DoubleMatrix applyToElements(DoubleMatrix other, + DoubleDoubleFunction fun); } Index: ml/src/main/java/org/apache/hama/ml/math/DoubleVector.java =================================================================== --- ml/src/main/java/org/apache/hama/ml/math/DoubleVector.java (revision 1498967) +++ ml/src/main/java/org/apache/hama/ml/math/DoubleVector.java (working copy) @@ -58,7 +58,7 @@ * @param value the value at the index of the vector to set. */ public void set(int index, double value); - + /** * Apply a given {@link DoubleVectorFunction} to this vector and return a new * one. @@ -68,7 +68,7 @@ */ @Deprecated public DoubleVector apply(DoubleVectorFunction func); - + /** * Apply a given {@link DoubleDoubleVectorFunction} to this vector and the * other given vector. @@ -97,17 +97,26 @@ * @param func the function to apply on this and the other vector. * @return a new vector with the result of the function of the two vectors. */ - public DoubleVector applyToElements(DoubleVector other, DoubleDoubleFunction func); + public DoubleVector applyToElements(DoubleVector other, + DoubleDoubleFunction func); /** * Adds the given {@link DoubleVector} to this vector. * - * @param v the other vector. + * @param vector the other vector. * @return a new vector with the sum of both vectors at each element index. */ - public DoubleVector add(DoubleVector v); + public DoubleVector add(DoubleVector vector); /** + * Validates the input and adds the given {@link DoubleVector} to this vector. + * + * @param vector the other vector. + * @return a new vector with the sum of both vectors at each element index. + */ + public DoubleVector safeAdd(DoubleVector vector); + + /** * Adds the given scalar to this vector. * * @param scalar the scalar. @@ -118,12 +127,21 @@ /** * Subtracts this vector by the given {@link DoubleVector}. * - * @param v the other vector. + * @param vector the other vector. * @return a new vector with the difference of both vectors. */ - public DoubleVector subtract(DoubleVector v); + public DoubleVector subtract(DoubleVector vector); /** + * Validates the input and subtracts this vector by the given + * {@link DoubleVector}. + * + * @param vector the other vector. + * @return a new vector with the difference of both vectors. + */ + public DoubleVector safeSubtract(DoubleVector vector); + + /** * Subtracts the given scalar to this vector. (vector - scalar). * * @param scalar the scalar. @@ -156,6 +174,15 @@ public DoubleVector multiply(DoubleVector vector); /** + * Validates the input and multiplies the given {@link DoubleVector} with this + * vector. + * + * @param vector the other vector. + * @return a new vector with the result of the operation. + */ + public DoubleVector safeMultiplay(DoubleVector vector); + + /** * Divides this vector by the given scalar. (= vector/scalar). * * @param scalar the given scalar. @@ -201,12 +228,21 @@ /** * Calculates the dot product between this vector and the given vector. * - * @param s the given vector s. + * @param vector the given vector. * @return the dot product as a double. */ - public double dot(DoubleVector s); + public double dot(DoubleVector vector); /** + * Validates the input and calculates the dot product between this vector and + * the given vector. + * + * @param vector the given vector. + * @return the dot product as a double. + */ + public double safeDot(DoubleVector vector); + + /** * Slices this vector from index 0 to the given length. * * @param length must be > 0 and smaller than the dimension of the vector.