Search in sources :

Example 1 with ScorePhraseMeasures

use of edu.stanford.nlp.patterns.ConstantsAndVariables.ScorePhraseMeasures in project CoreNLP by stanfordnlp.

the class ScorePhrasesLearnFeatWt method getPhraseFeaturesForPattern.

Counter<ScorePhraseMeasures> getPhraseFeaturesForPattern(String label, CandidatePhrase word) {
    if (phraseScoresRaw.containsFirstKey(word))
        return phraseScoresRaw.getCounter(word);
    Counter<ScorePhraseMeasures> scoreslist = new ClassicCounter<>();
    //Add features on the word, if any!
    if (word.getFeatures() != null) {
        scoreslist.addAll(Counters.transform(word.getFeatures(), x -> ScorePhraseMeasures.create(x)));
    } else {
        Redwood.log(ConstantsAndVariables.extremedebug, "features are null for " + word);
    }
    if (constVars.usePatternEvalSemanticOdds) {
        double dscore = this.getDictOddsScore(word, label, 0);
        scoreslist.setCount(ScorePhraseMeasures.SEMANTICODDS, dscore);
    }
    if (constVars.usePatternEvalGoogleNgram) {
        Double gscore = getGoogleNgramScore(word);
        if (gscore.isInfinite() || gscore.isNaN()) {
            throw new RuntimeException("how is the google ngrams score " + gscore + " for " + word);
        }
        scoreslist.setCount(ScorePhraseMeasures.GOOGLENGRAM, gscore);
    }
    if (constVars.usePatternEvalDomainNgram) {
        Double gscore = getDomainNgramScore(word.getPhrase());
        if (gscore.isInfinite() || gscore.isNaN()) {
            throw new RuntimeException("how is the domain ngrams score " + gscore + " for " + word + " when domain raw freq is " + Data.domainNGramRawFreq.getCount(word) + " and raw freq is " + Data.rawFreq.getCount(word));
        }
        scoreslist.setCount(ScorePhraseMeasures.DOMAINNGRAM, gscore);
    }
    if (constVars.usePatternEvalWordClass) {
        Integer wordclass = constVars.getWordClassClusters().get(word.getPhrase());
        if (wordclass == null) {
            wordclass = constVars.getWordClassClusters().get(word.getPhrase().toLowerCase());
        }
        scoreslist.setCount(ScorePhraseMeasures.create(ScorePhraseMeasures.DISTSIM.toString() + "-" + wordclass), 1.0);
    }
    if (constVars.usePatternEvalEditDistSame) {
        double ed = constVars.getEditDistanceScoresThisClass(label, word.getPhrase());
        assert ed <= 1 : " how come edit distance from the true class is " + ed + " for word " + word;
        scoreslist.setCount(ScorePhraseMeasures.EDITDISTSAME, ed);
    }
    if (constVars.usePatternEvalEditDistOther) {
        double ed = constVars.getEditDistanceScoresOtherClass(label, word.getPhrase());
        assert ed <= 1 : " how come edit distance from the true class is " + ed + " for word " + word;
        ;
        scoreslist.setCount(ScorePhraseMeasures.EDITDISTOTHER, ed);
    }
    if (constVars.usePatternEvalWordShape) {
        scoreslist.setCount(ScorePhraseMeasures.WORDSHAPE, this.getWordShapeScore(word.getPhrase(), label));
    }
    if (constVars.usePatternEvalWordShapeStr) {
        scoreslist.setCount(ScorePhraseMeasures.create(ScorePhraseMeasures.WORDSHAPESTR + "-" + this.wordShape(word.getPhrase())), 1.0);
    }
    if (constVars.usePatternEvalFirstCapital) {
        scoreslist.setCount(ScorePhraseMeasures.ISFIRSTCAPITAL, StringUtils.isCapitalized(word.getPhrase()) ? 1.0 : 0);
    }
    if (constVars.usePatternEvalBOW) {
        for (String s : word.getPhrase().split("\\s+")) scoreslist.setCount(ScorePhraseMeasures.create(ScorePhraseMeasures.BOW + "-" + s), 1.0);
    }
    phraseScoresRaw.setCounter(word, scoreslist);
    //System.out.println("scores for " + word + " are " + scoreslist);
    return scoreslist;
}
Also used : java.util(java.util) ExtractPhraseFromPattern(edu.stanford.nlp.patterns.dep.ExtractPhraseFromPattern) edu.stanford.nlp.util(edu.stanford.nlp.util) ConcurrentHashCounter(edu.stanford.nlp.util.concurrent.ConcurrentHashCounter) Function(java.util.function.Function) edu.stanford.nlp.stats(edu.stanford.nlp.stats) AtomicDouble(edu.stanford.nlp.util.concurrent.AtomicDouble) SemanticGraph(edu.stanford.nlp.semgraph.SemanticGraph) RVFDatum(edu.stanford.nlp.ling.RVFDatum) Option(edu.stanford.nlp.util.ArgumentParser.Option) IndexedWord(edu.stanford.nlp.ling.IndexedWord) CoreLabel(edu.stanford.nlp.ling.CoreLabel) ScorePhraseMeasures(edu.stanford.nlp.patterns.ConstantsAndVariables.ScorePhraseMeasures) DataInstanceDep(edu.stanford.nlp.patterns.dep.DataInstanceDep) BufferedWriter(java.io.BufferedWriter) java.util.concurrent(java.util.concurrent) IOUtils(edu.stanford.nlp.io.IOUtils) Redwood(edu.stanford.nlp.util.logging.Redwood) FileWriter(java.io.FileWriter) BasicDatum(edu.stanford.nlp.ling.BasicDatum) IOException(java.io.IOException) Collectors(java.util.stream.Collectors) File(java.io.File) ExtractedPhrase(edu.stanford.nlp.patterns.dep.ExtractedPhrase) edu.stanford.nlp.classify(edu.stanford.nlp.classify) Entry(java.util.Map.Entry) ScorePhraseMeasures(edu.stanford.nlp.patterns.ConstantsAndVariables.ScorePhraseMeasures) AtomicDouble(edu.stanford.nlp.util.concurrent.AtomicDouble)

