Search in sources :

Example 1 with Dataset

use of edu.stanford.nlp.classify.Dataset in project CoreNLP by stanfordnlp.

the class SingletonPredictor method generateFeatureVectors.

/**
 * Generate the training features from the CoNLL input file.
 * @return Dataset of feature vectors
 * @throws Exception
 */
private static GeneralDataset<String, String> generateFeatureVectors(Properties props) throws Exception {
    GeneralDataset<String, String> dataset = new Dataset<>();
    Dictionaries dict = new Dictionaries(props);
    MentionExtractor mentionExtractor = new CoNLLMentionExtractor(dict, props, new Semantics(dict));
    Document document;
    while ((document = mentionExtractor.nextDoc()) != null) {
        setTokenIndices(document);
        document.extractGoldCorefClusters();
        Map<Integer, CorefCluster> entities = document.goldCorefClusters;
        // Generate features for coreferent mentions with class label 1
        for (CorefCluster entity : entities.values()) {
            for (Mention mention : entity.getCorefMentions()) {
                // Ignore verbal mentions
                if (mention.headWord.tag().startsWith("V"))
                    continue;
                IndexedWord head = mention.dependency.getNodeByIndexSafe(mention.headWord.index());
                if (head == null)
                    continue;
                ArrayList<String> feats = mention.getSingletonFeatures(dict);
                dataset.add(new BasicDatum<>(feats, "1"));
            }
        }
        // Generate features for singletons with class label 0
        ArrayList<CoreLabel> gold_heads = new ArrayList<>();
        for (Mention gold_men : document.allGoldMentions.values()) {
            gold_heads.add(gold_men.headWord);
        }
        for (Mention predicted_men : document.allPredictedMentions.values()) {
            SemanticGraph dep = predicted_men.dependency;
            IndexedWord head = dep.getNodeByIndexSafe(predicted_men.headWord.index());
            if (head == null)
                continue;
            // Ignore verbal mentions
            if (predicted_men.headWord.tag().startsWith("V"))
                continue;
            // If the mention is in the gold set, it is not a singleton and thus ignore
            if (gold_heads.contains(predicted_men.headWord))
                continue;
            dataset.add(new BasicDatum<>(predicted_men.getSingletonFeatures(dict), "0"));
        }
    }
    dataset.summaryStatistics();
    return dataset;
}
Also used : Dataset(edu.stanford.nlp.classify.Dataset) GeneralDataset(edu.stanford.nlp.classify.GeneralDataset) ArrayList(java.util.ArrayList) CoreLabel(edu.stanford.nlp.ling.CoreLabel) SemanticGraph(edu.stanford.nlp.semgraph.SemanticGraph) IndexedWord(edu.stanford.nlp.ling.IndexedWord)

Example 2 with Dataset

use of edu.stanford.nlp.classify.Dataset in project CoreNLP by stanfordnlp.

the class SingletonPredictor method generateFeatureVectors.

/**
 * Generate the training features from the CoNLL input file.
 * @return Dataset of feature vectors
 * @throws Exception
 */
private static GeneralDataset<String, String> generateFeatureVectors(Properties props) throws Exception {
    GeneralDataset<String, String> dataset = new Dataset<>();
    Dictionaries dict = new Dictionaries(props);
    DocumentMaker docMaker = new DocumentMaker(props, dict);
    Document document;
    while ((document = docMaker.nextDoc()) != null) {
        setTokenIndices(document);
        Map<Integer, CorefCluster> entities = document.goldCorefClusters;
        // Generate features for coreferent mentions with class label 1
        for (CorefCluster entity : entities.values()) {
            for (Mention mention : entity.getCorefMentions()) {
                // Ignore verbal mentions
                if (mention.headWord.tag().startsWith("V"))
                    continue;
                IndexedWord head = mention.enhancedDependency.getNodeByIndexSafe(mention.headWord.index());
                if (head == null)
                    continue;
                ArrayList<String> feats = mention.getSingletonFeatures(dict);
                dataset.add(new BasicDatum<>(feats, "1"));
            }
        }
        // Generate features for singletons with class label 0
        ArrayList<CoreLabel> gold_heads = new ArrayList<>();
        for (Mention gold_men : document.goldMentionsByID.values()) {
            gold_heads.add(gold_men.headWord);
        }
        for (Mention predicted_men : document.predictedMentionsByID.values()) {
            SemanticGraph dep = predicted_men.enhancedDependency;
            IndexedWord head = dep.getNodeByIndexSafe(predicted_men.headWord.index());
            if (head == null || !dep.vertexSet().contains(head))
                continue;
            // Ignore verbal mentions
            if (predicted_men.headWord.tag().startsWith("V"))
                continue;
            // If the mention is in the gold set, it is not a singleton and thus ignore
            if (gold_heads.contains(predicted_men.headWord))
                continue;
            dataset.add(new BasicDatum<>(predicted_men.getSingletonFeatures(dict), "0"));
        }
    }
    dataset.summaryStatistics();
    return dataset;
}
Also used : Dictionaries(edu.stanford.nlp.coref.data.Dictionaries) GeneralDataset(edu.stanford.nlp.classify.GeneralDataset) Dataset(edu.stanford.nlp.classify.Dataset) ArrayList(java.util.ArrayList) Document(edu.stanford.nlp.coref.data.Document) CoreLabel(edu.stanford.nlp.ling.CoreLabel) DocumentMaker(edu.stanford.nlp.coref.data.DocumentMaker) CorefCluster(edu.stanford.nlp.coref.data.CorefCluster) Mention(edu.stanford.nlp.coref.data.Mention) SemanticGraph(edu.stanford.nlp.semgraph.SemanticGraph) IndexedWord(edu.stanford.nlp.ling.IndexedWord)

