Search in sources :

Example 1 with ClassicCounter

use of edu.stanford.nlp.stats.ClassicCounter in project CoreNLP by stanfordnlp.

the class ParentAnnotationStats method printStats.

public void printStats() {
    NumberFormat nf = NumberFormat.getNumberInstance();
    nf.setMaximumFractionDigits(2);
    // System.out.println("Node rules");
    // System.out.println(nodeRules);
    // System.out.println("Parent rules");
    // System.out.println(pRules);
    // System.out.println("Grandparent rules");
    // System.out.println(gPRules);
    // Store java code for selSplit
    StringBuffer[] javaSB = new StringBuffer[CUTOFFS.length];
    for (int i = 0; i < CUTOFFS.length; i++) {
        javaSB[i] = new StringBuffer("  private static String[] splitters" + (i + 1) + " = new String[] {");
    }
    ClassicCounter<List<String>> allScores = new ClassicCounter<>();
    // do value of parent
    for (String node : nodeRules.keySet()) {
        ArrayList<Pair<List<String>, Double>> answers = Generics.newArrayList();
        ClassicCounter<List<String>> cntr = nodeRules.get(node);
        double support = (cntr.totalCount());
        System.out.println("Node " + node + " support is " + support);
        for (List<String> key : pRules.keySet()) {
            if (key.get(0).equals(node)) {
                // only do it if they match
                ClassicCounter<List<String>> cntr2 = pRules.get(key);
                double support2 = (cntr2.totalCount());
                double kl = Counters.klDivergence(cntr2, cntr);
                System.out.println("KL(" + key + "||" + node + ") = " + nf.format(kl) + "\t" + "support(" + key + ") = " + support2);
                double score = kl * support2;
                answers.add(new Pair<>(key, new Double(score)));
                allScores.setCount(key, score);
            }
        }
        System.out.println("----");
        System.out.println("Sorted descending support * KL");
        Collections.sort(answers, (o1, o2) -> o2.second().compareTo(o1.second()));
        for (Pair<List<String>, Double> answer : answers) {
            Pair p = (Pair) answer;
            double psd = ((Double) p.second()).doubleValue();
            System.out.println(p.first() + ": " + nf.format(psd));
            if (psd >= CUTOFFS[0]) {
                List lst = (List) p.first();
                String nd = (String) lst.get(0);
                String par = (String) lst.get(1);
                for (int j = 0; j < CUTOFFS.length; j++) {
                    if (psd >= CUTOFFS[j]) {
                        javaSB[j].append("\"").append(nd).append("^");
                        javaSB[j].append(par).append("\", ");
                    }
                }
            }
        }
        System.out.println();
    }
    // do value of grandparent
    for (List<String> node : pRules.keySet()) {
        ArrayList<Pair<List<String>, Double>> answers = Generics.newArrayList();
        ClassicCounter<List<String>> cntr = pRules.get(node);
        double support = (cntr.totalCount());
        if (support < SUPPCUTOFF) {
            continue;
        }
        System.out.println("Node " + node + " support is " + support);
        for (List<String> key : gPRules.keySet()) {
            if (key.get(0).equals(node.get(0)) && key.get(1).equals(node.get(1))) {
                // only do it if they match
                ClassicCounter<List<String>> cntr2 = gPRules.get(key);
                double support2 = (cntr2.totalCount());
                double kl = Counters.klDivergence(cntr2, cntr);
                System.out.println("KL(" + key + "||" + node + ") = " + nf.format(kl) + "\t" + "support(" + key + ") = " + support2);
                double score = kl * support2;
                answers.add(Pair.makePair(key, new Double(score)));
                allScores.setCount(key, score);
            }
        }
        System.out.println("----");
        System.out.println("Sorted descending support * KL");
        Collections.sort(answers, (o1, o2) -> o2.second().compareTo(o1.second()));
        for (Pair<List<String>, Double> answer : answers) {
            Pair p = (Pair) answer;
            double psd = ((Double) p.second()).doubleValue();
            System.out.println(p.first() + ": " + nf.format(psd));
            if (psd >= CUTOFFS[0]) {
                List lst = (List) p.first();
                String nd = (String) lst.get(0);
                String par = (String) lst.get(1);
                String gpar = (String) lst.get(2);
                for (int j = 0; j < CUTOFFS.length; j++) {
                    if (psd >= CUTOFFS[j]) {
                        javaSB[j].append("\"").append(nd).append("^");
                        javaSB[j].append(par).append("~");
                        javaSB[j].append(gpar).append("\", ");
                    }
                }
            }
        }
        System.out.println();
    }
    System.out.println();
    System.out.println("All scores:");
    edu.stanford.nlp.util.PriorityQueue<List<String>> pq = Counters.toPriorityQueue(allScores);
    while (!pq.isEmpty()) {
        List<String> key = pq.getFirst();
        double score = pq.getPriority(key);
        pq.removeFirst();
        System.out.println(key + "\t" + score);
    }
    System.out.println("  // Automatically generated by ParentAnnotationStats -- preferably don't edit");
    for (int i = 0; i < CUTOFFS.length; i++) {
        int len = javaSB[i].length();
        javaSB[i].replace(len - 2, len, "};");
        System.out.println(javaSB[i]);
    }
    System.out.print("  public static HashSet splitters = new HashSet(Arrays.asList(");
    for (int i = CUTOFFS.length; i > 0; i--) {
        if (i == 1) {
            System.out.print("splitters1");
        } else {
            System.out.print("selectiveSplit" + i + " ? splitters" + i + " : (");
        }
    }
    // need to print extra one to close other things open
    for (int i = CUTOFFS.length; i >= 0; i--) {
        System.out.print(")");
    }
    System.out.println(";");
}
Also used : java.util(java.util) ClassicCounter(edu.stanford.nlp.stats.ClassicCounter) NumberFormat(java.text.NumberFormat) Pair(edu.stanford.nlp.util.Pair)

