use of org.ejml.simple.SimpleMatrix in project CoreNLP by stanfordnlp.
the class NeuralUtilsTest method testCosine.
@Test
public void testCosine() {
double[][] values = new double[1][5];
values[0] = new double[] { 0.1, 0.2, 0.3, 0.4, 0.5 };
SimpleMatrix vector1 = new SimpleMatrix(values);
values[0] = new double[] { 0.5, 0.4, 0.3, 0.2, 0.1 };
SimpleMatrix vector2 = new SimpleMatrix(values);
assertEquals(0.35000000000000003, NeuralUtils.dot(vector1, vector2), 1e-5);
assertEquals(0.6363636363636364, NeuralUtils.cosine(vector1, vector2), 1e-5);
vector1 = vector1.transpose();
vector2 = vector2.transpose();
assertEquals(0.35000000000000003, NeuralUtils.dot(vector1, vector2), 1e-5);
assertEquals(0.6363636363636364, NeuralUtils.cosine(vector1, vector2), 1e-5);
}
use of org.ejml.simple.SimpleMatrix in project CoreNLP by stanfordnlp.
the class Embedding method loadWordVectors.
/**
* This method reads a file of raw word vectors, with a given expected size, and returns a map of word to vector.
* <br>
* The file should be in the format <br>
* <code>WORD X1 X2 X3 ...</code> <br>
* If vectors in the file are smaller than expectedSize, an
* exception is thrown. If vectors are larger, the vectors are
* truncated and a warning is printed.
*/
private void loadWordVectors(String wordVectorFile) {
log.info("# Loading embedding ...\n word vector file = " + wordVectorFile);
boolean warned = false;
int numWords = 0;
for (String line : IOUtils.readLines(wordVectorFile, "utf-8")) {
String[] lineSplit = line.split("\\s+");
String word = lineSplit[0];
// check for unknown token
if (word.equals("UNKNOWN") || word.equals("UUUNKKK") || word.equals("UNK") || word.equals("*UNKNOWN*") || word.equals("<unk>")) {
word = UNKNOWN_WORD;
}
// check for start token
if (word.equals("<s>")) {
word = START_WORD;
}
// check for end token
if (word.equals("</s>")) {
word = END_WORD;
}
int dimOfWords = lineSplit.length - 1;
if (embeddingSize <= 0) {
embeddingSize = dimOfWords;
log.info(" detected embedding size = " + dimOfWords);
}
// the other entries will all be entries in the word vector
if (dimOfWords > embeddingSize) {
if (!warned) {
warned = true;
log.info("WARNING: Dimensionality of numHid parameter and word vectors do not match, deleting word vector dimensions to fit!");
}
dimOfWords = embeddingSize;
} else if (dimOfWords < embeddingSize) {
throw new RuntimeException("Word vectors file has dimension too small for requested numHid of " + embeddingSize);
}
double[][] vec = new double[dimOfWords][1];
for (int i = 1; i <= dimOfWords; i++) {
vec[i - 1][0] = Double.parseDouble(lineSplit[i]);
}
SimpleMatrix vector = new SimpleMatrix(vec);
wordVectors.put(word, vector);
numWords++;
}
log.info(" num words = " + numWords);
}
Aggregations