use of edu.stanford.nlp.simple.SentimentClass in project CoreNLP by stanfordnlp.
the class SimpleSentiment method classify.
/**
* @see SimpleSentiment#classify(CoreMap)
*/
public SentimentClass classify(String text) {
Annotation ann = new Annotation(text);
pipeline.get().annotate(ann);
CoreMap sentence = ann.get(CoreAnnotations.SentencesAnnotation.class).get(0);
Counter<String> features = featurize(sentence);
RVFDatum<SentimentClass, String> datum = new RVFDatum<>(features);
return impl.classOf(datum);
}
use of edu.stanford.nlp.simple.SentimentClass in project CoreNLP by stanfordnlp.
the class SimpleSentiment method train.
/**
* Train a sentiment model from a set of data.
*
* @param data The data to train the model from.
* @param modelLocation An optional location to save the model.
* Note that this stream will be closed in this method,
* and should not be written to thereafter.
*
* @return A sentiment classifier, ready to use.
*/
@SuppressWarnings({ "OptionalUsedAsFieldOrParameterType", "ConstantConditions" })
public static SimpleSentiment train(Stream<SentimentDatum> data, Optional<OutputStream> modelLocation) {
// Some useful variables configuring how we train
boolean useL1 = true;
double sigma = 1.0;
int featureCountThreshold = 5;
// Featurize the data
forceTrack("Featurizing");
RVFDataset<SentimentClass, String> dataset = new RVFDataset<>();
AtomicInteger datasize = new AtomicInteger(0);
Counter<SentimentClass> distribution = new ClassicCounter<>();
data.unordered().parallel().map(datum -> {
if (datasize.incrementAndGet() % 10000 == 0) {
log("Added " + datasize.get() + " datums");
}
return new RVFDatum<>(featurize(datum.asCoreMap()), datum.sentiment);
}).forEach(x -> {
synchronized (dataset) {
distribution.incrementCount(x.label());
dataset.add(x);
}
});
endTrack("Featurizing");
// Print label distribution
startTrack("Distribution");
for (SentimentClass label : SentimentClass.values()) {
log(String.format("%7d", (int) distribution.getCount(label)) + " " + label);
}
endTrack("Distribution");
// Train the classifier
forceTrack("Training");
if (featureCountThreshold > 1) {
dataset.applyFeatureCountThreshold(featureCountThreshold);
}
dataset.randomize(42L);
LinearClassifierFactory<SentimentClass, String> factory = new LinearClassifierFactory<>();
factory.setVerbose(true);
try {
factory.setMinimizerCreator(() -> {
QNMinimizer minimizer = new QNMinimizer();
if (useL1) {
minimizer.useOWLQN(true, 1 / (sigma * sigma));
} else {
factory.setSigma(sigma);
}
return minimizer;
});
} catch (Exception ignored) {
}
factory.setSigma(sigma);
LinearClassifier<SentimentClass, String> classifier = factory.trainClassifier(dataset);
// Optionally save the model
modelLocation.ifPresent(stream -> {
try {
ObjectOutputStream oos = new ObjectOutputStream(stream);
oos.writeObject(classifier);
oos.close();
} catch (IOException e) {
log.err("Could not save model to stream!");
}
});
endTrack("Training");
// Evaluate the model
forceTrack("Evaluating");
factory.setVerbose(false);
double sumAccuracy = 0.0;
Counter<SentimentClass> sumP = new ClassicCounter<>();
Counter<SentimentClass> sumR = new ClassicCounter<>();
int numFolds = 4;
for (int fold = 0; fold < numFolds; ++fold) {
Pair<GeneralDataset<SentimentClass, String>, GeneralDataset<SentimentClass, String>> trainTest = dataset.splitOutFold(fold, numFolds);
// convex objective, so this should be OK
LinearClassifier<SentimentClass, String> foldClassifier = factory.trainClassifierWithInitialWeights(trainTest.first, classifier);
sumAccuracy += foldClassifier.evaluateAccuracy(trainTest.second);
for (SentimentClass label : SentimentClass.values()) {
Pair<Double, Double> pr = foldClassifier.evaluatePrecisionAndRecall(trainTest.second, label);
sumP.incrementCount(label, pr.first);
sumP.incrementCount(label, pr.second);
}
}
DecimalFormat df = new DecimalFormat("0.000%");
log.info("----------");
double aveAccuracy = sumAccuracy / ((double) numFolds);
log.info("" + numFolds + "-fold accuracy: " + df.format(aveAccuracy));
log.info("");
for (SentimentClass label : SentimentClass.values()) {
double p = sumP.getCount(label) / numFolds;
double r = sumR.getCount(label) / numFolds;
log.info(label + " (P) = " + df.format(p));
log.info(label + " (R) = " + df.format(r));
log.info(label + " (F1) = " + df.format(2 * p * r / (p + r)));
log.info("");
}
log.info("----------");
endTrack("Evaluating");
// Return
return new SimpleSentiment(classifier);
}
Aggregations