use of edu.stanford.nlp.coref.statistical.ClustererDataLoader.ClustererDoc in project CoreNLP by stanfordnlp.
the class Clusterer method doTraining.
public void doTraining(String modelName) {
classifier.setWeight("bias", -0.3);
classifier.setWeight("anaphorSeen", -1);
classifier.setWeight("max-ranking", 1);
classifier.setWeight("bias-single", -0.3);
classifier.setWeight("anaphorSeen-single", -1);
classifier.setWeight("max-ranking-single", 1);
String outputPath = StatisticalCorefTrainer.clusteringModelsPath + modelName + "/";
File outDir = new File(outputPath);
if (!outDir.exists()) {
outDir.mkdir();
}
PrintWriter progressWriter;
List<ClustererDoc> trainDocs;
try {
PrintWriter configWriter = new PrintWriter(outputPath + "config", "UTF-8");
configWriter.print(StatisticalCorefTrainer.fieldValues(this));
configWriter.close();
progressWriter = new PrintWriter(outputPath + "progress", "UTF-8");
Redwood.log("scoref.train", "Loading training data");
StatisticalCorefTrainer.setDataPath("dev");
trainDocs = ClustererDataLoader.loadDocuments(MAX_DOCS);
} catch (Exception e) {
throw new RuntimeException("Error setting up training", e);
}
double bestTrainScore = 0;
List<List<Pair<CandidateAction, CandidateAction>>> examples = new ArrayList<>();
for (int iteration = 0; iteration < RETRAIN_ITERATIONS; iteration++) {
Redwood.log("scoref.train", "ITERATION " + iteration);
classifier.printWeightVector(null);
Redwood.log("scoref.train", "");
try {
classifier.writeWeights(outputPath + "model");
classifier.printWeightVector(IOUtils.getPrintWriter(outputPath + "weights"));
} catch (Exception e) {
throw new RuntimeException();
}
long start = System.currentTimeMillis();
Collections.shuffle(trainDocs, random);
examples = examples.subList(Math.max(0, examples.size() - BUFFER_SIZE_MULTIPLIER * trainDocs.size()), examples.size());
trainPolicy(examples);
if (iteration % EVAL_FREQUENCY == 0) {
double trainScore = evaluatePolicy(trainDocs, true);
if (trainScore > bestTrainScore) {
bestTrainScore = trainScore;
writeModel("best", outputPath);
}
if (iteration % 10 == 0) {
writeModel("iter_" + iteration, outputPath);
}
writeModel("last", outputPath);
double timeElapsed = (System.currentTimeMillis() - start) / 1000.0;
double ffhr = State.ffHits / (double) (State.ffHits + State.ffMisses);
double shr = State.sHits / (double) (State.sHits + State.sMisses);
double fhr = featuresCacheHits / (double) (featuresCacheHits + featuresCacheMisses);
Redwood.log("scoref.train", modelName);
Redwood.log("scoref.train", String.format("Best train: %.4f", bestTrainScore));
Redwood.log("scoref.train", String.format("Time elapsed: %.2f", timeElapsed));
Redwood.log("scoref.train", String.format("Cost hit rate: %.4f", ffhr));
Redwood.log("scoref.train", String.format("Score hit rate: %.4f", shr));
Redwood.log("scoref.train", String.format("Features hit rate: %.4f", fhr));
Redwood.log("scoref.train", "");
progressWriter.write(iteration + " " + trainScore + " " + " " + timeElapsed + " " + ffhr + " " + shr + " " + fhr + "\n");
progressWriter.flush();
}
for (ClustererDoc trainDoc : trainDocs) {
examples.add(runPolicy(trainDoc, Math.pow(EXPERT_DECAY, (iteration + 1))));
}
}
progressWriter.close();
}
use of edu.stanford.nlp.coref.statistical.ClustererDataLoader.ClustererDoc in project CoreNLP by stanfordnlp.
the class Clusterer method evaluatePolicy.
private double evaluatePolicy(List<ClustererDoc> docs, boolean training) {
isTraining = 0;
B3Evaluator evaluator = new B3Evaluator();
for (ClustererDoc doc : docs) {
State currentState = new State(doc);
while (!currentState.isComplete()) {
currentState.doBestAction(classifier);
}
currentState.updateEvaluator(evaluator);
}
isTraining = 1;
double score = evaluator.getF1();
Redwood.log("scoref.train", String.format("B3 F1 score on %s: %.4f", training ? "train" : "validate", score));
return score;
}
use of edu.stanford.nlp.coref.statistical.ClustererDataLoader.ClustererDoc in project CoreNLP by stanfordnlp.
the class ClusteringCorefAlgorithm method runCoref.
@Override
public void runCoref(Document document) {
Map<Pair<Integer, Integer>, Boolean> mentionPairs = CorefUtils.getUnlabeledMentionPairs(document);
if (mentionPairs.size() == 0) {
return;
}
Compressor<String> compressor = new Compressor<>();
DocumentExamples examples = extractor.extract(0, document, mentionPairs, compressor);
Counter<Pair<Integer, Integer>> classificationScores = new ClassicCounter<>();
Counter<Pair<Integer, Integer>> rankingScores = new ClassicCounter<>();
Counter<Integer> anaphoricityScores = new ClassicCounter<>();
for (Example example : examples.examples) {
CorefUtils.checkForInterrupt();
Pair<Integer, Integer> mentionPair = new Pair<>(example.mentionId1, example.mentionId2);
classificationScores.incrementCount(mentionPair, classificationModel.predict(example, examples.mentionFeatures, compressor));
rankingScores.incrementCount(mentionPair, rankingModel.predict(example, examples.mentionFeatures, compressor));
if (!anaphoricityScores.containsKey(example.mentionId2)) {
anaphoricityScores.incrementCount(example.mentionId2, anaphoricityModel.predict(new Example(example, false), examples.mentionFeatures, compressor));
}
}
ClustererDoc doc = new ClustererDoc(0, classificationScores, rankingScores, anaphoricityScores, mentionPairs, null, document.predictedMentionsByID.entrySet().stream().collect(Collectors.toMap(Map.Entry::getKey, e -> e.getValue().mentionType.toString())));
for (Pair<Integer, Integer> mentionPair : clusterer.getClusterMerges(doc)) {
CorefUtils.mergeCoreferenceClusters(mentionPair, document);
}
}
Aggregations