Search in sources :

Example 31 with ClassicCounter

use of edu.stanford.nlp.stats.ClassicCounter in project CoreNLP by stanfordnlp.

the class SisterAnnotationStats method main.

/**
   * Calculate sister annotation statistics suitable for doing
   * selective sister splitting in the PCFGParser inside the
   * FactoredParser.
   *
   * @param args One argument: path to the Treebank
   */
public static void main(String[] args) {
    ClassicCounter<String> c = new ClassicCounter<>();
    c.setCount("A", 0);
    c.setCount("B", 1);
    double d = Counters.klDivergence(c, c);
    System.out.println("KL Divergence: " + d);
    String encoding = "UTF-8";
    if (args.length > 1) {
        encoding = args[1];
    }
    if (args.length < 1) {
        System.out.println("Usage: ParentAnnotationStats treebankPath");
    } else {
        SisterAnnotationStats pas = new SisterAnnotationStats();
        Treebank treebank = new DiskTreebank(in -> new PennTreeReader(in, new LabeledScoredTreeFactory(new StringLabelFactory()), new BobChrisTreeNormalizer()), encoding);
        treebank.loadPath(args[0]);
        treebank.apply(pas);
        pas.printStats();
    }
}
Also used : StringLabelFactory(edu.stanford.nlp.ling.StringLabelFactory) ClassicCounter(edu.stanford.nlp.stats.ClassicCounter)

Example 32 with ClassicCounter

use of edu.stanford.nlp.stats.ClassicCounter in project CoreNLP by stanfordnlp.

the class SisterAnnotationStats method sisterCounters.

protected void sisterCounters(Tree t, Tree p) {
    List rewrite = kidLabels(t);
    List left = leftSisterLabels(t, p);
    List right = rightSisterLabels(t, p);
    String label = t.label().value();
    if (!nodeRules.containsKey(label)) {
        nodeRules.put(label, new ClassicCounter());
    }
    if (!rightRules.containsKey(label)) {
        rightRules.put(label, new HashMap());
    }
    if (!leftRules.containsKey(label)) {
        leftRules.put(label, new HashMap());
    }
    ((ClassicCounter) nodeRules.get(label)).incrementCount(rewrite);
    sideCounters(label, rewrite, left, leftRules);
    sideCounters(label, rewrite, right, rightRules);
}
Also used : ClassicCounter(edu.stanford.nlp.stats.ClassicCounter)

Example 33 with ClassicCounter

use of edu.stanford.nlp.stats.ClassicCounter in project CoreNLP by stanfordnlp.

the class SisterAnnotationStats method printStats.

