Search in sources :

Example 1 with ClustererDoc

use of edu.stanford.nlp.coref.statistical.ClustererDataLoader.ClustererDoc in project CoreNLP by stanfordnlp.

the class Clusterer method doTraining.

public void doTraining(String modelName) {
    classifier.setWeight("bias", -0.3);
    classifier.setWeight("anaphorSeen", -1);
    classifier.setWeight("max-ranking", 1);
    classifier.setWeight("bias-single", -0.3);
    classifier.setWeight("anaphorSeen-single", -1);
    classifier.setWeight("max-ranking-single", 1);
    String outputPath = StatisticalCorefTrainer.clusteringModelsPath + modelName + "/";
    File outDir = new File(outputPath);
    if (!outDir.exists()) {
        outDir.mkdir();
    }
    PrintWriter progressWriter;
    List<ClustererDoc> trainDocs;
    try {
        PrintWriter configWriter = new PrintWriter(outputPath + "config", "UTF-8");
        configWriter.print(StatisticalCorefTrainer.fieldValues(this));
        configWriter.close();
        progressWriter = new PrintWriter(outputPath + "progress", "UTF-8");
        Redwood.log("scoref.train", "Loading training data");
        StatisticalCorefTrainer.setDataPath("dev");
        trainDocs = ClustererDataLoader.loadDocuments(MAX_DOCS);
    } catch (Exception e) {
        throw new RuntimeException("Error setting up training", e);
    }
    double bestTrainScore = 0;
    List<List<Pair<CandidateAction, CandidateAction>>> examples = new ArrayList<>();
    for (int iteration = 0; iteration < RETRAIN_ITERATIONS; iteration++) {
        Redwood.log("scoref.train", "ITERATION " + iteration);
        classifier.printWeightVector(null);
        Redwood.log("scoref.train", "");
        try {
            classifier.writeWeights(outputPath + "model");
            classifier.printWeightVector(IOUtils.getPrintWriter(outputPath + "weights"));
        } catch (Exception e) {
            throw new RuntimeException();
        }
        long start = System.currentTimeMillis();
        Collections.shuffle(trainDocs, random);
        examples = examples.subList(Math.max(0, examples.size() - BUFFER_SIZE_MULTIPLIER * trainDocs.size()), examples.size());
        trainPolicy(examples);
        if (iteration % EVAL_FREQUENCY == 0) {
            double trainScore = evaluatePolicy(trainDocs, true);
            if (trainScore > bestTrainScore) {
                bestTrainScore = trainScore;
                writeModel("best", outputPath);
            }
            if (iteration % 10 == 0) {
                writeModel("iter_" + iteration, outputPath);
            }
            writeModel("last", outputPath);
            double timeElapsed = (System.currentTimeMillis() - start) / 1000.0;
            double ffhr = State.ffHits / (double) (State.ffHits + State.ffMisses);
            double shr = State.sHits / (double) (State.sHits + State.sMisses);
            double fhr = featuresCacheHits / (double) (featuresCacheHits + featuresCacheMisses);
            Redwood.log("scoref.train", modelName);
            Redwood.log("scoref.train", String.format("Best train: %.4f", bestTrainScore));
            Redwood.log("scoref.train", String.format("Time elapsed: %.2f", timeElapsed));
            Redwood.log("scoref.train", String.format("Cost hit rate: %.4f", ffhr));
            Redwood.log("scoref.train", String.format("Score hit rate: %.4f", shr));
            Redwood.log("scoref.train", String.format("Features hit rate: %.4f", fhr));
            Redwood.log("scoref.train", "");
            progressWriter.write(iteration + " " + trainScore + " " + " " + timeElapsed + " " + ffhr + " " + shr + " " + fhr + "\n");
            progressWriter.flush();
        }
        for (ClustererDoc trainDoc : trainDocs) {
            examples.add(runPolicy(trainDoc, Math.pow(EXPERT_DECAY, (iteration + 1))));
        }
    }
    progressWriter.close();
}
Also used : ClustererDoc(edu.stanford.nlp.coref.statistical.ClustererDataLoader.ClustererDoc) ArrayList(java.util.ArrayList) ArrayList(java.util.ArrayList) List(java.util.List) File(java.io.File) PrintWriter(java.io.PrintWriter)

