Search in sources :

Example 6 with RVFDatum

use of edu.stanford.nlp.ling.RVFDatum in project CoreNLP by stanfordnlp.

the class ClauseSplitter method train.

/**
   * Train a clause searcher factory. That is, train a classifier for which arcs should be
   * new clauses.
   *
   * @param trainingData The training data. This is a stream of triples of:
   *                     <ol>
   *                       <li>The sentence containing a known extraction.</li>
   *                       <li>The span of the subject in the sentence, as a token span.</li>
   *                       <li>The span of the object in the sentence, as a token span.</li>
   *                     </ol>
   * @param modelPath The path to save the model to. This is useful for {@link ClauseSplitter#load(String)}.
   * @param trainingDataDump The path to save the training data, as a set of labeled featurized datums.
   * @param featurizer The featurizer to use for this classifier.
   *
   * @return A factory for creating searchers from a given dependency tree.
   */
static ClauseSplitter train(Stream<Pair<CoreMap, Collection<Pair<Span, Span>>>> trainingData, Optional<File> modelPath, Optional<File> trainingDataDump, Featurizer featurizer) {
    // Parse options
    LinearClassifierFactory<ClauseClassifierLabel, String> factory = new LinearClassifierFactory<>();
    // Generally useful objects
    OpenIE openie = new OpenIE(PropertiesUtils.asProperties("splitter.nomodel", "true", "optimizefor", "GENERAL"));
    WeightedDataset<ClauseClassifierLabel, String> dataset = new WeightedDataset<>();
    AtomicInteger numExamplesProcessed = new AtomicInteger(0);
    final Optional<PrintWriter> datasetDumpWriter = trainingDataDump.map(file -> {
        try {
            return new PrintWriter(new OutputStreamWriter(new GZIPOutputStream(new FileOutputStream(trainingDataDump.get()))));
        } catch (IOException e) {
            throw new RuntimeIOException(e);
        }
    });
    // Step 1: Loop over data
    forceTrack("Training inference");
    trainingData.forEach(rawExample -> {
        CoreMap sentence = rawExample.first;
        Collection<Pair<Span, Span>> spans = rawExample.second;
        List<CoreLabel> tokens = sentence.get(CoreAnnotations.TokensAnnotation.class);
        SemanticGraph tree = sentence.get(SemanticGraphCoreAnnotations.EnhancedDependenciesAnnotation.class);
        ClauseSplitterSearchProblem problem = new ClauseSplitterSearchProblem(tree, true);
        problem.search(fragmentAndScore -> {
            List<Counter<String>> features = fragmentAndScore.second;
            SentenceFragment fragment = fragmentAndScore.third.get();
            Set<RelationTriple> extractions = new HashSet<>(openie.relationsInFragments(openie.entailmentsFromClause(fragment)));
            Trilean correct = Trilean.FALSE;
            RELATION_TRIPLE_LOOP: for (RelationTriple extraction : extractions) {
                Span subjectGuess = Span.fromValues(extraction.subject.get(0).index() - 1, extraction.subject.get(extraction.subject.size() - 1).index());
                Span objectGuess = Span.fromValues(extraction.object.get(0).index() - 1, extraction.object.get(extraction.object.size() - 1).index());
                for (Pair<Span, Span> candidateGold : spans) {
                    Span subjectSpan = candidateGold.first;
                    Span objectSpan = candidateGold.second;
                    if ((subjectGuess.equals(subjectSpan) && objectGuess.equals(objectSpan)) || (subjectGuess.equals(objectSpan) && objectGuess.equals(subjectSpan))) {
                        correct = Trilean.TRUE;
                        break RELATION_TRIPLE_LOOP;
                    } else if (Util.nerOverlap(tokens, subjectSpan, subjectGuess) && Util.nerOverlap(tokens, objectSpan, objectGuess) || Util.nerOverlap(tokens, subjectSpan, objectGuess) && Util.nerOverlap(tokens, objectSpan, subjectGuess)) {
                        if (!correct.isTrue()) {
                            correct = Trilean.TRUE;
                            break RELATION_TRIPLE_LOOP;
                        }
                    } else {
                        if (!correct.isTrue()) {
                            correct = Trilean.UNKNOWN;
                            break RELATION_TRIPLE_LOOP;
                        }
                    }
                }
            }
            if (!features.isEmpty()) {
                List<Pair<Counter<String>, ClauseClassifierLabel>> decisionsToAddAsDatums = new ArrayList<>();
                if (correct.isTrue()) {
                    for (int i = 0; i < features.size(); ++i) {
                        if (i == features.size() - 1) {
                            decisionsToAddAsDatums.add(Pair.makePair(features.get(i), ClauseClassifierLabel.CLAUSE_SPLIT));
                        } else {
                            decisionsToAddAsDatums.add(Pair.makePair(features.get(i), ClauseClassifierLabel.CLAUSE_INTERM));
                        }
                    }
                } else if (correct.isFalse()) {
                    decisionsToAddAsDatums.add(Pair.makePair(features.get(features.size() - 1), ClauseClassifierLabel.NOT_A_CLAUSE));
                } else if (correct.isUnknown()) {
                    boolean isSimpleSplit = false;
                    for (Counter<String> feats : features) {
                        if (featurizer.isSimpleSplit(feats)) {
                            isSimpleSplit = true;
                            break;
                        }
                    }
                    if (isSimpleSplit) {
                        for (int i = 0; i < features.size(); ++i) {
                            if (i == features.size() - 1) {
                                decisionsToAddAsDatums.add(Pair.makePair(features.get(i), ClauseClassifierLabel.CLAUSE_SPLIT));
                            } else {
                                decisionsToAddAsDatums.add(Pair.makePair(features.get(i), ClauseClassifierLabel.CLAUSE_INTERM));
                            }
                        }
                    }
                }
                for (Pair<Counter<String>, ClauseClassifierLabel> decision : decisionsToAddAsDatums) {
                    RVFDatum<ClauseClassifierLabel, String> datum = new RVFDatum<>(decision.first);
                    datum.setLabel(decision.second);
                    if (datasetDumpWriter.isPresent()) {
                        datasetDumpWriter.get().println(decision.second + "\t" + StringUtils.join(decision.first.entrySet().stream().map(entry -> entry.getKey() + "->" + entry.getValue()), ";"));
                    }
                    dataset.add(datum);
                }
            }
            return true;
        }, new LinearClassifier<>(new ClassicCounter<>()), Collections.emptyMap(), featurizer, 10000);
        if (numExamplesProcessed.incrementAndGet() % 100 == 0) {
            log("processed " + numExamplesProcessed + " training sentences: " + dataset.size() + " datums");
        }
    });
    endTrack("Training inference");
    // Close the file
    if (datasetDumpWriter.isPresent()) {
        datasetDumpWriter.get().close();
    }
    // Step 2: Train classifier
    forceTrack("Training");
    Classifier<ClauseClassifierLabel, String> fullClassifier = factory.trainClassifier(dataset);
    endTrack("Training");
    if (modelPath.isPresent()) {
        Pair<Classifier<ClauseClassifierLabel, String>, Featurizer> toSave = Pair.makePair(fullClassifier, featurizer);
        try {
            IOUtils.writeObjectToFile(toSave, modelPath.get());
            log("SUCCESS: wrote model to " + modelPath.get().getPath());
        } catch (IOException e) {
            log("ERROR: failed to save model to path: " + modelPath.get().getPath());
            err(e);
        }
    }
    // Step 3: Check accuracy of classifier
    forceTrack("Training accuracy");
    dataset.randomize(42L);
    Util.dumpAccuracy(fullClassifier, dataset);
    endTrack("Training accuracy");
    int numFolds = 5;
    forceTrack(numFolds + " fold cross-validation");
    for (int fold = 0; fold < numFolds; ++fold) {
        forceTrack("Fold " + (fold + 1));
        forceTrack("Training");
        Pair<GeneralDataset<ClauseClassifierLabel, String>, GeneralDataset<ClauseClassifierLabel, String>> foldData = dataset.splitOutFold(fold, numFolds);
        Classifier<ClauseClassifierLabel, String> classifier = factory.trainClassifier(foldData.first);
        endTrack("Training");
        forceTrack("Test");
        Util.dumpAccuracy(classifier, foldData.second);
        endTrack("Test");
        endTrack("Fold " + (fold + 1));
    }
    endTrack(numFolds + " fold cross-validation");
    // Step 5: return factory
    return (tree, truth) -> new ClauseSplitterSearchProblem(tree, truth, Optional.of(fullClassifier), Optional.of(featurizer));
}
Also used : CoreLabel(edu.stanford.nlp.ling.CoreLabel) java.util(java.util) CoreAnnotations(edu.stanford.nlp.ling.CoreAnnotations) IOUtils(edu.stanford.nlp.io.IOUtils) edu.stanford.nlp.util(edu.stanford.nlp.util) BiFunction(java.util.function.BiFunction) Redwood(edu.stanford.nlp.util.logging.Redwood) Util(edu.stanford.nlp.util.logging.Redwood.Util) Span(edu.stanford.nlp.ie.machinereading.structure.Span) Counter(edu.stanford.nlp.stats.Counter) Stream(java.util.stream.Stream) java.io(java.io) SemanticGraphCoreAnnotations(edu.stanford.nlp.semgraph.SemanticGraphCoreAnnotations) AtomicInteger(java.util.concurrent.atomic.AtomicInteger) edu.stanford.nlp.classify(edu.stanford.nlp.classify) RelationTriple(edu.stanford.nlp.ie.util.RelationTriple) RuntimeIOException(edu.stanford.nlp.io.RuntimeIOException) SemanticGraph(edu.stanford.nlp.semgraph.SemanticGraph) GZIPOutputStream(java.util.zip.GZIPOutputStream) ClassicCounter(edu.stanford.nlp.stats.ClassicCounter) RVFDatum(edu.stanford.nlp.ling.RVFDatum) ClauseSplitterSearchProblem(edu.stanford.nlp.naturalli.ClauseSplitterSearchProblem) Counter(edu.stanford.nlp.stats.Counter) ClassicCounter(edu.stanford.nlp.stats.ClassicCounter) GZIPOutputStream(java.util.zip.GZIPOutputStream) RelationTriple(edu.stanford.nlp.ie.util.RelationTriple) ClauseSplitterSearchProblem(edu.stanford.nlp.naturalli.ClauseSplitterSearchProblem) SemanticGraphCoreAnnotations(edu.stanford.nlp.semgraph.SemanticGraphCoreAnnotations) CoreLabel(edu.stanford.nlp.ling.CoreLabel) AtomicInteger(java.util.concurrent.atomic.AtomicInteger) SemanticGraph(edu.stanford.nlp.semgraph.SemanticGraph) Span(edu.stanford.nlp.ie.machinereading.structure.Span) RVFDatum(edu.stanford.nlp.ling.RVFDatum) RuntimeIOException(edu.stanford.nlp.io.RuntimeIOException) RuntimeIOException(edu.stanford.nlp.io.RuntimeIOException) CoreAnnotations(edu.stanford.nlp.ling.CoreAnnotations) SemanticGraphCoreAnnotations(edu.stanford.nlp.semgraph.SemanticGraphCoreAnnotations) ClassicCounter(edu.stanford.nlp.stats.ClassicCounter)