Example 2 with ClassicCounter

use of edu.stanford.nlp.stats.ClassicCounter in project CoreNLP by stanfordnlp.

the class RandomWalk method train.

public void train(Collection<Pair<?, ?>> data) {
    for (Pair p : data) {
        Object seen = p.first();
        Object hidden = p.second();
        if (!hiddenToSeen.keySet().contains(hidden)) {
            hiddenToSeen.put(hidden, new ClassicCounter());
        }
        hiddenToSeen.get(hidden).incrementCount(seen);
        if (!seenToHidden.keySet().contains(seen)) {
            seenToHidden.put(seen, new ClassicCounter());
        }
        seenToHidden.get(seen).incrementCount(hidden);
    }
}
Also used : ClassicCounter(edu.stanford.nlp.stats.ClassicCounter) Pair(edu.stanford.nlp.util.Pair)

Example 3 with ClassicCounter

use of edu.stanford.nlp.stats.ClassicCounter in project CoreNLP by stanfordnlp.

the class SisterAnnotationStats method main.

/**
   * Calculate sister annotation statistics suitable for doing
   * selective sister splitting in the PCFGParser inside the
   * FactoredParser.
   *
   * @param args One argument: path to the Treebank
   */
public static void main(String[] args) {
    ClassicCounter<String> c = new ClassicCounter<>();
    c.setCount("A", 0);
    c.setCount("B", 1);
    double d = Counters.klDivergence(c, c);
    System.out.println("KL Divergence: " + d);
    String encoding = "UTF-8";
    if (args.length > 1) {
        encoding = args[1];
    }
    if (args.length < 1) {
        System.out.println("Usage: ParentAnnotationStats treebankPath");
    } else {
        SisterAnnotationStats pas = new SisterAnnotationStats();
        Treebank treebank = new DiskTreebank(in -> new PennTreeReader(in, new LabeledScoredTreeFactory(new StringLabelFactory()), new BobChrisTreeNormalizer()), encoding);
        treebank.loadPath(args[0]);
        treebank.apply(pas);
        pas.printStats();
    }
}
Also used : StringLabelFactory(edu.stanford.nlp.ling.StringLabelFactory) ClassicCounter(edu.stanford.nlp.stats.ClassicCounter)

Example 4 with ClassicCounter

use of edu.stanford.nlp.stats.ClassicCounter in project CoreNLP by stanfordnlp.

the class SisterAnnotationStats method sisterCounters.

