Search in sources :

Example 1 with RVFDatum

use of edu.stanford.nlp.ling.RVFDatum in project CoreNLP by stanfordnlp.

the class LinearClassifier method justificationOf.

/** Print all features active for a particular datum and the weight that
   *  the classifier assigns to each class for those features.
   *
   *  @param example The datum for which features are to be printed
   *  @param pw Where to print it to
   *  @param printer If this is non-null, then it is applied to each
   *        feature to convert it to a more readable form
   *  @param sortedByFeature Whether to sort by feature names
   */
public <T> void justificationOf(Datum<L, F> example, PrintWriter pw, Function<F, T> printer, boolean sortedByFeature) {
    if (example instanceof RVFDatum<?, ?>) {
        justificationOfRVFDatum((RVFDatum<L, F>) example, pw);
        return;
    }
    NumberFormat nf = NumberFormat.getNumberInstance();
    nf.setMinimumFractionDigits(2);
    nf.setMaximumFractionDigits(2);
    if (nf instanceof DecimalFormat) {
        ((DecimalFormat) nf).setPositivePrefix(" ");
    }
    // determine width for features, making it at least total's width
    int featureLength = 0;
    //TODO: not really sure what this Printer is supposed to spit out...
    for (F f : example.asFeatures()) {
        int length = f.toString().length();
        if (printer != null) {
            length = printer.apply(f).toString().length();
        }
        featureLength = Math.max(featureLength, length);
    }
    // make as wide as total printout
    featureLength = Math.max(featureLength, "Total:".length());
    // don't make it ridiculously wide
    featureLength = Math.min(featureLength, MAX_FEATURE_ALIGN_WIDTH);
    // determine width for labels
    int labelLength = 6;
    for (L l : labels()) {
        labelLength = Math.max(labelLength, l.toString().length());
    }
    // print header row of output listing classes
    StringBuilder header = new StringBuilder("");
    for (int s = 0; s < featureLength; s++) {
        header.append(' ');
    }
    for (L l : labels()) {
        header.append(' ');
        header.append(StringUtils.pad(l, labelLength));
    }
    pw.println(header);
    // print active features and weights per class
    Collection<F> featColl = example.asFeatures();
    if (sortedByFeature) {
        featColl = ErasureUtils.sortedIfPossible(featColl);
    }
    for (F f : featColl) {
        String fStr;
        if (printer != null) {
            fStr = printer.apply(f).toString();
        } else {
            fStr = f.toString();
        }
        StringBuilder line = new StringBuilder(fStr);
        for (int s = fStr.length(); s < featureLength; s++) {
            line.append(' ');
        }
        for (L l : labels()) {
            String lStr = nf.format(weight(f, l));
            line.append(' ');
            line.append(lStr);
            for (int s = lStr.length(); s < labelLength; s++) {
                line.append(' ');
            }
        }
        pw.println(line);
    }
    // Print totals, probs, etc.
    Counter<L> scores = scoresOf(example);
    StringBuilder footer = new StringBuilder("Total:");
    for (int s = footer.length(); s < featureLength; s++) {
        footer.append(' ');
    }
    for (L l : labels()) {
        footer.append(' ');
        String str = nf.format(scores.getCount(l));
        footer.append(str);
        for (int s = str.length(); s < labelLength; s++) {
            footer.append(' ');
        }
    }
    pw.println(footer);
    Distribution<L> distr = Distribution.distributionFromLogisticCounter(scores);
    footer = new StringBuilder("Prob:");
    for (int s = footer.length(); s < featureLength; s++) {
        footer.append(' ');
    }
    for (L l : labels()) {
        footer.append(' ');
        String str = nf.format(distr.getCount(l));
        footer.append(str);
        for (int s = str.length(); s < labelLength; s++) {
            footer.append(' ');
        }
    }
    pw.println(footer);
}
Also used : DecimalFormat(java.text.DecimalFormat) RVFDatum(edu.stanford.nlp.ling.RVFDatum) NumberFormat(java.text.NumberFormat)

Example 2 with RVFDatum

use of edu.stanford.nlp.ling.RVFDatum in project CoreNLP by stanfordnlp.

the class SimpleSentiment method classify.

/**
   * @see SimpleSentiment#classify(CoreMap)
   */
public SentimentClass classify(String text) {
    Annotation ann = new Annotation(text);
    pipeline.get().annotate(ann);
    CoreMap sentence = ann.get(CoreAnnotations.SentencesAnnotation.class).get(0);
    Counter<String> features = featurize(sentence);
    RVFDatum<SentimentClass, String> datum = new RVFDatum<>(features);
    return impl.classOf(datum);
}
Also used : SentimentClass(edu.stanford.nlp.simple.SentimentClass) RVFDatum(edu.stanford.nlp.ling.RVFDatum) CoreMap(edu.stanford.nlp.util.CoreMap) Annotation(edu.stanford.nlp.pipeline.Annotation)

Example 3 with RVFDatum

use of edu.stanford.nlp.ling.RVFDatum in project CoreNLP by stanfordnlp.

the class RFSieve method extractDatum.

