Search in sources :

Example 26 with ClassicCounter

use of edu.stanford.nlp.stats.ClassicCounter in project CoreNLP by stanfordnlp.

the class ArabicSegmenter method evaluate.

/**
   * Evaluate accuracy when the input is gold segmented text *with* segmentation
   * markers and morphological analyses. In other words, the evaluation file has the
   * same format as the training data.
   *
   * @param pwOut
   */
private void evaluate(PrintWriter pwOut) {
    log.info("Starting evaluation...");
    boolean hasSegmentationMarkers = true;
    boolean hasTags = true;
    DocumentReaderAndWriter<CoreLabel> docReader = new ArabicDocumentReaderAndWriter(hasSegmentationMarkers, hasTags, hasDomainLabels, domain, tf);
    ObjectBank<List<CoreLabel>> lines = classifier.makeObjectBankFromFile(flags.testFile, docReader);
    PrintWriter tedEvalGoldTree = null, tedEvalParseTree = null;
    PrintWriter tedEvalGoldSeg = null, tedEvalParseSeg = null;
    if (tedEvalPrefix != null) {
        try {
            tedEvalGoldTree = new PrintWriter(tedEvalPrefix + "_gold.ftree");
            tedEvalGoldSeg = new PrintWriter(tedEvalPrefix + "_gold.segmentation");
            tedEvalParseTree = new PrintWriter(tedEvalPrefix + "_parse.ftree");
            tedEvalParseSeg = new PrintWriter(tedEvalPrefix + "_parse.segmentation");
        } catch (FileNotFoundException e) {
            System.err.printf("%s: %s%n", ArabicSegmenter.class.getName(), e.getMessage());
        }
    }
    Counter<String> labelTotal = new ClassicCounter<>();
    Counter<String> labelCorrect = new ClassicCounter<>();
    int total = 0;
    int correct = 0;
    for (List<CoreLabel> line : lines) {
        final String[] inputTokens = tedEvalSanitize(IOBUtils.IOBToString(line).replaceAll(":", "#pm#")).split(" ");
        final String[] goldTokens = tedEvalSanitize(IOBUtils.IOBToString(line, ":")).split(" ");
        line = classifier.classify(line);
        final String[] parseTokens = tedEvalSanitize(IOBUtils.IOBToString(line, ":")).split(" ");
        for (CoreLabel label : line) {
            // Do not evaluate labeling of whitespace
            String observation = label.get(CoreAnnotations.CharAnnotation.class);
            if (!observation.equals(IOBUtils.getBoundaryCharacter())) {
                total++;
                String hypothesis = label.get(CoreAnnotations.AnswerAnnotation.class);
                String reference = label.get(CoreAnnotations.GoldAnswerAnnotation.class);
                labelTotal.incrementCount(reference);
                if (hypothesis.equals(reference)) {
                    correct++;
                    labelCorrect.incrementCount(reference);
                }
            }
        }
        if (tedEvalParseSeg != null) {
            tedEvalGoldTree.printf("(root");
            tedEvalParseTree.printf("(root");
            int safeLength = inputTokens.length;
            if (inputTokens.length != goldTokens.length) {
                log.info("In generating TEDEval files: Input and gold do not have the same number of tokens");
                log.info("    (ignoring any extras)");
                log.info("  input: " + Arrays.toString(inputTokens));
                log.info("  gold: " + Arrays.toString(goldTokens));
                safeLength = Math.min(inputTokens.length, goldTokens.length);
            }
            if (inputTokens.length != parseTokens.length) {
                log.info("In generating TEDEval files: Input and parse do not have the same number of tokens");
                log.info("    (ignoring any extras)");
                log.info("  input: " + Arrays.toString(inputTokens));
                log.info("  parse: " + Arrays.toString(parseTokens));
                safeLength = Math.min(inputTokens.length, parseTokens.length);
            }
            for (int i = 0; i < safeLength; i++) {
                for (String segment : goldTokens[i].split(":")) tedEvalGoldTree.printf(" (seg %s)", segment);
                tedEvalGoldSeg.printf("%s\t%s%n", inputTokens[i], goldTokens[i]);
                for (String segment : parseTokens[i].split(":")) tedEvalParseTree.printf(" (seg %s)", segment);
                tedEvalParseSeg.printf("%s\t%s%n", inputTokens[i], parseTokens[i]);
            }
            tedEvalGoldTree.printf(")%n");
            tedEvalGoldSeg.println();
            tedEvalParseTree.printf(")%n");
            tedEvalParseSeg.println();
        }
    }
    double accuracy = ((double) correct) / ((double) total);
    accuracy *= 100.0;
    pwOut.println("EVALUATION RESULTS");
    pwOut.printf("#datums:\t%d%n", total);
    pwOut.printf("#correct:\t%d%n", correct);
    pwOut.printf("accuracy:\t%.2f%n", accuracy);
    pwOut.println("==================");
    // Output the per label accuracies
    pwOut.println("PER LABEL ACCURACIES");
    for (String refLabel : labelTotal.keySet()) {
        double nTotal = labelTotal.getCount(refLabel);
        double nCorrect = labelCorrect.getCount(refLabel);
        double acc = (nCorrect / nTotal) * 100.0;
        pwOut.printf(" %s\t%.2f%n", refLabel, acc);
    }
    if (tedEvalParseSeg != null) {
        tedEvalGoldTree.close();
        tedEvalGoldSeg.close();
        tedEvalParseTree.close();
        tedEvalParseSeg.close();
    }
}
Also used : FileNotFoundException(java.io.FileNotFoundException) CoreLabel(edu.stanford.nlp.ling.CoreLabel) ClassicCounter(edu.stanford.nlp.stats.ClassicCounter) CoreAnnotations(edu.stanford.nlp.ling.CoreAnnotations) List(java.util.List) PrintWriter(java.io.PrintWriter)

