Search in sources :

Example 1 with IntCounter

use of edu.stanford.nlp.stats.IntCounter in project CoreNLP by stanfordnlp.

the class PerceptronModel method outputStats.

/**
 * Output some random facts about the model and the training iteration
 */
public void outputStats(TrainingResult result) {
    log.info("While training, got " + result.numCorrect + " transitions correct and " + result.numWrong + " transitions wrong");
    log.info("Number of known features: " + featureWeights.size());
    log.info("Number of non-zero weights: " + numWeights());
    log.info("Weight values maxAbs: " + maxAbs());
    int wordLength = 0;
    for (String feature : featureWeights.keySet()) {
        wordLength += feature.length();
    }
    log.info("Total word length: " + wordLength);
    log.info("Number of transitions: " + transitionIndex.size());
    IntCounter<Pair<Integer, Integer>> firstErrors = new IntCounter<>();
    for (Pair<Integer, Integer> firstError : result.firstErrors) {
        firstErrors.incrementCount(firstError);
    }
    outputFirstErrors(firstErrors);
    outputReordererStats(result.reorderSuccess, result.reorderFail);
    outputTransitionStats(result);
}
Also used : TwoDimensionalIntCounter(edu.stanford.nlp.stats.TwoDimensionalIntCounter) IntCounter(edu.stanford.nlp.stats.IntCounter) ParserConstraint(edu.stanford.nlp.parser.common.ParserConstraint) Pair(edu.stanford.nlp.util.Pair)

Example 2 with IntCounter

use of edu.stanford.nlp.stats.IntCounter in project CoreNLP by stanfordnlp.

the class PerceptronModel method trainTree.

/**
 * index: the tree to train
 * binarizedTrees: a list of all the training trees we know about, binarized
 * transitionLists: a list of pre-assembled transitions for the trees
 */
