use of edu.stanford.nlp.stats.ClassicCounter in project CoreNLP by stanfordnlp.
the class TextAnnotationPatterns method suggestPhrasesTest.
public String suggestPhrasesTest(Properties testProps, String modelPropertiesFile, String stopWordsFile) throws IllegalAccessException, InterruptedException, ExecutionException, IOException, InstantiationException, NoSuchMethodException, InvocationTargetException, ClassNotFoundException, SQLException {
logger.info("Suggesting phrases in test");
logger.info("test properties are " + testProps);
Properties runProps = StringUtils.argsToPropertiesWithResolve(new String[] { "-props", modelPropertiesFile });
String[] removeProperties = new String[] { "allPatternsDir", "storePatsForEachToken", "invertedIndexClass", "savePatternsWordsDir", "batchProcessSents", "outDir", "saveInvertedIndex", "removeOverLappingLabels", "numThreads" };
for (String s : removeProperties) if (runProps.containsKey(s))
runProps.remove(s);
runProps.setProperty("stopWordsPatternFiles", stopWordsFile);
runProps.setProperty("englishWordsFiles", stopWordsFile);
runProps.setProperty("commonWordsPatternFiles", stopWordsFile);
runProps.putAll(props);
runProps.putAll(testProps);
props.putAll(runProps);
processText(false);
GetPatternsFromDataMultiClass<SurfacePattern> model = new GetPatternsFromDataMultiClass<>(runProps, Data.sents, seedWords, true, humanLabelClasses);
ArgumentParser.fillOptions(model, runProps);
GetPatternsFromDataMultiClass.loadFromSavedPatternsWordsDir(model, runProps);
Map<String, Integer> alreadyLearnedIters = new HashMap<>();
for (String label : model.constVars.getLabels()) alreadyLearnedIters.put(label, model.constVars.getLearnedWordsEachIter().get(label).lastEntry().getKey());
if (model.constVars.learn) {
// Map<String, E> p0 = new HashMap<String, SurfacePattern>();
// Map<String, Counter<CandidatePhrase>> p0Set = new HashMap<String, Counter<CandidatePhrase>>();
// Map<String, Set<E>> ignorePatterns = new HashMap<String, Set<E>>();
model.iterateExtractApply(null, null, null);
}
Map<String, Counter<CandidatePhrase>> allExtractions = new HashMap<>();
//Only for one label right now!
String label = model.constVars.getLabels().iterator().next();
allExtractions.put(label, new ClassicCounter<>());
for (Map.Entry<String, DataInstance> sent : Data.sents.entrySet()) {
StringBuffer str = new StringBuffer();
for (CoreLabel l : sent.getValue().getTokens()) {
if (l.get(PatternsAnnotations.MatchedPatterns.class) != null && !l.get(PatternsAnnotations.MatchedPatterns.class).isEmpty()) {
str.append(" " + l.word());
} else {
allExtractions.get(label).incrementCount(CandidatePhrase.createOrGet(str.toString().trim()));
str.setLength(0);
}
}
}
allExtractions.putAll(model.matchedSeedWords);
return model.constVars.getSetWordsAsJson(allExtractions);
}
use of edu.stanford.nlp.stats.ClassicCounter in project CoreNLP by stanfordnlp.
the class SimpleSentiment method train.
/**
* Train a sentiment model from a set of data.
*
* @param data The data to train the model from.
* @param modelLocation An optional location to save the model.
* Note that this stream will be closed in this method,
* and should not be written to thereafter.
*
* @return A sentiment classifier, ready to use.
*/
@SuppressWarnings({ "OptionalUsedAsFieldOrParameterType", "ConstantConditions" })
public static SimpleSentiment train(Stream<SentimentDatum> data, Optional<OutputStream> modelLocation) {
// Some useful variables configuring how we train
boolean useL1 = true;
double sigma = 1.0;
int featureCountThreshold = 5;
// Featurize the data
forceTrack("Featurizing");
RVFDataset<SentimentClass, String> dataset = new RVFDataset<>();
AtomicInteger datasize = new AtomicInteger(0);
Counter<SentimentClass> distribution = new ClassicCounter<>();
data.unordered().parallel().map(datum -> {
if (datasize.incrementAndGet() % 10000 == 0) {
log("Added " + datasize.get() + " datums");
}
return new RVFDatum<>(featurize(datum.asCoreMap()), datum.sentiment);
}).forEach(x -> {
synchronized (dataset) {
distribution.incrementCount(x.label());
dataset.add(x);
}
});
endTrack("Featurizing");
// Print label distribution
startTrack("Distribution");
for (SentimentClass label : SentimentClass.values()) {
log(String.format("%7d", (int) distribution.getCount(label)) + " " + label);
}
endTrack("Distribution");
// Train the classifier
forceTrack("Training");
if (featureCountThreshold > 1) {
dataset.applyFeatureCountThreshold(featureCountThreshold);
}
dataset.randomize(42L);
LinearClassifierFactory<SentimentClass, String> factory = new LinearClassifierFactory<>();
factory.setVerbose(true);
try {
factory.setMinimizerCreator(() -> {
QNMinimizer minimizer = new QNMinimizer();
if (useL1) {
minimizer.useOWLQN(true, 1 / (sigma * sigma));
} else {
factory.setSigma(sigma);
}
return minimizer;
});
} catch (Exception ignored) {
}
factory.setSigma(sigma);
LinearClassifier<SentimentClass, String> classifier = factory.trainClassifier(dataset);
// Optionally save the model
modelLocation.ifPresent(stream -> {
try {
ObjectOutputStream oos = new ObjectOutputStream(stream);
oos.writeObject(classifier);
oos.close();
} catch (IOException e) {
log.err("Could not save model to stream!");
}
});
endTrack("Training");
// Evaluate the model
forceTrack("Evaluating");
factory.setVerbose(false);
double sumAccuracy = 0.0;
Counter<SentimentClass> sumP = new ClassicCounter<>();
Counter<SentimentClass> sumR = new ClassicCounter<>();
int numFolds = 4;
for (int fold = 0; fold < numFolds; ++fold) {
Pair<GeneralDataset<SentimentClass, String>, GeneralDataset<SentimentClass, String>> trainTest = dataset.splitOutFold(fold, numFolds);
// convex objective, so this should be OK
LinearClassifier<SentimentClass, String> foldClassifier = factory.trainClassifierWithInitialWeights(trainTest.first, classifier);
sumAccuracy += foldClassifier.evaluateAccuracy(trainTest.second);
for (SentimentClass label : SentimentClass.values()) {
Pair<Double, Double> pr = foldClassifier.evaluatePrecisionAndRecall(trainTest.second, label);
sumP.incrementCount(label, pr.first);
sumP.incrementCount(label, pr.second);
}
}
DecimalFormat df = new DecimalFormat("0.000%");
log.info("----------");
double aveAccuracy = sumAccuracy / ((double) numFolds);
log.info("" + numFolds + "-fold accuracy: " + df.format(aveAccuracy));
log.info("");
for (SentimentClass label : SentimentClass.values()) {
double p = sumP.getCount(label) / numFolds;
double r = sumR.getCount(label) / numFolds;
log.info(label + " (P) = " + df.format(p));
log.info(label + " (R) = " + df.format(r));
log.info(label + " (F1) = " + df.format(2 * p * r / (p + r)));
log.info("");
}
log.info("----------");
endTrack("Evaluating");
// Return
return new SimpleSentiment(classifier);
}
use of edu.stanford.nlp.stats.ClassicCounter in project CoreNLP by stanfordnlp.
the class Dependencies method getTypedDependencyChains.
public static Counter<List<TypedDependency>> getTypedDependencyChains(List<TypedDependency> deps, int maxLength) {
Map<IndexedWord, List<TypedDependency>> govToDepMap = govToDepMap(deps);
Counter<List<TypedDependency>> tdc = new ClassicCounter<>();
for (IndexedWord gov : govToDepMap.keySet()) {
Set<List<TypedDependency>> maxChains = getGovMaxChains(govToDepMap, gov, maxLength);
for (List<TypedDependency> maxChain : maxChains) {
for (int i = 1; i <= maxChain.size(); i++) {
List<TypedDependency> chain = maxChain.subList(0, i);
tdc.incrementCount(chain);
}
}
}
return tdc;
}
use of edu.stanford.nlp.stats.ClassicCounter in project CoreNLP by stanfordnlp.
the class GenerateTrees method readGrammar.
public void readGrammar(BufferedReader bin) {
try {
String line;
Section section = Section.TERMINALS;
while ((line = bin.readLine()) != null) {
line = line.trim();
if (line.equals("")) {
continue;
}
if (line.length() > 0 && line.charAt(0) == '#') {
// skip comments
continue;
}
try {
Section newSection = Section.valueOf(line.toUpperCase());
section = newSection;
if (section == Section.TSURGEON) {
// this will tregex pattern until it has eaten a blank
// line, then read tsurgeon until it has eaten another
// blank line.
Pair<TregexPattern, TsurgeonPattern> operation = Tsurgeon.getOperationFromReader(bin, compiler);
tsurgeons.add(operation);
}
continue;
} catch (IllegalArgumentException e) {
// never mind, not an enum
}
String[] pieces = line.split(" +");
switch(section) {
case TSURGEON:
{
throw new RuntimeException("Found a non-empty line in a tsurgeon section after reading the operation");
}
case TERMINALS:
{
Counter<String> productions = terminals.get(pieces[0]);
if (productions == null) {
productions = new ClassicCounter<>();
terminals.put(pieces[0], productions);
}
for (int i = 1; i < pieces.length; ++i) {
productions.incrementCount(pieces[i]);
}
break;
}
case NONTERMINALS:
{
Counter<List<String>> productions = nonTerminals.get(pieces[0]);
if (productions == null) {
productions = new ClassicCounter<>();
nonTerminals.put(pieces[0], productions);
}
String[] sublist = Arrays.copyOfRange(pieces, 1, pieces.length);
productions.incrementCount(Arrays.asList(sublist));
}
}
}
} catch (IOException e) {
throw new RuntimeIOException(e);
}
}
Aggregations