Index: ml/src/main/java/org/apache/hama/ml/ann/AbstractLayeredNeuralNetwork.java =================================================================== --- ml/src/main/java/org/apache/hama/ml/ann/AbstractLayeredNeuralNetwork.java (revision 1549613) +++ ml/src/main/java/org/apache/hama/ml/ann/AbstractLayeredNeuralNetwork.java (working copy) @@ -31,6 +31,7 @@ import org.apache.hama.commons.math.FunctionFactory; import com.google.common.base.Preconditions; +import com.google.common.collect.Lists; /** * AbstractLayeredNeuralNetwork defines the general operations for derivative @@ -66,7 +67,7 @@ protected LearningStyle learningStyle; public static enum TrainingMethod { - GRADIATE_DESCENT + GRADIENT_DESCENT } public static enum LearningStyle { @@ -77,7 +78,7 @@ public AbstractLayeredNeuralNetwork() { this.regularizationWeight = DEFAULT_REGULARIZATION_WEIGHT; this.momentumWeight = DEFAULT_MOMENTUM_WEIGHT; - this.trainingMethod = TrainingMethod.GRADIATE_DESCENT; + this.trainingMethod = TrainingMethod.GRADIENT_DESCENT; this.learningStyle = LearningStyle.SUPERVISED; } @@ -229,7 +230,7 @@ // read layer size list int numLayers = input.readInt(); - this.layerSizeList = new ArrayList(); + this.layerSizeList = Lists.newArrayList(); for (int i = 0; i < numLayers; ++i) { this.layerSizeList.add(input.readInt()); } Index: ml/src/main/java/org/apache/hama/ml/ann/NeuralNetwork.java =================================================================== --- ml/src/main/java/org/apache/hama/ml/ann/NeuralNetwork.java (revision 1549613) +++ ml/src/main/java/org/apache/hama/ml/ann/NeuralNetwork.java (working copy) @@ -39,6 +39,7 @@ import org.apache.hama.ml.util.FeatureTransformer; import com.google.common.base.Preconditions; +import com.google.common.io.Closeables; /** * NeuralNetwork defines the general operations for all the derivative models. @@ -85,7 +86,7 @@ */ public void setLearningRate(double learningRate) { Preconditions.checkArgument(learningRate > 0, - "Learning rate must larger than 0."); + "Learning rate must be larger than 0."); this.learningRate = learningRate; } @@ -144,13 +145,16 @@ Preconditions.checkArgument(this.modelPath != null, "Model path has not been set."); Configuration conf = new Configuration(); + FSDataInputStream is = null; try { URI uri = new URI(this.modelPath); FileSystem fs = FileSystem.get(uri, conf); - FSDataInputStream is = new FSDataInputStream(fs.open(new Path(modelPath))); + is = new FSDataInputStream(fs.open(new Path(modelPath))); this.readFields(is); } catch (URISyntaxException e) { e.printStackTrace(); + } finally { + Closeables.close(is, false); } } @@ -164,10 +168,17 @@ Preconditions.checkArgument(this.modelPath != null, "Model path has not been set."); Configuration conf = new Configuration(); - FileSystem fs = FileSystem.get(conf); - FSDataOutputStream stream = fs.create(new Path(this.modelPath), true); - this.write(stream); - stream.close(); + FSDataOutputStream is = null; + try { + URI uri = new URI(this.modelPath); + FileSystem fs = FileSystem.get(uri, conf); + is = fs.create(new Path(this.modelPath), true); + this.write(is); + } catch (URISyntaxException e) { + e.printStackTrace(); + } + + Closeables.close(is, false); } /** @@ -215,7 +226,7 @@ Constructor[] constructors = featureTransformerCls .getDeclaredConstructors(); Constructor constructor = constructors[0]; - + try { this.featureTransformer = (FeatureTransformer) constructor .newInstance(new Object[] {}); Index: ml/src/main/java/org/apache/hama/ml/ann/SmallLayeredNeuralNetwork.java =================================================================== --- ml/src/main/java/org/apache/hama/ml/ann/SmallLayeredNeuralNetwork.java (revision 1549613) +++ ml/src/main/java/org/apache/hama/ml/ann/SmallLayeredNeuralNetwork.java (working copy) @@ -23,8 +23,8 @@ import java.util.ArrayList; import java.util.List; import java.util.Map; -import java.util.Random; +import org.apache.commons.lang.math.RandomUtils; import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.fs.Path; import org.apache.hadoop.io.LongWritable; @@ -43,6 +43,7 @@ import org.mortbay.log.Log; import com.google.common.base.Preconditions; +import com.google.common.collect.Lists; /** * SmallLayeredNeuralNetwork defines the general operations for derivative @@ -70,10 +71,10 @@ protected int finalLayerIdx; public SmallLayeredNeuralNetwork() { - this.layerSizeList = new ArrayList(); - this.weightMatrixList = new ArrayList(); - this.prevWeightUpdatesList = new ArrayList(); - this.squashingFunctionList = new ArrayList(); + this.layerSizeList = Lists.newArrayList(); + this.weightMatrixList = Lists.newArrayList(); + this.prevWeightUpdatesList = Lists.newArrayList(); + this.squashingFunctionList = Lists.newArrayList(); } public SmallLayeredNeuralNetwork(String modelPath) { @@ -86,7 +87,8 @@ */ public int addLayer(int size, boolean isFinalLayer, DoubleFunction squashingFunction) { - Preconditions.checkArgument(size > 0, "Size of layer must larger than 0."); + Preconditions.checkArgument(size > 0, + "Size of layer must be larger than 0."); if (!isFinalLayer) { size += 1; } @@ -107,11 +109,10 @@ int col = sizePrevLayer; DoubleMatrix weightMatrix = new DenseDoubleMatrix(row, col); // initialize weights - final Random rnd = new Random(); weightMatrix.applyToElements(new DoubleFunction() { @Override public double apply(double value) { - return rnd.nextDouble() - 0.5; + return RandomUtils.nextDouble() - 0.5; } @Override @@ -138,6 +139,10 @@ } } + /** + * Set the previous weight matrices. + * @param prevUpdates + */ void setPrevWeightMatrices(DoubleMatrix[] prevUpdates) { this.prevWeightUpdatesList.clear(); for (DoubleMatrix prevUpdate : prevUpdates) { @@ -176,8 +181,8 @@ */ public void setWeightMatrices(DoubleMatrix[] matrices) { this.weightMatrixList = new ArrayList(); - for (int i = 0; i < matrices.length; ++i) { - this.weightMatrixList.add(matrices[i]); + for (DoubleMatrix matrix : matrices) { + this.weightMatrixList.add(matrix); } } @@ -197,8 +202,9 @@ public void setWeightMatrix(int index, DoubleMatrix matrix) { Preconditions.checkArgument( - 0 <= index && index < this.weightMatrixList.size(), - String.format("index [%d] out of range.", index)); + 0 <= index && index < this.weightMatrixList.size(), String.format( + "index [%d] should be in range[%d, %d].", index, 0, + this.weightMatrixList.size())); this.weightMatrixList.set(index, matrix); } @@ -208,7 +214,7 @@ // read squash functions int squashingFunctionSize = input.readInt(); - this.squashingFunctionList = new ArrayList(); + this.squashingFunctionList = Lists.newArrayList(); for (int i = 0; i < squashingFunctionSize; ++i) { this.squashingFunctionList.add(FunctionFactory .createDoubleFunction(WritableUtils.readString(input))); @@ -216,8 +222,8 @@ // read weights and construct matrices of previous updates int numOfMatrices = input.readInt(); - this.weightMatrixList = new ArrayList(); - this.prevWeightUpdatesList = new ArrayList(); + this.weightMatrixList = Lists.newArrayList(); + this.prevWeightUpdatesList = Lists.newArrayList(); for (int i = 0; i < numOfMatrices; ++i) { DoubleMatrix matrix = MatrixWritable.read(input); this.weightMatrixList.add(matrix); @@ -257,8 +263,8 @@ */ @Override public DoubleVector getOutput(DoubleVector instance) { - Preconditions.checkArgument(this.layerSizeList.get(0) == instance - .getDimension() + 1, String.format( + Preconditions.checkArgument(this.layerSizeList.get(0) - 1 == instance + .getDimension(), String.format( "The dimension of input instance should be %d.", this.layerSizeList.get(0) - 1)); // transform the features to another space @@ -336,8 +342,6 @@ public DoubleMatrix[] trainByInstance(DoubleVector trainingInstance) { DoubleVector transformedVector = this.featureTransformer .transform(trainingInstance.sliceUnsafe(this.layerSizeList.get(0) - 1)); - - int inputDimension = this.layerSizeList.get(0) - 1; int outputDimension; @@ -389,11 +393,12 @@ calculateTrainingError(labels, output.deepCopy().sliceUnsafe(1, output.getDimension() - 1)); - if (this.trainingMethod.equals(TrainingMethod.GRADIATE_DESCENT)) { + if (this.trainingMethod.equals(TrainingMethod.GRADIENT_DESCENT)) { return this.trainByInstanceGradientDescent(labels, internalResults); + } else { + throw new IllegalArgumentException( + String.format("Training method is not supported.")); } - throw new IllegalArgumentException( - String.format("Training method is not supported.")); } /** @@ -483,9 +488,6 @@ * squashingFunction.applyDerivative(curLayerOutput.get(i))); } - // System.out.printf("Delta layer: %d, %s\n", curLayerIdx, - // delta.toString()); - // update weights for (int i = 0; i < weightUpdateMatrix.getRowCount(); ++i) { for (int j = 0; j < weightUpdateMatrix.getColumnCount(); ++j) { @@ -495,9 +497,6 @@ } } - // System.out.printf("Weight Layer %d, %s\n", curLayerIdx, - // weightUpdateMatrix.toString()); - return delta; } @@ -556,9 +555,7 @@ protected void calculateTrainingError(DoubleVector labels, DoubleVector output) { DoubleVector errors = labels.deepCopy().applyToElements(output, this.costFunction); - // System.out.printf("Labels: %s\tOutput: %s\n", labels, output); this.trainingError = errors.sum(); - // System.out.printf("Training error: %s\n", errors); } /** Index: ml/src/main/java/org/apache/hama/ml/ann/SmallLayeredNeuralNetworkMessage.java =================================================================== --- ml/src/main/java/org/apache/hama/ml/ann/SmallLayeredNeuralNetworkMessage.java (revision 1549613) +++ ml/src/main/java/org/apache/hama/ml/ann/SmallLayeredNeuralNetworkMessage.java (working copy) @@ -78,12 +78,12 @@ } else { output.writeBoolean(true); } - for (int i = 0; i < curMatrices.length; ++i) { - MatrixWritable.write(curMatrices[i], output); + for (DoubleMatrix matrix : curMatrices) { + MatrixWritable.write(matrix, output); } if (prevMatrices != null) { - for (int i = 0; i < prevMatrices.length; ++i) { - MatrixWritable.write(prevMatrices[i], output); + for (DoubleMatrix matrix : prevMatrices) { + MatrixWritable.write(matrix, output); } } } Index: ml/src/test/java/org/apache/hama/ml/ann/TestSmallLayeredNeuralNetwork.java =================================================================== --- ml/src/test/java/org/apache/hama/ml/ann/TestSmallLayeredNeuralNetwork.java (revision 1549613) +++ ml/src/test/java/org/apache/hama/ml/ann/TestSmallLayeredNeuralNetwork.java (working copy) @@ -103,7 +103,7 @@ assertEquals(momentumWeight, annCopy.getMomemtumWeight(), 0.000001); assertEquals(regularizationWeight, annCopy.getRegularizationWeight(), 0.000001); - assertEquals(TrainingMethod.GRADIATE_DESCENT, annCopy.getTrainingMethod()); + assertEquals(TrainingMethod.GRADIENT_DESCENT, annCopy.getTrainingMethod()); assertEquals(LearningStyle.UNSUPERVISED, annCopy.getLearningStyle()); // compare weights