Index: ml/src/test/java/org/apache/hama/ml/kmeans/TestKMeansBSP.java =================================================================== --- ml/src/test/java/org/apache/hama/ml/kmeans/TestKMeansBSP.java (revision 1555701) +++ ml/src/test/java/org/apache/hama/ml/kmeans/TestKMeansBSP.java (working copy) @@ -18,70 +18,126 @@ package org.apache.hama.ml.kmeans; import java.io.BufferedWriter; +import java.io.IOException; import java.io.OutputStreamWriter; import java.util.HashMap; import junit.framework.TestCase; import org.apache.hadoop.conf.Configuration; -import org.apache.hadoop.fs.FSDataOutputStream; import org.apache.hadoop.fs.FileSystem; import org.apache.hadoop.fs.Path; +import org.apache.hama.HamaConfiguration; import org.apache.hama.bsp.BSPJob; import org.apache.hama.commons.math.DoubleVector; public class TestKMeansBSP extends TestCase { + public static final String TMP_OUTPUT = "/tmp/clustering/"; public void testRunJob() throws Exception { - Configuration conf = new Configuration(); - Path in = new Path("/tmp/clustering/in/in.txt"); - Path out = new Path("/tmp/clustering/out/"); + Configuration conf = new HamaConfiguration(); FileSystem fs = FileSystem.get(conf); - Path center = null; + if (fs.exists(new Path(TMP_OUTPUT))) { + fs.delete(new Path(TMP_OUTPUT), true); + } - try { - center = new Path(in.getParent(), "center/cen.seq"); + // test for bspTaskNum 1 to 4 + for (int i = 1; i < 5; i++) { + try { + test(conf, fs, i); + } finally { + fs.delete(new Path(TMP_OUTPUT), true); + } + } + } - Path centerOut = new Path(out, "center/center_output.seq"); - conf.set(KMeansBSP.CENTER_IN_PATH, center.toString()); - conf.set(KMeansBSP.CENTER_OUT_PATH, centerOut.toString()); - int iterations = 10; - conf.setInt(KMeansBSP.MAX_ITERATIONS_KEY, iterations); - int k = 1; + /** + * Test + * + * Create 101 input vectors of dimension two + * + * Input vectors: (0,0) (1,1) (2,2) ... (100,100) + * + * k = 1, maxIterations = 10 + * + * Resulting center should be (50,50) + */ + private void test(Configuration conf, FileSystem fs, int numBspTask) + throws IOException, InterruptedException, ClassNotFoundException { - FSDataOutputStream create = fs.create(in); - BufferedWriter bw = new BufferedWriter(new OutputStreamWriter(create)); - StringBuilder sb = new StringBuilder(); + Path in = new Path(TMP_OUTPUT + "in"); + Path out = new Path(TMP_OUTPUT + "out"); + Path centerIn = new Path(TMP_OUTPUT + "center/center_input.seq"); + Path centerOut = new Path(TMP_OUTPUT + "center/center_output.seq"); + conf.set(KMeansBSP.CENTER_IN_PATH, centerIn.toString()); + conf.set(KMeansBSP.CENTER_OUT_PATH, centerOut.toString()); - for (int i = 0; i < 100; i++) { - sb.append(i); - sb.append('\t'); - sb.append(i); - sb.append('\n'); - } + int k = 1; + int iterations = 10; + conf.setInt(KMeansBSP.MAX_ITERATIONS_KEY, iterations); - bw.write(sb.toString()); - bw.close(); + in = generateInputText(k, conf, fs, in, centerIn, out, numBspTask); - in = KMeansBSP.prepareInputText(k, conf, in, center, out, fs, false); + BSPJob job = KMeansBSP.createJob(conf, in, out, true); + job.setNumBspTask(numBspTask); - BSPJob job = KMeansBSP.createJob(conf, in, out, true); + // just submit the job + boolean result = job.waitForCompletion(true); - // just submit the job - boolean result = job.waitForCompletion(true); + assertEquals(true, result); - assertEquals(true, result); + HashMap centerMap = KMeansBSP.readClusterCenters( + conf, out, centerOut, fs); + System.out.println(centerMap); - HashMap centerMap = KMeansBSP.readClusterCenters( - conf, out, centerOut, fs); - System.out.println(centerMap); - assertEquals(1, centerMap.size()); - DoubleVector doubleVector = centerMap.get(0); - assertTrue(doubleVector.get(0) >= 50 && doubleVector.get(0) < 51); - assertTrue(doubleVector.get(1) >= 50 && doubleVector.get(1) < 51); - } finally { - fs.delete(new Path("/tmp/clustering"), true); + assertEquals(1, centerMap.size()); // because k = 1 + + DoubleVector doubleVector = centerMap.get(0); + assertEquals(Double.valueOf(50), doubleVector.get(0)); + assertEquals(Double.valueOf(50), doubleVector.get(1)); + } + + private Path generateInputText(int k, Configuration conf, FileSystem fs, + Path in, Path centerIn, Path out, int numBspTask) throws IOException { + int totalNumberOfPoints = 100; + int interval = totalNumberOfPoints / numBspTask; + Path parts = new Path(in, "parts"); + + for (int part = 0; part < numBspTask; part++) { + Path partIn = new Path(parts, "part" + part + "/input.txt"); + BufferedWriter bw = new BufferedWriter(new OutputStreamWriter( + fs.create(partIn))); + + int start = interval * part; + int end = start + interval - 1; + if ((numBspTask - 1) == part) { + end = totalNumberOfPoints; + } + System.out + .println("Partition " + part + ": from " + start + " to " + end); + + for (int i = start; i <= end; i++) { + bw.append(i + "\t" + i + "\n"); + } + bw.flush(); + bw.close(); + + // Convert input text to sequence file + Path seqFile = null; + if (part == 0) { + seqFile = KMeansBSP.prepareInputText(k, conf, partIn, centerIn, out, + fs, false); + } else { + seqFile = KMeansBSP.prepareInputText(0, conf, partIn, new Path(centerIn + + "_empty.seq"), out, fs, false); + } + + fs.moveFromLocalFile(seqFile, new Path(parts, "part" + part + ".seq")); + fs.delete(seqFile.getParent(), true); + fs.delete(partIn.getParent(), true); } + + return parts; } } Index: ml/src/main/java/org/apache/hama/ml/kmeans/KMeansBSP.java =================================================================== --- ml/src/main/java/org/apache/hama/ml/kmeans/KMeansBSP.java (revision 1555701) +++ ml/src/main/java/org/apache/hama/ml/kmeans/KMeansBSP.java (working copy) @@ -112,7 +112,8 @@ try { distanceMeasurer = ReflectionUtils.newInstance(distanceClass); } catch (ClassNotFoundException e) { - throw new RuntimeException("Wrong DistanceMeasurer implementation " + distanceClass + " provided"); + throw new RuntimeException("Wrong DistanceMeasurer implementation " + + distanceClass + " provided"); } } else { distanceMeasurer = new EuclidianDistance(); @@ -244,8 +245,8 @@ // add the vector to the center newCenterArray[lowestDistantCenter] = newCenterArray[lowestDistantCenter] .addUnsafe(key); - summationCount[lowestDistantCenter]++; } + summationCount[lowestDistantCenter]++; } private int getNearestCenter(DoubleVector key) { @@ -514,7 +515,7 @@ fs.delete(out, true); if (fs.exists(center)) - fs.delete(out, true); + fs.delete(center, true); if (fs.exists(in)) fs.delete(in, true); Index: examples/src/main/java/org/apache/hama/examples/Kmeans.java =================================================================== --- examples/src/main/java/org/apache/hama/examples/Kmeans.java (revision 1555701) +++ examples/src/main/java/org/apache/hama/examples/Kmeans.java (working copy) @@ -67,10 +67,11 @@ Path out = new Path(args[1]); FileSystem fs = FileSystem.get(conf); Path center = null; - if (fs.isFile(in)) + if (fs.isFile(in)) { center = new Path(in.getParent(), "center/cen.seq"); - else + } else { center = new Path(in, "center/cen.seq"); + } Path centerOut = new Path(out, "center/center_output.seq"); conf.set(KMeansBSP.CENTER_IN_PATH, center.toString()); conf.set(KMeansBSP.CENTER_OUT_PATH, centerOut.toString()); @@ -84,12 +85,18 @@ int dimension = Integer.parseInt(args[6]); System.out.println("N: " + count + " Dimension: " + dimension + " Iterations: " + iterations); + if (!fs.isFile(in)) { + in = new Path(in, "input.seq"); + } // prepare the input, like deleting old versions and creating centers KMeansBSP.prepareInput(count, k, dimension, conf, in, center, out, fs); } else { + if (!fs.isFile(in)) { + System.out.println("Cannot read text input file: " + in.toString()); + return; + } // Set the last argument to TRUE if first column is required to be the key - KMeansBSP.prepareInputText(k, conf, in, center, out, fs, true); - in = new Path(in.getParent(), "textinput/in.seq"); + in = KMeansBSP.prepareInputText(k, conf, in, center, out, fs, true); } BSPJob job = KMeansBSP.createJob(conf, in, out, true); Index: core/src/main/java/org/apache/hama/pipes/util/SequenceFileDumper.java =================================================================== --- core/src/main/java/org/apache/hama/pipes/util/SequenceFileDumper.java (revision 1555701) +++ core/src/main/java/org/apache/hama/pipes/util/SequenceFileDumper.java (working copy) @@ -32,6 +32,7 @@ import org.apache.commons.logging.LogFactory; import org.apache.hadoop.fs.FileSystem; import org.apache.hadoop.fs.Path; +import org.apache.hadoop.io.NullWritable; import org.apache.hadoop.io.SequenceFile; import org.apache.hadoop.io.Writable; import org.apache.hama.HamaConfiguration; @@ -123,8 +124,18 @@ sub = Integer.parseInt(cmdLine.getOptionValue("substring")); } - Writable key = (Writable) reader.getKeyClass().newInstance(); - Writable value = (Writable) reader.getValueClass().newInstance(); + Writable key; + if (reader.getKeyClass() != NullWritable.class) { + key = (Writable) reader.getKeyClass().newInstance(); + } else { + key = NullWritable.get(); + } + Writable value; + if (reader.getValueClass() != NullWritable.class) { + value = (Writable) reader.getValueClass().newInstance(); + } else { + value = NullWritable.get(); + } writer.append("Key class: ") .append(String.valueOf(reader.getKeyClass()))