Example 2 with ClustererDoc

use of edu.stanford.nlp.coref.statistical.ClustererDataLoader.ClustererDoc in project CoreNLP by stanfordnlp.

the class Clusterer method evaluatePolicy.

private double evaluatePolicy(List<ClustererDoc> docs, boolean training) {
    isTraining = 0;
    B3Evaluator evaluator = new B3Evaluator();
    for (ClustererDoc doc : docs) {
        State currentState = new State(doc);
        while (!currentState.isComplete()) {
            currentState.doBestAction(classifier);
        }
        currentState.updateEvaluator(evaluator);
    }
    isTraining = 1;
    double score = evaluator.getF1();
    Redwood.log("scoref.train", String.format("B3 F1 score on %s: %.4f", training ? "train" : "validate", score));
    return score;
}
Also used : ClustererDoc(edu.stanford.nlp.coref.statistical.ClustererDataLoader.ClustererDoc) B3Evaluator(edu.stanford.nlp.coref.statistical.EvalUtils.B3Evaluator)

Example 3 with ClustererDoc

use of edu.stanford.nlp.coref.statistical.ClustererDataLoader.ClustererDoc in project CoreNLP by stanfordnlp.

the class ClusteringCorefAlgorithm method runCoref.

@Override
public void runCoref(Document document) {
    Map<Pair<Integer, Integer>, Boolean> mentionPairs = CorefUtils.getUnlabeledMentionPairs(document);
    if (mentionPairs.size() == 0) {
        return;
    }
    Compressor<String> compressor = new Compressor<>();
    DocumentExamples examples = extractor.extract(0, document, mentionPairs, compressor);
    Counter<Pair<Integer, Integer>> classificationScores = new ClassicCounter<>();
    Counter<Pair<Integer, Integer>> rankingScores = new ClassicCounter<>();
    Counter<Integer> anaphoricityScores = new ClassicCounter<>();
    for (Example example : examples.examples) {
        CorefUtils.checkForInterrupt();
        Pair<Integer, Integer> mentionPair = new Pair<>(example.mentionId1, example.mentionId2);
        classificationScores.incrementCount(mentionPair, classificationModel.predict(example, examples.mentionFeatures, compressor));
        rankingScores.incrementCount(mentionPair, rankingModel.predict(example, examples.mentionFeatures, compressor));
        if (!anaphoricityScores.containsKey(example.mentionId2)) {
            anaphoricityScores.incrementCount(example.mentionId2, anaphoricityModel.predict(new Example(example, false), examples.mentionFeatures, compressor));
        }
    }
    ClustererDoc doc = new ClustererDoc(0, classificationScores, rankingScores, anaphoricityScores, mentionPairs, null, document.predictedMentionsByID.entrySet().stream().collect(Collectors.toMap(Map.Entry::getKey, e -> e.getValue().mentionType.toString())));
    for (Pair<Integer, Integer> mentionPair : clusterer.getClusterMerges(doc)) {
        CorefUtils.mergeCoreferenceClusters(mentionPair, document);
    }
}
Also used : ClustererDoc(edu.stanford.nlp.coref.statistical.ClustererDataLoader.ClustererDoc) ClassicCounter(edu.stanford.nlp.stats.ClassicCounter) Pair(edu.stanford.nlp.util.Pair)

Aggregations

ClustererDoc (edu.stanford.nlp.coref.statistical.ClustererDataLoader.ClustererDoc)3 B3Evaluator (edu.stanford.nlp.coref.statistical.EvalUtils.B3Evaluator)1 ClassicCounter (edu.stanford.nlp.stats.ClassicCounter)1 Pair (edu.stanford.nlp.util.Pair)1 File (java.io.File)1 PrintWriter (java.io.PrintWriter)1 ArrayList (java.util.ArrayList)1 List (java.util.List)1