use of edu.stanford.nlp.parser.metrics.EvaluateTreebank in project CoreNLP by stanfordnlp.
the class StringBinaryRule method testGrammarCompaction.
public void testGrammarCompaction() {
// these for testing against the markov 3rd order baseline
// use the parser constructor to extract the grammars from the treebank
op = new Options();
LexicalizedParser lp = LexicalizedParser.trainFromTreebank(path, new NumberRangeFileFilter(trainLow, trainHigh, true), op);
// compact grammars
if (compactor != null) {
// extract a bunch of paths
Timing.startTime();
System.out.print("Extracting other paths...");
allTrainPaths = extractPaths(path, trainLow, trainHigh, true);
allTestPaths = extractPaths(path, testLow, testHigh, true);
Timing.tick("done");
// compact grammars
Timing.startTime();
System.out.print("Compacting grammars...");
Pair<UnaryGrammar, BinaryGrammar> grammar = Generics.newPair(lp.ug, lp.bg);
Triple<Index<String>, UnaryGrammar, BinaryGrammar> compactedGrammar = compactor.compactGrammar(grammar, allTrainPaths, allTestPaths, lp.stateIndex);
lp.stateIndex = compactedGrammar.first();
lp.ug = compactedGrammar.second();
lp.bg = compactedGrammar.third();
Timing.tick("done.");
}
if (asciiOutputPath != null) {
lp.saveParserToTextFile(asciiOutputPath);
}
// test it
Treebank testTreebank = op.tlpParams.testMemoryTreebank();
testTreebank.loadPath(path, new NumberRangeFileFilter(testLow, testHigh, true));
System.out.println("Currently " + new Date());
EvaluateTreebank evaluator = new EvaluateTreebank(lp);
evaluator.testOnTreebank(testTreebank);
System.out.println("Currently " + new Date());
}
use of edu.stanford.nlp.parser.metrics.EvaluateTreebank in project CoreNLP by stanfordnlp.
the class DVParser method train.
public void train(List<Tree> sentences, IdentityHashMap<Tree, byte[]> compressedParses, Treebank testTreebank, String modelPath, String resultsRecordPath) throws IOException {
// process:
// we come up with a cost and a derivative for the model
// we always use the gold tree as the example to train towards
// every time through, we will look at the top N trees from
// the LexicalizedParser and pick the best one according to
// our model (at the start, this is essentially random)
// we use QN to minimize the cost function for the model
// to do this minimization, we turn all of the matrices in the
// DVModel into one big Theta, which is the set of variables to
// be optimized by the QN.
Timing timing = new Timing();
long maxTrainTimeMillis = op.trainOptions.maxTrainTimeSeconds * 1000;
int batchCount = 0;
int debugCycle = 0;
double bestLabelF1 = 0.0;
if (op.trainOptions.useContextWords) {
for (Tree tree : sentences) {
Trees.convertToCoreLabels(tree);
tree.setSpans();
}
}
// for AdaGrad
double[] sumGradSquare = new double[dvModel.totalParamSize()];
Arrays.fill(sumGradSquare, 1.0);
int numBatches = sentences.size() / op.trainOptions.batchSize + 1;
log.info("Training on " + sentences.size() + " trees in " + numBatches + " batches");
log.info("Times through each training batch: " + op.trainOptions.trainingIterations);
log.info("QN iterations per batch: " + op.trainOptions.qnIterationsPerBatch);
for (int iter = 0; iter < op.trainOptions.trainingIterations; ++iter) {
List<Tree> shuffledSentences = new ArrayList<>(sentences);
Collections.shuffle(shuffledSentences, dvModel.rand);
for (int batch = 0; batch < numBatches; ++batch) {
++batchCount;
// This did not help performance
// log.info("Setting AdaGrad's sum of squares to 1...");
// Arrays.fill(sumGradSquare, 1.0);
log.info("======================================");
log.info("Iteration " + iter + " batch " + batch);
// Each batch will be of the specified batch size, except the
// last batch will include any leftover trees at the end of
// the list
int startTree = batch * op.trainOptions.batchSize;
int endTree = (batch + 1) * op.trainOptions.batchSize;
if (endTree > shuffledSentences.size()) {
endTree = shuffledSentences.size();
}
executeOneTrainingBatch(shuffledSentences.subList(startTree, endTree), compressedParses, sumGradSquare);
long totalElapsed = timing.report();
log.info("Finished iteration " + iter + " batch " + batch + "; total training time " + totalElapsed + " ms");
if (maxTrainTimeMillis > 0 && totalElapsed > maxTrainTimeMillis) {
// no need to debug output, we're done now
break;
}
if (op.trainOptions.debugOutputFrequency > 0 && batchCount % op.trainOptions.debugOutputFrequency == 0) {
log.info("Finished " + batchCount + " total batches, running evaluation cycle");
// Time for debugging output!
double tagF1 = 0.0;
double labelF1 = 0.0;
if (testTreebank != null) {
EvaluateTreebank evaluator = new EvaluateTreebank(attachModelToLexicalizedParser());
evaluator.testOnTreebank(testTreebank);
labelF1 = evaluator.getLBScore();
tagF1 = evaluator.getTagScore();
if (labelF1 > bestLabelF1) {
bestLabelF1 = labelF1;
}
log.info("Best label f1 on dev set so far: " + NF.format(bestLabelF1));
}
String tempName = null;
if (modelPath != null) {
tempName = modelPath;
if (modelPath.endsWith(".ser.gz")) {
tempName = modelPath.substring(0, modelPath.length() - 7) + "-" + FILENAME.format(debugCycle) + "-" + NF.format(labelF1) + ".ser.gz";
}
saveModel(tempName);
}
String statusLine = ("CHECKPOINT:" + " iteration " + iter + " batch " + batch + " labelF1 " + NF.format(labelF1) + " tagF1 " + NF.format(tagF1) + " bestLabelF1 " + NF.format(bestLabelF1) + " model " + tempName + op.trainOptions + " word vectors: " + op.lexOptions.wordVectorFile + " numHid: " + op.lexOptions.numHid);
log.info(statusLine);
if (resultsRecordPath != null) {
// append
FileWriter fout = new FileWriter(resultsRecordPath, true);
fout.write(statusLine);
fout.write("\n");
fout.close();
}
++debugCycle;
}
}
long totalElapsed = timing.report();
if (maxTrainTimeMillis > 0 && totalElapsed > maxTrainTimeMillis) {
// no need to debug output, we're done now
log.info("Max training time exceeded, exiting");
break;
}
}
}
use of edu.stanford.nlp.parser.metrics.EvaluateTreebank in project CoreNLP by stanfordnlp.
the class DVParser method main.
/**
* An example command line for training a new parser:
* <br>
* nohup java -mx6g edu.stanford.nlp.parser.dvparser.DVParser -cachedTrees /scr/nlp/data/dvparser/wsj/cached.wsj.train.simple.ser.gz -train -testTreebank /afs/ir/data/linguistic-data/Treebank/3/parsed/mrg/wsj/22 2200-2219 -debugOutputFrequency 400 -nofilter -trainingThreads 5 -parser /u/nlp/data/lexparser/wsjPCFG.nocompact.simple.ser.gz -trainingIterations 40 -batchSize 25 -model /scr/nlp/data/dvparser/wsj/wsj.combine.v2.ser.gz -unkWord "*UNK*" -dvCombineCategories > /scr/nlp/data/dvparser/wsj/wsj.combine.v2.out 2>&1 &
*/
public static void main(String[] args) throws IOException, ClassNotFoundException {
if (args.length == 0) {
help();
System.exit(2);
}
log.info("Running DVParser with arguments:");
for (String arg : args) {
log.info(" " + arg);
}
log.info();
String parserPath = null;
String trainTreebankPath = null;
FileFilter trainTreebankFilter = null;
String cachedTrainTreesPath = null;
boolean runGradientCheck = false;
boolean runTraining = false;
String testTreebankPath = null;
FileFilter testTreebankFilter = null;
String initialModelPath = null;
String modelPath = null;
boolean filter = true;
String resultsRecordPath = null;
List<String> unusedArgs = new ArrayList<>();
// These parameters can be null or 0 if the model was not
// serialized with the new parameters. Setting the options at the
// command line will override these defaults.
// TODO: if/when we integrate back into the main branch and
// rebuild models, we can get rid of this
List<String> argsWithDefaults = new ArrayList<>(Arrays.asList(new String[] { "-wordVectorFile", Options.LexOptions.DEFAULT_WORD_VECTOR_FILE, "-dvKBest", Integer.toString(TrainOptions.DEFAULT_K_BEST), "-batchSize", Integer.toString(TrainOptions.DEFAULT_BATCH_SIZE), "-trainingIterations", Integer.toString(TrainOptions.DEFAULT_TRAINING_ITERATIONS), "-qnIterationsPerBatch", Integer.toString(TrainOptions.DEFAULT_QN_ITERATIONS_PER_BATCH), "-regCost", Double.toString(TrainOptions.DEFAULT_REGCOST), "-learningRate", Double.toString(TrainOptions.DEFAULT_LEARNING_RATE), "-deltaMargin", Double.toString(TrainOptions.DEFAULT_DELTA_MARGIN), "-unknownNumberVector", "-unknownDashedWordVectors", "-unknownCapsVector", "-unknownchinesepercentvector", "-unknownchinesenumbervector", "-unknownchineseyearvector", "-unkWord", "*UNK*", "-transformMatrixType", "DIAGONAL", "-scalingForInit", Double.toString(TrainOptions.DEFAULT_SCALING_FOR_INIT), "-trainWordVectors" }));
argsWithDefaults.addAll(Arrays.asList(args));
args = argsWithDefaults.toArray(new String[argsWithDefaults.size()]);
for (int argIndex = 0; argIndex < args.length; ) {
if (args[argIndex].equalsIgnoreCase("-parser")) {
parserPath = args[argIndex + 1];
argIndex += 2;
} else if (args[argIndex].equalsIgnoreCase("-testTreebank")) {
Pair<String, FileFilter> treebankDescription = ArgUtils.getTreebankDescription(args, argIndex, "-testTreebank");
argIndex = argIndex + ArgUtils.numSubArgs(args, argIndex) + 1;
testTreebankPath = treebankDescription.first();
testTreebankFilter = treebankDescription.second();
} else if (args[argIndex].equalsIgnoreCase("-treebank")) {
Pair<String, FileFilter> treebankDescription = ArgUtils.getTreebankDescription(args, argIndex, "-treebank");
argIndex = argIndex + ArgUtils.numSubArgs(args, argIndex) + 1;
trainTreebankPath = treebankDescription.first();
trainTreebankFilter = treebankDescription.second();
} else if (args[argIndex].equalsIgnoreCase("-cachedTrees")) {
cachedTrainTreesPath = args[argIndex + 1];
argIndex += 2;
} else if (args[argIndex].equalsIgnoreCase("-runGradientCheck")) {
runGradientCheck = true;
argIndex++;
} else if (args[argIndex].equalsIgnoreCase("-train")) {
runTraining = true;
argIndex++;
} else if (args[argIndex].equalsIgnoreCase("-model")) {
modelPath = args[argIndex + 1];
argIndex += 2;
} else if (args[argIndex].equalsIgnoreCase("-nofilter")) {
filter = false;
argIndex++;
} else if (args[argIndex].equalsIgnoreCase("-continueTraining")) {
runTraining = true;
filter = false;
initialModelPath = args[argIndex + 1];
argIndex += 2;
} else if (args[argIndex].equalsIgnoreCase("-resultsRecord")) {
resultsRecordPath = args[argIndex + 1];
argIndex += 2;
} else {
unusedArgs.add(args[argIndex++]);
}
}
if (parserPath == null && modelPath == null) {
throw new IllegalArgumentException("Must supply either a base parser model with -parser or a serialized DVParser with -model");
}
if (!runTraining && modelPath == null && !runGradientCheck) {
throw new IllegalArgumentException("Need to either train a new model, run the gradient check or specify a model to load with -model");
}
String[] newArgs = unusedArgs.toArray(new String[unusedArgs.size()]);
DVParser dvparser = null;
LexicalizedParser lexparser = null;
if (initialModelPath != null) {
lexparser = LexicalizedParser.loadModel(initialModelPath, newArgs);
DVModel model = getModelFromLexicalizedParser(lexparser);
dvparser = new DVParser(model, lexparser);
} else if (runTraining || runGradientCheck) {
lexparser = LexicalizedParser.loadModel(parserPath, newArgs);
dvparser = new DVParser(lexparser);
} else if (modelPath != null) {
lexparser = LexicalizedParser.loadModel(modelPath, newArgs);
DVModel model = getModelFromLexicalizedParser(lexparser);
dvparser = new DVParser(model, lexparser);
}
List<Tree> trainSentences = new ArrayList<>();
IdentityHashMap<Tree, byte[]> trainCompressedParses = Generics.newIdentityHashMap();
if (cachedTrainTreesPath != null) {
for (String path : cachedTrainTreesPath.split(",")) {
List<Pair<Tree, byte[]>> cache = IOUtils.readObjectFromFile(path);
for (Pair<Tree, byte[]> pair : cache) {
trainSentences.add(pair.first());
trainCompressedParses.put(pair.first(), pair.second());
}
log.info("Read in " + cache.size() + " trees from " + path);
}
}
if (trainTreebankPath != null) {
// TODO: make the transformer a member of the model?
TreeTransformer transformer = buildTrainTransformer(dvparser.getOp());
Treebank treebank = dvparser.getOp().tlpParams.memoryTreebank();
;
treebank.loadPath(trainTreebankPath, trainTreebankFilter);
treebank = treebank.transform(transformer);
log.info("Read in " + treebank.size() + " trees from " + trainTreebankPath);
CacheParseHypotheses cacher = new CacheParseHypotheses(dvparser.parser);
CacheParseHypotheses.CacheProcessor processor = new CacheParseHypotheses.CacheProcessor(cacher, lexparser, dvparser.op.trainOptions.dvKBest, transformer);
for (Tree tree : treebank) {
trainSentences.add(tree);
trainCompressedParses.put(tree, processor.process(tree).second);
// System.out.println(tree);
}
log.info("Finished parsing " + treebank.size() + " trees, getting " + dvparser.op.trainOptions.dvKBest + " hypotheses each");
}
if ((runTraining || runGradientCheck) && filter) {
log.info("Filtering rules for the given training set");
dvparser.dvModel.setRulesForTrainingSet(trainSentences, trainCompressedParses);
log.info("Done filtering rules; " + dvparser.dvModel.numBinaryMatrices + " binary matrices, " + dvparser.dvModel.numUnaryMatrices + " unary matrices, " + dvparser.dvModel.wordVectors.size() + " word vectors");
}
// dvparser.dvModel.printAllMatrices();
Treebank testTreebank = null;
if (testTreebankPath != null) {
log.info("Reading in trees from " + testTreebankPath);
if (testTreebankFilter != null) {
log.info("Filtering on " + testTreebankFilter);
}
testTreebank = dvparser.getOp().tlpParams.memoryTreebank();
;
testTreebank.loadPath(testTreebankPath, testTreebankFilter);
log.info("Read in " + testTreebank.size() + " trees for testing");
}
// runGradientCheck= true;
if (runGradientCheck) {
log.info("Running gradient check on " + trainSentences.size() + " trees");
dvparser.runGradientCheck(trainSentences, trainCompressedParses);
}
if (runTraining) {
log.info("Training the RNN parser");
log.info("Current train options: " + dvparser.getOp().trainOptions);
dvparser.train(trainSentences, trainCompressedParses, testTreebank, modelPath, resultsRecordPath);
if (modelPath != null) {
dvparser.saveModel(modelPath);
}
}
if (testTreebankPath != null) {
EvaluateTreebank evaluator = new EvaluateTreebank(dvparser.attachModelToLexicalizedParser());
evaluator.testOnTreebank(testTreebank);
}
log.info("Successfully ran DVParser");
}
use of edu.stanford.nlp.parser.metrics.EvaluateTreebank in project CoreNLP by stanfordnlp.
the class CrossValidateTestOptions method main.
public static void main(String[] args) throws IOException, ClassNotFoundException {
String dvmodelFile = null;
String lexparserFile = null;
String testTreebankPath = null;
FileFilter testTreebankFilter = null;
List<String> unusedArgs = new ArrayList<>();
for (int argIndex = 0; argIndex < args.length; ) {
if (args[argIndex].equalsIgnoreCase("-lexparser")) {
lexparserFile = args[argIndex + 1];
argIndex += 2;
} else if (args[argIndex].equalsIgnoreCase("-testTreebank")) {
Pair<String, FileFilter> treebankDescription = ArgUtils.getTreebankDescription(args, argIndex, "-testTreebank");
argIndex = argIndex + ArgUtils.numSubArgs(args, argIndex) + 1;
testTreebankPath = treebankDescription.first();
testTreebankFilter = treebankDescription.second();
} else {
unusedArgs.add(args[argIndex++]);
}
}
log.info("Loading lexparser from: " + lexparserFile);
String[] newArgs = unusedArgs.toArray(new String[unusedArgs.size()]);
LexicalizedParser lexparser = LexicalizedParser.loadModel(lexparserFile, newArgs);
log.info("... done");
Treebank testTreebank = null;
if (testTreebankPath != null) {
log.info("Reading in trees from " + testTreebankPath);
if (testTreebankFilter != null) {
log.info("Filtering on " + testTreebankFilter);
}
testTreebank = lexparser.getOp().tlpParams.memoryTreebank();
;
testTreebank.loadPath(testTreebankPath, testTreebankFilter);
log.info("Read in " + testTreebank.size() + " trees for testing");
}
double[] labelResults = new double[weights.length];
double[] tagResults = new double[weights.length];
for (int i = 0; i < weights.length; ++i) {
lexparser.getOp().baseParserWeight = weights[i];
EvaluateTreebank evaluator = new EvaluateTreebank(lexparser);
evaluator.testOnTreebank(testTreebank);
labelResults[i] = evaluator.getLBScore();
tagResults[i] = evaluator.getTagScore();
}
for (int i = 0; i < weights.length; ++i) {
log.info("LexicalizedParser weight " + weights[i] + ": labeled " + labelResults[i] + " tag " + tagResults[i]);
}
}
use of edu.stanford.nlp.parser.metrics.EvaluateTreebank in project CoreNLP by stanfordnlp.
the class PerceptronModel method evaluate.
private double evaluate(Tagger tagger, Treebank devTreebank, String message) {
ShiftReduceParser temp = new ShiftReduceParser(op, this);
EvaluateTreebank evaluator = new EvaluateTreebank(temp.getOp(), null, temp, tagger, temp.getExtraEvals(), temp.getParserQueryEvals());
evaluator.testOnTreebank(devTreebank);
double labelF1 = evaluator.getLBScore();
log.info(message + ": " + labelF1);
return labelF1;
}
Aggregations