Example 3 with Dataset

use of edu.stanford.nlp.classify.Dataset in project CoreNLP by stanfordnlp.

the class BoWSceneGraphParser method getTrainingExamples.

/**
 * Generate training examples.
 *
 * @param trainingFile Path to JSON file with training images and scene graphs.
 * @param sampleNeg Whether to sample the same number of negative examples as positive examples.
 * @return Dataset to train a classifier.
 * @throws IOException
 */
public Dataset<String, String> getTrainingExamples(String trainingFile, boolean sampleNeg) throws IOException {
    Dataset<String, String> dataset = new Dataset<String, String>();
    Dataset<String, String> negDataset = new Dataset<String, String>();
    /* Load images. */
    List<SceneGraphImage> images = loadImages(trainingFile);
    for (SceneGraphImage image : images) {
        for (SceneGraphImageRegion region : image.regions) {
            SemanticGraph sg = region.getEnhancedSemanticGraph();
            SemanticGraphEnhancer.processQuanftificationModifiers(sg);
            SemanticGraphEnhancer.collapseCompounds(sg);
            SemanticGraphEnhancer.collapseParticles(sg);
            SemanticGraphEnhancer.resolvePronouns(sg);
            Set<Integer> entityPairs = Generics.newHashSet();
            List<Triple<IndexedWord, IndexedWord, String>> relationTriples = this.sentenceMatcher.getRelationTriples(region);
            for (Triple<IndexedWord, IndexedWord, String> triple : relationTriples) {
                IndexedWord iw1 = sg.getNodeByIndexSafe(triple.first.index());
                IndexedWord iw2 = sg.getNodeByIndexSafe(triple.second.index());
                if (iw1 != null && iw2 != null && (!enforceSubtree || SceneGraphUtils.inSameSubTree(sg, iw1, iw2))) {
                    entityClassifer.predictEntity(iw1, this.embeddings);
                    entityClassifer.predictEntity(iw2, this.embeddings);
                    BoWExample example = new BoWExample(iw1, iw2, sg);
                    dataset.add(example.extractFeatures(featureSets), triple.third);
                }
                entityPairs.add((triple.first.index() << 4) + triple.second.index());
            }
            /* Add negative examples. */
            List<IndexedWord> entities = EntityExtractor.extractEntities(sg);
            List<IndexedWord> attributes = EntityExtractor.extractAttributes(sg);
            for (IndexedWord e : entities) {
                entityClassifer.predictEntity(e, this.embeddings);
            }
            for (IndexedWord a : attributes) {
                entityClassifer.predictEntity(a, this.embeddings);
            }
            for (IndexedWord e1 : entities) {
                for (IndexedWord e2 : entities) {
                    if (e1.index() == e2.index()) {
                        continue;
                    }
                    int entityPair = (e1.index() << 4) + e2.index();
                    if (!entityPairs.contains(entityPair) && (!enforceSubtree || SceneGraphUtils.inSameSubTree(sg, e1, e2))) {
                        BoWExample example = new BoWExample(e1, e2, sg);
                        negDataset.add(example.extractFeatures(featureSets), NONE_RELATION);
                    }
                }
            }
            for (IndexedWord e : entities) {
                for (IndexedWord a : attributes) {
                    int entityPair = (e.index() << 4) + a.index();
                    if (!entityPairs.contains(entityPair) && (!enforceSubtree || SceneGraphUtils.inSameSubTree(sg, e, a))) {
                        BoWExample example = new BoWExample(e, a, sg);
                        negDataset.add(example.extractFeatures(featureSets), NONE_RELATION);
                    }
                }
            }
        }
    }
    /* Sample from negative examples to make the training set
     * more balanced. */
    if (sampleNeg && dataset.size() < negDataset.size()) {
        negDataset = negDataset.getRandomSubDataset(dataset.size() * 1.0 / negDataset.size(), 42);
    }
    dataset.addAll(negDataset);
    return dataset;
}
Also used : SceneGraphImage(edu.stanford.nlp.scenegraph.image.SceneGraphImage) Dataset(edu.stanford.nlp.classify.Dataset) Triple(edu.stanford.nlp.util.Triple) SemanticGraph(edu.stanford.nlp.semgraph.SemanticGraph) IndexedWord(edu.stanford.nlp.ling.IndexedWord) SceneGraphImageRegion(edu.stanford.nlp.scenegraph.image.SceneGraphImageRegion)

Aggregations

Dataset (edu.stanford.nlp.classify.Dataset)3 IndexedWord (edu.stanford.nlp.ling.IndexedWord)3 SemanticGraph (edu.stanford.nlp.semgraph.SemanticGraph)3 GeneralDataset (edu.stanford.nlp.classify.GeneralDataset)2 CoreLabel (edu.stanford.nlp.ling.CoreLabel)2 ArrayList (java.util.ArrayList)2 CorefCluster (edu.stanford.nlp.coref.data.CorefCluster)1 Dictionaries (edu.stanford.nlp.coref.data.Dictionaries)1 Document (edu.stanford.nlp.coref.data.Document)1 DocumentMaker (edu.stanford.nlp.coref.data.DocumentMaker)1 Mention (edu.stanford.nlp.coref.data.Mention)1 SceneGraphImage (edu.stanford.nlp.scenegraph.image.SceneGraphImage)1 SceneGraphImageRegion (edu.stanford.nlp.scenegraph.image.SceneGraphImageRegion)1 Triple (edu.stanford.nlp.util.Triple)1