Search in sources :

Example 11 with BatchTrainer

use of edu.illinois.cs.cogcomp.lbjava.learn.BatchTrainer in project cogcomp-nlp by CogComp.

the class ACERelationTester method test_ts_predicted.

public static void test_ts_predicted() {
    int total_correct = 0;
    int total_labeled = 0;
    int total_predicted = 0;
    int total_coarse_correct = 0;
    ACEMentionReader train_parser = IOHelper.serializeDataIn("relation-extraction/preprocess/reader/all");
    relation_classifier classifier = new relation_classifier();
    classifier.setLexiconLocation("models/relation_classifier_all");
    BatchTrainer trainer = new BatchTrainer(classifier, train_parser);
    Learner preExtractLearner = trainer.preExtract("models/relation_classifier_all", true, Lexicon.CountPolicy.none);
    preExtractLearner.saveLexicon();
    Lexicon lexicon = preExtractLearner.getLexicon();
    classifier.setLexicon(lexicon);
    int examples = train_parser.relations_bi.size();
    classifier.initialize(examples, preExtractLearner.getLexicon().size());
    for (Relation r : train_parser.relations_bi) {
        classifier.learn(r);
    }
    classifier.doneWithRound();
    classifier.doneLearning();
    ACERelationConstrainedClassifier constrainedClassifier = new ACERelationConstrainedClassifier(classifier);
    PredictedMentionReader predictedMentionReader = new PredictedMentionReader("data/partition_with_dev/dev");
    total_labeled = predictedMentionReader.size_of_gold_relations;
    for (Object o = predictedMentionReader.next(); o != null; o = predictedMentionReader.next()) {
        Relation r = (Relation) o;
        String gold_label = r.getAttribute("RelationSubtype");
        String predicted_label = constrainedClassifier.discreteValue(r);
        if (!predicted_label.equals("NOT_RELATED")) {
            total_predicted++;
        }
        if (!gold_label.equals("NOT_RELATED")) {
            if (gold_label.equals(predicted_label)) {
                total_correct++;
            }
            if (getCoarseType(gold_label).equals(getCoarseType(predicted_label))) {
                total_coarse_correct++;
            }
        }
    }
    System.out.println("Total labeled: " + total_labeled);
    System.out.println("Total predicted: " + total_predicted);
    System.out.println("Total correct: " + total_correct);
    System.out.println("Total coarse correct: " + total_coarse_correct);
    double p = (double) total_correct * 100.0 / (double) total_predicted;
    double r = (double) total_correct * 100.0 / (double) total_labeled;
    double f = 2 * p * r / (p + r);
    System.out.println("Precision: " + p);
    System.out.println("Recall: " + r);
    System.out.println("Fine Type F1: " + f);
    System.out.println("Coarse Type F1: " + f * (double) total_coarse_correct / (double) total_correct);
}
Also used : Relation(edu.illinois.cs.cogcomp.core.datastructures.textannotation.Relation) BatchTrainer(edu.illinois.cs.cogcomp.lbjava.learn.BatchTrainer) Lexicon(edu.illinois.cs.cogcomp.lbjava.learn.Lexicon) LbjGen.relation_classifier(org.cogcomp.re.LbjGen.relation_classifier) Learner(edu.illinois.cs.cogcomp.lbjava.learn.Learner)

Example 12 with BatchTrainer

use of edu.illinois.cs.cogcomp.lbjava.learn.BatchTrainer in project cogcomp-nlp by CogComp.

the class RelationExtractionTest method testSimpleTraining.

