use of edu.umd.hooka.alignment.hmm.ATable in project Cloud9 by lintool.
the class HadoopAlign method doAlignment.
@SuppressWarnings("deprecation")
public static void doAlignment(int mapTasks, int reduceTasks, HadoopAlignConfig hac) throws IOException {
System.out.println("Running alignment: " + hac);
FileSystem fs = FileSystem.get(hac);
Path cbtxt = new Path(hac.getRoot() + "/comp-bitext");
// fs.delete(cbtxt, true);
if (!fs.exists(cbtxt)) {
CorpusVocabNormalizerAndNumberizer.preprocessAndNumberizeFiles(hac, hac.getBitexts(), cbtxt);
}
System.out.println("Finished preprocessing");
int m1iters = hac.getModel1Iterations();
int hmmiters = hac.getHMMIterations();
int totalIterations = m1iters + hmmiters;
String modelType = null;
ArrayList<Double> perps = new ArrayList<Double>();
ArrayList<Double> aers = new ArrayList<Double>();
boolean hmm = false;
boolean firstHmm = true;
Path model1PosteriorsPath = null;
for (int iteration = 0; iteration < totalIterations; iteration++) {
long start = System.currentTimeMillis();
hac.setBoolean("ha.generate.posterios", false);
boolean lastIteration = (iteration == totalIterations - 1);
boolean lastModel1Iteration = (iteration == m1iters - 1);
if (iteration >= m1iters)
hmm = true;
if (hmm)
modelType = "HMM";
else
modelType = "Model1";
FileSystem fileSys = FileSystem.get(hac);
String sOutputPath = modelType + ".data." + iteration;
Path outputPath = new Path(sOutputPath);
try {
if (// no probs in first iteration!
usePServer && iteration > 0)
startPServers(hac);
System.out.println("Starting iteration " + iteration + (iteration == 0 ? " (initialization)" : "") + ": " + modelType);
JobConf conf = new JobConf(hac, HadoopAlign.class);
conf.setJobName("EMTrain." + modelType + ".iter" + iteration);
conf.setInputFormat(SequenceFileInputFormat.class);
conf.set(KEY_TRAINER, MODEL1_TRAINER);
conf.set(KEY_ITERATION, Integer.toString(iteration));
conf.set("mapred.child.java.opts", "-Xmx2048m");
if (iteration == 0)
conf.set(KEY_TRAINER, MODEL1_UNIFORM_INIT);
if (hmm) {
conf.set(KEY_TRAINER, HMM_TRAINER);
if (firstHmm) {
firstHmm = false;
System.out.println("Writing default a-table...");
Path pathATable = hac.getATablePath();
fileSys.delete(pathATable, true);
DataOutputStream dos = new DataOutputStream(new BufferedOutputStream(fileSys.create(pathATable)));
int cond_values = 1;
if (!hac.isHMMHomogeneous()) {
cond_values = 100;
}
ATable at = new ATable(hac.isHMMHomogeneous(), cond_values, 100);
at.normalize();
at.write(dos);
// System.out.println(at);
dos.close();
}
}
conf.setOutputKeyClass(IntWritable.class);
conf.setOutputValueClass(PartialCountContainer.class);
conf.setMapperClass(EMapper.class);
conf.setReducerClass(EMReducer.class);
conf.setNumMapTasks(mapTasks);
conf.setNumReduceTasks(reduceTasks);
System.out.println("Running job " + conf.getJobName());
// otherwise, input is set to output of last model 1 iteration
if (model1PosteriorsPath != null) {
System.out.println("Input: " + model1PosteriorsPath);
FileInputFormat.setInputPaths(conf, model1PosteriorsPath);
} else {
System.out.println("Input: " + cbtxt);
FileInputFormat.setInputPaths(conf, cbtxt);
}
System.out.println("Output: " + outputPath);
FileOutputFormat.setOutputPath(conf, new Path(hac.getRoot() + "/" + outputPath.toString()));
fileSys.delete(new Path(hac.getRoot() + "/" + outputPath.toString()), true);
conf.setOutputFormat(SequenceFileOutputFormat.class);
RunningJob job = JobClient.runJob(conf);
Counters c = job.getCounters();
double lp = c.getCounter(CrossEntropyCounters.LOGPROB);
double wc = c.getCounter(CrossEntropyCounters.WORDCOUNT);
double ce = lp / wc / Math.log(2);
double perp = Math.pow(2.0, ce);
double aer = ComputeAER(c);
System.out.println("Iteration " + iteration + ": (" + modelType + ")\tCROSS-ENTROPY: " + ce + " PERPLEXITY: " + perp);
System.out.println("Iteration " + iteration + ": " + aer + " AER");
aers.add(aer);
perps.add(perp);
} finally {
stopPServers();
}
JobConf conf = new JobConf(hac, ModelMergeMapper2.class);
System.err.println("Setting " + TTABLE_ITERATION_OUTPUT + " to " + outputPath.toString());
conf.set(TTABLE_ITERATION_OUTPUT, hac.getRoot() + "/" + outputPath.toString());
conf.setJobName("EMTrain.ModelMerge");
// conf.setOutputKeyClass(LongWritable.class);
conf.setMapperClass(ModelMergeMapper2.class);
conf.setSpeculativeExecution(false);
conf.setNumMapTasks(1);
conf.setNumReduceTasks(0);
conf.setInputFormat(NullInputFormat.class);
conf.setOutputFormat(NullOutputFormat.class);
conf.set("mapred.map.child.java.opts", "-Xmx2048m");
conf.set("mapred.reduce.child.java.opts", "-Xmx2048m");
// FileInputFormat.setInputPaths(conf, root+"/dummy");
// fileSys.delete(new Path(root+"/dummy.out"), true);
// FileOutputFormat.setOutputPath(conf, new Path(root+"/dummy.out"));
// conf.setOutputFormat(SequenceFileOutputFormat.class);
System.out.println("Running job " + conf.getJobName());
System.out.println("Input: " + hac.getRoot() + "/dummy");
System.out.println("Output: " + hac.getRoot() + "/dummy.out");
JobClient.runJob(conf);
fileSys.delete(new Path(hac.getRoot() + "/" + outputPath.toString()), true);
if (lastIteration || lastModel1Iteration) {
//hac.setBoolean("ha.generate.posteriors", true);
conf = new JobConf(hac, HadoopAlign.class);
sOutputPath = modelType + ".data." + iteration;
outputPath = new Path(sOutputPath);
conf.setJobName(modelType + ".align");
conf.set("mapred.map.child.java.opts", "-Xmx2048m");
conf.set("mapred.reduce.child.java.opts", "-Xmx2048m");
// TODO use file cache
/*try {
if (hmm || iteration > 0) {
URI ttable = new URI(fileSys.getHomeDirectory() + Path.SEPARATOR + hac.getTTablePath().toString());
DistributedCache.addCacheFile(ttable, conf);
System.out.println("cache<-- " + ttable);
}
} catch (Exception e) { throw new RuntimeException("Caught " + e); }
*/
conf.setInputFormat(SequenceFileInputFormat.class);
conf.setOutputFormat(SequenceFileOutputFormat.class);
conf.set(KEY_TRAINER, MODEL1_TRAINER);
conf.set(KEY_ITERATION, Integer.toString(iteration));
if (hmm)
conf.set(KEY_TRAINER, HMM_TRAINER);
conf.setOutputKeyClass(Text.class);
conf.setOutputValueClass(PhrasePair.class);
conf.setMapperClass(AlignMapper.class);
conf.setReducerClass(IdentityReducer.class);
conf.setNumMapTasks(mapTasks);
conf.setNumReduceTasks(reduceTasks);
FileOutputFormat.setOutputPath(conf, new Path(hac.getRoot() + "/" + outputPath.toString()));
//if last model1 iteration, save output path, to be used as input path in later iterations
if (lastModel1Iteration) {
FileInputFormat.setInputPaths(conf, cbtxt);
model1PosteriorsPath = new Path(hac.getRoot() + "/" + outputPath.toString());
} else {
FileInputFormat.setInputPaths(conf, model1PosteriorsPath);
}
fileSys.delete(outputPath, true);
System.out.println("Running job " + conf.getJobName());
RunningJob job = JobClient.runJob(conf);
System.out.println("GENERATED: " + model1PosteriorsPath);
Counters c = job.getCounters();
double aer = ComputeAER(c);
// System.out.println("Iteration " + iteration + ": (" + modelType + ")\tCROSS-ENTROPY: " + ce + " PERPLEXITY: " + perp);
System.out.println("Iteration " + iteration + ": " + aer + " AER");
aers.add(aer);
perps.add(0.0);
}
long end = System.currentTimeMillis();
System.out.println(modelType + " iteration " + iteration + " took " + ((end - start) / 1000) + " seconds.");
}
for (int i = 0; i < perps.size(); i++) {
System.out.print("I=" + i + "\t");
if (aers.size() > 0) {
System.out.print(aers.get(i) + "\t");
}
System.out.println(perps.get(i));
}
}
use of edu.umd.hooka.alignment.hmm.ATable in project Cloud9 by lintool.
the class HadoopAlign method loadATable.
public static ATable loadATable(Path path, Configuration job) throws IOException {
org.apache.hadoop.conf.Configuration conf = new org.apache.hadoop.conf.Configuration(job);
FileSystem fileSys = FileSystem.get(conf);
DataInput in = new DataInputStream(new BufferedInputStream(fileSys.open(path)));
ATable at = new ATable();
at.readFields(in);
return at;
}
use of edu.umd.hooka.alignment.hmm.ATable in project Cloud9 by lintool.
the class PartialCountContainer method readFields.
public void readFields(DataInput in) throws IOException {
type = in.readByte();
if (type == CONTENT_ATABLE) {
content = new ATable();
} else if (type == CONTENT_ARRAY) {
content = new IndexedFloatArray();
} else {
throw new RuntimeException("Bad content type!");
}
content.readFields(in);
}
Aggregations