Example 27 with ClassicCounter

use of edu.stanford.nlp.stats.ClassicCounter in project CoreNLP by stanfordnlp.

the class PhraseTable method readPhrasesWithTagScores.

public void readPhrasesWithTagScores(String filename, Pattern fieldDelimiterPattern, Pattern countDelimiterPattern) throws IOException {
    Timing timer = new Timing();
    timer.doing("Reading phrases: " + filename);
    BufferedReader br = IOUtils.getBufferedFileReader(filename);
    String line;
    int lineno = 0;
    while ((line = br.readLine()) != null) {
        String[] columns = fieldDelimiterPattern.split(line);
        String phrase = columns[0];
        // Pick map factory to use depending on number of tags we have
        MapFactory<String, MutableDouble> mapFactory = (columns.length < 20) ? MapFactory.<String, MutableDouble>arrayMapFactory() : MapFactory.<String, MutableDouble>linkedHashMapFactory();
        Counter<String> counts = new ClassicCounter<>(mapFactory);
        for (int i = 1; i < columns.length; i++) {
            String[] tagCount = countDelimiterPattern.split(columns[i], 2);
            if (tagCount.length == 2) {
                try {
                    counts.setCount(tagCount[0], Double.parseDouble(tagCount[1]));
                } catch (NumberFormatException ex) {
                    throw new RuntimeException("Error processing field " + i + ": '" + columns[i] + "' from (" + filename + ":" + lineno + "): " + line, ex);
                }
            } else {
                throw new RuntimeException("Error processing field " + i + ": '" + columns[i] + "' from + (" + filename + ":" + lineno + "): " + line);
            }
        }
        addPhrase(phrase, null, counts);
        lineno++;
    }
    br.close();
    timer.done();
}
Also used : BufferedReader(java.io.BufferedReader) ClassicCounter(edu.stanford.nlp.stats.ClassicCounter)

Example 28 with ClassicCounter

use of edu.stanford.nlp.stats.ClassicCounter in project CoreNLP by stanfordnlp.

the class SparseAdaGradMinimizer method minimize.

