use of edu.stanford.nlp.scenegraph.image.SceneGraphImageAttribute in project CoreNLP by stanfordnlp.
the class KNNSceneGraphParser method main.
public static void main(String[] args) throws IOException {
if (args.length < 3 || !args[2].equals("-train")) {
KNNSceneGraphParser parser = new KNNSceneGraphParser(args[1]);
Map<Integer, SceneGraphImage> trainImages = parser.loadImages(args[2]);
BufferedReader reader = IOUtils.readerFromString(args[0]);
PrintWriter predWriter = IOUtils.getPrintWriter(args[3]);
PrintWriter goldWriter = IOUtils.getPrintWriter(args[4]);
SceneGraphEvaluation evaluation = new SceneGraphEvaluation();
double count = 0.0;
double f1Sum = 0.0;
for (String line = reader.readLine(); line != null; line = reader.readLine()) {
SceneGraphImage img = SceneGraphImage.readFromJSON(line);
for (SceneGraphImageRegion region : img.regions) {
count += 1.0;
SceneGraphImageRegion predicted = parser.parse(region.tokens, trainImages);
Triple<Double, Double, Double> scores = evaluation.evaluate(predicted, region);
evaluation.toSmatchString(predicted, region, predWriter, goldWriter);
SceneGraphImage predictedImg = new SceneGraphImage();
predictedImg.id = img.id;
predictedImg.url = img.url;
predictedImg.height = img.height;
predictedImg.width = img.width;
Set<Integer> objectIds = Generics.newHashSet();
for (SceneGraphImageAttribute attr : region.attributes) {
objectIds.add(img.objects.indexOf(attr.subject));
}
for (SceneGraphImageRelationship reln : region.relationships) {
objectIds.add(img.objects.indexOf(reln.subject));
objectIds.add(img.objects.indexOf(reln.object));
}
predictedImg.objects = Generics.newArrayList();
for (Integer objectId : objectIds) {
predictedImg.objects.add(img.objects.get(objectId));
}
SceneGraphImageRegion newRegion = new SceneGraphImageRegion();
newRegion.phrase = region.phrase;
newRegion.x = region.x;
newRegion.y = region.y;
newRegion.h = region.h;
newRegion.w = region.w;
newRegion.attributes = Generics.newHashSet();
newRegion.relationships = Generics.newHashSet();
predictedImg.regions = Generics.newArrayList();
predictedImg.regions.add(newRegion);
predictedImg.attributes = Generics.newLinkedList();
for (SceneGraphImageAttribute attr : region.attributes) {
SceneGraphImageAttribute attrCopy = attr.clone();
attrCopy.region = newRegion;
attrCopy.image = predictedImg;
predictedImg.addAttribute(attrCopy);
}
predictedImg.relationships = Generics.newLinkedList();
for (SceneGraphImageRelationship reln : region.relationships) {
SceneGraphImageRelationship relnCopy = reln.clone();
relnCopy.image = predictedImg;
relnCopy.region = newRegion;
predictedImg.addRelationship(relnCopy);
}
System.out.println(predictedImg.toJSON());
System.err.printf("Prec: %f, Recall: %f, F1: %f%n", scores.first, scores.second, scores.third);
f1Sum += scores.third;
}
}
System.err.println("#########################################################");
System.err.printf("Macro-averaged F1: %f%n", f1Sum / count);
System.err.println("#########################################################");
} else {
KNNSceneGraphParser parser = new KNNSceneGraphParser(null);
parser.train(args[0], args[1]);
}
}
use of edu.stanford.nlp.scenegraph.image.SceneGraphImageAttribute in project CoreNLP by stanfordnlp.
the class SceneGraphImageFilter method countAll.
private static void countAll(List<SceneGraphImage> images) {
for (SceneGraphImage img : images) {
for (SceneGraphImageAttribute attr : img.attributes) {
entityCounter.incrementCount(attr.subjectLemmaGloss());
attributeCounter.incrementCount(attr.attributeLemmaGloss());
}
for (SceneGraphImageRelationship attr : img.relationships) {
entityCounter.incrementCount(attr.subjectLemmaGloss());
relationCounter.incrementCount(attr.predicateLemmaGloss());
entityCounter.incrementCount(attr.objectLemmaGloss());
}
}
}
Aggregations