Example 2 with ScorePhraseMeasures

use of edu.stanford.nlp.patterns.ConstantsAndVariables.ScorePhraseMeasures in project CoreNLP by stanfordnlp.

the class ScorePhrasesLearnFeatWt method choosedatums.

public GeneralDataset<String, ScorePhraseMeasures> choosedatums(boolean forLearningPattern, String answerLabel, TwoDimensionalCounter<CandidatePhrase, E> wordsPatExtracted, Counter<E> allSelectedPatterns, boolean computeRawFreq) throws IOException {
    boolean expandNeg = false;
    if (closeToNegativesFirstIter == null) {
        closeToNegativesFirstIter = new ClassicCounter<>();
        if (constVars.expandNegativesWhenSampling)
            expandNeg = true;
    }
    boolean expandPos = false;
    if (closeToPositivesFirstIter == null) {
        closeToPositivesFirstIter = new ClassicCounter<>();
        if (constVars.expandPositivesWhenSampling)
            expandPos = true;
    }
    Counter<Integer> distSimClustersOfPositive = new ClassicCounter<>();
    if ((expandPos || expandNeg) && !constVars.useWordVectorsToComputeSim) {
        for (CandidatePhrase s : CollectionUtils.union(constVars.getLearnedWords(answerLabel).keySet(), constVars.getSeedLabelDictionary().get(answerLabel))) {
            String[] toks = s.getPhrase().split("\\s+");
            Integer num = constVars.getWordClassClusters().get(s.getPhrase());
            if (num == null)
                num = constVars.getWordClassClusters().get(s.getPhrase().toLowerCase());
            if (num == null) {
                for (String tok : toks) {
                    Integer toknum = constVars.getWordClassClusters().get(tok);
                    if (toknum == null)
                        toknum = constVars.getWordClassClusters().get(tok.toLowerCase());
                    if (toknum != null) {
                        distSimClustersOfPositive.incrementCount(toknum);
                    }
                }
            } else
                distSimClustersOfPositive.incrementCount(num);
        }
    }
    //computing this regardless of expandpos and expandneg because we reject all positive words that occur in negatives (can happen in multi word phrases etc)
    Map<String, Collection<CandidatePhrase>> allPossibleNegativePhrases = getAllPossibleNegativePhrases(answerLabel);
    GeneralDataset<String, ScorePhraseMeasures> dataset = new RVFDataset<>();
    int numpos = 0;
    Set<CandidatePhrase> allNegativePhrases = new HashSet<>();
    Set<CandidatePhrase> allUnknownPhrases = new HashSet<>();
    Set<CandidatePhrase> allPositivePhrases = new HashSet<>();
    //Counter<CandidatePhrase> allCloseToPositivePhrases = new ClassicCounter<CandidatePhrase>();
    //Counter<CandidatePhrase> allCloseToNegativePhrases = new ClassicCounter<CandidatePhrase>();
    //for all sentences brtch
    ConstantsAndVariables.DataSentsIterator sentsIter = new ConstantsAndVariables.DataSentsIterator(constVars.batchProcessSents);
    while (sentsIter.hasNext()) {
        Pair<Map<String, DataInstance>, File> sentsf = sentsIter.next();
        Map<String, DataInstance> sents = sentsf.first();
        Redwood.log(Redwood.DBG, "Sampling datums from " + sentsf.second());
        if (computeRawFreq)
            Data.computeRawFreqIfNull(sents, PatternFactory.numWordsCompoundMax);
        List<List<String>> threadedSentIds = GetPatternsFromDataMultiClass.getThreadBatches(new ArrayList<>(sents.keySet()), constVars.numThreads);
        ExecutorService executor = Executors.newFixedThreadPool(constVars.numThreads);
        List<Future<Quintuple<Set<CandidatePhrase>, Set<CandidatePhrase>, Set<CandidatePhrase>, Counter<CandidatePhrase>, Counter<CandidatePhrase>>>> list = new ArrayList<>();
        //multi-threaded choose positive, negative and unknown
        for (List<String> keys : threadedSentIds) {
            Callable<Quintuple<Set<CandidatePhrase>, Set<CandidatePhrase>, Set<CandidatePhrase>, Counter<CandidatePhrase>, Counter<CandidatePhrase>>> task = new ChooseDatumsThread(answerLabel, sents, keys, wordsPatExtracted, allSelectedPatterns, distSimClustersOfPositive, allPossibleNegativePhrases, expandPos, expandNeg);
            Future<Quintuple<Set<CandidatePhrase>, Set<CandidatePhrase>, Set<CandidatePhrase>, Counter<CandidatePhrase>, Counter<CandidatePhrase>>> submit = executor.submit(task);
            list.add(submit);
        }
        // Now retrieve the result
        for (Future<Quintuple<Set<CandidatePhrase>, Set<CandidatePhrase>, Set<CandidatePhrase>, Counter<CandidatePhrase>, Counter<CandidatePhrase>>> future : list) {
            try {
                Quintuple<Set<CandidatePhrase>, Set<CandidatePhrase>, Set<CandidatePhrase>, Counter<CandidatePhrase>, Counter<CandidatePhrase>> result = future.get();
                allPositivePhrases.addAll(result.first());
                allNegativePhrases.addAll(result.second());
                allUnknownPhrases.addAll(result.third());
                if (expandPos)
                    for (Entry<CandidatePhrase, Double> en : result.fourth().entrySet()) closeToPositivesFirstIter.setCount(en.getKey(), en.getValue());
                if (expandNeg)
                    for (Entry<CandidatePhrase, Double> en : result.fifth().entrySet()) closeToNegativesFirstIter.setCount(en.getKey(), en.getValue());
            } catch (Exception e) {
                executor.shutdownNow();
                throw new RuntimeException(e);
            }
        }
        executor.shutdown();
    }
    //Set<CandidatePhrase> knownPositivePhrases = CollectionUtils.unionAsSet(constVars.getLearnedWords().get(answerLabel).keySet(), constVars.getSeedLabelDictionary().get(answerLabel));
    //TODO: this is kinda not nice; how is allpositivephrases different from positivephrases again?
    allPositivePhrases.addAll(constVars.getLearnedWords(answerLabel).keySet());
    //allPositivePhrases.addAll(knownPositivePhrases);
    BufferedWriter logFile = null;
    BufferedWriter logFileFeat = null;
    if (constVars.logFileVectorSimilarity != null) {
        logFile = new BufferedWriter(new FileWriter(constVars.logFileVectorSimilarity));
        logFileFeat = new BufferedWriter(new FileWriter(constVars.logFileVectorSimilarity + "_feat"));
        if (wordVectors != null) {
            for (CandidatePhrase p : allPositivePhrases) {
                if (wordVectors.containsKey(p.getPhrase())) {
                    logFile.write(p.getPhrase() + "-P " + ArrayUtils.toString(wordVectors.get(p.getPhrase()), " ") + "\n");
                }
            }
        }
    }
    if (constVars.expandPositivesWhenSampling) {
        //TODO: patwtbyfrew
        //Counters.retainTop(allCloseToPositivePhrases, (int) (allCloseToPositivePhrases.size()*constVars.subSampleUnkAsPosUsingSimPercentage));
        Redwood.log("Expanding positives by adding " + Counters.toSortedString(closeToPositivesFirstIter, closeToPositivesFirstIter.size(), "%1$s:%2$f", "\t") + " phrases");
        allPositivePhrases.addAll(closeToPositivesFirstIter.keySet());
        //write log
        if (logFile != null && wordVectors != null && expandNeg) {
            for (CandidatePhrase p : closeToPositivesFirstIter.keySet()) {
                if (wordVectors.containsKey(p.getPhrase())) {
                    logFile.write(p.getPhrase() + "-PP " + ArrayUtils.toString(wordVectors.get(p.getPhrase()), " ") + "\n");
                }
            }
        }
    }
    if (constVars.expandNegativesWhenSampling) {
        //TODO: patwtbyfrew
        //Counters.retainTop(allCloseToPositivePhrases, (int) (allCloseToPositivePhrases.size()*constVars.subSampleUnkAsPosUsingSimPercentage));
        Redwood.log("Expanding negatives by adding " + Counters.toSortedString(closeToNegativesFirstIter, closeToNegativesFirstIter.size(), "%1$s:%2$f", "\t") + " phrases");
        allNegativePhrases.addAll(closeToNegativesFirstIter.keySet());
        //write log
        if (logFile != null && wordVectors != null && expandNeg) {
            for (CandidatePhrase p : closeToNegativesFirstIter.keySet()) {
                if (wordVectors.containsKey(p.getPhrase())) {
                    logFile.write(p.getPhrase() + "-NN " + ArrayUtils.toString(wordVectors.get(p.getPhrase()), " ") + "\n");
                }
            }
        }
    }
    System.out.println("all positive phrases of size " + allPositivePhrases.size() + " are  " + allPositivePhrases);
    for (CandidatePhrase candidate : allPositivePhrases) {
        Counter<ScorePhraseMeasures> feat;
        //CandidatePhrase candidate = new CandidatePhrase(l.word());
        if (forLearningPattern) {
            feat = getPhraseFeaturesForPattern(answerLabel, candidate);
        } else {
            feat = getFeatures(answerLabel, candidate, wordsPatExtracted.getCounter(candidate), allSelectedPatterns);
        }
        RVFDatum<String, ScorePhraseMeasures> datum = new RVFDatum<>(feat, "true");
        dataset.add(datum);
        numpos += 1;
        if (logFileFeat != null) {
            logFileFeat.write("POSITIVE " + candidate.getPhrase() + "\t" + Counters.toSortedByKeysString(feat, "%1$s:%2$.0f", ";", "%s") + "\n");
        }
    }
    Redwood.log(Redwood.DBG, "Number of pure negative phrases is " + allNegativePhrases.size());
    Redwood.log(Redwood.DBG, "Number of unknown phrases is " + allUnknownPhrases.size());
    if (constVars.subsampleUnkAsNegUsingSim) {
        Set<CandidatePhrase> chosenUnknown = chooseUnknownAsNegatives(allUnknownPhrases, answerLabel, allPositivePhrases, allPossibleNegativePhrases, logFile);
        Redwood.log(Redwood.DBG, "Choosing " + chosenUnknown.size() + " unknowns as negative based to their similarity to the positive phrases");
        allNegativePhrases.addAll(chosenUnknown);
    } else {
        allNegativePhrases.addAll(allUnknownPhrases);
    }
    if (allNegativePhrases.size() > numpos) {
        Redwood.log(Redwood.WARN, "Num of negative (" + allNegativePhrases.size() + ") is higher than number of positive phrases (" + numpos + ") = " + (allNegativePhrases.size() / (double) numpos) + ". " + "Capping the number by taking the first numPositives as negative. Consider decreasing perSelectRand");
        int i = 0;
        Set<CandidatePhrase> selectedNegPhrases = new HashSet<>();
        for (CandidatePhrase p : allNegativePhrases) {
            if (i >= numpos)
                break;
            selectedNegPhrases.add(p);
            i++;
        }
        allNegativePhrases.clear();
        allNegativePhrases = selectedNegPhrases;
    }
    System.out.println("all negative phrases are " + allNegativePhrases);
    for (CandidatePhrase negative : allNegativePhrases) {
        Counter<ScorePhraseMeasures> feat;
        //CandidatePhrase candidate = new CandidatePhrase(l.word());
        if (forLearningPattern) {
            feat = getPhraseFeaturesForPattern(answerLabel, negative);
        } else {
            feat = getFeatures(answerLabel, negative, wordsPatExtracted.getCounter(negative), allSelectedPatterns);
        }
        RVFDatum<String, ScorePhraseMeasures> datum = new RVFDatum<>(feat, "false");
        dataset.add(datum);
        if (logFile != null && wordVectors != null && wordVectors.containsKey(negative.getPhrase())) {
            logFile.write(negative.getPhrase() + "-N" + " " + ArrayUtils.toString(wordVectors.get(negative.getPhrase()), " ") + "\n");
        }
        if (logFileFeat != null)
            logFileFeat.write("NEGATIVE " + negative.getPhrase() + "\t" + Counters.toSortedByKeysString(feat, "%1$s:%2$.0f", ";", "%s") + "\n");
    }
    if (logFile != null) {
        logFile.close();
    }
    if (logFileFeat != null) {
        logFileFeat.close();
    }
    System.out.println("Before feature count threshold, dataset stats are ");
    dataset.summaryStatistics();
    dataset.applyFeatureCountThreshold(constVars.featureCountThreshold);
    System.out.println("AFTER feature count threshold of " + constVars.featureCountThreshold + ", dataset stats are ");
    dataset.summaryStatistics();
    Redwood.log(Redwood.DBG, "Eventually, number of positive datums:  " + numpos + " and number of negative datums: " + allNegativePhrases.size());
    return dataset;
}
Also used : ScorePhraseMeasures(edu.stanford.nlp.patterns.ConstantsAndVariables.ScorePhraseMeasures) FileWriter(java.io.FileWriter) BufferedWriter(java.io.BufferedWriter) Entry(java.util.Map.Entry) ConcurrentHashCounter(edu.stanford.nlp.util.concurrent.ConcurrentHashCounter) RVFDatum(edu.stanford.nlp.ling.RVFDatum) IOException(java.io.IOException) File(java.io.File)

