use of edu.stanford.nlp.ie.machinereading.structure.RelationMentionFactory in project CoreNLP by stanfordnlp.
the class RothResultsByRelation method printResults.
@Override
public void printResults(PrintWriter pw, List<CoreMap> goldStandard, List<CoreMap> extractorOutput) {
featureFactory = MachineReading.makeRelationFeatureFactory(MachineReadingProperties.relationFeatureFactoryClass, MachineReadingProperties.relationFeatures, false);
// generic mentions work well in this domain
mentionFactory = new RelationMentionFactory();
ResultsPrinter.align(goldStandard, extractorOutput);
List<RelationMention> relations = new ArrayList<>();
final Map<RelationMention, String> predictions = new HashMap<>();
for (int i = 0; i < goldStandard.size(); i++) {
List<RelationMention> goldRelations = AnnotationUtils.getAllRelations(mentionFactory, goldStandard.get(i), true);
relations.addAll(goldRelations);
for (RelationMention rel : goldRelations) {
predictions.put(rel, AnnotationUtils.getRelation(mentionFactory, extractorOutput.get(i), rel.getArg(0), rel.getArg(1)).getType());
}
}
final Counter<Pair<Pair<String, String>, String>> pathCounts = new ClassicCounter<>();
for (RelationMention rel : relations) {
pathCounts.incrementCount(new Pair<>(new Pair<>(rel.getArg(0).getType(), rel.getArg(1).getType()), featureFactory.getFeature(rel, "dependency_path_lowlevel")));
}
Counter<String> singletonCorrect = new ClassicCounter<>();
Counter<String> singletonPredicted = new ClassicCounter<>();
Counter<String> singletonActual = new ClassicCounter<>();
for (RelationMention rel : relations) {
if (pathCounts.getCount(new Pair<>(new Pair<>(rel.getArg(0).getType(), rel.getArg(1).getType()), featureFactory.getFeature(rel, "dependency_path_lowlevel"))) == 1.0) {
String prediction = predictions.get(rel);
if (prediction.equals(rel.getType())) {
singletonCorrect.incrementCount(prediction);
}
singletonPredicted.incrementCount(prediction);
singletonActual.incrementCount(rel.getType());
}
}
class RelComp implements Comparator<RelationMention> {
@Override
public int compare(RelationMention rel1, RelationMention rel2) {
// Group together actual relations of a type with relations that were
// predicted to be that type
String prediction1 = predictions.get(rel1);
String prediction2 = predictions.get(rel2);
// String rel1group = RelationsSentence.isUnrelatedLabel(rel1.getType())
// ? prediction1 : rel1.getType();
// String rel2group = RelationsSentence.isUnrelatedLabel(rel2.getType())
// ? prediction2 : rel2.getType();
int entComp = (rel1.getArg(0).getType() + rel1.getArg(1).getType()).compareTo(rel2.getArg(0).getType() + rel2.getArg(1).getType());
// int groupComp = rel1group.compareTo(rel2group);
int typeComp = rel1.getType().compareTo(rel2.getType());
int predictionComp = prediction1.compareTo(prediction2);
// int pathComp =
// getFeature(rel1,"generalized_dependency_path").compareTo(getFeature(rel2,"generalized_dependency_path"));
double pathCount1 = pathCounts.getCount(new Pair<>(new Pair<>(rel1.getArg(0).getType(), rel1.getArg(1).getType()), featureFactory.getFeature(rel1, "dependency_path_lowlevel")));
double pathCount2 = pathCounts.getCount(new Pair<>(new Pair<>(rel2.getArg(0).getType(), rel2.getArg(1).getType()), featureFactory.getFeature(rel2, "dependency_path_lowlevel")));
if (entComp != 0) {
return entComp;
// } else if (pathComp != 0) {
// return pathComp;
} else if (pathCount1 < pathCount2) {
return -1;
} else if (pathCount1 > pathCount2) {
return 1;
} else if (typeComp != 0) {
return typeComp;
} else if (predictionComp != 0) {
return predictionComp;
} else {
return rel1.getSentence().get(CoreAnnotations.TextAnnotation.class).compareTo(rel2.getSentence().get(CoreAnnotations.TextAnnotation.class));
}
}
}
RelComp relComp = new RelComp();
Collections.sort(relations, relComp);
for (RelationMention rel : relations) {
String prediction = predictions.get(rel);
// if (RelationsSentence.isUnrelatedLabel(prediction) &&
// RelationsSentence.isUnrelatedLabel(rel.getType())) {
// continue;
// }
String type1 = rel.getArg(0).getType();
String type2 = rel.getArg(1).getType();
String path = featureFactory.getFeature(rel, "dependency_path_lowlevel");
if (!((type1.equals("PEOPLE") && type2.equals("PEOPLE")) || (type1.equals("PEOPLE") && type2.equals("LOCATION")) || (type1.equals("LOCATION") && type2.equals("LOCATION")) || (type1.equals("ORGANIZATION") && type2.equals("LOCATION")) || (type1.equals("PEOPLE") && type2.equals("ORGANIZATION")))) {
continue;
}
if (path.equals("")) {
continue;
}
pw.println("\nLABEL: " + prediction);
pw.println(rel);
pw.println(path);
pw.println(featureFactory.getFeatures(rel, "dependency_path_words"));
pw.println(featureFactory.getFeature(rel, "surface_path_POS"));
}
}
Aggregations