public static RVFDatum<Boolean, String> extractDatum(Mention m, Mention candidate, Document document, int mentionDist, Dictionaries dict, Properties props, String sievename) {
    try {
        boolean label = (document.goldMentions == null) ? false : document.isCoref(m, candidate);
        Counter<String> features = new ClassicCounter<>();
        CorefCluster mC = document.corefClusters.get(m.corefClusterID);
        CorefCluster aC = document.corefClusters.get(candidate.corefClusterID);
        CoreLabel mFirst = m.sentenceWords.get(m.startIndex);
        CoreLabel mLast = m.sentenceWords.get(m.endIndex - 1);
        CoreLabel mPreceding = (m.startIndex > 0) ? m.sentenceWords.get(m.startIndex - 1) : null;
        CoreLabel mFollowing = (m.endIndex < m.sentenceWords.size()) ? m.sentenceWords.get(m.endIndex) : null;
        CoreLabel aFirst = candidate.sentenceWords.get(candidate.startIndex);
        CoreLabel aLast = candidate.sentenceWords.get(candidate.endIndex - 1);
        CoreLabel aPreceding = (candidate.startIndex > 0) ? candidate.sentenceWords.get(candidate.startIndex - 1) : null;
        CoreLabel aFollowing = (candidate.endIndex < candidate.sentenceWords.size()) ? candidate.sentenceWords.get(candidate.endIndex) : null;
        ////////////////////////////////////////////////////////////////////////////////
        if (HybridCorefProperties.useBasicFeatures(props, sievename)) {
            int sentDist = m.sentNum - candidate.sentNum;
            features.incrementCount("SENTDIST", sentDist);
            features.incrementCount("MENTIONDIST", mentionDist);
            int minSentDist = sentDist;
            for (Mention a : aC.corefMentions) {
                minSentDist = Math.min(minSentDist, Math.abs(m.sentNum - a.sentNum));
            }
            features.incrementCount("MINSENTDIST", minSentDist);
            // When they are in the same sentence, divides a sentence into clauses and add such feature
            if (CorefProperties.useConstituencyParse(props)) {
                if (m.sentNum == candidate.sentNum) {
                    int clauseCount = 0;
                    Tree tree = m.contextParseTree;
                    Tree current = m.mentionSubTree;
                    while (true) {
                        current = current.ancestor(1, tree);
                        if (current.label().value().startsWith("S")) {
                            clauseCount++;
                        }
                        if (current.dominates(candidate.mentionSubTree))
                            break;
                        if (current.label().value().equals("ROOT") || current.ancestor(1, tree) == null)
                            break;
                    }
                    features.incrementCount("CLAUSECOUNT", clauseCount);
                }
            }
            if (document.docType == DocType.CONVERSATION)
                features.incrementCount("B-DOCTYPE-" + document.docType);
            if (m.headWord.get(SpeakerAnnotation.class).equalsIgnoreCase("PER0")) {
                features.incrementCount("B-SPEAKER-PER0");
            }
            if (document.docInfo != null && document.docInfo.containsKey("DOC_ID")) {
                features.incrementCount("B-DOCSOURCE-" + document.docInfo.get("DOC_ID").split("/")[1]);
            }
            features.incrementCount("M-LENGTH", m.originalSpan.size());
            features.incrementCount("A-LENGTH", candidate.originalSpan.size());
            if (m.originalSpan.size() < candidate.originalSpan.size())
                features.incrementCount("B-A-ISLONGER");
            features.incrementCount("A-SIZE", aC.getCorefMentions().size());
            features.incrementCount("M-SIZE", mC.getCorefMentions().size());
            String antRole = "A-NOROLE";
            String mRole = "M-NOROLE";
            if (m.isSubject)
                mRole = "M-SUBJ";
            if (m.isDirectObject)
                mRole = "M-DOBJ";
            if (m.isIndirectObject)
                mRole = "M-IOBJ";
            if (m.isPrepositionObject)
                mRole = "M-POBJ";
            if (candidate.isSubject)
                antRole = "A-SUBJ";
            if (candidate.isDirectObject)
                antRole = "A-DOBJ";
            if (candidate.isIndirectObject)
                antRole = "A-IOBJ";
            if (candidate.isPrepositionObject)
                antRole = "A-POBJ";
            features.incrementCount("B-" + mRole);
            features.incrementCount("B-" + antRole);
            features.incrementCount("B-" + antRole + "-" + mRole);
            if (HybridCorefProperties.combineObjectRoles(props, sievename)) {
                // combine all objects
                if (m.isDirectObject || m.isIndirectObject || m.isPrepositionObject || candidate.isDirectObject || candidate.isIndirectObject || candidate.isPrepositionObject) {
                    if (m.isDirectObject || m.isIndirectObject || m.isPrepositionObject) {
                        mRole = "M-OBJ";
                        features.incrementCount("B-M-OBJ");
                    }
                    if (candidate.isDirectObject || candidate.isIndirectObject || candidate.isPrepositionObject) {
                        antRole = "A-OBJ";
                        features.incrementCount("B-A-OBJ");
                    }
                    features.incrementCount("B-" + antRole + "-" + mRole);
                }
            }
            if (mFirst.word().toLowerCase().matches("a|an")) {
                features.incrementCount("B-M-START-WITH-INDEFINITE");
            }
            if (aFirst.word().toLowerCase().matches("a|an")) {
                features.incrementCount("B-A-START-WITH-INDEFINITE");
            }
            if (mFirst.word().equalsIgnoreCase("the")) {
                features.incrementCount("B-M-START-WITH-DEFINITE");
            }
            if (aFirst.word().equalsIgnoreCase("the")) {
                features.incrementCount("B-A-START-WITH-DEFINITE");
            }
            if (dict.indefinitePronouns.contains(m.lowercaseNormalizedSpanString())) {
                features.incrementCount("B-M-INDEFINITE-PRONOUN");
            }
            if (dict.indefinitePronouns.contains(candidate.lowercaseNormalizedSpanString())) {
                features.incrementCount("B-A-INDEFINITE-PRONOUN");
            }
            if (dict.indefinitePronouns.contains(mFirst.word().toLowerCase())) {
                features.incrementCount("B-M-INDEFINITE-ADJ");
            }
            if (dict.indefinitePronouns.contains(aFirst.word().toLowerCase())) {
                features.incrementCount("B-A-INDEFINITE-ADJ");
            }
            if (dict.reflexivePronouns.contains(m.headString)) {
                features.incrementCount("B-M-REFLEXIVE");
            }
            if (dict.reflexivePronouns.contains(candidate.headString)) {
                features.incrementCount("B-A-REFLEXIVE");
            }
            if (m.headIndex == m.endIndex - 1)
                features.incrementCount("B-M-HEADEND");
            if (m.headIndex < m.endIndex - 1) {
                CoreLabel headnext = m.sentenceWords.get(m.headIndex + 1);
                if (headnext.word().matches("that|,") || headnext.tag().startsWith("W")) {
                    features.incrementCount("B-M-HASPOSTPHRASE");
                    if (mFirst.tag().equals("DT") && mFirst.word().toLowerCase().matches("the|this|these|those"))
                        features.incrementCount("B-M-THE-HASPOSTPHRASE");
                    else if (mFirst.word().toLowerCase().matches("a|an"))
                        features.incrementCount("B-M-INDEFINITE-HASPOSTPHRASE");
                }
            }
            // shape feature from Bjorkelund & Kuhn
            StringBuilder sb = new StringBuilder();
            List<Mention> sortedMentions = new ArrayList<>(aC.corefMentions.size());
            sortedMentions.addAll(aC.corefMentions);
            Collections.sort(sortedMentions, new CorefChain.MentionComparator());
            for (Mention a : sortedMentions) {
                sb.append(a.mentionType).append("-");
            }
            features.incrementCount("B-A-SHAPE-" + sb.toString());
            sb = new StringBuilder();
            sortedMentions = new ArrayList<>(mC.corefMentions.size());
            sortedMentions.addAll(mC.corefMentions);
            Collections.sort(sortedMentions, new CorefChain.MentionComparator());
            for (Mention men : sortedMentions) {
                sb.append(men.mentionType).append("-");
            }
            features.incrementCount("B-M-SHAPE-" + sb.toString());
            if (CorefProperties.useConstituencyParse(props)) {
                sb = new StringBuilder();
                Tree mTree = m.contextParseTree;
                Tree mHead = mTree.getLeaves().get(m.headIndex).ancestor(1, mTree);
                for (Tree node : mTree.pathNodeToNode(mHead, mTree)) {
                    sb.append(node.value()).append("-");
                    if (node.value().equals("S"))
                        break;
                }
                features.incrementCount("B-M-SYNPATH-" + sb.toString());
                sb = new StringBuilder();
                Tree aTree = candidate.contextParseTree;
                Tree aHead = aTree.getLeaves().get(candidate.headIndex).ancestor(1, aTree);
                for (Tree node : aTree.pathNodeToNode(aHead, aTree)) {
                    sb.append(node.value()).append("-");
                    if (node.value().equals("S"))
                        break;
                }
                features.incrementCount("B-A-SYNPATH-" + sb.toString());
            }
            features.incrementCount("A-FIRSTAPPEAR", aC.representative.sentNum);
            features.incrementCount("M-FIRSTAPPEAR", mC.representative.sentNum);
            // document size in # of sentences
            int docSize = document.predictedMentions.size();
            features.incrementCount("A-FIRSTAPPEAR-NORMALIZED", aC.representative.sentNum / docSize);
            features.incrementCount("M-FIRSTAPPEAR-NORMALIZED", mC.representative.sentNum / docSize);
        }
        ////////////////////////////////////////////////////////////////////////////////
        if (HybridCorefProperties.useMentionDetectionFeatures(props, sievename)) {
            // bare plurals
            if (m.originalSpan.size() == 1 && m.headWord.tag().equals("NNS"))
                features.incrementCount("B-M-BAREPLURAL");
            if (candidate.originalSpan.size() == 1 && candidate.headWord.tag().equals("NNS"))
                features.incrementCount("B-A-BAREPLURAL");
            // pleonastic it
            if (CorefProperties.useConstituencyParse(props)) {
                if (RuleBasedCorefMentionFinder.isPleonastic(m, m.contextParseTree) || RuleBasedCorefMentionFinder.isPleonastic(candidate, candidate.contextParseTree)) {
                    features.incrementCount("B-PLEONASTICIT");
                }
            }
            // quantRule
            if (dict.quantifiers.contains(mFirst.word().toLowerCase(Locale.ENGLISH)))
                features.incrementCount("B-M-QUANTIFIER");
            if (dict.quantifiers.contains(aFirst.word().toLowerCase(Locale.ENGLISH)))
                features.incrementCount("B-A-QUANTIFIER");
            // starts with negation
            if (mFirst.word().toLowerCase(Locale.ENGLISH).matches("none|no|nothing|not") || aFirst.word().toLowerCase(Locale.ENGLISH).matches("none|no|nothing|not")) {
                features.incrementCount("B-NEGATIVE-START");
            }
            // parititive rule
            if (RuleBasedCorefMentionFinder.partitiveRule(m, m.sentenceWords, dict))
                features.incrementCount("B-M-PARTITIVE");
            if (RuleBasedCorefMentionFinder.partitiveRule(candidate, candidate.sentenceWords, dict))
                features.incrementCount("B-A-PARTITIVE");
            // %
            if (m.headString.equals("%"))
                features.incrementCount("B-M-HEAD%");
            if (candidate.headString.equals("%"))
                features.incrementCount("B-A-HEAD%");
            // adjective form of nations
            if (dict.isAdjectivalDemonym(m.spanToString()))
                features.incrementCount("B-M-ADJ-DEMONYM");
            if (dict.isAdjectivalDemonym(candidate.spanToString()))
                features.incrementCount("B-A-ADJ-DEMONYM");
            // ends with "etc."
            if (m.lowercaseNormalizedSpanString().endsWith("etc."))
                features.incrementCount("B-M-ETC-END");
            if (candidate.lowercaseNormalizedSpanString().endsWith("etc."))
                features.incrementCount("B-A-ETC-END");
        }
        ////////////////////////////////////////////////////////////////////////////////
        ///////    attributes, attributes agree                             ////////////
        ////////////////////////////////////////////////////////////////////////////////
        features.incrementCount("B-M-NUMBER-" + m.number);
        features.incrementCount("B-A-NUMBER-" + candidate.number);
        features.incrementCount("B-M-GENDER-" + m.gender);
        features.incrementCount("B-A-GENDER-" + candidate.gender);
        features.incrementCount("B-M-ANIMACY-" + m.animacy);
        features.incrementCount("B-A-ANIMACY-" + candidate.animacy);
        features.incrementCount("B-M-PERSON-" + m.person);
        features.incrementCount("B-A-PERSON-" + candidate.person);
        features.incrementCount("B-M-NETYPE-" + m.nerString);
        features.incrementCount("B-A-NETYPE-" + candidate.nerString);
        features.incrementCount("B-BOTH-NUMBER-" + candidate.number + "-" + m.number);
        features.incrementCount("B-BOTH-GENDER-" + candidate.gender + "-" + m.gender);
        features.incrementCount("B-BOTH-ANIMACY-" + candidate.animacy + "-" + m.animacy);
        features.incrementCount("B-BOTH-PERSON-" + candidate.person + "-" + m.person);
        features.incrementCount("B-BOTH-NETYPE-" + candidate.nerString + "-" + m.nerString);
        Set<Number> mcNumber = Generics.newHashSet();
        for (Number n : mC.numbers) {
            features.incrementCount("B-MC-NUMBER-" + n);
            mcNumber.add(n);
        }
        if (mcNumber.size() == 1) {
            features.incrementCount("B-MC-CLUSTERNUMBER-" + mcNumber.iterator().next());
        } else {
            mcNumber.remove(Number.UNKNOWN);
            if (mcNumber.size() == 1)
                features.incrementCount("B-MC-CLUSTERNUMBER-" + mcNumber.iterator().next());
            else
                features.incrementCount("B-MC-CLUSTERNUMBER-CONFLICT");
        }
        Set<Gender> mcGender = Generics.newHashSet();
        for (Gender g : mC.genders) {
            features.incrementCount("B-MC-GENDER-" + g);
            mcGender.add(g);
        }
        if (mcGender.size() == 1) {
            features.incrementCount("B-MC-CLUSTERGENDER-" + mcGender.iterator().next());
        } else {
            mcGender.remove(Gender.UNKNOWN);
            if (mcGender.size() == 1)
                features.incrementCount("B-MC-CLUSTERGENDER-" + mcGender.iterator().next());
            else
                features.incrementCount("B-MC-CLUSTERGENDER-CONFLICT");
        }
        Set<Animacy> mcAnimacy = Generics.newHashSet();
        for (Animacy a : mC.animacies) {
            features.incrementCount("B-MC-ANIMACY-" + a);
            mcAnimacy.add(a);
        }
        if (mcAnimacy.size() == 1) {
            features.incrementCount("B-MC-CLUSTERANIMACY-" + mcAnimacy.iterator().next());
        } else {
            mcAnimacy.remove(Animacy.UNKNOWN);
            if (mcAnimacy.size() == 1)
                features.incrementCount("B-MC-CLUSTERANIMACY-" + mcAnimacy.iterator().next());
            else
                features.incrementCount("B-MC-CLUSTERANIMACY-CONFLICT");
        }
        Set<String> mcNER = Generics.newHashSet();
        for (String t : mC.nerStrings) {
            features.incrementCount("B-MC-NETYPE-" + t);
            mcNER.add(t);
        }
        if (mcNER.size() == 1) {
            features.incrementCount("B-MC-CLUSTERNETYPE-" + mcNER.iterator().next());
        } else {
            mcNER.remove("O");
            if (mcNER.size() == 1)
                features.incrementCount("B-MC-CLUSTERNETYPE-" + mcNER.iterator().next());
            else
                features.incrementCount("B-MC-CLUSTERNETYPE-CONFLICT");
        }
        Set<Number> acNumber = Generics.newHashSet();
        for (Number n : aC.numbers) {
            features.incrementCount("B-AC-NUMBER-" + n);
            acNumber.add(n);
        }
        if (acNumber.size() == 1) {
            features.incrementCount("B-AC-CLUSTERNUMBER-" + acNumber.iterator().next());
        } else {
            acNumber.remove(Number.UNKNOWN);
            if (acNumber.size() == 1)
                features.incrementCount("B-AC-CLUSTERNUMBER-" + acNumber.iterator().next());
            else
                features.incrementCount("B-AC-CLUSTERNUMBER-CONFLICT");
        }
        Set<Gender> acGender = Generics.newHashSet();
        for (Gender g : aC.genders) {
            features.incrementCount("B-AC-GENDER-" + g);
            acGender.add(g);
        }
        if (acGender.size() == 1) {
            features.incrementCount("B-AC-CLUSTERGENDER-" + acGender.iterator().next());
        } else {
            acGender.remove(Gender.UNKNOWN);
            if (acGender.size() == 1)
                features.incrementCount("B-AC-CLUSTERGENDER-" + acGender.iterator().next());
            else
                features.incrementCount("B-AC-CLUSTERGENDER-CONFLICT");
        }
        Set<Animacy> acAnimacy = Generics.newHashSet();
        for (Animacy a : aC.animacies) {
            features.incrementCount("B-AC-ANIMACY-" + a);
            acAnimacy.add(a);
        }
        if (acAnimacy.size() == 1) {
            features.incrementCount("B-AC-CLUSTERANIMACY-" + acAnimacy.iterator().next());
        } else {
            acAnimacy.remove(Animacy.UNKNOWN);
            if (acAnimacy.size() == 1)
                features.incrementCount("B-AC-CLUSTERANIMACY-" + acAnimacy.iterator().next());
            else
                features.incrementCount("B-AC-CLUSTERANIMACY-CONFLICT");
        }
        Set<String> acNER = Generics.newHashSet();
        for (String t : aC.nerStrings) {
            features.incrementCount("B-AC-NETYPE-" + t);
            acNER.add(t);
        }
        if (acNER.size() == 1) {
            features.incrementCount("B-AC-CLUSTERNETYPE-" + acNER.iterator().next());
        } else {
            acNER.remove("O");
            if (acNER.size() == 1)
                features.incrementCount("B-AC-CLUSTERNETYPE-" + acNER.iterator().next());
            else
                features.incrementCount("B-AC-CLUSTERNETYPE-CONFLICT");
        }
        if (m.numbersAgree(candidate))
            features.incrementCount("B-NUMBER-AGREE");
        if (m.gendersAgree(candidate))
            features.incrementCount("B-GENDER-AGREE");
        if (m.animaciesAgree(candidate))
            features.incrementCount("B-ANIMACY-AGREE");
        if (CorefRules.entityAttributesAgree(mC, aC))
            features.incrementCount("B-ATTRIBUTES-AGREE");
        if (CorefRules.entityPersonDisagree(document, m, candidate, dict))
            features.incrementCount("B-PERSON-DISAGREE");
        ////////////////////////////////////////////////////////////////////////////////
        if (HybridCorefProperties.useDcorefRules(props, sievename)) {
            if (CorefRules.entityIWithinI(m, candidate, dict))
                features.incrementCount("B-i-within-i");
            if (CorefRules.antecedentIsMentionSpeaker(document, m, candidate, dict))
                features.incrementCount("B-ANT-IS-SPEAKER");
            if (CorefRules.entitySameSpeaker(document, m, candidate))
                features.incrementCount("B-SAME-SPEAKER");
            if (CorefRules.entitySubjectObject(m, candidate))
                features.incrementCount("B-SUBJ-OBJ");
            for (Mention a : aC.corefMentions) {
                if (CorefRules.entitySubjectObject(m, a))
                    features.incrementCount("B-CLUSTER-SUBJ-OBJ");
            }
            if (CorefRules.entityPersonDisagree(document, m, candidate, dict) && CorefRules.entitySameSpeaker(document, m, candidate))
                features.incrementCount("B-PERSON-DISAGREE-SAME-SPEAKER");
            if (CorefRules.entityIWithinI(mC, aC, dict))
                features.incrementCount("B-ENTITY-IWITHINI");
            if (CorefRules.antecedentMatchesMentionSpeakerAnnotation(m, candidate, document))
                features.incrementCount("B-ANT-IS-SPEAKER-OF-MENTION");
            Set<MentionType> mType = HybridCorefProperties.getMentionType(props, sievename);
            if (mType.contains(MentionType.PROPER) || mType.contains(MentionType.NOMINAL)) {
                if (m.headString.equals(candidate.headString))
                    features.incrementCount("B-HEADMATCH");
                if (CorefRules.entityHeadsAgree(mC, aC, m, candidate, dict))
                    features.incrementCount("B-HEADSAGREE");
                if (CorefRules.entityExactStringMatch(mC, aC, dict, document.roleSet))
                    features.incrementCount("B-EXACTSTRINGMATCH");
                if (CorefRules.entityHaveExtraProperNoun(m, candidate, new HashSet<>()))
                    features.incrementCount("B-HAVE-EXTRA-PROPER-NOUN");
                if (CorefRules.entityBothHaveProper(mC, aC))
                    features.incrementCount("B-BOTH-HAVE-PROPER");
                if (CorefRules.entityHaveDifferentLocation(m, candidate, dict))
                    features.incrementCount("B-HAVE-DIFF-LOC");
                if (CorefRules.entityHaveIncompatibleModifier(mC, aC))
                    features.incrementCount("B-HAVE-INCOMPATIBLE-MODIFIER");
                if (CorefRules.entityIsAcronym(document, mC, aC))
                    features.incrementCount("B-IS-ACRONYM");
                if (CorefRules.entityIsApposition(mC, aC, m, candidate))
                    features.incrementCount("B-IS-APPOSITION");
                if (CorefRules.entityIsPredicateNominatives(mC, aC, m, candidate))
                    features.incrementCount("B-IS-PREDICATE-NOMINATIVES");
                if (CorefRules.entityIsRoleAppositive(mC, aC, m, candidate, dict))
                    features.incrementCount("B-IS-ROLE-APPOSITIVE");
                if (CorefRules.entityNumberInLaterMention(m, candidate))
                    features.incrementCount("B-NUMBER-IN-LATER");
                if (CorefRules.entityRelaxedExactStringMatch(mC, aC, m, candidate, dict, document.roleSet))
                    features.incrementCount("B-RELAXED-EXACT-STRING-MATCH");
                if (CorefRules.entityRelaxedHeadsAgreeBetweenMentions(mC, aC, m, candidate))
                    features.incrementCount("B-RELAXED-HEAD-AGREE");
                if (CorefRules.entitySameProperHeadLastWord(m, candidate))
                    features.incrementCount("B-SAME-PROPER-HEAD");
                if (CorefRules.entitySameProperHeadLastWord(mC, aC, m, candidate))
                    features.incrementCount("B-CLUSTER-SAME-PROPER-HEAD");
                if (CorefRules.entityWordsIncluded(mC, aC, m, candidate))
                    features.incrementCount("B-WORD-INCLUSION");
            }
            if (mType.contains(MentionType.LIST)) {
                features.incrementCount("NUM-LIST-", numEntitiesInList(m));
                if (m.spanToString().contains("two") || m.spanToString().contains("2") || m.spanToString().contains("both"))
                    features.incrementCount("LIST-M-TWO");
                if (m.spanToString().contains("three") || m.spanToString().contains("3"))
                    features.incrementCount("LIST-M-THREE");
                if (candidate.spanToString().contains("two") || candidate.spanToString().contains("2") || candidate.spanToString().contains("both")) {
                    features.incrementCount("B-LIST-A-TWO");
                }
                if (candidate.spanToString().contains("three") || candidate.spanToString().contains("3")) {
                    features.incrementCount("B-LIST-A-THREE");
                }
            }
            if (mType.contains(MentionType.PRONOMINAL)) {
                if (dict.firstPersonPronouns.contains(m.headString))
                    features.incrementCount("B-M-I");
                if (dict.secondPersonPronouns.contains(m.headString))
                    features.incrementCount("B-M-YOU");
                if (dict.thirdPersonPronouns.contains(m.headString))
                    features.incrementCount("B-M-3RDPERSON");
                if (dict.possessivePronouns.contains(m.headString))
                    features.incrementCount("B-M-POSSESSIVE");
                if (dict.neutralPronouns.contains(m.headString))
                    features.incrementCount("B-M-NEUTRAL");
                if (dict.malePronouns.contains(m.headString))
                    features.incrementCount("B-M-MALE");
                if (dict.femalePronouns.contains(m.headString))
                    features.incrementCount("B-M-FEMALE");
                if (dict.firstPersonPronouns.contains(candidate.headString))
                    features.incrementCount("B-A-I");
                if (dict.secondPersonPronouns.contains(candidate.headString))
                    features.incrementCount("B-A-YOU");
                if (dict.thirdPersonPronouns.contains(candidate.headString))
                    features.incrementCount("B-A-3RDPERSON");
                if (dict.possessivePronouns.contains(candidate.headString))
                    features.incrementCount("B-A-POSSESSIVE");
                if (dict.neutralPronouns.contains(candidate.headString))
                    features.incrementCount("B-A-NEUTRAL");
                if (dict.malePronouns.contains(candidate.headString))
                    features.incrementCount("B-A-MALE");
                if (dict.femalePronouns.contains(candidate.headString))
                    features.incrementCount("B-A-FEMALE");
                features.incrementCount("B-M-GENERIC-" + m.generic);
                features.incrementCount("B-A-GENERIC-" + candidate.generic);
                if (HybridCorefPrinter.dcorefPronounSieve.skipThisMention(document, m, mC, dict)) {
                    features.incrementCount("B-SKIPTHISMENTION-true");
                }
                if (m.spanToString().equalsIgnoreCase("you") && mFollowing != null && mFollowing.word().equalsIgnoreCase("know")) {
                    features.incrementCount("B-YOUKNOW-PRECEDING-POS-" + ((mPreceding == null) ? "NULL" : mPreceding.tag()));
                    features.incrementCount("B-YOUKNOW-PRECEDING-WORD-" + ((mPreceding == null) ? "NULL" : mPreceding.word().toLowerCase()));
                    CoreLabel nextword = (m.endIndex + 1 < m.sentenceWords.size()) ? m.sentenceWords.get(m.endIndex + 1) : null;
                    features.incrementCount("B-YOUKNOW-FOLLOWING-POS-" + ((nextword == null) ? "NULL" : nextword.tag()));
                    features.incrementCount("B-YOUKNOW-FOLLOWING-WORD-" + ((nextword == null) ? "NULL" : nextword.word().toLowerCase()));
                }
                if (candidate.spanToString().equalsIgnoreCase("you") && aFollowing != null && aFollowing.word().equalsIgnoreCase("know")) {
                    features.incrementCount("B-YOUKNOW-PRECEDING-POS-" + ((aPreceding == null) ? "NULL" : aPreceding.tag()));
                    features.incrementCount("B-YOUKNOW-PRECEDING-WORD-" + ((aPreceding == null) ? "NULL" : aPreceding.word().toLowerCase()));
                    CoreLabel nextword = (candidate.endIndex + 1 < candidate.sentenceWords.size()) ? candidate.sentenceWords.get(candidate.endIndex + 1) : null;
                    features.incrementCount("B-YOUKNOW-FOLLOWING-POS-" + ((nextword == null) ? "NULL" : nextword.tag()));
                    features.incrementCount("B-YOUKNOW-FOLLOWING-WORD-" + ((nextword == null) ? "NULL" : nextword.word().toLowerCase()));
                }
            }
            // discourse match features
            if (m.person == Person.YOU && document.docType == DocType.ARTICLE && m.headWord.get(CoreAnnotations.SpeakerAnnotation.class).equals("PER0")) {
                features.incrementCount("B-DISCOURSE-M-YOU-GENERIC?");
            }
            if (candidate.generic && candidate.person == Person.YOU)
                features.incrementCount("B-DISCOURSE-A-YOU-GENERIC?");
            String mString = m.lowercaseNormalizedSpanString();
            String antString = candidate.lowercaseNormalizedSpanString();
            // I-I
            if (m.number == Number.SINGULAR && dict.firstPersonPronouns.contains(mString) && candidate.number == Number.SINGULAR && dict.firstPersonPronouns.contains(antString) && CorefRules.entitySameSpeaker(document, m, candidate)) {
                features.incrementCount("B-DISCOURSE-I-I-SAMESPEAKER");
            }
            // (speaker - I)
            if ((m.number == Number.SINGULAR && dict.firstPersonPronouns.contains(mString)) && CorefRules.antecedentIsMentionSpeaker(document, m, candidate, dict)) {
                features.incrementCount("B-DISCOURSE-SPEAKER-I");
            }
            // (I - speaker)
            if ((candidate.number == Number.SINGULAR && dict.firstPersonPronouns.contains(antString)) && CorefRules.antecedentIsMentionSpeaker(document, candidate, m, dict)) {
                features.incrementCount("B-DISCOURSE-I-SPEAKER");
            }
            // Can be iffy if more than two speakers... but still should be okay most of the time
            if (dict.secondPersonPronouns.contains(mString) && dict.secondPersonPronouns.contains(antString) && CorefRules.entitySameSpeaker(document, m, candidate)) {
                features.incrementCount("B-DISCOURSE-BOTH-YOU");
            }
            // previous I - you or previous you - I in two person conversation
            if (((m.person == Person.I && candidate.person == Person.YOU || (m.person == Person.YOU && candidate.person == Person.I)) && (m.headWord.get(CoreAnnotations.UtteranceAnnotation.class) - candidate.headWord.get(CoreAnnotations.UtteranceAnnotation.class) == 1) && document.docType == DocType.CONVERSATION)) {
                features.incrementCount("B-DISCOURSE-I-YOU");
            }
            if (dict.reflexivePronouns.contains(m.headString) && CorefRules.entitySubjectObject(m, candidate)) {
                features.incrementCount("B-DISCOURSE-REFLEXIVE");
            }
            if (m.person == Person.I && candidate.person == Person.I && !CorefRules.entitySameSpeaker(document, m, candidate)) {
                features.incrementCount("B-DISCOURSE-I-I-DIFFSPEAKER");
            }
            if (m.person == Person.YOU && candidate.person == Person.YOU && !CorefRules.entitySameSpeaker(document, m, candidate)) {
                features.incrementCount("B-DISCOURSE-YOU-YOU-DIFFSPEAKER");
            }
            if (m.person == Person.WE && candidate.person == Person.WE && !CorefRules.entitySameSpeaker(document, m, candidate)) {
                features.incrementCount("B-DISCOURSE-WE-WE-DIFFSPEAKER");
            }
        }
        ////////////////////////////////////////////////////////////////////////////////
        if (HybridCorefProperties.usePOSFeatures(props, sievename)) {
            features.incrementCount("B-LEXICAL-M-HEADPOS-" + m.headWord.tag());
            features.incrementCount("B-LEXICAL-A-HEADPOS-" + candidate.headWord.tag());
            features.incrementCount("B-LEXICAL-M-FIRSTPOS-" + mFirst.tag());
            features.incrementCount("B-LEXICAL-A-FIRSTPOS-" + aFirst.tag());
            features.incrementCount("B-LEXICAL-M-LASTPOS-" + mLast.tag());
            features.incrementCount("B-LEXICAL-A-LASTPOS-" + aLast.tag());
            features.incrementCount("B-LEXICAL-M-PRECEDINGPOS-" + ((mPreceding == null) ? "NULL" : mPreceding.tag()));
            features.incrementCount("B-LEXICAL-A-PRECEDINGPOS-" + ((aPreceding == null) ? "NULL" : aPreceding.tag()));
            features.incrementCount("B-LEXICAL-M-FOLLOWINGPOS-" + ((mFollowing == null) ? "NULL" : mFollowing.tag()));
            features.incrementCount("B-LEXICAL-A-FOLLOWINGPOS-" + ((aFollowing == null) ? "NULL" : aFollowing.tag()));
        }
        ////////////////////////////////////////////////////////////////////////////////
        if (HybridCorefProperties.useLexicalFeatures(props, sievename)) {
            features.incrementCount("B-LEXICAL-M-HEADWORD-" + m.headString.toLowerCase());
            features.incrementCount("B-LEXICAL-A-HEADWORD-" + candidate.headString.toLowerCase());
            features.incrementCount("B-LEXICAL-M-FIRSTWORD-" + mFirst.word().toLowerCase());
            features.incrementCount("B-LEXICAL-A-FIRSTWORD-" + aFirst.word().toLowerCase());
            features.incrementCount("B-LEXICAL-M-LASTWORD-" + mLast.word().toLowerCase());
            features.incrementCount("B-LEXICAL-A-LASTWORD-" + aLast.word().toLowerCase());
            features.incrementCount("B-LEXICAL-M-PRECEDINGWORD-" + ((mPreceding == null) ? "NULL" : mPreceding.word().toLowerCase()));
            features.incrementCount("B-LEXICAL-A-PRECEDINGWORD-" + ((aPreceding == null) ? "NULL" : aPreceding.word().toLowerCase()));
            features.incrementCount("B-LEXICAL-M-FOLLOWINGWORD-" + ((mFollowing == null) ? "NULL" : mFollowing.word().toLowerCase()));
            features.incrementCount("B-LEXICAL-A-FOLLOWINGWORD-" + ((aFollowing == null) ? "NULL" : aFollowing.word().toLowerCase()));
            //extra headword, modifiers lexical features
            for (String mHead : mC.heads) {
                if (!aC.heads.contains(mHead))
                    features.incrementCount("B-LEXICAL-MC-EXTRAHEAD-" + mHead);
            }
            for (String mWord : mC.words) {
                if (!aC.words.contains(mWord))
                    features.incrementCount("B-LEXICAL-MC-EXTRAWORD-" + mWord);
            }
        }
        // cosine
        if (HybridCorefProperties.useWordEmbedding(props, sievename)) {
            // dimension
            int dim = dict.vectors.entrySet().iterator().next().getValue().length;
            // distance between headword
            float[] mV = dict.vectors.get(m.headString.toLowerCase());
            float[] aV = dict.vectors.get(candidate.headString.toLowerCase());
            if (mV != null && aV != null) {
                features.incrementCount("WORDVECTOR-DIFF-HEADWORD", cosine(mV, aV));
            }
            mV = dict.vectors.get(mFirst.word().toLowerCase());
            aV = dict.vectors.get(aFirst.word().toLowerCase());
            if (mV != null && aV != null) {
                features.incrementCount("WORDVECTOR-DIFF-FIRSTWORD", cosine(mV, aV));
            }
            mV = dict.vectors.get(mLast.word().toLowerCase());
            aV = dict.vectors.get(aLast.word().toLowerCase());
            if (mV != null && aV != null) {
                features.incrementCount("WORDVECTOR-DIFF-LASTWORD", cosine(mV, aV));
            }
            if (mPreceding != null && aPreceding != null) {
                mV = dict.vectors.get(mPreceding.word().toLowerCase());
                aV = dict.vectors.get(aPreceding.word().toLowerCase());
                if (mV != null && aV != null) {
                    features.incrementCount("WORDVECTOR-DIFF-PRECEDINGWORD", cosine(mV, aV));
                }
            }
            if (mFollowing != null && aFollowing != null) {
                mV = dict.vectors.get(mFollowing.word().toLowerCase());
                aV = dict.vectors.get(aFollowing.word().toLowerCase());
                if (mV != null && aV != null) {
                    features.incrementCount("WORDVECTOR-DIFF-FOLLOWINGWORD", cosine(mV, aV));
                }
            }
            float[] aggreM = new float[dim];
            float[] aggreA = new float[dim];
            for (CoreLabel cl : m.originalSpan) {
                float[] v = dict.vectors.get(cl.word().toLowerCase());
                if (v == null)
                    continue;
                ArrayMath.pairwiseAddInPlace(aggreM, v);
            }
            for (CoreLabel cl : candidate.originalSpan) {
                float[] v = dict.vectors.get(cl.word().toLowerCase());
                if (v == null)
                    continue;
                ArrayMath.pairwiseAddInPlace(aggreA, v);
            }
            if (ArrayMath.L2Norm(aggreM) != 0 && ArrayMath.L2Norm(aggreA) != 0) {
                features.incrementCount("WORDVECTOR-AGGREGATE-DIFF", cosine(aggreM, aggreA));
            }
            int cnt = 0;
            double dist = 0;
            for (CoreLabel mcl : m.originalSpan) {
                for (CoreLabel acl : candidate.originalSpan) {
                    mV = dict.vectors.get(mcl.word().toLowerCase());
                    aV = dict.vectors.get(acl.word().toLowerCase());
                    if (mV == null || aV == null)
                        continue;
                    cnt++;
                    dist += cosine(mV, aV);
                }
            }
            features.incrementCount("WORDVECTOR-AVG-DIFF", dist / cnt);
        }
        return new RVFDatum<>(features, label);
    } catch (Exception e) {
        log.info("Datum Extraction failed in Sieve.java while processing document: " + document.docInfo.get("DOC_ID") + " part: " + document.docInfo.get("DOC_PART"));
        throw new RuntimeException(e);
    }
}
Also used : ArrayList(java.util.ArrayList) Gender(edu.stanford.nlp.coref.data.Dictionaries.Gender) Number(edu.stanford.nlp.coref.data.Dictionaries.Number) CorefChain(edu.stanford.nlp.coref.data.CorefChain) Mention(edu.stanford.nlp.coref.data.Mention) Tree(edu.stanford.nlp.trees.Tree) RVFDatum(edu.stanford.nlp.ling.RVFDatum) Animacy(edu.stanford.nlp.coref.data.Dictionaries.Animacy) MentionType(edu.stanford.nlp.coref.data.Dictionaries.MentionType) CoreLabel(edu.stanford.nlp.ling.CoreLabel) CorefCluster(edu.stanford.nlp.coref.data.CorefCluster) ClassicCounter(edu.stanford.nlp.stats.ClassicCounter) CoreAnnotations(edu.stanford.nlp.ling.CoreAnnotations) SpeakerAnnotation(edu.stanford.nlp.ling.CoreAnnotations.SpeakerAnnotation)

