Search in sources :

Example 26 with WordVectors

use of org.deeplearning4j.models.embeddings.wordvectors.WordVectors in project deeplearning4j by deeplearning4j.

the class GloveTest method testGloVe1.

@Ignore
@Test
public void testGloVe1() throws Exception {
    File inputFile = new ClassPathResource("/big/raw_sentences.txt").getFile();
    SentenceIterator iter = new BasicLineIterator(inputFile.getAbsolutePath());
    // Split on white spaces in the line to get words
    TokenizerFactory t = new DefaultTokenizerFactory();
    t.setTokenPreProcessor(new CommonPreprocessor());
    Glove glove = new Glove.Builder().iterate(iter).tokenizerFactory(t).alpha(0.75).learningRate(0.1).epochs(45).xMax(100).shuffle(true).symmetric(true).build();
    glove.fit();
    double simD = glove.similarity("day", "night");
    double simP = glove.similarity("best", "police");
    log.info("Day/night similarity: " + simD);
    log.info("Best/police similarity: " + simP);
    Collection<String> words = glove.wordsNearest("day", 10);
    log.info("Nearest words to 'day': " + words);
    assertTrue(simD > 0.7);
    // actually simP should be somewhere at 0
    assertTrue(simP < 0.5);
    assertTrue(words.contains("night"));
    assertTrue(words.contains("year"));
    assertTrue(words.contains("week"));
    File tempFile = File.createTempFile("glove", "temp");
    tempFile.deleteOnExit();
    INDArray day1 = glove.getWordVectorMatrix("day").dup();
    WordVectorSerializer.writeWordVectors(glove, tempFile);
    WordVectors vectors = WordVectorSerializer.loadTxtVectors(tempFile);
    INDArray day2 = vectors.getWordVectorMatrix("day").dup();
    assertEquals(day1, day2);
    tempFile.delete();
}
Also used : BasicLineIterator(org.deeplearning4j.text.sentenceiterator.BasicLineIterator) TokenizerFactory(org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory) DefaultTokenizerFactory(org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFactory) ClassPathResource(org.datavec.api.util.ClassPathResource) SentenceIterator(org.deeplearning4j.text.sentenceiterator.SentenceIterator) LineSentenceIterator(org.deeplearning4j.text.sentenceiterator.LineSentenceIterator) DefaultTokenizerFactory(org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFactory) CommonPreprocessor(org.deeplearning4j.text.tokenization.tokenizer.preprocessor.CommonPreprocessor) INDArray(org.nd4j.linalg.api.ndarray.INDArray) WordVectors(org.deeplearning4j.models.embeddings.wordvectors.WordVectors) File(java.io.File) Ignore(org.junit.Ignore) Test(org.junit.Test)

Aggregations

WordVectors (org.deeplearning4j.models.embeddings.wordvectors.WordVectors)26 Test (org.junit.Test)26 File (java.io.File)15 INDArray (org.nd4j.linalg.api.ndarray.INDArray)15 ClassPathResource (org.datavec.api.util.ClassPathResource)10 Ignore (org.junit.Ignore)9 CommonPreprocessor (org.deeplearning4j.text.tokenization.tokenizer.preprocessor.CommonPreprocessor)6 DefaultTokenizerFactory (org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFactory)6 TokenizerFactory (org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory)6 InMemoryLookupTable (org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable)4 SentenceIterator (org.deeplearning4j.text.sentenceiterator.SentenceIterator)4 ArrayList (java.util.ArrayList)3 Word2Vec (org.deeplearning4j.models.word2vec.Word2Vec)3 VocabCache (org.deeplearning4j.models.word2vec.wordstore.VocabCache)3 InMemoryLookupCache (org.deeplearning4j.models.word2vec.wordstore.inmemory.InMemoryLookupCache)3 UimaSentenceIterator (org.deeplearning4j.text.sentenceiterator.UimaSentenceIterator)3 VocabWord (org.deeplearning4j.models.word2vec.VocabWord)2 BasicLineIterator (org.deeplearning4j.text.sentenceiterator.BasicLineIterator)2 FileInputStream (java.io.FileInputStream)1 FileOutputStream (java.io.FileOutputStream)1