use of edu.stanford.nlp.ling.RVFDatum in project CoreNLP by stanfordnlp.
the class ClauseSplitter method train.
/**
* Train a clause searcher factory. That is, train a classifier for which arcs should be
* new clauses.
*
* @param trainingData The training data. This is a stream of triples of:
* <ol>
* <li>The sentence containing a known extraction.</li>
* <li>The span of the subject in the sentence, as a token span.</li>
* <li>The span of the object in the sentence, as a token span.</li>
* </ol>
* @param modelPath The path to save the model to. This is useful for {@link ClauseSplitter#load(String)}.
* @param trainingDataDump The path to save the training data, as a set of labeled featurized datums.
* @param featurizer The featurizer to use for this classifier.
*
* @return A factory for creating searchers from a given dependency tree.
*/
static ClauseSplitter train(Stream<Pair<CoreMap, Collection<Pair<Span, Span>>>> trainingData, Optional<File> modelPath, Optional<File> trainingDataDump, Featurizer featurizer) {
// Parse options
LinearClassifierFactory<ClauseClassifierLabel, String> factory = new LinearClassifierFactory<>();
// Generally useful objects
OpenIE openie = new OpenIE(PropertiesUtils.asProperties("splitter.nomodel", "true", "optimizefor", "GENERAL"));
WeightedDataset<ClauseClassifierLabel, String> dataset = new WeightedDataset<>();
AtomicInteger numExamplesProcessed = new AtomicInteger(0);
final Optional<PrintWriter> datasetDumpWriter = trainingDataDump.map(file -> {
try {
return new PrintWriter(new OutputStreamWriter(new GZIPOutputStream(new FileOutputStream(trainingDataDump.get()))));
} catch (IOException e) {
throw new RuntimeIOException(e);
}
});
// Step 1: Loop over data
forceTrack("Training inference");
trainingData.forEach(rawExample -> {
CoreMap sentence = rawExample.first;
Collection<Pair<Span, Span>> spans = rawExample.second;
List<CoreLabel> tokens = sentence.get(CoreAnnotations.TokensAnnotation.class);
SemanticGraph tree = sentence.get(SemanticGraphCoreAnnotations.EnhancedDependenciesAnnotation.class);
ClauseSplitterSearchProblem problem = new ClauseSplitterSearchProblem(tree, true);
problem.search(fragmentAndScore -> {
List<Counter<String>> features = fragmentAndScore.second;
SentenceFragment fragment = fragmentAndScore.third.get();
Set<RelationTriple> extractions = new HashSet<>(openie.relationsInFragments(openie.entailmentsFromClause(fragment)));
Trilean correct = Trilean.FALSE;
RELATION_TRIPLE_LOOP: for (RelationTriple extraction : extractions) {
Span subjectGuess = Span.fromValues(extraction.subject.get(0).index() - 1, extraction.subject.get(extraction.subject.size() - 1).index());
Span objectGuess = Span.fromValues(extraction.object.get(0).index() - 1, extraction.object.get(extraction.object.size() - 1).index());
for (Pair<Span, Span> candidateGold : spans) {
Span subjectSpan = candidateGold.first;
Span objectSpan = candidateGold.second;
if ((subjectGuess.equals(subjectSpan) && objectGuess.equals(objectSpan)) || (subjectGuess.equals(objectSpan) && objectGuess.equals(subjectSpan))) {
correct = Trilean.TRUE;
break RELATION_TRIPLE_LOOP;
} else if (Util.nerOverlap(tokens, subjectSpan, subjectGuess) && Util.nerOverlap(tokens, objectSpan, objectGuess) || Util.nerOverlap(tokens, subjectSpan, objectGuess) && Util.nerOverlap(tokens, objectSpan, subjectGuess)) {
if (!correct.isTrue()) {
correct = Trilean.TRUE;
break RELATION_TRIPLE_LOOP;
}
} else {
if (!correct.isTrue()) {
correct = Trilean.UNKNOWN;
break RELATION_TRIPLE_LOOP;
}
}
}
}
if (!features.isEmpty()) {
List<Pair<Counter<String>, ClauseClassifierLabel>> decisionsToAddAsDatums = new ArrayList<>();
if (correct.isTrue()) {
for (int i = 0; i < features.size(); ++i) {
if (i == features.size() - 1) {
decisionsToAddAsDatums.add(Pair.makePair(features.get(i), ClauseClassifierLabel.CLAUSE_SPLIT));
} else {
decisionsToAddAsDatums.add(Pair.makePair(features.get(i), ClauseClassifierLabel.CLAUSE_INTERM));
}
}
} else if (correct.isFalse()) {
decisionsToAddAsDatums.add(Pair.makePair(features.get(features.size() - 1), ClauseClassifierLabel.NOT_A_CLAUSE));
} else if (correct.isUnknown()) {
boolean isSimpleSplit = false;
for (Counter<String> feats : features) {
if (featurizer.isSimpleSplit(feats)) {
isSimpleSplit = true;
break;
}
}
if (isSimpleSplit) {
for (int i = 0; i < features.size(); ++i) {
if (i == features.size() - 1) {
decisionsToAddAsDatums.add(Pair.makePair(features.get(i), ClauseClassifierLabel.CLAUSE_SPLIT));
} else {
decisionsToAddAsDatums.add(Pair.makePair(features.get(i), ClauseClassifierLabel.CLAUSE_INTERM));
}
}
}
}
for (Pair<Counter<String>, ClauseClassifierLabel> decision : decisionsToAddAsDatums) {
RVFDatum<ClauseClassifierLabel, String> datum = new RVFDatum<>(decision.first);
datum.setLabel(decision.second);
if (datasetDumpWriter.isPresent()) {
datasetDumpWriter.get().println(decision.second + "\t" + StringUtils.join(decision.first.entrySet().stream().map(entry -> entry.getKey() + "->" + entry.getValue()), ";"));
}
dataset.add(datum);
}
}
return true;
}, new LinearClassifier<>(new ClassicCounter<>()), Collections.emptyMap(), featurizer, 10000);
if (numExamplesProcessed.incrementAndGet() % 100 == 0) {
log("processed " + numExamplesProcessed + " training sentences: " + dataset.size() + " datums");
}
});
endTrack("Training inference");
// Close the file
if (datasetDumpWriter.isPresent()) {
datasetDumpWriter.get().close();
}
// Step 2: Train classifier
forceTrack("Training");
Classifier<ClauseClassifierLabel, String> fullClassifier = factory.trainClassifier(dataset);
endTrack("Training");
if (modelPath.isPresent()) {
Pair<Classifier<ClauseClassifierLabel, String>, Featurizer> toSave = Pair.makePair(fullClassifier, featurizer);
try {
IOUtils.writeObjectToFile(toSave, modelPath.get());
log("SUCCESS: wrote model to " + modelPath.get().getPath());
} catch (IOException e) {
log("ERROR: failed to save model to path: " + modelPath.get().getPath());
err(e);
}
}
// Step 3: Check accuracy of classifier
forceTrack("Training accuracy");
dataset.randomize(42L);
Util.dumpAccuracy(fullClassifier, dataset);
endTrack("Training accuracy");
int numFolds = 5;
forceTrack(numFolds + " fold cross-validation");
for (int fold = 0; fold < numFolds; ++fold) {
forceTrack("Fold " + (fold + 1));
forceTrack("Training");
Pair<GeneralDataset<ClauseClassifierLabel, String>, GeneralDataset<ClauseClassifierLabel, String>> foldData = dataset.splitOutFold(fold, numFolds);
Classifier<ClauseClassifierLabel, String> classifier = factory.trainClassifier(foldData.first);
endTrack("Training");
forceTrack("Test");
Util.dumpAccuracy(classifier, foldData.second);
endTrack("Test");
endTrack("Fold " + (fold + 1));
}
endTrack(numFolds + " fold cross-validation");
// Step 5: return factory
return (tree, truth) -> new ClauseSplitterSearchProblem(tree, truth, Optional.of(fullClassifier), Optional.of(featurizer));
}
use of edu.stanford.nlp.ling.RVFDatum in project CoreNLP by stanfordnlp.
the class ScorePhrasesLearnFeatWt method learnClassifier.
public edu.stanford.nlp.classify.Classifier learnClassifier(String label, boolean forLearningPatterns, TwoDimensionalCounter<CandidatePhrase, E> wordsPatExtracted, Counter<E> allSelectedPatterns) throws IOException, ClassNotFoundException {
phraseScoresRaw.clear();
learnedScores.clear();
if (Data.domainNGramsFile != null)
Data.loadDomainNGrams();
boolean computeRawFreq = false;
if (Data.rawFreq == null) {
Data.rawFreq = new ClassicCounter<>();
computeRawFreq = true;
}
GeneralDataset<String, ScorePhraseMeasures> dataset = choosedatums(forLearningPatterns, label, wordsPatExtracted, allSelectedPatterns, computeRawFreq);
edu.stanford.nlp.classify.Classifier classifier;
if (scoreClassifierType.equals(ClassifierType.LR)) {
LogisticClassifierFactory<String, ScorePhraseMeasures> logfactory = new LogisticClassifierFactory<>();
LogPrior lprior = new LogPrior();
lprior.setSigma(constVars.LRSigma);
classifier = logfactory.trainClassifier(dataset, lprior, false);
LogisticClassifier logcl = ((LogisticClassifier) classifier);
String l = (String) logcl.getLabelForInternalPositiveClass();
Counter<String> weights = logcl.weightsAsCounter();
if (l.equals(Boolean.FALSE.toString())) {
Counters.multiplyInPlace(weights, -1);
}
List<Pair<String, Double>> wtd = Counters.toDescendingMagnitudeSortedListWithCounts(weights);
Redwood.log(ConstantsAndVariables.minimaldebug, "The weights are " + StringUtils.join(wtd.subList(0, Math.min(wtd.size(), 600)), "\n"));
} else if (scoreClassifierType.equals(ClassifierType.SVM)) {
SVMLightClassifierFactory<String, ScorePhraseMeasures> svmcf = new SVMLightClassifierFactory<>(true);
classifier = svmcf.trainClassifier(dataset);
Set<String> labels = Generics.newHashSet(Arrays.asList("true"));
List<Triple<ScorePhraseMeasures, String, Double>> topfeatures = ((SVMLightClassifier<String, ScorePhraseMeasures>) classifier).getTopFeatures(labels, 0, true, 600, true);
Redwood.log(ConstantsAndVariables.minimaldebug, "The weights are " + StringUtils.join(topfeatures, "\n"));
} else if (scoreClassifierType.equals(ClassifierType.SHIFTLR)) {
//change the dataset to basic dataset because currently ShiftParamsLR doesn't support RVFDatum
GeneralDataset<String, ScorePhraseMeasures> newdataset = new Dataset<>();
Iterator<RVFDatum<String, ScorePhraseMeasures>> iter = dataset.iterator();
while (iter.hasNext()) {
RVFDatum<String, ScorePhraseMeasures> inst = iter.next();
newdataset.add(new BasicDatum<>(inst.asFeatures(), inst.label()));
}
ShiftParamsLogisticClassifierFactory<String, ScorePhraseMeasures> factory = new ShiftParamsLogisticClassifierFactory<>();
classifier = factory.trainClassifier(newdataset);
//print weights
MultinomialLogisticClassifier<String, ScorePhraseMeasures> logcl = ((MultinomialLogisticClassifier) classifier);
Counter<ScorePhraseMeasures> weights = logcl.weightsAsGenericCounter().get("true");
List<Pair<ScorePhraseMeasures, Double>> wtd = Counters.toDescendingMagnitudeSortedListWithCounts(weights);
Redwood.log(ConstantsAndVariables.minimaldebug, "The weights are " + StringUtils.join(wtd.subList(0, Math.min(wtd.size(), 600)), "\n"));
} else if (scoreClassifierType.equals(ClassifierType.LINEAR)) {
LinearClassifierFactory<String, ScorePhraseMeasures> lcf = new LinearClassifierFactory<>();
classifier = lcf.trainClassifier(dataset);
Set<String> labels = Generics.newHashSet(Arrays.asList("true"));
List<Triple<ScorePhraseMeasures, String, Double>> topfeatures = ((LinearClassifier<String, ScorePhraseMeasures>) classifier).getTopFeatures(labels, 0, true, 600, true);
Redwood.log(ConstantsAndVariables.minimaldebug, "The weights are " + StringUtils.join(topfeatures, "\n"));
} else
throw new RuntimeException("cannot identify classifier " + scoreClassifierType);
// else if (scoreClassifierType.equals(ClassifierType.RF)) {
// ClassifierFactory wekaFactory = new WekaDatumClassifierFactory<String, ScorePhraseMeasures>("weka.classifiers.trees.RandomForest", constVars.wekaOptions);
// classifier = wekaFactory.trainClassifier(dataset);
// Classifier cls = ((WekaDatumClassifier) classifier).getClassifier();
// RandomForest rf = (RandomForest) cls;
// }
BufferedWriter w = new BufferedWriter(new FileWriter("tempscorestrainer.txt"));
System.out.println("size of learned scores is " + phraseScoresRaw.size());
for (CandidatePhrase s : phraseScoresRaw.firstKeySet()) {
w.write(s + "\t" + phraseScoresRaw.getCounter(s) + "\n");
}
w.close();
return classifier;
}
use of edu.stanford.nlp.ling.RVFDatum in project CoreNLP by stanfordnlp.
the class LearnImportantFeatures method getDatum.
private RVFDatum<String, String> getDatum(CoreLabel[] sent, int i) {
Counter<String> feat = new ClassicCounter<>();
CoreLabel l = sent[i];
String label;
if (l.get(answerClass).toString().equals(answerLabel))
label = answerLabel;
else
label = "O";
CollectionValuedMap<String, CandidatePhrase> matchedPhrases = l.get(PatternsAnnotations.MatchedPhrases.class);
if (matchedPhrases == null) {
matchedPhrases = new CollectionValuedMap<>();
matchedPhrases.add(label, CandidatePhrase.createOrGet(l.word()));
}
for (CandidatePhrase w : matchedPhrases.allValues()) {
Integer num = this.clusterIds.get(w.getPhrase());
if (num == null)
num = -1;
feat.setCount("Cluster-" + num, 1.0);
}
// feat.incrementCount("WORD-" + l.word());
// feat.incrementCount("LEMMA-" + l.lemma());
// feat.incrementCount("TAG-" + l.tag());
int window = 0;
for (int j = Math.max(0, i - window); j < i; j++) {
CoreLabel lj = sent[j];
feat.incrementCount("PREV-" + "WORD-" + lj.word());
feat.incrementCount("PREV-" + "LEMMA-" + lj.lemma());
feat.incrementCount("PREV-" + "TAG-" + lj.tag());
}
for (int j = i + 1; j < sent.length && j <= i + window; j++) {
CoreLabel lj = sent[j];
feat.incrementCount("NEXT-" + "WORD-" + lj.word());
feat.incrementCount("NEXT-" + "LEMMA-" + lj.lemma());
feat.incrementCount("NEXT-" + "TAG-" + lj.tag());
}
// System.out.println("adding " + l.word() + " as " + label);
return new RVFDatum<>(feat, label);
}
use of edu.stanford.nlp.ling.RVFDatum in project CoreNLP by stanfordnlp.
the class SimpleSentiment method train.
/**
* Train a sentiment model from a set of data.
*
* @param data The data to train the model from.
* @param modelLocation An optional location to save the model.
* Note that this stream will be closed in this method,
* and should not be written to thereafter.
*
* @return A sentiment classifier, ready to use.
*/
@SuppressWarnings({ "OptionalUsedAsFieldOrParameterType", "ConstantConditions" })
public static SimpleSentiment train(Stream<SentimentDatum> data, Optional<OutputStream> modelLocation) {
// Some useful variables configuring how we train
boolean useL1 = true;
double sigma = 1.0;
int featureCountThreshold = 5;
// Featurize the data
forceTrack("Featurizing");
RVFDataset<SentimentClass, String> dataset = new RVFDataset<>();
AtomicInteger datasize = new AtomicInteger(0);
Counter<SentimentClass> distribution = new ClassicCounter<>();
data.unordered().parallel().map(datum -> {
if (datasize.incrementAndGet() % 10000 == 0) {
log("Added " + datasize.get() + " datums");
}
return new RVFDatum<>(featurize(datum.asCoreMap()), datum.sentiment);
}).forEach(x -> {
synchronized (dataset) {
distribution.incrementCount(x.label());
dataset.add(x);
}
});
endTrack("Featurizing");
// Print label distribution
startTrack("Distribution");
for (SentimentClass label : SentimentClass.values()) {
log(String.format("%7d", (int) distribution.getCount(label)) + " " + label);
}
endTrack("Distribution");
// Train the classifier
forceTrack("Training");
if (featureCountThreshold > 1) {
dataset.applyFeatureCountThreshold(featureCountThreshold);
}
dataset.randomize(42L);
LinearClassifierFactory<SentimentClass, String> factory = new LinearClassifierFactory<>();
factory.setVerbose(true);
try {
factory.setMinimizerCreator(() -> {
QNMinimizer minimizer = new QNMinimizer();
if (useL1) {
minimizer.useOWLQN(true, 1 / (sigma * sigma));
} else {
factory.setSigma(sigma);
}
return minimizer;
});
} catch (Exception ignored) {
}
factory.setSigma(sigma);
LinearClassifier<SentimentClass, String> classifier = factory.trainClassifier(dataset);
// Optionally save the model
modelLocation.ifPresent(stream -> {
try {
ObjectOutputStream oos = new ObjectOutputStream(stream);
oos.writeObject(classifier);
oos.close();
} catch (IOException e) {
log.err("Could not save model to stream!");
}
});
endTrack("Training");
// Evaluate the model
forceTrack("Evaluating");
factory.setVerbose(false);
double sumAccuracy = 0.0;
Counter<SentimentClass> sumP = new ClassicCounter<>();
Counter<SentimentClass> sumR = new ClassicCounter<>();
int numFolds = 4;
for (int fold = 0; fold < numFolds; ++fold) {
Pair<GeneralDataset<SentimentClass, String>, GeneralDataset<SentimentClass, String>> trainTest = dataset.splitOutFold(fold, numFolds);
// convex objective, so this should be OK
LinearClassifier<SentimentClass, String> foldClassifier = factory.trainClassifierWithInitialWeights(trainTest.first, classifier);
sumAccuracy += foldClassifier.evaluateAccuracy(trainTest.second);
for (SentimentClass label : SentimentClass.values()) {
Pair<Double, Double> pr = foldClassifier.evaluatePrecisionAndRecall(trainTest.second, label);
sumP.incrementCount(label, pr.first);
sumP.incrementCount(label, pr.second);
}
}
DecimalFormat df = new DecimalFormat("0.000%");
log.info("----------");
double aveAccuracy = sumAccuracy / ((double) numFolds);
log.info("" + numFolds + "-fold accuracy: " + df.format(aveAccuracy));
log.info("");
for (SentimentClass label : SentimentClass.values()) {
double p = sumP.getCount(label) / numFolds;
double r = sumR.getCount(label) / numFolds;
log.info(label + " (P) = " + df.format(p));
log.info(label + " (R) = " + df.format(r));
log.info(label + " (F1) = " + df.format(2 * p * r / (p + r)));
log.info("");
}
log.info("----------");
endTrack("Evaluating");
// Return
return new SimpleSentiment(classifier);
}
use of edu.stanford.nlp.ling.RVFDatum in project CoreNLP by stanfordnlp.
the class LinearClassifierITest method testStrMultiClassDatums.
public void testStrMultiClassDatums() throws Exception {
RVFDataset<String, String> trainData = new RVFDataset<String, String>();
List<RVFDatum<String, String>> datums = new ArrayList<RVFDatum<String, String>>();
datums.add(newDatum("alpha", new String[] { "f1", "f2" }, new Double[] { 1.0, 0.0 }));
;
datums.add(newDatum("beta", new String[] { "f1", "f2" }, new Double[] { 0.0, 1.0 }));
datums.add(newDatum("charlie", new String[] { "f1", "f2" }, new Double[] { 5.0, 5.0 }));
for (RVFDatum<String, String> datum : datums) trainData.add(datum);
LinearClassifierFactory<String, String> lfc = new LinearClassifierFactory<String, String>();
LinearClassifier<String, String> lc = lfc.trainClassifier(trainData);
RVFDatum td1 = newDatum("alpha", new String[] { "f1", "f2", "f3" }, new Double[] { 2.0, 0.0, 5.5 });
// Try the obvious (should get train data with 100% acc)
for (RVFDatum<String, String> datum : datums) Assert.assertEquals(datum.label(), lc.classOf(datum));
// Test data
Assert.assertEquals(td1.label(), lc.classOf(td1));
}
Aggregations