use of org.ejml.simple.SimpleMatrix in project CoreNLP by stanfordnlp.
the class FindNearestNeighbors method main.
public static void main(String[] args) throws Exception {
String modelPath = null;
String outputPath = null;
String testTreebankPath = null;
FileFilter testTreebankFilter = null;
List<String> unusedArgs = new ArrayList<>();
for (int argIndex = 0; argIndex < args.length; ) {
if (args[argIndex].equalsIgnoreCase("-model")) {
modelPath = args[argIndex + 1];
argIndex += 2;
} else if (args[argIndex].equalsIgnoreCase("-testTreebank")) {
Pair<String, FileFilter> treebankDescription = ArgUtils.getTreebankDescription(args, argIndex, "-testTreebank");
argIndex = argIndex + ArgUtils.numSubArgs(args, argIndex) + 1;
testTreebankPath = treebankDescription.first();
testTreebankFilter = treebankDescription.second();
} else if (args[argIndex].equalsIgnoreCase("-output")) {
outputPath = args[argIndex + 1];
argIndex += 2;
} else {
unusedArgs.add(args[argIndex++]);
}
}
if (modelPath == null) {
throw new IllegalArgumentException("Need to specify -model");
}
if (testTreebankPath == null) {
throw new IllegalArgumentException("Need to specify -testTreebank");
}
if (outputPath == null) {
throw new IllegalArgumentException("Need to specify -output");
}
String[] newArgs = unusedArgs.toArray(new String[unusedArgs.size()]);
LexicalizedParser lexparser = LexicalizedParser.loadModel(modelPath, newArgs);
Treebank testTreebank = null;
if (testTreebankPath != null) {
log.info("Reading in trees from " + testTreebankPath);
if (testTreebankFilter != null) {
log.info("Filtering on " + testTreebankFilter);
}
testTreebank = lexparser.getOp().tlpParams.memoryTreebank();
;
testTreebank.loadPath(testTreebankPath, testTreebankFilter);
log.info("Read in " + testTreebank.size() + " trees for testing");
}
FileWriter out = new FileWriter(outputPath);
BufferedWriter bout = new BufferedWriter(out);
log.info("Parsing " + testTreebank.size() + " trees");
int count = 0;
List<ParseRecord> records = Generics.newArrayList();
for (Tree goldTree : testTreebank) {
List<Word> tokens = goldTree.yieldWords();
ParserQuery parserQuery = lexparser.parserQuery();
if (!parserQuery.parse(tokens)) {
throw new AssertionError("Could not parse: " + tokens);
}
if (!(parserQuery instanceof RerankingParserQuery)) {
throw new IllegalArgumentException("Expected a LexicalizedParser with a Reranker attached");
}
RerankingParserQuery rpq = (RerankingParserQuery) parserQuery;
if (!(rpq.rerankerQuery() instanceof DVModelReranker.Query)) {
throw new IllegalArgumentException("Expected a LexicalizedParser with a DVModel attached");
}
DeepTree tree = ((DVModelReranker.Query) rpq.rerankerQuery()).getDeepTrees().get(0);
SimpleMatrix rootVector = null;
for (Map.Entry<Tree, SimpleMatrix> entry : tree.getVectors().entrySet()) {
if (entry.getKey().label().value().equals("ROOT")) {
rootVector = entry.getValue();
break;
}
}
if (rootVector == null) {
throw new AssertionError("Could not find root nodevector");
}
out.write(tokens + "\n");
out.write(tree.getTree() + "\n");
for (int i = 0; i < rootVector.getNumElements(); ++i) {
out.write(" " + rootVector.get(i));
}
out.write("\n\n\n");
count++;
if (count % 10 == 0) {
log.info(" " + count);
}
records.add(new ParseRecord(tokens, goldTree, tree.getTree(), rootVector, tree.getVectors()));
}
log.info(" done parsing");
List<Pair<Tree, SimpleMatrix>> subtrees = Generics.newArrayList();
for (ParseRecord record : records) {
for (Map.Entry<Tree, SimpleMatrix> entry : record.nodeVectors.entrySet()) {
if (entry.getKey().getLeaves().size() <= maxLength) {
subtrees.add(Pair.makePair(entry.getKey(), entry.getValue()));
}
}
}
log.info("There are " + subtrees.size() + " subtrees in the set of trees");
PriorityQueue<ScoredObject<Pair<Tree, Tree>>> bestmatches = new PriorityQueue<>(101, ScoredComparator.DESCENDING_COMPARATOR);
for (int i = 0; i < subtrees.size(); ++i) {
log.info(subtrees.get(i).first().yieldWords());
log.info(subtrees.get(i).first());
for (int j = 0; j < subtrees.size(); ++j) {
if (i == j) {
continue;
}
// TODO: look at basic category?
double normF = subtrees.get(i).second().minus(subtrees.get(j).second()).normF();
bestmatches.add(new ScoredObject<>(Pair.makePair(subtrees.get(i).first(), subtrees.get(j).first()), normF));
if (bestmatches.size() > 100) {
bestmatches.poll();
}
}
List<ScoredObject<Pair<Tree, Tree>>> ordered = Generics.newArrayList();
while (bestmatches.size() > 0) {
ordered.add(bestmatches.poll());
}
Collections.reverse(ordered);
for (ScoredObject<Pair<Tree, Tree>> pair : ordered) {
log.info(" MATCHED " + pair.object().second.yieldWords() + " ... " + pair.object().second() + " with a score of " + pair.score());
}
log.info();
log.info();
bestmatches.clear();
}
/*
for (int i = 0; i < records.size(); ++i) {
if (i % 10 == 0) {
log.info(" " + i);
}
List<ScoredObject<ParseRecord>> scored = Generics.newArrayList();
for (int j = 0; j < records.size(); ++j) {
if (i == j) continue;
double score = 0.0;
int matches = 0;
for (Map.Entry<Tree, SimpleMatrix> first : records.get(i).nodeVectors.entrySet()) {
for (Map.Entry<Tree, SimpleMatrix> second : records.get(j).nodeVectors.entrySet()) {
String firstBasic = dvparser.dvModel.basicCategory(first.getKey().label().value());
String secondBasic = dvparser.dvModel.basicCategory(second.getKey().label().value());
if (firstBasic.equals(secondBasic)) {
++matches;
double normF = first.getValue().minus(second.getValue()).normF();
score += normF * normF;
}
}
}
if (matches == 0) {
score = Double.POSITIVE_INFINITY;
} else {
score = score / matches;
}
//double score = records.get(i).vector.minus(records.get(j).vector).normF();
scored.add(new ScoredObject<ParseRecord>(records.get(j), score));
}
Collections.sort(scored, ScoredComparator.ASCENDING_COMPARATOR);
out.write(records.get(i).sentence.toString() + "\n");
for (int j = 0; j < numNeighbors; ++j) {
out.write(" " + scored.get(j).score() + ": " + scored.get(j).object().sentence + "\n");
}
out.write("\n\n");
}
log.info();
*/
bout.flush();
out.flush();
out.close();
}
use of org.ejml.simple.SimpleMatrix in project CoreNLP by stanfordnlp.
the class ParseAndPrintMatrices method main.
public static void main(String[] args) throws IOException {
String modelPath = null;
String outputPath = null;
String inputPath = null;
String testTreebankPath = null;
FileFilter testTreebankFilter = null;
List<String> unusedArgs = Generics.newArrayList();
for (int argIndex = 0; argIndex < args.length; ) {
if (args[argIndex].equalsIgnoreCase("-model")) {
modelPath = args[argIndex + 1];
argIndex += 2;
} else if (args[argIndex].equalsIgnoreCase("-output")) {
outputPath = args[argIndex + 1];
argIndex += 2;
} else if (args[argIndex].equalsIgnoreCase("-input")) {
inputPath = args[argIndex + 1];
argIndex += 2;
} else if (args[argIndex].equalsIgnoreCase("-testTreebank")) {
Pair<String, FileFilter> treebankDescription = ArgUtils.getTreebankDescription(args, argIndex, "-testTreebank");
argIndex = argIndex + ArgUtils.numSubArgs(args, argIndex) + 1;
testTreebankPath = treebankDescription.first();
testTreebankFilter = treebankDescription.second();
} else {
unusedArgs.add(args[argIndex++]);
}
}
String[] newArgs = unusedArgs.toArray(new String[unusedArgs.size()]);
LexicalizedParser parser = LexicalizedParser.loadModel(modelPath, newArgs);
DVModel model = DVParser.getModelFromLexicalizedParser(parser);
File outputFile = new File(outputPath);
FileSystem.checkNotExistsOrFail(outputFile);
FileSystem.mkdirOrFail(outputFile);
int count = 0;
if (inputPath != null) {
Reader input = new BufferedReader(new FileReader(inputPath));
DocumentPreprocessor processor = new DocumentPreprocessor(input);
for (List<HasWord> sentence : processor) {
// index from 1
count++;
ParserQuery pq = parser.parserQuery();
if (!(pq instanceof RerankingParserQuery)) {
throw new IllegalArgumentException("Expected a RerankingParserQuery");
}
RerankingParserQuery rpq = (RerankingParserQuery) pq;
if (!rpq.parse(sentence)) {
throw new RuntimeException("Unparsable sentence: " + sentence);
}
RerankerQuery reranker = rpq.rerankerQuery();
if (!(reranker instanceof DVModelReranker.Query)) {
throw new IllegalArgumentException("Expected a DVModelReranker");
}
DeepTree deepTree = ((DVModelReranker.Query) reranker).getDeepTrees().get(0);
IdentityHashMap<Tree, SimpleMatrix> vectors = deepTree.getVectors();
for (Map.Entry<Tree, SimpleMatrix> entry : vectors.entrySet()) {
log.info(entry.getKey() + " " + entry.getValue());
}
FileWriter fout = new FileWriter(outputPath + File.separator + "sentence" + count + ".txt");
BufferedWriter bout = new BufferedWriter(fout);
bout.write(SentenceUtils.listToString(sentence));
bout.newLine();
bout.write(deepTree.getTree().toString());
bout.newLine();
for (HasWord word : sentence) {
outputMatrix(bout, model.getWordVector(word.word()));
}
Tree rootTree = findRootTree(vectors);
outputTreeMatrices(bout, rootTree, vectors);
bout.flush();
fout.close();
}
}
}
use of org.ejml.simple.SimpleMatrix in project CoreNLP by stanfordnlp.
the class DVModel method randomTransformMatrix.
/**
* Create a random transform matrix based on the initialization
* parameters. This will be numRows x numCols big. These can be
* plugged into either unary or binary transform matrices.
*/
private SimpleMatrix randomTransformMatrix() {
SimpleMatrix matrix;
switch(op.trainOptions.transformMatrixType) {
case DIAGONAL:
matrix = SimpleMatrix.random(numRows, numCols, -1.0 / Math.sqrt((double) numCols * 100.0), 1.0 / Math.sqrt((double) numCols * 100.0), rand).plus(identity);
break;
case RANDOM:
matrix = SimpleMatrix.random(numRows, numCols, -1.0 / Math.sqrt((double) numCols), 1.0 / Math.sqrt((double) numCols), rand);
break;
case OFF_DIAGONAL:
matrix = SimpleMatrix.random(numRows, numCols, -1.0 / Math.sqrt((double) numCols * 100.0), 1.0 / Math.sqrt((double) numCols * 100.0), rand).plus(identity);
for (int i = 0; i < numCols; ++i) {
int x = rand.nextInt(numCols);
int y = rand.nextInt(numCols);
// -1, 0, or 1
int scale = rand.nextInt(3) - 1;
matrix.set(x, y, matrix.get(x, y) + scale);
}
break;
case RANDOM_ZEROS:
matrix = SimpleMatrix.random(numRows, numCols, -1.0 / Math.sqrt((double) numCols * 100.0), 1.0 / Math.sqrt((double) numCols * 100.0), rand).plus(identity);
for (int i = 0; i < numCols; ++i) {
int x = rand.nextInt(numCols);
int y = rand.nextInt(numCols);
matrix.set(x, y, 0.0);
}
break;
default:
throw new IllegalArgumentException("Unexpected matrix initialization type " + op.trainOptions.transformMatrixType);
}
return matrix;
}
use of org.ejml.simple.SimpleMatrix in project CoreNLP by stanfordnlp.
the class DVModel method randomContextMatrix.
/**
* Creates a random context matrix. This will be numRows x
* 2*numCols big. These can be appended to the end of either a
* unary or binary transform matrix to get the transform matrix
* which uses context words.
*/
private SimpleMatrix randomContextMatrix() {
SimpleMatrix matrix = new SimpleMatrix(numRows, numCols * 2);
matrix.insertIntoThis(0, 0, identity.scale(op.trainOptions.scalingForInit * 0.1));
matrix.insertIntoThis(0, numCols, identity.scale(op.trainOptions.scalingForInit * 0.1));
matrix = matrix.plus(SimpleMatrix.random(numRows, numCols * 2, -1.0 / Math.sqrt((double) numCols * 100.0), 1.0 / Math.sqrt((double) numCols * 100.0), rand));
return matrix;
}
use of org.ejml.simple.SimpleMatrix in project CoreNLP by stanfordnlp.
the class DVModel method filterRulesForBatch.
public void filterRulesForBatch(TwoDimensionalSet<String, String> binaryRules, Set<String> unaryRules, Set<String> words) {
TwoDimensionalMap<String, String, SimpleMatrix> newBinaryTransforms = TwoDimensionalMap.treeMap();
TwoDimensionalMap<String, String, SimpleMatrix> newBinaryScores = TwoDimensionalMap.treeMap();
for (Pair<String, String> binaryRule : binaryRules) {
SimpleMatrix transform = binaryTransform.get(binaryRule.first(), binaryRule.second());
if (transform != null) {
newBinaryTransforms.put(binaryRule.first(), binaryRule.second(), transform);
}
SimpleMatrix score = binaryScore.get(binaryRule.first(), binaryRule.second());
if (score != null) {
newBinaryScores.put(binaryRule.first(), binaryRule.second(), score);
}
if ((transform == null && score != null) || (transform != null && score == null)) {
throw new AssertionError();
}
}
binaryTransform = newBinaryTransforms;
binaryScore = newBinaryScores;
numBinaryMatrices = binaryTransform.size();
Map<String, SimpleMatrix> newUnaryTransforms = Generics.newTreeMap();
Map<String, SimpleMatrix> newUnaryScores = Generics.newTreeMap();
for (String unaryRule : unaryRules) {
SimpleMatrix transform = unaryTransform.get(unaryRule);
if (transform != null) {
newUnaryTransforms.put(unaryRule, transform);
}
SimpleMatrix score = unaryScore.get(unaryRule);
if (score != null) {
newUnaryScores.put(unaryRule, score);
}
if ((transform == null && score != null) || (transform != null && score == null)) {
throw new AssertionError();
}
}
unaryTransform = newUnaryTransforms;
unaryScore = newUnaryScores;
numUnaryMatrices = unaryTransform.size();
Map<String, SimpleMatrix> newWordVectors = Generics.newTreeMap();
for (String word : words) {
SimpleMatrix wordVector = wordVectors.get(word);
if (wordVector != null) {
newWordVectors.put(word, wordVector);
}
}
wordVectors = newWordVectors;
}
Aggregations