Example 7 with RVFDatum

use of edu.stanford.nlp.ling.RVFDatum in project CoreNLP by stanfordnlp.

the class ScorePhrasesLearnFeatWt method learnClassifier.

public edu.stanford.nlp.classify.Classifier learnClassifier(String label, boolean forLearningPatterns, TwoDimensionalCounter<CandidatePhrase, E> wordsPatExtracted, Counter<E> allSelectedPatterns) throws IOException, ClassNotFoundException {
    phraseScoresRaw.clear();
    learnedScores.clear();
    if (Data.domainNGramsFile != null)
        Data.loadDomainNGrams();
    boolean computeRawFreq = false;
    if (Data.rawFreq == null) {
        Data.rawFreq = new ClassicCounter<>();
        computeRawFreq = true;
    }
    GeneralDataset<String, ScorePhraseMeasures> dataset = choosedatums(forLearningPatterns, label, wordsPatExtracted, allSelectedPatterns, computeRawFreq);
    edu.stanford.nlp.classify.Classifier classifier;
    if (scoreClassifierType.equals(ClassifierType.LR)) {
        LogisticClassifierFactory<String, ScorePhraseMeasures> logfactory = new LogisticClassifierFactory<>();
        LogPrior lprior = new LogPrior();
        lprior.setSigma(constVars.LRSigma);
        classifier = logfactory.trainClassifier(dataset, lprior, false);
        LogisticClassifier logcl = ((LogisticClassifier) classifier);
        String l = (String) logcl.getLabelForInternalPositiveClass();
        Counter<String> weights = logcl.weightsAsCounter();
        if (l.equals(Boolean.FALSE.toString())) {
            Counters.multiplyInPlace(weights, -1);
        }
        List<Pair<String, Double>> wtd = Counters.toDescendingMagnitudeSortedListWithCounts(weights);
        Redwood.log(ConstantsAndVariables.minimaldebug, "The weights are " + StringUtils.join(wtd.subList(0, Math.min(wtd.size(), 600)), "\n"));
    } else if (scoreClassifierType.equals(ClassifierType.SVM)) {
        SVMLightClassifierFactory<String, ScorePhraseMeasures> svmcf = new SVMLightClassifierFactory<>(true);
        classifier = svmcf.trainClassifier(dataset);
        Set<String> labels = Generics.newHashSet(Arrays.asList("true"));
        List<Triple<ScorePhraseMeasures, String, Double>> topfeatures = ((SVMLightClassifier<String, ScorePhraseMeasures>) classifier).getTopFeatures(labels, 0, true, 600, true);
        Redwood.log(ConstantsAndVariables.minimaldebug, "The weights are " + StringUtils.join(topfeatures, "\n"));
    } else if (scoreClassifierType.equals(ClassifierType.SHIFTLR)) {
        //change the dataset to basic dataset because currently ShiftParamsLR doesn't support RVFDatum
        GeneralDataset<String, ScorePhraseMeasures> newdataset = new Dataset<>();
        Iterator<RVFDatum<String, ScorePhraseMeasures>> iter = dataset.iterator();
        while (iter.hasNext()) {
            RVFDatum<String, ScorePhraseMeasures> inst = iter.next();
            newdataset.add(new BasicDatum<>(inst.asFeatures(), inst.label()));
        }
        ShiftParamsLogisticClassifierFactory<String, ScorePhraseMeasures> factory = new ShiftParamsLogisticClassifierFactory<>();
        classifier = factory.trainClassifier(newdataset);
        //print weights
        MultinomialLogisticClassifier<String, ScorePhraseMeasures> logcl = ((MultinomialLogisticClassifier) classifier);
        Counter<ScorePhraseMeasures> weights = logcl.weightsAsGenericCounter().get("true");
        List<Pair<ScorePhraseMeasures, Double>> wtd = Counters.toDescendingMagnitudeSortedListWithCounts(weights);
        Redwood.log(ConstantsAndVariables.minimaldebug, "The weights are " + StringUtils.join(wtd.subList(0, Math.min(wtd.size(), 600)), "\n"));
    } else if (scoreClassifierType.equals(ClassifierType.LINEAR)) {
        LinearClassifierFactory<String, ScorePhraseMeasures> lcf = new LinearClassifierFactory<>();
        classifier = lcf.trainClassifier(dataset);
        Set<String> labels = Generics.newHashSet(Arrays.asList("true"));
        List<Triple<ScorePhraseMeasures, String, Double>> topfeatures = ((LinearClassifier<String, ScorePhraseMeasures>) classifier).getTopFeatures(labels, 0, true, 600, true);
        Redwood.log(ConstantsAndVariables.minimaldebug, "The weights are " + StringUtils.join(topfeatures, "\n"));
    } else
        throw new RuntimeException("cannot identify classifier " + scoreClassifierType);
    //    else if (scoreClassifierType.equals(ClassifierType.RF)) {
    //      ClassifierFactory wekaFactory = new WekaDatumClassifierFactory<String, ScorePhraseMeasures>("weka.classifiers.trees.RandomForest", constVars.wekaOptions);
    //      classifier = wekaFactory.trainClassifier(dataset);
    //      Classifier cls = ((WekaDatumClassifier) classifier).getClassifier();
    //      RandomForest rf = (RandomForest) cls;
    //    }
    BufferedWriter w = new BufferedWriter(new FileWriter("tempscorestrainer.txt"));
    System.out.println("size of learned scores is " + phraseScoresRaw.size());
    for (CandidatePhrase s : phraseScoresRaw.firstKeySet()) {
        w.write(s + "\t" + phraseScoresRaw.getCounter(s) + "\n");
    }
    w.close();
    return classifier;
}
Also used : ScorePhraseMeasures(edu.stanford.nlp.patterns.ConstantsAndVariables.ScorePhraseMeasures) edu.stanford.nlp.classify(edu.stanford.nlp.classify) FileWriter(java.io.FileWriter) BufferedWriter(java.io.BufferedWriter) RVFDatum(edu.stanford.nlp.ling.RVFDatum) AtomicDouble(edu.stanford.nlp.util.concurrent.AtomicDouble)

