use of edu.stanford.nlp.parser.lexparser.EvaluateTreebank in project CoreNLP by stanfordnlp.
the class CrossValidateTestOptions method main.
public static void main(String[] args) throws IOException, ClassNotFoundException {
String dvmodelFile = null;
String lexparserFile = null;
String testTreebankPath = null;
FileFilter testTreebankFilter = null;
List<String> unusedArgs = new ArrayList<>();
for (int argIndex = 0; argIndex < args.length; ) {
if (args[argIndex].equalsIgnoreCase("-lexparser")) {
lexparserFile = args[argIndex + 1];
argIndex += 2;
} else if (args[argIndex].equalsIgnoreCase("-testTreebank")) {
Pair<String, FileFilter> treebankDescription = ArgUtils.getTreebankDescription(args, argIndex, "-testTreebank");
argIndex = argIndex + ArgUtils.numSubArgs(args, argIndex) + 1;
testTreebankPath = treebankDescription.first();
testTreebankFilter = treebankDescription.second();
} else {
unusedArgs.add(args[argIndex++]);
}
}
log.info("Loading lexparser from: " + lexparserFile);
String[] newArgs = unusedArgs.toArray(new String[unusedArgs.size()]);
LexicalizedParser lexparser = LexicalizedParser.loadModel(lexparserFile, newArgs);
log.info("... done");
Treebank testTreebank = null;
if (testTreebankPath != null) {
log.info("Reading in trees from " + testTreebankPath);
if (testTreebankFilter != null) {
log.info("Filtering on " + testTreebankFilter);
}
testTreebank = lexparser.getOp().tlpParams.memoryTreebank();
;
testTreebank.loadPath(testTreebankPath, testTreebankFilter);
log.info("Read in " + testTreebank.size() + " trees for testing");
}
double[] labelResults = new double[weights.length];
double[] tagResults = new double[weights.length];
for (int i = 0; i < weights.length; ++i) {
lexparser.getOp().baseParserWeight = weights[i];
EvaluateTreebank evaluator = new EvaluateTreebank(lexparser);
evaluator.testOnTreebank(testTreebank);
labelResults[i] = evaluator.getLBScore();
tagResults[i] = evaluator.getTagScore();
}
for (int i = 0; i < weights.length; ++i) {
log.info("LexicalizedParser weight " + weights[i] + ": labeled " + labelResults[i] + " tag " + tagResults[i]);
}
}
use of edu.stanford.nlp.parser.lexparser.EvaluateTreebank in project CoreNLP by stanfordnlp.
the class DVParser method main.
/**
* An example command line for training a new parser:
* <br>
* nohup java -mx6g edu.stanford.nlp.parser.dvparser.DVParser -cachedTrees /scr/nlp/data/dvparser/wsj/cached.wsj.train.simple.ser.gz -train -testTreebank /afs/ir/data/linguistic-data/Treebank/3/parsed/mrg/wsj/22 2200-2219 -debugOutputFrequency 400 -nofilter -trainingThreads 5 -parser /u/nlp/data/lexparser/wsjPCFG.nocompact.simple.ser.gz -trainingIterations 40 -batchSize 25 -model /scr/nlp/data/dvparser/wsj/wsj.combine.v2.ser.gz -unkWord "*UNK*" -dvCombineCategories > /scr/nlp/data/dvparser/wsj/wsj.combine.v2.out 2>&1 &
*/
public static void main(String[] args) throws IOException, ClassNotFoundException {
if (args.length == 0) {
help();
System.exit(2);
}
log.info("Running DVParser with arguments:");
for (String arg : args) {
log.info(" " + arg);
}
log.info();
String parserPath = null;
String trainTreebankPath = null;
FileFilter trainTreebankFilter = null;
String cachedTrainTreesPath = null;
boolean runGradientCheck = false;
boolean runTraining = false;
String testTreebankPath = null;
FileFilter testTreebankFilter = null;
String initialModelPath = null;
String modelPath = null;
boolean filter = true;
String resultsRecordPath = null;
List<String> unusedArgs = new ArrayList<>();
// These parameters can be null or 0 if the model was not
// serialized with the new parameters. Setting the options at the
// command line will override these defaults.
// TODO: if/when we integrate back into the main branch and
// rebuild models, we can get rid of this
List<String> argsWithDefaults = new ArrayList<>(Arrays.asList(new String[] { "-wordVectorFile", Options.LexOptions.DEFAULT_WORD_VECTOR_FILE, "-dvKBest", Integer.toString(TrainOptions.DEFAULT_K_BEST), "-batchSize", Integer.toString(TrainOptions.DEFAULT_BATCH_SIZE), "-trainingIterations", Integer.toString(TrainOptions.DEFAULT_TRAINING_ITERATIONS), "-qnIterationsPerBatch", Integer.toString(TrainOptions.DEFAULT_QN_ITERATIONS_PER_BATCH), "-regCost", Double.toString(TrainOptions.DEFAULT_REGCOST), "-learningRate", Double.toString(TrainOptions.DEFAULT_LEARNING_RATE), "-deltaMargin", Double.toString(TrainOptions.DEFAULT_DELTA_MARGIN), "-unknownNumberVector", "-unknownDashedWordVectors", "-unknownCapsVector", "-unknownchinesepercentvector", "-unknownchinesenumbervector", "-unknownchineseyearvector", "-unkWord", "*UNK*", "-transformMatrixType", "DIAGONAL", "-scalingForInit", Double.toString(TrainOptions.DEFAULT_SCALING_FOR_INIT), "-trainWordVectors" }));
argsWithDefaults.addAll(Arrays.asList(args));
args = argsWithDefaults.toArray(new String[argsWithDefaults.size()]);
for (int argIndex = 0; argIndex < args.length; ) {
if (args[argIndex].equalsIgnoreCase("-parser")) {
parserPath = args[argIndex + 1];
argIndex += 2;
} else if (args[argIndex].equalsIgnoreCase("-testTreebank")) {
Pair<String, FileFilter> treebankDescription = ArgUtils.getTreebankDescription(args, argIndex, "-testTreebank");
argIndex = argIndex + ArgUtils.numSubArgs(args, argIndex) + 1;
testTreebankPath = treebankDescription.first();
testTreebankFilter = treebankDescription.second();
} else if (args[argIndex].equalsIgnoreCase("-treebank")) {
Pair<String, FileFilter> treebankDescription = ArgUtils.getTreebankDescription(args, argIndex, "-treebank");
argIndex = argIndex + ArgUtils.numSubArgs(args, argIndex) + 1;
trainTreebankPath = treebankDescription.first();
trainTreebankFilter = treebankDescription.second();
} else if (args[argIndex].equalsIgnoreCase("-cachedTrees")) {
cachedTrainTreesPath = args[argIndex + 1];
argIndex += 2;
} else if (args[argIndex].equalsIgnoreCase("-runGradientCheck")) {
runGradientCheck = true;
argIndex++;
} else if (args[argIndex].equalsIgnoreCase("-train")) {
runTraining = true;
argIndex++;
} else if (args[argIndex].equalsIgnoreCase("-model")) {
modelPath = args[argIndex + 1];
argIndex += 2;
} else if (args[argIndex].equalsIgnoreCase("-nofilter")) {
filter = false;
argIndex++;
} else if (args[argIndex].equalsIgnoreCase("-continueTraining")) {
runTraining = true;
filter = false;
initialModelPath = args[argIndex + 1];
argIndex += 2;
} else if (args[argIndex].equalsIgnoreCase("-resultsRecord")) {
resultsRecordPath = args[argIndex + 1];
argIndex += 2;
} else {
unusedArgs.add(args[argIndex++]);
}
}
if (parserPath == null && modelPath == null) {
throw new IllegalArgumentException("Must supply either a base parser model with -parser or a serialized DVParser with -model");
}
if (!runTraining && modelPath == null && !runGradientCheck) {
throw new IllegalArgumentException("Need to either train a new model, run the gradient check or specify a model to load with -model");
}
String[] newArgs = unusedArgs.toArray(new String[unusedArgs.size()]);
DVParser dvparser = null;
LexicalizedParser lexparser = null;
if (initialModelPath != null) {
lexparser = LexicalizedParser.loadModel(initialModelPath, newArgs);
DVModel model = getModelFromLexicalizedParser(lexparser);
dvparser = new DVParser(model, lexparser);
} else if (runTraining || runGradientCheck) {
lexparser = LexicalizedParser.loadModel(parserPath, newArgs);
dvparser = new DVParser(lexparser);
} else if (modelPath != null) {
lexparser = LexicalizedParser.loadModel(modelPath, newArgs);
DVModel model = getModelFromLexicalizedParser(lexparser);
dvparser = new DVParser(model, lexparser);
}
List<Tree> trainSentences = new ArrayList<>();
IdentityHashMap<Tree, byte[]> trainCompressedParses = Generics.newIdentityHashMap();
if (cachedTrainTreesPath != null) {
for (String path : cachedTrainTreesPath.split(",")) {
List<Pair<Tree, byte[]>> cache = IOUtils.readObjectFromFile(path);
for (Pair<Tree, byte[]> pair : cache) {
trainSentences.add(pair.first());
trainCompressedParses.put(pair.first(), pair.second());
}
log.info("Read in " + cache.size() + " trees from " + path);
}
}
if (trainTreebankPath != null) {
// TODO: make the transformer a member of the model?
TreeTransformer transformer = buildTrainTransformer(dvparser.getOp());
Treebank treebank = dvparser.getOp().tlpParams.memoryTreebank();
;
treebank.loadPath(trainTreebankPath, trainTreebankFilter);
treebank = treebank.transform(transformer);
log.info("Read in " + treebank.size() + " trees from " + trainTreebankPath);
CacheParseHypotheses cacher = new CacheParseHypotheses(dvparser.parser);
CacheParseHypotheses.CacheProcessor processor = new CacheParseHypotheses.CacheProcessor(cacher, lexparser, dvparser.op.trainOptions.dvKBest, transformer);
for (Tree tree : treebank) {
trainSentences.add(tree);
trainCompressedParses.put(tree, processor.process(tree).second);
//System.out.println(tree);
}
log.info("Finished parsing " + treebank.size() + " trees, getting " + dvparser.op.trainOptions.dvKBest + " hypotheses each");
}
if ((runTraining || runGradientCheck) && filter) {
log.info("Filtering rules for the given training set");
dvparser.dvModel.setRulesForTrainingSet(trainSentences, trainCompressedParses);
log.info("Done filtering rules; " + dvparser.dvModel.numBinaryMatrices + " binary matrices, " + dvparser.dvModel.numUnaryMatrices + " unary matrices, " + dvparser.dvModel.wordVectors.size() + " word vectors");
}
//dvparser.dvModel.printAllMatrices();
Treebank testTreebank = null;
if (testTreebankPath != null) {
log.info("Reading in trees from " + testTreebankPath);
if (testTreebankFilter != null) {
log.info("Filtering on " + testTreebankFilter);
}
testTreebank = dvparser.getOp().tlpParams.memoryTreebank();
;
testTreebank.loadPath(testTreebankPath, testTreebankFilter);
log.info("Read in " + testTreebank.size() + " trees for testing");
}
// runGradientCheck= true;
if (runGradientCheck) {
log.info("Running gradient check on " + trainSentences.size() + " trees");
dvparser.runGradientCheck(trainSentences, trainCompressedParses);
}
if (runTraining) {
log.info("Training the RNN parser");
log.info("Current train options: " + dvparser.getOp().trainOptions);
dvparser.train(trainSentences, trainCompressedParses, testTreebank, modelPath, resultsRecordPath);
if (modelPath != null) {
dvparser.saveModel(modelPath);
}
}
if (testTreebankPath != null) {
EvaluateTreebank evaluator = new EvaluateTreebank(dvparser.attachModelToLexicalizedParser());
evaluator.testOnTreebank(testTreebank);
}
log.info("Successfully ran DVParser");
}
use of edu.stanford.nlp.parser.lexparser.EvaluateTreebank in project CoreNLP by stanfordnlp.
the class DVParser method train.
public void train(List<Tree> sentences, IdentityHashMap<Tree, byte[]> compressedParses, Treebank testTreebank, String modelPath, String resultsRecordPath) throws IOException {
// process:
// we come up with a cost and a derivative for the model
// we always use the gold tree as the example to train towards
// every time through, we will look at the top N trees from
// the LexicalizedParser and pick the best one according to
// our model (at the start, this is essentially random)
// we use QN to minimize the cost function for the model
// to do this minimization, we turn all of the matrices in the
// DVModel into one big Theta, which is the set of variables to
// be optimized by the QN.
Timing timing = new Timing();
long maxTrainTimeMillis = op.trainOptions.maxTrainTimeSeconds * 1000;
int batchCount = 0;
int debugCycle = 0;
double bestLabelF1 = 0.0;
if (op.trainOptions.useContextWords) {
for (Tree tree : sentences) {
Trees.convertToCoreLabels(tree);
tree.setSpans();
}
}
// for AdaGrad
double[] sumGradSquare = new double[dvModel.totalParamSize()];
Arrays.fill(sumGradSquare, 1.0);
int numBatches = sentences.size() / op.trainOptions.batchSize + 1;
log.info("Training on " + sentences.size() + " trees in " + numBatches + " batches");
log.info("Times through each training batch: " + op.trainOptions.trainingIterations);
log.info("QN iterations per batch: " + op.trainOptions.qnIterationsPerBatch);
for (int iter = 0; iter < op.trainOptions.trainingIterations; ++iter) {
List<Tree> shuffledSentences = new ArrayList<>(sentences);
Collections.shuffle(shuffledSentences, dvModel.rand);
for (int batch = 0; batch < numBatches; ++batch) {
++batchCount;
// This did not help performance
//log.info("Setting AdaGrad's sum of squares to 1...");
//Arrays.fill(sumGradSquare, 1.0);
log.info("======================================");
log.info("Iteration " + iter + " batch " + batch);
// Each batch will be of the specified batch size, except the
// last batch will include any leftover trees at the end of
// the list
int startTree = batch * op.trainOptions.batchSize;
int endTree = (batch + 1) * op.trainOptions.batchSize;
if (endTree > shuffledSentences.size()) {
endTree = shuffledSentences.size();
}
executeOneTrainingBatch(shuffledSentences.subList(startTree, endTree), compressedParses, sumGradSquare);
long totalElapsed = timing.report();
log.info("Finished iteration " + iter + " batch " + batch + "; total training time " + totalElapsed + " ms");
if (maxTrainTimeMillis > 0 && totalElapsed > maxTrainTimeMillis) {
// no need to debug output, we're done now
break;
}
if (op.trainOptions.debugOutputFrequency > 0 && batchCount % op.trainOptions.debugOutputFrequency == 0) {
log.info("Finished " + batchCount + " total batches, running evaluation cycle");
// Time for debugging output!
double tagF1 = 0.0;
double labelF1 = 0.0;
if (testTreebank != null) {
EvaluateTreebank evaluator = new EvaluateTreebank(attachModelToLexicalizedParser());
evaluator.testOnTreebank(testTreebank);
labelF1 = evaluator.getLBScore();
tagF1 = evaluator.getTagScore();
if (labelF1 > bestLabelF1) {
bestLabelF1 = labelF1;
}
log.info("Best label f1 on dev set so far: " + NF.format(bestLabelF1));
}
String tempName = null;
if (modelPath != null) {
tempName = modelPath;
if (modelPath.endsWith(".ser.gz")) {
tempName = modelPath.substring(0, modelPath.length() - 7) + "-" + FILENAME.format(debugCycle) + "-" + NF.format(labelF1) + ".ser.gz";
}
saveModel(tempName);
}
String statusLine = ("CHECKPOINT:" + " iteration " + iter + " batch " + batch + " labelF1 " + NF.format(labelF1) + " tagF1 " + NF.format(tagF1) + " bestLabelF1 " + NF.format(bestLabelF1) + " model " + tempName + op.trainOptions + " word vectors: " + op.lexOptions.wordVectorFile + " numHid: " + op.lexOptions.numHid);
log.info(statusLine);
if (resultsRecordPath != null) {
// append
FileWriter fout = new FileWriter(resultsRecordPath, true);
fout.write(statusLine);
fout.write("\n");
fout.close();
}
++debugCycle;
}
}
long totalElapsed = timing.report();
if (maxTrainTimeMillis > 0 && totalElapsed > maxTrainTimeMillis) {
// no need to debug output, we're done now
log.info("Max training time exceeded, exiting");
break;
}
}
}
use of edu.stanford.nlp.parser.lexparser.EvaluateTreebank in project CoreNLP by stanfordnlp.
the class ShiftReduceParser method main.
public static void main(String[] args) {
List<String> remainingArgs = Generics.newArrayList();
List<Pair<String, FileFilter>> trainTreebankPath = null;
Pair<String, FileFilter> testTreebankPath = null;
Pair<String, FileFilter> devTreebankPath = null;
String serializedPath = null;
String tlppClass = null;
String continueTraining = null;
for (int argIndex = 0; argIndex < args.length; ) {
if (args[argIndex].equalsIgnoreCase("-trainTreebank")) {
if (trainTreebankPath == null) {
trainTreebankPath = Generics.newArrayList();
}
trainTreebankPath.add(ArgUtils.getTreebankDescription(args, argIndex, "-trainTreebank"));
argIndex = argIndex + ArgUtils.numSubArgs(args, argIndex) + 1;
} else if (args[argIndex].equalsIgnoreCase("-testTreebank")) {
testTreebankPath = ArgUtils.getTreebankDescription(args, argIndex, "-testTreebank");
argIndex = argIndex + ArgUtils.numSubArgs(args, argIndex) + 1;
} else if (args[argIndex].equalsIgnoreCase("-devTreebank")) {
devTreebankPath = ArgUtils.getTreebankDescription(args, argIndex, "-devTreebank");
argIndex = argIndex + ArgUtils.numSubArgs(args, argIndex) + 1;
} else if (args[argIndex].equalsIgnoreCase("-serializedPath") || args[argIndex].equalsIgnoreCase("-model")) {
serializedPath = args[argIndex + 1];
argIndex += 2;
} else if (args[argIndex].equalsIgnoreCase("-tlpp")) {
tlppClass = args[argIndex + 1];
argIndex += 2;
} else if (args[argIndex].equalsIgnoreCase("-continueTraining")) {
continueTraining = args[argIndex + 1];
argIndex += 2;
} else {
remainingArgs.add(args[argIndex]);
++argIndex;
}
}
String[] newArgs = new String[remainingArgs.size()];
newArgs = remainingArgs.toArray(newArgs);
if (trainTreebankPath == null && serializedPath == null) {
throw new IllegalArgumentException("Must specify a treebank to train from with -trainTreebank or a parser to load with -serializedPath");
}
ShiftReduceParser parser = null;
if (trainTreebankPath != null) {
log.info("Training ShiftReduceParser");
log.info("Initial arguments:");
log.info(" " + StringUtils.join(args));
if (continueTraining != null) {
parser = ShiftReduceParser.loadModel(continueTraining, ArrayUtils.concatenate(FORCE_TAGS, newArgs));
} else {
ShiftReduceOptions op = buildTrainingOptions(tlppClass, newArgs);
parser = new ShiftReduceParser(op);
}
parser.train(trainTreebankPath, devTreebankPath, serializedPath);
parser.saveModel(serializedPath);
}
if (serializedPath != null && parser == null) {
parser = ShiftReduceParser.loadModel(serializedPath, ArrayUtils.concatenate(FORCE_TAGS, newArgs));
}
if (testTreebankPath != null) {
log.info("Loading test trees from " + testTreebankPath.first());
Treebank testTreebank = parser.op.tlpParams.memoryTreebank();
testTreebank.loadPath(testTreebankPath.first(), testTreebankPath.second());
log.info("Loaded " + testTreebank.size() + " trees");
EvaluateTreebank evaluator = new EvaluateTreebank(parser.op, null, parser);
evaluator.testOnTreebank(testTreebank);
// log.info("Input tree: " + tree);
// log.info("Debinarized tree: " + query.getBestParse());
// log.info("Parsed binarized tree: " + query.getBestBinarizedParse());
// log.info("Predicted transition sequence: " + query.getBestTransitionSequence());
}
}
use of edu.stanford.nlp.parser.lexparser.EvaluateTreebank in project CoreNLP by stanfordnlp.
the class CombineDVModels method main.
public static void main(String[] args) throws IOException, ClassNotFoundException {
String modelPath = null;
List<String> baseModelPaths = null;
String testTreebankPath = null;
FileFilter testTreebankFilter = null;
List<String> unusedArgs = new ArrayList<>();
for (int argIndex = 0; argIndex < args.length; ) {
if (args[argIndex].equalsIgnoreCase("-model")) {
modelPath = args[argIndex + 1];
argIndex += 2;
} else if (args[argIndex].equalsIgnoreCase("-testTreebank")) {
Pair<String, FileFilter> treebankDescription = ArgUtils.getTreebankDescription(args, argIndex, "-testTreebank");
argIndex = argIndex + ArgUtils.numSubArgs(args, argIndex) + 1;
testTreebankPath = treebankDescription.first();
testTreebankFilter = treebankDescription.second();
} else if (args[argIndex].equalsIgnoreCase("-baseModels")) {
argIndex++;
baseModelPaths = new ArrayList<>();
while (argIndex < args.length && args[argIndex].charAt(0) != '-') {
baseModelPaths.add(args[argIndex++]);
}
if (baseModelPaths.size() == 0) {
throw new IllegalArgumentException("Found an argument -baseModels with no actual models named");
}
} else {
unusedArgs.add(args[argIndex++]);
}
}
String[] newArgs = unusedArgs.toArray(new String[unusedArgs.size()]);
LexicalizedParser underlyingParser = null;
Options options = null;
LexicalizedParser combinedParser = null;
if (baseModelPaths != null) {
List<DVModel> dvparsers = new ArrayList<>();
for (String baseModelPath : baseModelPaths) {
log.info("Loading serialized DVParser from " + baseModelPath);
LexicalizedParser dvparser = LexicalizedParser.loadModel(baseModelPath);
Reranker reranker = dvparser.reranker;
if (!(reranker instanceof DVModelReranker)) {
throw new IllegalArgumentException("Expected parsers with DVModel embedded");
}
dvparsers.add(((DVModelReranker) reranker).getModel());
if (underlyingParser == null) {
underlyingParser = dvparser;
options = underlyingParser.getOp();
// TODO: other parser's options?
options.setOptions(newArgs);
}
log.info("... done");
}
combinedParser = LexicalizedParser.copyLexicalizedParser(underlyingParser);
CombinedDVModelReranker reranker = new CombinedDVModelReranker(options, dvparsers);
combinedParser.reranker = reranker;
combinedParser.saveParserToSerialized(modelPath);
} else {
throw new IllegalArgumentException("Need to specify -model to load an already prepared CombinedParser");
}
Treebank testTreebank = null;
if (testTreebankPath != null) {
log.info("Reading in trees from " + testTreebankPath);
if (testTreebankFilter != null) {
log.info("Filtering on " + testTreebankFilter);
}
testTreebank = combinedParser.getOp().tlpParams.memoryTreebank();
;
testTreebank.loadPath(testTreebankPath, testTreebankFilter);
log.info("Read in " + testTreebank.size() + " trees for testing");
EvaluateTreebank evaluator = new EvaluateTreebank(combinedParser.getOp(), null, combinedParser);
evaluator.testOnTreebank(testTreebank);
}
}
Aggregations