Example 3 with ScorePhraseMeasures

use of edu.stanford.nlp.patterns.ConstantsAndVariables.ScorePhraseMeasures in project CoreNLP by stanfordnlp.

the class ScorePhrasesLearnFeatWt method learnClassifier.

public edu.stanford.nlp.classify.Classifier learnClassifier(String label, boolean forLearningPatterns, TwoDimensionalCounter<CandidatePhrase, E> wordsPatExtracted, Counter<E> allSelectedPatterns) throws IOException, ClassNotFoundException {
    phraseScoresRaw.clear();
    learnedScores.clear();
    if (Data.domainNGramsFile != null)
        Data.loadDomainNGrams();
    boolean computeRawFreq = false;
    if (Data.rawFreq == null) {
        Data.rawFreq = new ClassicCounter<>();
        computeRawFreq = true;
    }
    GeneralDataset<String, ScorePhraseMeasures> dataset = choosedatums(forLearningPatterns, label, wordsPatExtracted, allSelectedPatterns, computeRawFreq);
    edu.stanford.nlp.classify.Classifier classifier;
    if (scoreClassifierType.equals(ClassifierType.LR)) {
        LogisticClassifierFactory<String, ScorePhraseMeasures> logfactory = new LogisticClassifierFactory<>();
        LogPrior lprior = new LogPrior();
        lprior.setSigma(constVars.LRSigma);
        classifier = logfactory.trainClassifier(dataset, lprior, false);
        LogisticClassifier logcl = ((LogisticClassifier) classifier);
        String l = (String) logcl.getLabelForInternalPositiveClass();
        Counter<String> weights = logcl.weightsAsCounter();
        if (l.equals(Boolean.FALSE.toString())) {
            Counters.multiplyInPlace(weights, -1);
        }
        List<Pair<String, Double>> wtd = Counters.toDescendingMagnitudeSortedListWithCounts(weights);
        Redwood.log(ConstantsAndVariables.minimaldebug, "The weights are " + StringUtils.join(wtd.subList(0, Math.min(wtd.size(), 600)), "\n"));
    } else if (scoreClassifierType.equals(ClassifierType.SVM)) {
        SVMLightClassifierFactory<String, ScorePhraseMeasures> svmcf = new SVMLightClassifierFactory<>(true);
        classifier = svmcf.trainClassifier(dataset);
        Set<String> labels = Generics.newHashSet(Arrays.asList("true"));
        List<Triple<ScorePhraseMeasures, String, Double>> topfeatures = ((SVMLightClassifier<String, ScorePhraseMeasures>) classifier).getTopFeatures(labels, 0, true, 600, true);
        Redwood.log(ConstantsAndVariables.minimaldebug, "The weights are " + StringUtils.join(topfeatures, "\n"));
    } else if (scoreClassifierType.equals(ClassifierType.SHIFTLR)) {
        //change the dataset to basic dataset because currently ShiftParamsLR doesn't support RVFDatum
        GeneralDataset<String, ScorePhraseMeasures> newdataset = new Dataset<>();
        Iterator<RVFDatum<String, ScorePhraseMeasures>> iter = dataset.iterator();
        while (iter.hasNext()) {
            RVFDatum<String, ScorePhraseMeasures> inst = iter.next();
            newdataset.add(new BasicDatum<>(inst.asFeatures(), inst.label()));
        }
        ShiftParamsLogisticClassifierFactory<String, ScorePhraseMeasures> factory = new ShiftParamsLogisticClassifierFactory<>();
        classifier = factory.trainClassifier(newdataset);
        //print weights
        MultinomialLogisticClassifier<String, ScorePhraseMeasures> logcl = ((MultinomialLogisticClassifier) classifier);
        Counter<ScorePhraseMeasures> weights = logcl.weightsAsGenericCounter().get("true");
        List<Pair<ScorePhraseMeasures, Double>> wtd = Counters.toDescendingMagnitudeSortedListWithCounts(weights);
        Redwood.log(ConstantsAndVariables.minimaldebug, "The weights are " + StringUtils.join(wtd.subList(0, Math.min(wtd.size(), 600)), "\n"));
    } else if (scoreClassifierType.equals(ClassifierType.LINEAR)) {
        LinearClassifierFactory<String, ScorePhraseMeasures> lcf = new LinearClassifierFactory<>();
        classifier = lcf.trainClassifier(dataset);
        Set<String> labels = Generics.newHashSet(Arrays.asList("true"));
        List<Triple<ScorePhraseMeasures, String, Double>> topfeatures = ((LinearClassifier<String, ScorePhraseMeasures>) classifier).getTopFeatures(labels, 0, true, 600, true);
        Redwood.log(ConstantsAndVariables.minimaldebug, "The weights are " + StringUtils.join(topfeatures, "\n"));
    } else
        throw new RuntimeException("cannot identify classifier " + scoreClassifierType);
    //    else if (scoreClassifierType.equals(ClassifierType.RF)) {
    //      ClassifierFactory wekaFactory = new WekaDatumClassifierFactory<String, ScorePhraseMeasures>("weka.classifiers.trees.RandomForest", constVars.wekaOptions);
    //      classifier = wekaFactory.trainClassifier(dataset);
    //      Classifier cls = ((WekaDatumClassifier) classifier).getClassifier();
    //      RandomForest rf = (RandomForest) cls;
    //    }
    BufferedWriter w = new BufferedWriter(new FileWriter("tempscorestrainer.txt"));
    System.out.println("size of learned scores is " + phraseScoresRaw.size());
    for (CandidatePhrase s : phraseScoresRaw.firstKeySet()) {
        w.write(s + "\t" + phraseScoresRaw.getCounter(s) + "\n");
    }
    w.close();
    return classifier;
}
Also used : ScorePhraseMeasures(edu.stanford.nlp.patterns.ConstantsAndVariables.ScorePhraseMeasures) edu.stanford.nlp.classify(edu.stanford.nlp.classify) FileWriter(java.io.FileWriter) BufferedWriter(java.io.BufferedWriter) RVFDatum(edu.stanford.nlp.ling.RVFDatum) AtomicDouble(edu.stanford.nlp.util.concurrent.AtomicDouble)

