Search in sources :

Example 1 with BatchTrainer

use of edu.illinois.cs.cogcomp.lbjava.learn.BatchTrainer in project cogcomp-nlp by CogComp.

the class LearningCurveMultiDataset method getLearningCurve.

/**
     * use fixedNumIterations=-1 if you want to use the automatic convergence criterion
     * <p>
     * NB: assuming column format
     */
public static void getLearningCurve(Vector<Data> trainDataSet, Vector<Data> testDataSet, int fixedNumIterations) throws Exception {
    double bestF1Level1 = -1;
    int bestRoundLevel1 = 0;
    // Get the directory name (<configname>.model is appended in LbjTagger/Parameters.java:139)
    String modelPath = ParametersForLbjCode.currentParameters.pathToModelFile;
    String modelPathDir = modelPath.substring(0, modelPath.lastIndexOf("/"));
    if (IOUtils.exists(modelPathDir)) {
        if (!IOUtils.isDirectory(modelPathDir)) {
            String msg = "ERROR: " + NAME + ".getLearningCurve(): model directory '" + modelPathDir + "' already exists as a (non-directory) file.";
            logger.error(msg);
            throw new IOException(msg);
        } else
            logger.warn(NAME + ".getLearningCurve(): writing to existing model path '" + modelPathDir + "'...");
    } else {
        IOUtils.mkdir(modelPathDir);
    }
    NETaggerLevel1.Parameters paramLevel1 = new NETaggerLevel1.Parameters();
    paramLevel1.baseLTU = new SparseAveragedPerceptron(ParametersForLbjCode.currentParameters.learningRatePredictionsLevel1, 0, ParametersForLbjCode.currentParameters.thicknessPredictionsLevel1);
    logger.info("Level 1 classifier learning rate = " + ParametersForLbjCode.currentParameters.learningRatePredictionsLevel1 + ", thickness = " + ParametersForLbjCode.currentParameters.thicknessPredictionsLevel1);
    NETaggerLevel1 tagger1 = new NETaggerLevel1(paramLevel1, modelPath + ".level1", modelPath + ".level1.lex");
    tagger1.forget();
    for (int dataId = 0; dataId < trainDataSet.size(); dataId++) {
        Data trainData = trainDataSet.elementAt(dataId);
        if (ParametersForLbjCode.currentParameters.featuresToUse.containsKey("PredictionsLevel1")) {
            PredictionsAndEntitiesConfidenceScores.getAndMarkEntities(trainData, NEWord.LabelToLookAt.GoldLabel);
            TwoLayerPredictionAggregationFeatures.setLevel1AggregationFeatures(trainData, true);
        }
    }
    // preextract the L1 test and train data.
    String path = ParametersForLbjCode.currentParameters.pathToModelFile;
    String trainPathL1 = path + ".level1.prefetchedTrainData";
    File deleteme = new File(trainPathL1);
    if (deleteme.exists())
        deleteme.delete();
    String testPathL1 = path + ".level1.prefetchedTestData";
    deleteme = new File(testPathL1);
    if (deleteme.exists())
        deleteme.delete();
    logger.info("Pre-extracting the training data for Level 1 classifier, saving to " + trainPathL1);
    BatchTrainer bt1train = prefetchAndGetBatchTrainer(tagger1, trainDataSet, trainPathL1);
    logger.info("Pre-extracting the testing data for Level 1 classifier, saving to " + testPathL1);
    BatchTrainer bt1test = prefetchAndGetBatchTrainer(tagger1, testDataSet, testPathL1);
    Parser testParser1 = bt1test.getParser();
    for (int i = 0; (fixedNumIterations == -1 && i < 200 && i - bestRoundLevel1 < 10) || (fixedNumIterations > 0 && i <= fixedNumIterations); ++i) {
        bt1train.train(1);
        testParser1.reset();
        TestDiscrete simpleTest = new TestDiscrete();
        simpleTest.addNull("O");
        TestDiscrete.testDiscrete(simpleTest, tagger1, null, testParser1, true, 0);
        double f1Level1 = simpleTest.getOverallStats()[2];
        if (f1Level1 > bestF1Level1) {
            bestF1Level1 = f1Level1;
            bestRoundLevel1 = i;
            tagger1.save();
        }
        logger.info(i + " rounds.  Best so far for Level1 : (" + bestRoundLevel1 + ")=" + bestF1Level1);
    }
    logger.info("Level 1; best round : " + bestRoundLevel1 + "\tbest F1 : " + bestF1Level1);
    // trash the l2 prefetch data
    String trainPathL2 = path + ".level2.prefetchedTrainData";
    deleteme = new File(trainPathL2);
    if (deleteme.exists())
        deleteme.delete();
    String testPathL2 = path + ".level2.prefetchedTestData";
    deleteme = new File(testPathL1);
    if (deleteme.exists())
        deleteme.delete();
    NETaggerLevel2.Parameters paramLevel2 = new NETaggerLevel2.Parameters();
    paramLevel2.baseLTU = new SparseAveragedPerceptron(ParametersForLbjCode.currentParameters.learningRatePredictionsLevel2, 0, ParametersForLbjCode.currentParameters.thicknessPredictionsLevel2);
    NETaggerLevel2 tagger2 = new NETaggerLevel2(paramLevel2, ParametersForLbjCode.currentParameters.pathToModelFile + ".level2", ParametersForLbjCode.currentParameters.pathToModelFile + ".level2.lex");
    tagger2.forget();
    // Previously checked if PatternFeatures was in featuresToUse.
    if (ParametersForLbjCode.currentParameters.featuresToUse.containsKey("PredictionsLevel1")) {
        logger.info("Level 2 classifier learning rate = " + ParametersForLbjCode.currentParameters.learningRatePredictionsLevel2 + ", thickness = " + ParametersForLbjCode.currentParameters.thicknessPredictionsLevel2);
        double bestF1Level2 = -1;
        int bestRoundLevel2 = 0;
        logger.info("Pre-extracting the training data for Level 2 classifier, saving to " + trainPathL2);
        BatchTrainer bt2train = prefetchAndGetBatchTrainer(tagger2, trainDataSet, trainPathL2);
        logger.info("Pre-extracting the testing data for Level 2 classifier, saving to " + testPathL2);
        BatchTrainer bt2test = prefetchAndGetBatchTrainer(tagger2, testDataSet, testPathL2);
        Parser testParser2 = bt2test.getParser();
        for (int i = 0; (fixedNumIterations == -1 && i < 200 && i - bestRoundLevel2 < 10) || (fixedNumIterations > 0 && i <= fixedNumIterations); ++i) {
            logger.info("Learning level 2 classifier; round " + i);
            bt2train.train(1);
            logger.info("Testing level 2 classifier;  on prefetched data, round: " + i);
            testParser2.reset();
            TestDiscrete simpleTest = new TestDiscrete();
            simpleTest.addNull("O");
            TestDiscrete.testDiscrete(simpleTest, tagger2, null, testParser2, true, 0);
            double f1Level2 = simpleTest.getOverallStats()[2];
            if (f1Level2 > bestF1Level2) {
                bestF1Level2 = f1Level2;
                bestRoundLevel2 = i;
                tagger2.save();
            }
            logger.info(i + " rounds.  Best so far for Level2 : (" + bestRoundLevel2 + ") " + bestF1Level2);
        }
        // trash the l2 prefetch data
        deleteme = new File(trainPathL2);
        if (deleteme.exists())
            deleteme.delete();
        deleteme = new File(testPathL1);
        if (deleteme.exists())
            deleteme.delete();
        logger.info("Level1: bestround=" + bestRoundLevel1 + "\t F1=" + bestF1Level1 + "\t Level2: bestround=" + bestRoundLevel2 + "\t F1=" + bestF1Level2);
    }
    /*
         * This will override the models forcing to save the iteration we're interested in- the
         * fixedNumIterations iteration, the last one. But note - both layers will be saved for this
         * iteration. If the best performance for one of the layers came before the final iteration,
         * we're in a small trouble- the performance will decrease
         */
    if (fixedNumIterations > -1) {
        tagger1.save();
        tagger2.save();
    }
}
Also used : NETaggerLevel2(edu.illinois.cs.cogcomp.ner.LbjFeatures.NETaggerLevel2) NETaggerLevel1(edu.illinois.cs.cogcomp.ner.LbjFeatures.NETaggerLevel1) TestDiscrete(edu.illinois.cs.cogcomp.lbjava.classify.TestDiscrete) IOException(java.io.IOException) Parser(edu.illinois.cs.cogcomp.lbjava.parse.Parser) BatchTrainer(edu.illinois.cs.cogcomp.lbjava.learn.BatchTrainer) File(java.io.File) SparseAveragedPerceptron(edu.illinois.cs.cogcomp.lbjava.learn.SparseAveragedPerceptron)

