Index: ml/src/main/java/org/apache/hama/ml/perception/CostFunction.java =================================================================== --- ml/src/main/java/org/apache/hama/ml/perception/CostFunction.java (revision 0) +++ ml/src/main/java/org/apache/hama/ml/perception/CostFunction.java (working copy) @@ -0,0 +1,44 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.hama.ml.perception; + +/** + * The common interface for cost functions. + * + */ +public abstract class CostFunction { + + /** + * Get the error evaluated by squared error. + * + * @param target The target value. + * @param actual The actual value. + * @return + */ + public abstract double calculate(double target, double actual); + + /** + * Get the partial derivative of squared error. + * + * @param target + * @param actual + * @return + */ + public abstract double calculateDerivative(double target, double actual); + +} Index: ml/src/main/java/org/apache/hama/ml/perception/CostFunctionFactory.java =================================================================== --- ml/src/main/java/org/apache/hama/ml/perception/CostFunctionFactory.java (revision 0) +++ ml/src/main/java/org/apache/hama/ml/perception/CostFunctionFactory.java (working copy) @@ -0,0 +1,40 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.hama.ml.perception; + +/** + * The cost function factory that generates the cost function by name. + */ +public class CostFunctionFactory { + + /** + * Get the cost function according to the name. If no matched cost function is + * found, return the SquaredError by default. + * + * @param name The name of the cost function. + * @return The cost function instance. + */ + public static CostFunction getCostFunction(String name) { + if (name.equalsIgnoreCase("SquaredError")) { + return new SquaredError(); + } else if (name.equalsIgnoreCase("LogisticError")) { + return new LogisticCostFunction(); + } + return new SquaredError(); + } +} Index: ml/src/main/java/org/apache/hama/ml/perception/LogisticCostFunction.java =================================================================== --- ml/src/main/java/org/apache/hama/ml/perception/LogisticCostFunction.java (revision 0) +++ ml/src/main/java/org/apache/hama/ml/perception/LogisticCostFunction.java (working copy) @@ -0,0 +1,53 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.hama.ml.perception; + +/** + * The logistic cost function. + * + *
+ * cost(t, y) = - t * log(y) - (1 - t) * log(1 - y), + * where t denotes the target value, y denotes the estimated value. + *+ */ +public class LogisticCostFunction extends CostFunction { + + @Override + public double calculate(double target, double actual) { + return -target * Math.log(actual) - (1 - target) * Math.log(1 - actual); + } + + @Override + public double calculateDerivative(double target, double actual) { + double adjustedTarget = target; + double adjustedActual = actual; + if (adjustedActual == 1) { + adjustedActual = 0.999; + } else if (actual == 0) { + adjustedActual = 0.001; + } + if (adjustedTarget == 1) { + adjustedTarget = 0.999; + } else if (adjustedTarget == 0) { + adjustedTarget = 0.001; + } + return -adjustedTarget / adjustedActual + (1 - adjustedTarget) + / (1 - adjustedActual); + } + +} Index: ml/src/main/java/org/apache/hama/ml/perception/MLPMessage.java =================================================================== --- ml/src/main/java/org/apache/hama/ml/perception/MLPMessage.java (revision 0) +++ ml/src/main/java/org/apache/hama/ml/perception/MLPMessage.java (working copy) @@ -0,0 +1,41 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.hama.ml.perception; + +import org.apache.hadoop.io.Writable; + +/** + * MLPMessage is used to hold the parameters that needs to be sent between the + * tasks. + */ +public abstract class MLPMessage implements Writable { + protected boolean terminated; + + public MLPMessage(boolean terminated) { + setTerminated(terminated); + } + + public void setTerminated(boolean terminated) { + this.terminated = terminated; + } + + public boolean isTerminated() { + return terminated; + } + +} Index: ml/src/main/java/org/apache/hama/ml/perception/MultiLayerPerceptron.java =================================================================== --- ml/src/main/java/org/apache/hama/ml/perception/MultiLayerPerceptron.java (revision 0) +++ ml/src/main/java/org/apache/hama/ml/perception/MultiLayerPerceptron.java (working copy) @@ -0,0 +1,158 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.hama.ml.perception; + +import java.io.IOException; +import java.util.Map; + +import org.apache.hadoop.fs.Path; +import org.apache.hama.ml.math.DoubleVector; + +/** + * PerceptronBase defines the common behavior of all the concrete perceptrons. + */ +public abstract class MultiLayerPerceptron { + + /* The trainer for the model */ + protected PerceptronTrainer trainer; + /* The file path that contains the model meta-data */ + protected String modelPath; + + /* Model meta-data */ + protected String MLPType; + protected double learningRate; + protected boolean regularization; + protected double momentum; + protected int numberOfLayers; + protected String squashingFunctionName; + protected String costFunctionName; + protected int[] layerSizeArray; + + protected CostFunction costFunction; + protected SquashingFunction squashingFunction; + + /** + * Initialize the MLP. + * + * @param learningRate Larger learningRate makes MLP learn more aggressive. + * @param regularization Turn on regularization make MLP less likely to + * overfit. + * @param momentum The momentum makes the historical adjust have affect to + * current adjust. + * @param squashingFunctionName The name of squashing function. + * @param costFunctionName The name of the cost function. + * @param layerSizeArray The number of neurons for each layer. Note that the + * actual size of each layer is one more than the input size. + */ + public MultiLayerPerceptron(double learningRate, boolean regularization, + double momentum, String squashingFunctionName, String costFunctionName, + int[] layerSizeArray) { + this.learningRate = learningRate; + this.regularization = regularization; // no regularization + this.momentum = momentum; // no momentum + this.squashingFunctionName = squashingFunctionName; + this.costFunctionName = costFunctionName; + this.layerSizeArray = layerSizeArray; + this.numberOfLayers = this.layerSizeArray.length; + + this.costFunction = CostFunctionFactory + .getCostFunction(this.costFunctionName); + this.squashingFunction = SquashingFunctionFactory + .getSquashingFunction(this.squashingFunctionName); + } + + /** + * Initialize a multi-layer perceptron with existing model. + * + * @param modelPath Location of existing model meta-data. + */ + public MultiLayerPerceptron(String modelPath) { + this.modelPath = modelPath; + } + + /** + * Train the model with given data. This method invokes a perceptron training + * BSP task to train the model. It then write the model to modelPath. + * + * @param dataInputPath The path of the data. + * @param trainingParams Extra parameters for training. + */ + public abstract void train(Path dataInputPath, + Map
+ * f(z) = 1 / (1 + e^{-z})
+ *
+ */
+public class Sigmoid extends SquashingFunction {
+
+ @Override
+ public double calculate(int index, double value) {
+ return 1.0 / (1 + Math.exp(-value));
+ }
+
+ @Override
+ public double calculateDerivative(double value) {
+ return value * (1 - value);
+ }
+}
Index: ml/src/main/java/org/apache/hama/ml/perception/SmallMLPMessage.java
===================================================================
--- ml/src/main/java/org/apache/hama/ml/perception/SmallMLPMessage.java (revision 0)
+++ ml/src/main/java/org/apache/hama/ml/perception/SmallMLPMessage.java (working copy)
@@ -0,0 +1,86 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.hama.ml.perception;
+
+import java.io.DataInput;
+import java.io.DataOutput;
+import java.io.IOException;
+
+import org.apache.hama.ml.math.DenseDoubleMatrix;
+import org.apache.hama.ml.writable.MatrixWritable;
+
+/**
+ * SmallMLPMessage is used to exchange information for the
+ * {@link SmallMultiLayerPerceptron}. It send the whole parameter matrix from
+ * one task to another.
+ */
+public class SmallMLPMessage extends MLPMessage {
+
+ private int owner; // the ID of the task who creates the message
+ private DenseDoubleMatrix[] weightUpdatedMatrices;
+ private int numOfMatrices;
+
+ public SmallMLPMessage(int owner, boolean terminated, DenseDoubleMatrix[] mat) {
+ super(terminated);
+ this.owner = owner;
+ this.weightUpdatedMatrices = mat;
+ this.numOfMatrices = this.weightUpdatedMatrices == null ? 0
+ : this.weightUpdatedMatrices.length;
+ }
+
+ /**
+ * Get the owner task Id of the message.
+ *
+ * @return
+ */
+ public int getOwner() {
+ return owner;
+ }
+
+ /**
+ * Get the updated weight matrices.
+ *
+ * @return
+ */
+ public DenseDoubleMatrix[] getWeightsUpdatedMatrices() {
+ return this.weightUpdatedMatrices;
+ }
+
+ @Override
+ public void readFields(DataInput input) throws IOException {
+ this.owner = input.readInt();
+ this.terminated = input.readBoolean();
+ this.numOfMatrices = input.readInt();
+ this.weightUpdatedMatrices = new DenseDoubleMatrix[this.numOfMatrices];
+ for (int i = 0; i < this.numOfMatrices; ++i) {
+ this.weightUpdatedMatrices[i] = (DenseDoubleMatrix) MatrixWritable
+ .read(input);
+ }
+ }
+
+ @Override
+ public void write(DataOutput output) throws IOException {
+ output.writeInt(this.owner);
+ output.writeBoolean(this.terminated);
+ output.writeInt(this.numOfMatrices);
+ for (int i = 0; i < this.numOfMatrices; ++i) {
+ MatrixWritable.write(this.weightUpdatedMatrices[i], output);
+ }
+ }
+
+}
Index: ml/src/main/java/org/apache/hama/ml/perception/SmallMLPTrainer.java
===================================================================
--- ml/src/main/java/org/apache/hama/ml/perception/SmallMLPTrainer.java (revision 0)
+++ ml/src/main/java/org/apache/hama/ml/perception/SmallMLPTrainer.java (working copy)
@@ -0,0 +1,320 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.hama.ml.perception;
+
+import java.io.IOException;
+import java.util.Arrays;
+import java.util.BitSet;
+
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+import org.apache.hadoop.io.LongWritable;
+import org.apache.hadoop.io.NullWritable;
+import org.apache.hama.bsp.BSPPeer;
+import org.apache.hama.bsp.sync.SyncException;
+import org.apache.hama.ml.math.DenseDoubleMatrix;
+import org.apache.hama.ml.writable.VectorWritable;
+
+/**
+ * The perceptron trainer for small scale MLP.
+ */
+public class SmallMLPTrainer extends PerceptronTrainer {
+
+ private static final Log LOG = LogFactory.getLog(SmallMLPTrainer.class);
+ /* used by master only, check whether all slaves finishes reading */
+ private BitSet statusSet;
+
+ private int numTrainingInstanceRead = 0;
+ /* Once reader reaches the EOF, the training procedure would be terminated */
+ private boolean terminateTraining = false;
+
+ private SmallMultiLayerPerceptron inMemoryPerceptron;
+
+ private int[] layerSizeArray;
+
+ @Override
+ protected void extraSetup(
+ BSPPeer+ * cost(t, y) = 0.5 * (t - y) ˆ 2 + *+ */ +public class SquaredError extends CostFunction { + + @Override + /** + * {@inheritDoc} + */ + public double calculate(double target, double actual) { + double diff = target - actual; + return 0.5 * diff * diff; + } + + @Override + /** + * {@inheritDoc} + */ + public double calculateDerivative(double target, double actual) { + return target - actual; + } + +} Index: ml/src/main/java/org/apache/hama/ml/perception/SquashingFunction.java =================================================================== --- ml/src/main/java/org/apache/hama/ml/perception/SquashingFunction.java (revision 0) +++ ml/src/main/java/org/apache/hama/ml/perception/SquashingFunction.java (working copy) @@ -0,0 +1,41 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.hama.ml.perception; + +import org.apache.hama.ml.math.DoubleVectorFunction; + +/** + * The squashing function to activate the neurons. + * + */ +public abstract class SquashingFunction implements DoubleVectorFunction { + + /** + * Calculates the result with a given index and value of a vector. + */ + @Override + public abstract double calculate(int index, double value); + + /** + * Apply the gradient descent to each of the elements in vector. + * + * @param vector + * @return + */ + public abstract double calculateDerivative(double value); +} Index: ml/src/main/java/org/apache/hama/ml/perception/SquashingFunctionFactory.java =================================================================== --- ml/src/main/java/org/apache/hama/ml/perception/SquashingFunctionFactory.java (revision 0) +++ ml/src/main/java/org/apache/hama/ml/perception/SquashingFunctionFactory.java (working copy) @@ -0,0 +1,42 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.hama.ml.perception; + +/** + * Get the squashing function according to the name. + */ +public class SquashingFunctionFactory { + + /** + * Get the squashing function instance according to the name. If no matched + * squahsing function is found, return the sigmoid squashing function by + * default. + * + * @param name The name of the squashing function. + * @return The instance of the squashing function. + */ + public static SquashingFunction getSquashingFunction(String name) { + if (name.equalsIgnoreCase("Sigmoid")) { + return new Sigmoid(); + } else if (name.equalsIgnoreCase("Tanh")) { + return new Tanh(); + } + return new Sigmoid(); + } + +} Index: ml/src/main/java/org/apache/hama/ml/perception/Tanh.java =================================================================== --- ml/src/main/java/org/apache/hama/ml/perception/Tanh.java (revision 0) +++ ml/src/main/java/org/apache/hama/ml/perception/Tanh.java (working copy) @@ -0,0 +1,36 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.hama.ml.perception; + +/** + * The hyperbolic tangent function. It is used as a squashing function in + * multi-layer perceptron. + */ +public class Tanh extends SquashingFunction { + + @Override + public double calculate(int index, double value) { + return Math.tanh(value); + } + + @Override + public double calculateDerivative(double value) { + return 1 - value * value; + } + +} Index: ml/src/test/java/org/apache/hama/ml/perception/TestSmallMLPMessage.java =================================================================== --- ml/src/test/java/org/apache/hama/ml/perception/TestSmallMLPMessage.java (revision 0) +++ ml/src/test/java/org/apache/hama/ml/perception/TestSmallMLPMessage.java (working copy) @@ -0,0 +1,87 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.hama.ml.perception; + +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertEquals; + +import java.io.IOException; +import java.net.URI; +import java.net.URISyntaxException; + +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.FSDataInputStream; +import org.apache.hadoop.fs.FSDataOutputStream; +import org.apache.hadoop.fs.FileSystem; +import org.apache.hadoop.fs.Path; +import org.apache.hama.ml.math.DenseDoubleMatrix; +import org.junit.Test; + +/** + * Test the functionalities of SmallMLPMessage + * + */ +public class TestSmallMLPMessage { + + @Test + public void testReadWrite() { + int owner = 101; + double[][] mat = { { 1, 2, 3 }, { 4, 5, 6 }, { 7, 8, 9 } }; + + double[][] mat2 = { { 10, 20 }, { 30, 40 }, { 50, 60 } }; + + double[][][] mats = { mat, mat2 }; + + DenseDoubleMatrix[] matrices = new DenseDoubleMatrix[] { + new DenseDoubleMatrix(mat), new DenseDoubleMatrix(mat2) }; + + SmallMLPMessage message = new SmallMLPMessage(owner, true, matrices); + + Configuration conf = new Configuration(); + String strPath = "testSmallMLPMessage"; + Path path = new Path(strPath); + try { + FileSystem fs = FileSystem.get(new URI(strPath), conf); + FSDataOutputStream out = fs.create(path, true); + message.write(out); + out.close(); + + FSDataInputStream in = fs.open(path); + SmallMLPMessage outMessage = new SmallMLPMessage(0, false, null); + outMessage.readFields(in); + + assertEquals(owner, outMessage.getOwner()); + DenseDoubleMatrix[] outMatrices = outMessage.getWeightsUpdatedMatrices(); + // check each matrix + for (int i = 0; i < outMatrices.length; ++i) { + double[][] outMat = outMessage.getWeightsUpdatedMatrices()[i] + .getValues(); + for (int j = 0; j < outMat.length; ++j) { + assertArrayEquals(mats[i][j], outMat[j], 0.0001); + } + } + + fs.delete(path, true); + } catch (IOException e) { + e.printStackTrace(); + } catch (URISyntaxException e) { + e.printStackTrace(); + } + + } +} Index: ml/src/test/java/org/apache/hama/ml/perception/TestSmallMultiLayerPerceptron.java =================================================================== --- ml/src/test/java/org/apache/hama/ml/perception/TestSmallMultiLayerPerceptron.java (revision 0) +++ ml/src/test/java/org/apache/hama/ml/perception/TestSmallMultiLayerPerceptron.java (working copy) @@ -0,0 +1,283 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.hama.ml.perception; + +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertEquals; + +import java.io.IOException; +import java.net.URI; +import java.util.HashMap; +import java.util.Map; +import java.util.Random; + +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.FSDataOutputStream; +import org.apache.hadoop.fs.FileSystem; +import org.apache.hadoop.fs.Path; +import org.apache.hadoop.io.LongWritable; +import org.apache.hadoop.io.SequenceFile; +import org.apache.hadoop.io.WritableUtils; +import org.apache.hama.ml.math.DenseDoubleMatrix; +import org.apache.hama.ml.math.DenseDoubleVector; +import org.apache.hama.ml.math.DoubleMatrix; +import org.apache.hama.ml.math.DoubleVector; +import org.apache.hama.ml.writable.MatrixWritable; +import org.apache.hama.ml.writable.VectorWritable; +import org.junit.Test; + +public class TestSmallMultiLayerPerceptron { + + /** + * Write and read the parameters of MLP. + */ + @Test + public void testWriteReadMLP() { + String modelPath = "src/test/resources/perception/sampleModel.data"; + double learningRate = 0.5; + boolean regularization = false; // no regularization + double momentum = 0; // no momentum + String squashingFunctionName = "Sigmoid"; + String costFunctionName = "SquaredError"; + int[] layerSizeArray = new int[] { 3, 2, 2, 3 }; + MultiLayerPerceptron mlp = new SmallMultiLayerPerceptron(learningRate, + regularization, momentum, squashingFunctionName, costFunctionName, + layerSizeArray); + try { + mlp.writeModelToFile(modelPath); + } catch (IOException e) { + e.printStackTrace(); + } + + try { + // read the meta-data + Configuration conf = new Configuration(); + FileSystem fs = FileSystem.get(conf); + mlp = new SmallMultiLayerPerceptron(modelPath); + assertEquals("SmallMLP", mlp.getMLPType()); + assertEquals(learningRate, mlp.getLearningRate(), 0.001); + assertEquals(regularization, mlp.isRegularization()); + assertEquals(layerSizeArray.length, mlp.getNumberOfLayers()); + assertEquals(momentum, mlp.getMomentum(), 0.001); + assertEquals(squashingFunctionName, mlp.getSquashingFunctionName()); + assertEquals(costFunctionName, mlp.getCostFunctionName()); + assertArrayEquals(layerSizeArray, mlp.getLayerSizeArray()); + // delete test file + fs.delete(new Path(modelPath), true); + } catch (IOException e) { + e.printStackTrace(); + } + } + + /** + * Test the output of an example MLP. + */ + @Test + public void testOutput() { + // write the MLP meta-data manually + String modelPath = "src/test/resources/perception/sampleModel.data"; + Configuration conf = new Configuration(); + try { + FileSystem fs = FileSystem.get(conf); + FSDataOutputStream output = fs.create(new Path(modelPath)); + + String MLPType = "SmallMLP"; + double learningRate = 0.5; + boolean regularization = false; + double momentum = 0; + String squashingFunctionName = "Sigmoid"; + String costFunctionName = "SquaredError"; + int[] layerSizeArray = new int[] { 3, 2, 3, 3 }; + int numberOfLayers = layerSizeArray.length; + + WritableUtils.writeString(output, MLPType); + output.writeDouble(learningRate); + output.writeBoolean(regularization); + output.writeDouble(momentum); + output.writeInt(numberOfLayers); + WritableUtils.writeString(output, squashingFunctionName); + WritableUtils.writeString(output, costFunctionName); + + // write the number of neurons for each layer + for (int i = 0; i < numberOfLayers; ++i) { + output.writeInt(layerSizeArray[i]); + } + + double[][] matrix01 = { // 4 by 2 + { 0.5, 0.2 }, { 0.1, 0.1 }, { 0.2, 0.5 }, { 0.1, 0.5 } }; + + double[][] matrix12 = { // 3 by 3 + { 0.1, 0.2, 0.5 }, { 0.2, 0.5, 0.2 }, { 0.5, 0.5, 0.1 } }; + + double[][] matrix23 = { // 4 by 3 + { 0.2, 0.5, 0.2 }, { 0.5, 0.1, 0.5 }, { 0.1, 0.2, 0.1 }, + { 0.1, 0.2, 0.5 } }; + + DoubleMatrix[] matrices = { new DenseDoubleMatrix(matrix01), + new DenseDoubleMatrix(matrix12), new DenseDoubleMatrix(matrix23) }; + for (DoubleMatrix mat : matrices) { + MatrixWritable.write(mat, output); + } + output.close(); + + } catch (IOException e) { + e.printStackTrace(); + } + + // initial the mlp with existing model meta-data and get the output + MultiLayerPerceptron mlp = new SmallMultiLayerPerceptron(modelPath); + DoubleVector input = new DenseDoubleVector(new double[] { 1, 2, 3 }); + try { + DoubleVector result = mlp.output(input); + assertArrayEquals(new double[] { 0.6636557, 0.7009963, 0.7213835 }, + result.toArray(), 0.0001); + } catch (Exception e1) { + e1.printStackTrace(); + } + + // delete meta-data + try { + FileSystem fs = FileSystem.get(conf); + fs.delete(new Path(modelPath), true); + } catch (IOException e) { + // TODO Auto-generated catch block + e.printStackTrace(); + } + + } + + /** + * Test the MLP on XOR problem. + */ + @Test + public void testSingleInstanceTraining() { + // generate training data + DoubleVector[] trainingData = new DenseDoubleVector[] { + new DenseDoubleVector(new double[] { 0, 0, 0 }), + new DenseDoubleVector(new double[] { 0, 1, 1 }), + new DenseDoubleVector(new double[] { 1, 0, 1 }), + new DenseDoubleVector(new double[] { 1, 1, 0 }) }; + + // set parameters + double learningRate = 0.6; + boolean regularization = false; // no regularization + double momentum = 0; // no momentum + String squashingFunctionName = "Sigmoid"; + String costFunctionName = "SquaredError"; + int[] layerSizeArray = new int[] { 2, 5, 1 }; + SmallMultiLayerPerceptron mlp = new SmallMultiLayerPerceptron(learningRate, + regularization, momentum, squashingFunctionName, costFunctionName, + layerSizeArray); + + try { + // train by multiple instances + Random rnd = new Random(); + for (int i = 0; i < 30000; ++i) { + DenseDoubleMatrix[] weightUpdates = mlp + .trainByInstance(trainingData[rnd.nextInt(4)]); + mlp.updateWeightMatrices(weightUpdates); + } + + // System.out.printf("Weight matrices: %s\n", + // mlp.weightsToString(mlp.getWeightMatrices())); + for (int i = 0; i < trainingData.length; ++i) { + DenseDoubleVector testVec = (DenseDoubleVector) trainingData[i] + .slice(2); + assertEquals(trainingData[i].toArray()[2], mlp.output(testVec) + .toArray()[0], 0.2); + } + } catch (Exception e) { + e.printStackTrace(); + } + } + + /** + * Test the XOR problem. + */ + @Test + public void testTrainingByXOR() { + // write in some training instances + Configuration conf = new Configuration(); + String strDataPath = "src/test/resources/perception/xor"; + Path dataPath = new Path(strDataPath); + + // generate training data + DoubleVector[] trainingData = new DenseDoubleVector[] { + new DenseDoubleVector(new double[] { 0, 0, 0 }), + new DenseDoubleVector(new double[] { 0, 1, 1 }), + new DenseDoubleVector(new double[] { 1, 0, 1 }), + new DenseDoubleVector(new double[] { 1, 1, 0 }) }; + + try { + URI uri = new URI(strDataPath); + FileSystem fs = FileSystem.get(uri, conf); + fs.delete(dataPath, true); + if (!fs.exists(dataPath)) { + fs.createNewFile(dataPath); + SequenceFile.Writer writer = new SequenceFile.Writer(fs, conf, + dataPath, LongWritable.class, VectorWritable.class); + + for (int i = 0; i < 1000; ++i) { + VectorWritable vecWritable = new VectorWritable(trainingData[i % 4]); + writer.append(new LongWritable(i), vecWritable); + } + writer.close(); + } + + } catch (Exception e) { + e.printStackTrace(); + } + + // begin training + String modelPath = "src/test/resources/xorModel.data"; + double learningRate = 0.6; + boolean regularization = false; // no regularization + double momentum = 0; // no momentum + String squashingFunctionName = "Tanh"; + String costFunctionName = "SquaredError"; + int[] layerSizeArray = new int[] { 2, 5, 1 }; + SmallMultiLayerPerceptron mlp = new SmallMultiLayerPerceptron(learningRate, + regularization, momentum, squashingFunctionName, costFunctionName, + layerSizeArray); + + Map