Example 4 with RVFDatum

use of edu.stanford.nlp.ling.RVFDatum in project CoreNLP by stanfordnlp.

the class ScorePhrasesLearnFeatWt method choosedatums.

public GeneralDataset<String, ScorePhraseMeasures> choosedatums(boolean forLearningPattern, String answerLabel, TwoDimensionalCounter<CandidatePhrase, E> wordsPatExtracted, Counter<E> allSelectedPatterns, boolean computeRawFreq) throws IOException {
    boolean expandNeg = false;
    if (closeToNegativesFirstIter == null) {
        closeToNegativesFirstIter = new ClassicCounter<>();
        if (constVars.expandNegativesWhenSampling)
            expandNeg = true;
    }
    boolean expandPos = false;
    if (closeToPositivesFirstIter == null) {
        closeToPositivesFirstIter = new ClassicCounter<>();
        if (constVars.expandPositivesWhenSampling)
            expandPos = true;
    }
    Counter<Integer> distSimClustersOfPositive = new ClassicCounter<>();
    if ((expandPos || expandNeg) && !constVars.useWordVectorsToComputeSim) {
        for (CandidatePhrase s : CollectionUtils.union(constVars.getLearnedWords(answerLabel).keySet(), constVars.getSeedLabelDictionary().get(answerLabel))) {
            String[] toks = s.getPhrase().split("\\s+");
            Integer num = constVars.getWordClassClusters().get(s.getPhrase());
            if (num == null)
                num = constVars.getWordClassClusters().get(s.getPhrase().toLowerCase());
            if (num == null) {
                for (String tok : toks) {
                    Integer toknum = constVars.getWordClassClusters().get(tok);
                    if (toknum == null)
                        toknum = constVars.getWordClassClusters().get(tok.toLowerCase());
                    if (toknum != null) {
                        distSimClustersOfPositive.incrementCount(toknum);
                    }
                }
            } else
                distSimClustersOfPositive.incrementCount(num);
        }
    }
    //computing this regardless of expandpos and expandneg because we reject all positive words that occur in negatives (can happen in multi word phrases etc)
    Map<String, Collection<CandidatePhrase>> allPossibleNegativePhrases = getAllPossibleNegativePhrases(answerLabel);
    GeneralDataset<String, ScorePhraseMeasures> dataset = new RVFDataset<>();
    int numpos = 0;
    Set<CandidatePhrase> allNegativePhrases = new HashSet<>();
    Set<CandidatePhrase> allUnknownPhrases = new HashSet<>();
    Set<CandidatePhrase> allPositivePhrases = new HashSet<>();
    //Counter<CandidatePhrase> allCloseToPositivePhrases = new ClassicCounter<CandidatePhrase>();
    //Counter<CandidatePhrase> allCloseToNegativePhrases = new ClassicCounter<CandidatePhrase>();
    //for all sentences brtch
    ConstantsAndVariables.DataSentsIterator sentsIter = new ConstantsAndVariables.DataSentsIterator(constVars.batchProcessSents);
    while (sentsIter.hasNext()) {
        Pair<Map<String, DataInstance>, File> sentsf = sentsIter.next();
        Map<String, DataInstance> sents = sentsf.first();
        Redwood.log(Redwood.DBG, "Sampling datums from " + sentsf.second());
        if (computeRawFreq)
            Data.computeRawFreqIfNull(sents, PatternFactory.numWordsCompoundMax);
        List<List<String>> threadedSentIds = GetPatternsFromDataMultiClass.getThreadBatches(new ArrayList<>(sents.keySet()), constVars.numThreads);
        ExecutorService executor = Executors.newFixedThreadPool(constVars.numThreads);
        List<Future<Quintuple<Set<CandidatePhrase>, Set<CandidatePhrase>, Set<CandidatePhrase>, Counter<CandidatePhrase>, Counter<CandidatePhrase>>>> list = new ArrayList<>();
        //multi-threaded choose positive, negative and unknown
        for (List<String> keys : threadedSentIds) {
            Callable<Quintuple<Set<CandidatePhrase>, Set<CandidatePhrase>, Set<CandidatePhrase>, Counter<CandidatePhrase>, Counter<CandidatePhrase>>> task = new ChooseDatumsThread(answerLabel, sents, keys, wordsPatExtracted, allSelectedPatterns, distSimClustersOfPositive, allPossibleNegativePhrases, expandPos, expandNeg);
            Future<Quintuple<Set<CandidatePhrase>, Set<CandidatePhrase>, Set<CandidatePhrase>, Counter<CandidatePhrase>, Counter<CandidatePhrase>>> submit = executor.submit(task);
            list.add(submit);
        }
        // Now retrieve the result
        for (Future<Quintuple<Set<CandidatePhrase>, Set<CandidatePhrase>, Set<CandidatePhrase>, Counter<CandidatePhrase>, Counter<CandidatePhrase>>> future : list) {
            try {
                Quintuple<Set<CandidatePhrase>, Set<CandidatePhrase>, Set<CandidatePhrase>, Counter<CandidatePhrase>, Counter<CandidatePhrase>> result = future.get();
                allPositivePhrases.addAll(result.first());
                allNegativePhrases.addAll(result.second());
                allUnknownPhrases.addAll(result.third());
                if (expandPos)
                    for (Entry<CandidatePhrase, Double> en : result.fourth().entrySet()) closeToPositivesFirstIter.setCount(en.getKey(), en.getValue());
                if (expandNeg)
                    for (Entry<CandidatePhrase, Double> en : result.fifth().entrySet()) closeToNegativesFirstIter.setCount(en.getKey(), en.getValue());
            } catch (Exception e) {
                executor.shutdownNow();
                throw new RuntimeException(e);
            }
        }
        executor.shutdown();
    }
    //Set<CandidatePhrase> knownPositivePhrases = CollectionUtils.unionAsSet(constVars.getLearnedWords().get(answerLabel).keySet(), constVars.getSeedLabelDictionary().get(answerLabel));
    //TODO: this is kinda not nice; how is allpositivephrases different from positivephrases again?
    allPositivePhrases.addAll(constVars.getLearnedWords(answerLabel).keySet());
    //allPositivePhrases.addAll(knownPositivePhrases);
    BufferedWriter logFile = null;
    BufferedWriter logFileFeat = null;
    if (constVars.logFileVectorSimilarity != null) {
        logFile = new BufferedWriter(new FileWriter(constVars.logFileVectorSimilarity));
        logFileFeat = new BufferedWriter(new FileWriter(constVars.logFileVectorSimilarity + "_feat"));
        if (wordVectors != null) {
            for (CandidatePhrase p : allPositivePhrases) {
                if (wordVectors.containsKey(p.getPhrase())) {
                    logFile.write(p.getPhrase() + "-P " + ArrayUtils.toString(wordVectors.get(p.getPhrase()), " ") + "\n");
                }
            }
        }
    }
    if (constVars.expandPositivesWhenSampling) {
        //TODO: patwtbyfrew
        //Counters.retainTop(allCloseToPositivePhrases, (int) (allCloseToPositivePhrases.size()*constVars.subSampleUnkAsPosUsingSimPercentage));
        Redwood.log("Expanding positives by adding " + Counters.toSortedString(closeToPositivesFirstIter, closeToPositivesFirstIter.size(), "%1$s:%2$f", "\t") + " phrases");
        allPositivePhrases.addAll(closeToPositivesFirstIter.keySet());
        //write log
        if (logFile != null && wordVectors != null && expandNeg) {
            for (CandidatePhrase p : closeToPositivesFirstIter.keySet()) {
                if (wordVectors.containsKey(p.getPhrase())) {
                    logFile.write(p.getPhrase() + "-PP " + ArrayUtils.toString(wordVectors.get(p.getPhrase()), " ") + "\n");
                }
            }
        }
    }
    if (constVars.expandNegativesWhenSampling) {
        //TODO: patwtbyfrew
        //Counters.retainTop(allCloseToPositivePhrases, (int) (allCloseToPositivePhrases.size()*constVars.subSampleUnkAsPosUsingSimPercentage));
        Redwood.log("Expanding negatives by adding " + Counters.toSortedString(closeToNegativesFirstIter, closeToNegativesFirstIter.size(), "%1$s:%2$f", "\t") + " phrases");
        allNegativePhrases.addAll(closeToNegativesFirstIter.keySet());
        //write log
        if (logFile != null && wordVectors != null && expandNeg) {
            for (CandidatePhrase p : closeToNegativesFirstIter.keySet()) {
                if (wordVectors.containsKey(p.getPhrase())) {
                    logFile.write(p.getPhrase() + "-NN " + ArrayUtils.toString(wordVectors.get(p.getPhrase()), " ") + "\n");
                }
            }
        }
    }
    System.out.println("all positive phrases of size " + allPositivePhrases.size() + " are  " + allPositivePhrases);
    for (CandidatePhrase candidate : allPositivePhrases) {
        Counter<ScorePhraseMeasures> feat;
        //CandidatePhrase candidate = new CandidatePhrase(l.word());
        if (forLearningPattern) {
            feat = getPhraseFeaturesForPattern(answerLabel, candidate);
        } else {
            feat = getFeatures(answerLabel, candidate, wordsPatExtracted.getCounter(candidate), allSelectedPatterns);
        }
        RVFDatum<String, ScorePhraseMeasures> datum = new RVFDatum<>(feat, "true");
        dataset.add(datum);
        numpos += 1;
        if (logFileFeat != null) {
            logFileFeat.write("POSITIVE " + candidate.getPhrase() + "\t" + Counters.toSortedByKeysString(feat, "%1$s:%2$.0f", ";", "%s") + "\n");
        }
    }
    Redwood.log(Redwood.DBG, "Number of pure negative phrases is " + allNegativePhrases.size());
    Redwood.log(Redwood.DBG, "Number of unknown phrases is " + allUnknownPhrases.size());
    if (constVars.subsampleUnkAsNegUsingSim) {
        Set<CandidatePhrase> chosenUnknown = chooseUnknownAsNegatives(allUnknownPhrases, answerLabel, allPositivePhrases, allPossibleNegativePhrases, logFile);
        Redwood.log(Redwood.DBG, "Choosing " + chosenUnknown.size() + " unknowns as negative based to their similarity to the positive phrases");
        allNegativePhrases.addAll(chosenUnknown);
    } else {
        allNegativePhrases.addAll(allUnknownPhrases);
    }
    if (allNegativePhrases.size() > numpos) {
        Redwood.log(Redwood.WARN, "Num of negative (" + allNegativePhrases.size() + ") is higher than number of positive phrases (" + numpos + ") = " + (allNegativePhrases.size() / (double) numpos) + ". " + "Capping the number by taking the first numPositives as negative. Consider decreasing perSelectRand");
        int i = 0;
        Set<CandidatePhrase> selectedNegPhrases = new HashSet<>();
        for (CandidatePhrase p : allNegativePhrases) {
            if (i >= numpos)
                break;
            selectedNegPhrases.add(p);
            i++;
        }
        allNegativePhrases.clear();
        allNegativePhrases = selectedNegPhrases;
    }
    System.out.println("all negative phrases are " + allNegativePhrases);
    for (CandidatePhrase negative : allNegativePhrases) {
        Counter<ScorePhraseMeasures> feat;
        //CandidatePhrase candidate = new CandidatePhrase(l.word());
        if (forLearningPattern) {
            feat = getPhraseFeaturesForPattern(answerLabel, negative);
        } else {
            feat = getFeatures(answerLabel, negative, wordsPatExtracted.getCounter(negative), allSelectedPatterns);
        }
        RVFDatum<String, ScorePhraseMeasures> datum = new RVFDatum<>(feat, "false");
        dataset.add(datum);
        if (logFile != null && wordVectors != null && wordVectors.containsKey(negative.getPhrase())) {
            logFile.write(negative.getPhrase() + "-N" + " " + ArrayUtils.toString(wordVectors.get(negative.getPhrase()), " ") + "\n");
        }
        if (logFileFeat != null)
            logFileFeat.write("NEGATIVE " + negative.getPhrase() + "\t" + Counters.toSortedByKeysString(feat, "%1$s:%2$.0f", ";", "%s") + "\n");
    }
    if (logFile != null) {
        logFile.close();
    }
    if (logFileFeat != null) {
        logFileFeat.close();
    }
    System.out.println("Before feature count threshold, dataset stats are ");
    dataset.summaryStatistics();
    dataset.applyFeatureCountThreshold(constVars.featureCountThreshold);
    System.out.println("AFTER feature count threshold of " + constVars.featureCountThreshold + ", dataset stats are ");
    dataset.summaryStatistics();
    Redwood.log(Redwood.DBG, "Eventually, number of positive datums:  " + numpos + " and number of negative datums: " + allNegativePhrases.size());
    return dataset;
}
Also used : ScorePhraseMeasures(edu.stanford.nlp.patterns.ConstantsAndVariables.ScorePhraseMeasures) FileWriter(java.io.FileWriter) BufferedWriter(java.io.BufferedWriter) Entry(java.util.Map.Entry) ConcurrentHashCounter(edu.stanford.nlp.util.concurrent.ConcurrentHashCounter) RVFDatum(edu.stanford.nlp.ling.RVFDatum) IOException(java.io.IOException) File(java.io.File)

