use of edu.stanford.nlp.scenegraph.image.SceneGraphImage in project CoreNLP by stanfordnlp.
the class EntityClassifier method main.
public static void main(String[] args) throws IOException {
String filename = args[0];
String modelPath = args[1];
String embeddingsPath = args[2];
Embedding embeddings = new Embedding(args[2]);
BufferedReader reader = IOUtils.readerFromString(filename);
List<SceneGraphImage> images = Generics.newLinkedList();
for (String line = reader.readLine(); line != null; line = reader.readLine()) {
SceneGraphImage img = SceneGraphImage.readFromJSON(line);
if (img == null) {
continue;
}
images.add(img);
}
train(images, modelPath, embeddings);
}
use of edu.stanford.nlp.scenegraph.image.SceneGraphImage in project CoreNLP by stanfordnlp.
the class KNNSceneGraphParser method main.
public static void main(String[] args) throws IOException {
if (args.length < 3 || !args[2].equals("-train")) {
KNNSceneGraphParser parser = new KNNSceneGraphParser(args[1]);
Map<Integer, SceneGraphImage> trainImages = parser.loadImages(args[2]);
BufferedReader reader = IOUtils.readerFromString(args[0]);
PrintWriter predWriter = IOUtils.getPrintWriter(args[3]);
PrintWriter goldWriter = IOUtils.getPrintWriter(args[4]);
SceneGraphEvaluation evaluation = new SceneGraphEvaluation();
double count = 0.0;
double f1Sum = 0.0;
for (String line = reader.readLine(); line != null; line = reader.readLine()) {
SceneGraphImage img = SceneGraphImage.readFromJSON(line);
for (SceneGraphImageRegion region : img.regions) {
count += 1.0;
SceneGraphImageRegion predicted = parser.parse(region.tokens, trainImages);
Triple<Double, Double, Double> scores = evaluation.evaluate(predicted, region);
evaluation.toSmatchString(predicted, region, predWriter, goldWriter);
SceneGraphImage predictedImg = new SceneGraphImage();
predictedImg.id = img.id;
predictedImg.url = img.url;
predictedImg.height = img.height;
predictedImg.width = img.width;
Set<Integer> objectIds = Generics.newHashSet();
for (SceneGraphImageAttribute attr : region.attributes) {
objectIds.add(img.objects.indexOf(attr.subject));
}
for (SceneGraphImageRelationship reln : region.relationships) {
objectIds.add(img.objects.indexOf(reln.subject));
objectIds.add(img.objects.indexOf(reln.object));
}
predictedImg.objects = Generics.newArrayList();
for (Integer objectId : objectIds) {
predictedImg.objects.add(img.objects.get(objectId));
}
SceneGraphImageRegion newRegion = new SceneGraphImageRegion();
newRegion.phrase = region.phrase;
newRegion.x = region.x;
newRegion.y = region.y;
newRegion.h = region.h;
newRegion.w = region.w;
newRegion.attributes = Generics.newHashSet();
newRegion.relationships = Generics.newHashSet();
predictedImg.regions = Generics.newArrayList();
predictedImg.regions.add(newRegion);
predictedImg.attributes = Generics.newLinkedList();
for (SceneGraphImageAttribute attr : region.attributes) {
SceneGraphImageAttribute attrCopy = attr.clone();
attrCopy.region = newRegion;
attrCopy.image = predictedImg;
predictedImg.addAttribute(attrCopy);
}
predictedImg.relationships = Generics.newLinkedList();
for (SceneGraphImageRelationship reln : region.relationships) {
SceneGraphImageRelationship relnCopy = reln.clone();
relnCopy.image = predictedImg;
relnCopy.region = newRegion;
predictedImg.addRelationship(relnCopy);
}
System.out.println(predictedImg.toJSON());
System.err.printf("Prec: %f, Recall: %f, F1: %f%n", scores.first, scores.second, scores.third);
f1Sum += scores.third;
}
}
System.err.println("#########################################################");
System.err.printf("Macro-averaged F1: %f%n", f1Sum / count);
System.err.println("#########################################################");
} else {
KNNSceneGraphParser parser = new KNNSceneGraphParser(null);
parser.train(args[0], args[1]);
}
}
use of edu.stanford.nlp.scenegraph.image.SceneGraphImage in project CoreNLP by stanfordnlp.
the class KNNSceneGraphParser method parse.
public SceneGraphImageRegion parse(List<CoreLabel> tokens, Map<Integer, SceneGraphImage> trainImages) throws IOException {
Counter<String> features = new ClassicCounter<String>();
for (CoreLabel token : tokens) {
features.incrementCount(token.word());
}
RVFDatum<String, String> datum = new RVFDatum<String, String>(features);
String[] idParts = this.classifier.classOf(datum).split("_");
int imgId = Integer.parseInt(idParts[0]);
int regionId = Integer.parseInt(idParts[1]);
SceneGraphImage img = trainImages.get(imgId);
if (img == null)
return null;
return img.regions.get(regionId);
}
use of edu.stanford.nlp.scenegraph.image.SceneGraphImage in project CoreNLP by stanfordnlp.
the class KNNSceneGraphParser method train.
private void train(String trainFile, String modelPath) throws IOException {
Map<Integer, SceneGraphImage> images = loadImages(trainFile);
KNNClassifierFactory<String, String> classifierFactory = new KNNClassifierFactory<String, String>(1, false, false);
List<RVFDatum<String, String>> dataset = Generics.newLinkedList();
for (Integer imgId : images.keySet()) {
SceneGraphImage img = images.get(imgId);
if (img == null) {
continue;
}
for (int i = 0, sz = img.regions.size(); i < sz; i++) {
SceneGraphImageRegion region = img.regions.get(i);
Counter<String> features = new ClassicCounter<String>();
for (CoreLabel token : region.tokens) {
features.incrementCount(token.word());
}
RVFDatum<String, String> datum = new RVFDatum<String, String>(features, String.format("%d_%d", img.id, i));
dataset.add(datum);
}
}
KNNClassifier<String, String> classifier = classifierFactory.train(dataset);
IOUtils.writeObjectToFile(classifier, modelPath);
}
use of edu.stanford.nlp.scenegraph.image.SceneGraphImage in project CoreNLP by stanfordnlp.
the class SceneGraphImageFilter method countAll.
private static void countAll(List<SceneGraphImage> images) {
for (SceneGraphImage img : images) {
for (SceneGraphImageAttribute attr : img.attributes) {
entityCounter.incrementCount(attr.subjectLemmaGloss());
attributeCounter.incrementCount(attr.attributeLemmaGloss());
}
for (SceneGraphImageRelationship attr : img.relationships) {
entityCounter.incrementCount(attr.subjectLemmaGloss());
relationCounter.incrementCount(attr.predicateLemmaGloss());
entityCounter.incrementCount(attr.objectLemmaGloss());
}
}
}
Aggregations