Example 8 with RVFDatum

use of edu.stanford.nlp.ling.RVFDatum in project CoreNLP by stanfordnlp.

the class LearnImportantFeatures method getDatum.

private RVFDatum<String, String> getDatum(CoreLabel[] sent, int i) {
    Counter<String> feat = new ClassicCounter<>();
    CoreLabel l = sent[i];
    String label;
    if (l.get(answerClass).toString().equals(answerLabel))
        label = answerLabel;
    else
        label = "O";
    CollectionValuedMap<String, CandidatePhrase> matchedPhrases = l.get(PatternsAnnotations.MatchedPhrases.class);
    if (matchedPhrases == null) {
        matchedPhrases = new CollectionValuedMap<>();
        matchedPhrases.add(label, CandidatePhrase.createOrGet(l.word()));
    }
    for (CandidatePhrase w : matchedPhrases.allValues()) {
        Integer num = this.clusterIds.get(w.getPhrase());
        if (num == null)
            num = -1;
        feat.setCount("Cluster-" + num, 1.0);
    }
    // feat.incrementCount("WORD-" + l.word());
    // feat.incrementCount("LEMMA-" + l.lemma());
    // feat.incrementCount("TAG-" + l.tag());
    int window = 0;
    for (int j = Math.max(0, i - window); j < i; j++) {
        CoreLabel lj = sent[j];
        feat.incrementCount("PREV-" + "WORD-" + lj.word());
        feat.incrementCount("PREV-" + "LEMMA-" + lj.lemma());
        feat.incrementCount("PREV-" + "TAG-" + lj.tag());
    }
    for (int j = i + 1; j < sent.length && j <= i + window; j++) {
        CoreLabel lj = sent[j];
        feat.incrementCount("NEXT-" + "WORD-" + lj.word());
        feat.incrementCount("NEXT-" + "LEMMA-" + lj.lemma());
        feat.incrementCount("NEXT-" + "TAG-" + lj.tag());
    }
    // System.out.println("adding " + l.word() + " as " + label);
    return new RVFDatum<>(feat, label);
}
Also used : CoreLabel(edu.stanford.nlp.ling.CoreLabel) ClassicCounter(edu.stanford.nlp.stats.ClassicCounter) RVFDatum(edu.stanford.nlp.ling.RVFDatum) PatternsAnnotations(edu.stanford.nlp.patterns.PatternsAnnotations) CandidatePhrase(edu.stanford.nlp.patterns.CandidatePhrase)

