use of edu.stanford.nlp.parser.lexparser.EvaluateTreebank in project CoreNLP by stanfordnlp.
the class PerceptronModel method trainModel.
private void trainModel(String serializedPath, Tagger tagger, Random random, List<Tree> binarizedTrees, List<List<Transition>> transitionLists, Treebank devTreebank, int nThreads, Set<String> allowedFeatures) {
double bestScore = 0.0;
int bestIteration = 0;
PriorityQueue<ScoredObject<PerceptronModel>> bestModels = null;
if (op.trainOptions().averagedModels > 0) {
bestModels = new PriorityQueue<>(op.trainOptions().averagedModels + 1, ScoredComparator.ASCENDING_COMPARATOR);
}
List<Integer> indices = Generics.newArrayList();
for (int i = 0; i < binarizedTrees.size(); ++i) {
indices.add(i);
}
Oracle oracle = null;
if (op.trainOptions().trainingMethod == ShiftReduceTrainOptions.TrainingMethod.ORACLE) {
oracle = new Oracle(binarizedTrees, op.compoundUnaries, rootStates);
}
List<Update> updates = Generics.newArrayList();
MulticoreWrapper<Integer, Pair<Integer, Integer>> wrapper = null;
if (nThreads != 1) {
updates = Collections.synchronizedList(updates);
wrapper = new MulticoreWrapper<>(op.trainOptions.trainingThreads, new TrainTreeProcessor(binarizedTrees, transitionLists, updates, oracle));
}
IntCounter<String> featureFrequencies = null;
if (op.trainOptions().featureFrequencyCutoff > 1) {
featureFrequencies = new IntCounter<>();
}
for (int iteration = 1; iteration <= op.trainOptions.trainingIterations; ++iteration) {
Timing trainingTimer = new Timing();
int numCorrect = 0;
int numWrong = 0;
Collections.shuffle(indices, random);
for (int start = 0; start < indices.size(); start += op.trainOptions.batchSize) {
int end = Math.min(start + op.trainOptions.batchSize, indices.size());
Triple<List<Update>, Integer, Integer> result = trainBatch(indices.subList(start, end), binarizedTrees, transitionLists, updates, oracle, wrapper);
numCorrect += result.second;
numWrong += result.third;
for (Update update : result.first) {
for (String feature : update.features) {
if (allowedFeatures != null && !allowedFeatures.contains(feature)) {
continue;
}
Weight weights = featureWeights.get(feature);
if (weights == null) {
weights = new Weight();
featureWeights.put(feature, weights);
}
weights.updateWeight(update.goldTransition, update.delta);
weights.updateWeight(update.predictedTransition, -update.delta);
if (featureFrequencies != null) {
featureFrequencies.incrementCount(feature, (update.goldTransition >= 0 && update.predictedTransition >= 0) ? 2 : 1);
}
}
}
updates.clear();
}
trainingTimer.done("Iteration " + iteration);
log.info("While training, got " + numCorrect + " transitions correct and " + numWrong + " transitions wrong");
outputStats();
double labelF1 = 0.0;
if (devTreebank != null) {
EvaluateTreebank evaluator = new EvaluateTreebank(op, null, new ShiftReduceParser(op, this), tagger);
evaluator.testOnTreebank(devTreebank);
labelF1 = evaluator.getLBScore();
log.info("Label F1 after " + iteration + " iterations: " + labelF1);
if (labelF1 > bestScore) {
log.info("New best dev score (previous best " + bestScore + ")");
bestScore = labelF1;
bestIteration = iteration;
} else {
log.info("Failed to improve for " + (iteration - bestIteration) + " iteration(s) on previous best score of " + bestScore);
if (op.trainOptions.stalledIterationLimit > 0 && (iteration - bestIteration >= op.trainOptions.stalledIterationLimit)) {
log.info("Failed to improve for too long, stopping training");
break;
}
}
log.info();
if (bestModels != null) {
bestModels.add(new ScoredObject<>(new PerceptronModel(this), labelF1));
if (bestModels.size() > op.trainOptions().averagedModels) {
bestModels.poll();
}
}
}
if (op.trainOptions().saveIntermediateModels && serializedPath != null && op.trainOptions.debugOutputFrequency > 0) {
String tempName = serializedPath.substring(0, serializedPath.length() - 7) + "-" + FILENAME.format(iteration) + "-" + NF.format(labelF1) + ".ser.gz";
ShiftReduceParser temp = new ShiftReduceParser(op, this);
temp.saveModel(tempName);
// TODO: we could save a cutoff version of the model,
// especially if we also get a dev set number for it, but that
// might be overkill
}
}
if (wrapper != null) {
wrapper.join();
}
if (bestModels != null) {
if (op.trainOptions().cvAveragedModels && devTreebank != null) {
List<ScoredObject<PerceptronModel>> models = Generics.newArrayList();
while (bestModels.size() > 0) {
models.add(bestModels.poll());
}
Collections.reverse(models);
double bestF1 = 0.0;
int bestSize = 0;
for (int i = 1; i <= models.size(); ++i) {
log.info("Testing with " + i + " models averaged together");
// TODO: this is kind of ugly, would prefer a separate object
averageScoredModels(models.subList(0, i));
ShiftReduceParser temp = new ShiftReduceParser(op, this);
EvaluateTreebank evaluator = new EvaluateTreebank(temp.getOp(), null, temp, tagger);
evaluator.testOnTreebank(devTreebank);
double labelF1 = evaluator.getLBScore();
log.info("Label F1 for " + i + " models: " + labelF1);
if (labelF1 > bestF1) {
bestF1 = labelF1;
bestSize = i;
}
}
averageScoredModels(models.subList(0, bestSize));
} else {
averageScoredModels(bestModels);
}
}
// after filtering.
if (featureFrequencies != null) {
filterFeatures(featureFrequencies.keysAbove(op.trainOptions().featureFrequencyCutoff));
}
condenseFeatures();
}
Aggregations