Example 2 with BatchTrainer

use of edu.illinois.cs.cogcomp.lbjava.learn.BatchTrainer in project cogcomp-nlp by CogComp.

the class BIOTester method train_nom_classifier.

/**
 * Trainer for the head nominal classifier.
 * @param train_parser The parser containing all training examples
 * @param modelLoc The expected model file destination. Support null.
 */
public static bio_classifier_nom train_nom_classifier(Parser train_parser, String modelLoc) {
    bio_classifier_nom classifier = new bio_classifier_nom();
    train_parser.reset();
    BatchTrainer trainer = new BatchTrainer(classifier, train_parser);
    String modelFileName = "";
    if (modelLoc == null) {
        String parser_id = ((BIOReader) train_parser).id;
        modelFileName = "tmp/bio_classifier_" + parser_id;
    } else {
        modelFileName = modelLoc;
    }
    classifier.setLexiconLocation(modelFileName + ".lex");
    Learner preExtractLearner = trainer.preExtract(modelFileName + ".ex", true, Lexicon.CountPolicy.none);
    preExtractLearner.saveLexicon();
    Lexicon lexicon = preExtractLearner.getLexicon();
    classifier.setLexicon(lexicon);
    int examples = 0;
    for (Object example = train_parser.next(); example != null; example = train_parser.next()) {
        examples++;
    }
    train_parser.reset();
    classifier.initialize(examples, preExtractLearner.getLexicon().size());
    for (Object example = train_parser.next(); example != null; example = train_parser.next()) {
        classifier.learn(example);
    }
    train_parser.reset();
    classifier.doneWithRound();
    classifier.doneLearning();
    if (modelLoc != null) {
        classifier.setModelLocation(modelFileName + ".lc");
        classifier.saveModel();
    }
    return classifier;
}
Also used : BatchTrainer(edu.illinois.cs.cogcomp.lbjava.learn.BatchTrainer) Lexicon(edu.illinois.cs.cogcomp.lbjava.learn.Lexicon) Learner(edu.illinois.cs.cogcomp.lbjava.learn.Learner) LbjGen.bio_classifier_nom(org.cogcomp.md.LbjGen.bio_classifier_nom)