Example 9 with RVFDatum

use of edu.stanford.nlp.ling.RVFDatum in project CoreNLP by stanfordnlp.

the class SimpleSentiment method train.

/**
   * Train a sentiment model from a set of data.
   *
   * @param data The data to train the model from.
   * @param modelLocation An optional location to save the model.
   *                      Note that this stream will be closed in this method,
   *                      and should not be written to thereafter.
   *
   * @return A sentiment classifier, ready to use.
   */
@SuppressWarnings({ "OptionalUsedAsFieldOrParameterType", "ConstantConditions" })
public static SimpleSentiment train(Stream<SentimentDatum> data, Optional<OutputStream> modelLocation) {
    // Some useful variables configuring how we train
    boolean useL1 = true;
    double sigma = 1.0;
    int featureCountThreshold = 5;
    // Featurize the data
    forceTrack("Featurizing");
    RVFDataset<SentimentClass, String> dataset = new RVFDataset<>();
    AtomicInteger datasize = new AtomicInteger(0);
    Counter<SentimentClass> distribution = new ClassicCounter<>();
    data.unordered().parallel().map(datum -> {
        if (datasize.incrementAndGet() % 10000 == 0) {
            log("Added " + datasize.get() + " datums");
        }
        return new RVFDatum<>(featurize(datum.asCoreMap()), datum.sentiment);
    }).forEach(x -> {
        synchronized (dataset) {
            distribution.incrementCount(x.label());
            dataset.add(x);
        }
    });
    endTrack("Featurizing");
    // Print label distribution
    startTrack("Distribution");
    for (SentimentClass label : SentimentClass.values()) {
        log(String.format("%7d", (int) distribution.getCount(label)) + "   " + label);
    }
    endTrack("Distribution");
    // Train the classifier
    forceTrack("Training");
    if (featureCountThreshold > 1) {
        dataset.applyFeatureCountThreshold(featureCountThreshold);
    }
    dataset.randomize(42L);
    LinearClassifierFactory<SentimentClass, String> factory = new LinearClassifierFactory<>();
    factory.setVerbose(true);
    try {
        factory.setMinimizerCreator(() -> {
            QNMinimizer minimizer = new QNMinimizer();
            if (useL1) {
                minimizer.useOWLQN(true, 1 / (sigma * sigma));
            } else {
                factory.setSigma(sigma);
            }
            return minimizer;
        });
    } catch (Exception ignored) {
    }
    factory.setSigma(sigma);
    LinearClassifier<SentimentClass, String> classifier = factory.trainClassifier(dataset);
    // Optionally save the model
    modelLocation.ifPresent(stream -> {
        try {
            ObjectOutputStream oos = new ObjectOutputStream(stream);
            oos.writeObject(classifier);
            oos.close();
        } catch (IOException e) {
            log.err("Could not save model to stream!");
        }
    });
    endTrack("Training");
    // Evaluate the model
    forceTrack("Evaluating");
    factory.setVerbose(false);
    double sumAccuracy = 0.0;
    Counter<SentimentClass> sumP = new ClassicCounter<>();
    Counter<SentimentClass> sumR = new ClassicCounter<>();
    int numFolds = 4;
    for (int fold = 0; fold < numFolds; ++fold) {
        Pair<GeneralDataset<SentimentClass, String>, GeneralDataset<SentimentClass, String>> trainTest = dataset.splitOutFold(fold, numFolds);
        // convex objective, so this should be OK
        LinearClassifier<SentimentClass, String> foldClassifier = factory.trainClassifierWithInitialWeights(trainTest.first, classifier);
        sumAccuracy += foldClassifier.evaluateAccuracy(trainTest.second);
        for (SentimentClass label : SentimentClass.values()) {
            Pair<Double, Double> pr = foldClassifier.evaluatePrecisionAndRecall(trainTest.second, label);
            sumP.incrementCount(label, pr.first);
            sumP.incrementCount(label, pr.second);
        }
    }
    DecimalFormat df = new DecimalFormat("0.000%");
    log.info("----------");
    double aveAccuracy = sumAccuracy / ((double) numFolds);
    log.info("" + numFolds + "-fold accuracy: " + df.format(aveAccuracy));
    log.info("");
    for (SentimentClass label : SentimentClass.values()) {
        double p = sumP.getCount(label) / numFolds;
        double r = sumR.getCount(label) / numFolds;
        log.info(label + " (P)  = " + df.format(p));
        log.info(label + " (R)  = " + df.format(r));
        log.info(label + " (F1) = " + df.format(2 * p * r / (p + r)));
        log.info("");
    }
    log.info("----------");
    endTrack("Evaluating");
    // Return
    return new SimpleSentiment(classifier);
}
Also used : Arrays(java.util.Arrays) SentimentClass(edu.stanford.nlp.simple.SentimentClass) Document(edu.stanford.nlp.simple.Document) QNMinimizer(edu.stanford.nlp.optimization.QNMinimizer) Counter(edu.stanford.nlp.stats.Counter) StanfordCoreNLP(edu.stanford.nlp.pipeline.StanfordCoreNLP) AtomicInteger(java.util.concurrent.atomic.AtomicInteger) RuntimeIOException(edu.stanford.nlp.io.RuntimeIOException) Pair(edu.stanford.nlp.util.Pair) ObjectOutputStream(java.io.ObjectOutputStream) StreamSupport(java.util.stream.StreamSupport) ClassicCounter(edu.stanford.nlp.stats.ClassicCounter) CoreMap(edu.stanford.nlp.util.CoreMap) RVFDatum(edu.stanford.nlp.ling.RVFDatum) OutputStream(java.io.OutputStream) CoreLabel(edu.stanford.nlp.ling.CoreLabel) CoreAnnotations(edu.stanford.nlp.ling.CoreAnnotations) Properties(java.util.Properties) IOUtils(edu.stanford.nlp.io.IOUtils) Redwood(edu.stanford.nlp.util.logging.Redwood) DecimalFormat(java.text.DecimalFormat) Util(edu.stanford.nlp.util.logging.Redwood.Util) IOException(java.io.IOException) File(java.io.File) Lazy(edu.stanford.nlp.util.Lazy) List(java.util.List) Stream(java.util.stream.Stream) Annotation(edu.stanford.nlp.pipeline.Annotation) edu.stanford.nlp.classify(edu.stanford.nlp.classify) StringUtils(edu.stanford.nlp.util.StringUtils) Optional(java.util.Optional) RedwoodConfiguration(edu.stanford.nlp.util.logging.RedwoodConfiguration) Pattern(java.util.regex.Pattern) SentimentClass(edu.stanford.nlp.simple.SentimentClass) DecimalFormat(java.text.DecimalFormat) ObjectOutputStream(java.io.ObjectOutputStream) QNMinimizer(edu.stanford.nlp.optimization.QNMinimizer) RuntimeIOException(edu.stanford.nlp.io.RuntimeIOException) IOException(java.io.IOException) RuntimeIOException(edu.stanford.nlp.io.RuntimeIOException) IOException(java.io.IOException) AtomicInteger(java.util.concurrent.atomic.AtomicInteger) ClassicCounter(edu.stanford.nlp.stats.ClassicCounter)