private TrainingResult trainTree(TrainingExample example) {
    int numCorrect = 0;
    int numWrong = 0;
    Tree tree = example.binarizedTree;
    List<TrainingUpdate> updates = Generics.newArrayList();
    Pair<Integer, Integer> firstError = null;
    IntCounter<Class<? extends Transition>> correctTransitions = new IntCounter<>();
    TwoDimensionalIntCounter<Class<? extends Transition>, Class<? extends Transition>> wrongTransitions = new TwoDimensionalIntCounter<>();
    ReorderingOracle reorderer = null;
    if (op.trainOptions().trainingMethod == ShiftReduceTrainOptions.TrainingMethod.REORDER_ORACLE || op.trainOptions().trainingMethod == ShiftReduceTrainOptions.TrainingMethod.REORDER_BEAM) {
        reorderer = new ReorderingOracle(op, rootOnlyStates);
    }
    int reorderSuccess = 0;
    int reorderFail = 0;
    if (op.trainOptions().trainingMethod == ShiftReduceTrainOptions.TrainingMethod.BEAM || op.trainOptions().trainingMethod == ShiftReduceTrainOptions.TrainingMethod.REORDER_BEAM) {
        if (op.trainOptions().beamSize <= 0) {
            throw new IllegalArgumentException("Illegal beam size " + op.trainOptions().beamSize);
        }
        PriorityQueue<State> agenda = new PriorityQueue<>(op.trainOptions().beamSize + 1, ScoredComparator.ASCENDING_COMPARATOR);
        State goldState = example.initialStateFromGoldTagTree();
        List<Transition> transitions = example.trainTransitions();
        agenda.add(goldState);
        while (transitions.size() > 0) {
            Transition goldTransition = transitions.get(0);
            Transition highestScoringTransitionFromGoldState = null;
            double highestScoreFromGoldState = 0.0;
            PriorityQueue<State> newAgenda = new PriorityQueue<>(op.trainOptions().beamSize + 1, ScoredComparator.ASCENDING_COMPARATOR);
            State highestScoringState = null;
            // keep track of the state in the current agenda which leads
            // to the highest score on the next agenda.  this will be
            // trained down assuming it is not the correct state
            State highestCurrentState = null;
            for (State currentState : agenda) {
                // TODO: can maybe speed this part up, although it doesn't seem like a critical part of the runtime
                boolean isGoldState = goldState.areTransitionsEqual(currentState);
                List<String> features = featureFactory.featurize(currentState);
                Collection<ScoredObject<Integer>> stateTransitions = findHighestScoringTransitions(currentState, features, true, op.trainOptions().beamSize, null);
                for (ScoredObject<Integer> transition : stateTransitions) {
                    State newState = transitionIndex.get(transition.object()).apply(currentState, transition.score());
                    newAgenda.add(newState);
                    if (newAgenda.size() > op.trainOptions().beamSize) {
                        newAgenda.poll();
                    }
                    if (highestScoringState == null || highestScoringState.score() < newState.score()) {
                        highestScoringState = newState;
                        highestCurrentState = currentState;
                    }
                    if (isGoldState && (highestScoringTransitionFromGoldState == null || transition.score() > highestScoreFromGoldState)) {
                        highestScoringTransitionFromGoldState = transitionIndex.get(transition.object());
                        highestScoreFromGoldState = transition.score();
                    }
                }
            }
            // state (eg one with ROOT) isn't on the agenda so it stops.
            if (op.trainOptions().trainingMethod == ShiftReduceTrainOptions.TrainingMethod.REORDER_BEAM && highestScoringTransitionFromGoldState == null) {
                break;
            }
            if (highestScoringState == null) {
                System.err.println("Unable to find a best transition!");
                System.err.println("Previous agenda:");
                for (State state : agenda) {
                    System.err.println(state);
                }
                System.err.println("Gold transitions:");
                System.err.println(example.transitions);
                break;
            }
            State newGoldState = goldTransition.apply(goldState, 0.0);
            if (firstError == null && !highestScoringTransitionFromGoldState.equals(goldTransition)) {
                int predictedIndex = transitionIndex.indexOf(highestScoringTransitionFromGoldState);
                int goldIndex = transitionIndex.indexOf(goldTransition);
                if (predictedIndex < 0) {
                    throw new AssertionError("Predicted transition not in the index: " + highestScoringTransitionFromGoldState);
                }
                if (goldIndex < 0) {
                    throw new AssertionError("Gold transition not in the index: " + goldTransition);
                }
                firstError = new Pair<>(predictedIndex, goldIndex);
            }
            // otherwise, down the last transition, up the correct
            if (!newGoldState.areTransitionsEqual(highestScoringState)) {
                ++numWrong;
                wrongTransitions.incrementCount(goldTransition.getClass(), highestScoringTransitionFromGoldState.getClass());
                List<String> goldFeatures = featureFactory.featurize(goldState);
                int lastTransition = transitionIndex.indexOf(highestScoringState.transitions.peek());
                updates.add(new TrainingUpdate(featureFactory.featurize(highestCurrentState), -1, lastTransition, learningRate));
                updates.add(new TrainingUpdate(goldFeatures, transitionIndex.indexOf(goldTransition), -1, learningRate));
                if (op.trainOptions().trainingMethod == ShiftReduceTrainOptions.TrainingMethod.BEAM) {
                    // If the correct state has fallen off the agenda, break
                    if (!ShiftReduceUtils.findStateOnAgenda(newAgenda, newGoldState)) {
                        break;
                    } else {
                        transitions.remove(0);
                    }
                } else if (op.trainOptions().trainingMethod == ShiftReduceTrainOptions.TrainingMethod.REORDER_BEAM) {
                    if (!ShiftReduceUtils.findStateOnAgenda(newAgenda, newGoldState)) {
                        if (!reorderer.reorder(goldState, highestScoringTransitionFromGoldState, transitions)) {
                            if (reorderSuccess == 0)
                                reorderFail = 1;
                            break;
                        }
                        newGoldState = highestScoringTransitionFromGoldState.apply(goldState);
                        if (!ShiftReduceUtils.findStateOnAgenda(newAgenda, newGoldState)) {
                            if (reorderSuccess == 0)
                                reorderFail = 1;
                            break;
                        }
                        reorderSuccess = 1;
                    } else {
                        transitions.remove(0);
                    }
                }
            } else {
                ++numCorrect;
                correctTransitions.incrementCount(goldTransition.getClass());
                transitions.remove(0);
            }
            goldState = newGoldState;
            agenda = newAgenda;
        }
    } else if (op.trainOptions().trainingMethod == ShiftReduceTrainOptions.TrainingMethod.REORDER_ORACLE || op.trainOptions().trainingMethod == ShiftReduceTrainOptions.TrainingMethod.EARLY_TERMINATION || op.trainOptions().trainingMethod == ShiftReduceTrainOptions.TrainingMethod.GOLD) {
        State state = example.initialStateFromGoldTagTree();
        List<Transition> transitions = example.trainTransitions();
        boolean keepGoing = true;
        while (transitions.size() > 0 && keepGoing) {
            Transition gold = transitions.get(0);
            int goldNum = transitionIndex.indexOf(gold);
            List<String> features = featureFactory.featurize(state);
            int predictedNum = findHighestScoringTransition(state, features, false).object();
            Transition predicted = transitionIndex.get(predictedNum);
            if (goldNum == predictedNum) {
                transitions.remove(0);
                state = gold.apply(state);
                numCorrect++;
                correctTransitions.incrementCount(gold.getClass());
            } else {
                numWrong++;
                wrongTransitions.incrementCount(gold.getClass(), predicted.getClass());
                if (firstError == null) {
                    firstError = new Pair<>(predictedNum, goldNum);
                }
                // TODO: allow weighted features, weighted training, etc
                updates.add(new TrainingUpdate(features, goldNum, predictedNum, learningRate));
                switch(op.trainOptions().trainingMethod) {
                    case EARLY_TERMINATION:
                        keepGoing = false;
                        break;
                    case GOLD:
                        transitions.remove(0);
                        state = gold.apply(state);
                        break;
                    case REORDER_ORACLE:
                        keepGoing = reorderer.reorder(state, predicted, transitions);
                        if (keepGoing) {
                            state = predicted.apply(state);
                            reorderSuccess = 1;
                        } else if (reorderSuccess == 0) {
                            reorderFail = 1;
                        }
                        break;
                    default:
                        throw new IllegalArgumentException("Unexpected method " + op.trainOptions().trainingMethod);
                }
            }
        }
    }
    return new TrainingResult(updates, numCorrect, numWrong, firstError, correctTransitions, wrongTransitions, reorderSuccess, reorderFail);
}
Also used : ScoredObject(edu.stanford.nlp.util.ScoredObject) Tree(edu.stanford.nlp.trees.Tree) ArrayList(java.util.ArrayList) List(java.util.List) TwoDimensionalIntCounter(edu.stanford.nlp.stats.TwoDimensionalIntCounter) IntCounter(edu.stanford.nlp.stats.IntCounter) Pair(edu.stanford.nlp.util.Pair) PriorityQueue(java.util.PriorityQueue) TwoDimensionalIntCounter(edu.stanford.nlp.stats.TwoDimensionalIntCounter) ParserConstraint(edu.stanford.nlp.parser.common.ParserConstraint)