Example 3 with BatchTrainer

use of edu.illinois.cs.cogcomp.lbjava.learn.BatchTrainer in project cogcomp-nlp by CogComp.

the class Quantifier method trainOnAll.

public void trainOnAll() {
    QuantitiesClassifier classifier = new QuantitiesClassifier(modelName + ".lc", modelName + ".lex");
    QuantitiesDataReader trainReader = new QuantitiesDataReader(dataDir + "/allData.txt", "train");
    BatchTrainer trainer = new BatchTrainer(classifier, trainReader);
    trainer.train(45);
    classifier.save();
}
Also used : BatchTrainer(edu.illinois.cs.cogcomp.lbjava.learn.BatchTrainer)

Example 4 with BatchTrainer

use of edu.illinois.cs.cogcomp.lbjava.learn.BatchTrainer in project cogcomp-nlp by CogComp.

the class Main method train.

public void train() {
    if (!IOUtils.exists(modelsDir))
        IOUtils.mkdir(modelsDir);
    Learner classifier = new PrepSRLClassifier(modelName + ".lc", modelName + ".lex");
    Parser trainDataReader = new PrepSRLDataReader(dataDir, "train");
    BatchTrainer trainer = new BatchTrainer(classifier, trainDataReader, 1000);
    trainer.train(20);
    classifier.save();
    trainDataReader.close();
}
Also used : BatchTrainer(edu.illinois.cs.cogcomp.lbjava.learn.BatchTrainer) ConstrainedPrepSRLClassifier(edu.illinois.cs.cogcomp.prepsrl.inference.ConstrainedPrepSRLClassifier) PrepSRLDataReader(edu.illinois.cs.cogcomp.prepsrl.data.PrepSRLDataReader) Learner(edu.illinois.cs.cogcomp.lbjava.learn.Learner) Parser(edu.illinois.cs.cogcomp.lbjava.parse.Parser)

