use of edu.stanford.nlp.util.ScoredObject in project CoreNLP by stanfordnlp.
the class FindNearestNeighbors method main.
public static void main(String[] args) throws Exception {
String modelPath = null;
String outputPath = null;
String testTreebankPath = null;
FileFilter testTreebankFilter = null;
List<String> unusedArgs = new ArrayList<>();
for (int argIndex = 0; argIndex < args.length; ) {
if (args[argIndex].equalsIgnoreCase("-model")) {
modelPath = args[argIndex + 1];
argIndex += 2;
} else if (args[argIndex].equalsIgnoreCase("-testTreebank")) {
Pair<String, FileFilter> treebankDescription = ArgUtils.getTreebankDescription(args, argIndex, "-testTreebank");
argIndex = argIndex + ArgUtils.numSubArgs(args, argIndex) + 1;
testTreebankPath = treebankDescription.first();
testTreebankFilter = treebankDescription.second();
} else if (args[argIndex].equalsIgnoreCase("-output")) {
outputPath = args[argIndex + 1];
argIndex += 2;
} else {
unusedArgs.add(args[argIndex++]);
}
}
if (modelPath == null) {
throw new IllegalArgumentException("Need to specify -model");
}
if (testTreebankPath == null) {
throw new IllegalArgumentException("Need to specify -testTreebank");
}
if (outputPath == null) {
throw new IllegalArgumentException("Need to specify -output");
}
String[] newArgs = unusedArgs.toArray(new String[unusedArgs.size()]);
LexicalizedParser lexparser = LexicalizedParser.loadModel(modelPath, newArgs);
Treebank testTreebank = null;
if (testTreebankPath != null) {
log.info("Reading in trees from " + testTreebankPath);
if (testTreebankFilter != null) {
log.info("Filtering on " + testTreebankFilter);
}
testTreebank = lexparser.getOp().tlpParams.memoryTreebank();
;
testTreebank.loadPath(testTreebankPath, testTreebankFilter);
log.info("Read in " + testTreebank.size() + " trees for testing");
}
FileWriter out = new FileWriter(outputPath);
BufferedWriter bout = new BufferedWriter(out);
log.info("Parsing " + testTreebank.size() + " trees");
int count = 0;
List<ParseRecord> records = Generics.newArrayList();
for (Tree goldTree : testTreebank) {
List<Word> tokens = goldTree.yieldWords();
ParserQuery parserQuery = lexparser.parserQuery();
if (!parserQuery.parse(tokens)) {
throw new AssertionError("Could not parse: " + tokens);
}
if (!(parserQuery instanceof RerankingParserQuery)) {
throw new IllegalArgumentException("Expected a LexicalizedParser with a Reranker attached");
}
RerankingParserQuery rpq = (RerankingParserQuery) parserQuery;
if (!(rpq.rerankerQuery() instanceof DVModelReranker.Query)) {
throw new IllegalArgumentException("Expected a LexicalizedParser with a DVModel attached");
}
DeepTree tree = ((DVModelReranker.Query) rpq.rerankerQuery()).getDeepTrees().get(0);
SimpleMatrix rootVector = null;
for (Map.Entry<Tree, SimpleMatrix> entry : tree.getVectors().entrySet()) {
if (entry.getKey().label().value().equals("ROOT")) {
rootVector = entry.getValue();
break;
}
}
if (rootVector == null) {
throw new AssertionError("Could not find root nodevector");
}
out.write(tokens + "\n");
out.write(tree.getTree() + "\n");
for (int i = 0; i < rootVector.getNumElements(); ++i) {
out.write(" " + rootVector.get(i));
}
out.write("\n\n\n");
count++;
if (count % 10 == 0) {
log.info(" " + count);
}
records.add(new ParseRecord(tokens, goldTree, tree.getTree(), rootVector, tree.getVectors()));
}
log.info(" done parsing");
List<Pair<Tree, SimpleMatrix>> subtrees = Generics.newArrayList();
for (ParseRecord record : records) {
for (Map.Entry<Tree, SimpleMatrix> entry : record.nodeVectors.entrySet()) {
if (entry.getKey().getLeaves().size() <= maxLength) {
subtrees.add(Pair.makePair(entry.getKey(), entry.getValue()));
}
}
}
log.info("There are " + subtrees.size() + " subtrees in the set of trees");
PriorityQueue<ScoredObject<Pair<Tree, Tree>>> bestmatches = new PriorityQueue<>(101, ScoredComparator.DESCENDING_COMPARATOR);
for (int i = 0; i < subtrees.size(); ++i) {
log.info(subtrees.get(i).first().yieldWords());
log.info(subtrees.get(i).first());
for (int j = 0; j < subtrees.size(); ++j) {
if (i == j) {
continue;
}
// TODO: look at basic category?
double normF = subtrees.get(i).second().minus(subtrees.get(j).second()).normF();
bestmatches.add(new ScoredObject<>(Pair.makePair(subtrees.get(i).first(), subtrees.get(j).first()), normF));
if (bestmatches.size() > 100) {
bestmatches.poll();
}
}
List<ScoredObject<Pair<Tree, Tree>>> ordered = Generics.newArrayList();
while (bestmatches.size() > 0) {
ordered.add(bestmatches.poll());
}
Collections.reverse(ordered);
for (ScoredObject<Pair<Tree, Tree>> pair : ordered) {
log.info(" MATCHED " + pair.object().second.yieldWords() + " ... " + pair.object().second() + " with a score of " + pair.score());
}
log.info();
log.info();
bestmatches.clear();
}
/*
for (int i = 0; i < records.size(); ++i) {
if (i % 10 == 0) {
log.info(" " + i);
}
List<ScoredObject<ParseRecord>> scored = Generics.newArrayList();
for (int j = 0; j < records.size(); ++j) {
if (i == j) continue;
double score = 0.0;
int matches = 0;
for (Map.Entry<Tree, SimpleMatrix> first : records.get(i).nodeVectors.entrySet()) {
for (Map.Entry<Tree, SimpleMatrix> second : records.get(j).nodeVectors.entrySet()) {
String firstBasic = dvparser.dvModel.basicCategory(first.getKey().label().value());
String secondBasic = dvparser.dvModel.basicCategory(second.getKey().label().value());
if (firstBasic.equals(secondBasic)) {
++matches;
double normF = first.getValue().minus(second.getValue()).normF();
score += normF * normF;
}
}
}
if (matches == 0) {
score = Double.POSITIVE_INFINITY;
} else {
score = score / matches;
}
//double score = records.get(i).vector.minus(records.get(j).vector).normF();
scored.add(new ScoredObject<ParseRecord>(records.get(j), score));
}
Collections.sort(scored, ScoredComparator.ASCENDING_COMPARATOR);
out.write(records.get(i).sentence.toString() + "\n");
for (int j = 0; j < numNeighbors; ++j) {
out.write(" " + scored.get(j).score() + ": " + scored.get(j).object().sentence + "\n");
}
out.write("\n\n");
}
log.info();
*/
bout.flush();
out.flush();
out.close();
}
use of edu.stanford.nlp.util.ScoredObject in project CoreNLP by stanfordnlp.
the class CharniakScoredParsesReaderWriter method stringToParses.
/**
* Convert string representing scored parses (in the charniak parser output format)
* to list of scored parse trees
* @param parseStr
* @return list of scored parse trees
*/
public static List<ScoredObject<Tree>> stringToParses(String parseStr) {
try {
BufferedReader br = new BufferedReader(new StringReader(parseStr));
Iterable<List<ScoredObject<Tree>>> trees = readScoredTrees("", br);
List<ScoredObject<Tree>> res = null;
Iterator<List<ScoredObject<Tree>>> iter = trees.iterator();
if (iter != null && iter.hasNext()) {
res = iter.next();
}
br.close();
return res;
} catch (IOException ex) {
throw new RuntimeException(ex);
}
}
use of edu.stanford.nlp.util.ScoredObject in project CoreNLP by stanfordnlp.
the class CharniakScoredParsesReaderWriter method printScoredTrees.
/**
* Print scored parse trees in format used by charniak parser
* @param trees - trees to output
* @param filename - file to output to
*/
public static void printScoredTrees(Iterable<List<ScoredObject<Tree>>> trees, String filename) {
try {
PrintWriter pw = IOUtils.getPrintWriter(filename);
int i = 0;
for (List<ScoredObject<Tree>> treeList : trees) {
printScoredTrees(pw, i, treeList);
i++;
}
pw.close();
} catch (IOException ex) {
throw new RuntimeException(ex);
}
}
use of edu.stanford.nlp.util.ScoredObject in project CoreNLP by stanfordnlp.
the class PerceptronModel method findHighestScoringTransitions.
private Collection<ScoredObject<Integer>> findHighestScoringTransitions(State state, List<String> features, boolean requireLegal, int numTransitions, List<ParserConstraint> constraints) {
float[] scores = new float[transitionIndex.size()];
for (String feature : features) {
Weight weight = featureWeights.get(feature);
if (weight == null) {
// Features not in our index are ignored
continue;
}
weight.score(scores);
}
PriorityQueue<ScoredObject<Integer>> queue = new PriorityQueue<>(numTransitions + 1, ScoredComparator.ASCENDING_COMPARATOR);
for (int i = 0; i < scores.length; ++i) {
if (!requireLegal || transitionIndex.get(i).isLegal(state, constraints)) {
queue.add(new ScoredObject<>(i, scores[i]));
if (queue.size() > numTransitions) {
queue.poll();
}
}
}
return queue;
}
use of edu.stanford.nlp.util.ScoredObject in project CoreNLP by stanfordnlp.
the class PerceptronModel method trainTree.
/**
* index: the tree to train
* binarizedTrees: a list of all the training trees we know about, binarized
* transitionLists: a list of pre-assembled transitions for the trees
*/
private TrainingResult trainTree(TrainingExample example) {
int numCorrect = 0;
int numWrong = 0;
Tree tree = example.binarizedTree;
List<TrainingUpdate> updates = Generics.newArrayList();
Pair<Integer, Integer> firstError = null;
IntCounter<Class<? extends Transition>> correctTransitions = new IntCounter<>();
TwoDimensionalIntCounter<Class<? extends Transition>, Class<? extends Transition>> wrongTransitions = new TwoDimensionalIntCounter<>();
ReorderingOracle reorderer = null;
if (op.trainOptions().trainingMethod == ShiftReduceTrainOptions.TrainingMethod.REORDER_ORACLE || op.trainOptions().trainingMethod == ShiftReduceTrainOptions.TrainingMethod.REORDER_BEAM) {
reorderer = new ReorderingOracle(op, rootOnlyStates);
}
int reorderSuccess = 0;
int reorderFail = 0;
if (op.trainOptions().trainingMethod == ShiftReduceTrainOptions.TrainingMethod.BEAM || op.trainOptions().trainingMethod == ShiftReduceTrainOptions.TrainingMethod.REORDER_BEAM) {
if (op.trainOptions().beamSize <= 0) {
throw new IllegalArgumentException("Illegal beam size " + op.trainOptions().beamSize);
}
PriorityQueue<State> agenda = new PriorityQueue<>(op.trainOptions().beamSize + 1, ScoredComparator.ASCENDING_COMPARATOR);
State goldState = example.initialStateFromGoldTagTree();
List<Transition> transitions = example.trainTransitions();
agenda.add(goldState);
while (transitions.size() > 0) {
Transition goldTransition = transitions.get(0);
Transition highestScoringTransitionFromGoldState = null;
double highestScoreFromGoldState = 0.0;
PriorityQueue<State> newAgenda = new PriorityQueue<>(op.trainOptions().beamSize + 1, ScoredComparator.ASCENDING_COMPARATOR);
State highestScoringState = null;
// keep track of the state in the current agenda which leads
// to the highest score on the next agenda. this will be
// trained down assuming it is not the correct state
State highestCurrentState = null;
for (State currentState : agenda) {
// TODO: can maybe speed this part up, although it doesn't seem like a critical part of the runtime
boolean isGoldState = goldState.areTransitionsEqual(currentState);
List<String> features = featureFactory.featurize(currentState);
Collection<ScoredObject<Integer>> stateTransitions = findHighestScoringTransitions(currentState, features, true, op.trainOptions().beamSize, null);
for (ScoredObject<Integer> transition : stateTransitions) {
State newState = transitionIndex.get(transition.object()).apply(currentState, transition.score());
newAgenda.add(newState);
if (newAgenda.size() > op.trainOptions().beamSize) {
newAgenda.poll();
}
if (highestScoringState == null || highestScoringState.score() < newState.score()) {
highestScoringState = newState;
highestCurrentState = currentState;
}
if (isGoldState && (highestScoringTransitionFromGoldState == null || transition.score() > highestScoreFromGoldState)) {
highestScoringTransitionFromGoldState = transitionIndex.get(transition.object());
highestScoreFromGoldState = transition.score();
}
}
}
// state (eg one with ROOT) isn't on the agenda so it stops.
if (op.trainOptions().trainingMethod == ShiftReduceTrainOptions.TrainingMethod.REORDER_BEAM && highestScoringTransitionFromGoldState == null) {
break;
}
if (highestScoringState == null) {
System.err.println("Unable to find a best transition!");
System.err.println("Previous agenda:");
for (State state : agenda) {
System.err.println(state);
}
System.err.println("Gold transitions:");
System.err.println(example.transitions);
break;
}
State newGoldState = goldTransition.apply(goldState, 0.0);
if (firstError == null && !highestScoringTransitionFromGoldState.equals(goldTransition)) {
int predictedIndex = transitionIndex.indexOf(highestScoringTransitionFromGoldState);
int goldIndex = transitionIndex.indexOf(goldTransition);
if (predictedIndex < 0) {
throw new AssertionError("Predicted transition not in the index: " + highestScoringTransitionFromGoldState);
}
if (goldIndex < 0) {
throw new AssertionError("Gold transition not in the index: " + goldTransition);
}
firstError = new Pair<>(predictedIndex, goldIndex);
}
// otherwise, down the last transition, up the correct
if (!newGoldState.areTransitionsEqual(highestScoringState)) {
++numWrong;
wrongTransitions.incrementCount(goldTransition.getClass(), highestScoringTransitionFromGoldState.getClass());
List<String> goldFeatures = featureFactory.featurize(goldState);
int lastTransition = transitionIndex.indexOf(highestScoringState.transitions.peek());
updates.add(new TrainingUpdate(featureFactory.featurize(highestCurrentState), -1, lastTransition, learningRate));
updates.add(new TrainingUpdate(goldFeatures, transitionIndex.indexOf(goldTransition), -1, learningRate));
if (op.trainOptions().trainingMethod == ShiftReduceTrainOptions.TrainingMethod.BEAM) {
// If the correct state has fallen off the agenda, break
if (!ShiftReduceUtils.findStateOnAgenda(newAgenda, newGoldState)) {
break;
} else {
transitions.remove(0);
}
} else if (op.trainOptions().trainingMethod == ShiftReduceTrainOptions.TrainingMethod.REORDER_BEAM) {
if (!ShiftReduceUtils.findStateOnAgenda(newAgenda, newGoldState)) {
if (!reorderer.reorder(goldState, highestScoringTransitionFromGoldState, transitions)) {
if (reorderSuccess == 0)
reorderFail = 1;
break;
}
newGoldState = highestScoringTransitionFromGoldState.apply(goldState);
if (!ShiftReduceUtils.findStateOnAgenda(newAgenda, newGoldState)) {
if (reorderSuccess == 0)
reorderFail = 1;
break;
}
reorderSuccess = 1;
} else {
transitions.remove(0);
}
}
} else {
++numCorrect;
correctTransitions.incrementCount(goldTransition.getClass());
transitions.remove(0);
}
goldState = newGoldState;
agenda = newAgenda;
}
} else if (op.trainOptions().trainingMethod == ShiftReduceTrainOptions.TrainingMethod.REORDER_ORACLE || op.trainOptions().trainingMethod == ShiftReduceTrainOptions.TrainingMethod.EARLY_TERMINATION || op.trainOptions().trainingMethod == ShiftReduceTrainOptions.TrainingMethod.GOLD) {
State state = example.initialStateFromGoldTagTree();
List<Transition> transitions = example.trainTransitions();
boolean keepGoing = true;
while (transitions.size() > 0 && keepGoing) {
Transition gold = transitions.get(0);
int goldNum = transitionIndex.indexOf(gold);
List<String> features = featureFactory.featurize(state);
int predictedNum = findHighestScoringTransition(state, features, false).object();
Transition predicted = transitionIndex.get(predictedNum);
if (goldNum == predictedNum) {
transitions.remove(0);
state = gold.apply(state);
numCorrect++;
correctTransitions.incrementCount(gold.getClass());
} else {
numWrong++;
wrongTransitions.incrementCount(gold.getClass(), predicted.getClass());
if (firstError == null) {
firstError = new Pair<>(predictedNum, goldNum);
}
// TODO: allow weighted features, weighted training, etc
updates.add(new TrainingUpdate(features, goldNum, predictedNum, learningRate));
switch(op.trainOptions().trainingMethod) {
case EARLY_TERMINATION:
keepGoing = false;
break;
case GOLD:
transitions.remove(0);
state = gold.apply(state);
break;
case REORDER_ORACLE:
keepGoing = reorderer.reorder(state, predicted, transitions);
if (keepGoing) {
state = predicted.apply(state);
reorderSuccess = 1;
} else if (reorderSuccess == 0) {
reorderFail = 1;
}
break;
default:
throw new IllegalArgumentException("Unexpected method " + op.trainOptions().trainingMethod);
}
}
}
}
return new TrainingResult(updates, numCorrect, numWrong, firstError, correctTransitions, wrongTransitions, reorderSuccess, reorderFail);
}
Aggregations