Example 10 with RVFDatum

use of edu.stanford.nlp.ling.RVFDatum in project CoreNLP by stanfordnlp.

the class LinearClassifierITest method testStrMultiClassDatums.

public void testStrMultiClassDatums() throws Exception {
    RVFDataset<String, String> trainData = new RVFDataset<String, String>();
    List<RVFDatum<String, String>> datums = new ArrayList<RVFDatum<String, String>>();
    datums.add(newDatum("alpha", new String[] { "f1", "f2" }, new Double[] { 1.0, 0.0 }));
    ;
    datums.add(newDatum("beta", new String[] { "f1", "f2" }, new Double[] { 0.0, 1.0 }));
    datums.add(newDatum("charlie", new String[] { "f1", "f2" }, new Double[] { 5.0, 5.0 }));
    for (RVFDatum<String, String> datum : datums) trainData.add(datum);
    LinearClassifierFactory<String, String> lfc = new LinearClassifierFactory<String, String>();
    LinearClassifier<String, String> lc = lfc.trainClassifier(trainData);
    RVFDatum td1 = newDatum("alpha", new String[] { "f1", "f2", "f3" }, new Double[] { 2.0, 0.0, 5.5 });
    // Try the obvious (should get train data with 100% acc)
    for (RVFDatum<String, String> datum : datums) Assert.assertEquals(datum.label(), lc.classOf(datum));
    // Test data
    Assert.assertEquals(td1.label(), lc.classOf(td1));
}
Also used : ArrayList(java.util.ArrayList) RVFDatum(edu.stanford.nlp.ling.RVFDatum)

