Index: CHANGES.txt
===================================================================
--- CHANGES.txt (revision 1561388)
+++ CHANGES.txt (working copy)
@@ -24,6 +24,7 @@
IMPROVEMENTS
+ HAMA-859: Leverage commons cli2 to parse the input argument for NeuralNetwork Example (Yexi Jiang)
HAMA-853: Refactor Outgoing message manager (edwardyoon)
HAMA-852: Add MessageClass property in BSPJob (Martin Illecker)
HAMA-843: Message communication overhead between master aggregation and vertex computation supersteps (edwardyoon)
Index: core/pom.xml
===================================================================
--- core/pom.xml (revision 1561388)
+++ core/pom.xml (working copy)
@@ -51,10 +51,6 @@
commons-logging
- commons-cli
- commons-cli
-
-
commons-configuration
commons-configuration
@@ -135,6 +131,10 @@
org.apache.zookeeper
zookeeper
+
+ org.apache.mahout.commons
+ commons-cli
+
Index: examples/src/main/java/org/apache/hama/examples/NeuralNetwork.java
===================================================================
--- examples/src/main/java/org/apache/hama/examples/NeuralNetwork.java (revision 1561388)
+++ examples/src/main/java/org/apache/hama/examples/NeuralNetwork.java (working copy)
@@ -23,194 +23,288 @@
import java.io.OutputStreamWriter;
import java.net.URI;
import java.util.HashMap;
+import java.util.List;
import java.util.Map;
+import org.apache.commons.cli2.CommandLine;
+import org.apache.commons.cli2.Group;
+import org.apache.commons.cli2.Option;
+import org.apache.commons.cli2.builder.ArgumentBuilder;
+import org.apache.commons.cli2.builder.DefaultOptionBuilder;
+import org.apache.commons.cli2.builder.GroupBuilder;
+import org.apache.commons.cli2.commandline.Parser;
+import org.apache.commons.cli2.util.HelpFormatter;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.hama.HamaConfiguration;
import org.apache.hama.commons.math.DenseDoubleVector;
import org.apache.hama.commons.math.DoubleVector;
import org.apache.hama.commons.math.FunctionFactory;
+import org.apache.hama.examples.util.ParserUtil;
import org.apache.hama.ml.ann.SmallLayeredNeuralNetwork;
+import com.google.common.io.Closeables;
+
/**
* The example of using {@link SmallLayeredNeuralNetwork}, including the
* training phase and labeling phase.
*/
public class NeuralNetwork {
+ // either train or label
+ private static String mode;
- public static void main(String[] args) throws Exception {
- if (args.length < 3) {
- printUsage();
- return;
- }
- String mode = args[0];
- if (mode.equalsIgnoreCase("label")) {
- if (args.length < 4) {
- printUsage();
- return;
- }
- HamaConfiguration conf = new HamaConfiguration();
+ // arguments for labeling
+ private static String featureDataPath;
+ private static String resultDataPath;
+ private static String modelPath;
- String featureDataPath = args[1];
- String resultDataPath = args[2];
- String modelPath = args[3];
+ // arguments for training
+ private static String trainingDataPath;
+ private static int featureDimension;
+ private static int labelDimension;
+ private static List hiddenLayerDimension;
+ private static int iterations;
+ private static double learningRate;
+ private static double momemtumWeight;
+ private static double regularizationWeight;
- SmallLayeredNeuralNetwork ann = new SmallLayeredNeuralNetwork(modelPath);
+ public static boolean parseArgs(String[] args) {
+ DefaultOptionBuilder optionBuilder = new DefaultOptionBuilder();
+ GroupBuilder groupBuilder = new GroupBuilder();
+ ArgumentBuilder argumentBuilder = new ArgumentBuilder();
- // process data in streaming approach
- FileSystem fs = FileSystem.get(new URI(featureDataPath), conf);
- BufferedReader br = new BufferedReader(new InputStreamReader(
- fs.open(new Path(featureDataPath))));
- Path outputPath = new Path(resultDataPath);
- if (fs.exists(outputPath)) {
- fs.delete(outputPath, true);
- }
- BufferedWriter bw = new BufferedWriter(new OutputStreamWriter(
- fs.create(outputPath)));
+ // the feature data (unlabeled data) path argument
+ Option featureDataPathOption = optionBuilder
+ .withLongName("feature-data-path")
+ .withShortName("fp")
+ .withDescription("the path of the feature data (unlabeled data).")
+ .withArgument(
+ argumentBuilder.withName("path").withMinimum(1).withMaximum(1)
+ .create()).withRequired(true).create();
- String line = null;
+ // the result data path argument
+ Option resultDataPathOption = optionBuilder
+ .withLongName("result-data-path")
+ .withShortName("rp")
+ .withDescription("the path to store the result.")
+ .withArgument(
+ argumentBuilder.withName("path").withMinimum(1).withMaximum(1)
+ .create()).withRequired(true).create();
- while ((line = br.readLine()) != null) {
- if (line.trim().length() == 0) {
- continue;
- }
- String[] tokens = line.trim().split(",");
- double[] vals = new double[tokens.length];
- for (int i = 0; i < tokens.length; ++i) {
- vals[i] = Double.parseDouble(tokens[i]);
- }
- DoubleVector instance = new DenseDoubleVector(vals);
- DoubleVector result = ann.getOutput(instance);
- double[] arrResult = result.toArray();
- StringBuilder sb = new StringBuilder();
- for (int i = 0; i < arrResult.length; ++i) {
- sb.append(arrResult[i]);
- if (i != arrResult.length - 1) {
- sb.append(",");
- } else {
- sb.append("\n");
- }
- }
- bw.write(sb.toString());
- }
+ // the path to store the model
+ Option modelPathOption = optionBuilder
+ .withLongName("model-data-path")
+ .withShortName("mp")
+ .withDescription("the path to store the trained model.")
+ .withArgument(
+ argumentBuilder.withName("path").withMinimum(1).withMaximum(1)
+ .create()).withRequired(true).create();
- br.close();
- bw.close();
- } else if (mode.equals("train")) {
- if (args.length < 5) {
- printUsage();
- return;
- }
+ // the path of the training data
+ Option trainingDataPathOption = optionBuilder
+ .withLongName("training-data-path")
+ .withShortName("tp")
+ .withDescription("the path to store the trained model.")
+ .withArgument(
+ argumentBuilder.withName("path").withMinimum(1).withMaximum(1)
+ .create()).withRequired(true).create();
- String trainingDataPath = args[1];
- String trainedModelPath = args[2];
+ // the dimension of the features
+ Option featureDimensionOption = optionBuilder
+ .withLongName("feature dimension")
+ .withShortName("fd")
+ .withDescription("the dimension of the features.")
+ .withArgument(
+ argumentBuilder.withName("dimension").withMinimum(1).withMaximum(1)
+ .create()).withRequired(true).create();
- int featureDimension = Integer.parseInt(args[3]);
- int labelDimension = Integer.parseInt(args[4]);
+ // the dimension of the hidden layers, at most two hidden layers
+ Option hiddenLayerOption = optionBuilder
+ .withLongName("hidden layer dimension(s)")
+ .withShortName("hd")
+ .withDescription("the dimension of the hidden layer(s).")
+ .withArgument(
+ argumentBuilder.withName("dimension").withMinimum(0).withMaximum(2)
+ .create()).withRequired(true).create();
- int iteration = 1000;
- double learningRate = 0.4;
- double momemtumWeight = 0.2;
- double regularizationWeight = 0.01;
+ // the dimension of the labels
+ Option labelDimensionOption = optionBuilder
+ .withLongName("label dimension")
+ .withShortName("ld")
+ .withDescription("the dimension of the label(s).")
+ .withArgument(
+ argumentBuilder.withName("dimension").withMinimum(1).withMaximum(1)
+ .create()).withRequired(true).create();
- // parse parameters
- if (args.length >= 6) {
- try {
- iteration = Integer.parseInt(args[5]);
- System.out.printf("Iteration: %d\n", iteration);
- } catch (NumberFormatException e) {
- System.err
- .println("MAX_ITERATION format invalid. It should be a positive number.");
- return;
- }
- }
- if (args.length >= 7) {
- try {
- learningRate = Double.parseDouble(args[6]);
- System.out.printf("Learning rate: %f\n", learningRate);
- } catch (NumberFormatException e) {
- System.err
- .println("LEARNING_RATE format invalid. It should be a positive double in range (0, 1.0)");
- return;
- }
- }
- if (args.length >= 8) {
- try {
- momemtumWeight = Double.parseDouble(args[7]);
- System.out.printf("Momemtum weight: %f\n", momemtumWeight);
- } catch (NumberFormatException e) {
- System.err
- .println("MOMEMTUM_WEIGHT format invalid. It should be a positive double in range (0, 1.0)");
- return;
- }
- }
- if (args.length >= 9) {
- try {
- regularizationWeight = Double.parseDouble(args[8]);
- System.out
- .printf("Regularization weight: %f\n", regularizationWeight);
- } catch (NumberFormatException e) {
- System.err
- .println("REGULARIZATION_WEIGHT format invalid. It should be a positive double in range (0, 1.0)");
- return;
- }
- }
+ // the number of iterations for training
+ Option iterationOption = optionBuilder
+ .withLongName("iterations")
+ .withShortName("itr")
+ .withDescription("the iterations for training.")
+ .withArgument(
+ argumentBuilder.withName("iterations").withMinimum(1)
+ .withMaximum(1).withDefault(1000).create()).create();
- // train the model
- SmallLayeredNeuralNetwork ann = new SmallLayeredNeuralNetwork();
- ann.setLearningRate(learningRate);
- ann.setMomemtumWeight(momemtumWeight);
- ann.setRegularizationWeight(regularizationWeight);
- ann.addLayer(featureDimension, false,
- FunctionFactory.createDoubleFunction("Sigmoid"));
- ann.addLayer(featureDimension, false,
- FunctionFactory.createDoubleFunction("Sigmoid"));
- ann.addLayer(labelDimension, true,
- FunctionFactory.createDoubleFunction("Sigmoid"));
- ann.setCostFunction(FunctionFactory
- .createDoubleDoubleFunction("CrossEntropy"));
- ann.setModelPath(trainedModelPath);
+ // the learning rate
+ Option learningRateOption = optionBuilder
+ .withLongName("learning-rate")
+ .withShortName("l")
+ .withDescription("the learning rate for training, default 0.1.")
+ .withArgument(
+ argumentBuilder.withName("learning-rate").withMinimum(1)
+ .withMaximum(1).withDefault(0.1).create()).create();
- Map trainingParameters = new HashMap();
- trainingParameters.put("tasks", "5");
- trainingParameters.put("training.max.iterations", "" + iteration);
- trainingParameters.put("training.batch.size", "300");
- trainingParameters.put("convergence.check.interval", "1000");
- ann.train(new Path(trainingDataPath), trainingParameters);
+ // the momemtum weight
+ Option momentumWeightOption = optionBuilder
+ .withLongName("momemtum-weight")
+ .withShortName("m")
+ .withDescription("the momemtum weight for training, default 0.1.")
+ .withArgument(
+ argumentBuilder.withName("momemtum weight").withMinimum(1)
+ .withMaximum(1).withDefault(0.1).create()).create();
+
+ // the regularization weight
+ Option regularizationWeightOption = optionBuilder
+ .withLongName("regularization-weight")
+ .withShortName("r")
+ .withDescription("the regularization weight for training, default 0.")
+ .withArgument(
+ argumentBuilder.withName("regularization weight").withMinimum(1)
+ .withMaximum(1).withDefault(0).create()).create();
+
+ // the parameters related to train mode
+ Group trainModeGroup = groupBuilder.withOption(trainingDataPathOption)
+ .withOption(modelPathOption).withOption(featureDimensionOption)
+ .withOption(labelDimensionOption).withOption(hiddenLayerOption)
+ .withOption(iterationOption).withOption(learningRateOption)
+ .withOption(momentumWeightOption)
+ .withOption(regularizationWeightOption).create();
+
+ // the parameters related to label mode
+ Group labelModeGroup = groupBuilder.withOption(modelPathOption)
+ .withOption(featureDataPathOption).withOption(resultDataPathOption)
+ .create();
+
+ Option trainModeOption = optionBuilder.withLongName("train")
+ .withShortName("train").withDescription("the train mode")
+ .withChildren(trainModeGroup).create();
+
+ Option labelModeOption = optionBuilder.withLongName("label")
+ .withShortName("label").withChildren(labelModeGroup)
+ .withDescription("the label mode").create();
+
+ Group normalGroup = groupBuilder.withOption(trainModeOption)
+ .withOption(labelModeOption).create();
+
+ Parser parser = new Parser();
+ parser.setGroup(normalGroup);
+ parser.setHelpFormatter(new HelpFormatter());
+ parser.setHelpTrigger("--help");
+ CommandLine cli = parser.parseAndHelp(args);
+ if (cli == null) {
+ return false;
}
+ // get the arguments
+ boolean hasTrainMode = cli.hasOption(trainModeOption);
+ boolean hasLabelMode = cli.hasOption(labelModeOption);
+ if (hasTrainMode && hasLabelMode) {
+ return false;
+ }
+
+ mode = hasTrainMode ? "train" : "label";
+ if (mode.equals("train")) {
+ trainingDataPath = ParserUtil.getString(cli, trainingDataPathOption);
+ modelPath = ParserUtil.getString(cli, modelPathOption);
+ featureDimension = ParserUtil.getInteger(cli, featureDimensionOption);
+ labelDimension = ParserUtil.getInteger(cli, labelDimensionOption);
+ hiddenLayerDimension = ParserUtil.getInts(cli, hiddenLayerOption);
+ iterations = ParserUtil.getInteger(cli, iterationOption);
+ learningRate = ParserUtil.getDouble(cli, learningRateOption);
+ momemtumWeight = ParserUtil.getDouble(cli, momentumWeightOption);
+ regularizationWeight = ParserUtil.getDouble(cli,
+ regularizationWeightOption);
+ } else {
+ featureDataPath = ParserUtil.getString(cli, featureDataPathOption);
+ modelPath = ParserUtil.getString(cli, modelPathOption);
+ resultDataPath = ParserUtil.getString(cli, resultDataPathOption);
+ }
+
+ return true;
}
- private static void printUsage() {
- System.out
- .println("USAGE: | [ ]");
- System.out
- .println("\tMODE\t- train: train the model with given training data.");
- System.out
- .println("\t\t- label: obtain the result by feeding the features to the neural network.");
- System.out
- .println("\tINPUT_PATH\tin 'train' mode, it is the path of the training data; in 'label' mode, it is the path of the to be evaluated data that lacks the label.");
- System.out
- .println("\tOUTPUT_PATH\tin 'train' mode, it is where the trained model is stored; in 'label' mode, it is where the labeled data is stored.");
- System.out.println("\n\tConditional Parameters:");
- System.out
- .println("\tMODEL_PATH\tonly required in 'label' mode. It specifies where to load the trained neural network model.");
- System.out
- .println("\tMAX_ITERATION\tonly used in 'train' mode. It specifies how many iterations for the neural network to run. Default is 0.01.");
- System.out
- .println("\tLEARNING_RATE\tonly used to 'train' mode. It specifies the degree of aggregation for learning, usually in range (0, 1.0). Default is 0.1.");
- System.out
- .println("\tMOMEMTUM_WEIGHT\tonly used to 'train' mode. It specifies the weight of momemtum. Default is 0.");
- System.out
- .println("\tREGULARIZATION_WEIGHT\tonly required in 'train' model. It specifies the weight of reqularization.");
- System.out.println("\nExample:");
- System.out
- .println("Train a neural network with with feature dimension 8, label dimension 1 and default setting:\n\tneuralnets train hdfs://localhost:30002/training_data hdfs://localhost:30002/model 8 1");
- System.out
- .println("Train a neural network with with feature dimension 8, label dimension 1 and specify learning rate as 0.1, momemtum rate as 0.2, and regularization weight as 0.01:\n\tneuralnets.train hdfs://localhost:30002/training_data hdfs://localhost:30002/model 8 1 0.1 0.2 0.01");
- System.out
- .println("Label the data with trained model:\n\tneuralnets evaluate hdfs://localhost:30002/unlabeled_data hdfs://localhost:30002/result hdfs://localhost:30002/model");
+ public static void main(String[] args) throws Exception {
+ if (parseArgs(args)) {
+ if (mode.equals("label")) {
+ HamaConfiguration conf = new HamaConfiguration();
+ SmallLayeredNeuralNetwork ann = new SmallLayeredNeuralNetwork(modelPath);
+
+ // process data in streaming approach
+ FileSystem fs = FileSystem.get(new URI(featureDataPath), conf);
+ BufferedReader br = new BufferedReader(new InputStreamReader(
+ fs.open(new Path(featureDataPath))));
+ Path outputPath = new Path(resultDataPath);
+ if (fs.exists(outputPath)) {
+ fs.delete(outputPath, true);
+ }
+ BufferedWriter bw = new BufferedWriter(new OutputStreamWriter(
+ fs.create(outputPath)));
+
+ String line = null;
+
+ while ((line = br.readLine()) != null) {
+ if (line.trim().length() == 0) {
+ continue;
+ }
+ String[] tokens = line.trim().split(",");
+ double[] vals = new double[tokens.length];
+ for (int i = 0; i < tokens.length; ++i) {
+ vals[i] = Double.parseDouble(tokens[i]);
+ }
+ DoubleVector instance = new DenseDoubleVector(vals);
+ DoubleVector result = ann.getOutput(instance);
+ double[] arrResult = result.toArray();
+ StringBuilder sb = new StringBuilder();
+ for (int i = 0; i < arrResult.length; ++i) {
+ sb.append(arrResult[i]);
+ if (i != arrResult.length - 1) {
+ sb.append(",");
+ } else {
+ sb.append("\n");
+ }
+ }
+ bw.write(sb.toString());
+ }
+
+ Closeables.close(br, true);
+ Closeables.close(bw, true);
+ } else { // train the model
+ SmallLayeredNeuralNetwork ann = new SmallLayeredNeuralNetwork();
+ ann.setLearningRate(learningRate);
+ ann.setMomemtumWeight(momemtumWeight);
+ ann.setRegularizationWeight(regularizationWeight);
+ ann.addLayer(featureDimension, false,
+ FunctionFactory.createDoubleFunction("Sigmoid"));
+ if (hiddenLayerDimension != null) {
+ for (int dimension : hiddenLayerDimension) {
+ ann.addLayer(dimension, false,
+ FunctionFactory.createDoubleFunction("Sigmoid"));
+ }
+ }
+ ann.addLayer(labelDimension, true,
+ FunctionFactory.createDoubleFunction("Sigmoid"));
+ ann.setCostFunction(FunctionFactory
+ .createDoubleDoubleFunction("CrossEntropy"));
+ ann.setModelPath(modelPath);
+
+ Map trainingParameters = new HashMap();
+ trainingParameters.put("tasks", "5");
+ trainingParameters.put("training.max.iterations", "" + iterations);
+ trainingParameters.put("training.batch.size", "300");
+ trainingParameters.put("convergence.check.interval", "1000");
+ ann.train(new Path(trainingDataPath), trainingParameters);
+ }
+ }
}
}
Index: examples/src/main/java/org/apache/hama/examples/util/ParserUtil.java
===================================================================
--- examples/src/main/java/org/apache/hama/examples/util/ParserUtil.java (revision 0)
+++ examples/src/main/java/org/apache/hama/examples/util/ParserUtil.java (working copy)
@@ -0,0 +1,132 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.hama.examples.util;
+
+import java.util.ArrayList;
+import java.util.List;
+
+import org.apache.commons.cli2.CommandLine;
+import org.apache.commons.cli2.Option;
+
+/**
+ * Facilitate the command line argument parsing.
+ *
+ */
+public class ParserUtil {
+
+ /**
+ * Parse and return the string parameter.
+ *
+ * @param cli
+ * @param option
+ * @return
+ */
+ public static String getString(CommandLine cli, Option option) {
+ Object val = cli.getValue(option);
+ if (val != null) {
+ return val.toString();
+ }
+ return null;
+ }
+
+ /**
+ * Parse and return the integer parameter.
+ *
+ * @param cli
+ * @param option
+ * @return
+ */
+ public static Integer getInteger(CommandLine cli, Option option) {
+ Object val = cli.getValue(option);
+ if (val != null) {
+ return Integer.parseInt(val.toString());
+ }
+ return null;
+ }
+
+ /**
+ * Parse and return the long parameter.
+ *
+ * @param cli
+ * @param option
+ * @return
+ */
+ public static Long getLong(CommandLine cli, Option option) {
+ Object val = cli.getValue(option);
+ if (val != null) {
+ return Long.parseLong(val.toString());
+ }
+ return null;
+ }
+
+ /**
+ * Parse and return the double parameter.
+ *
+ * @param cli
+ * @param option
+ * @return
+ */
+ public static Double getDouble(CommandLine cli, Option option) {
+ Object val = cli.getValue(option);
+ if (val != null) {
+ return Double.parseDouble(val.toString());
+ }
+ return null;
+ }
+
+ /**
+ * Parse and return the boolean parameter. If the parameter is set, it is
+ * true, otherwise it is false.
+ *
+ * @param cli
+ * @param option
+ * @return
+ */
+ public static boolean getBoolean(CommandLine cli, Option option) {
+ return cli.hasOption(option);
+ }
+
+ /**
+ * Parse and return the array parameters.
+ * @param cli
+ * @param option
+ * @return
+ */
+ public static List getStrings(CommandLine cli, Option option) {
+ List list = new ArrayList();
+ for (Object obj : cli.getValues(option)) {
+ list.add(obj.toString());
+ }
+ return list;
+ }
+
+ /**
+ * Parse and return the array parameters.
+ * @param cli
+ * @param option
+ * @return
+ */
+ public static List getInts(CommandLine cli, Option option) {
+ List list = new ArrayList();
+ for (String str : getStrings(cli, option)) {
+ list.add(Integer.parseInt(str));
+ }
+ return list;
+ }
+}
+
Index: examples/src/test/java/org/apache/hama/examples/NeuralNetworkTest.java
===================================================================
--- examples/src/test/java/org/apache/hama/examples/NeuralNetworkTest.java (revision 1561388)
+++ examples/src/test/java/org/apache/hama/examples/NeuralNetworkTest.java (working copy)
@@ -23,8 +23,6 @@
import java.util.ArrayList;
import java.util.List;
-import junit.framework.TestCase;
-
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
@@ -33,32 +31,34 @@
import org.apache.hama.HamaConfiguration;
import org.apache.hama.commons.io.VectorWritable;
import org.apache.hama.commons.math.DenseDoubleVector;
+import org.junit.Before;
+import org.junit.Test;
/**
* Test the functionality of NeuralNetwork Example.
*
*/
-public class NeuralNetworkTest extends TestCase {
+public class NeuralNetworkTest {
private Configuration conf = new HamaConfiguration();
private FileSystem fs;
private String MODEL_PATH = "/tmp/neuralnets.model";
private String RESULT_PATH = "/tmp/neuralnets.txt";
private String SEQTRAIN_DATA = "/tmp/test-neuralnets.data";
-
- @Override
- protected void setUp() throws Exception {
- super.setUp();
+
+ @Before
+ public void setup() throws Exception {
fs = FileSystem.get(conf);
}
+ @Test
public void testNeuralnetsLabeling() throws IOException {
this.neuralNetworkTraining();
String dataPath = "src/test/resources/neuralnets_classification_test.txt";
- String mode = "label";
+ String mode = "-label";
try {
NeuralNetwork
- .main(new String[] { mode, dataPath, RESULT_PATH, MODEL_PATH });
+ .main(new String[] { mode, "-fp", dataPath, "-rp", RESULT_PATH, "-mp", MODEL_PATH });
// compare results with ground-truth
BufferedReader groundTruthReader = new BufferedReader(new FileReader(
@@ -98,7 +98,7 @@
}
private void neuralNetworkTraining() {
- String mode = "train";
+ String mode = "-train";
String strTrainingDataPath = "src/test/resources/neuralnets_classification_training.txt";
int featureDimension = 8;
int labelDimension = 1;
@@ -130,8 +130,9 @@
}
try {
- NeuralNetwork.main(new String[] { mode, SEQTRAIN_DATA,
- MODEL_PATH, "" + featureDimension, "" + labelDimension });
+ NeuralNetwork.main(new String[] { mode, "-tp", SEQTRAIN_DATA, "-mp",
+ MODEL_PATH, "-fd", "" + featureDimension, "-hd",
+ "" + featureDimension, "-ld", "" + labelDimension, "-itr", "3000", "-m", "0.2", "-l", "0.2" });
} catch (Exception e) {
e.printStackTrace();
}
Index: pom.xml
===================================================================
--- pom.xml (revision 1561388)
+++ pom.xml (working copy)
@@ -88,6 +88,7 @@
1.1.1
1.2
+ 2.0-mahout
1.7
2.6
3.0.1
@@ -276,7 +277,12 @@
jackson-mapper-asl
1.9.2
-
+
+
+ org.apache.mahout.commons
+ commons-cli
+ ${commons-cli2.version}
+