// Does L1 or L2 using FOBOS and lazy update, so L1 should not be handled in the
// objective
// Alternatively, you can handle other regularization in the objective,
// but then, if the derivative is not sparse, this routine would not be very
// efficient. However, might still be okay for CRFs
@Override
public Counter<K> minimize(F function, Counter<K> x, int maxIterations) {
    sayln("       Batch size of: " + batchSize);
    sayln("       Data dimension of: " + function.dataSize());
    int numBatches = (function.dataSize() - 1) / this.batchSize + 1;
    sayln("       Batches per pass through data:  " + numBatches);
    sayln("       Number of passes is = " + numPasses);
    sayln("       Max iterations is = " + maxIterations);
    Counter<K> lastUpdated = new ClassicCounter<>();
    int timeStep = 0;
    Timing total = new Timing();
    total.start();
    for (int iter = 0; iter < numPasses; iter++) {
        double totalObjValue = 0;
        for (int j = 0; j < numBatches; j++) {
            int[] selectedData = getSample(function, this.batchSize);
            // the core adagrad
            Counter<K> gradient = function.derivativeAt(x, selectedData);
            totalObjValue = totalObjValue + function.valueAt(x, selectedData);
            for (K feature : gradient.keySet()) {
                double gradf = gradient.getCount(feature);
                double prevrate = eta / (Math.sqrt(sumGradSquare.getCount(feature)) + soften);
                double sgsValue = sumGradSquare.incrementCount(feature, gradf * gradf);
                double currentrate = eta / (Math.sqrt(sgsValue) + soften);
                double testupdate = x.getCount(feature) - (currentrate * gradient.getCount(feature));
                double lastUpdateTimeStep = lastUpdated.getCount(feature);
                double idleinterval = timeStep - lastUpdateTimeStep - 1;
                lastUpdated.setCount(feature, (double) timeStep);
                // does lazy update using idleinterval
                double trunc = Math.max(0.0, (Math.abs(testupdate) - (currentrate + prevrate * idleinterval) * this.lambdaL1));
                double trunc2 = trunc * Math.pow(1 - this.lambdaL2, currentrate + prevrate * idleinterval);
                double realupdate = Math.signum(testupdate) * trunc2;
                if (realupdate < EPS) {
                    x.remove(feature);
                } else {
                    x.setCount(feature, realupdate);
                }
                // reporting
                timeStep++;
                if (timeStep > maxIterations) {
                    sayln("Stochastic Optimization complete.  Stopped after max iterations");
                    break;
                }
                sayln(System.out.format("Iter %d \t batch: %d \t time=%.2f \t obj=%.4f", iter, timeStep, total.report() / 1000.0, totalObjValue).toString());
            }
        }
    }
    return x;
}
Also used : ClassicCounter(edu.stanford.nlp.stats.ClassicCounter) Timing(edu.stanford.nlp.util.Timing)

Example 29 with ClassicCounter

use of edu.stanford.nlp.stats.ClassicCounter in project CoreNLP by stanfordnlp.

the class TreebankFactoredLexiconStats method main.

//  private static String stripTag(String tag) {
//    if (tag.startsWith("DT")) {
//      String newTag = tag.substring(2, tag.length());
//      return newTag.length() > 0 ? newTag : tag;
//    }
//    return tag;
//  }
/**
   * @param args
   */
