Index: ml/src/main/java/org/apache/hama/ml/math/DenseDoubleVector.java =================================================================== --- ml/src/main/java/org/apache/hama/ml/math/DenseDoubleVector.java (revision 1514730) +++ ml/src/main/java/org/apache/hama/ml/math/DenseDoubleVector.java (working copy) @@ -245,6 +245,22 @@ return v; } + @Override + public DoubleVector multiply(DoubleMatrix matrix) { + Preconditions.checkArgument(this.vector.length == matrix.getRowCount(), + "Dimension mismatch when multiply a vector to a matrix."); + return this.multiplyUnsafe(matrix); + } + + @Override + public DoubleVector multiplyUnsafe(DoubleMatrix matrix) { + DoubleVector vec = new DenseDoubleVector(matrix.getColumnCount()); + for (int i = 0; i < vec.getDimension(); ++i) { + vec.set(i, this.multiplyUnsafe(matrix.getColumnVector(i)).sum()); + } + return vec; + } + /* * (non-Javadoc) * @see de.jungblut.math.DoubleVector#divide(double) @@ -356,12 +372,12 @@ public DoubleVector slice(int length) { return slice(0, length - 1); } - + @Override public DoubleVector sliceUnsafe(int length) { return sliceUnsafe(0, length - 1); } - + /* * (non-Javadoc) * @see de.jungblut.math.DoubleVector#slice(int, int) @@ -373,7 +389,7 @@ return sliceUnsafe(start, end); } - + /** * {@inheritDoc} */ Index: ml/src/main/java/org/apache/hama/ml/math/DoubleVector.java =================================================================== --- ml/src/main/java/org/apache/hama/ml/math/DoubleVector.java (revision 1514730) +++ ml/src/main/java/org/apache/hama/ml/math/DoubleVector.java (working copy) @@ -183,6 +183,23 @@ public DoubleVector multiply(DoubleVector vector); /** + * Validates the input and multiplies the given {@link DoubleMatrix} with this + * vector. + * + * @param matrix + * @return + */ + public DoubleVector multiply(DoubleMatrix matrix); + + /** + * Multiplies the given {@link DoubleMatrix} with this vector. + * + * @param matrix + * @return + */ + public DoubleVector multiplyUnsafe(DoubleMatrix matrix); + + /** * Divides this vector by the given scalar. (= vector/scalar). * * @param scalar the given scalar. @@ -243,13 +260,14 @@ public double dot(DoubleVector vector); /** - * Validates the input and slices this vector from index 0 to the given length. + * Validates the input and slices this vector from index 0 to the given + * length. * * @param length must be > 0 and smaller than the dimension of the vector. * @return a new vector that is only length long. */ public DoubleVector slice(int length); - + /** * Slices this vector from index 0 to the given length. * Index: ml/src/test/java/org/apache/hama/ml/math/TestDenseDoubleVector.java =================================================================== --- ml/src/test/java/org/apache/hama/ml/math/TestDenseDoubleVector.java (revision 1514730) +++ ml/src/test/java/org/apache/hama/ml/math/TestDenseDoubleVector.java (working copy) @@ -185,4 +185,24 @@ DoubleVector vec = new DenseDoubleVector(arr1); vec.slice(4, 3); } + + @Test + public void testVectorMultiplyMatrix() { + DoubleVector vec = new DenseDoubleVector(new double[]{1, 2, 3}); + DoubleMatrix mat = new DenseDoubleMatrix(new double[][] { + {1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12} + }); + double[] expectedRes = new double[] {38, 44, 50, 56}; + + assertArrayEquals(expectedRes, vec.multiply(mat).toArray(), 0.000001); + } + + @Test(expected = IllegalArgumentException.class) + public void testVectorMultiplyMatrixAbnormal() { + DoubleVector vec = new DenseDoubleVector(new double[]{1, 2, 3}); + DoubleMatrix mat = new DenseDoubleMatrix(new double[][] { + {1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}, {13, 14, 15, 16} + }); + vec.multiply(mat); + } }