Index: ml/src/main/java/org/apache/hama/ml/math/CrossEntropy.java =================================================================== --- ml/src/main/java/org/apache/hama/ml/math/CrossEntropy.java (revision 0) +++ ml/src/main/java/org/apache/hama/ml/math/CrossEntropy.java (working copy) @@ -0,0 +1,53 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.hama.ml.math; + +/** + * The cross entropy cost function. + * + *
+ * cost(t, y) = - t * log(y) - (1 - t) * log(1 - y), + * where t denotes the target value, y denotes the estimated value. + *+ */ +public class CrossEntropy extends DoubleDoubleFunction { + + @Override + public double apply(double target, double actual) { + return -target * Math.log(actual) - (1 - target) * Math.log(1 - actual); + } + + @Override + public double applyDerivative(double target, double actual) { + double adjustedTarget = target; + double adjustedActual = actual; + if (adjustedActual == 1) { + adjustedActual = 0.999; + } else if (actual == 0) { + adjustedActual = 0.001; + } + if (adjustedTarget == 1) { + adjustedTarget = 0.999; + } else if (adjustedTarget == 0) { + adjustedTarget = 0.001; + } + return -adjustedTarget / adjustedActual + (1 - adjustedTarget) + / (1 - adjustedActual); + } + +} Index: ml/src/main/java/org/apache/hama/ml/math/DenseDoubleMatrix.java =================================================================== --- ml/src/main/java/org/apache/hama/ml/math/DenseDoubleMatrix.java (revision 1493881) +++ ml/src/main/java/org/apache/hama/ml/math/DenseDoubleMatrix.java (working copy) @@ -778,4 +778,37 @@ return a.subtract(b).sum(); } + @Override + /** + * {@inheritDoc} + */ + public DoubleMatrix applyToElements(DoubleFunction fun) { + for (int r = 0; r < this.numRows; ++r) { + for (int c = 0; c < this.numColumns; ++c) { + this.set(r, c, fun.apply(this.get(r, c))); + } + } + return this; + } + + @Override + /** + * {@inheritDoc} + */ + public DoubleMatrix applyToElements(DoubleMatrix other, DoubleDoubleFunction fun) { + if (this.numRows != other.getRowCount() + || this.numColumns != other.getColumnCount()) { + throw new IllegalArgumentException( + "Cannot apply double double function to matrices with different sizes."); + } + + for (int r = 0; r < this.numRows; ++r) { + for (int c = 0; c < this.numColumns; ++c) { + this.set(r, c, fun.apply(this.get(r, c), other.get(r, c))); + } + } + + return this; + } + } Index: ml/src/main/java/org/apache/hama/ml/math/DenseDoubleVector.java =================================================================== --- ml/src/main/java/org/apache/hama/ml/math/DenseDoubleVector.java (revision 1493881) +++ ml/src/main/java/org/apache/hama/ml/math/DenseDoubleVector.java (working copy) @@ -100,11 +100,34 @@ vector[index] = value; } + /** + * {@inheritDoc} + */ + @Override + public DoubleVector apply(DoubleFunction func) { + for (int i = 0; i < vector.length; i++) { + this.vector[i] = func.apply(vector[i]); + } + return this; + } + + /** + * {@inheritDoc}} + */ + @Override + public DoubleVector apply(DoubleVector other, DoubleDoubleFunction func) { + for (int i = 0; i < vector.length; i++) { + this.vector[i] = func.apply(vector[i], other.get(i)); + } + return this; + } + /* * (non-Javadoc) * @see de.jungblut.math.DoubleVector#apply(de.jungblut.math.function. * DoubleVectorFunction) */ + @Deprecated @Override public DoubleVector apply(DoubleVectorFunction func) { DenseDoubleVector newV = new DenseDoubleVector(this.vector); @@ -119,6 +142,7 @@ * @see de.jungblut.math.DoubleVector#apply(de.jungblut.math.DoubleVector, * de.jungblut.math.function.DoubleDoubleVectorFunction) */ + @Deprecated @Override public DoubleVector apply(DoubleVector other, DoubleDoubleVectorFunction func) { DenseDoubleVector newV = (DenseDoubleVector) deepCopy(); Index: ml/src/main/java/org/apache/hama/ml/math/DoubleDoubleFunction.java =================================================================== --- ml/src/main/java/org/apache/hama/ml/math/DoubleDoubleFunction.java (revision 0) +++ ml/src/main/java/org/apache/hama/ml/math/DoubleDoubleFunction.java (working copy) @@ -0,0 +1,45 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.hama.ml.math; + +/** + * A double double function takes two arguments. A vector or matrix can apply + * the double function to each element. + * + */ +public abstract class DoubleDoubleFunction extends Function { + + /** + * Apply the function to elements to two given arguments. + * + * @param x1 + * @param x2 + * @return The result based on the calculation on two arguments. + */ + public abstract double apply(double x1, double x2); + + /** + * Apply the derivative of this function to two given arguments. + * + * @param x1 + * @param x2 + * @return The result based on the calculation on two arguments. + */ + public abstract double applyDerivative(double x1, double x2); + +} Index: ml/src/main/java/org/apache/hama/ml/math/DoubleDoubleVectorFunction.java =================================================================== --- ml/src/main/java/org/apache/hama/ml/math/DoubleDoubleVectorFunction.java (revision 1493881) +++ ml/src/main/java/org/apache/hama/ml/math/DoubleDoubleVectorFunction.java (working copy) @@ -20,7 +20,10 @@ /** * A function that can be applied to two double vectors via {@link DoubleVector} * #apply({@link DoubleVector} v, {@link DoubleDoubleVectorFunction} f); + * + * This class will be replaced by {@link DoubleDoubleFunction} */ +@Deprecated public interface DoubleDoubleVectorFunction { /** Index: ml/src/main/java/org/apache/hama/ml/math/DoubleFunction.java =================================================================== --- ml/src/main/java/org/apache/hama/ml/math/DoubleFunction.java (revision 0) +++ ml/src/main/java/org/apache/hama/ml/math/DoubleFunction.java (working copy) @@ -0,0 +1,43 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.hama.ml.math; + +/** + * A double double function takes two arguments. A vector or matrix can apply + * the double function to each element. + * + */ +public abstract class DoubleFunction extends Function { + + /** + * Apply the function to element. + * + * @param elem The element that the function apply to. + * @return The result after applying the function. + */ + public abstract double apply(double value); + + /** + * Apply the gradient of the function. + * + * @param elem + * @return + */ + public abstract double applyDerivative(double value); + +} Index: ml/src/main/java/org/apache/hama/ml/math/DoubleMatrix.java =================================================================== --- ml/src/main/java/org/apache/hama/ml/math/DoubleMatrix.java (revision 1493881) +++ ml/src/main/java/org/apache/hama/ml/math/DoubleMatrix.java (working copy) @@ -184,4 +184,25 @@ */ public DoubleMatrix slice(int rowOffset, int rowMax, int colOffset, int colMax); + /** + * Apply a double function f(x) onto each element of the matrix. After + * applying, each element of the current matrix will be changed from x to + * f(x). + * + * @param fun The function. + * @return The matrix itself, supply for chain operation. + */ + public DoubleMatrix applyToElements(DoubleFunction fun); + + /** + * Apply a double double function f(x, y) onto each pair of the current matrix + * elements and given matrix. After applying, each element of the current + * matrix will be changed from x to f(x, y). + * + * @param other The matrix contributing the second argument of the function. + * @param fun The function that takes two arguments. + * @return The matrix itself, supply for chain operation. + */ + public DoubleMatrix applyToElements(DoubleMatrix other, DoubleDoubleFunction fun); + } Index: ml/src/main/java/org/apache/hama/ml/math/DoubleVector.java =================================================================== --- ml/src/main/java/org/apache/hama/ml/math/DoubleVector.java (revision 1493881) +++ ml/src/main/java/org/apache/hama/ml/math/DoubleVector.java (working copy) @@ -58,7 +58,7 @@ * @param value the value at the index of the vector to set. */ public void set(int index, double value); - + /** * Apply a given {@link DoubleVectorFunction} to this vector and return a new * one. @@ -66,8 +66,9 @@ * @param func the function to apply. * @return a new vector with the applied function. */ + @Deprecated public DoubleVector apply(DoubleVectorFunction func); - + /** * Apply a given {@link DoubleDoubleVectorFunction} to this vector and the * other given vector. @@ -76,9 +77,29 @@ * @param func the function to apply on this and the other vector. * @return a new vector with the result of the function of the two vectors. */ + @Deprecated public DoubleVector apply(DoubleVector other, DoubleDoubleVectorFunction func); /** + * Apply a given {@link DoubleVectorFunction} to this vector and return a new + * one. + * + * @param func the function to apply. + * @return a new vector with the applied function. + */ + public DoubleVector apply(DoubleFunction func); + + /** + * Apply a given {@link DoubleDoubleVectorFunction} to this vector and the + * other given vector. + * + * @param other the other vector. + * @param func the function to apply on this and the other vector. + * @return a new vector with the result of the function of the two vectors. + */ + public DoubleVector apply(DoubleVector other, DoubleDoubleFunction func); + + /** * Adds the given {@link DoubleVector} to this vector. * * @param v the other vector. Index: ml/src/main/java/org/apache/hama/ml/math/DoubleVectorFunction.java =================================================================== --- ml/src/main/java/org/apache/hama/ml/math/DoubleVectorFunction.java (revision 1493881) +++ ml/src/main/java/org/apache/hama/ml/math/DoubleVectorFunction.java (working copy) @@ -20,7 +20,10 @@ /** * A function that can be applied to a double vector via {@link DoubleVector} * #apply({@link DoubleVectorFunction} f); + * + * This class will be replaced by {@link DoubleFunction} */ +@Deprecated public interface DoubleVectorFunction { /** Index: ml/src/main/java/org/apache/hama/ml/math/Function.java =================================================================== --- ml/src/main/java/org/apache/hama/ml/math/Function.java (revision 0) +++ ml/src/main/java/org/apache/hama/ml/math/Function.java (working copy) @@ -0,0 +1,33 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.hama.ml.math; + +/** + * A generic function. + * + */ +public abstract class Function { + /** + * Get the name of the function. + * + * @return The name of the function. + */ + final public String getFunctionName() { + return this.getClass().getSimpleName(); + } +} Index: ml/src/main/java/org/apache/hama/ml/math/FunctionFactory.java =================================================================== --- ml/src/main/java/org/apache/hama/ml/math/FunctionFactory.java (revision 0) +++ ml/src/main/java/org/apache/hama/ml/math/FunctionFactory.java (working copy) @@ -0,0 +1,61 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.hama.ml.math; + +/** + * Factory to create the functions. + * + */ +public class FunctionFactory { + + /** + * Create a double function with specified name. + * + * @param functionName + * @return + */ + public static DoubleFunction createDoubleFunction(String functionName) { + if (functionName.equals(Sigmoid.class.getSimpleName())) { + return new Sigmoid(); + } else if (functionName.equals(Tanh.class.getSimpleName())) { + return new Tanh(); + } + + throw new IllegalArgumentException(String.format( + "No double function with name '%s' exists.", functionName)); + } + + /** + * Create a double double function with specified name. + * + * @param functionName + * @return + */ + public static DoubleDoubleFunction createDoubleDoubleFunction( + String functionName) { + if (functionName.equals(SquaredError.class.getSimpleName())) { + return new SquaredError(); + } else if (functionName.equals(CrossEntropy.class.getSimpleName())) { + return new CrossEntropy(); + } + + throw new IllegalArgumentException(String.format( + "No double double function with name '%s' exists.", functionName)); + } + +} Index: ml/src/main/java/org/apache/hama/ml/math/Sigmoid.java =================================================================== --- ml/src/main/java/org/apache/hama/ml/math/Sigmoid.java (revision 0) +++ ml/src/main/java/org/apache/hama/ml/math/Sigmoid.java (working copy) @@ -0,0 +1,39 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.hama.ml.math; + +/** + * The Sigmoid function + * + *
+ * f(x) = 1 / (1 + e^{-x})
+ *
+ */
+public class Sigmoid extends DoubleFunction {
+
+ @Override
+ public double apply(double value) {
+ return 1.0 / (1 + Math.exp(-value));
+ }
+
+ @Override
+ public double applyDerivative(double value) {
+ return value * (1 - value);
+ }
+
+}
Index: ml/src/main/java/org/apache/hama/ml/math/SquaredError.java
===================================================================
--- ml/src/main/java/org/apache/hama/ml/math/SquaredError.java (revision 0)
+++ ml/src/main/java/org/apache/hama/ml/math/SquaredError.java (working copy)
@@ -0,0 +1,47 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.hama.ml.math;
+
+/**
+ * Square error cost function.
+ *
+ * + * cost(t, y) = 0.5 * (t - y) ˆ 2 + *+ */ +public class SquaredError extends DoubleDoubleFunction { + + @Override + /** + * {@inheritDoc} + */ + public double apply(double target, double actual) { + double diff = target - actual; + return 0.5 * diff * diff; + } + + @Override + /** + * {@inheritDoc} + */ + public double applyDerivative(double target, double actual) { + // return target - actual; + return actual - target; + } + +} Index: ml/src/main/java/org/apache/hama/ml/math/Tanh.java =================================================================== --- ml/src/main/java/org/apache/hama/ml/math/Tanh.java (revision 0) +++ ml/src/main/java/org/apache/hama/ml/math/Tanh.java (working copy) @@ -0,0 +1,36 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.hama.ml.math; + +/** + * Tanh function. + * + */ +public class Tanh extends DoubleFunction { + + @Override + public double apply(double value) { + return Math.tanh(value); + } + + @Override + public double applyDerivative(double value) { + return 1 - value * value; + } + +} Index: ml/src/main/java/org/apache/hama/ml/perception/CostFunction.java =================================================================== --- ml/src/main/java/org/apache/hama/ml/perception/CostFunction.java (revision 1493881) +++ ml/src/main/java/org/apache/hama/ml/perception/CostFunction.java (working copy) @@ -1,43 +0,0 @@ -/** - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.hama.ml.perception; - -/** - * The common interface for cost functions. - */ -public abstract class CostFunction { - - /** - * Get the error evaluated by squared error. - * - * @param target The target value. - * @param actual The actual value. - * @return - */ - public abstract double calculate(double target, double actual); - - /** - * Get the partial derivative of squared error. - * - * @param target - * @param actual - * @return - */ - public abstract double calculateDerivative(double target, double actual); - -} Index: ml/src/main/java/org/apache/hama/ml/perception/CostFunctionFactory.java =================================================================== --- ml/src/main/java/org/apache/hama/ml/perception/CostFunctionFactory.java (revision 1493881) +++ ml/src/main/java/org/apache/hama/ml/perception/CostFunctionFactory.java (working copy) @@ -1,41 +0,0 @@ -/** - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.hama.ml.perception; - -/** - * The cost function factory that generates the cost function by name. - */ -public class CostFunctionFactory { - - /** - * Get the cost function according to the name. If no matched cost function is - * found, return the SquaredError by default. - * - * @param name The name of the cost function. - * @return The cost function instance. - */ - public static CostFunction getCostFunction(String name) { - if (name.equalsIgnoreCase("SquaredError")) { - return new SquaredError(); - } else if (name.equalsIgnoreCase("CrossEntropy")) { - return new CrossEntropy(); - } - throw new IllegalStateException(String.format( - "No cost function with name '%s' found.", name)); - } -} Index: ml/src/main/java/org/apache/hama/ml/perception/CrossEntropy.java =================================================================== --- ml/src/main/java/org/apache/hama/ml/perception/CrossEntropy.java (revision 1493881) +++ ml/src/main/java/org/apache/hama/ml/perception/CrossEntropy.java (working copy) @@ -1,53 +0,0 @@ -/** - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.hama.ml.perception; - -/** - * The cross entropy cost function. - * - *
- * cost(t, y) = - t * log(y) - (1 - t) * log(1 - y), - * where t denotes the target value, y denotes the estimated value. - *- */ -public class CrossEntropy extends CostFunction { - - @Override - public double calculate(double target, double actual) { - return -target * Math.log(actual) - (1 - target) * Math.log(1 - actual); - } - - @Override - public double calculateDerivative(double target, double actual) { - double adjustedTarget = target; - double adjustedActual = actual; - if (adjustedActual == 1) { - adjustedActual = 0.999; - } else if (actual == 0) { - adjustedActual = 0.001; - } - if (adjustedTarget == 1) { - adjustedTarget = 0.999; - } else if (adjustedTarget == 0) { - adjustedTarget = 0.001; - } - return -adjustedTarget / adjustedActual + (1 - adjustedTarget) - / (1 - adjustedActual); - } - -} Index: ml/src/main/java/org/apache/hama/ml/perception/LogisticCostFunction.java =================================================================== --- ml/src/main/java/org/apache/hama/ml/perception/LogisticCostFunction.java (revision 1493881) +++ ml/src/main/java/org/apache/hama/ml/perception/LogisticCostFunction.java (working copy) @@ -1,53 +0,0 @@ -/** - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.hama.ml.perception; - -/** - * The logistic cost function. - * - *
- * cost(t, y) = - t * log(y) - (1 - t) * log(1 - y), - * where t denotes the target value, y denotes the estimated value. - *- */ -public class LogisticCostFunction extends CostFunction { - - @Override - public double calculate(double target, double actual) { - return -target * Math.log(actual) - (1 - target) * Math.log(1 - actual); - } - - @Override - public double calculateDerivative(double target, double actual) { - double adjustedTarget = target; - double adjustedActual = actual; - if (adjustedActual == 1) { - adjustedActual = 0.999; - } else if (actual == 0) { - adjustedActual = 0.001; - } - if (adjustedTarget == 1) { - adjustedTarget = 0.999; - } else if (adjustedTarget == 0) { - adjustedTarget = 0.001; - } - return -adjustedTarget / adjustedActual + (1 - adjustedTarget) - / (1 - adjustedActual); - } - -} Index: ml/src/main/java/org/apache/hama/ml/perception/MultiLayerPerceptron.java =================================================================== --- ml/src/main/java/org/apache/hama/ml/perception/MultiLayerPerceptron.java (revision 1493881) +++ ml/src/main/java/org/apache/hama/ml/perception/MultiLayerPerceptron.java (working copy) @@ -21,7 +21,10 @@ import java.util.Map; import org.apache.hadoop.fs.Path; +import org.apache.hama.ml.math.DoubleDoubleFunction; +import org.apache.hama.ml.math.DoubleFunction; import org.apache.hama.ml.math.DoubleVector; +import org.apache.hama.ml.math.FunctionFactory; /** * PerceptronBase defines the common behavior of all the concrete perceptrons. @@ -43,8 +46,8 @@ protected String costFunctionName; protected int[] layerSizeArray; - protected CostFunction costFunction; - protected SquashingFunction squashingFunction; + protected DoubleDoubleFunction costFunction; + protected DoubleFunction squashingFunction; /** * Initialize the MLP. @@ -83,10 +86,10 @@ this.layerSizeArray = layerSizeArray; this.numberOfLayers = this.layerSizeArray.length; - this.costFunction = CostFunctionFactory - .getCostFunction(this.costFunctionName); - this.squashingFunction = SquashingFunctionFactory - .getSquashingFunction(this.squashingFunctionName); + this.costFunction = FunctionFactory + .createDoubleDoubleFunction(this.costFunctionName); + this.squashingFunction = FunctionFactory + .createDoubleFunction(this.squashingFunctionName); } /** Index: ml/src/main/java/org/apache/hama/ml/perception/Sigmoid.java =================================================================== --- ml/src/main/java/org/apache/hama/ml/perception/Sigmoid.java (revision 1493881) +++ ml/src/main/java/org/apache/hama/ml/perception/Sigmoid.java (working copy) @@ -1,38 +0,0 @@ -/** - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.hama.ml.perception; - -/** - * The Sigmoid function - * - *
- * f(z) = 1 / (1 + e^{-z})
- *
- */
-public class Sigmoid extends SquashingFunction {
-
- @Override
- public double calculate(int index, double value) {
- return 1.0 / (1 + Math.exp(-value));
- }
-
- @Override
- public double calculateDerivative(double value) {
- return value * (1 - value);
- }
-}
Index: ml/src/main/java/org/apache/hama/ml/perception/SmallMultiLayerPerceptron.java
===================================================================
--- ml/src/main/java/org/apache/hama/ml/perception/SmallMultiLayerPerceptron.java (revision 1493881)
+++ ml/src/main/java/org/apache/hama/ml/perception/SmallMultiLayerPerceptron.java (working copy)
@@ -41,7 +41,9 @@
import org.apache.hama.bsp.BSPJob;
import org.apache.hama.ml.math.DenseDoubleMatrix;
import org.apache.hama.ml.math.DenseDoubleVector;
+import org.apache.hama.ml.math.DoubleFunction;
import org.apache.hama.ml.math.DoubleVector;
+import org.apache.hama.ml.math.FunctionFactory;
import org.apache.hama.ml.writable.MatrixWritable;
import org.apache.hama.ml.writable.VectorWritable;
import org.mortbay.log.Log;
@@ -102,18 +104,34 @@
private void initializeWeightMatrix() {
this.weightMatrice = new DenseDoubleMatrix[this.numberOfLayers - 1];
// each layer contains one bias neuron
- Random rnd = new Random();
for (int i = 0; i < this.numberOfLayers - 1; ++i) {
// add weights for bias
this.weightMatrice[i] = new DenseDoubleMatrix(this.layerSizeArray[i] + 1,
this.layerSizeArray[i + 1]);
- int rowCount = this.weightMatrice[i].getRowCount();
- int colCount = this.weightMatrice[i].getColumnCount();
- for (int row = 0; row < rowCount; ++row) {
- for (int col = 0; col < colCount; ++col) {
- this.weightMatrice[i].set(row, col, rnd.nextDouble() - 0.5);
+
+ this.weightMatrice[i].applyToElements(new DoubleFunction() {
+
+ private Random rnd = new Random();
+
+ @Override
+ public double apply(double value) {
+ return rnd.nextDouble() - 0.5;
}
- }
+
+ @Override
+ public double applyDerivative(double value) {
+ throw new UnsupportedOperationException("Not supported");
+ }
+
+ });
+
+// int rowCount = this.weightMatrice[i].getRowCount();
+// int colCount = this.weightMatrice[i].getColumnCount();
+// for (int row = 0; row < rowCount; ++row) {
+// for (int col = 0; col < colCount; ++col) {
+// this.weightMatrice[i].set(row, col, rnd.nextDouble() - 0.5);
+// }
+// }
}
}
@@ -199,8 +217,7 @@
prevNeuronIdx, neuronIdx) * intermediateResult[prevNeuronIdx];
}
// calculate via squashing function
- results[neuronIdx + offset] = this.squashingFunction.calculate(0,
- results[neuronIdx + offset]);
+ results[neuronIdx + offset] = this.squashingFunction.apply(results[neuronIdx + offset]);
}
return results;
@@ -243,7 +260,7 @@
DenseDoubleMatrix prevWeightUpdateMatrix = this.prevWeightUpdateMatrices[this.prevWeightUpdateMatrices.length - 1];
for (int j = 0; j < delta.length; ++j) {
- delta[j] = this.costFunction.calculateDerivative(trainingLabels[j],
+ delta[j] = this.costFunction.applyDerivative(trainingLabels[j],
outputLayerOutput[j]);
// add regularization term
if (this.regularization != 0.0) {
@@ -257,7 +274,7 @@
}
delta[j] *= this.squashingFunction
- .calculateDerivative(outputLayerOutput[j]);
+ .applyDerivative(outputLayerOutput[j]);
// calculate the weight update matrix between the last hidden layer and
// the output layer
@@ -307,7 +324,7 @@
delta[j] += weight * nextLayerDelta[k];
}
delta[j] *= this.squashingFunction
- .calculateDerivative(curLayerOutput[j + 1]);
+ .applyDerivative(curLayerOutput[j + 1]);
// calculate the weight update matrix between the previous layer and the
// current layer
@@ -395,10 +412,10 @@
for (int i = 0; i < numberOfLayers - 1; ++i) {
this.weightMatrice[i] = (DenseDoubleMatrix) MatrixWritable.read(input);
}
- this.squashingFunction = SquashingFunctionFactory
- .getSquashingFunction(this.squashingFunctionName);
- this.costFunction = CostFunctionFactory
- .getCostFunction(this.costFunctionName);
+ this.squashingFunction = FunctionFactory
+ .createDoubleFunction(this.squashingFunctionName);
+ this.costFunction = FunctionFactory
+ .createDoubleDoubleFunction(this.costFunctionName);
}
@Override
Index: ml/src/main/java/org/apache/hama/ml/perception/SquaredError.java
===================================================================
--- ml/src/main/java/org/apache/hama/ml/perception/SquaredError.java (revision 1493881)
+++ ml/src/main/java/org/apache/hama/ml/perception/SquaredError.java (working copy)
@@ -1,47 +0,0 @@
-/**
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.hama.ml.perception;
-
-/**
- * Square error cost function.
- *
- * - * cost(t, y) = 0.5 * (t - y) ˆ 2 - *- */ -public class SquaredError extends CostFunction { - - @Override - /** - * {@inheritDoc} - */ - public double calculate(double target, double actual) { - double diff = target - actual; - return 0.5 * diff * diff; - } - - @Override - /** - * {@inheritDoc} - */ - public double calculateDerivative(double target, double actual) { - // return target - actual; - return actual - target; - } - -} Index: ml/src/main/java/org/apache/hama/ml/perception/SquashingFunction.java =================================================================== --- ml/src/main/java/org/apache/hama/ml/perception/SquashingFunction.java (revision 1493881) +++ ml/src/main/java/org/apache/hama/ml/perception/SquashingFunction.java (working copy) @@ -1,41 +0,0 @@ -/** - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.hama.ml.perception; - -import org.apache.hama.ml.math.DoubleVectorFunction; - -/** - * The squashing function to activate the neurons. - * - */ -public abstract class SquashingFunction implements DoubleVectorFunction { - - /** - * Calculates the result with a given index and value of a vector. - */ - @Override - public abstract double calculate(int index, double value); - - /** - * Apply the gradient descent to each of the elements in vector. - * - * @param vector - * @return - */ - public abstract double calculateDerivative(double value); -} Index: ml/src/main/java/org/apache/hama/ml/perception/SquashingFunctionFactory.java =================================================================== --- ml/src/main/java/org/apache/hama/ml/perception/SquashingFunctionFactory.java (revision 1493881) +++ ml/src/main/java/org/apache/hama/ml/perception/SquashingFunctionFactory.java (working copy) @@ -1,43 +0,0 @@ -/** - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.hama.ml.perception; - -/** - * Get the squashing function according to the name. - */ -public class SquashingFunctionFactory { - - /** - * Get the squashing function instance according to the name. If no matched - * squahsing function is found, return the sigmoid squashing function by - * default. - * - * @param name The name of the squashing function. - * @return The instance of the squashing function. - */ - public static SquashingFunction getSquashingFunction(String name) { - if (name.equalsIgnoreCase("Sigmoid")) { - return new Sigmoid(); - } else if (name.equalsIgnoreCase("Tanh")) { - return new Tanh(); - } - throw new IllegalStateException(String.format( - "No squashing function with name '%s' found.", name)); - } - -} Index: ml/src/main/java/org/apache/hama/ml/perception/Tanh.java =================================================================== --- ml/src/main/java/org/apache/hama/ml/perception/Tanh.java (revision 1493881) +++ ml/src/main/java/org/apache/hama/ml/perception/Tanh.java (working copy) @@ -1,36 +0,0 @@ -/** - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.hama.ml.perception; - -/** - * The hyperbolic tangent function. It is used as a squashing function in - * multi-layer perceptron. - */ -public class Tanh extends SquashingFunction { - - @Override - public double calculate(int index, double value) { - return Math.tanh(value); - } - - @Override - public double calculateDerivative(double value) { - return 1 - value * value; - } - -} Index: ml/src/test/java/org/apache/hama/ml/math/TestDenseDoubleMatrix.java =================================================================== --- ml/src/test/java/org/apache/hama/ml/math/TestDenseDoubleMatrix.java (revision 0) +++ ml/src/test/java/org/apache/hama/ml/math/TestDenseDoubleMatrix.java (working copy) @@ -0,0 +1,86 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.hama.ml.math; + +import static org.junit.Assert.assertArrayEquals; + +import org.junit.Test; + +/** + * Test case for {@link DenseDoubleMatrix} + * + */ +public class TestDenseDoubleMatrix { + + @Test + public void testDoubleFunction() { + double[][] values = new double[][] { { 1, 2, 3 }, { 4, 5, 6 }, { 7, 8, 9 } }; + + double[][] result = new double[][] { { 2, 3, 4 }, { 5, 6, 7 }, { 8, 9, 10 } }; + + DenseDoubleMatrix mat = new DenseDoubleMatrix(values); + mat.applyToElements(new DoubleFunction() { + + @Override + public double apply(double value) { + return value + 1; + } + + @Override + public double applyDerivative(double value) { + throw new UnsupportedOperationException(); + } + + }); + + double[][] actual = mat.getValues(); + for (int i = 0; i < actual.length; ++i) { + assertArrayEquals(result[i], actual[i], 0.0001); + } + } + + @Test + public void testDoubleDoubleFunction() { + double[][] values1 = new double[][] { { 1, 2, 3 }, { 4, 5, 6 }, { 7, 8, 9 } }; + double[][] values2 = new double[][] { { 2, 3, 4 }, { 5, 6, 7 }, { 8, 9, 10 } }; + double[][] result = new double[][] { {3, 5, 7}, {9, 11, 13}, {15, 17, 19}}; + + DenseDoubleMatrix mat1 = new DenseDoubleMatrix(values1); + DenseDoubleMatrix mat2 = new DenseDoubleMatrix(values2); + + mat1.applyToElements(mat2, new DoubleDoubleFunction() { + + @Override + public double apply(double x1, double x2) { + return x1 + x2; + } + + @Override + public double applyDerivative(double x1, double x2) { + throw new UnsupportedOperationException(); + } + + }); + + double[][] actual = mat1.getValues(); + for (int i = 0; i < actual.length; ++i) { + assertArrayEquals(result[i], actual[i], 0.0001); + } + } + +} Index: ml/src/test/java/org/apache/hama/ml/math/TestDenseDoubleVector.java =================================================================== --- ml/src/test/java/org/apache/hama/ml/math/TestDenseDoubleVector.java (revision 0) +++ ml/src/test/java/org/apache/hama/ml/math/TestDenseDoubleVector.java (working copy) @@ -0,0 +1,80 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.hama.ml.math; + +import static org.junit.Assert.assertArrayEquals; + +import org.junit.Test; + +/** + * Testcase for {@link DenseDoubleVector} + * + */ +public class TestDenseDoubleVector { + + @Test + public void testApplyDoubleFunction() { + double[] values = new double[] {1, 2, 3, 4, 5}; + double[] result = new double[] {2, 3, 4, 5, 6}; + + DoubleVector vec1 = new DenseDoubleVector(values); + + vec1.apply(new DoubleFunction() { + + @Override + public double apply(double value) { + return value + 1; + } + + @Override + public double applyDerivative(double value) { + throw new UnsupportedOperationException("Not supported."); + } + + }); + + assertArrayEquals(result, vec1.toArray(), 0.0001); + } + + @Test + public void testApplyDoubleDoubleFunction() { + double[] values1 = new double[] {1, 2, 3, 4, 5, 6}; + double[] values2 = new double[] {7, 8, 9, 10, 11, 12}; + double[] result = new double[] {8, 10, 12, 14, 16, 18}; + + DoubleVector vec1 = new DenseDoubleVector(values1); + DoubleVector vec2 = new DenseDoubleVector(values2); + + vec1.apply(vec2, new DoubleDoubleFunction() { + + @Override + public double apply(double x1, double x2) { + return x1 + x2; + } + + @Override + public double applyDerivative(double x1, double x2) { + throw new UnsupportedOperationException("Not supported"); + } + + }); + + assertArrayEquals(result, vec1.toArray(), 0.0001); + + } +} Index: ml/src/test/java/org/apache/hama/ml/math/TestFunctionFactory.java =================================================================== --- ml/src/test/java/org/apache/hama/ml/math/TestFunctionFactory.java (revision 0) +++ ml/src/test/java/org/apache/hama/ml/math/TestFunctionFactory.java (working copy) @@ -0,0 +1,71 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.hama.ml.math; + +import static org.junit.Assert.assertEquals; + +import org.junit.Test; + +/** + * Test case for {@link FunctionFactory} + * + */ +public class TestFunctionFactory { + + @Test + public void testCreateDoubleFunction() { + double input = 0.8; + + String sigmoidName = "Sigmoid"; + DoubleFunction sigmoidFunction = FunctionFactory + .createDoubleFunction(sigmoidName); + assertEquals(sigmoidName, sigmoidFunction.getFunctionName()); + + double sigmoidExcepted = 0.68997448; + assertEquals(sigmoidExcepted, sigmoidFunction.apply(input), 0.000001); + + String tanhName = "Tanh"; + DoubleFunction tanhFunction = FunctionFactory.createDoubleFunction(tanhName); + assertEquals(tanhName, tanhFunction.getFunctionName()); + + double tanhExpected = 0.66403677; + assertEquals(tanhExpected, tanhFunction.apply(input), 0.00001); + } + + @Test + public void testCreateDoubleDoubleFunction() { + double target = 0.5; + double output = 0.8; + + String squaredErrorName = "SquaredError"; + DoubleDoubleFunction squaredErrorFunction = FunctionFactory.createDoubleDoubleFunction(squaredErrorName); + assertEquals(squaredErrorName, squaredErrorFunction.getFunctionName()); + + double squaredErrorExpected = 0.045; + + assertEquals(squaredErrorExpected, squaredErrorFunction.apply(target, output), 0.000001); + + String crossEntropyName = "CrossEntropy"; + DoubleDoubleFunction crossEntropyFunction = FunctionFactory.createDoubleDoubleFunction(crossEntropyName); + assertEquals(crossEntropyName, crossEntropyFunction.getFunctionName()); + + double crossEntropyExpected = 0.91629; + assertEquals(crossEntropyExpected, crossEntropyFunction.apply(target, output), 0.000001); + } + +} Index: ml/src/test/java/org/apache/hama/ml/perception/TestSmallMultiLayerPerceptron.java =================================================================== --- ml/src/test/java/org/apache/hama/ml/perception/TestSmallMultiLayerPerceptron.java (revision 1493881) +++ ml/src/test/java/org/apache/hama/ml/perception/TestSmallMultiLayerPerceptron.java (working copy) @@ -1,5 +1,5 @@ /** - * Licensed to the Apache Software Foundation (ASF) under one +c * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file