Example 3 with IntCounter

use of edu.stanford.nlp.stats.IntCounter in project CoreNLP by stanfordnlp.

the class PerceptronModel method outputFirstErrors.

/**
 * Output the top 9 transition errors made by the model during training.
 * <br>
 * Creates a copy so that the original counter is unchanged
 */
private void outputFirstErrors(IntCounter<Pair<Integer, Integer>> firstErrors) {
    if (firstErrors == null || firstErrors.size() == 0)
        return;
    IntCounter<Pair<Integer, Integer>> firstErrorCopy = new IntCounter<>(firstErrors);
    log.info("Most common transition errors: gold -> predicted");
    for (int i = 0; i < 9 && firstErrorCopy.size() > 0; ++i) {
        Pair<Integer, Integer> mostCommon = firstErrorCopy.argmax();
        int count = firstErrorCopy.max();
        firstErrorCopy.decrementCount(mostCommon, count);
        Transition predicted = transitionIndex.get(mostCommon.first());
        Transition gold = transitionIndex.get(mostCommon.second());
        log.info("  # " + (i + 1) + ": " + gold + " -> " + predicted + " happened " + firstErrorCopy.max() + " times");
    }
}
Also used : TwoDimensionalIntCounter(edu.stanford.nlp.stats.TwoDimensionalIntCounter) IntCounter(edu.stanford.nlp.stats.IntCounter) ParserConstraint(edu.stanford.nlp.parser.common.ParserConstraint) Pair(edu.stanford.nlp.util.Pair)

Example 4 with IntCounter

use of edu.stanford.nlp.stats.IntCounter in project CoreNLP by stanfordnlp.

the class ChineseWordFeatureExtractor method applyFeatureCountThreshold.

public void applyFeatureCountThreshold(Collection<String> data, int thresh) {
    IntCounter c = new IntCounter();
    for (String datum : data) {
        for (String feat : makeFeatures(datum)) {
            c.incrementCount(feat);
        }
    }
    threshedFeatures = c.keysAbove(thresh);
    log.info((c.size() - threshedFeatures.size()) + " word features removed due to thresholding.");
}
Also used : IntCounter(edu.stanford.nlp.stats.IntCounter)

Aggregations

IntCounter (edu.stanford.nlp.stats.IntCounter)4 ParserConstraint (edu.stanford.nlp.parser.common.ParserConstraint)3 TwoDimensionalIntCounter (edu.stanford.nlp.stats.TwoDimensionalIntCounter)3 Pair (edu.stanford.nlp.util.Pair)3 Tree (edu.stanford.nlp.trees.Tree)1 ScoredObject (edu.stanford.nlp.util.ScoredObject)1 ArrayList (java.util.ArrayList)1 List (java.util.List)1 PriorityQueue (java.util.PriorityQueue)1