@Test
public void testSimpleTraining() {
    File modelDir = null;
    try {
        Datastore ds = new Datastore(new ResourceConfigurator().getDefaultConfig());
        modelDir = ds.getDirectory("org.cogcomp.re", "ACE_TEST_DOCS", 1.1, false);
    } catch (Exception e) {
        e.printStackTrace();
    }
    ACEMentionReader train_parser = new ACEMentionReader(modelDir.getPath() + File.separator + "ACE_TEST_DOCS", "relations_bi");
    relation_classifier classifier = new relation_classifier();
    classifier.setLexiconLocation("src/test/tmp.lex");
    BatchTrainer trainer = new BatchTrainer(classifier, train_parser);
    Learner preExtractLearner = trainer.preExtract("src/test/tmp.ex", true, Lexicon.CountPolicy.none);
    preExtractLearner.saveLexicon();
    Lexicon lexicon = preExtractLearner.getLexicon();
    classifier.setLexicon(lexicon);
    int examples = train_parser.relations_bi.size();
    classifier.initialize(examples, preExtractLearner.getLexicon().size());
    for (Relation r : train_parser.relations_bi) {
        classifier.learn(r);
    }
    classifier.doneWithRound();
    classifier.doneLearning();
    train_parser.reset();
    int correct = 0;
    for (Relation r : train_parser.relations_bi) {
        String tag = classifier.discreteValue(r);
        if (tag.equals(r.getAttribute("RelationSubtype"))) {
            correct++;
        }
    }
    assertTrue(correct > 0);
}
Also used : Relation(edu.illinois.cs.cogcomp.core.datastructures.textannotation.Relation) BatchTrainer(edu.illinois.cs.cogcomp.lbjava.learn.BatchTrainer) Datastore(org.cogcomp.Datastore) Lexicon(edu.illinois.cs.cogcomp.lbjava.learn.Lexicon) LbjGen.relation_classifier(org.cogcomp.re.LbjGen.relation_classifier) ResourceConfigurator(edu.illinois.cs.cogcomp.core.resources.ResourceConfigurator) File(java.io.File) ACEMentionReader(org.cogcomp.re.ACEMentionReader) Learner(edu.illinois.cs.cogcomp.lbjava.learn.Learner) Test(org.junit.Test)

Example 13 with BatchTrainer

use of edu.illinois.cs.cogcomp.lbjava.learn.BatchTrainer in project cogcomp-nlp by CogComp.

the class Quantifier method train.

public void train() {
    QuantitiesClassifier classifier = new QuantitiesClassifier(modelName + ".lc", modelName + ".lex");
    QuantitiesDataReader trainReader = new QuantitiesDataReader(dataDir + "/train.txt", "train");
    BatchTrainer trainer = new BatchTrainer(classifier, trainReader);
    trainer.train(45);
    classifier.save();
}
Also used : BatchTrainer(edu.illinois.cs.cogcomp.lbjava.learn.BatchTrainer)

Example 14 with BatchTrainer

use of edu.illinois.cs.cogcomp.lbjava.learn.BatchTrainer in project cogcomp-nlp by CogComp.

the class ClassifierComparison method printConstrainedClassifierPerformance.

