use of edu.stanford.nlp.stats.IntCounter in project CoreNLP by stanfordnlp.
the class PerceptronModel method outputStats.
/**
* Output some random facts about the model and the training iteration
*/
public void outputStats(TrainingResult result) {
log.info("While training, got " + result.numCorrect + " transitions correct and " + result.numWrong + " transitions wrong");
log.info("Number of known features: " + featureWeights.size());
log.info("Number of non-zero weights: " + numWeights());
log.info("Weight values maxAbs: " + maxAbs());
int wordLength = 0;
for (String feature : featureWeights.keySet()) {
wordLength += feature.length();
}
log.info("Total word length: " + wordLength);
log.info("Number of transitions: " + transitionIndex.size());
IntCounter<Pair<Integer, Integer>> firstErrors = new IntCounter<>();
for (Pair<Integer, Integer> firstError : result.firstErrors) {
firstErrors.incrementCount(firstError);
}
outputFirstErrors(firstErrors);
outputReordererStats(result.reorderSuccess, result.reorderFail);
outputTransitionStats(result);
}
use of edu.stanford.nlp.stats.IntCounter in project CoreNLP by stanfordnlp.
the class PerceptronModel method trainTree.
/**
* index: the tree to train
* binarizedTrees: a list of all the training trees we know about, binarized
* transitionLists: a list of pre-assembled transitions for the trees
*/
private TrainingResult trainTree(TrainingExample example) {
int numCorrect = 0;
int numWrong = 0;
Tree tree = example.binarizedTree;
List<TrainingUpdate> updates = Generics.newArrayList();
Pair<Integer, Integer> firstError = null;
IntCounter<Class<? extends Transition>> correctTransitions = new IntCounter<>();
TwoDimensionalIntCounter<Class<? extends Transition>, Class<? extends Transition>> wrongTransitions = new TwoDimensionalIntCounter<>();
ReorderingOracle reorderer = null;
if (op.trainOptions().trainingMethod == ShiftReduceTrainOptions.TrainingMethod.REORDER_ORACLE || op.trainOptions().trainingMethod == ShiftReduceTrainOptions.TrainingMethod.REORDER_BEAM) {
reorderer = new ReorderingOracle(op, rootOnlyStates);
}
int reorderSuccess = 0;
int reorderFail = 0;
if (op.trainOptions().trainingMethod == ShiftReduceTrainOptions.TrainingMethod.BEAM || op.trainOptions().trainingMethod == ShiftReduceTrainOptions.TrainingMethod.REORDER_BEAM) {
if (op.trainOptions().beamSize <= 0) {
throw new IllegalArgumentException("Illegal beam size " + op.trainOptions().beamSize);
}
PriorityQueue<State> agenda = new PriorityQueue<>(op.trainOptions().beamSize + 1, ScoredComparator.ASCENDING_COMPARATOR);
State goldState = example.initialStateFromGoldTagTree();
List<Transition> transitions = example.trainTransitions();
agenda.add(goldState);
while (transitions.size() > 0) {
Transition goldTransition = transitions.get(0);
Transition highestScoringTransitionFromGoldState = null;
double highestScoreFromGoldState = 0.0;
PriorityQueue<State> newAgenda = new PriorityQueue<>(op.trainOptions().beamSize + 1, ScoredComparator.ASCENDING_COMPARATOR);
State highestScoringState = null;
// keep track of the state in the current agenda which leads
// to the highest score on the next agenda. this will be
// trained down assuming it is not the correct state
State highestCurrentState = null;
for (State currentState : agenda) {
// TODO: can maybe speed this part up, although it doesn't seem like a critical part of the runtime
boolean isGoldState = goldState.areTransitionsEqual(currentState);
List<String> features = featureFactory.featurize(currentState);
Collection<ScoredObject<Integer>> stateTransitions = findHighestScoringTransitions(currentState, features, true, op.trainOptions().beamSize, null);
for (ScoredObject<Integer> transition : stateTransitions) {
State newState = transitionIndex.get(transition.object()).apply(currentState, transition.score());
newAgenda.add(newState);
if (newAgenda.size() > op.trainOptions().beamSize) {
newAgenda.poll();
}
if (highestScoringState == null || highestScoringState.score() < newState.score()) {
highestScoringState = newState;
highestCurrentState = currentState;
}
if (isGoldState && (highestScoringTransitionFromGoldState == null || transition.score() > highestScoreFromGoldState)) {
highestScoringTransitionFromGoldState = transitionIndex.get(transition.object());
highestScoreFromGoldState = transition.score();
}
}
}
// state (eg one with ROOT) isn't on the agenda so it stops.
if (op.trainOptions().trainingMethod == ShiftReduceTrainOptions.TrainingMethod.REORDER_BEAM && highestScoringTransitionFromGoldState == null) {
break;
}
if (highestScoringState == null) {
System.err.println("Unable to find a best transition!");
System.err.println("Previous agenda:");
for (State state : agenda) {
System.err.println(state);
}
System.err.println("Gold transitions:");
System.err.println(example.transitions);
break;
}
State newGoldState = goldTransition.apply(goldState, 0.0);
if (firstError == null && !highestScoringTransitionFromGoldState.equals(goldTransition)) {
int predictedIndex = transitionIndex.indexOf(highestScoringTransitionFromGoldState);
int goldIndex = transitionIndex.indexOf(goldTransition);
if (predictedIndex < 0) {
throw new AssertionError("Predicted transition not in the index: " + highestScoringTransitionFromGoldState);
}
if (goldIndex < 0) {
throw new AssertionError("Gold transition not in the index: " + goldTransition);
}
firstError = new Pair<>(predictedIndex, goldIndex);
}
// otherwise, down the last transition, up the correct
if (!newGoldState.areTransitionsEqual(highestScoringState)) {
++numWrong;
wrongTransitions.incrementCount(goldTransition.getClass(), highestScoringTransitionFromGoldState.getClass());
List<String> goldFeatures = featureFactory.featurize(goldState);
int lastTransition = transitionIndex.indexOf(highestScoringState.transitions.peek());
updates.add(new TrainingUpdate(featureFactory.featurize(highestCurrentState), -1, lastTransition, learningRate));
updates.add(new TrainingUpdate(goldFeatures, transitionIndex.indexOf(goldTransition), -1, learningRate));
if (op.trainOptions().trainingMethod == ShiftReduceTrainOptions.TrainingMethod.BEAM) {
// If the correct state has fallen off the agenda, break
if (!ShiftReduceUtils.findStateOnAgenda(newAgenda, newGoldState)) {
break;
} else {
transitions.remove(0);
}
} else if (op.trainOptions().trainingMethod == ShiftReduceTrainOptions.TrainingMethod.REORDER_BEAM) {
if (!ShiftReduceUtils.findStateOnAgenda(newAgenda, newGoldState)) {
if (!reorderer.reorder(goldState, highestScoringTransitionFromGoldState, transitions)) {
if (reorderSuccess == 0)
reorderFail = 1;
break;
}
newGoldState = highestScoringTransitionFromGoldState.apply(goldState);
if (!ShiftReduceUtils.findStateOnAgenda(newAgenda, newGoldState)) {
if (reorderSuccess == 0)
reorderFail = 1;
break;
}
reorderSuccess = 1;
} else {
transitions.remove(0);
}
}
} else {
++numCorrect;
correctTransitions.incrementCount(goldTransition.getClass());
transitions.remove(0);
}
goldState = newGoldState;
agenda = newAgenda;
}
} else if (op.trainOptions().trainingMethod == ShiftReduceTrainOptions.TrainingMethod.REORDER_ORACLE || op.trainOptions().trainingMethod == ShiftReduceTrainOptions.TrainingMethod.EARLY_TERMINATION || op.trainOptions().trainingMethod == ShiftReduceTrainOptions.TrainingMethod.GOLD) {
State state = example.initialStateFromGoldTagTree();
List<Transition> transitions = example.trainTransitions();
boolean keepGoing = true;
while (transitions.size() > 0 && keepGoing) {
Transition gold = transitions.get(0);
int goldNum = transitionIndex.indexOf(gold);
List<String> features = featureFactory.featurize(state);
int predictedNum = findHighestScoringTransition(state, features, false).object();
Transition predicted = transitionIndex.get(predictedNum);
if (goldNum == predictedNum) {
transitions.remove(0);
state = gold.apply(state);
numCorrect++;
correctTransitions.incrementCount(gold.getClass());
} else {
numWrong++;
wrongTransitions.incrementCount(gold.getClass(), predicted.getClass());
if (firstError == null) {
firstError = new Pair<>(predictedNum, goldNum);
}
// TODO: allow weighted features, weighted training, etc
updates.add(new TrainingUpdate(features, goldNum, predictedNum, learningRate));
switch(op.trainOptions().trainingMethod) {
case EARLY_TERMINATION:
keepGoing = false;
break;
case GOLD:
transitions.remove(0);
state = gold.apply(state);
break;
case REORDER_ORACLE:
keepGoing = reorderer.reorder(state, predicted, transitions);
if (keepGoing) {
state = predicted.apply(state);
reorderSuccess = 1;
} else if (reorderSuccess == 0) {
reorderFail = 1;
}
break;
default:
throw new IllegalArgumentException("Unexpected method " + op.trainOptions().trainingMethod);
}
}
}
}
return new TrainingResult(updates, numCorrect, numWrong, firstError, correctTransitions, wrongTransitions, reorderSuccess, reorderFail);
}
use of edu.stanford.nlp.stats.IntCounter in project CoreNLP by stanfordnlp.
the class PerceptronModel method outputFirstErrors.
/**
* Output the top 9 transition errors made by the model during training.
* <br>
* Creates a copy so that the original counter is unchanged
*/
private void outputFirstErrors(IntCounter<Pair<Integer, Integer>> firstErrors) {
if (firstErrors == null || firstErrors.size() == 0)
return;
IntCounter<Pair<Integer, Integer>> firstErrorCopy = new IntCounter<>(firstErrors);
log.info("Most common transition errors: gold -> predicted");
for (int i = 0; i < 9 && firstErrorCopy.size() > 0; ++i) {
Pair<Integer, Integer> mostCommon = firstErrorCopy.argmax();
int count = firstErrorCopy.max();
firstErrorCopy.decrementCount(mostCommon, count);
Transition predicted = transitionIndex.get(mostCommon.first());
Transition gold = transitionIndex.get(mostCommon.second());
log.info(" # " + (i + 1) + ": " + gold + " -> " + predicted + " happened " + firstErrorCopy.max() + " times");
}
}
use of edu.stanford.nlp.stats.IntCounter in project CoreNLP by stanfordnlp.
the class ChineseWordFeatureExtractor method applyFeatureCountThreshold.
public void applyFeatureCountThreshold(Collection<String> data, int thresh) {
IntCounter c = new IntCounter();
for (String datum : data) {
for (String feat : makeFeatures(datum)) {
c.incrementCount(feat);
}
}
threshedFeatures = c.keysAbove(thresh);
log.info((c.size() - threshedFeatures.size()) + " word features removed due to thresholding.");
}
Aggregations