Example 4 with ScorePhraseMeasures

use of edu.stanford.nlp.patterns.ConstantsAndVariables.ScorePhraseMeasures in project CoreNLP by stanfordnlp.

the class ScorePhrasesLearnFeatWt method getFeatures.

Counter<ScorePhraseMeasures> getFeatures(String label, CandidatePhrase word, Counter<E> patThatExtractedWord, Counter<E> allSelectedPatterns) {
    if (phraseScoresRaw.containsFirstKey(word))
        return phraseScoresRaw.getCounter(word);
    Counter<ScorePhraseMeasures> scoreslist = new ClassicCounter<>();
    //Add features on the word, if any!
    if (word.getFeatures() != null) {
        scoreslist.addAll(Counters.transform(word.getFeatures(), x -> ScorePhraseMeasures.create(x)));
    } else {
        Redwood.log(ConstantsAndVariables.extremedebug, "features are null for " + word);
    }
    if (constVars.usePhraseEvalPatWtByFreq) {
        double tfscore = getPatTFIDFScore(word, patThatExtractedWord, allSelectedPatterns);
        scoreslist.setCount(ScorePhraseMeasures.PATWTBYFREQ, tfscore);
    }
    if (constVars.usePhraseEvalSemanticOdds) {
        double dscore = this.getDictOddsScore(word, label, 0);
        scoreslist.setCount(ScorePhraseMeasures.SEMANTICODDS, dscore);
    }
    if (constVars.usePhraseEvalGoogleNgram) {
        Double gscore = getGoogleNgramScore(word);
        if (gscore.isInfinite() || gscore.isNaN()) {
            throw new RuntimeException("how is the google ngrams score " + gscore + " for " + word);
        }
        scoreslist.setCount(ScorePhraseMeasures.GOOGLENGRAM, gscore);
    }
    if (constVars.usePhraseEvalDomainNgram) {
        Double gscore = getDomainNgramScore(word.getPhrase());
        if (gscore.isInfinite() || gscore.isNaN()) {
            throw new RuntimeException("how is the domain ngrams score " + gscore + " for " + word + " when domain raw freq is " + Data.domainNGramRawFreq.getCount(word) + " and raw freq is " + Data.rawFreq.getCount(word));
        }
        scoreslist.setCount(ScorePhraseMeasures.DOMAINNGRAM, gscore);
    }
    if (constVars.usePhraseEvalWordClass) {
        Integer wordclass = constVars.getWordClassClusters().get(word.getPhrase());
        if (wordclass == null) {
            wordclass = constVars.getWordClassClusters().get(word.getPhrase().toLowerCase());
        }
        scoreslist.setCount(ScorePhraseMeasures.create(ScorePhraseMeasures.DISTSIM.toString() + "-" + wordclass), 1.0);
    }
    if (constVars.usePhraseEvalWordVector) {
        Map<String, double[]> sims = getSimilarities(word.getPhrase());
        if (sims == null) {
            //TODO: make more efficient
            Map<String, Collection<CandidatePhrase>> allPossibleNegativePhrases = getAllPossibleNegativePhrases(label);
            Set<CandidatePhrase> knownPositivePhrases = CollectionUtils.unionAsSet(constVars.getLearnedWords(label).keySet(), constVars.getSeedLabelDictionary().get(label));
            computeSimWithWordVectors(Arrays.asList(word), knownPositivePhrases, allPossibleNegativePhrases, label);
            sims = getSimilarities(word.getPhrase());
        }
        assert sims != null : " Why are there no similarities for " + word;
        double avgPosSim = sims.get(label)[Similarities.AVGSIM.ordinal()];
        double maxPosSim = sims.get(label)[Similarities.MAXSIM.ordinal()];
        double sumNeg = 0, maxNeg = Double.MIN_VALUE;
        double allNumItems = 0;
        for (Entry<String, double[]> simEn : sims.entrySet()) {
            if (simEn.getKey().equals(label))
                continue;
            double numItems = simEn.getValue()[Similarities.NUMITEMS.ordinal()];
            sumNeg += simEn.getValue()[Similarities.AVGSIM.ordinal()] * numItems;
            allNumItems += numItems;
            double maxNegLabel = simEn.getValue()[Similarities.MAXSIM.ordinal()];
            if (maxNeg < maxNegLabel)
                maxNeg = maxNegLabel;
        }
        double avgNegSim = sumNeg / allNumItems;
        scoreslist.setCount(ScorePhraseMeasures.WORDVECPOSSIMAVG, avgPosSim);
        scoreslist.setCount(ScorePhraseMeasures.WORDVECPOSSIMMAX, maxPosSim);
        scoreslist.setCount(ScorePhraseMeasures.WORDVECNEGSIMAVG, avgNegSim);
        scoreslist.setCount(ScorePhraseMeasures.WORDVECNEGSIMAVG, maxNeg);
    }
    if (constVars.usePhraseEvalEditDistSame) {
        double ed = constVars.getEditDistanceScoresThisClass(label, word.getPhrase());
        assert ed <= 1 : " how come edit distance from the true class is " + ed + " for word " + word;
        scoreslist.setCount(ScorePhraseMeasures.EDITDISTSAME, ed);
    }
    if (constVars.usePhraseEvalEditDistOther) {
        double ed = constVars.getEditDistanceScoresOtherClass(label, word.getPhrase());
        assert ed <= 1 : " how come edit distance from the true class is " + ed + " for word " + word;
        ;
        scoreslist.setCount(ScorePhraseMeasures.EDITDISTOTHER, ed);
    }
    if (constVars.usePhraseEvalWordShape) {
        scoreslist.setCount(ScorePhraseMeasures.WORDSHAPE, this.getWordShapeScore(word.getPhrase(), label));
    }
    if (constVars.usePhraseEvalWordShapeStr) {
        scoreslist.setCount(ScorePhraseMeasures.create(ScorePhraseMeasures.WORDSHAPESTR + "-" + this.wordShape(word.getPhrase())), 1.0);
    }
    if (constVars.usePhraseEvalFirstCapital) {
        scoreslist.setCount(ScorePhraseMeasures.ISFIRSTCAPITAL, StringUtils.isCapitalized(word.getPhrase()) ? 1.0 : 0);
    }
    if (constVars.usePhraseEvalBOW) {
        for (String s : word.getPhrase().split("\\s+")) scoreslist.setCount(ScorePhraseMeasures.create(ScorePhraseMeasures.BOW + "-" + s), 1.0);
    }
    phraseScoresRaw.setCounter(word, scoreslist);
    //System.out.println("scores for " + word + " are " + scoreslist);
    return scoreslist;
}
Also used : java.util(java.util) ExtractPhraseFromPattern(edu.stanford.nlp.patterns.dep.ExtractPhraseFromPattern) edu.stanford.nlp.util(edu.stanford.nlp.util) ConcurrentHashCounter(edu.stanford.nlp.util.concurrent.ConcurrentHashCounter) Function(java.util.function.Function) edu.stanford.nlp.stats(edu.stanford.nlp.stats) AtomicDouble(edu.stanford.nlp.util.concurrent.AtomicDouble) SemanticGraph(edu.stanford.nlp.semgraph.SemanticGraph) RVFDatum(edu.stanford.nlp.ling.RVFDatum) Option(edu.stanford.nlp.util.ArgumentParser.Option) IndexedWord(edu.stanford.nlp.ling.IndexedWord) CoreLabel(edu.stanford.nlp.ling.CoreLabel) ScorePhraseMeasures(edu.stanford.nlp.patterns.ConstantsAndVariables.ScorePhraseMeasures) DataInstanceDep(edu.stanford.nlp.patterns.dep.DataInstanceDep) BufferedWriter(java.io.BufferedWriter) java.util.concurrent(java.util.concurrent) IOUtils(edu.stanford.nlp.io.IOUtils) Redwood(edu.stanford.nlp.util.logging.Redwood) FileWriter(java.io.FileWriter) BasicDatum(edu.stanford.nlp.ling.BasicDatum) IOException(java.io.IOException) Collectors(java.util.stream.Collectors) File(java.io.File) ExtractedPhrase(edu.stanford.nlp.patterns.dep.ExtractedPhrase) edu.stanford.nlp.classify(edu.stanford.nlp.classify) Entry(java.util.Map.Entry) ScorePhraseMeasures(edu.stanford.nlp.patterns.ConstantsAndVariables.ScorePhraseMeasures) AtomicDouble(edu.stanford.nlp.util.concurrent.AtomicDouble)