public void printStats() {
    NumberFormat nf = NumberFormat.getNumberInstance();
    nf.setMaximumFractionDigits(2);
    // System.out.println("Node rules");
    // System.out.println(nodeRules);
    // System.out.println("Parent rules");
    // System.out.println(pRules);
    // System.out.println("Grandparent rules");
    // System.out.println(gPRules);
    // Store java code for selSplit
    StringBuffer[] javaSB = new StringBuffer[CUTOFFS.length];
    for (int i = 0; i < CUTOFFS.length; i++) {
        javaSB[i] = new StringBuffer("  private static String[] sisterSplit" + (i + 1) + " = new String[] {");
    }
    /** topScores contains all enriched categories, to be sorted
     * later */
    ArrayList topScores = new ArrayList();
    for (Object o : nodeRules.keySet()) {
        ArrayList answers = new ArrayList();
        String label = (String) o;
        ClassicCounter cntr = (ClassicCounter) nodeRules.get(label);
        double support = (cntr.totalCount());
        System.out.println("Node " + label + " support is " + support);
        for (Object o4 : ((HashMap) leftRules.get(label)).keySet()) {
            String sis = (String) o4;
            ClassicCounter cntr2 = (ClassicCounter) ((HashMap) leftRules.get(label)).get(sis);
            double support2 = (cntr2.totalCount());
            /* alternative 1: use full distribution to calculate score */
            double kl = Counters.klDivergence(cntr2, cntr);
            /* alternative 2: hold out test-context data to calculate score */
            /* this doesn't work because it can lead to zero-probability
         * data points hence infinite divergence */
            // 	Counter tempCounter = new Counter();
            // 	tempCounter.addCounter(cntr2);
            // 	for(Iterator i = tempCounter.seenSet().iterator(); i.hasNext();) {
            // 	  Object o = i.next();
            // 	  tempCounter.setCount(o,-1*tempCounter.countOf(o));
            // 	}
            // 	System.out.println(tempCounter); //debugging
            // 	tempCounter.addCounter(cntr);
            // 	System.out.println(tempCounter); //debugging
            // 	System.out.println(cntr);
            // 	double kl = cntr2.klDivergence(tempCounter);
            /* alternative 2 ends here */
            String annotatedLabel = label + "=l=" + sis;
            System.out.println("KL(" + annotatedLabel + "||" + label + ") = " + nf.format(kl) + "\t" + "support(" + sis + ") = " + support2);
            answers.add(new Pair(annotatedLabel, new Double(kl * support2)));
            topScores.add(new Pair(annotatedLabel, new Double(kl * support2)));
        }
        for (Object o3 : ((HashMap) rightRules.get(label)).keySet()) {
            String sis = (String) o3;
            ClassicCounter cntr2 = (ClassicCounter) ((HashMap) rightRules.get(label)).get(sis);
            double support2 = (cntr2.totalCount());
            double kl = Counters.klDivergence(cntr2, cntr);
            String annotatedLabel = label + "=r=" + sis;
            System.out.println("KL(" + annotatedLabel + "||" + label + ") = " + nf.format(kl) + "\t" + "support(" + sis + ") = " + support2);
            answers.add(new Pair(annotatedLabel, new Double(kl * support2)));
            topScores.add(new Pair(annotatedLabel, new Double(kl * support2)));
        }
        // upto
        System.out.println("----");
        System.out.println("Sorted descending support * KL");
        Collections.sort(answers, (o1, o2) -> {
            Pair p1 = (Pair) o1;
            Pair p2 = (Pair) o2;
            Double p12 = (Double) p1.second();
            Double p22 = (Double) p2.second();
            return p22.compareTo(p12);
        });
        for (Object answer : answers) {
            Pair p = (Pair) answer;
            double psd = ((Double) p.second()).doubleValue();
            System.out.println(p.first() + ": " + nf.format(psd));
            if (psd >= CUTOFFS[0]) {
                String annotatedLabel = (String) p.first();
                for (double CUTOFF : CUTOFFS) {
                    if (psd >= CUTOFF) {
                    //javaSB[j].append("\"").append(annotatedLabel);
                    //javaSB[j].append("\",");
                    }
                }
            }
        }
        System.out.println();
    }
    Collections.sort(topScores, (o1, o2) -> {
        Pair p1 = (Pair) o1;
        Pair p2 = (Pair) o2;
        Double p12 = (Double) p1.second();
        Double p22 = (Double) p2.second();
        return p22.compareTo(p12);
    });
    String outString = "All enriched categories, sorted by score\n";
    for (Object topScore : topScores) {
        Pair p = (Pair) topScore;
        double psd = ((Double) p.second()).doubleValue();
        System.out.println(p.first() + ": " + nf.format(psd));
    }
    System.out.println();
    System.out.println("  // Automatically generated by SisterAnnotationStats -- preferably don't edit");
    int k = CUTOFFS.length - 1;
    for (int j = 0; j < topScores.size(); j++) {
        Pair p = (Pair) topScores.get(j);
        double psd = ((Double) p.second()).doubleValue();
        if (psd < CUTOFFS[k]) {
            if (k == 0) {
                break;
            } else {
                k--;
                // messy but should do it
                j -= 1;
                continue;
            }
        }
        javaSB[k].append("\"").append(p.first());
        javaSB[k].append("\",");
    }
    for (int i = 0; i < CUTOFFS.length; i++) {
        int len = javaSB[i].length();
        javaSB[i].replace(len - 2, len, "};");
        System.out.println(javaSB[i]);
    }
    System.out.print("  public static String[] sisterSplit = ");
    for (int i = CUTOFFS.length; i > 0; i--) {
        if (i == 1) {
            System.out.print("sisterSplit1");
        } else {
            System.out.print("selectiveSisterSplit" + i + " ? sisterSplit" + i + " : (");
        }
    }
    // need to print extra one to close other things open
    for (int i = CUTOFFS.length; i >= 0; i--) {
        System.out.print(")");
    }
    System.out.println(";");
}
Also used : ClassicCounter(edu.stanford.nlp.stats.ClassicCounter) NumberFormat(java.text.NumberFormat) Pair(edu.stanford.nlp.util.Pair)

Example 34 with ClassicCounter

use of edu.stanford.nlp.stats.ClassicCounter in project CoreNLP by stanfordnlp.