public static void main(String[] args) {
    if (args.length != 3) {
        System.err.printf("Usage: java %s language filename features%n", TreebankFactoredLexiconStats.class.getName());
        System.exit(-1);
    }
    Language language = Language.valueOf(args[0]);
    TreebankLangParserParams tlpp = language.params;
    if (language.equals(Language.Arabic)) {
        String[] options = { "-arabicFactored" };
        tlpp.setOptionFlag(options, 0);
    } else {
        String[] options = { "-frenchFactored" };
        tlpp.setOptionFlag(options, 0);
    }
    Treebank tb = tlpp.diskTreebank();
    tb.loadPath(args[1]);
    MorphoFeatureSpecification morphoSpec = language.equals(Language.Arabic) ? new ArabicMorphoFeatureSpecification() : new FrenchMorphoFeatureSpecification();
    String[] features = args[2].trim().split(",");
    for (String feature : features) {
        morphoSpec.activate(MorphoFeatureType.valueOf(feature));
    }
    // Counters
    Counter<String> wordTagCounter = new ClassicCounter<>(30000);
    Counter<String> morphTagCounter = new ClassicCounter<>(500);
    //    Counter<String> signatureTagCounter = new ClassicCounter<String>();
    Counter<String> morphCounter = new ClassicCounter<>(500);
    Counter<String> wordCounter = new ClassicCounter<>(30000);
    Counter<String> tagCounter = new ClassicCounter<>(300);
    Counter<String> lemmaCounter = new ClassicCounter<>(25000);
    Counter<String> lemmaTagCounter = new ClassicCounter<>(25000);
    Counter<String> richTagCounter = new ClassicCounter<>(1000);
    Counter<String> reducedTagCounter = new ClassicCounter<>(500);
    Counter<String> reducedTagLemmaCounter = new ClassicCounter<>(500);
    Map<String, Set<String>> wordLemmaMap = Generics.newHashMap();
    TwoDimensionalIntCounter<String, String> lemmaReducedTagCounter = new TwoDimensionalIntCounter<>(30000);
    TwoDimensionalIntCounter<String, String> reducedTagTagCounter = new TwoDimensionalIntCounter<>(500);
    TwoDimensionalIntCounter<String, String> tagReducedTagCounter = new TwoDimensionalIntCounter<>(300);
    int numTrees = 0;
    for (Tree tree : tb) {
        for (Tree subTree : tree) {
            if (!subTree.isLeaf()) {
                tlpp.transformTree(subTree, tree);
            }
        }
        List<Label> pretermList = tree.preTerminalYield();
        List<Label> yield = tree.yield();
        assert yield.size() == pretermList.size();
        int yieldLen = yield.size();
        for (int i = 0; i < yieldLen; ++i) {
            String tag = pretermList.get(i).value();
            String word = yield.get(i).value();
            String morph = ((CoreLabel) yield.get(i)).originalText();
            // Note: if there is no lemma, then we use the surface form.
            Pair<String, String> lemmaTag = MorphoFeatureSpecification.splitMorphString(word, morph);
            String lemma = lemmaTag.first();
            String richTag = lemmaTag.second();
            // WSGDEBUG
            if (tag.contains("MW"))
                lemma += "-MWE";
            lemmaCounter.incrementCount(lemma);
            lemmaTagCounter.incrementCount(lemma + tag);
            richTagCounter.incrementCount(richTag);
            String reducedTag = morphoSpec.strToFeatures(richTag).toString();
            reducedTagCounter.incrementCount(reducedTag);
            reducedTagLemmaCounter.incrementCount(reducedTag + lemma);
            wordTagCounter.incrementCount(word + tag);
            morphTagCounter.incrementCount(morph + tag);
            morphCounter.incrementCount(morph);
            wordCounter.incrementCount(word);
            tagCounter.incrementCount(tag);
            reducedTag = reducedTag.equals("") ? "NONE" : reducedTag;
            if (wordLemmaMap.containsKey(word)) {
                wordLemmaMap.get(word).add(lemma);
            } else {
                Set<String> lemmas = Generics.newHashSet(1);
                wordLemmaMap.put(word, lemmas);
            }
            lemmaReducedTagCounter.incrementCount(lemma, reducedTag);
            reducedTagTagCounter.incrementCount(lemma + reducedTag, tag);
            tagReducedTagCounter.incrementCount(tag, reducedTag);
        }
        ++numTrees;
    }
    // Barf...
    System.out.println("Language: " + language.toString());
    System.out.printf("#trees:\t%d%n", numTrees);
    System.out.printf("#tokens:\t%d%n", (int) wordCounter.totalCount());
    System.out.printf("#words:\t%d%n", wordCounter.keySet().size());
    System.out.printf("#tags:\t%d%n", tagCounter.keySet().size());
    System.out.printf("#wordTagPairs:\t%d%n", wordTagCounter.keySet().size());
    System.out.printf("#lemmas:\t%d%n", lemmaCounter.keySet().size());
    System.out.printf("#lemmaTagPairs:\t%d%n", lemmaTagCounter.keySet().size());
    System.out.printf("#feattags:\t%d%n", reducedTagCounter.keySet().size());
    System.out.printf("#feattag+lemmas:\t%d%n", reducedTagLemmaCounter.keySet().size());
    System.out.printf("#richtags:\t%d%n", richTagCounter.keySet().size());
    System.out.printf("#richtag+lemma:\t%d%n", morphCounter.keySet().size());
    System.out.printf("#richtag+lemmaTagPairs:\t%d%n", morphTagCounter.keySet().size());
    // Extra
    System.out.println("==================");
    StringBuilder sbNoLemma = new StringBuilder();
    StringBuilder sbMultLemmas = new StringBuilder();
    for (Map.Entry<String, Set<String>> wordLemmas : wordLemmaMap.entrySet()) {
        String word = wordLemmas.getKey();
        Set<String> lemmas = wordLemmas.getValue();
        if (lemmas.size() == 0) {
            sbNoLemma.append("NO LEMMAS FOR WORD: " + word + "\n");
            continue;
        }
        if (lemmas.size() > 1) {
            sbMultLemmas.append("MULTIPLE LEMMAS: " + word + " " + setToString(lemmas) + "\n");
            continue;
        }
        String lemma = lemmas.iterator().next();
        Set<String> reducedTags = lemmaReducedTagCounter.getCounter(lemma).keySet();
        if (reducedTags.size() > 1) {
            System.out.printf("%s --> %s%n", word, lemma);
            for (String reducedTag : reducedTags) {
                int count = lemmaReducedTagCounter.getCount(lemma, reducedTag);
                String posTags = setToString(reducedTagTagCounter.getCounter(lemma + reducedTag).keySet());
                System.out.printf("\t%s\t%d\t%s%n", reducedTag, count, posTags);
            }
            System.out.println();
        }
    }
    System.out.println("==================");
    System.out.println(sbNoLemma.toString());
    System.out.println(sbMultLemmas.toString());
    System.out.println("==================");
    List<String> tags = new ArrayList<>(tagReducedTagCounter.firstKeySet());
    Collections.sort(tags);
    for (String tag : tags) {
        System.out.println(tag);
        Set<String> reducedTags = tagReducedTagCounter.getCounter(tag).keySet();
        for (String reducedTag : reducedTags) {
            int count = tagReducedTagCounter.getCount(tag, reducedTag);
            //        reducedTag = reducedTag.equals("") ? "NONE" : reducedTag;
            System.out.printf("\t%s\t%d%n", reducedTag, count);
        }
        System.out.println();
    }
    System.out.println("==================");
}
Also used : FrenchMorphoFeatureSpecification(edu.stanford.nlp.international.french.FrenchMorphoFeatureSpecification) Set(java.util.Set) Treebank(edu.stanford.nlp.trees.Treebank) CoreLabel(edu.stanford.nlp.ling.CoreLabel) Label(edu.stanford.nlp.ling.Label) ArrayList(java.util.ArrayList) TreebankLangParserParams(edu.stanford.nlp.parser.lexparser.TreebankLangParserParams) Language(edu.stanford.nlp.international.Language) Tree(edu.stanford.nlp.trees.Tree) ArabicMorphoFeatureSpecification(edu.stanford.nlp.international.arabic.ArabicMorphoFeatureSpecification) MorphoFeatureSpecification(edu.stanford.nlp.international.morph.MorphoFeatureSpecification) ArabicMorphoFeatureSpecification(edu.stanford.nlp.international.arabic.ArabicMorphoFeatureSpecification) FrenchMorphoFeatureSpecification(edu.stanford.nlp.international.french.FrenchMorphoFeatureSpecification) TwoDimensionalIntCounter(edu.stanford.nlp.stats.TwoDimensionalIntCounter) CoreLabel(edu.stanford.nlp.ling.CoreLabel) ClassicCounter(edu.stanford.nlp.stats.ClassicCounter) Map(java.util.Map)