Example 5 with BatchTrainer

use of edu.illinois.cs.cogcomp.lbjava.learn.BatchTrainer in project cogcomp-nlp by CogComp.

the class ACERelationTester method test_cv_gold.

/*
     * This function only tests the constrained classifier
     * It performs a similar five-fold cv
     */
public static void test_cv_gold() {
    int total_correct = 0;
    int total_labeled = 0;
    int total_predicted = 0;
    int total_coarse_correct = 0;
    for (int i = 0; i < 5; i++) {
        fine_relation_label output = new fine_relation_label();
        ACEMentionReader train_parser = IOHelper.readFiveFold(i, "TRAIN");
        relation_classifier classifier = new relation_classifier();
        classifier.setLexiconLocation("models/relation_classifier_fold_" + i + ".lex");
        BatchTrainer trainer = new BatchTrainer(classifier, train_parser);
        Learner preExtractLearner = trainer.preExtract("models/relation_classifier_fold_" + i + ".ex", true, Lexicon.CountPolicy.none);
        preExtractLearner.saveLexicon();
        Lexicon lexicon = preExtractLearner.getLexicon();
        classifier.setLexicon(lexicon);
        int examples = train_parser.relations_bi.size();
        classifier.initialize(examples, preExtractLearner.getLexicon().size());
        for (Relation r : train_parser.relations_bi) {
            classifier.learn(r);
        }
        classifier.doneWithRound();
        classifier.doneLearning();
        ACERelationConstrainedClassifier constrainedClassifier = new ACERelationConstrainedClassifier(classifier);
        ACEMentionReader test_parser = IOHelper.readFiveFold(i, "TEST");
        for (Relation r : test_parser.relations_bi) {
            String predicted_label = constrainedClassifier.discreteValue(r);
            String gold_label = output.discreteValue(r);
            Relation oppoR = new Relation("TO_TEST", r.getTarget(), r.getSource(), 1.0f);
            String oppo_predicted_label = constrainedClassifier.discreteValue(oppoR);
            if (!predicted_label.equals(ACEMentionReader.getOppoName(oppo_predicted_label))) {
                ScoreSet scores = classifier.scores(r);
                Score[] scoresArray = scores.toArray();
                double score_curtag = 0.0;
                for (Score score : scoresArray) {
                    if (score.value.equals(predicted_label)) {
                        score_curtag = score.score;
                    }
                }
                scores = classifier.scores((Object) oppoR);
                scoresArray = scores.toArray();
                double oppo_score_opptag = 0.0;
                for (Score score : scoresArray) {
                    if (score.value.equals(oppo_predicted_label)) {
                        oppo_score_opptag = score.score;
                    }
                }
                if (score_curtag < oppo_score_opptag && oppo_score_opptag - score_curtag > 0.005) {
                    predicted_label = ACEMentionReader.getOppoName(oppo_predicted_label);
                }
            }
            if (!predicted_label.equals("NOT_RELATED")) {
                total_predicted++;
            }
            if (!gold_label.equals("NOT_RELATED")) {
                total_labeled++;
            }
            if (predicted_label.equals(gold_label)) {
                if (!predicted_label.equals("NOT_RELATED")) {
                    total_correct++;
                }
            }
            if (getCoarseType(predicted_label).equals(getCoarseType(gold_label))) {
                if (!predicted_label.equals("NOT_RELATED")) {
                    total_coarse_correct++;
                }
            }
        }
        classifier.forget();
    }
    System.out.println("Total labeled: " + total_labeled);
    System.out.println("Total predicted: " + total_predicted);
    System.out.println("Total correct: " + total_correct);
    System.out.println("Total coarse correct: " + total_coarse_correct);
    double p = (double) total_correct * 100.0 / (double) total_predicted;
    double r = (double) total_correct * 100.0 / (double) total_labeled;
    double f = 2 * p * r / (p + r);
    System.out.println("Precision: " + p);
    System.out.println("Recall: " + r);
    System.out.println("Fine Type F1: " + f);
    System.out.println("Coarse Type F1: " + f * (double) total_coarse_correct / (double) total_correct);
}
Also used : Lexicon(edu.illinois.cs.cogcomp.lbjava.learn.Lexicon) Learner(edu.illinois.cs.cogcomp.lbjava.learn.Learner) Relation(edu.illinois.cs.cogcomp.core.datastructures.textannotation.Relation) Score(edu.illinois.cs.cogcomp.lbjava.classify.Score) BatchTrainer(edu.illinois.cs.cogcomp.lbjava.learn.BatchTrainer) LbjGen.fine_relation_label(org.cogcomp.re.LbjGen.fine_relation_label) LbjGen.relation_classifier(org.cogcomp.re.LbjGen.relation_classifier) ScoreSet(edu.illinois.cs.cogcomp.lbjava.classify.ScoreSet)

