use of edu.stanford.nlp.trees.DeepTree in project CoreNLP by stanfordnlp.
the class DVParserCostAndGradient method calculate.
// fill value & derivative
public void calculate(double[] theta) {
dvModel.vectorToParams(theta);
double localValue = 0.0;
double[] localDerivative = new double[theta.length];
TwoDimensionalMap<String, String, SimpleMatrix> binaryW_dfsG, binaryW_dfsB;
binaryW_dfsG = TwoDimensionalMap.treeMap();
binaryW_dfsB = TwoDimensionalMap.treeMap();
TwoDimensionalMap<String, String, SimpleMatrix> binaryScoreDerivativesG, binaryScoreDerivativesB;
binaryScoreDerivativesG = TwoDimensionalMap.treeMap();
binaryScoreDerivativesB = TwoDimensionalMap.treeMap();
Map<String, SimpleMatrix> unaryW_dfsG, unaryW_dfsB;
unaryW_dfsG = new TreeMap<>();
unaryW_dfsB = new TreeMap<>();
Map<String, SimpleMatrix> unaryScoreDerivativesG, unaryScoreDerivativesB;
unaryScoreDerivativesG = new TreeMap<>();
unaryScoreDerivativesB = new TreeMap<>();
Map<String, SimpleMatrix> wordVectorDerivativesG = new TreeMap<>();
Map<String, SimpleMatrix> wordVectorDerivativesB = new TreeMap<>();
for (TwoDimensionalMap.Entry<String, String, SimpleMatrix> entry : dvModel.binaryTransform) {
int numRows = entry.getValue().numRows();
int numCols = entry.getValue().numCols();
binaryW_dfsG.put(entry.getFirstKey(), entry.getSecondKey(), new SimpleMatrix(numRows, numCols));
binaryW_dfsB.put(entry.getFirstKey(), entry.getSecondKey(), new SimpleMatrix(numRows, numCols));
binaryScoreDerivativesG.put(entry.getFirstKey(), entry.getSecondKey(), new SimpleMatrix(1, numRows));
binaryScoreDerivativesB.put(entry.getFirstKey(), entry.getSecondKey(), new SimpleMatrix(1, numRows));
}
for (Map.Entry<String, SimpleMatrix> entry : dvModel.unaryTransform.entrySet()) {
int numRows = entry.getValue().numRows();
int numCols = entry.getValue().numCols();
unaryW_dfsG.put(entry.getKey(), new SimpleMatrix(numRows, numCols));
unaryW_dfsB.put(entry.getKey(), new SimpleMatrix(numRows, numCols));
unaryScoreDerivativesG.put(entry.getKey(), new SimpleMatrix(1, numRows));
unaryScoreDerivativesB.put(entry.getKey(), new SimpleMatrix(1, numRows));
}
if (op.trainOptions.trainWordVectors) {
for (Map.Entry<String, SimpleMatrix> entry : dvModel.wordVectors.entrySet()) {
int numRows = entry.getValue().numRows();
int numCols = entry.getValue().numCols();
wordVectorDerivativesG.put(entry.getKey(), new SimpleMatrix(numRows, numCols));
wordVectorDerivativesB.put(entry.getKey(), new SimpleMatrix(numRows, numCols));
}
}
// Some optimization methods prints out a line without an end, so our
// debugging statements are misaligned
Timing scoreTiming = new Timing();
scoreTiming.doing("Scoring trees");
int treeNum = 0;
MulticoreWrapper<Tree, Pair<DeepTree, DeepTree>> wrapper = new MulticoreWrapper<>(op.trainOptions.trainingThreads, new ScoringProcessor());
for (Tree tree : trainingBatch) {
wrapper.put(tree);
}
wrapper.join();
scoreTiming.done();
while (wrapper.peek()) {
Pair<DeepTree, DeepTree> result = wrapper.poll();
DeepTree goldTree = result.first;
DeepTree bestTree = result.second;
StringBuilder treeDebugLine = new StringBuilder();
Formatter formatter = new Formatter(treeDebugLine);
boolean isDone = (Math.abs(bestTree.getScore() - goldTree.getScore()) <= 0.00001 || goldTree.getScore() > bestTree.getScore());
String done = isDone ? "done" : "";
formatter.format("Tree %6d Highest tree: %12.4f Correct tree: %12.4f %s", treeNum, bestTree.getScore(), goldTree.getScore(), done);
log.info(treeDebugLine.toString());
if (!isDone) {
// if the gold tree is better than the best hypothesis tree by
// a large enough margin, then the score difference will be 0
// and we ignore the tree
double valueDelta = bestTree.getScore() - goldTree.getScore();
//double valueDelta = Math.max(0.0, - scoreGold + bestScore);
localValue += valueDelta;
// get the context words for this tree - should be the same
// for either goldTree or bestTree
List<String> words = getContextWords(goldTree.getTree());
// The derivatives affected by this tree are only based on the
// nodes present in this tree, eg not all matrix derivatives
// will be affected by this tree
backpropDerivative(goldTree.getTree(), words, goldTree.getVectors(), binaryW_dfsG, unaryW_dfsG, binaryScoreDerivativesG, unaryScoreDerivativesG, wordVectorDerivativesG);
backpropDerivative(bestTree.getTree(), words, bestTree.getVectors(), binaryW_dfsB, unaryW_dfsB, binaryScoreDerivativesB, unaryScoreDerivativesB, wordVectorDerivativesB);
}
++treeNum;
}
double[] localDerivativeGood;
double[] localDerivativeB;
if (op.trainOptions.trainWordVectors) {
localDerivativeGood = NeuralUtils.paramsToVector(theta.length, binaryW_dfsG.valueIterator(), unaryW_dfsG.values().iterator(), binaryScoreDerivativesG.valueIterator(), unaryScoreDerivativesG.values().iterator(), wordVectorDerivativesG.values().iterator());
localDerivativeB = NeuralUtils.paramsToVector(theta.length, binaryW_dfsB.valueIterator(), unaryW_dfsB.values().iterator(), binaryScoreDerivativesB.valueIterator(), unaryScoreDerivativesB.values().iterator(), wordVectorDerivativesB.values().iterator());
} else {
localDerivativeGood = NeuralUtils.paramsToVector(theta.length, binaryW_dfsG.valueIterator(), unaryW_dfsG.values().iterator(), binaryScoreDerivativesG.valueIterator(), unaryScoreDerivativesG.values().iterator());
localDerivativeB = NeuralUtils.paramsToVector(theta.length, binaryW_dfsB.valueIterator(), unaryW_dfsB.values().iterator(), binaryScoreDerivativesB.valueIterator(), unaryScoreDerivativesB.values().iterator());
}
// correct - highest
for (int i = 0; i < localDerivativeGood.length; i++) {
localDerivative[i] = localDerivativeB[i] - localDerivativeGood[i];
}
// TODO: this is where we would combine multiple costs if we had parallelized the calculation
value = localValue;
derivative = localDerivative;
// normalizing by training batch size
value = (1.0 / trainingBatch.size()) * value;
ArrayMath.multiplyInPlace(derivative, (1.0 / trainingBatch.size()));
// add regularization to cost:
double[] currentParams = dvModel.paramsToVector();
double regCost = 0;
for (double currentParam : currentParams) {
regCost += currentParam * currentParam;
}
regCost = op.trainOptions.regCost * 0.5 * regCost;
value += regCost;
// add regularization to gradient
ArrayMath.multiplyInPlace(currentParams, op.trainOptions.regCost);
ArrayMath.pairwiseAddInPlace(derivative, currentParams);
}
use of edu.stanford.nlp.trees.DeepTree in project CoreNLP by stanfordnlp.
the class FindNearestNeighbors method main.
public static void main(String[] args) throws Exception {
String modelPath = null;
String outputPath = null;
String testTreebankPath = null;
FileFilter testTreebankFilter = null;
List<String> unusedArgs = new ArrayList<>();
for (int argIndex = 0; argIndex < args.length; ) {
if (args[argIndex].equalsIgnoreCase("-model")) {
modelPath = args[argIndex + 1];
argIndex += 2;
} else if (args[argIndex].equalsIgnoreCase("-testTreebank")) {
Pair<String, FileFilter> treebankDescription = ArgUtils.getTreebankDescription(args, argIndex, "-testTreebank");
argIndex = argIndex + ArgUtils.numSubArgs(args, argIndex) + 1;
testTreebankPath = treebankDescription.first();
testTreebankFilter = treebankDescription.second();
} else if (args[argIndex].equalsIgnoreCase("-output")) {
outputPath = args[argIndex + 1];
argIndex += 2;
} else {
unusedArgs.add(args[argIndex++]);
}
}
if (modelPath == null) {
throw new IllegalArgumentException("Need to specify -model");
}
if (testTreebankPath == null) {
throw new IllegalArgumentException("Need to specify -testTreebank");
}
if (outputPath == null) {
throw new IllegalArgumentException("Need to specify -output");
}
String[] newArgs = unusedArgs.toArray(new String[unusedArgs.size()]);
LexicalizedParser lexparser = LexicalizedParser.loadModel(modelPath, newArgs);
Treebank testTreebank = null;
if (testTreebankPath != null) {
log.info("Reading in trees from " + testTreebankPath);
if (testTreebankFilter != null) {
log.info("Filtering on " + testTreebankFilter);
}
testTreebank = lexparser.getOp().tlpParams.memoryTreebank();
;
testTreebank.loadPath(testTreebankPath, testTreebankFilter);
log.info("Read in " + testTreebank.size() + " trees for testing");
}
FileWriter out = new FileWriter(outputPath);
BufferedWriter bout = new BufferedWriter(out);
log.info("Parsing " + testTreebank.size() + " trees");
int count = 0;
List<ParseRecord> records = Generics.newArrayList();
for (Tree goldTree : testTreebank) {
List<Word> tokens = goldTree.yieldWords();
ParserQuery parserQuery = lexparser.parserQuery();
if (!parserQuery.parse(tokens)) {
throw new AssertionError("Could not parse: " + tokens);
}
if (!(parserQuery instanceof RerankingParserQuery)) {
throw new IllegalArgumentException("Expected a LexicalizedParser with a Reranker attached");
}
RerankingParserQuery rpq = (RerankingParserQuery) parserQuery;
if (!(rpq.rerankerQuery() instanceof DVModelReranker.Query)) {
throw new IllegalArgumentException("Expected a LexicalizedParser with a DVModel attached");
}
DeepTree tree = ((DVModelReranker.Query) rpq.rerankerQuery()).getDeepTrees().get(0);
SimpleMatrix rootVector = null;
for (Map.Entry<Tree, SimpleMatrix> entry : tree.getVectors().entrySet()) {
if (entry.getKey().label().value().equals("ROOT")) {
rootVector = entry.getValue();
break;
}
}
if (rootVector == null) {
throw new AssertionError("Could not find root nodevector");
}
out.write(tokens + "\n");
out.write(tree.getTree() + "\n");
for (int i = 0; i < rootVector.getNumElements(); ++i) {
out.write(" " + rootVector.get(i));
}
out.write("\n\n\n");
count++;
if (count % 10 == 0) {
log.info(" " + count);
}
records.add(new ParseRecord(tokens, goldTree, tree.getTree(), rootVector, tree.getVectors()));
}
log.info(" done parsing");
List<Pair<Tree, SimpleMatrix>> subtrees = Generics.newArrayList();
for (ParseRecord record : records) {
for (Map.Entry<Tree, SimpleMatrix> entry : record.nodeVectors.entrySet()) {
if (entry.getKey().getLeaves().size() <= maxLength) {
subtrees.add(Pair.makePair(entry.getKey(), entry.getValue()));
}
}
}
log.info("There are " + subtrees.size() + " subtrees in the set of trees");
PriorityQueue<ScoredObject<Pair<Tree, Tree>>> bestmatches = new PriorityQueue<>(101, ScoredComparator.DESCENDING_COMPARATOR);
for (int i = 0; i < subtrees.size(); ++i) {
log.info(subtrees.get(i).first().yieldWords());
log.info(subtrees.get(i).first());
for (int j = 0; j < subtrees.size(); ++j) {
if (i == j) {
continue;
}
// TODO: look at basic category?
double normF = subtrees.get(i).second().minus(subtrees.get(j).second()).normF();
bestmatches.add(new ScoredObject<>(Pair.makePair(subtrees.get(i).first(), subtrees.get(j).first()), normF));
if (bestmatches.size() > 100) {
bestmatches.poll();
}
}
List<ScoredObject<Pair<Tree, Tree>>> ordered = Generics.newArrayList();
while (bestmatches.size() > 0) {
ordered.add(bestmatches.poll());
}
Collections.reverse(ordered);
for (ScoredObject<Pair<Tree, Tree>> pair : ordered) {
log.info(" MATCHED " + pair.object().second.yieldWords() + " ... " + pair.object().second() + " with a score of " + pair.score());
}
log.info();
log.info();
bestmatches.clear();
}
/*
for (int i = 0; i < records.size(); ++i) {
if (i % 10 == 0) {
log.info(" " + i);
}
List<ScoredObject<ParseRecord>> scored = Generics.newArrayList();
for (int j = 0; j < records.size(); ++j) {
if (i == j) continue;
double score = 0.0;
int matches = 0;
for (Map.Entry<Tree, SimpleMatrix> first : records.get(i).nodeVectors.entrySet()) {
for (Map.Entry<Tree, SimpleMatrix> second : records.get(j).nodeVectors.entrySet()) {
String firstBasic = dvparser.dvModel.basicCategory(first.getKey().label().value());
String secondBasic = dvparser.dvModel.basicCategory(second.getKey().label().value());
if (firstBasic.equals(secondBasic)) {
++matches;
double normF = first.getValue().minus(second.getValue()).normF();
score += normF * normF;
}
}
}
if (matches == 0) {
score = Double.POSITIVE_INFINITY;
} else {
score = score / matches;
}
//double score = records.get(i).vector.minus(records.get(j).vector).normF();
scored.add(new ScoredObject<ParseRecord>(records.get(j), score));
}
Collections.sort(scored, ScoredComparator.ASCENDING_COMPARATOR);
out.write(records.get(i).sentence.toString() + "\n");
for (int j = 0; j < numNeighbors; ++j) {
out.write(" " + scored.get(j).score() + ": " + scored.get(j).object().sentence + "\n");
}
out.write("\n\n");
}
log.info();
*/
bout.flush();
out.flush();
out.close();
}
use of edu.stanford.nlp.trees.DeepTree in project CoreNLP by stanfordnlp.
the class ParseAndPrintMatrices method main.
public static void main(String[] args) throws IOException {
String modelPath = null;
String outputPath = null;
String inputPath = null;
String testTreebankPath = null;
FileFilter testTreebankFilter = null;
List<String> unusedArgs = Generics.newArrayList();
for (int argIndex = 0; argIndex < args.length; ) {
if (args[argIndex].equalsIgnoreCase("-model")) {
modelPath = args[argIndex + 1];
argIndex += 2;
} else if (args[argIndex].equalsIgnoreCase("-output")) {
outputPath = args[argIndex + 1];
argIndex += 2;
} else if (args[argIndex].equalsIgnoreCase("-input")) {
inputPath = args[argIndex + 1];
argIndex += 2;
} else if (args[argIndex].equalsIgnoreCase("-testTreebank")) {
Pair<String, FileFilter> treebankDescription = ArgUtils.getTreebankDescription(args, argIndex, "-testTreebank");
argIndex = argIndex + ArgUtils.numSubArgs(args, argIndex) + 1;
testTreebankPath = treebankDescription.first();
testTreebankFilter = treebankDescription.second();
} else {
unusedArgs.add(args[argIndex++]);
}
}
String[] newArgs = unusedArgs.toArray(new String[unusedArgs.size()]);
LexicalizedParser parser = LexicalizedParser.loadModel(modelPath, newArgs);
DVModel model = DVParser.getModelFromLexicalizedParser(parser);
File outputFile = new File(outputPath);
FileSystem.checkNotExistsOrFail(outputFile);
FileSystem.mkdirOrFail(outputFile);
int count = 0;
if (inputPath != null) {
Reader input = new BufferedReader(new FileReader(inputPath));
DocumentPreprocessor processor = new DocumentPreprocessor(input);
for (List<HasWord> sentence : processor) {
// index from 1
count++;
ParserQuery pq = parser.parserQuery();
if (!(pq instanceof RerankingParserQuery)) {
throw new IllegalArgumentException("Expected a RerankingParserQuery");
}
RerankingParserQuery rpq = (RerankingParserQuery) pq;
if (!rpq.parse(sentence)) {
throw new RuntimeException("Unparsable sentence: " + sentence);
}
RerankerQuery reranker = rpq.rerankerQuery();
if (!(reranker instanceof DVModelReranker.Query)) {
throw new IllegalArgumentException("Expected a DVModelReranker");
}
DeepTree deepTree = ((DVModelReranker.Query) reranker).getDeepTrees().get(0);
IdentityHashMap<Tree, SimpleMatrix> vectors = deepTree.getVectors();
for (Map.Entry<Tree, SimpleMatrix> entry : vectors.entrySet()) {
log.info(entry.getKey() + " " + entry.getValue());
}
FileWriter fout = new FileWriter(outputPath + File.separator + "sentence" + count + ".txt");
BufferedWriter bout = new BufferedWriter(fout);
bout.write(SentenceUtils.listToString(sentence));
bout.newLine();
bout.write(deepTree.getTree().toString());
bout.newLine();
for (HasWord word : sentence) {
outputMatrix(bout, model.getWordVector(word.word()));
}
Tree rootTree = findRootTree(vectors);
outputTreeMatrices(bout, rootTree, vectors);
bout.flush();
fout.close();
}
}
}
use of edu.stanford.nlp.trees.DeepTree in project CoreNLP by stanfordnlp.
the class DVParserCostAndGradient method getHighestScoringTree.
public DeepTree getHighestScoringTree(Tree tree, double lambda) {
List<Tree> hypotheses = topParses.get(tree);
if (hypotheses == null || hypotheses.size() == 0) {
throw new AssertionError("Failed to get any hypothesis trees for " + tree);
}
double bestScore = Double.NEGATIVE_INFINITY;
Tree bestTree = null;
IdentityHashMap<Tree, SimpleMatrix> bestVectors = null;
for (Tree hypothesis : hypotheses) {
IdentityHashMap<Tree, SimpleMatrix> nodeVectors = new IdentityHashMap<>();
double scoreHyp = score(hypothesis, nodeVectors);
double deltaMargin = 0;
if (lambda != 0) {
//TODO: RS: Play around with this parameter to prevent blowing up of scores
deltaMargin = op.trainOptions.deltaMargin * lambda * getMargin(tree, hypothesis);
}
scoreHyp = scoreHyp + deltaMargin;
if (bestTree == null || scoreHyp > bestScore) {
bestTree = hypothesis;
bestScore = scoreHyp;
bestVectors = nodeVectors;
}
}
DeepTree returnTree = new DeepTree(bestTree, bestVectors, bestScore);
return returnTree;
}
Aggregations