the class GrammarCompactor method convertGraphsToGrammar.

/**
   * @param graphs      a Map from String categories to TransducerGraph objects
   * @param unaryRules  is a Set of UnaryRule objects that we need to add
   * @param binaryRules is a Set of BinaryRule objects that we need to add
   * @return a new Pair of UnaryGrammar, BinaryGrammar
   */
protected Pair<UnaryGrammar, BinaryGrammar> convertGraphsToGrammar(Set<TransducerGraph> graphs, Set<UnaryRule> unaryRules, Set<BinaryRule> binaryRules) {
    // first go through all the existing rules and number them with new numberer
    newStateIndex = new HashIndex<>();
    for (UnaryRule rule : unaryRules) {
        String parent = stateIndex.get(rule.parent);
        rule.parent = newStateIndex.addToIndex(parent);
        String child = stateIndex.get(rule.child);
        rule.child = newStateIndex.addToIndex(child);
    }
    for (BinaryRule rule : binaryRules) {
        String parent = stateIndex.get(rule.parent);
        rule.parent = newStateIndex.addToIndex(parent);
        String leftChild = stateIndex.get(rule.leftChild);
        rule.leftChild = newStateIndex.addToIndex(leftChild);
        String rightChild = stateIndex.get(rule.rightChild);
        rule.rightChild = newStateIndex.addToIndex(rightChild);
    }
    // now go through the graphs and add the rules
    for (TransducerGraph graph : graphs) {
        Object startNode = graph.getStartNode();
        for (Arc arc : graph.getArcs()) {
            // TODO: make sure these are the strings we're looking for
            String source = arc.getSourceNode().toString();
            String target = arc.getTargetNode().toString();
            Object input = arc.getInput();
            String inputString = input.toString();
            double output = ((Double) arc.getOutput()).doubleValue();
            if (source.equals(startNode)) {
                // make a UnaryRule
                UnaryRule ur = new UnaryRule(newStateIndex.addToIndex(target), newStateIndex.addToIndex(inputString), smartNegate(output));
                unaryRules.add(ur);
            } else if (inputString.equals(END) || inputString.equals(EPSILON)) {
                // make a UnaryRule
                UnaryRule ur = new UnaryRule(newStateIndex.addToIndex(target), newStateIndex.addToIndex(source), smartNegate(output));
                unaryRules.add(ur);
            } else {
                // make a BinaryRule
                // figure out whether the input was generated on the left or right
                int length = inputString.length();
                char leftOrRight = inputString.charAt(length - 1);
                inputString = inputString.substring(0, length - 1);
                BinaryRule br;
                if (leftOrRight == '<' || leftOrRight == '[') {
                    br = new BinaryRule(newStateIndex.addToIndex(target), newStateIndex.addToIndex(inputString), newStateIndex.addToIndex(source), smartNegate(output));
                } else if (leftOrRight == '>' || leftOrRight == ']') {
                    br = new BinaryRule(newStateIndex.addToIndex(target), newStateIndex.addToIndex(source), newStateIndex.addToIndex(inputString), smartNegate(output));
                } else {
                    throw new RuntimeException("Arc input is in unexpected format: " + arc);
                }
                binaryRules.add(br);
            }
        }
    }
    // by now, the unaryRules and binaryRules Sets have old untouched and new rules with scores
    ClassicCounter<String> symbolCounter = new ClassicCounter<>();
    if (outputType == RAW_COUNTS) {
        // so we count parent symbol occurrences
        for (UnaryRule rule : unaryRules) {
            symbolCounter.incrementCount(newStateIndex.get(rule.parent), rule.score);
        }
        for (BinaryRule rule : binaryRules) {
            symbolCounter.incrementCount(newStateIndex.get(rule.parent), rule.score);
        }
    }
    // now we put the rules in the grammars
    // this should be smaller than last one
    int numStates = newStateIndex.size();
    int numRules = 0;
    UnaryGrammar ug = new UnaryGrammar(newStateIndex);
    BinaryGrammar bg = new BinaryGrammar(newStateIndex);
    for (UnaryRule rule : unaryRules) {
        if (outputType == RAW_COUNTS) {
            double count = symbolCounter.getCount(newStateIndex.get(rule.parent));
            rule.score = (float) Math.log(rule.score / count);
        }
        ug.addRule(rule);
        numRules++;
    }
    for (BinaryRule rule : binaryRules) {
        if (outputType == RAW_COUNTS) {
            double count = symbolCounter.getCount(newStateIndex.get(rule.parent));
            rule.score = (float) Math.log((rule.score - op.trainOptions.ruleDiscount) / count);
        }
        bg.addRule(rule);
        numRules++;
    }
    if (verbose) {
        System.out.println("Number of minimized rules: " + numRules);
        System.out.println("Number of minimized states: " + newStateIndex.size());
    }
    ug.purgeRules();
    bg.splitRules();
    return new Pair<>(ug, bg);
}
Also used : Arc(edu.stanford.nlp.fsm.TransducerGraph.Arc) TransducerGraph(edu.stanford.nlp.fsm.TransducerGraph) ClassicCounter(edu.stanford.nlp.stats.ClassicCounter) Pair(edu.stanford.nlp.util.Pair)

