use of org.ejml.simple.SimpleMatrix in project CoreNLP by stanfordnlp.
the class SentimentCostAndGradient method scaleAndRegularize.
private static double scaleAndRegularize(TwoDimensionalMap<String, String, SimpleMatrix> derivatives, TwoDimensionalMap<String, String, SimpleMatrix> currentMatrices, double scale, double regCost, boolean dropBiasColumn) {
// the regularization cost
double cost = 0.0;
for (TwoDimensionalMap.Entry<String, String, SimpleMatrix> entry : currentMatrices) {
SimpleMatrix D = derivatives.get(entry.getFirstKey(), entry.getSecondKey());
SimpleMatrix regMatrix = entry.getValue();
if (dropBiasColumn) {
regMatrix = new SimpleMatrix(regMatrix);
regMatrix.insertIntoThis(0, regMatrix.numCols() - 1, new SimpleMatrix(regMatrix.numRows(), 1));
}
D = D.scale(scale).plus(regMatrix.scale(regCost));
derivatives.put(entry.getFirstKey(), entry.getSecondKey(), D);
cost += regMatrix.elementMult(regMatrix).elementSum() * regCost / 2.0;
}
return cost;
}
use of org.ejml.simple.SimpleMatrix in project CoreNLP by stanfordnlp.
the class SentimentModel method randomTransformMatrix.
SimpleMatrix randomTransformMatrix() {
SimpleMatrix binary = new SimpleMatrix(numHid, numHid * 2 + 1);
// bias column values are initialized zero
binary.insertIntoThis(0, 0, randomTransformBlock());
binary.insertIntoThis(0, numHid, randomTransformBlock());
return binary.scale(op.trainOptions.scalingForInit);
}
use of org.ejml.simple.SimpleMatrix in project CoreNLP by stanfordnlp.
the class ConvertMatlabModel method main.
public static void main(String[] args) throws IOException {
String basePath = "/user/socherr/scr/projects/semComp/RNTN/src/params/";
int numSlices = 25;
boolean useEscapedParens = false;
for (int argIndex = 0; argIndex < args.length; ) {
if (args[argIndex].equalsIgnoreCase("-slices")) {
numSlices = Integer.parseInt(args[argIndex + 1]);
argIndex += 2;
} else if (args[argIndex].equalsIgnoreCase("-path")) {
basePath = args[argIndex + 1];
argIndex += 2;
} else if (args[argIndex].equalsIgnoreCase("-useEscapedParens")) {
useEscapedParens = true;
argIndex += 1;
} else {
log.info("Unknown argument " + args[argIndex]);
System.exit(2);
}
}
SimpleMatrix[] slices = new SimpleMatrix[numSlices];
for (int i = 0; i < numSlices; ++i) {
slices[i] = loadMatrix(basePath + "bin/Wt_" + (i + 1) + ".bin", basePath + "Wt_" + (i + 1) + ".txt");
}
SimpleTensor tensor = new SimpleTensor(slices);
log.info("W tensor size: " + tensor.numRows() + "x" + tensor.numCols() + "x" + tensor.numSlices());
SimpleMatrix W = loadMatrix(basePath + "bin/W.bin", basePath + "W.txt");
log.info("W matrix size: " + W.numRows() + "x" + W.numCols());
SimpleMatrix Wcat = loadMatrix(basePath + "bin/Wcat.bin", basePath + "Wcat.txt");
log.info("W cat size: " + Wcat.numRows() + "x" + Wcat.numCols());
SimpleMatrix combinedWV = loadMatrix(basePath + "bin/Wv.bin", basePath + "Wv.txt");
log.info("Word matrix size: " + combinedWV.numRows() + "x" + combinedWV.numCols());
File vocabFile = new File(basePath + "vocab_1.txt");
if (!vocabFile.exists()) {
vocabFile = new File(basePath + "words.txt");
}
List<String> lines = Generics.newArrayList();
for (String line : IOUtils.readLines(vocabFile)) {
lines.add(line.trim());
}
log.info("Lines in vocab file: " + lines.size());
Map<String, SimpleMatrix> wordVectors = Generics.newTreeMap();
for (int i = 0; i < lines.size() && i < combinedWV.numCols(); ++i) {
String[] pieces = lines.get(i).split(" +");
if (pieces.length == 0 || pieces.length > 1) {
continue;
}
wordVectors.put(pieces[0], combinedWV.extractMatrix(0, numSlices, i, i + 1));
if (pieces[0].equals("UNK")) {
wordVectors.put(SentimentModel.UNKNOWN_WORD, wordVectors.get("UNK"));
}
}
// If there is no ",", we first try to look for an HTML escaping,
// then fall back to "." as better than just a random word vector.
// Same for "``" and ";"
copyWordVector(wordVectors, ",", ",");
copyWordVector(wordVectors, ".", ",");
copyWordVector(wordVectors, ";", ";");
copyWordVector(wordVectors, ".", ";");
copyWordVector(wordVectors, "``", "``");
copyWordVector(wordVectors, "''", "``");
if (useEscapedParens) {
replaceWordVector(wordVectors, "(", "-LRB-");
replaceWordVector(wordVectors, ")", "-RRB-");
}
RNNOptions op = new RNNOptions();
op.numHid = numSlices;
op.lowercaseWordVectors = false;
if (Wcat.numRows() == 2) {
op.classNames = new String[] { "Negative", "Positive" };
// TODO: set to null once old models are updated
op.equivalenceClasses = new int[][] { { 0 }, { 1 } };
op.numClasses = 2;
}
if (!wordVectors.containsKey(SentimentModel.UNKNOWN_WORD)) {
wordVectors.put(SentimentModel.UNKNOWN_WORD, SimpleMatrix.random_DDRM(numSlices, 1, -0.00001, 0.00001, new Random()));
}
SentimentModel model = SentimentModel.modelFromMatrices(W, Wcat, tensor, wordVectors, op);
model.saveSerialized("matlab.ser.gz");
}
use of org.ejml.simple.SimpleMatrix in project CoreNLP by stanfordnlp.
the class SentimentModel method readWordVectors.
void readWordVectors() {
Embedding embedding = new Embedding(op.wordVectors, op.numHid);
this.wordVectors = Generics.newTreeMap();
// for (String word : rawWordVectors.keySet()) {
for (String word : embedding.keySet()) {
// TODO: factor out unknown word vector code from DVParser
wordVectors.put(word, embedding.get(word));
}
String unkWord = op.unkWord;
SimpleMatrix unknownWordVector = wordVectors.get(unkWord);
wordVectors.put(UNKNOWN_WORD, unknownWordVector);
if (unknownWordVector == null) {
throw new RuntimeException("Unknown word vector not specified in the word vector file");
}
}
use of org.ejml.simple.SimpleMatrix in project CoreNLP by stanfordnlp.
the class SentimentModel method randomClassificationMatrix.
/**
* Returns matrices of the right size for either binary or unary (terminal) classification
*/
SimpleMatrix randomClassificationMatrix() {
SimpleMatrix score = new SimpleMatrix(numClasses, numHid + 1);
double range = 1.0 / (Math.sqrt((double) numHid));
score.insertIntoThis(0, 0, SimpleMatrix.random_DDRM(numClasses, numHid, -range, range, rand));
// bias column goes from 0 to 1 initially
score.insertIntoThis(0, numHid, SimpleMatrix.random_DDRM(numClasses, 1, 0.0, 1.0, rand));
return score.scale(op.trainOptions.scalingForInit);
}
Aggregations