Search in sources :

Example 1 with ClassicCounter

use of edu.stanford.nlp.stats.ClassicCounter in project CoreNLP by stanfordnlp.

the class ParentAnnotationStats method printStats.

public void printStats() {
    NumberFormat nf = NumberFormat.getNumberInstance();
    // System.out.println("Node rules");
    // System.out.println(nodeRules);
    // System.out.println("Parent rules");
    // System.out.println(pRules);
    // System.out.println("Grandparent rules");
    // System.out.println(gPRules);
    // Store java code for selSplit
    StringBuffer[] javaSB = new StringBuffer[CUTOFFS.length];
    for (int i = 0; i < CUTOFFS.length; i++) {
        javaSB[i] = new StringBuffer("  private static String[] splitters" + (i + 1) + " = new String[] {");
    ClassicCounter<List<String>> allScores = new ClassicCounter<>();
    // do value of parent
    for (String node : nodeRules.keySet()) {
        ArrayList<Pair<List<String>, Double>> answers = Generics.newArrayList();
        ClassicCounter<List<String>> cntr = nodeRules.get(node);
        double support = (cntr.totalCount());
        System.out.println("Node " + node + " support is " + support);
        for (List<String> key : pRules.keySet()) {
            if (key.get(0).equals(node)) {
                // only do it if they match
                ClassicCounter<List<String>> cntr2 = pRules.get(key);
                double support2 = (cntr2.totalCount());
                double kl = Counters.klDivergence(cntr2, cntr);
                System.out.println("KL(" + key + "||" + node + ") = " + nf.format(kl) + "\t" + "support(" + key + ") = " + support2);
                double score = kl * support2;
                answers.add(new Pair<>(key, new Double(score)));
                allScores.setCount(key, score);
        System.out.println("Sorted descending support * KL");
        Collections.sort(answers, (o1, o2) -> o2.second().compareTo(o1.second()));
        for (Pair<List<String>, Double> answer : answers) {
            Pair p = (Pair) answer;
            double psd = ((Double) p.second()).doubleValue();
            System.out.println(p.first() + ": " + nf.format(psd));
            if (psd >= CUTOFFS[0]) {
                List lst = (List) p.first();
                String nd = (String) lst.get(0);
                String par = (String) lst.get(1);
                for (int j = 0; j < CUTOFFS.length; j++) {
                    if (psd >= CUTOFFS[j]) {
                        javaSB[j].append(par).append("\", ");
    // do value of grandparent
    for (List<String> node : pRules.keySet()) {
        ArrayList<Pair<List<String>, Double>> answers = Generics.newArrayList();
        ClassicCounter<List<String>> cntr = pRules.get(node);
        double support = (cntr.totalCount());
        if (support < SUPPCUTOFF) {
        System.out.println("Node " + node + " support is " + support);
        for (List<String> key : gPRules.keySet()) {
            if (key.get(0).equals(node.get(0)) && key.get(1).equals(node.get(1))) {
                // only do it if they match
                ClassicCounter<List<String>> cntr2 = gPRules.get(key);
                double support2 = (cntr2.totalCount());
                double kl = Counters.klDivergence(cntr2, cntr);
                System.out.println("KL(" + key + "||" + node + ") = " + nf.format(kl) + "\t" + "support(" + key + ") = " + support2);
                double score = kl * support2;
                answers.add(Pair.makePair(key, new Double(score)));
                allScores.setCount(key, score);
        System.out.println("Sorted descending support * KL");
        Collections.sort(answers, (o1, o2) -> o2.second().compareTo(o1.second()));
        for (Pair<List<String>, Double> answer : answers) {
            Pair p = (Pair) answer;
            double psd = ((Double) p.second()).doubleValue();
            System.out.println(p.first() + ": " + nf.format(psd));
            if (psd >= CUTOFFS[0]) {
                List lst = (List) p.first();
                String nd = (String) lst.get(0);
                String par = (String) lst.get(1);
                String gpar = (String) lst.get(2);
                for (int j = 0; j < CUTOFFS.length; j++) {
                    if (psd >= CUTOFFS[j]) {
                        javaSB[j].append(gpar).append("\", ");
    System.out.println("All scores:");
    edu.stanford.nlp.util.PriorityQueue<List<String>> pq = Counters.toPriorityQueue(allScores);
    while (!pq.isEmpty()) {
        List<String> key = pq.getFirst();
        double score = pq.getPriority(key);
        System.out.println(key + "\t" + score);
    System.out.println("  // Automatically generated by ParentAnnotationStats -- preferably don't edit");
    for (int i = 0; i < CUTOFFS.length; i++) {
        int len = javaSB[i].length();
        javaSB[i].replace(len - 2, len, "};");
    System.out.print("  public static HashSet splitters = new HashSet(Arrays.asList(");
    for (int i = CUTOFFS.length; i > 0; i--) {
        if (i == 1) {
        } else {
            System.out.print("selectiveSplit" + i + " ? splitters" + i + " : (");
    // need to print extra one to close other things open
    for (int i = CUTOFFS.length; i >= 0; i--) {
Also used : java.util(java.util) ClassicCounter(edu.stanford.nlp.stats.ClassicCounter) NumberFormat(java.text.NumberFormat) Pair(edu.stanford.nlp.util.Pair)

Example 2 with ClassicCounter

use of edu.stanford.nlp.stats.ClassicCounter in project CoreNLP by stanfordnlp.

the class RandomWalk method train.

public void train(Collection<Pair<?, ?>> data) {
    for (Pair p : data) {
        Object seen = p.first();
        Object hidden = p.second();
        if (!hiddenToSeen.keySet().contains(hidden)) {
            hiddenToSeen.put(hidden, new ClassicCounter());
        if (!seenToHidden.keySet().contains(seen)) {
            seenToHidden.put(seen, new ClassicCounter());
Also used : ClassicCounter(edu.stanford.nlp.stats.ClassicCounter) Pair(edu.stanford.nlp.util.Pair)

Example 3 with ClassicCounter

use of edu.stanford.nlp.stats.ClassicCounter in project CoreNLP by stanfordnlp.

the class SisterAnnotationStats method main.

   * Calculate sister annotation statistics suitable for doing
   * selective sister splitting in the PCFGParser inside the
   * FactoredParser.
   * @param args One argument: path to the Treebank
public static void main(String[] args) {
    ClassicCounter<String> c = new ClassicCounter<>();
    c.setCount("A", 0);
    c.setCount("B", 1);
    double d = Counters.klDivergence(c, c);
    System.out.println("KL Divergence: " + d);
    String encoding = "UTF-8";
    if (args.length > 1) {
        encoding = args[1];
    if (args.length < 1) {
        System.out.println("Usage: ParentAnnotationStats treebankPath");
    } else {
        SisterAnnotationStats pas = new SisterAnnotationStats();
        Treebank treebank = new DiskTreebank(in -> new PennTreeReader(in, new LabeledScoredTreeFactory(new StringLabelFactory()), new BobChrisTreeNormalizer()), encoding);
Also used : StringLabelFactory(edu.stanford.nlp.ling.StringLabelFactory) ClassicCounter(edu.stanford.nlp.stats.ClassicCounter)

Example 4 with ClassicCounter

use of edu.stanford.nlp.stats.ClassicCounter in project CoreNLP by stanfordnlp.

the class SisterAnnotationStats method sisterCounters.

protected void sisterCounters(Tree t, Tree p) {
    List rewrite = kidLabels(t);
    List left = leftSisterLabels(t, p);
    List right = rightSisterLabels(t, p);
    String label = t.label().value();
    if (!nodeRules.containsKey(label)) {
        nodeRules.put(label, new ClassicCounter());
    if (!rightRules.containsKey(label)) {
        rightRules.put(label, new HashMap());
    if (!leftRules.containsKey(label)) {
        leftRules.put(label, new HashMap());
    ((ClassicCounter) nodeRules.get(label)).incrementCount(rewrite);
    sideCounters(label, rewrite, left, leftRules);
    sideCounters(label, rewrite, right, rightRules);
Also used : ClassicCounter(edu.stanford.nlp.stats.ClassicCounter)

Example 5 with ClassicCounter

use of edu.stanford.nlp.stats.ClassicCounter in project CoreNLP by stanfordnlp.

the class SisterAnnotationStats method printStats.

public void printStats() {
    NumberFormat nf = NumberFormat.getNumberInstance();
    // System.out.println("Node rules");
    // System.out.println(nodeRules);
    // System.out.println("Parent rules");
    // System.out.println(pRules);
    // System.out.println("Grandparent rules");
    // System.out.println(gPRules);
    // Store java code for selSplit
    StringBuffer[] javaSB = new StringBuffer[CUTOFFS.length];
    for (int i = 0; i < CUTOFFS.length; i++) {
        javaSB[i] = new StringBuffer("  private static String[] sisterSplit" + (i + 1) + " = new String[] {");
    /** topScores contains all enriched categories, to be sorted
     * later */
    ArrayList topScores = new ArrayList();
    for (Object o : nodeRules.keySet()) {
        ArrayList answers = new ArrayList();
        String label = (String) o;
        ClassicCounter cntr = (ClassicCounter) nodeRules.get(label);
        double support = (cntr.totalCount());
        System.out.println("Node " + label + " support is " + support);
        for (Object o4 : ((HashMap) leftRules.get(label)).keySet()) {
            String sis = (String) o4;
            ClassicCounter cntr2 = (ClassicCounter) ((HashMap) leftRules.get(label)).get(sis);
            double support2 = (cntr2.totalCount());
            /* alternative 1: use full distribution to calculate score */
            double kl = Counters.klDivergence(cntr2, cntr);
            /* alternative 2: hold out test-context data to calculate score */
            /* this doesn't work because it can lead to zero-probability
         * data points hence infinite divergence */
            // 	Counter tempCounter = new Counter();
            // 	tempCounter.addCounter(cntr2);
            // 	for(Iterator i = tempCounter.seenSet().iterator(); i.hasNext();) {
            // 	  Object o =;
            // 	  tempCounter.setCount(o,-1*tempCounter.countOf(o));
            // 	}
            // 	System.out.println(tempCounter); //debugging
            // 	tempCounter.addCounter(cntr);
            // 	System.out.println(tempCounter); //debugging
            // 	System.out.println(cntr);
            // 	double kl = cntr2.klDivergence(tempCounter);
            /* alternative 2 ends here */
            String annotatedLabel = label + "=l=" + sis;
            System.out.println("KL(" + annotatedLabel + "||" + label + ") = " + nf.format(kl) + "\t" + "support(" + sis + ") = " + support2);
            answers.add(new Pair(annotatedLabel, new Double(kl * support2)));
            topScores.add(new Pair(annotatedLabel, new Double(kl * support2)));
        for (Object o3 : ((HashMap) rightRules.get(label)).keySet()) {
            String sis = (String) o3;
            ClassicCounter cntr2 = (ClassicCounter) ((HashMap) rightRules.get(label)).get(sis);
            double support2 = (cntr2.totalCount());
            double kl = Counters.klDivergence(cntr2, cntr);
            String annotatedLabel = label + "=r=" + sis;
            System.out.println("KL(" + annotatedLabel + "||" + label + ") = " + nf.format(kl) + "\t" + "support(" + sis + ") = " + support2);
            answers.add(new Pair(annotatedLabel, new Double(kl * support2)));
            topScores.add(new Pair(annotatedLabel, new Double(kl * support2)));
        // upto
        System.out.println("Sorted descending support * KL");
        Collections.sort(answers, (o1, o2) -> {
            Pair p1 = (Pair) o1;
            Pair p2 = (Pair) o2;
            Double p12 = (Double) p1.second();
            Double p22 = (Double) p2.second();
            return p22.compareTo(p12);
        for (Object answer : answers) {
            Pair p = (Pair) answer;
            double psd = ((Double) p.second()).doubleValue();
            System.out.println(p.first() + ": " + nf.format(psd));
            if (psd >= CUTOFFS[0]) {
                String annotatedLabel = (String) p.first();
                for (double CUTOFF : CUTOFFS) {
                    if (psd >= CUTOFF) {
    Collections.sort(topScores, (o1, o2) -> {
        Pair p1 = (Pair) o1;
        Pair p2 = (Pair) o2;
        Double p12 = (Double) p1.second();
        Double p22 = (Double) p2.second();
        return p22.compareTo(p12);
    String outString = "All enriched categories, sorted by score\n";
    for (Object topScore : topScores) {
        Pair p = (Pair) topScore;
        double psd = ((Double) p.second()).doubleValue();
        System.out.println(p.first() + ": " + nf.format(psd));
    System.out.println("  // Automatically generated by SisterAnnotationStats -- preferably don't edit");
    int k = CUTOFFS.length - 1;
    for (int j = 0; j < topScores.size(); j++) {
        Pair p = (Pair) topScores.get(j);
        double psd = ((Double) p.second()).doubleValue();
        if (psd < CUTOFFS[k]) {
            if (k == 0) {
            } else {
                // messy but should do it
                j -= 1;
    for (int i = 0; i < CUTOFFS.length; i++) {
        int len = javaSB[i].length();
        javaSB[i].replace(len - 2, len, "};");
    System.out.print("  public static String[] sisterSplit = ");
    for (int i = CUTOFFS.length; i > 0; i--) {
        if (i == 1) {
        } else {
            System.out.print("selectiveSisterSplit" + i + " ? sisterSplit" + i + " : (");
    // need to print extra one to close other things open
    for (int i = CUTOFFS.length; i >= 0; i--) {
Also used : ClassicCounter(edu.stanford.nlp.stats.ClassicCounter) NumberFormat(java.text.NumberFormat) Pair(edu.stanford.nlp.util.Pair)


ClassicCounter (edu.stanford.nlp.stats.ClassicCounter)69 CoreLabel (edu.stanford.nlp.ling.CoreLabel)27 ArrayList (java.util.ArrayList)21 CoreAnnotations (edu.stanford.nlp.ling.CoreAnnotations)18 Tree (edu.stanford.nlp.trees.Tree)13 Pair (edu.stanford.nlp.util.Pair)11 Counter (edu.stanford.nlp.stats.Counter)10 List (java.util.List)10 Mention ( Language ( RuntimeIOException ( CoreMap (edu.stanford.nlp.util.CoreMap)7 IOUtils ( Label (edu.stanford.nlp.ling.Label)6 TreebankLangParserParams (edu.stanford.nlp.parser.lexparser.TreebankLangParserParams)6 PrintWriter ( java.util (java.util)6 HashSet (java.util.HashSet)6 RVFDatum (edu.stanford.nlp.ling.RVFDatum)5 DiskTreebank (edu.stanford.nlp.trees.DiskTreebank)5