use of org.deeplearning4j.models.embeddings.wordvectors.WordVectors in project deeplearning4j by deeplearning4j.
the class GloveTest method testGloVe1.
@Ignore
@Test
public void testGloVe1() throws Exception {
File inputFile = new ClassPathResource("/big/raw_sentences.txt").getFile();
SentenceIterator iter = new BasicLineIterator(inputFile.getAbsolutePath());
// Split on white spaces in the line to get words
TokenizerFactory t = new DefaultTokenizerFactory();
t.setTokenPreProcessor(new CommonPreprocessor());
Glove glove = new Glove.Builder().iterate(iter).tokenizerFactory(t).alpha(0.75).learningRate(0.1).epochs(45).xMax(100).shuffle(true).symmetric(true).build();
glove.fit();
double simD = glove.similarity("day", "night");
double simP = glove.similarity("best", "police");
log.info("Day/night similarity: " + simD);
log.info("Best/police similarity: " + simP);
Collection<String> words = glove.wordsNearest("day", 10);
log.info("Nearest words to 'day': " + words);
assertTrue(simD > 0.7);
// actually simP should be somewhere at 0
assertTrue(simP < 0.5);
assertTrue(words.contains("night"));
assertTrue(words.contains("year"));
assertTrue(words.contains("week"));
File tempFile = File.createTempFile("glove", "temp");
tempFile.deleteOnExit();
INDArray day1 = glove.getWordVectorMatrix("day").dup();
WordVectorSerializer.writeWordVectors(glove, tempFile);
WordVectors vectors = WordVectorSerializer.loadTxtVectors(tempFile);
INDArray day2 = vectors.getWordVectorMatrix("day").dup();
assertEquals(day1, day2);
tempFile.delete();
}
Aggregations