Example 35 with ClassicCounter

use of edu.stanford.nlp.stats.ClassicCounter in project CoreNLP by stanfordnlp.

the class ChineseCorefBenchmarkSlowITest method getCorefResults.

private static Counter<String> getCorefResults(String resultsString) throws IOException {
    Counter<String> results = new ClassicCounter<String>();
    BufferedReader r = new BufferedReader(new StringReader(resultsString));
    for (String line; (line = r.readLine()) != null; ) {
        Matcher m1 = MENTION_PATTERN.matcher(line);
        if (m1.matches()) {
            results.setCount(MENTION_TP, Double.parseDouble(m1.group(1)));
            results.setCount(MENTION_F1, Double.parseDouble(m1.group(2)));
        }
        Matcher m2 = MUC_PATTERN.matcher(line);
        if (m2.matches()) {
            results.setCount(MUC_TP, Double.parseDouble(m2.group(1)));
            results.setCount(MUC_F1, Double.parseDouble(m2.group(2)));
        }
        Matcher m3 = BCUBED_PATTERN.matcher(line);
        if (m3.matches()) {
            results.setCount(BCUBED_TP, Double.parseDouble(m3.group(1)));
            results.setCount(BCUBED_F1, Double.parseDouble(m3.group(2)));
        }
        Matcher m4 = CEAFM_PATTERN.matcher(line);
        if (m4.matches()) {
            results.setCount(CEAFM_TP, Double.parseDouble(m4.group(1)));
            results.setCount(CEAFM_F1, Double.parseDouble(m4.group(2)));
        }
        Matcher m5 = CEAFE_PATTERN.matcher(line);
        if (m5.matches()) {
            results.setCount(CEAFE_TP, Double.parseDouble(m5.group(1)));
            results.setCount(CEAFE_F1, Double.parseDouble(m5.group(2)));
        }
        Matcher m6 = BLANC_PATTERN.matcher(line);
        if (m6.matches()) {
            results.setCount(BLANC_F1, Double.parseDouble(m6.group(1)));
        }
        Matcher m7 = CONLL_PATTERN.matcher(line);
        if (m7.matches()) {
            results.setCount(CONLL_SCORE, Double.parseDouble(m7.group(1)));
        }
    }
    return results;
}
Also used : Matcher(java.util.regex.Matcher) ClassicCounter(edu.stanford.nlp.stats.ClassicCounter) BufferedReader(java.io.BufferedReader) StringReader(java.io.StringReader)

Aggregations

ClassicCounter (edu.stanford.nlp.stats.ClassicCounter)69 CoreLabel (edu.stanford.nlp.ling.CoreLabel)27 ArrayList (java.util.ArrayList)21 CoreAnnotations (edu.stanford.nlp.ling.CoreAnnotations)18 Tree (edu.stanford.nlp.trees.Tree)13 Pair (edu.stanford.nlp.util.Pair)11 Counter (edu.stanford.nlp.stats.Counter)10 List (java.util.List)10 Mention (edu.stanford.nlp.coref.data.Mention)8 Language (edu.stanford.nlp.international.Language)7 RuntimeIOException (edu.stanford.nlp.io.RuntimeIOException)7 CoreMap (edu.stanford.nlp.util.CoreMap)7 IOUtils (edu.stanford.nlp.io.IOUtils)6 Label (edu.stanford.nlp.ling.Label)6 TreebankLangParserParams (edu.stanford.nlp.parser.lexparser.TreebankLangParserParams)6 PrintWriter (java.io.PrintWriter)6 java.util (java.util)6 HashSet (java.util.HashSet)6 RVFDatum (edu.stanford.nlp.ling.RVFDatum)5 DiskTreebank (edu.stanford.nlp.trees.DiskTreebank)5