Search in sources :

Example 1 with SentimentClass

use of edu.stanford.nlp.simple.SentimentClass in project CoreNLP by stanfordnlp.

the class SimpleSentiment method classify.

/**
 * @see SimpleSentiment#classify(CoreMap)
 */
public SentimentClass classify(String text) {
    Annotation ann = new Annotation(text);
    pipeline.get().annotate(ann);
    CoreMap sentence = ann.get(CoreAnnotations.SentencesAnnotation.class).get(0);
    Counter<String> features = featurize(sentence);
    RVFDatum<SentimentClass, String> datum = new RVFDatum<>(features);
    return impl.classOf(datum);
}
Also used : SentimentClass(edu.stanford.nlp.simple.SentimentClass) RVFDatum(edu.stanford.nlp.ling.RVFDatum) CoreMap(edu.stanford.nlp.util.CoreMap) Annotation(edu.stanford.nlp.pipeline.Annotation)

Example 2 with SentimentClass

use of edu.stanford.nlp.simple.SentimentClass in project CoreNLP by stanfordnlp.

the class SimpleSentiment method train.

/**
 * Train a sentiment model from a set of data.
 *
 * @param data The data to train the model from.
 * @param modelLocation An optional location to save the model.
 *                      Note that this stream will be closed in this method,
 *                      and should not be written to thereafter.
 *
 * @return A sentiment classifier, ready to use.
 */
@SuppressWarnings({ "OptionalUsedAsFieldOrParameterType", "ConstantConditions" })
public static SimpleSentiment train(Stream<SentimentDatum> data, Optional<OutputStream> modelLocation) {
    // Some useful variables configuring how we train
    boolean useL1 = true;
    double sigma = 1.0;
    int featureCountThreshold = 5;
    // Featurize the data
    forceTrack("Featurizing");
    RVFDataset<SentimentClass, String> dataset = new RVFDataset<>();
    AtomicInteger datasize = new AtomicInteger(0);
    Counter<SentimentClass> distribution = new ClassicCounter<>();
    data.unordered().parallel().map(datum -> {
        if (datasize.incrementAndGet() % 10000 == 0) {
            log("Added " + datasize.get() + " datums");
        }
        return new RVFDatum<>(featurize(datum.asCoreMap()), datum.sentiment);
    }).forEach(x -> {
        synchronized (dataset) {
            distribution.incrementCount(x.label());
            dataset.add(x);
        }
    });
    endTrack("Featurizing");
    // Print label distribution
    startTrack("Distribution");
    for (SentimentClass label : SentimentClass.values()) {
        log(String.format("%7d", (int) distribution.getCount(label)) + "   " + label);
    }
    endTrack("Distribution");
    // Train the classifier
    forceTrack("Training");
    if (featureCountThreshold > 1) {
        dataset.applyFeatureCountThreshold(featureCountThreshold);
    }
    dataset.randomize(42L);
    LinearClassifierFactory<SentimentClass, String> factory = new LinearClassifierFactory<>();
    factory.setVerbose(true);
    try {
        factory.setMinimizerCreator(() -> {
            QNMinimizer minimizer = new QNMinimizer();
            if (useL1) {
                minimizer.useOWLQN(true, 1 / (sigma * sigma));
            } else {
                factory.setSigma(sigma);
            }
            return minimizer;
        });
    } catch (Exception ignored) {
    }
    factory.setSigma(sigma);
    LinearClassifier<SentimentClass, String> classifier = factory.trainClassifier(dataset);
    // Optionally save the model
    modelLocation.ifPresent(stream -> {
        try {
            ObjectOutputStream oos = new ObjectOutputStream(stream);
            oos.writeObject(classifier);
            oos.close();
        } catch (IOException e) {
            log.err("Could not save model to stream!");
        }
    });
    endTrack("Training");
    // Evaluate the model
    forceTrack("Evaluating");
    factory.setVerbose(false);
    double sumAccuracy = 0.0;
    Counter<SentimentClass> sumP = new ClassicCounter<>();
    Counter<SentimentClass> sumR = new ClassicCounter<>();
    int numFolds = 4;
    for (int fold = 0; fold < numFolds; ++fold) {
        Pair<GeneralDataset<SentimentClass, String>, GeneralDataset<SentimentClass, String>> trainTest = dataset.splitOutFold(fold, numFolds);
        // convex objective, so this should be OK
        LinearClassifier<SentimentClass, String> foldClassifier = factory.trainClassifierWithInitialWeights(trainTest.first, classifier);
        sumAccuracy += foldClassifier.evaluateAccuracy(trainTest.second);
        for (SentimentClass label : SentimentClass.values()) {
            Pair<Double, Double> pr = foldClassifier.evaluatePrecisionAndRecall(trainTest.second, label);
            sumP.incrementCount(label, pr.first);
            sumP.incrementCount(label, pr.second);
        }
    }
    DecimalFormat df = new DecimalFormat("0.000%");
    log.info("----------");
    double aveAccuracy = sumAccuracy / ((double) numFolds);
    log.info("" + numFolds + "-fold accuracy: " + df.format(aveAccuracy));
    log.info("");
    for (SentimentClass label : SentimentClass.values()) {
        double p = sumP.getCount(label) / numFolds;
        double r = sumR.getCount(label) / numFolds;
        log.info(label + " (P)  = " + df.format(p));
        log.info(label + " (R)  = " + df.format(r));
        log.info(label + " (F1) = " + df.format(2 * p * r / (p + r)));
        log.info("");
    }
    log.info("----------");
    endTrack("Evaluating");
    // Return
    return new SimpleSentiment(classifier);
}
Also used : Arrays(java.util.Arrays) SentimentClass(edu.stanford.nlp.simple.SentimentClass) Document(edu.stanford.nlp.simple.Document) QNMinimizer(edu.stanford.nlp.optimization.QNMinimizer) Counter(edu.stanford.nlp.stats.Counter) StanfordCoreNLP(edu.stanford.nlp.pipeline.StanfordCoreNLP) AtomicInteger(java.util.concurrent.atomic.AtomicInteger) RuntimeIOException(edu.stanford.nlp.io.RuntimeIOException) Pair(edu.stanford.nlp.util.Pair) ObjectOutputStream(java.io.ObjectOutputStream) StreamSupport(java.util.stream.StreamSupport) ClassicCounter(edu.stanford.nlp.stats.ClassicCounter) CoreMap(edu.stanford.nlp.util.CoreMap) RVFDatum(edu.stanford.nlp.ling.RVFDatum) OutputStream(java.io.OutputStream) CoreLabel(edu.stanford.nlp.ling.CoreLabel) CoreAnnotations(edu.stanford.nlp.ling.CoreAnnotations) Properties(java.util.Properties) IOUtils(edu.stanford.nlp.io.IOUtils) Redwood(edu.stanford.nlp.util.logging.Redwood) DecimalFormat(java.text.DecimalFormat) Util(edu.stanford.nlp.util.logging.Redwood.Util) IOException(java.io.IOException) File(java.io.File) Lazy(edu.stanford.nlp.util.Lazy) List(java.util.List) Stream(java.util.stream.Stream) Annotation(edu.stanford.nlp.pipeline.Annotation) edu.stanford.nlp.classify(edu.stanford.nlp.classify) StringUtils(edu.stanford.nlp.util.StringUtils) Optional(java.util.Optional) RedwoodConfiguration(edu.stanford.nlp.util.logging.RedwoodConfiguration) Pattern(java.util.regex.Pattern) SentimentClass(edu.stanford.nlp.simple.SentimentClass) DecimalFormat(java.text.DecimalFormat) ObjectOutputStream(java.io.ObjectOutputStream) QNMinimizer(edu.stanford.nlp.optimization.QNMinimizer) RuntimeIOException(edu.stanford.nlp.io.RuntimeIOException) IOException(java.io.IOException) RuntimeIOException(edu.stanford.nlp.io.RuntimeIOException) IOException(java.io.IOException) AtomicInteger(java.util.concurrent.atomic.AtomicInteger) ClassicCounter(edu.stanford.nlp.stats.ClassicCounter)

