use of edu.stanford.nlp.stats.ClassicCounter in project CoreNLP by stanfordnlp.
the class SisterAnnotationStats method main.
/**
* Calculate sister annotation statistics suitable for doing
* selective sister splitting in the PCFGParser inside the
* FactoredParser.
*
* @param args One argument: path to the Treebank
*/
public static void main(String[] args) {
ClassicCounter<String> c = new ClassicCounter<>();
c.setCount("A", 0);
c.setCount("B", 1);
double d = Counters.klDivergence(c, c);
System.out.println("KL Divergence: " + d);
String encoding = "UTF-8";
if (args.length > 1) {
encoding = args[1];
}
if (args.length < 1) {
System.out.println("Usage: ParentAnnotationStats treebankPath");
} else {
SisterAnnotationStats pas = new SisterAnnotationStats();
Treebank treebank = new DiskTreebank(in -> new PennTreeReader(in, new LabeledScoredTreeFactory(new StringLabelFactory()), new BobChrisTreeNormalizer()), encoding);
treebank.loadPath(args[0]);
treebank.apply(pas);
pas.printStats();
}
}
use of edu.stanford.nlp.stats.ClassicCounter in project CoreNLP by stanfordnlp.
the class SisterAnnotationStats method sisterCounters.
protected void sisterCounters(Tree t, Tree p) {
List rewrite = kidLabels(t);
List left = leftSisterLabels(t, p);
List right = rightSisterLabels(t, p);
String label = t.label().value();
if (!nodeRules.containsKey(label)) {
nodeRules.put(label, new ClassicCounter());
}
if (!rightRules.containsKey(label)) {
rightRules.put(label, new HashMap());
}
if (!leftRules.containsKey(label)) {
leftRules.put(label, new HashMap());
}
((ClassicCounter) nodeRules.get(label)).incrementCount(rewrite);
sideCounters(label, rewrite, left, leftRules);
sideCounters(label, rewrite, right, rightRules);
}
use of edu.stanford.nlp.stats.ClassicCounter in project CoreNLP by stanfordnlp.
the class SisterAnnotationStats method printStats.
public void printStats() {
NumberFormat nf = NumberFormat.getNumberInstance();
nf.setMaximumFractionDigits(2);
// System.out.println("Node rules");
// System.out.println(nodeRules);
// System.out.println("Parent rules");
// System.out.println(pRules);
// System.out.println("Grandparent rules");
// System.out.println(gPRules);
// Store java code for selSplit
StringBuffer[] javaSB = new StringBuffer[CUTOFFS.length];
for (int i = 0; i < CUTOFFS.length; i++) {
javaSB[i] = new StringBuffer(" private static String[] sisterSplit" + (i + 1) + " = new String[] {");
}
/** topScores contains all enriched categories, to be sorted
* later */
ArrayList topScores = new ArrayList();
for (Object o : nodeRules.keySet()) {
ArrayList answers = new ArrayList();
String label = (String) o;
ClassicCounter cntr = (ClassicCounter) nodeRules.get(label);
double support = (cntr.totalCount());
System.out.println("Node " + label + " support is " + support);
for (Object o4 : ((HashMap) leftRules.get(label)).keySet()) {
String sis = (String) o4;
ClassicCounter cntr2 = (ClassicCounter) ((HashMap) leftRules.get(label)).get(sis);
double support2 = (cntr2.totalCount());
/* alternative 1: use full distribution to calculate score */
double kl = Counters.klDivergence(cntr2, cntr);
/* alternative 2: hold out test-context data to calculate score */
/* this doesn't work because it can lead to zero-probability
* data points hence infinite divergence */
// Counter tempCounter = new Counter();
// tempCounter.addCounter(cntr2);
// for(Iterator i = tempCounter.seenSet().iterator(); i.hasNext();) {
// Object o = i.next();
// tempCounter.setCount(o,-1*tempCounter.countOf(o));
// }
// System.out.println(tempCounter); //debugging
// tempCounter.addCounter(cntr);
// System.out.println(tempCounter); //debugging
// System.out.println(cntr);
// double kl = cntr2.klDivergence(tempCounter);
/* alternative 2 ends here */
String annotatedLabel = label + "=l=" + sis;
System.out.println("KL(" + annotatedLabel + "||" + label + ") = " + nf.format(kl) + "\t" + "support(" + sis + ") = " + support2);
answers.add(new Pair(annotatedLabel, new Double(kl * support2)));
topScores.add(new Pair(annotatedLabel, new Double(kl * support2)));
}
for (Object o3 : ((HashMap) rightRules.get(label)).keySet()) {
String sis = (String) o3;
ClassicCounter cntr2 = (ClassicCounter) ((HashMap) rightRules.get(label)).get(sis);
double support2 = (cntr2.totalCount());
double kl = Counters.klDivergence(cntr2, cntr);
String annotatedLabel = label + "=r=" + sis;
System.out.println("KL(" + annotatedLabel + "||" + label + ") = " + nf.format(kl) + "\t" + "support(" + sis + ") = " + support2);
answers.add(new Pair(annotatedLabel, new Double(kl * support2)));
topScores.add(new Pair(annotatedLabel, new Double(kl * support2)));
}
// upto
System.out.println("----");
System.out.println("Sorted descending support * KL");
Collections.sort(answers, (o1, o2) -> {
Pair p1 = (Pair) o1;
Pair p2 = (Pair) o2;
Double p12 = (Double) p1.second();
Double p22 = (Double) p2.second();
return p22.compareTo(p12);
});
for (Object answer : answers) {
Pair p = (Pair) answer;
double psd = ((Double) p.second()).doubleValue();
System.out.println(p.first() + ": " + nf.format(psd));
if (psd >= CUTOFFS[0]) {
String annotatedLabel = (String) p.first();
for (double CUTOFF : CUTOFFS) {
if (psd >= CUTOFF) {
//javaSB[j].append("\"").append(annotatedLabel);
//javaSB[j].append("\",");
}
}
}
}
System.out.println();
}
Collections.sort(topScores, (o1, o2) -> {
Pair p1 = (Pair) o1;
Pair p2 = (Pair) o2;
Double p12 = (Double) p1.second();
Double p22 = (Double) p2.second();
return p22.compareTo(p12);
});
String outString = "All enriched categories, sorted by score\n";
for (Object topScore : topScores) {
Pair p = (Pair) topScore;
double psd = ((Double) p.second()).doubleValue();
System.out.println(p.first() + ": " + nf.format(psd));
}
System.out.println();
System.out.println(" // Automatically generated by SisterAnnotationStats -- preferably don't edit");
int k = CUTOFFS.length - 1;
for (int j = 0; j < topScores.size(); j++) {
Pair p = (Pair) topScores.get(j);
double psd = ((Double) p.second()).doubleValue();
if (psd < CUTOFFS[k]) {
if (k == 0) {
break;
} else {
k--;
// messy but should do it
j -= 1;
continue;
}
}
javaSB[k].append("\"").append(p.first());
javaSB[k].append("\",");
}
for (int i = 0; i < CUTOFFS.length; i++) {
int len = javaSB[i].length();
javaSB[i].replace(len - 2, len, "};");
System.out.println(javaSB[i]);
}
System.out.print(" public static String[] sisterSplit = ");
for (int i = CUTOFFS.length; i > 0; i--) {
if (i == 1) {
System.out.print("sisterSplit1");
} else {
System.out.print("selectiveSisterSplit" + i + " ? sisterSplit" + i + " : (");
}
}
// need to print extra one to close other things open
for (int i = CUTOFFS.length; i >= 0; i--) {
System.out.print(")");
}
System.out.println(";");
}
use of edu.stanford.nlp.stats.ClassicCounter in project CoreNLP by stanfordnlp.
the class GrammarCompactor method convertGraphsToGrammar.
/**
* @param graphs a Map from String categories to TransducerGraph objects
* @param unaryRules is a Set of UnaryRule objects that we need to add
* @param binaryRules is a Set of BinaryRule objects that we need to add
* @return a new Pair of UnaryGrammar, BinaryGrammar
*/
protected Pair<UnaryGrammar, BinaryGrammar> convertGraphsToGrammar(Set<TransducerGraph> graphs, Set<UnaryRule> unaryRules, Set<BinaryRule> binaryRules) {
// first go through all the existing rules and number them with new numberer
newStateIndex = new HashIndex<>();
for (UnaryRule rule : unaryRules) {
String parent = stateIndex.get(rule.parent);
rule.parent = newStateIndex.addToIndex(parent);
String child = stateIndex.get(rule.child);
rule.child = newStateIndex.addToIndex(child);
}
for (BinaryRule rule : binaryRules) {
String parent = stateIndex.get(rule.parent);
rule.parent = newStateIndex.addToIndex(parent);
String leftChild = stateIndex.get(rule.leftChild);
rule.leftChild = newStateIndex.addToIndex(leftChild);
String rightChild = stateIndex.get(rule.rightChild);
rule.rightChild = newStateIndex.addToIndex(rightChild);
}
// now go through the graphs and add the rules
for (TransducerGraph graph : graphs) {
Object startNode = graph.getStartNode();
for (Arc arc : graph.getArcs()) {
// TODO: make sure these are the strings we're looking for
String source = arc.getSourceNode().toString();
String target = arc.getTargetNode().toString();
Object input = arc.getInput();
String inputString = input.toString();
double output = ((Double) arc.getOutput()).doubleValue();
if (source.equals(startNode)) {
// make a UnaryRule
UnaryRule ur = new UnaryRule(newStateIndex.addToIndex(target), newStateIndex.addToIndex(inputString), smartNegate(output));
unaryRules.add(ur);
} else if (inputString.equals(END) || inputString.equals(EPSILON)) {
// make a UnaryRule
UnaryRule ur = new UnaryRule(newStateIndex.addToIndex(target), newStateIndex.addToIndex(source), smartNegate(output));
unaryRules.add(ur);
} else {
// make a BinaryRule
// figure out whether the input was generated on the left or right
int length = inputString.length();
char leftOrRight = inputString.charAt(length - 1);
inputString = inputString.substring(0, length - 1);
BinaryRule br;
if (leftOrRight == '<' || leftOrRight == '[') {
br = new BinaryRule(newStateIndex.addToIndex(target), newStateIndex.addToIndex(inputString), newStateIndex.addToIndex(source), smartNegate(output));
} else if (leftOrRight == '>' || leftOrRight == ']') {
br = new BinaryRule(newStateIndex.addToIndex(target), newStateIndex.addToIndex(source), newStateIndex.addToIndex(inputString), smartNegate(output));
} else {
throw new RuntimeException("Arc input is in unexpected format: " + arc);
}
binaryRules.add(br);
}
}
}
// by now, the unaryRules and binaryRules Sets have old untouched and new rules with scores
ClassicCounter<String> symbolCounter = new ClassicCounter<>();
if (outputType == RAW_COUNTS) {
// so we count parent symbol occurrences
for (UnaryRule rule : unaryRules) {
symbolCounter.incrementCount(newStateIndex.get(rule.parent), rule.score);
}
for (BinaryRule rule : binaryRules) {
symbolCounter.incrementCount(newStateIndex.get(rule.parent), rule.score);
}
}
// now we put the rules in the grammars
// this should be smaller than last one
int numStates = newStateIndex.size();
int numRules = 0;
UnaryGrammar ug = new UnaryGrammar(newStateIndex);
BinaryGrammar bg = new BinaryGrammar(newStateIndex);
for (UnaryRule rule : unaryRules) {
if (outputType == RAW_COUNTS) {
double count = symbolCounter.getCount(newStateIndex.get(rule.parent));
rule.score = (float) Math.log(rule.score / count);
}
ug.addRule(rule);
numRules++;
}
for (BinaryRule rule : binaryRules) {
if (outputType == RAW_COUNTS) {
double count = symbolCounter.getCount(newStateIndex.get(rule.parent));
rule.score = (float) Math.log((rule.score - op.trainOptions.ruleDiscount) / count);
}
bg.addRule(rule);
numRules++;
}
if (verbose) {
System.out.println("Number of minimized rules: " + numRules);
System.out.println("Number of minimized states: " + newStateIndex.size());
}
ug.purgeRules();
bg.splitRules();
return new Pair<>(ug, bg);
}
use of edu.stanford.nlp.stats.ClassicCounter in project CoreNLP by stanfordnlp.
the class ChineseCorefBenchmarkSlowITest method getCorefResults.
private static Counter<String> getCorefResults(String resultsString) throws IOException {
Counter<String> results = new ClassicCounter<String>();
BufferedReader r = new BufferedReader(new StringReader(resultsString));
for (String line; (line = r.readLine()) != null; ) {
Matcher m1 = MENTION_PATTERN.matcher(line);
if (m1.matches()) {
results.setCount(MENTION_TP, Double.parseDouble(m1.group(1)));
results.setCount(MENTION_F1, Double.parseDouble(m1.group(2)));
}
Matcher m2 = MUC_PATTERN.matcher(line);
if (m2.matches()) {
results.setCount(MUC_TP, Double.parseDouble(m2.group(1)));
results.setCount(MUC_F1, Double.parseDouble(m2.group(2)));
}
Matcher m3 = BCUBED_PATTERN.matcher(line);
if (m3.matches()) {
results.setCount(BCUBED_TP, Double.parseDouble(m3.group(1)));
results.setCount(BCUBED_F1, Double.parseDouble(m3.group(2)));
}
Matcher m4 = CEAFM_PATTERN.matcher(line);
if (m4.matches()) {
results.setCount(CEAFM_TP, Double.parseDouble(m4.group(1)));
results.setCount(CEAFM_F1, Double.parseDouble(m4.group(2)));
}
Matcher m5 = CEAFE_PATTERN.matcher(line);
if (m5.matches()) {
results.setCount(CEAFE_TP, Double.parseDouble(m5.group(1)));
results.setCount(CEAFE_F1, Double.parseDouble(m5.group(2)));
}
Matcher m6 = BLANC_PATTERN.matcher(line);
if (m6.matches()) {
results.setCount(BLANC_F1, Double.parseDouble(m6.group(1)));
}
Matcher m7 = CONLL_PATTERN.matcher(line);
if (m7.matches()) {
results.setCount(CONLL_SCORE, Double.parseDouble(m7.group(1)));
}
}
return results;
}
Aggregations