use of edu.stanford.nlp.classify.Dataset in project CoreNLP by stanfordnlp.
the class SingletonPredictor method generateFeatureVectors.
/**
* Generate the training features from the CoNLL input file.
* @return Dataset of feature vectors
* @throws Exception
*/
private static GeneralDataset<String, String> generateFeatureVectors(Properties props) throws Exception {
GeneralDataset<String, String> dataset = new Dataset<>();
Dictionaries dict = new Dictionaries(props);
MentionExtractor mentionExtractor = new CoNLLMentionExtractor(dict, props, new Semantics(dict));
Document document;
while ((document = mentionExtractor.nextDoc()) != null) {
setTokenIndices(document);
document.extractGoldCorefClusters();
Map<Integer, CorefCluster> entities = document.goldCorefClusters;
// Generate features for coreferent mentions with class label 1
for (CorefCluster entity : entities.values()) {
for (Mention mention : entity.getCorefMentions()) {
// Ignore verbal mentions
if (mention.headWord.tag().startsWith("V"))
continue;
IndexedWord head = mention.dependency.getNodeByIndexSafe(mention.headWord.index());
if (head == null)
continue;
ArrayList<String> feats = mention.getSingletonFeatures(dict);
dataset.add(new BasicDatum<>(feats, "1"));
}
}
// Generate features for singletons with class label 0
ArrayList<CoreLabel> gold_heads = new ArrayList<>();
for (Mention gold_men : document.allGoldMentions.values()) {
gold_heads.add(gold_men.headWord);
}
for (Mention predicted_men : document.allPredictedMentions.values()) {
SemanticGraph dep = predicted_men.dependency;
IndexedWord head = dep.getNodeByIndexSafe(predicted_men.headWord.index());
if (head == null)
continue;
// Ignore verbal mentions
if (predicted_men.headWord.tag().startsWith("V"))
continue;
// If the mention is in the gold set, it is not a singleton and thus ignore
if (gold_heads.contains(predicted_men.headWord))
continue;
dataset.add(new BasicDatum<>(predicted_men.getSingletonFeatures(dict), "0"));
}
}
dataset.summaryStatistics();
return dataset;
}
use of edu.stanford.nlp.classify.Dataset in project CoreNLP by stanfordnlp.
the class SingletonPredictor method generateFeatureVectors.
/**
* Generate the training features from the CoNLL input file.
* @return Dataset of feature vectors
* @throws Exception
*/
private static GeneralDataset<String, String> generateFeatureVectors(Properties props) throws Exception {
GeneralDataset<String, String> dataset = new Dataset<>();
Dictionaries dict = new Dictionaries(props);
DocumentMaker docMaker = new DocumentMaker(props, dict);
Document document;
while ((document = docMaker.nextDoc()) != null) {
setTokenIndices(document);
Map<Integer, CorefCluster> entities = document.goldCorefClusters;
// Generate features for coreferent mentions with class label 1
for (CorefCluster entity : entities.values()) {
for (Mention mention : entity.getCorefMentions()) {
// Ignore verbal mentions
if (mention.headWord.tag().startsWith("V"))
continue;
IndexedWord head = mention.enhancedDependency.getNodeByIndexSafe(mention.headWord.index());
if (head == null)
continue;
ArrayList<String> feats = mention.getSingletonFeatures(dict);
dataset.add(new BasicDatum<>(feats, "1"));
}
}
// Generate features for singletons with class label 0
ArrayList<CoreLabel> gold_heads = new ArrayList<>();
for (Mention gold_men : document.goldMentionsByID.values()) {
gold_heads.add(gold_men.headWord);
}
for (Mention predicted_men : document.predictedMentionsByID.values()) {
SemanticGraph dep = predicted_men.enhancedDependency;
IndexedWord head = dep.getNodeByIndexSafe(predicted_men.headWord.index());
if (head == null || !dep.vertexSet().contains(head))
continue;
// Ignore verbal mentions
if (predicted_men.headWord.tag().startsWith("V"))
continue;
// If the mention is in the gold set, it is not a singleton and thus ignore
if (gold_heads.contains(predicted_men.headWord))
continue;
dataset.add(new BasicDatum<>(predicted_men.getSingletonFeatures(dict), "0"));
}
}
dataset.summaryStatistics();
return dataset;
}
use of edu.stanford.nlp.classify.Dataset in project CoreNLP by stanfordnlp.
the class BoWSceneGraphParser method getTrainingExamples.
/**
* Generate training examples.
*
* @param trainingFile Path to JSON file with training images and scene graphs.
* @param sampleNeg Whether to sample the same number of negative examples as positive examples.
* @return Dataset to train a classifier.
* @throws IOException
*/
public Dataset<String, String> getTrainingExamples(String trainingFile, boolean sampleNeg) throws IOException {
Dataset<String, String> dataset = new Dataset<String, String>();
Dataset<String, String> negDataset = new Dataset<String, String>();
/* Load images. */
List<SceneGraphImage> images = loadImages(trainingFile);
for (SceneGraphImage image : images) {
for (SceneGraphImageRegion region : image.regions) {
SemanticGraph sg = region.getEnhancedSemanticGraph();
SemanticGraphEnhancer.processQuanftificationModifiers(sg);
SemanticGraphEnhancer.collapseCompounds(sg);
SemanticGraphEnhancer.collapseParticles(sg);
SemanticGraphEnhancer.resolvePronouns(sg);
Set<Integer> entityPairs = Generics.newHashSet();
List<Triple<IndexedWord, IndexedWord, String>> relationTriples = this.sentenceMatcher.getRelationTriples(region);
for (Triple<IndexedWord, IndexedWord, String> triple : relationTriples) {
IndexedWord iw1 = sg.getNodeByIndexSafe(triple.first.index());
IndexedWord iw2 = sg.getNodeByIndexSafe(triple.second.index());
if (iw1 != null && iw2 != null && (!enforceSubtree || SceneGraphUtils.inSameSubTree(sg, iw1, iw2))) {
entityClassifer.predictEntity(iw1, this.embeddings);
entityClassifer.predictEntity(iw2, this.embeddings);
BoWExample example = new BoWExample(iw1, iw2, sg);
dataset.add(example.extractFeatures(featureSets), triple.third);
}
entityPairs.add((triple.first.index() << 4) + triple.second.index());
}
/* Add negative examples. */
List<IndexedWord> entities = EntityExtractor.extractEntities(sg);
List<IndexedWord> attributes = EntityExtractor.extractAttributes(sg);
for (IndexedWord e : entities) {
entityClassifer.predictEntity(e, this.embeddings);
}
for (IndexedWord a : attributes) {
entityClassifer.predictEntity(a, this.embeddings);
}
for (IndexedWord e1 : entities) {
for (IndexedWord e2 : entities) {
if (e1.index() == e2.index()) {
continue;
}
int entityPair = (e1.index() << 4) + e2.index();
if (!entityPairs.contains(entityPair) && (!enforceSubtree || SceneGraphUtils.inSameSubTree(sg, e1, e2))) {
BoWExample example = new BoWExample(e1, e2, sg);
negDataset.add(example.extractFeatures(featureSets), NONE_RELATION);
}
}
}
for (IndexedWord e : entities) {
for (IndexedWord a : attributes) {
int entityPair = (e.index() << 4) + a.index();
if (!entityPairs.contains(entityPair) && (!enforceSubtree || SceneGraphUtils.inSameSubTree(sg, e, a))) {
BoWExample example = new BoWExample(e, a, sg);
negDataset.add(example.extractFeatures(featureSets), NONE_RELATION);
}
}
}
}
}
/* Sample from negative examples to make the training set
* more balanced. */
if (sampleNeg && dataset.size() < negDataset.size()) {
negDataset = negDataset.getRandomSubDataset(dataset.size() * 1.0 / negDataset.size(), 42);
}
dataset.addAll(negDataset);
return dataset;
}
Aggregations