Example 30 with ClassicCounter

use of edu.stanford.nlp.stats.ClassicCounter in project CoreNLP by stanfordnlp.

the class BaseUnknownWordModelTrainer method finishTraining.

@Override
public UnknownWordModel finishTraining() {
    if (useGT) {
        unknownGTTrainer.finishTraining();
    }
    for (Map.Entry<Label, ClassicCounter<String>> entry : c.entrySet()) {
        /* outer iteration is over tags */
        Label key = entry.getKey();
        // counts for words given a tag
        ClassicCounter<String> wc = entry.getValue();
        if (!tagHash.containsKey(key)) {
            tagHash.put(key, new ClassicCounter<>());
        }
        /* the UNKNOWN sequence is assumed to be seen once in each tag */
        // This is sort of broken, but you can regard it as a Dirichlet prior.
        tc.incrementCount(key);
        wc.setCount(unknown, 1.0);
        /* inner iteration is over words */
        for (String end : wc.keySet()) {
            // p(sig|tag)
            double prob = Math.log((wc.getCount(end)) / (tc.getCount(key)));
            tagHash.get(key).setCount(end, prob);
        //if (Test.verbose)
        //EncodingPrintWriter.out.println(tag + " rewrites as " + end + " endchar with probability " + prob,encoding);
        }
    }
    return model;
}
Also used : Label(edu.stanford.nlp.ling.Label) ClassicCounter(edu.stanford.nlp.stats.ClassicCounter) Map(java.util.Map)

Aggregations

ClassicCounter (edu.stanford.nlp.stats.ClassicCounter)69 CoreLabel (edu.stanford.nlp.ling.CoreLabel)27 ArrayList (java.util.ArrayList)21 CoreAnnotations (edu.stanford.nlp.ling.CoreAnnotations)18 Tree (edu.stanford.nlp.trees.Tree)13 Pair (edu.stanford.nlp.util.Pair)11 Counter (edu.stanford.nlp.stats.Counter)10 List (java.util.List)10 Mention (edu.stanford.nlp.coref.data.Mention)8 Language (edu.stanford.nlp.international.Language)7 RuntimeIOException (edu.stanford.nlp.io.RuntimeIOException)7 CoreMap (edu.stanford.nlp.util.CoreMap)7 IOUtils (edu.stanford.nlp.io.IOUtils)6 Label (edu.stanford.nlp.ling.Label)6 TreebankLangParserParams (edu.stanford.nlp.parser.lexparser.TreebankLangParserParams)6 PrintWriter (java.io.PrintWriter)6 java.util (java.util)6 HashSet (java.util.HashSet)6 RVFDatum (edu.stanford.nlp.ling.RVFDatum)5 DiskTreebank (edu.stanford.nlp.trees.DiskTreebank)5