protected void sisterCounters(Tree t, Tree p) {
    List rewrite = kidLabels(t);
    List left = leftSisterLabels(t, p);
    List right = rightSisterLabels(t, p);
    String label = t.label().value();
    if (!nodeRules.containsKey(label)) {
        nodeRules.put(label, new ClassicCounter());
    }
    if (!rightRules.containsKey(label)) {
        rightRules.put(label, new HashMap());
    }
    if (!leftRules.containsKey(label)) {
        leftRules.put(label, new HashMap());
    }
    ((ClassicCounter) nodeRules.get(label)).incrementCount(rewrite);
    sideCounters(label, rewrite, left, leftRules);
    sideCounters(label, rewrite, right, rightRules);
}
Also used : ClassicCounter(edu.stanford.nlp.stats.ClassicCounter)

Example 5 with ClassicCounter

use of edu.stanford.nlp.stats.ClassicCounter in project CoreNLP by stanfordnlp.

the class SisterAnnotationStats method printStats.

public void printStats() {
    NumberFormat nf = NumberFormat.getNumberInstance();
    nf.setMaximumFractionDigits(2);
    // System.out.println("Node rules");
    // System.out.println(nodeRules);
    // System.out.println("Parent rules");
    // System.out.println(pRules);
    // System.out.println("Grandparent rules");
    // System.out.println(gPRules);
    // Store java code for selSplit
    StringBuffer[] javaSB = new StringBuffer[CUTOFFS.length];
    for (int i = 0; i < CUTOFFS.length; i++) {
        javaSB[i] = new StringBuffer("  private static String[] sisterSplit" + (i + 1) + " = new String[] {");
    }
    /** topScores contains all enriched categories, to be sorted
     * later */
    ArrayList topScores = new ArrayList();
    for (Object o : nodeRules.keySet()) {
        ArrayList answers = new ArrayList();
        String label = (String) o;
        ClassicCounter cntr = (ClassicCounter) nodeRules.get(label);
        double support = (cntr.totalCount());
        System.out.println("Node " + label + " support is " + support);
        for (Object o4 : ((HashMap) leftRules.get(label)).keySet()) {
            String sis = (String) o4;
            ClassicCounter cntr2 = (ClassicCounter) ((HashMap) leftRules.get(label)).get(sis);
            double support2 = (cntr2.totalCount());
            /* alternative 1: use full distribution to calculate score */
            double kl = Counters.klDivergence(cntr2, cntr);
            /* alternative 2: hold out test-context data to calculate score */
            /* this doesn't work because it can lead to zero-probability
         * data points hence infinite divergence */
            // 	Counter tempCounter = new Counter();
            // 	tempCounter.addCounter(cntr2);
            // 	for(Iterator i = tempCounter.seenSet().iterator(); i.hasNext();) {
            // 	  Object o = i.next();
            // 	  tempCounter.setCount(o,-1*tempCounter.countOf(o));
            // 	}
            // 	System.out.println(tempCounter); //debugging
            // 	tempCounter.addCounter(cntr);
            // 	System.out.println(tempCounter); //debugging
            // 	System.out.println(cntr);
            // 	double kl = cntr2.klDivergence(tempCounter);
            /* alternative 2 ends here */
            String annotatedLabel = label + "=l=" + sis;
            System.out.println("KL(" + annotatedLabel + "||" + label + ") = " + nf.format(kl) + "\t" + "support(" + sis + ") = " + support2);
            answers.add(new Pair(annotatedLabel, new Double(kl * support2)));
            topScores.add(new Pair(annotatedLabel, new Double(kl * support2)));
        }
        for (Object o3 : ((HashMap) rightRules.get(label)).keySet()) {
            String sis = (String) o3;
            ClassicCounter cntr2 = (ClassicCounter) ((HashMap) rightRules.get(label)).get(sis);
            double support2 = (cntr2.totalCount());
            double kl = Counters.klDivergence(cntr2, cntr);
            String annotatedLabel = label + "=r=" + sis;
            System.out.println("KL(" + annotatedLabel + "||" + label + ") = " + nf.format(kl) + "\t" + "support(" + sis + ") = " + support2);
            answers.add(new Pair(annotatedLabel, new Double(kl * support2)));
            topScores.add(new Pair(annotatedLabel, new Double(kl * support2)));
        }
        // upto
        System.out.println("----");
        System.out.println("Sorted descending support * KL");
        Collections.sort(answers, (o1, o2) -> {
            Pair p1 = (Pair) o1;
            Pair p2 = (Pair) o2;
            Double p12 = (Double) p1.second();
            Double p22 = (Double) p2.second();
            return p22.compareTo(p12);
        });
        for (Object answer : answers) {
            Pair p = (Pair) answer;
            double psd = ((Double) p.second()).doubleValue();
            System.out.println(p.first() + ": " + nf.format(psd));
            if (psd >= CUTOFFS[0]) {
                String annotatedLabel = (String) p.first();
                for (double CUTOFF : CUTOFFS) {
                    if (psd >= CUTOFF) {
                    //javaSB[j].append("\"").append(annotatedLabel);
                    //javaSB[j].append("\",");
                    }
                }
            }
        }
        System.out.println();
    }
    Collections.sort(topScores, (o1, o2) -> {
        Pair p1 = (Pair) o1;
        Pair p2 = (Pair) o2;
        Double p12 = (Double) p1.second();
        Double p22 = (Double) p2.second();
        return p22.compareTo(p12);
    });
    String outString = "All enriched categories, sorted by score\n";
    for (Object topScore : topScores) {
        Pair p = (Pair) topScore;
        double psd = ((Double) p.second()).doubleValue();
        System.out.println(p.first() + ": " + nf.format(psd));
    }
    System.out.println();
    System.out.println("  // Automatically generated by SisterAnnotationStats -- preferably don't edit");
    int k = CUTOFFS.length - 1;
    for (int j = 0; j < topScores.size(); j++) {
        Pair p = (Pair) topScores.get(j);
        double psd = ((Double) p.second()).doubleValue();
        if (psd < CUTOFFS[k]) {
            if (k == 0) {
                break;
            } else {
                k--;
                // messy but should do it
                j -= 1;
                continue;
            }
        }
        javaSB[k].append("\"").append(p.first());
        javaSB[k].append("\",");
    }
    for (int i = 0; i < CUTOFFS.length; i++) {
        int len = javaSB[i].length();
        javaSB[i].replace(len - 2, len, "};");
        System.out.println(javaSB[i]);
    }
    System.out.print("  public static String[] sisterSplit = ");
    for (int i = CUTOFFS.length; i > 0; i--) {
        if (i == 1) {
            System.out.print("sisterSplit1");
        } else {
            System.out.print("selectiveSisterSplit" + i + " ? sisterSplit" + i + " : (");
        }
    }
    // need to print extra one to close other things open
    for (int i = CUTOFFS.length; i >= 0; i--) {
        System.out.print(")");
    }
    System.out.println(";");
}
Also used : ClassicCounter(edu.stanford.nlp.stats.ClassicCounter) NumberFormat(java.text.NumberFormat) Pair(edu.stanford.nlp.util.Pair)

Aggregations

ClassicCounter (edu.stanford.nlp.stats.ClassicCounter)69 CoreLabel (edu.stanford.nlp.ling.CoreLabel)27 ArrayList (java.util.ArrayList)21 CoreAnnotations (edu.stanford.nlp.ling.CoreAnnotations)18 Tree (edu.stanford.nlp.trees.Tree)13 Pair (edu.stanford.nlp.util.Pair)11 Counter (edu.stanford.nlp.stats.Counter)10 List (java.util.List)10 Mention (edu.stanford.nlp.coref.data.Mention)8 Language (edu.stanford.nlp.international.Language)7 RuntimeIOException (edu.stanford.nlp.io.RuntimeIOException)7 CoreMap (edu.stanford.nlp.util.CoreMap)7 IOUtils (edu.stanford.nlp.io.IOUtils)6 Label (edu.stanford.nlp.ling.Label)6 TreebankLangParserParams (edu.stanford.nlp.parser.lexparser.TreebankLangParserParams)6 PrintWriter (java.io.PrintWriter)6 java.util (java.util)6 HashSet (java.util.HashSet)6 RVFDatum (edu.stanford.nlp.ling.RVFDatum)5 DiskTreebank (edu.stanford.nlp.trees.DiskTreebank)5