use of edu.stanford.nlp.util.ScoredObject in project CoreNLP by stanfordnlp.
the class PerceptronModel method trainTree.
private Pair<Integer, Integer> trainTree(int index, List<Tree> binarizedTrees, List<List<Transition>> transitionLists, List<Update> updates, Oracle oracle) {
int numCorrect = 0;
int numWrong = 0;
Tree tree = binarizedTrees.get(index);
ReorderingOracle reorderer = null;
if (op.trainOptions().trainingMethod == ShiftReduceTrainOptions.TrainingMethod.REORDER_ORACLE || op.trainOptions().trainingMethod == ShiftReduceTrainOptions.TrainingMethod.REORDER_BEAM) {
reorderer = new ReorderingOracle(op);
}
// it under control.
if (op.trainOptions().trainingMethod == ShiftReduceTrainOptions.TrainingMethod.ORACLE) {
State state = ShiftReduceParser.initialStateFromGoldTagTree(tree);
while (!state.isFinished()) {
List<String> features = featureFactory.featurize(state);
ScoredObject<Integer> prediction = findHighestScoringTransition(state, features, true);
if (prediction == null) {
throw new AssertionError("Did not find a legal transition");
}
int predictedNum = prediction.object();
Transition predicted = transitionIndex.get(predictedNum);
OracleTransition gold = oracle.goldTransition(index, state);
if (gold.isCorrect(predicted)) {
numCorrect++;
if (gold.transition != null && !gold.transition.equals(predicted)) {
int transitionNum = transitionIndex.indexOf(gold.transition);
if (transitionNum < 0) {
// only possible when the parser has gone off the rails?
continue;
}
updates.add(new Update(features, transitionNum, -1, 1.0f));
}
} else {
numWrong++;
int transitionNum = -1;
if (gold.transition != null) {
transitionNum = transitionIndex.indexOf(gold.transition);
// TODO: this can theoretically result in a -1 gold
// transition if the transition exists, but is a
// CompoundUnaryTransition which only exists because the
// parser is wrong. Do we want to add those transitions?
}
updates.add(new Update(features, transitionNum, predictedNum, 1.0f));
}
state = predicted.apply(state);
}
} else if (op.trainOptions().trainingMethod == ShiftReduceTrainOptions.TrainingMethod.BEAM || op.trainOptions().trainingMethod == ShiftReduceTrainOptions.TrainingMethod.REORDER_BEAM) {
if (op.trainOptions().beamSize <= 0) {
throw new IllegalArgumentException("Illegal beam size " + op.trainOptions().beamSize);
}
List<Transition> transitions = Generics.newLinkedList(transitionLists.get(index));
PriorityQueue<State> agenda = new PriorityQueue<>(op.trainOptions().beamSize + 1, ScoredComparator.ASCENDING_COMPARATOR);
State goldState = ShiftReduceParser.initialStateFromGoldTagTree(tree);
agenda.add(goldState);
int transitionCount = 0;
while (transitions.size() > 0) {
Transition goldTransition = transitions.get(0);
Transition highestScoringTransitionFromGoldState = null;
double highestScoreFromGoldState = 0.0;
PriorityQueue<State> newAgenda = new PriorityQueue<>(op.trainOptions().beamSize + 1, ScoredComparator.ASCENDING_COMPARATOR);
State highestScoringState = null;
State highestCurrentState = null;
for (State currentState : agenda) {
boolean isGoldState = (op.trainOptions().trainingMethod == ShiftReduceTrainOptions.TrainingMethod.REORDER_BEAM && goldState.areTransitionsEqual(currentState));
List<String> features = featureFactory.featurize(currentState);
Collection<ScoredObject<Integer>> stateTransitions = findHighestScoringTransitions(currentState, features, true, op.trainOptions().beamSize, null);
for (ScoredObject<Integer> transition : stateTransitions) {
State newState = transitionIndex.get(transition.object()).apply(currentState, transition.score());
newAgenda.add(newState);
if (newAgenda.size() > op.trainOptions().beamSize) {
newAgenda.poll();
}
if (highestScoringState == null || highestScoringState.score() < newState.score()) {
highestScoringState = newState;
highestCurrentState = currentState;
}
if (isGoldState && (highestScoringTransitionFromGoldState == null || transition.score() > highestScoreFromGoldState)) {
highestScoringTransitionFromGoldState = transitionIndex.get(transition.object());
highestScoreFromGoldState = transition.score();
}
}
}
// state (eg one with ROOT) isn't on the agenda so it stops.
if (op.trainOptions().trainingMethod == ShiftReduceTrainOptions.TrainingMethod.REORDER_BEAM && highestScoringTransitionFromGoldState == null) {
break;
}
State newGoldState = goldTransition.apply(goldState, 0.0);
// otherwise, down the last transition, up the correct
if (!newGoldState.areTransitionsEqual(highestScoringState)) {
++numWrong;
List<String> goldFeatures = featureFactory.featurize(goldState);
int lastTransition = transitionIndex.indexOf(highestScoringState.transitions.peek());
updates.add(new Update(featureFactory.featurize(highestCurrentState), -1, lastTransition, 1.0f));
updates.add(new Update(goldFeatures, transitionIndex.indexOf(goldTransition), -1, 1.0f));
if (op.trainOptions().trainingMethod == ShiftReduceTrainOptions.TrainingMethod.BEAM) {
// If the correct state has fallen off the agenda, break
if (!ShiftReduceUtils.findStateOnAgenda(newAgenda, newGoldState)) {
break;
} else {
transitions.remove(0);
}
} else if (op.trainOptions().trainingMethod == ShiftReduceTrainOptions.TrainingMethod.REORDER_BEAM) {
if (!ShiftReduceUtils.findStateOnAgenda(newAgenda, newGoldState)) {
if (!reorderer.reorder(goldState, highestScoringTransitionFromGoldState, transitions)) {
break;
}
newGoldState = highestScoringTransitionFromGoldState.apply(goldState);
if (!ShiftReduceUtils.findStateOnAgenda(newAgenda, newGoldState)) {
break;
}
} else {
transitions.remove(0);
}
}
} else {
++numCorrect;
transitions.remove(0);
}
goldState = newGoldState;
agenda = newAgenda;
}
} else if (op.trainOptions().trainingMethod == ShiftReduceTrainOptions.TrainingMethod.REORDER_ORACLE || op.trainOptions().trainingMethod == ShiftReduceTrainOptions.TrainingMethod.EARLY_TERMINATION || op.trainOptions().trainingMethod == ShiftReduceTrainOptions.TrainingMethod.GOLD) {
State state = ShiftReduceParser.initialStateFromGoldTagTree(tree);
List<Transition> transitions = transitionLists.get(index);
transitions = Generics.newLinkedList(transitions);
boolean keepGoing = true;
while (transitions.size() > 0 && keepGoing) {
Transition transition = transitions.get(0);
int transitionNum = transitionIndex.indexOf(transition);
List<String> features = featureFactory.featurize(state);
int predictedNum = findHighestScoringTransition(state, features, false).object();
Transition predicted = transitionIndex.get(predictedNum);
if (transitionNum == predictedNum) {
transitions.remove(0);
state = transition.apply(state);
numCorrect++;
} else {
numWrong++;
// TODO: allow weighted features, weighted training, etc
updates.add(new Update(features, transitionNum, predictedNum, 1.0f));
switch(op.trainOptions().trainingMethod) {
case EARLY_TERMINATION:
keepGoing = false;
break;
case GOLD:
transitions.remove(0);
state = transition.apply(state);
break;
case REORDER_ORACLE:
keepGoing = reorderer.reorder(state, predicted, transitions);
if (keepGoing) {
state = predicted.apply(state);
}
break;
default:
throw new IllegalArgumentException("Unexpected method " + op.trainOptions().trainingMethod);
}
}
}
}
return Pair.makePair(numCorrect, numWrong);
}
use of edu.stanford.nlp.util.ScoredObject in project CoreNLP by stanfordnlp.
the class LexicalizedParserQuery method getKBestParses.
/**
* Return the k best parses of the sentence most recently parsed.
*
* NB: The dependency parser does not implement a k-best method
* and the factored parser's method seems to be broken and therefore
* this method always returns a list of size 1 if either of these
* two parsers was used.
*
* @return A list of scored trees
* @throws NoSuchParseException If no previously successfully parsed
* sentence
*/
@Override
public List<ScoredObject<Tree>> getKBestParses(int k) {
if (parseSkipped) {
return null;
}
if (bparser != null && parseSucceeded) {
// The getKGoodParses seems to be broken, so just return the best parse
Tree binaryTree = bparser.getBestParse();
Tree tree = debinarizer.transformTree(binaryTree);
if (op.nodePrune) {
NodePruner np = new NodePruner(pparser, debinarizer);
tree = np.prune(tree);
}
tree = subcategoryStripper.transformTree(tree);
restoreOriginalWords(tree);
double score = dparser.getBestScore();
ScoredObject<Tree> so = new ScoredObject<>(tree, score);
List<ScoredObject<Tree>> trees = new ArrayList<>(1);
trees.add(so);
return trees;
} else if (pparser != null && pparser.hasParse() && fallbackToPCFG) {
return this.getKBestPCFGParses(k);
} else if (dparser != null && dparser.hasParse()) {
// && fallbackToDG
// The dependency parser doesn't support k-best parse extraction, so just
// return the best parse
Tree tree = this.getBestDependencyParse(true);
double score = dparser.getBestScore();
ScoredObject<Tree> so = new ScoredObject<>(tree, score);
List<ScoredObject<Tree>> trees = new ArrayList<>(1);
trees.add(so);
return trees;
} else {
throw new NoSuchParseException();
}
}
use of edu.stanford.nlp.util.ScoredObject in project CoreNLP by stanfordnlp.
the class EvaluateTreebank method processResults.
public void processResults(ParserQuery pq, Tree goldTree, PrintWriter pwErr, PrintWriter pwOut, PrintWriter pwFileOut, PrintWriter pwStats, TreePrint treePrint) {
if (pq.saidMemMessage()) {
saidMemMessage = true;
}
Tree tree;
List<? extends HasWord> sentence = pq.originalSentence();
try {
tree = pq.getBestParse();
} catch (NoSuchParseException e) {
tree = null;
}
List<ScoredObject<Tree>> kbestPCFGTrees = null;
if (tree != null && kbestPCFG > 0) {
kbestPCFGTrees = pq.getKBestPCFGParses(kbestPCFG);
}
// combo parse goes to pwOut (System.out)
if (op.testOptions.verbose) {
pwOut.println("ComboParser best");
Tree ot = tree;
if (ot != null && !op.tlpParams.treebankLanguagePack().isStartSymbol(ot.value())) {
ot = ot.treeFactory().newTreeNode(op.tlpParams.treebankLanguagePack().startSymbol(), Collections.singletonList(ot));
}
treePrint.printTree(ot, pwOut);
} else {
treePrint.printTree(tree, pwOut);
}
// print various statistics
if (tree != null) {
if (op.testOptions.printAllBestParses) {
List<ScoredObject<Tree>> parses = pq.getBestPCFGParses();
int sz = parses.size();
if (sz > 1) {
pwOut.println("There were " + sz + " best PCFG parses with score " + parses.get(0).score() + '.');
Tree transGoldTree = collinizer.transformTree(goldTree);
int iii = 0;
for (ScoredObject<Tree> sot : parses) {
iii++;
Tree tb = sot.object();
Tree tbd = debinarizer.transformTree(tb);
tbd = subcategoryStripper.transformTree(tbd);
pq.restoreOriginalWords(tbd);
pwOut.println("PCFG Parse #" + iii + " with score " + tbd.score());
tbd.pennPrint(pwOut);
Tree tbtr = collinizer.transformTree(tbd);
// pwOut.println("Tree size = " + tbtr.size() + "; depth = " + tbtr.depth());
kGoodLB.evaluate(tbtr, transGoldTree, pwErr);
}
}
} else // Huang and Chiang (2006) Algorithm 3 output from the PCFG parser
if (op.testOptions.printPCFGkBest > 0 && op.testOptions.outputkBestEquivocation == null) {
List<ScoredObject<Tree>> trees = kbestPCFGTrees.subList(0, op.testOptions.printPCFGkBest);
Tree transGoldTree = collinizer.transformTree(goldTree);
int i = 0;
for (ScoredObject<Tree> tp : trees) {
i++;
pwOut.println("PCFG Parse #" + i + " with score " + tp.score());
Tree tbd = tp.object();
tbd.pennPrint(pwOut);
Tree tbtr = collinizer.transformTree(tbd);
kGoodLB.evaluate(tbtr, transGoldTree, pwErr);
}
} else // Chart parser (factored) n-best list
if (op.testOptions.printFactoredKGood > 0 && pq.hasFactoredParse()) {
// DZ: debug n best trees
List<ScoredObject<Tree>> trees = pq.getKGoodFactoredParses(op.testOptions.printFactoredKGood);
Tree transGoldTree = collinizer.transformTree(goldTree);
int ii = 0;
for (ScoredObject<Tree> tp : trees) {
ii++;
pwOut.println("Factored Parse #" + ii + " with score " + tp.score());
Tree tbd = tp.object();
tbd.pennPrint(pwOut);
Tree tbtr = collinizer.transformTree(tbd);
kGoodLB.evaluate(tbtr, transGoldTree, pwOut);
}
} else // 1-best output
if (pwFileOut != null) {
pwFileOut.println(tree.toString());
}
// Print the derivational entropy
if (op.testOptions.outputkBestEquivocation != null && op.testOptions.printPCFGkBest > 0) {
List<ScoredObject<Tree>> trees = kbestPCFGTrees.subList(0, op.testOptions.printPCFGkBest);
double[] logScores = new double[trees.size()];
int treeId = 0;
for (ScoredObject<Tree> kBestTree : trees) logScores[treeId++] = kBestTree.score();
// Re-normalize
double entropy = 0.0;
double denom = ArrayMath.logSum(logScores);
for (double logScore : logScores) {
double logPr = logScore - denom;
entropy += Math.exp(logPr) * logPr * LN_TO_LOG2;
}
// Convert to bits
entropy *= -1;
pwStats.printf("%f\t%d\t%d\n", entropy, trees.size(), sentence.size());
}
}
// Perform various evaluations specified by the user
if (tree != null) {
// Strip subcategories and remove punctuation for evaluation
tree = subcategoryStripper.transformTree(tree);
Tree treeFact = collinizer.transformTree(tree);
// Setup the gold tree
if (op.testOptions.verbose) {
pwOut.println("Correct parse");
treePrint.printTree(goldTree, pwOut);
}
Tree transGoldTree = collinizer.transformTree(goldTree);
if (transGoldTree != null)
transGoldTree = subcategoryStripper.transformTree(transGoldTree);
// Can't do evaluation in these two cases
if (transGoldTree == null) {
pwErr.println("Couldn't transform gold tree for evaluation, skipping eval. Gold tree was:");
goldTree.pennPrint(pwErr);
numSkippedEvals++;
return;
} else if (treeFact == null) {
pwErr.println("Couldn't transform hypothesis tree for evaluation, skipping eval. Tree was:");
tree.pennPrint(pwErr);
numSkippedEvals++;
return;
} else if (treeFact.yield().size() != transGoldTree.yield().size()) {
List<Label> fYield = treeFact.yield();
List<Label> gYield = transGoldTree.yield();
pwErr.println("WARNING: Evaluation could not be performed due to gold/parsed yield mismatch.");
pwErr.printf(" sizes: gold: %d (transf) %d (orig); parsed: %d (transf) %d (orig).%n", gYield.size(), goldTree.yield().size(), fYield.size(), tree.yield().size());
pwErr.println(" gold: " + SentenceUtils.listToString(gYield, true));
pwErr.println(" pars: " + SentenceUtils.listToString(fYield, true));
numSkippedEvals++;
return;
}
if (topKEvals.size() > 0) {
List<Tree> transGuesses = new ArrayList<>();
int kbest = Math.min(op.testOptions.evalPCFGkBest, kbestPCFGTrees.size());
for (ScoredObject<Tree> guess : kbestPCFGTrees.subList(0, kbest)) {
transGuesses.add(collinizer.transformTree(guess.object()));
}
for (BestOfTopKEval eval : topKEvals) {
eval.evaluate(transGuesses, transGoldTree, pwErr);
}
}
// PCFG eval
Tree treePCFG = pq.getBestPCFGParse();
if (treePCFG != null) {
Tree treePCFGeval = collinizer.transformTree(treePCFG);
if (pcfgLB != null) {
pcfgLB.evaluate(treePCFGeval, transGoldTree, pwErr);
}
if (pcfgChildSpecific != null) {
pcfgChildSpecific.evaluate(treePCFGeval, transGoldTree, pwErr);
}
if (pcfgLA != null) {
pcfgLA.evaluate(treePCFGeval, transGoldTree, pwErr);
}
if (pcfgCB != null) {
pcfgCB.evaluate(treePCFGeval, transGoldTree, pwErr);
}
if (pcfgDA != null) {
// Re-index the leaves after Collinization, stripping traces, etc.
treePCFGeval.indexLeaves(true);
transGoldTree.indexLeaves(true);
pcfgDA.evaluate(treePCFGeval, transGoldTree, pwErr);
}
if (pcfgTA != null) {
pcfgTA.evaluate(treePCFGeval, transGoldTree, pwErr);
}
if (pcfgLL != null && pq.getPCFGParser() != null) {
pcfgLL.recordScore(pq.getPCFGParser(), pwErr);
}
if (pcfgRUO != null) {
pcfgRUO.evaluate(treePCFGeval, transGoldTree, pwErr);
}
if (pcfgCUO != null) {
pcfgCUO.evaluate(treePCFGeval, transGoldTree, pwErr);
}
if (pcfgCatE != null) {
pcfgCatE.evaluate(treePCFGeval, transGoldTree, pwErr);
}
}
// Dependency eval
// todo: is treeDep really useful here, or should we really use depDAEval tree (debinarized) throughout? We use it for parse, and it sure seems like we could use it for tag eval, but maybe not factDA?
Tree treeDep = pq.getBestDependencyParse(false);
if (treeDep != null) {
Tree goldTreeB = binarizerOnly.transformTree(goldTree);
Tree goldTreeEval = goldTree.deepCopy();
goldTreeEval.indexLeaves(true);
goldTreeEval.percolateHeads(op.langpack().headFinder());
Tree depDAEval = pq.getBestDependencyParse(true);
depDAEval.indexLeaves(true);
depDAEval.percolateHeadIndices();
if (depDA != null) {
depDA.evaluate(depDAEval, goldTreeEval, pwErr);
}
if (depTA != null) {
Tree undoneTree = debinarizer.transformTree(treeDep);
undoneTree = subcategoryStripper.transformTree(undoneTree);
pq.restoreOriginalWords(undoneTree);
// pwErr.println("subcategoryStripped tree: " + undoneTree.toStructureDebugString());
depTA.evaluate(undoneTree, goldTree, pwErr);
}
if (depLL != null && pq.getDependencyParser() != null) {
depLL.recordScore(pq.getDependencyParser(), pwErr);
}
Tree factTreeB;
if (pq.hasFactoredParse()) {
factTreeB = pq.getBestFactoredParse();
} else {
factTreeB = treeDep;
}
if (factDA != null) {
factDA.evaluate(factTreeB, goldTreeB, pwErr);
}
}
// Factored parser (1best) eval
if (factLB != null) {
factLB.evaluate(treeFact, transGoldTree, pwErr);
}
if (factChildSpecific != null) {
factChildSpecific.evaluate(treeFact, transGoldTree, pwErr);
}
if (factLA != null) {
factLA.evaluate(treeFact, transGoldTree, pwErr);
}
if (factTA != null) {
factTA.evaluate(tree, boundaryRemover.transformTree(goldTree), pwErr);
}
if (factLL != null && pq.getFactoredParser() != null) {
factLL.recordScore(pq.getFactoredParser(), pwErr);
}
if (factCB != null) {
factCB.evaluate(treeFact, transGoldTree, pwErr);
}
for (Eval eval : evals) {
eval.evaluate(treeFact, transGoldTree, pwErr);
}
for (ParserQueryEval eval : parserQueryEvals) {
eval.evaluate(pq, transGoldTree, pwErr);
}
if (op.testOptions.evalb) {
// empty out scores just in case
nanScores(tree);
EvalbFormatWriter.writeEVALBline(treeFact, transGoldTree);
}
}
pwErr.println();
}
use of edu.stanford.nlp.util.ScoredObject in project CoreNLP by stanfordnlp.
the class RerankingParserQuery method getBestPCFGParses.
@Override
public List<ScoredObject<Tree>> getBestPCFGParses() {
if (scoredTrees == null || scoredTrees.isEmpty()) {
throw new AssertionError();
}
List<ScoredObject<Tree>> equalTrees = Generics.newArrayList();
double score = scoredTrees.get(0).score();
int treePos = 0;
while (treePos < scoredTrees.size() && scoredTrees.get(treePos).score() == score) {
equalTrees.add(scoredTrees.get(treePos));
}
return equalTrees;
}
use of edu.stanford.nlp.util.ScoredObject in project CoreNLP by stanfordnlp.
the class DVParser method getTopParsesForOneTree.
public static List<Tree> getTopParsesForOneTree(LexicalizedParser parser, int dvKBest, Tree tree, TreeTransformer transformer) {
ParserQuery pq = parser.parserQuery();
List<Word> sentence = tree.yieldWords();
// sentence symbol
if (sentence.size() <= 1) {
return null;
}
sentence = sentence.subList(0, sentence.size() - 1);
if (!pq.parse(sentence)) {
log.info("Failed to use the given parser to reparse sentence \"" + sentence + "\"");
return null;
}
List<Tree> parses = new ArrayList<>();
List<ScoredObject<Tree>> bestKParses = pq.getKBestPCFGParses(dvKBest);
for (ScoredObject<Tree> so : bestKParses) {
Tree result = so.object();
if (transformer != null) {
result = transformer.transformTree(result);
}
parses.add(result);
}
return parses;
}
Aggregations