Aggregations

RVFDatum (edu.stanford.nlp.ling.RVFDatum)2 Annotation (edu.stanford.nlp.pipeline.Annotation)2 SentimentClass (edu.stanford.nlp.simple.SentimentClass)2 CoreMap (edu.stanford.nlp.util.CoreMap)2 edu.stanford.nlp.classify (edu.stanford.nlp.classify)1 IOUtils (edu.stanford.nlp.io.IOUtils)1 RuntimeIOException (edu.stanford.nlp.io.RuntimeIOException)1 CoreAnnotations (edu.stanford.nlp.ling.CoreAnnotations)1 CoreLabel (edu.stanford.nlp.ling.CoreLabel)1 QNMinimizer (edu.stanford.nlp.optimization.QNMinimizer)1 StanfordCoreNLP (edu.stanford.nlp.pipeline.StanfordCoreNLP)1 Document (edu.stanford.nlp.simple.Document)1 ClassicCounter (edu.stanford.nlp.stats.ClassicCounter)1 Counter (edu.stanford.nlp.stats.Counter)1 Lazy (edu.stanford.nlp.util.Lazy)1 Pair (edu.stanford.nlp.util.Pair)1 StringUtils (edu.stanford.nlp.util.StringUtils)1 Redwood (edu.stanford.nlp.util.logging.Redwood)1 Util (edu.stanford.nlp.util.logging.Redwood.Util)1 RedwoodConfiguration (edu.stanford.nlp.util.logging.RedwoodConfiguration)1