Search in sources :

Example 1 with RelationMentionFactory

use of edu.stanford.nlp.ie.machinereading.structure.RelationMentionFactory in project CoreNLP by stanfordnlp.

the class RothResultsByRelation method printResults.

@Override
public void printResults(PrintWriter pw, List<CoreMap> goldStandard, List<CoreMap> extractorOutput) {
    featureFactory = MachineReading.makeRelationFeatureFactory(MachineReadingProperties.relationFeatureFactoryClass, MachineReadingProperties.relationFeatures, false);
    // generic mentions work well in this domain
    mentionFactory = new RelationMentionFactory();
    ResultsPrinter.align(goldStandard, extractorOutput);
    List<RelationMention> relations = new ArrayList<>();
    final Map<RelationMention, String> predictions = new HashMap<>();
    for (int i = 0; i < goldStandard.size(); i++) {
        List<RelationMention> goldRelations = AnnotationUtils.getAllRelations(mentionFactory, goldStandard.get(i), true);
        relations.addAll(goldRelations);
        for (RelationMention rel : goldRelations) {
            predictions.put(rel, AnnotationUtils.getRelation(mentionFactory, extractorOutput.get(i), rel.getArg(0), rel.getArg(1)).getType());
        }
    }
    final Counter<Pair<Pair<String, String>, String>> pathCounts = new ClassicCounter<>();
    for (RelationMention rel : relations) {
        pathCounts.incrementCount(new Pair<>(new Pair<>(rel.getArg(0).getType(), rel.getArg(1).getType()), featureFactory.getFeature(rel, "dependency_path_lowlevel")));
    }
    Counter<String> singletonCorrect = new ClassicCounter<>();
    Counter<String> singletonPredicted = new ClassicCounter<>();
    Counter<String> singletonActual = new ClassicCounter<>();
    for (RelationMention rel : relations) {
        if (pathCounts.getCount(new Pair<>(new Pair<>(rel.getArg(0).getType(), rel.getArg(1).getType()), featureFactory.getFeature(rel, "dependency_path_lowlevel"))) == 1.0) {
            String prediction = predictions.get(rel);
            if (prediction.equals(rel.getType())) {
                singletonCorrect.incrementCount(prediction);
            }
            singletonPredicted.incrementCount(prediction);
            singletonActual.incrementCount(rel.getType());
        }
    }
    class RelComp implements Comparator<RelationMention> {

        @Override
        public int compare(RelationMention rel1, RelationMention rel2) {
            // Group together actual relations of a type with relations that were
            // predicted to be that type
            String prediction1 = predictions.get(rel1);
            String prediction2 = predictions.get(rel2);
            // String rel1group = RelationsSentence.isUnrelatedLabel(rel1.getType())
            // ? prediction1 : rel1.getType();
            // String rel2group = RelationsSentence.isUnrelatedLabel(rel2.getType())
            // ? prediction2 : rel2.getType();
            int entComp = (rel1.getArg(0).getType() + rel1.getArg(1).getType()).compareTo(rel2.getArg(0).getType() + rel2.getArg(1).getType());
            // int groupComp = rel1group.compareTo(rel2group);
            int typeComp = rel1.getType().compareTo(rel2.getType());
            int predictionComp = prediction1.compareTo(prediction2);
            // int pathComp =
            // getFeature(rel1,"generalized_dependency_path").compareTo(getFeature(rel2,"generalized_dependency_path"));
            double pathCount1 = pathCounts.getCount(new Pair<>(new Pair<>(rel1.getArg(0).getType(), rel1.getArg(1).getType()), featureFactory.getFeature(rel1, "dependency_path_lowlevel")));
            double pathCount2 = pathCounts.getCount(new Pair<>(new Pair<>(rel2.getArg(0).getType(), rel2.getArg(1).getType()), featureFactory.getFeature(rel2, "dependency_path_lowlevel")));
            if (entComp != 0) {
                return entComp;
            // } else if (pathComp != 0) {
            // return pathComp;
            } else if (pathCount1 < pathCount2) {
                return -1;
            } else if (pathCount1 > pathCount2) {
                return 1;
            } else if (typeComp != 0) {
                return typeComp;
            } else if (predictionComp != 0) {
                return predictionComp;
            } else {
                return rel1.getSentence().get(CoreAnnotations.TextAnnotation.class).compareTo(rel2.getSentence().get(CoreAnnotations.TextAnnotation.class));
            }
        }
    }
    RelComp relComp = new RelComp();
    Collections.sort(relations, relComp);
    for (RelationMention rel : relations) {
        String prediction = predictions.get(rel);
        // if (RelationsSentence.isUnrelatedLabel(prediction) &&
        // RelationsSentence.isUnrelatedLabel(rel.getType())) {
        // continue;
        // }
        String type1 = rel.getArg(0).getType();
        String type2 = rel.getArg(1).getType();
        String path = featureFactory.getFeature(rel, "dependency_path_lowlevel");
        if (!((type1.equals("PEOPLE") && type2.equals("PEOPLE")) || (type1.equals("PEOPLE") && type2.equals("LOCATION")) || (type1.equals("LOCATION") && type2.equals("LOCATION")) || (type1.equals("ORGANIZATION") && type2.equals("LOCATION")) || (type1.equals("PEOPLE") && type2.equals("ORGANIZATION")))) {
            continue;
        }
        if (path.equals("")) {
            continue;
        }
        pw.println("\nLABEL: " + prediction);
        pw.println(rel);
        pw.println(path);
        pw.println(featureFactory.getFeatures(rel, "dependency_path_words"));
        pw.println(featureFactory.getFeature(rel, "surface_path_POS"));
    }
}
Also used : RelationMentionFactory(edu.stanford.nlp.ie.machinereading.structure.RelationMentionFactory) RelationMention(edu.stanford.nlp.ie.machinereading.structure.RelationMention) HashMap(java.util.HashMap) ArrayList(java.util.ArrayList) Comparator(java.util.Comparator) ClassicCounter(edu.stanford.nlp.stats.ClassicCounter) Pair(edu.stanford.nlp.util.Pair)

Aggregations

RelationMention (edu.stanford.nlp.ie.machinereading.structure.RelationMention)1 RelationMentionFactory (edu.stanford.nlp.ie.machinereading.structure.RelationMentionFactory)1 ClassicCounter (edu.stanford.nlp.stats.ClassicCounter)1 Pair (edu.stanford.nlp.util.Pair)1 ArrayList (java.util.ArrayList)1 Comparator (java.util.Comparator)1 HashMap (java.util.HashMap)1