Aggregations

BatchTrainer (edu.illinois.cs.cogcomp.lbjava.learn.BatchTrainer)18 Lexicon (edu.illinois.cs.cogcomp.lbjava.learn.Lexicon)11 Learner (edu.illinois.cs.cogcomp.lbjava.learn.Learner)9 Relation (edu.illinois.cs.cogcomp.core.datastructures.textannotation.Relation)5 SparseAveragedPerceptron (edu.illinois.cs.cogcomp.lbjava.learn.SparseAveragedPerceptron)4 Parser (edu.illinois.cs.cogcomp.lbjava.parse.Parser)4 File (java.io.File)4 LbjGen.relation_classifier (org.cogcomp.re.LbjGen.relation_classifier)4 LocalCommaClassifier (edu.illinois.cs.cogcomp.comma.lbj.LocalCommaClassifier)2 EvaluateDiscrete (edu.illinois.cs.cogcomp.comma.utils.EvaluateDiscrete)2 TestDiscrete (edu.illinois.cs.cogcomp.lbjava.classify.TestDiscrete)2 FoldParser (edu.illinois.cs.cogcomp.lbjava.parse.FoldParser)2 NETaggerLevel1 (edu.illinois.cs.cogcomp.ner.LbjFeatures.NETaggerLevel1)2 NETaggerLevel2 (edu.illinois.cs.cogcomp.ner.LbjFeatures.NETaggerLevel2)2 IOException (java.io.IOException)2 ListCommasConstrainedCommaClassifier (edu.illinois.cs.cogcomp.comma.lbj.ListCommasConstrainedCommaClassifier)1 LocativePairConstrainedCommaClassifier (edu.illinois.cs.cogcomp.comma.lbj.LocativePairConstrainedCommaClassifier)1 OxfordCommaConstrainedCommaClassifier (edu.illinois.cs.cogcomp.comma.lbj.OxfordCommaConstrainedCommaClassifier)1 SubstitutePairConstrainedCommaClassifier (edu.illinois.cs.cogcomp.comma.lbj.SubstitutePairConstrainedCommaClassifier)1 StructuredCommaClassifier (edu.illinois.cs.cogcomp.comma.sl.StructuredCommaClassifier)1