public static void printConstrainedClassifierPerformance(Parser parser) {
    List<Pair<Classifier, EvaluateDiscrete>> classifiers = new ArrayList<>();
    LocalCommaClassifier learner = new LocalCommaClassifier();
    EvaluateDiscrete unconstrainedPerformance = new EvaluateDiscrete();
    learner.setLTU(new SparseAveragedPerceptron(0.003, 0, 3.5));
    classifiers.add(new Pair<Classifier, EvaluateDiscrete>(new SubstitutePairConstrainedCommaClassifier(), new EvaluateDiscrete()));
    classifiers.add(new Pair<Classifier, EvaluateDiscrete>(new LocativePairConstrainedCommaClassifier(), new EvaluateDiscrete()));
    classifiers.add(new Pair<Classifier, EvaluateDiscrete>(new ListCommasConstrainedCommaClassifier(), new EvaluateDiscrete()));
    classifiers.add(new Pair<Classifier, EvaluateDiscrete>(new OxfordCommaConstrainedCommaClassifier(), new EvaluateDiscrete()));
    int k = 5;
    parser.reset();
    FoldParser foldParser = new FoldParser(parser, k, SplitPolicy.sequential, 0, false);
    for (int i = 0; i < k; foldParser.setPivot(++i)) {
        foldParser.setFromPivot(false);
        foldParser.reset();
        learner.forget();
        BatchTrainer bt = new BatchTrainer(learner, foldParser);
        Lexicon lexicon = bt.preExtract(null);
        learner.setLexicon(lexicon);
        bt.train(250);
        learner.save();
        foldParser.setFromPivot(true);
        foldParser.reset();
        unconstrainedPerformance.reportAll(EvaluateDiscrete.evaluateDiscrete(learner, learner.getLabeler(), foldParser));
        for (Pair<Classifier, EvaluateDiscrete> pair : classifiers) {
            foldParser.reset();
            pair.getSecond().reportAll(EvaluateDiscrete.evaluateDiscrete(pair.getFirst(), learner.getLabeler(), foldParser));
        }
    }
    for (Pair<Classifier, EvaluateDiscrete> pair : classifiers) {
        System.out.println(pair.getFirst().name + " " + pair.getSecond().getOverallStats()[2]);
    }
}
Also used : ListCommasConstrainedCommaClassifier(edu.illinois.cs.cogcomp.comma.lbj.ListCommasConstrainedCommaClassifier) OxfordCommaConstrainedCommaClassifier(edu.illinois.cs.cogcomp.comma.lbj.OxfordCommaConstrainedCommaClassifier) LocativePairConstrainedCommaClassifier(edu.illinois.cs.cogcomp.comma.lbj.LocativePairConstrainedCommaClassifier) Lexicon(edu.illinois.cs.cogcomp.lbjava.learn.Lexicon) ArrayList(java.util.ArrayList) OxfordCommaConstrainedCommaClassifier(edu.illinois.cs.cogcomp.comma.lbj.OxfordCommaConstrainedCommaClassifier) Classifier(edu.illinois.cs.cogcomp.lbjava.classify.Classifier) LocativePairConstrainedCommaClassifier(edu.illinois.cs.cogcomp.comma.lbj.LocativePairConstrainedCommaClassifier) StructuredCommaClassifier(edu.illinois.cs.cogcomp.comma.sl.StructuredCommaClassifier) SubstitutePairConstrainedCommaClassifier(edu.illinois.cs.cogcomp.comma.lbj.SubstitutePairConstrainedCommaClassifier) LocalCommaClassifier(edu.illinois.cs.cogcomp.comma.lbj.LocalCommaClassifier) ListCommasConstrainedCommaClassifier(edu.illinois.cs.cogcomp.comma.lbj.ListCommasConstrainedCommaClassifier) BatchTrainer(edu.illinois.cs.cogcomp.lbjava.learn.BatchTrainer) EvaluateDiscrete(edu.illinois.cs.cogcomp.comma.utils.EvaluateDiscrete) SubstitutePairConstrainedCommaClassifier(edu.illinois.cs.cogcomp.comma.lbj.SubstitutePairConstrainedCommaClassifier) SparseAveragedPerceptron(edu.illinois.cs.cogcomp.lbjava.learn.SparseAveragedPerceptron) Pair(edu.illinois.cs.cogcomp.core.datastructures.Pair) LocalCommaClassifier(edu.illinois.cs.cogcomp.comma.lbj.LocalCommaClassifier) FoldParser(edu.illinois.cs.cogcomp.lbjava.parse.FoldParser)

Example 15 with BatchTrainer

use of edu.illinois.cs.cogcomp.lbjava.learn.BatchTrainer in project cogcomp-nlp by CogComp.

the class ClassifierComparison method localCVal.

