import java.util.ArrayList;
import java.io.IOException;

import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.ByteWritable;
import org.apache.hadoop.io.SequenceFile;
import org.apache.hadoop.io.SequenceFile.Writer;
import org.apache.hadoop.io.SequenceFile.CompressionType;


public class TestWriteConcurrency {
  static final int NUM_WRITES = 1000*1000;

  public static class WriterThread extends Thread {

    private String testPath;
    private double writesPerSecond;
    private FileSystem fs;
    private Configuration conf;

    public WriterThread(Configuration conf, FileSystem fs, String testPath) {
      this.conf = conf;
      this.fs = fs;
      this.testPath = testPath;
      this.writesPerSecond = -1;
    }

    public void run() {
      try {
        Writer f = SequenceFile.createWriter(
          fs, conf, new Path(testPath),
          ByteWritable.class, ByteWritable.class, CompressionType.NONE);
        ByteWritable v = new ByteWritable();
						
        long time = System.currentTimeMillis();
        for (int i = 0; i < NUM_WRITES; i ++)
          f.append(v, v);
        f.close();
        long end = System.currentTimeMillis();
						
        long msTaken = (end - time);
        writesPerSecond = NUM_WRITES / ((double)msTaken/1000.0);
      } catch (Exception e) {
        // TODO Auto-generated catch block
        e.printStackTrace();
      }					
    }
  }

  public static double doBench(Configuration conf, FileSystem fs, int numThreads) throws Exception {
    ArrayList<WriterThread> thrs = new ArrayList<WriterThread>(numThreads);
    for (int i = 0; i < numThreads; i++) {
      WriterThread t = new WriterThread(conf, fs, "/dev/shm/toddtest/" + i);
      t.start();
      thrs.add(t);
    }

    double averageWritesPerSecond = 0;
    for (WriterThread t : thrs) {
      t.join();
      if (t.writesPerSecond < 0) {
        throw new RuntimeException("Bad writes per second!");
      }

      averageWritesPerSecond += t.writesPerSecond / (double)numThreads;
    }

    return averageWritesPerSecond;
  }

  public static void main(String[] args) throws Exception {
    final Configuration c = new Configuration();
    final FileSystem fs = FileSystem.get(c);		

    for (int trial = 0; trial < 30; trial++) {
      for (int num_thr=1; num_thr <= 8; num_thr++) {
        double writesPerSecond = doBench(c, fs, num_thr);
        if (trial > 5) {
          System.out.printf("%d\t%.1f\n", num_thr, writesPerSecond);
        }
      }
    }
  }
}