Aggregations

RVFDatum (edu.stanford.nlp.ling.RVFDatum)11 CoreLabel (edu.stanford.nlp.ling.CoreLabel)5 ClassicCounter (edu.stanford.nlp.stats.ClassicCounter)5 edu.stanford.nlp.classify (edu.stanford.nlp.classify)4 IOUtils (edu.stanford.nlp.io.IOUtils)3 RuntimeIOException (edu.stanford.nlp.io.RuntimeIOException)3 CoreAnnotations (edu.stanford.nlp.ling.CoreAnnotations)3 Counter (edu.stanford.nlp.stats.Counter)3 Redwood (edu.stanford.nlp.util.logging.Redwood)3 Util (edu.stanford.nlp.util.logging.Redwood.Util)3 File (java.io.File)3 ArrayList (java.util.ArrayList)3 AtomicInteger (java.util.concurrent.atomic.AtomicInteger)3 Span (edu.stanford.nlp.ie.machinereading.structure.Span)2 ScorePhraseMeasures (edu.stanford.nlp.patterns.ConstantsAndVariables.ScorePhraseMeasures)2 Annotation (edu.stanford.nlp.pipeline.Annotation)2 SentimentClass (edu.stanford.nlp.simple.SentimentClass)2 edu.stanford.nlp.util (edu.stanford.nlp.util)2 CoreMap (edu.stanford.nlp.util.CoreMap)2 RedwoodConfiguration (edu.stanford.nlp.util.logging.RedwoodConfiguration)2