public static EvaluateDiscrete localCVal(boolean trainOnGold, boolean testOnGold, Parser parser, int learningRounds, double learningRate, double threshold, double thickness, boolean testOnTrain) {
    int k = 5;
    LocalCommaClassifier learner = new LocalCommaClassifier();
    learner.setLTU(new SparseAveragedPerceptron(learningRate, threshold, thickness));
    parser.reset();
    final FoldParser foldParser = new FoldParser(parser, k, SplitPolicy.sequential, 0, false);
    EvaluateDiscrete performanceRecord = new EvaluateDiscrete();
    for (int i = 0; i < k; foldParser.setPivot(++i)) {
        foldParser.setFromPivot(false);
        foldParser.reset();
        learner.forget();
        BatchTrainer bt = new BatchTrainer(learner, foldParser);
        Comma.useGoldFeatures(trainOnGold);
        Lexicon lexicon = bt.preExtract(null);
        learner.setLexicon(lexicon);
        bt.train(learningRounds);
        if (!testOnTrain)
            foldParser.setFromPivot(true);
        foldParser.reset();
        Comma.useGoldFeatures(testOnGold);
        EvaluateDiscrete currentPerformance = EvaluateDiscrete.evaluateDiscrete(learner, learner.getLabeler(), foldParser);
        performanceRecord.reportAll(currentPerformance);
    }
    // System.out.println(performanceRecord.getOverallStats()[2]);
    performanceRecord.printPerformance(System.out);
    // performanceRecord.printConfusion(System.out);
    return performanceRecord;
}
Also used : BatchTrainer(edu.illinois.cs.cogcomp.lbjava.learn.BatchTrainer) EvaluateDiscrete(edu.illinois.cs.cogcomp.comma.utils.EvaluateDiscrete) Lexicon(edu.illinois.cs.cogcomp.lbjava.learn.Lexicon) SparseAveragedPerceptron(edu.illinois.cs.cogcomp.lbjava.learn.SparseAveragedPerceptron) LocalCommaClassifier(edu.illinois.cs.cogcomp.comma.lbj.LocalCommaClassifier) FoldParser(edu.illinois.cs.cogcomp.lbjava.parse.FoldParser)

Aggregations

BatchTrainer (edu.illinois.cs.cogcomp.lbjava.learn.BatchTrainer)18 Lexicon (edu.illinois.cs.cogcomp.lbjava.learn.Lexicon)11 Learner (edu.illinois.cs.cogcomp.lbjava.learn.Learner)9 Relation (edu.illinois.cs.cogcomp.core.datastructures.textannotation.Relation)5 SparseAveragedPerceptron (edu.illinois.cs.cogcomp.lbjava.learn.SparseAveragedPerceptron)4 Parser (edu.illinois.cs.cogcomp.lbjava.parse.Parser)4 File (java.io.File)4 LbjGen.relation_classifier (org.cogcomp.re.LbjGen.relation_classifier)4 LocalCommaClassifier (edu.illinois.cs.cogcomp.comma.lbj.LocalCommaClassifier)2 EvaluateDiscrete (edu.illinois.cs.cogcomp.comma.utils.EvaluateDiscrete)2 TestDiscrete (edu.illinois.cs.cogcomp.lbjava.classify.TestDiscrete)2 FoldParser (edu.illinois.cs.cogcomp.lbjava.parse.FoldParser)2 NETaggerLevel1 (edu.illinois.cs.cogcomp.ner.LbjFeatures.NETaggerLevel1)2 NETaggerLevel2 (edu.illinois.cs.cogcomp.ner.LbjFeatures.NETaggerLevel2)2 IOException (java.io.IOException)2 ListCommasConstrainedCommaClassifier (edu.illinois.cs.cogcomp.comma.lbj.ListCommasConstrainedCommaClassifier)1 LocativePairConstrainedCommaClassifier (edu.illinois.cs.cogcomp.comma.lbj.LocativePairConstrainedCommaClassifier)1 OxfordCommaConstrainedCommaClassifier (edu.illinois.cs.cogcomp.comma.lbj.OxfordCommaConstrainedCommaClassifier)1 SubstitutePairConstrainedCommaClassifier (edu.illinois.cs.cogcomp.comma.lbj.SubstitutePairConstrainedCommaClassifier)1 StructuredCommaClassifier (edu.illinois.cs.cogcomp.comma.sl.StructuredCommaClassifier)1