Example 5 with RVFDatum

use of edu.stanford.nlp.ling.RVFDatum in project CoreNLP by stanfordnlp.

the class LinearClassifierITest method testStrMultiClassDatums.

public void testStrMultiClassDatums() throws Exception {
    RVFDataset<String, String> trainData = new RVFDataset<String, String>();
    List<RVFDatum<String, String>> datums = new ArrayList<RVFDatum<String, String>>();
    datums.add(newDatum("alpha", new String[] { "f1", "f2" }, new Double[] { 1.0, 0.0 }));
    ;
    datums.add(newDatum("beta", new String[] { "f1", "f2" }, new Double[] { 0.0, 1.0 }));
    datums.add(newDatum("charlie", new String[] { "f1", "f2" }, new Double[] { 5.0, 5.0 }));
    for (RVFDatum<String, String> datum : datums) trainData.add(datum);
    LinearClassifierFactory<String, String> lfc = new LinearClassifierFactory<String, String>();
    LinearClassifier<String, String> lc = lfc.trainClassifier(trainData);
    RVFDatum td1 = newDatum("alpha", new String[] { "f1", "f2", "f3" }, new Double[] { 2.0, 0.0, 5.5 });
    // Try the obvious (should get train data with 100% acc)
    for (RVFDatum<String, String> datum : datums) Assert.assertEquals(datum.label(), lc.classOf(datum));
    // Test data
    Assert.assertEquals(td1.label(), lc.classOf(td1));
}
Also used : ArrayList(java.util.ArrayList) RVFDatum(edu.stanford.nlp.ling.RVFDatum)

