use of edu.stanford.nlp.parser.lexparser.LexicalizedParser in project CoreNLP by stanfordnlp.
the class ParseAndPrintMatrices method main.
public static void main(String[] args) throws IOException {
String modelPath = null;
String outputPath = null;
String inputPath = null;
String testTreebankPath = null;
FileFilter testTreebankFilter = null;
List<String> unusedArgs = Generics.newArrayList();
for (int argIndex = 0; argIndex < args.length; ) {
if (args[argIndex].equalsIgnoreCase("-model")) {
modelPath = args[argIndex + 1];
argIndex += 2;
} else if (args[argIndex].equalsIgnoreCase("-output")) {
outputPath = args[argIndex + 1];
argIndex += 2;
} else if (args[argIndex].equalsIgnoreCase("-input")) {
inputPath = args[argIndex + 1];
argIndex += 2;
} else if (args[argIndex].equalsIgnoreCase("-testTreebank")) {
Pair<String, FileFilter> treebankDescription = ArgUtils.getTreebankDescription(args, argIndex, "-testTreebank");
argIndex = argIndex + ArgUtils.numSubArgs(args, argIndex) + 1;
testTreebankPath = treebankDescription.first();
testTreebankFilter = treebankDescription.second();
} else {
unusedArgs.add(args[argIndex++]);
}
}
String[] newArgs = unusedArgs.toArray(new String[unusedArgs.size()]);
LexicalizedParser parser = LexicalizedParser.loadModel(modelPath, newArgs);
DVModel model = DVParser.getModelFromLexicalizedParser(parser);
File outputFile = new File(outputPath);
FileSystem.checkNotExistsOrFail(outputFile);
FileSystem.mkdirOrFail(outputFile);
int count = 0;
if (inputPath != null) {
Reader input = new BufferedReader(new FileReader(inputPath));
DocumentPreprocessor processor = new DocumentPreprocessor(input);
for (List<HasWord> sentence : processor) {
// index from 1
count++;
ParserQuery pq = parser.parserQuery();
if (!(pq instanceof RerankingParserQuery)) {
throw new IllegalArgumentException("Expected a RerankingParserQuery");
}
RerankingParserQuery rpq = (RerankingParserQuery) pq;
if (!rpq.parse(sentence)) {
throw new RuntimeException("Unparsable sentence: " + sentence);
}
RerankerQuery reranker = rpq.rerankerQuery();
if (!(reranker instanceof DVModelReranker.Query)) {
throw new IllegalArgumentException("Expected a DVModelReranker");
}
DeepTree deepTree = ((DVModelReranker.Query) reranker).getDeepTrees().get(0);
IdentityHashMap<Tree, SimpleMatrix> vectors = deepTree.getVectors();
for (Map.Entry<Tree, SimpleMatrix> entry : vectors.entrySet()) {
log.info(entry.getKey() + " " + entry.getValue());
}
FileWriter fout = new FileWriter(outputPath + File.separator + "sentence" + count + ".txt");
BufferedWriter bout = new BufferedWriter(fout);
bout.write(SentenceUtils.listToString(sentence));
bout.newLine();
bout.write(deepTree.getTree().toString());
bout.newLine();
for (HasWord word : sentence) {
outputMatrix(bout, model.getWordVector(word.word()));
}
Tree rootTree = findRootTree(vectors);
outputTreeMatrices(bout, rootTree, vectors);
bout.flush();
fout.close();
}
}
}
use of edu.stanford.nlp.parser.lexparser.LexicalizedParser in project CoreNLP by stanfordnlp.
the class DVParser method saveModel.
public void saveModel(String filename) {
log.info("Saving serialized model to " + filename);
LexicalizedParser newParser = attachModelToLexicalizedParser();
newParser.saveParserToSerialized(filename);
log.info("... done");
}
use of edu.stanford.nlp.parser.lexparser.LexicalizedParser in project CoreNLP by stanfordnlp.
the class DVParser method attachModelToLexicalizedParser.
public LexicalizedParser attachModelToLexicalizedParser() {
LexicalizedParser newParser = LexicalizedParser.copyLexicalizedParser(parser);
DVModelReranker reranker = new DVModelReranker(dvModel);
newParser.reranker = reranker;
return newParser;
}
use of edu.stanford.nlp.parser.lexparser.LexicalizedParser in project CoreNLP by stanfordnlp.
the class DVParser method main.
/**
* An example command line for training a new parser:
* <br>
* nohup java -mx6g edu.stanford.nlp.parser.dvparser.DVParser -cachedTrees /scr/nlp/data/dvparser/wsj/cached.wsj.train.simple.ser.gz -train -testTreebank /afs/ir/data/linguistic-data/Treebank/3/parsed/mrg/wsj/22 2200-2219 -debugOutputFrequency 400 -nofilter -trainingThreads 5 -parser /u/nlp/data/lexparser/wsjPCFG.nocompact.simple.ser.gz -trainingIterations 40 -batchSize 25 -model /scr/nlp/data/dvparser/wsj/wsj.combine.v2.ser.gz -unkWord "*UNK*" -dvCombineCategories > /scr/nlp/data/dvparser/wsj/wsj.combine.v2.out 2>&1 &
*/
public static void main(String[] args) throws IOException, ClassNotFoundException {
if (args.length == 0) {
help();
System.exit(2);
}
log.info("Running DVParser with arguments:");
for (String arg : args) {
log.info(" " + arg);
}
log.info();
String parserPath = null;
String trainTreebankPath = null;
FileFilter trainTreebankFilter = null;
String cachedTrainTreesPath = null;
boolean runGradientCheck = false;
boolean runTraining = false;
String testTreebankPath = null;
FileFilter testTreebankFilter = null;
String initialModelPath = null;
String modelPath = null;
boolean filter = true;
String resultsRecordPath = null;
List<String> unusedArgs = new ArrayList<>();
// These parameters can be null or 0 if the model was not
// serialized with the new parameters. Setting the options at the
// command line will override these defaults.
// TODO: if/when we integrate back into the main branch and
// rebuild models, we can get rid of this
List<String> argsWithDefaults = new ArrayList<>(Arrays.asList(new String[] { "-wordVectorFile", Options.LexOptions.DEFAULT_WORD_VECTOR_FILE, "-dvKBest", Integer.toString(TrainOptions.DEFAULT_K_BEST), "-batchSize", Integer.toString(TrainOptions.DEFAULT_BATCH_SIZE), "-trainingIterations", Integer.toString(TrainOptions.DEFAULT_TRAINING_ITERATIONS), "-qnIterationsPerBatch", Integer.toString(TrainOptions.DEFAULT_QN_ITERATIONS_PER_BATCH), "-regCost", Double.toString(TrainOptions.DEFAULT_REGCOST), "-learningRate", Double.toString(TrainOptions.DEFAULT_LEARNING_RATE), "-deltaMargin", Double.toString(TrainOptions.DEFAULT_DELTA_MARGIN), "-unknownNumberVector", "-unknownDashedWordVectors", "-unknownCapsVector", "-unknownchinesepercentvector", "-unknownchinesenumbervector", "-unknownchineseyearvector", "-unkWord", "*UNK*", "-transformMatrixType", "DIAGONAL", "-scalingForInit", Double.toString(TrainOptions.DEFAULT_SCALING_FOR_INIT), "-trainWordVectors" }));
argsWithDefaults.addAll(Arrays.asList(args));
args = argsWithDefaults.toArray(new String[argsWithDefaults.size()]);
for (int argIndex = 0; argIndex < args.length; ) {
if (args[argIndex].equalsIgnoreCase("-parser")) {
parserPath = args[argIndex + 1];
argIndex += 2;
} else if (args[argIndex].equalsIgnoreCase("-testTreebank")) {
Pair<String, FileFilter> treebankDescription = ArgUtils.getTreebankDescription(args, argIndex, "-testTreebank");
argIndex = argIndex + ArgUtils.numSubArgs(args, argIndex) + 1;
testTreebankPath = treebankDescription.first();
testTreebankFilter = treebankDescription.second();
} else if (args[argIndex].equalsIgnoreCase("-treebank")) {
Pair<String, FileFilter> treebankDescription = ArgUtils.getTreebankDescription(args, argIndex, "-treebank");
argIndex = argIndex + ArgUtils.numSubArgs(args, argIndex) + 1;
trainTreebankPath = treebankDescription.first();
trainTreebankFilter = treebankDescription.second();
} else if (args[argIndex].equalsIgnoreCase("-cachedTrees")) {
cachedTrainTreesPath = args[argIndex + 1];
argIndex += 2;
} else if (args[argIndex].equalsIgnoreCase("-runGradientCheck")) {
runGradientCheck = true;
argIndex++;
} else if (args[argIndex].equalsIgnoreCase("-train")) {
runTraining = true;
argIndex++;
} else if (args[argIndex].equalsIgnoreCase("-model")) {
modelPath = args[argIndex + 1];
argIndex += 2;
} else if (args[argIndex].equalsIgnoreCase("-nofilter")) {
filter = false;
argIndex++;
} else if (args[argIndex].equalsIgnoreCase("-continueTraining")) {
runTraining = true;
filter = false;
initialModelPath = args[argIndex + 1];
argIndex += 2;
} else if (args[argIndex].equalsIgnoreCase("-resultsRecord")) {
resultsRecordPath = args[argIndex + 1];
argIndex += 2;
} else {
unusedArgs.add(args[argIndex++]);
}
}
if (parserPath == null && modelPath == null) {
throw new IllegalArgumentException("Must supply either a base parser model with -parser or a serialized DVParser with -model");
}
if (!runTraining && modelPath == null && !runGradientCheck) {
throw new IllegalArgumentException("Need to either train a new model, run the gradient check or specify a model to load with -model");
}
String[] newArgs = unusedArgs.toArray(new String[unusedArgs.size()]);
DVParser dvparser = null;
LexicalizedParser lexparser = null;
if (initialModelPath != null) {
lexparser = LexicalizedParser.loadModel(initialModelPath, newArgs);
DVModel model = getModelFromLexicalizedParser(lexparser);
dvparser = new DVParser(model, lexparser);
} else if (runTraining || runGradientCheck) {
lexparser = LexicalizedParser.loadModel(parserPath, newArgs);
dvparser = new DVParser(lexparser);
} else if (modelPath != null) {
lexparser = LexicalizedParser.loadModel(modelPath, newArgs);
DVModel model = getModelFromLexicalizedParser(lexparser);
dvparser = new DVParser(model, lexparser);
}
List<Tree> trainSentences = new ArrayList<>();
IdentityHashMap<Tree, byte[]> trainCompressedParses = Generics.newIdentityHashMap();
if (cachedTrainTreesPath != null) {
for (String path : cachedTrainTreesPath.split(",")) {
List<Pair<Tree, byte[]>> cache = IOUtils.readObjectFromFile(path);
for (Pair<Tree, byte[]> pair : cache) {
trainSentences.add(pair.first());
trainCompressedParses.put(pair.first(), pair.second());
}
log.info("Read in " + cache.size() + " trees from " + path);
}
}
if (trainTreebankPath != null) {
// TODO: make the transformer a member of the model?
TreeTransformer transformer = buildTrainTransformer(dvparser.getOp());
Treebank treebank = dvparser.getOp().tlpParams.memoryTreebank();
;
treebank.loadPath(trainTreebankPath, trainTreebankFilter);
treebank = treebank.transform(transformer);
log.info("Read in " + treebank.size() + " trees from " + trainTreebankPath);
CacheParseHypotheses cacher = new CacheParseHypotheses(dvparser.parser);
CacheParseHypotheses.CacheProcessor processor = new CacheParseHypotheses.CacheProcessor(cacher, lexparser, dvparser.op.trainOptions.dvKBest, transformer);
for (Tree tree : treebank) {
trainSentences.add(tree);
trainCompressedParses.put(tree, processor.process(tree).second);
//System.out.println(tree);
}
log.info("Finished parsing " + treebank.size() + " trees, getting " + dvparser.op.trainOptions.dvKBest + " hypotheses each");
}
if ((runTraining || runGradientCheck) && filter) {
log.info("Filtering rules for the given training set");
dvparser.dvModel.setRulesForTrainingSet(trainSentences, trainCompressedParses);
log.info("Done filtering rules; " + dvparser.dvModel.numBinaryMatrices + " binary matrices, " + dvparser.dvModel.numUnaryMatrices + " unary matrices, " + dvparser.dvModel.wordVectors.size() + " word vectors");
}
//dvparser.dvModel.printAllMatrices();
Treebank testTreebank = null;
if (testTreebankPath != null) {
log.info("Reading in trees from " + testTreebankPath);
if (testTreebankFilter != null) {
log.info("Filtering on " + testTreebankFilter);
}
testTreebank = dvparser.getOp().tlpParams.memoryTreebank();
;
testTreebank.loadPath(testTreebankPath, testTreebankFilter);
log.info("Read in " + testTreebank.size() + " trees for testing");
}
// runGradientCheck= true;
if (runGradientCheck) {
log.info("Running gradient check on " + trainSentences.size() + " trees");
dvparser.runGradientCheck(trainSentences, trainCompressedParses);
}
if (runTraining) {
log.info("Training the RNN parser");
log.info("Current train options: " + dvparser.getOp().trainOptions);
dvparser.train(trainSentences, trainCompressedParses, testTreebank, modelPath, resultsRecordPath);
if (modelPath != null) {
dvparser.saveModel(modelPath);
}
}
if (testTreebankPath != null) {
EvaluateTreebank evaluator = new EvaluateTreebank(dvparser.attachModelToLexicalizedParser());
evaluator.testOnTreebank(testTreebank);
}
log.info("Successfully ran DVParser");
}
use of edu.stanford.nlp.parser.lexparser.LexicalizedParser in project CoreNLP by stanfordnlp.
the class UpdateParserOptions method main.
public static void main(String[] args) {
String input = null;
String output = null;
List<String> extraArgs = Generics.newArrayList();
for (int argIndex = 0; argIndex < args.length; ) {
if (args[argIndex].equalsIgnoreCase("-input")) {
input = args[argIndex + 1];
argIndex += 2;
} else if (args[argIndex].equalsIgnoreCase("-output")) {
output = args[argIndex + 1];
argIndex += 2;
} else {
extraArgs.add(args[argIndex++]);
}
}
LexicalizedParser parser = LexicalizedParser.loadModel(input, extraArgs);
parser.saveParserToSerialized(output);
}
Aggregations