Index: HAMA-765.patch =================================================================== Index: ml/src/main/java/org/apache/hama/ml/math/DenseDoubleMatrix.java =================================================================== --- ml/src/main/java/org/apache/hama/ml/math/DenseDoubleMatrix.java (revision 1493881) +++ ml/src/main/java/org/apache/hama/ml/math/DenseDoubleMatrix.java (working copy) @@ -778,4 +778,37 @@ return a.subtract(b).sum(); } + @Override + /** + * {@inheritDoc} + */ + public DoubleMatrix apply(DoubleFunction fun) { + for (int r = 0; r < this.numRows; ++r) { + for (int c = 0; c < this.numColumns; ++c) { + this.set(r, c, fun.apply(this.get(r, c))); + } + } + return this; + } + + @Override + /** + * {@inheritDoc} + */ + public DoubleMatrix apply(DoubleMatrix other, DoubleDoubleFunction fun) { + if (this.numRows != other.getRowCount() + || this.numColumns != other.getColumnCount()) { + throw new IllegalArgumentException( + "Cannot apply double double function to matrices with different sizes."); + } + + for (int r = 0; r < this.numRows; ++r) { + for (int c = 0; c < this.numColumns; ++c) { + this.set(r, c, fun.apply(this.get(r, c), other.get(r, c))); + } + } + + return this; + } + } Index: ml/src/main/java/org/apache/hama/ml/math/DenseDoubleVector.java =================================================================== --- ml/src/main/java/org/apache/hama/ml/math/DenseDoubleVector.java (revision 1493881) +++ ml/src/main/java/org/apache/hama/ml/math/DenseDoubleVector.java (working copy) @@ -100,11 +100,34 @@ vector[index] = value; } + /** + * {@inheritDoc} + */ + @Override + public DoubleVector apply(DoubleFunction func) { + for (int i = 0; i < vector.length; i++) { + this.vector[i] = func.apply(vector[i]); + } + return this; + } + + /** + * {@inheritDoc}} + */ + @Override + public DoubleVector apply(DoubleVector other, DoubleDoubleFunction func) { + for (int i = 0; i < vector.length; i++) { + this.vector[i] = func.apply(vector[i], other.get(i)); + } + return this; + } + /* * (non-Javadoc) * @see de.jungblut.math.DoubleVector#apply(de.jungblut.math.function. * DoubleVectorFunction) */ + @Deprecated @Override public DoubleVector apply(DoubleVectorFunction func) { DenseDoubleVector newV = new DenseDoubleVector(this.vector); @@ -119,6 +142,7 @@ * @see de.jungblut.math.DoubleVector#apply(de.jungblut.math.DoubleVector, * de.jungblut.math.function.DoubleDoubleVectorFunction) */ + @Deprecated @Override public DoubleVector apply(DoubleVector other, DoubleDoubleVectorFunction func) { DenseDoubleVector newV = (DenseDoubleVector) deepCopy(); Index: ml/src/main/java/org/apache/hama/ml/math/DoubleDoubleFunction.java =================================================================== --- ml/src/main/java/org/apache/hama/ml/math/DoubleDoubleFunction.java (revision 0) +++ ml/src/main/java/org/apache/hama/ml/math/DoubleDoubleFunction.java (working copy) @@ -0,0 +1,45 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.hama.ml.math; + +/** + * A double double function takes two arguments. A vector or matrix can apply + * the double function to each element. + * + */ +public abstract class DoubleDoubleFunction extends Function { + + /** + * Apply the function to elements to two given arguments. + * + * @param x1 + * @param x2 + * @return The result based on the calculation on two arguments. + */ + public abstract double apply(double x1, double x2); + + /** + * Apply the derivative of this function to two given arguments. + * + * @param x1 + * @param x2 + * @return The result based on the calculation on two arguments. + */ + public abstract double applyDerivative(double x1, double x2); + +} Index: ml/src/main/java/org/apache/hama/ml/math/DoubleDoubleVectorFunction.java =================================================================== --- ml/src/main/java/org/apache/hama/ml/math/DoubleDoubleVectorFunction.java (revision 1493881) +++ ml/src/main/java/org/apache/hama/ml/math/DoubleDoubleVectorFunction.java (working copy) @@ -20,7 +20,10 @@ /** * A function that can be applied to two double vectors via {@link DoubleVector} * #apply({@link DoubleVector} v, {@link DoubleDoubleVectorFunction} f); + * + * This class will be replaced by {@link DoubleDoubleFunction} */ +@Deprecated public interface DoubleDoubleVectorFunction { /** Index: ml/src/main/java/org/apache/hama/ml/math/DoubleFunction.java =================================================================== --- ml/src/main/java/org/apache/hama/ml/math/DoubleFunction.java (revision 0) +++ ml/src/main/java/org/apache/hama/ml/math/DoubleFunction.java (working copy) @@ -0,0 +1,43 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.hama.ml.math; + +/** + * A double double function takes two arguments. A vector or matrix can apply + * the double function to each element. + * + */ +public abstract class DoubleFunction extends Function { + + /** + * Apply the function to element. + * + * @param elem The element that the function apply to. + * @return The result after applying the function. + */ + public abstract double apply(double value); + + /** + * Apply the gradient of the function. + * + * @param elem + * @return + */ + public abstract double applyDerivative(double value); + +} Index: ml/src/main/java/org/apache/hama/ml/math/DoubleMatrix.java =================================================================== --- ml/src/main/java/org/apache/hama/ml/math/DoubleMatrix.java (revision 1493881) +++ ml/src/main/java/org/apache/hama/ml/math/DoubleMatrix.java (working copy) @@ -184,4 +184,25 @@ */ public DoubleMatrix slice(int rowOffset, int rowMax, int colOffset, int colMax); + /** + * Apply a double function f(x) onto each element of the matrix. After + * applying, each element of the current matrix will be changed from x to + * f(x). + * + * @param fun The function. + * @return The matrix itself, supply for chain operation. + */ + public DoubleMatrix apply(DoubleFunction fun); + + /** + * Apply a double double function f(x, y) onto each pair of the current matrix + * elements and given matrix. After applying, each element of the current + * matrix will be changed from x to f(x, y). + * + * @param other The matrix contributing the second argument of the function. + * @param fun The function that takes two arguments. + * @return The matrix itself, supply for chain operation. + */ + public DoubleMatrix apply(DoubleMatrix other, DoubleDoubleFunction fun); + } Index: ml/src/main/java/org/apache/hama/ml/math/DoubleVector.java =================================================================== --- ml/src/main/java/org/apache/hama/ml/math/DoubleVector.java (revision 1493881) +++ ml/src/main/java/org/apache/hama/ml/math/DoubleVector.java (working copy) @@ -58,7 +58,7 @@ * @param value the value at the index of the vector to set. */ public void set(int index, double value); - + /** * Apply a given {@link DoubleVectorFunction} to this vector and return a new * one. @@ -66,8 +66,9 @@ * @param func the function to apply. * @return a new vector with the applied function. */ + @Deprecated public DoubleVector apply(DoubleVectorFunction func); - + /** * Apply a given {@link DoubleDoubleVectorFunction} to this vector and the * other given vector. @@ -76,9 +77,29 @@ * @param func the function to apply on this and the other vector. * @return a new vector with the result of the function of the two vectors. */ + @Deprecated public DoubleVector apply(DoubleVector other, DoubleDoubleVectorFunction func); /** + * Apply a given {@link DoubleVectorFunction} to this vector and return a new + * one. + * + * @param func the function to apply. + * @return a new vector with the applied function. + */ + public DoubleVector apply(DoubleFunction func); + + /** + * Apply a given {@link DoubleDoubleVectorFunction} to this vector and the + * other given vector. + * + * @param other the other vector. + * @param func the function to apply on this and the other vector. + * @return a new vector with the result of the function of the two vectors. + */ + public DoubleVector apply(DoubleVector other, DoubleDoubleFunction func); + + /** * Adds the given {@link DoubleVector} to this vector. * * @param v the other vector. Index: ml/src/main/java/org/apache/hama/ml/math/DoubleVectorFunction.java =================================================================== --- ml/src/main/java/org/apache/hama/ml/math/DoubleVectorFunction.java (revision 1493881) +++ ml/src/main/java/org/apache/hama/ml/math/DoubleVectorFunction.java (working copy) @@ -20,7 +20,10 @@ /** * A function that can be applied to a double vector via {@link DoubleVector} * #apply({@link DoubleVectorFunction} f); + * + * This class will be replaced by {@link DoubleFunction} */ +@Deprecated public interface DoubleVectorFunction { /** Index: ml/src/main/java/org/apache/hama/ml/perception/CostFunction.java =================================================================== --- ml/src/main/java/org/apache/hama/ml/perception/CostFunction.java (revision 1493881) +++ ml/src/main/java/org/apache/hama/ml/perception/CostFunction.java (working copy) @@ -1,43 +0,0 @@ -/** - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.hama.ml.perception; - -/** - * The common interface for cost functions. - */ -public abstract class CostFunction { - - /** - * Get the error evaluated by squared error. - * - * @param target The target value. - * @param actual The actual value. - * @return - */ - public abstract double calculate(double target, double actual); - - /** - * Get the partial derivative of squared error. - * - * @param target - * @param actual - * @return - */ - public abstract double calculateDerivative(double target, double actual); - -} Index: ml/src/main/java/org/apache/hama/ml/perception/CostFunctionFactory.java =================================================================== --- ml/src/main/java/org/apache/hama/ml/perception/CostFunctionFactory.java (revision 1493881) +++ ml/src/main/java/org/apache/hama/ml/perception/CostFunctionFactory.java (working copy) @@ -1,41 +0,0 @@ -/** - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.hama.ml.perception; - -/** - * The cost function factory that generates the cost function by name. - */ -public class CostFunctionFactory { - - /** - * Get the cost function according to the name. If no matched cost function is - * found, return the SquaredError by default. - * - * @param name The name of the cost function. - * @return The cost function instance. - */ - public static CostFunction getCostFunction(String name) { - if (name.equalsIgnoreCase("SquaredError")) { - return new SquaredError(); - } else if (name.equalsIgnoreCase("CrossEntropy")) { - return new CrossEntropy(); - } - throw new IllegalStateException(String.format( - "No cost function with name '%s' found.", name)); - } -} Index: ml/src/main/java/org/apache/hama/ml/perception/CrossEntropy.java =================================================================== --- ml/src/main/java/org/apache/hama/ml/perception/CrossEntropy.java (revision 1493881) +++ ml/src/main/java/org/apache/hama/ml/perception/CrossEntropy.java (working copy) @@ -1,53 +0,0 @@ -/** - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.hama.ml.perception; - -/** - * The cross entropy cost function. - * - *
- * cost(t, y) = - t * log(y) - (1 - t) * log(1 - y),
- * where t denotes the target value, y denotes the estimated value.
- * 
- */ -public class CrossEntropy extends CostFunction { - - @Override - public double calculate(double target, double actual) { - return -target * Math.log(actual) - (1 - target) * Math.log(1 - actual); - } - - @Override - public double calculateDerivative(double target, double actual) { - double adjustedTarget = target; - double adjustedActual = actual; - if (adjustedActual == 1) { - adjustedActual = 0.999; - } else if (actual == 0) { - adjustedActual = 0.001; - } - if (adjustedTarget == 1) { - adjustedTarget = 0.999; - } else if (adjustedTarget == 0) { - adjustedTarget = 0.001; - } - return -adjustedTarget / adjustedActual + (1 - adjustedTarget) - / (1 - adjustedActual); - } - -} Index: ml/src/main/java/org/apache/hama/ml/perception/LogisticCostFunction.java =================================================================== --- ml/src/main/java/org/apache/hama/ml/perception/LogisticCostFunction.java (revision 1493881) +++ ml/src/main/java/org/apache/hama/ml/perception/LogisticCostFunction.java (working copy) @@ -1,53 +0,0 @@ -/** - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.hama.ml.perception; - -/** - * The logistic cost function. - * - *
- * cost(t, y) = - t * log(y) - (1 - t) * log(1 - y),
- * where t denotes the target value, y denotes the estimated value.
- * 
- */ -public class LogisticCostFunction extends CostFunction { - - @Override - public double calculate(double target, double actual) { - return -target * Math.log(actual) - (1 - target) * Math.log(1 - actual); - } - - @Override - public double calculateDerivative(double target, double actual) { - double adjustedTarget = target; - double adjustedActual = actual; - if (adjustedActual == 1) { - adjustedActual = 0.999; - } else if (actual == 0) { - adjustedActual = 0.001; - } - if (adjustedTarget == 1) { - adjustedTarget = 0.999; - } else if (adjustedTarget == 0) { - adjustedTarget = 0.001; - } - return -adjustedTarget / adjustedActual + (1 - adjustedTarget) - / (1 - adjustedActual); - } - -} Index: ml/src/main/java/org/apache/hama/ml/perception/MultiLayerPerceptron.java =================================================================== --- ml/src/main/java/org/apache/hama/ml/perception/MultiLayerPerceptron.java (revision 1493881) +++ ml/src/main/java/org/apache/hama/ml/perception/MultiLayerPerceptron.java (working copy) @@ -21,7 +21,10 @@ import java.util.Map; import org.apache.hadoop.fs.Path; +import org.apache.hama.ml.math.DoubleDoubleFunction; +import org.apache.hama.ml.math.DoubleFunction; import org.apache.hama.ml.math.DoubleVector; +import org.apache.hama.ml.math.FunctionFactory; /** * PerceptronBase defines the common behavior of all the concrete perceptrons. @@ -43,8 +46,8 @@ protected String costFunctionName; protected int[] layerSizeArray; - protected CostFunction costFunction; - protected SquashingFunction squashingFunction; + protected DoubleDoubleFunction costFunction; + protected DoubleFunction squashingFunction; /** * Initialize the MLP. @@ -83,10 +86,10 @@ this.layerSizeArray = layerSizeArray; this.numberOfLayers = this.layerSizeArray.length; - this.costFunction = CostFunctionFactory - .getCostFunction(this.costFunctionName); - this.squashingFunction = SquashingFunctionFactory - .getSquashingFunction(this.squashingFunctionName); + this.costFunction = FunctionFactory + .createDoubleDoubleFunction(this.costFunctionName); + this.squashingFunction = FunctionFactory + .createDoubleFunction(this.squashingFunctionName); } /** Index: ml/src/main/java/org/apache/hama/ml/perception/Sigmoid.java =================================================================== --- ml/src/main/java/org/apache/hama/ml/perception/Sigmoid.java (revision 1493881) +++ ml/src/main/java/org/apache/hama/ml/perception/Sigmoid.java (working copy) @@ -1,38 +0,0 @@ -/** - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.hama.ml.perception; - -/** - * The Sigmoid function - * - *
- * f(z) = 1 / (1 + e^{-z})
- * 
- */ -public class Sigmoid extends SquashingFunction { - - @Override - public double calculate(int index, double value) { - return 1.0 / (1 + Math.exp(-value)); - } - - @Override - public double calculateDerivative(double value) { - return value * (1 - value); - } -} Index: ml/src/main/java/org/apache/hama/ml/perception/SmallMultiLayerPerceptron.java =================================================================== --- ml/src/main/java/org/apache/hama/ml/perception/SmallMultiLayerPerceptron.java (revision 1493881) +++ ml/src/main/java/org/apache/hama/ml/perception/SmallMultiLayerPerceptron.java (working copy) @@ -41,7 +41,9 @@ import org.apache.hama.bsp.BSPJob; import org.apache.hama.ml.math.DenseDoubleMatrix; import org.apache.hama.ml.math.DenseDoubleVector; +import org.apache.hama.ml.math.DoubleFunction; import org.apache.hama.ml.math.DoubleVector; +import org.apache.hama.ml.math.FunctionFactory; import org.apache.hama.ml.writable.MatrixWritable; import org.apache.hama.ml.writable.VectorWritable; import org.mortbay.log.Log; @@ -107,13 +109,30 @@ // add weights for bias this.weightMatrice[i] = new DenseDoubleMatrix(this.layerSizeArray[i] + 1, this.layerSizeArray[i + 1]); - int rowCount = this.weightMatrice[i].getRowCount(); - int colCount = this.weightMatrice[i].getColumnCount(); - for (int row = 0; row < rowCount; ++row) { - for (int col = 0; col < colCount; ++col) { - this.weightMatrice[i].set(row, col, rnd.nextDouble() - 0.5); + + this.weightMatrice[i].apply(new DoubleFunction() { + + private Random rnd = new Random(); + + @Override + public double apply(double value) { + return rnd.nextDouble() - 0.5; } - } + + @Override + public double applyDerivative(double value) { + throw new UnsupportedOperationException("Not supported"); + } + + }); + +// int rowCount = this.weightMatrice[i].getRowCount(); +// int colCount = this.weightMatrice[i].getColumnCount(); +// for (int row = 0; row < rowCount; ++row) { +// for (int col = 0; col < colCount; ++col) { +// this.weightMatrice[i].set(row, col, rnd.nextDouble() - 0.5); +// } +// } } } @@ -199,8 +218,7 @@ prevNeuronIdx, neuronIdx) * intermediateResult[prevNeuronIdx]; } // calculate via squashing function - results[neuronIdx + offset] = this.squashingFunction.calculate(0, - results[neuronIdx + offset]); + results[neuronIdx + offset] = this.squashingFunction.apply(results[neuronIdx + offset]); } return results; @@ -243,7 +261,7 @@ DenseDoubleMatrix prevWeightUpdateMatrix = this.prevWeightUpdateMatrices[this.prevWeightUpdateMatrices.length - 1]; for (int j = 0; j < delta.length; ++j) { - delta[j] = this.costFunction.calculateDerivative(trainingLabels[j], + delta[j] = this.costFunction.applyDerivative(trainingLabels[j], outputLayerOutput[j]); // add regularization term if (this.regularization != 0.0) { @@ -257,7 +275,7 @@ } delta[j] *= this.squashingFunction - .calculateDerivative(outputLayerOutput[j]); + .applyDerivative(outputLayerOutput[j]); // calculate the weight update matrix between the last hidden layer and // the output layer @@ -307,7 +325,7 @@ delta[j] += weight * nextLayerDelta[k]; } delta[j] *= this.squashingFunction - .calculateDerivative(curLayerOutput[j + 1]); + .applyDerivative(curLayerOutput[j + 1]); // calculate the weight update matrix between the previous layer and the // current layer @@ -395,10 +413,10 @@ for (int i = 0; i < numberOfLayers - 1; ++i) { this.weightMatrice[i] = (DenseDoubleMatrix) MatrixWritable.read(input); } - this.squashingFunction = SquashingFunctionFactory - .getSquashingFunction(this.squashingFunctionName); - this.costFunction = CostFunctionFactory - .getCostFunction(this.costFunctionName); + this.squashingFunction = FunctionFactory + .createDoubleFunction(this.squashingFunctionName); + this.costFunction = FunctionFactory + .createDoubleDoubleFunction(this.costFunctionName); } @Override Index: ml/src/main/java/org/apache/hama/ml/perception/SquaredError.java =================================================================== --- ml/src/main/java/org/apache/hama/ml/perception/SquaredError.java (revision 1493881) +++ ml/src/main/java/org/apache/hama/ml/perception/SquaredError.java (working copy) @@ -1,47 +0,0 @@ -/** - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.hama.ml.perception; - -/** - * Square error cost function. - * - *
- * cost(t, y) = 0.5 * (t - y) ˆ 2
- * 
- */ -public class SquaredError extends CostFunction { - - @Override - /** - * {@inheritDoc} - */ - public double calculate(double target, double actual) { - double diff = target - actual; - return 0.5 * diff * diff; - } - - @Override - /** - * {@inheritDoc} - */ - public double calculateDerivative(double target, double actual) { - // return target - actual; - return actual - target; - } - -} Index: ml/src/main/java/org/apache/hama/ml/perception/SquashingFunction.java =================================================================== --- ml/src/main/java/org/apache/hama/ml/perception/SquashingFunction.java (revision 1493881) +++ ml/src/main/java/org/apache/hama/ml/perception/SquashingFunction.java (working copy) @@ -1,41 +0,0 @@ -/** - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.hama.ml.perception; - -import org.apache.hama.ml.math.DoubleVectorFunction; - -/** - * The squashing function to activate the neurons. - * - */ -public abstract class SquashingFunction implements DoubleVectorFunction { - - /** - * Calculates the result with a given index and value of a vector. - */ - @Override - public abstract double calculate(int index, double value); - - /** - * Apply the gradient descent to each of the elements in vector. - * - * @param vector - * @return - */ - public abstract double calculateDerivative(double value); -} Index: ml/src/main/java/org/apache/hama/ml/perception/SquashingFunctionFactory.java =================================================================== --- ml/src/main/java/org/apache/hama/ml/perception/SquashingFunctionFactory.java (revision 1493881) +++ ml/src/main/java/org/apache/hama/ml/perception/SquashingFunctionFactory.java (working copy) @@ -1,43 +0,0 @@ -/** - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.hama.ml.perception; - -/** - * Get the squashing function according to the name. - */ -public class SquashingFunctionFactory { - - /** - * Get the squashing function instance according to the name. If no matched - * squahsing function is found, return the sigmoid squashing function by - * default. - * - * @param name The name of the squashing function. - * @return The instance of the squashing function. - */ - public static SquashingFunction getSquashingFunction(String name) { - if (name.equalsIgnoreCase("Sigmoid")) { - return new Sigmoid(); - } else if (name.equalsIgnoreCase("Tanh")) { - return new Tanh(); - } - throw new IllegalStateException(String.format( - "No squashing function with name '%s' found.", name)); - } - -} Index: ml/src/main/java/org/apache/hama/ml/perception/Tanh.java =================================================================== --- ml/src/main/java/org/apache/hama/ml/perception/Tanh.java (revision 1493881) +++ ml/src/main/java/org/apache/hama/ml/perception/Tanh.java (working copy) @@ -1,36 +0,0 @@ -/** - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.hama.ml.perception; - -/** - * The hyperbolic tangent function. It is used as a squashing function in - * multi-layer perceptron. - */ -public class Tanh extends SquashingFunction { - - @Override - public double calculate(int index, double value) { - return Math.tanh(value); - } - - @Override - public double calculateDerivative(double value) { - return 1 - value * value; - } - -} Index: ml/src/test/java/org/apache/hama/ml/math/TestDenseDoubleMatrix.java =================================================================== --- ml/src/test/java/org/apache/hama/ml/math/TestDenseDoubleMatrix.java (revision 0) +++ ml/src/test/java/org/apache/hama/ml/math/TestDenseDoubleMatrix.java (working copy) @@ -0,0 +1,86 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.hama.ml.math; + +import static org.junit.Assert.assertArrayEquals; + +import org.junit.Test; + +/** + * Test case for {@link DenseDoubleMatrix} + * + */ +public class TestDenseDoubleMatrix { + + @Test + public void testDoubleFunction() { + double[][] values = new double[][] { { 1, 2, 3 }, { 4, 5, 6 }, { 7, 8, 9 } }; + + double[][] result = new double[][] { { 2, 3, 4 }, { 5, 6, 7 }, { 8, 9, 10 } }; + + DenseDoubleMatrix mat = new DenseDoubleMatrix(values); + mat.apply(new DoubleFunction() { + + @Override + public double apply(double value) { + return value + 1; + } + + @Override + public double applyDerivative(double value) { + throw new UnsupportedOperationException(); + } + + }); + + double[][] actual = mat.getValues(); + for (int i = 0; i < actual.length; ++i) { + assertArrayEquals(result[i], actual[i], 0.0001); + } + } + + @Test + public void testDoubleDoubleFunction() { + double[][] values1 = new double[][] { { 1, 2, 3 }, { 4, 5, 6 }, { 7, 8, 9 } }; + double[][] values2 = new double[][] { { 2, 3, 4 }, { 5, 6, 7 }, { 8, 9, 10 } }; + double[][] result = new double[][] { {3, 5, 7}, {9, 11, 13}, {15, 17, 19}}; + + DenseDoubleMatrix mat1 = new DenseDoubleMatrix(values1); + DenseDoubleMatrix mat2 = new DenseDoubleMatrix(values2); + + mat1.apply(mat2, new DoubleDoubleFunction() { + + @Override + public double apply(double x1, double x2) { + return x1 + x2; + } + + @Override + public double applyDerivative(double x1, double x2) { + throw new UnsupportedOperationException(); + } + + }); + + double[][] actual = mat1.getValues(); + for (int i = 0; i < actual.length; ++i) { + assertArrayEquals(result[i], actual[i], 0.0001); + } + } + +} Index: ml/src/test/java/org/apache/hama/ml/math/TestDenseDoubleVector.java =================================================================== --- ml/src/test/java/org/apache/hama/ml/math/TestDenseDoubleVector.java (revision 0) +++ ml/src/test/java/org/apache/hama/ml/math/TestDenseDoubleVector.java (working copy) @@ -0,0 +1,80 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.hama.ml.math; + +import static org.junit.Assert.assertArrayEquals; + +import org.junit.Test; + +/** + * Testcase for {@link DenseDoubleVector} + * + */ +public class TestDenseDoubleVector { + + @Test + public void testApplyDoubleFunction() { + double[] values = new double[] {1, 2, 3, 4, 5}; + double[] result = new double[] {2, 3, 4, 5, 6}; + + DoubleVector vec1 = new DenseDoubleVector(values); + + vec1.apply(new DoubleFunction() { + + @Override + public double apply(double value) { + return value + 1; + } + + @Override + public double applyDerivative(double value) { + throw new UnsupportedOperationException("Not supported."); + } + + }); + + assertArrayEquals(result, vec1.toArray(), 0.0001); + } + + @Test + public void testApplyDoubleDoubleFunction() { + double[] values1 = new double[] {1, 2, 3, 4, 5, 6}; + double[] values2 = new double[] {7, 8, 9, 10, 11, 12}; + double[] result = new double[] {8, 10, 12, 14, 16, 18}; + + DoubleVector vec1 = new DenseDoubleVector(values1); + DoubleVector vec2 = new DenseDoubleVector(values2); + + vec1.apply(vec2, new DoubleDoubleFunction() { + + @Override + public double apply(double x1, double x2) { + return x1 + x2; + } + + @Override + public double applyDerivative(double x1, double x2) { + throw new UnsupportedOperationException("Not supported"); + } + + }); + + assertArrayEquals(result, vec1.toArray(), 0.0001); + + } +} Index: ml/src/test/java/org/apache/hama/ml/perception/TestSmallMultiLayerPerceptron.java =================================================================== --- ml/src/test/java/org/apache/hama/ml/perception/TestSmallMultiLayerPerceptron.java (revision 1493881) +++ ml/src/test/java/org/apache/hama/ml/perception/TestSmallMultiLayerPerceptron.java (working copy) @@ -1,5 +1,5 @@ /** - * Licensed to the Apache Software Foundation (ASF) under one +c * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file