Aggregations

RVFDatum (edu.stanford.nlp.ling.RVFDatum)11 CoreLabel (edu.stanford.nlp.ling.CoreLabel)5 ClassicCounter (edu.stanford.nlp.stats.ClassicCounter)5 edu.stanford.nlp.classify (edu.stanford.nlp.classify)4 IOUtils (edu.stanford.nlp.io.IOUtils)3 RuntimeIOException (edu.stanford.nlp.io.RuntimeIOException)3 CoreAnnotations (edu.stanford.nlp.ling.CoreAnnotations)3 Counter (edu.stanford.nlp.stats.Counter)3 Redwood (edu.stanford.nlp.util.logging.Redwood)3 Util (edu.stanford.nlp.util.logging.Redwood.Util)3 File (java.io.File)3 ArrayList (java.util.ArrayList)3 AtomicInteger (java.util.concurrent.atomic.AtomicInteger)3 Span (edu.stanford.nlp.ie.machinereading.structure.Span)2 ScorePhraseMeasures (edu.stanford.nlp.patterns.ConstantsAndVariables.ScorePhraseMeasures)2 Annotation (edu.stanford.nlp.pipeline.Annotation)2 SentimentClass (edu.stanford.nlp.simple.SentimentClass)2 edu.stanford.nlp.util (edu.stanford.nlp.util)2 CoreMap (edu.stanford.nlp.util.CoreMap)2 RedwoodConfiguration (edu.stanford.nlp.util.logging.RedwoodConfiguration)2