Aggregations

RVFDatum (edu.stanford.nlp.ling.RVFDatum)4 ScorePhraseMeasures (edu.stanford.nlp.patterns.ConstantsAndVariables.ScorePhraseMeasures)4 BufferedWriter (java.io.BufferedWriter)4 FileWriter (java.io.FileWriter)4 edu.stanford.nlp.classify (edu.stanford.nlp.classify)3 AtomicDouble (edu.stanford.nlp.util.concurrent.AtomicDouble)3 ConcurrentHashCounter (edu.stanford.nlp.util.concurrent.ConcurrentHashCounter)3 File (java.io.File)3 IOException (java.io.IOException)3 Entry (java.util.Map.Entry)3 IOUtils (edu.stanford.nlp.io.IOUtils)2 BasicDatum (edu.stanford.nlp.ling.BasicDatum)2 CoreLabel (edu.stanford.nlp.ling.CoreLabel)2 IndexedWord (edu.stanford.nlp.ling.IndexedWord)2 DataInstanceDep (edu.stanford.nlp.patterns.dep.DataInstanceDep)2 ExtractPhraseFromPattern (edu.stanford.nlp.patterns.dep.ExtractPhraseFromPattern)2 ExtractedPhrase (edu.stanford.nlp.patterns.dep.ExtractedPhrase)2 SemanticGraph (edu.stanford.nlp.semgraph.SemanticGraph)2 edu.stanford.nlp.stats (edu.stanford.nlp.stats)2 edu.stanford.nlp.util (edu.stanford.nlp.util)2