use of edu.stanford.nlp.stats.ClassicCounter in project CoreNLP by stanfordnlp.
the class ParentAnnotationStats method printStats.
public void printStats() {
NumberFormat nf = NumberFormat.getNumberInstance();
nf.setMaximumFractionDigits(2);
// System.out.println("Node rules");
// System.out.println(nodeRules);
// System.out.println("Parent rules");
// System.out.println(pRules);
// System.out.println("Grandparent rules");
// System.out.println(gPRules);
// Store java code for selSplit
StringBuffer[] javaSB = new StringBuffer[CUTOFFS.length];
for (int i = 0; i < CUTOFFS.length; i++) {
javaSB[i] = new StringBuffer(" private static String[] splitters" + (i + 1) + " = new String[] {");
}
ClassicCounter<List<String>> allScores = new ClassicCounter<>();
// do value of parent
for (String node : nodeRules.keySet()) {
ArrayList<Pair<List<String>, Double>> answers = Generics.newArrayList();
ClassicCounter<List<String>> cntr = nodeRules.get(node);
double support = (cntr.totalCount());
System.out.println("Node " + node + " support is " + support);
for (List<String> key : pRules.keySet()) {
if (key.get(0).equals(node)) {
// only do it if they match
ClassicCounter<List<String>> cntr2 = pRules.get(key);
double support2 = (cntr2.totalCount());
double kl = Counters.klDivergence(cntr2, cntr);
System.out.println("KL(" + key + "||" + node + ") = " + nf.format(kl) + "\t" + "support(" + key + ") = " + support2);
double score = kl * support2;
answers.add(new Pair<>(key, new Double(score)));
allScores.setCount(key, score);
}
}
System.out.println("----");
System.out.println("Sorted descending support * KL");
Collections.sort(answers, (o1, o2) -> o2.second().compareTo(o1.second()));
for (Pair<List<String>, Double> answer : answers) {
Pair p = (Pair) answer;
double psd = ((Double) p.second()).doubleValue();
System.out.println(p.first() + ": " + nf.format(psd));
if (psd >= CUTOFFS[0]) {
List lst = (List) p.first();
String nd = (String) lst.get(0);
String par = (String) lst.get(1);
for (int j = 0; j < CUTOFFS.length; j++) {
if (psd >= CUTOFFS[j]) {
javaSB[j].append("\"").append(nd).append("^");
javaSB[j].append(par).append("\", ");
}
}
}
}
System.out.println();
}
// do value of grandparent
for (List<String> node : pRules.keySet()) {
ArrayList<Pair<List<String>, Double>> answers = Generics.newArrayList();
ClassicCounter<List<String>> cntr = pRules.get(node);
double support = (cntr.totalCount());
if (support < SUPPCUTOFF) {
continue;
}
System.out.println("Node " + node + " support is " + support);
for (List<String> key : gPRules.keySet()) {
if (key.get(0).equals(node.get(0)) && key.get(1).equals(node.get(1))) {
// only do it if they match
ClassicCounter<List<String>> cntr2 = gPRules.get(key);
double support2 = (cntr2.totalCount());
double kl = Counters.klDivergence(cntr2, cntr);
System.out.println("KL(" + key + "||" + node + ") = " + nf.format(kl) + "\t" + "support(" + key + ") = " + support2);
double score = kl * support2;
answers.add(Pair.makePair(key, new Double(score)));
allScores.setCount(key, score);
}
}
System.out.println("----");
System.out.println("Sorted descending support * KL");
Collections.sort(answers, (o1, o2) -> o2.second().compareTo(o1.second()));
for (Pair<List<String>, Double> answer : answers) {
Pair p = (Pair) answer;
double psd = ((Double) p.second()).doubleValue();
System.out.println(p.first() + ": " + nf.format(psd));
if (psd >= CUTOFFS[0]) {
List lst = (List) p.first();
String nd = (String) lst.get(0);
String par = (String) lst.get(1);
String gpar = (String) lst.get(2);
for (int j = 0; j < CUTOFFS.length; j++) {
if (psd >= CUTOFFS[j]) {
javaSB[j].append("\"").append(nd).append("^");
javaSB[j].append(par).append("~");
javaSB[j].append(gpar).append("\", ");
}
}
}
}
System.out.println();
}
System.out.println();
System.out.println("All scores:");
edu.stanford.nlp.util.PriorityQueue<List<String>> pq = Counters.toPriorityQueue(allScores);
while (!pq.isEmpty()) {
List<String> key = pq.getFirst();
double score = pq.getPriority(key);
pq.removeFirst();
System.out.println(key + "\t" + score);
}
System.out.println(" // Automatically generated by ParentAnnotationStats -- preferably don't edit");
for (int i = 0; i < CUTOFFS.length; i++) {
int len = javaSB[i].length();
javaSB[i].replace(len - 2, len, "};");
System.out.println(javaSB[i]);
}
System.out.print(" public static HashSet splitters = new HashSet(Arrays.asList(");
for (int i = CUTOFFS.length; i > 0; i--) {
if (i == 1) {
System.out.print("splitters1");
} else {
System.out.print("selectiveSplit" + i + " ? splitters" + i + " : (");
}
}
// need to print extra one to close other things open
for (int i = CUTOFFS.length; i >= 0; i--) {
System.out.print(")");
}
System.out.println(";");
}
use of edu.stanford.nlp.stats.ClassicCounter in project CoreNLP by stanfordnlp.
the class RandomWalk method train.
public void train(Collection<Pair<?, ?>> data) {
for (Pair p : data) {
Object seen = p.first();
Object hidden = p.second();
if (!hiddenToSeen.keySet().contains(hidden)) {
hiddenToSeen.put(hidden, new ClassicCounter());
}
hiddenToSeen.get(hidden).incrementCount(seen);
if (!seenToHidden.keySet().contains(seen)) {
seenToHidden.put(seen, new ClassicCounter());
}
seenToHidden.get(seen).incrementCount(hidden);
}
}
use of edu.stanford.nlp.stats.ClassicCounter in project CoreNLP by stanfordnlp.
the class SisterAnnotationStats method main.
/**
* Calculate sister annotation statistics suitable for doing
* selective sister splitting in the PCFGParser inside the
* FactoredParser.
*
* @param args One argument: path to the Treebank
*/
public static void main(String[] args) {
ClassicCounter<String> c = new ClassicCounter<>();
c.setCount("A", 0);
c.setCount("B", 1);
double d = Counters.klDivergence(c, c);
System.out.println("KL Divergence: " + d);
String encoding = "UTF-8";
if (args.length > 1) {
encoding = args[1];
}
if (args.length < 1) {
System.out.println("Usage: ParentAnnotationStats treebankPath");
} else {
SisterAnnotationStats pas = new SisterAnnotationStats();
Treebank treebank = new DiskTreebank(in -> new PennTreeReader(in, new LabeledScoredTreeFactory(new StringLabelFactory()), new BobChrisTreeNormalizer()), encoding);
treebank.loadPath(args[0]);
treebank.apply(pas);
pas.printStats();
}
}
use of edu.stanford.nlp.stats.ClassicCounter in project CoreNLP by stanfordnlp.
the class SisterAnnotationStats method sisterCounters.
protected void sisterCounters(Tree t, Tree p) {
List rewrite = kidLabels(t);
List left = leftSisterLabels(t, p);
List right = rightSisterLabels(t, p);
String label = t.label().value();
if (!nodeRules.containsKey(label)) {
nodeRules.put(label, new ClassicCounter());
}
if (!rightRules.containsKey(label)) {
rightRules.put(label, new HashMap());
}
if (!leftRules.containsKey(label)) {
leftRules.put(label, new HashMap());
}
((ClassicCounter) nodeRules.get(label)).incrementCount(rewrite);
sideCounters(label, rewrite, left, leftRules);
sideCounters(label, rewrite, right, rightRules);
}
use of edu.stanford.nlp.stats.ClassicCounter in project CoreNLP by stanfordnlp.
the class SisterAnnotationStats method printStats.
public void printStats() {
NumberFormat nf = NumberFormat.getNumberInstance();
nf.setMaximumFractionDigits(2);
// System.out.println("Node rules");
// System.out.println(nodeRules);
// System.out.println("Parent rules");
// System.out.println(pRules);
// System.out.println("Grandparent rules");
// System.out.println(gPRules);
// Store java code for selSplit
StringBuffer[] javaSB = new StringBuffer[CUTOFFS.length];
for (int i = 0; i < CUTOFFS.length; i++) {
javaSB[i] = new StringBuffer(" private static String[] sisterSplit" + (i + 1) + " = new String[] {");
}
/** topScores contains all enriched categories, to be sorted
* later */
ArrayList topScores = new ArrayList();
for (Object o : nodeRules.keySet()) {
ArrayList answers = new ArrayList();
String label = (String) o;
ClassicCounter cntr = (ClassicCounter) nodeRules.get(label);
double support = (cntr.totalCount());
System.out.println("Node " + label + " support is " + support);
for (Object o4 : ((HashMap) leftRules.get(label)).keySet()) {
String sis = (String) o4;
ClassicCounter cntr2 = (ClassicCounter) ((HashMap) leftRules.get(label)).get(sis);
double support2 = (cntr2.totalCount());
/* alternative 1: use full distribution to calculate score */
double kl = Counters.klDivergence(cntr2, cntr);
/* alternative 2: hold out test-context data to calculate score */
/* this doesn't work because it can lead to zero-probability
* data points hence infinite divergence */
// Counter tempCounter = new Counter();
// tempCounter.addCounter(cntr2);
// for(Iterator i = tempCounter.seenSet().iterator(); i.hasNext();) {
// Object o = i.next();
// tempCounter.setCount(o,-1*tempCounter.countOf(o));
// }
// System.out.println(tempCounter); //debugging
// tempCounter.addCounter(cntr);
// System.out.println(tempCounter); //debugging
// System.out.println(cntr);
// double kl = cntr2.klDivergence(tempCounter);
/* alternative 2 ends here */
String annotatedLabel = label + "=l=" + sis;
System.out.println("KL(" + annotatedLabel + "||" + label + ") = " + nf.format(kl) + "\t" + "support(" + sis + ") = " + support2);
answers.add(new Pair(annotatedLabel, new Double(kl * support2)));
topScores.add(new Pair(annotatedLabel, new Double(kl * support2)));
}
for (Object o3 : ((HashMap) rightRules.get(label)).keySet()) {
String sis = (String) o3;
ClassicCounter cntr2 = (ClassicCounter) ((HashMap) rightRules.get(label)).get(sis);
double support2 = (cntr2.totalCount());
double kl = Counters.klDivergence(cntr2, cntr);
String annotatedLabel = label + "=r=" + sis;
System.out.println("KL(" + annotatedLabel + "||" + label + ") = " + nf.format(kl) + "\t" + "support(" + sis + ") = " + support2);
answers.add(new Pair(annotatedLabel, new Double(kl * support2)));
topScores.add(new Pair(annotatedLabel, new Double(kl * support2)));
}
// upto
System.out.println("----");
System.out.println("Sorted descending support * KL");
Collections.sort(answers, (o1, o2) -> {
Pair p1 = (Pair) o1;
Pair p2 = (Pair) o2;
Double p12 = (Double) p1.second();
Double p22 = (Double) p2.second();
return p22.compareTo(p12);
});
for (Object answer : answers) {
Pair p = (Pair) answer;
double psd = ((Double) p.second()).doubleValue();
System.out.println(p.first() + ": " + nf.format(psd));
if (psd >= CUTOFFS[0]) {
String annotatedLabel = (String) p.first();
for (double CUTOFF : CUTOFFS) {
if (psd >= CUTOFF) {
//javaSB[j].append("\"").append(annotatedLabel);
//javaSB[j].append("\",");
}
}
}
}
System.out.println();
}
Collections.sort(topScores, (o1, o2) -> {
Pair p1 = (Pair) o1;
Pair p2 = (Pair) o2;
Double p12 = (Double) p1.second();
Double p22 = (Double) p2.second();
return p22.compareTo(p12);
});
String outString = "All enriched categories, sorted by score\n";
for (Object topScore : topScores) {
Pair p = (Pair) topScore;
double psd = ((Double) p.second()).doubleValue();
System.out.println(p.first() + ": " + nf.format(psd));
}
System.out.println();
System.out.println(" // Automatically generated by SisterAnnotationStats -- preferably don't edit");
int k = CUTOFFS.length - 1;
for (int j = 0; j < topScores.size(); j++) {
Pair p = (Pair) topScores.get(j);
double psd = ((Double) p.second()).doubleValue();
if (psd < CUTOFFS[k]) {
if (k == 0) {
break;
} else {
k--;
// messy but should do it
j -= 1;
continue;
}
}
javaSB[k].append("\"").append(p.first());
javaSB[k].append("\",");
}
for (int i = 0; i < CUTOFFS.length; i++) {
int len = javaSB[i].length();
javaSB[i].replace(len - 2, len, "};");
System.out.println(javaSB[i]);
}
System.out.print(" public static String[] sisterSplit = ");
for (int i = CUTOFFS.length; i > 0; i--) {
if (i == 1) {
System.out.print("sisterSplit1");
} else {
System.out.print("selectiveSisterSplit" + i + " ? sisterSplit" + i + " : (");
}
}
// need to print extra one to close other things open
for (int i = CUTOFFS.length; i >= 0; i--) {
System.out.print(")");
}
System.out.println(";");
}
Aggregations