Search in sources :

Example 1 with LinearClassifierFactory

use of edu.stanford.nlp.classify.LinearClassifierFactory in project CoreNLP by stanfordnlp.

the class ChineseMaxentLexicon method finishTraining.

@Override
public void finishTraining() {
    IntCounter<String> tagCounter = new IntCounter<>();
    WeightedDataset data = new WeightedDataset(datumCounter.size());
    for (TaggedWord word : datumCounter.keySet()) {
        int count = datumCounter.getIntCount(word);
        if (trainOnLowCount && count > trainCountThreshold) {
            continue;
        }
        if (functionWordTags.containsKey(word.word())) {
            continue;
        }
        tagCounter.incrementCount(word.tag());
        if (trainByType) {
            count = 1;
        }
        data.add(new BasicDatum(featExtractor.makeFeatures(word.word()), word.tag()), count);
    }
    datumCounter = null;
    tagDist = Distribution.laplaceSmoothedDistribution(tagCounter, tagCounter.size(), 0.5);
    tagCounter = null;
    applyThresholds(data);
    verbose("Making classifier...");
    // new ResultStoringMonitor(5, "weights"));
    QNMinimizer minim = new QNMinimizer();
    // minim.shutUp();
    LinearClassifierFactory factory = new LinearClassifierFactory(minim);
    factory.setTol(tol);
    factory.setSigma(sigma);
    if (tuneSigma) {
        factory.setTuneSigmaHeldOut();
    }
    scorer = factory.trainClassifier(data);
    verbose("Done training.");
}
Also used : TaggedWord(edu.stanford.nlp.ling.TaggedWord) LinearClassifierFactory(edu.stanford.nlp.classify.LinearClassifierFactory) WeightedDataset(edu.stanford.nlp.classify.WeightedDataset) QNMinimizer(edu.stanford.nlp.optimization.QNMinimizer) BasicDatum(edu.stanford.nlp.ling.BasicDatum)

Example 2 with LinearClassifierFactory

use of edu.stanford.nlp.classify.LinearClassifierFactory in project CoreNLP by stanfordnlp.

the class EntityClassifier method train.

private static void train(List<SceneGraphImage> images, String modelPath, Embedding embeddings) throws IOException {
    RVFDataset<String, String> dataset = new RVFDataset<String, String>();
    SceneGraphSentenceMatcher sentenceMatcher = new SceneGraphSentenceMatcher(embeddings);
    for (SceneGraphImage img : images) {
        for (SceneGraphImageRegion region : img.regions) {
            SemanticGraph sg = region.getEnhancedSemanticGraph();
            SemanticGraphEnhancer.enhance(sg);
            List<Triple<IndexedWord, IndexedWord, String>> relationTriples = sentenceMatcher.getRelationTriples(region);
            for (Triple<IndexedWord, IndexedWord, String> relation : relationTriples) {
                IndexedWord w1 = sg.getNodeByIndexSafe(relation.first.index());
                if (w1 != null) {
                    dataset.add(getDatum(w1, relation.first.get(SceneGraphCoreAnnotations.GoldEntityAnnotation.class), embeddings));
                }
            }
        }
    }
    LinearClassifierFactory<String, String> classifierFactory = new LinearClassifierFactory<String, String>(new QNMinimizer(15), 1e-4, false, REG_STRENGTH);
    Classifier<String, String> classifier = classifierFactory.trainClassifier(dataset);
    IOUtils.writeObjectToFile(classifier, modelPath);
    System.err.println(classifier.evaluateAccuracy(dataset));
}
Also used : RVFDataset(edu.stanford.nlp.classify.RVFDataset) SceneGraphImage(edu.stanford.nlp.scenegraph.image.SceneGraphImage) QNMinimizer(edu.stanford.nlp.optimization.QNMinimizer) Triple(edu.stanford.nlp.util.Triple) LinearClassifierFactory(edu.stanford.nlp.classify.LinearClassifierFactory) SemanticGraph(edu.stanford.nlp.semgraph.SemanticGraph) IndexedWord(edu.stanford.nlp.ling.IndexedWord) SceneGraphImageRegion(edu.stanford.nlp.scenegraph.image.SceneGraphImageRegion)

Example 3 with LinearClassifierFactory

use of edu.stanford.nlp.classify.LinearClassifierFactory in project CoreNLP by stanfordnlp.

the class BoWSceneGraphParser method train.

/**
 * Trains a classifier using the examples in trainingFile and saves
 * it to modelPath.
 *
 * @param trainingFile Path to JSON file with images and scene graphs.
 * @param modelPath
 * @throws IOException
 */
public void train(String trainingFile, String modelPath) throws IOException {
    LinearClassifierFactory<String, String> classifierFactory = new LinearClassifierFactory<String, String>(new QNMinimizer(15), 1e-4, false, REG_STRENGTH);
    /* Create dataset. */
    Dataset<String, String> dataset = getTrainingExamples(trainingFile, true);
    /* Train the classifier. */
    Classifier<String, String> classifier = classifierFactory.trainClassifier(dataset);
    /* Save classifier to disk. */
    IOUtils.writeObjectToFile(classifier, modelPath);
}
Also used : LinearClassifierFactory(edu.stanford.nlp.classify.LinearClassifierFactory) QNMinimizer(edu.stanford.nlp.optimization.QNMinimizer)

Aggregations

LinearClassifierFactory (edu.stanford.nlp.classify.LinearClassifierFactory)3 QNMinimizer (edu.stanford.nlp.optimization.QNMinimizer)3 RVFDataset (edu.stanford.nlp.classify.RVFDataset)1 WeightedDataset (edu.stanford.nlp.classify.WeightedDataset)1 BasicDatum (edu.stanford.nlp.ling.BasicDatum)1 IndexedWord (edu.stanford.nlp.ling.IndexedWord)1 TaggedWord (edu.stanford.nlp.ling.TaggedWord)1 SceneGraphImage (edu.stanford.nlp.scenegraph.image.SceneGraphImage)1 SceneGraphImageRegion (edu.stanford.nlp.scenegraph.image.SceneGraphImageRegion)1 SemanticGraph (edu.stanford.nlp.semgraph.SemanticGraph)1 Triple (edu.stanford.nlp.util.Triple)1