use of edu.illinois.cs.cogcomp.depparse.core.LabeledChuLiuEdmondsDecoder in project cogcomp-nlp by CogComp.
the class MainClass method train.
private static SLModel train(String trainFile, String configFilePath, String modelFile) throws Exception {
SLModel model = new SLModel();
SLParameters para = new SLParameters();
para.loadConfigFile(configFilePath);
model.lm = new Lexiconer(true);
if (model.lm.isAllowNewFeatures())
model.lm.addFeature("W:unknownword");
model.featureGenerator = new LabeledDepFeatureGenerator(model.lm);
model.infSolver = new LabeledChuLiuEdmondsDecoder(model.featureGenerator);
SLProblem problem = getStructuredData(trainFile, (LabeledChuLiuEdmondsDecoder) model.infSolver);
((LabeledChuLiuEdmondsDecoder) model.infSolver).saveDepRelDict();
Learner learner = LearnerFactory.getLearner(model.infSolver, model.featureGenerator, para);
learner.runWhenReportingProgress((w, inference) -> printMemoryUsage());
model.wv = learner.train(problem);
printMemoryUsage();
model.lm.setAllowNewFeatures(false);
model.saveModel(modelFile);
return model;
}
use of edu.illinois.cs.cogcomp.depparse.core.LabeledChuLiuEdmondsDecoder in project cogcomp-nlp by CogComp.
the class DepAnnotator method initialize.
@Override
public void initialize(ResourceManager rm) {
try {
// TODO Ugly hack: SL doesn't accept streams and can't create a file from inside a jar
File dest = new File(TEMP_MODEL_FILE_NAME);
String modelName = rm.getString(DepConfigurator.MODEL_NAME.key);
URL fileURL = IOUtils.lsResources(DepAnnotator.class, modelName).get(0);
logger.info("Loading {} into temp file: {}", modelName, TEMP_MODEL_FILE_NAME);
FileUtils.copyURLToFile(fileURL, dest);
model = SLModel.loadModel(TEMP_MODEL_FILE_NAME);
((LabeledChuLiuEdmondsDecoder) model.infSolver).loadDepRelDict();
if (!dest.delete())
throw new IOException("Could not delete temporary model file " + TEMP_MODEL_FILE_NAME);
} catch (IOException | ClassNotFoundException | URISyntaxException e) {
e.printStackTrace();
File dest = new File(TEMP_MODEL_FILE_NAME);
if (!dest.delete())
throw new RuntimeException("Could not delete temporary model file " + TEMP_MODEL_FILE_NAME);
}
}
use of edu.illinois.cs.cogcomp.depparse.core.LabeledChuLiuEdmondsDecoder in project cogcomp-nlp by CogComp.
the class MainClass method test.
private static void test(String modelPath, String testDataPath, boolean updateMatrix) throws Exception {
SLModel model = SLModel.loadModel(modelPath);
((LabeledChuLiuEdmondsDecoder) model.infSolver).loadDepRelDict();
SLProblem sp = getStructuredData(testDataPath, (LabeledChuLiuEdmondsDecoder) model.infSolver);
double acc_undirected = 0.0;
double acc_directed_unlabeled = 0.0;
double acc_labeled = 0.0;
double total = 0.0;
long totalTime = 0L;
int totalLength = 0;
for (int i = 0; i < sp.instanceList.size(); i++) {
DepInst sent = (DepInst) sp.instanceList.get(i);
totalLength += sent.size();
DepStruct gold = (DepStruct) sp.goldStructureList.get(i);
long startTime = System.currentTimeMillis();
DepStruct prediction = (DepStruct) model.infSolver.getBestStructure(model.wv, sent);
totalTime += (System.currentTimeMillis() - startTime);
IntPair tmp_undirected = evaluate(sent, gold, prediction, false, false, false);
IntPair tmp_directed_unlabeled = evaluate(sent, gold, prediction, true, false, false);
IntPair tmp_labeled = evaluate(sent, gold, prediction, true, true, updateMatrix);
acc_undirected += tmp_undirected.getFirst();
acc_directed_unlabeled += tmp_directed_unlabeled.getFirst();
acc_labeled += tmp_labeled.getFirst();
total += tmp_directed_unlabeled.getSecond();
}
System.out.println("Parsing time taken for " + sp.size() + " sentences with average length " + totalLength / sp.size() + ": " + totalTime);
System.out.println("Average parsing time " + totalTime / sp.size());
System.out.println("undirected acc " + acc_undirected);
System.out.println("directed unlabeled acc " + acc_directed_unlabeled);
System.out.println("labeled acc " + acc_labeled);
System.out.println("total " + total);
System.out.println("%age correct undirected " + (acc_undirected * 1.0 / total));
System.out.println("%age correct directed & unlabeled " + (acc_directed_unlabeled * 1.0 / total));
System.out.println("%age correct labeled " + (acc_labeled * 1.0 / total));
if (updateMatrix)
printMatrix();
System.out.println("Done with testing!");
}
Aggregations