Search in sources :

Example 11 with Word2Vec

use of org.deeplearning4j.models.word2vec.Word2Vec in project deeplearning4j by deeplearning4j.

the class PerformanceTests method testWord2VecCBOWBig.

@Ignore
@Test
public void testWord2VecCBOWBig() throws Exception {
    SentenceIterator iter = new BasicLineIterator("/home/raver119/Downloads/corpus/namuwiki_raw.txt");
    //iter = new BasicLineIterator("/home/raver119/Downloads/corpus/ru_sentences.txt");
    //SentenceIterator iter = new BasicLineIterator("/ext/DATASETS/ru/Socials/ru_sentences.txt");
    TokenizerFactory t = new KoreanTokenizerFactory();
    //t = new DefaultTokenizerFactory();
    //t.setTokenPreProcessor(new CommonPreprocessor());
    Word2Vec vec = new Word2Vec.Builder().minWordFrequency(1).iterations(5).learningRate(0.025).layerSize(150).seed(42).sampling(0).negativeSample(0).useHierarchicSoftmax(true).windowSize(5).modelUtils(new BasicModelUtils<VocabWord>()).useAdaGrad(false).iterate(iter).workers(8).allowParallelTokenization(true).tokenizerFactory(t).elementsLearningAlgorithm(new CBOW<VocabWord>()).build();
    long time1 = System.currentTimeMillis();
    vec.fit();
    long time2 = System.currentTimeMillis();
    log.info("Total execution time: {}", (time2 - time1));
}
Also used : BasicLineIterator(org.deeplearning4j.text.sentenceiterator.BasicLineIterator) KoreanTokenizerFactory(org.deeplearning4j.text.tokenization.tokenizerfactory.KoreanTokenizerFactory) TokenizerFactory(org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory) DefaultTokenizerFactory(org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFactory) BasicModelUtils(org.deeplearning4j.models.embeddings.reader.impl.BasicModelUtils) Word2Vec(org.deeplearning4j.models.word2vec.Word2Vec) CBOW(org.deeplearning4j.models.embeddings.learning.impl.elements.CBOW) KoreanTokenizerFactory(org.deeplearning4j.text.tokenization.tokenizerfactory.KoreanTokenizerFactory) SentenceIterator(org.deeplearning4j.text.sentenceiterator.SentenceIterator) Ignore(org.junit.Ignore) Test(org.junit.Test)

Example 12 with Word2Vec

use of org.deeplearning4j.models.word2vec.Word2Vec in project deeplearning4j by deeplearning4j.

the class ManualTests method testWord2VecPlot.

@Test
public void testWord2VecPlot() throws Exception {
    File inputFile = new ClassPathResource("/big/raw_sentences.txt").getFile();
    SentenceIterator iter = new BasicLineIterator(inputFile.getAbsolutePath());
    TokenizerFactory t = new DefaultTokenizerFactory();
    t.setTokenPreProcessor(new CommonPreprocessor());
    Word2Vec vec = new Word2Vec.Builder().minWordFrequency(5).iterations(2).batchSize(1000).learningRate(0.025).layerSize(100).seed(42).sampling(0).negativeSample(0).windowSize(5).modelUtils(new BasicModelUtils<VocabWord>()).useAdaGrad(false).iterate(iter).workers(10).tokenizerFactory(t).build();
    vec.fit();
    //        UiConnectionInfo connectionInfo = UiServer.getInstance().getConnectionInfo();
    //        vec.getLookupTable().plotVocab(100, connectionInfo);
    Thread.sleep(10000000000L);
    fail("Not implemented");
}
Also used : DefaultTokenizerFactory(org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFactory) CommonPreprocessor(org.deeplearning4j.text.tokenization.tokenizer.preprocessor.CommonPreprocessor) BasicLineIterator(org.deeplearning4j.text.sentenceiterator.BasicLineIterator) TokenizerFactory(org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory) DefaultTokenizerFactory(org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFactory) BasicModelUtils(org.deeplearning4j.models.embeddings.reader.impl.BasicModelUtils) Word2Vec(org.deeplearning4j.models.word2vec.Word2Vec) File(java.io.File) ClassPathResource(org.datavec.api.util.ClassPathResource) SentenceIterator(org.deeplearning4j.text.sentenceiterator.SentenceIterator) Test(org.junit.Test)

Example 13 with Word2Vec

use of org.deeplearning4j.models.word2vec.Word2Vec in project deeplearning4j by deeplearning4j.

the class WordVectorSerializerTest method testFullModelSerialization.

@Test
public void testFullModelSerialization() throws Exception {
    File inputFile = new ClassPathResource("/big/raw_sentences.txt").getFile();
    SentenceIterator iter = UimaSentenceIterator.createWithPath(inputFile.getAbsolutePath());
    // Split on white spaces in the line to get words
    TokenizerFactory t = new DefaultTokenizerFactory();
    t.setTokenPreProcessor(new CommonPreprocessor());
    InMemoryLookupCache cache = new InMemoryLookupCache(false);
    WeightLookupTable table = new InMemoryLookupTable.Builder().vectorLength(100).useAdaGrad(false).negative(5.0).cache(cache).lr(0.025f).build();
    Word2Vec vec = new Word2Vec.Builder().minWordFrequency(5).iterations(1).epochs(1).layerSize(100).lookupTable(table).stopWords(new ArrayList<String>()).useAdaGrad(false).negativeSample(5).vocabCache(cache).seed(42).windowSize(5).iterate(iter).tokenizerFactory(t).build();
    assertEquals(new ArrayList<String>(), vec.getStopWords());
    vec.fit();
    //logger.info("Original word 0: " + cache.wordFor(cache.wordAtIndex(0)));
    //logger.info("Closest Words:");
    Collection<String> lst = vec.wordsNearest("day", 10);
    System.out.println(lst);
    WordVectorSerializer.writeFullModel(vec, "tempModel.txt");
    File modelFile = new File("tempModel.txt");
    modelFile.deleteOnExit();
    assertTrue(modelFile.exists());
    assertTrue(modelFile.length() > 0);
    Word2Vec vec2 = WordVectorSerializer.loadFullModel("tempModel.txt");
    assertNotEquals(null, vec2);
    assertEquals(vec.getConfiguration(), vec2.getConfiguration());
    //logger.info("Source ExpTable: " + ArrayUtils.toString(((InMemoryLookupTable) table).getExpTable()));
    //logger.info("Dest  ExpTable: " + ArrayUtils.toString(((InMemoryLookupTable)  vec2.getLookupTable()).getExpTable()));
    assertTrue(ArrayUtils.isEquals(((InMemoryLookupTable) table).getExpTable(), ((InMemoryLookupTable) vec2.getLookupTable()).getExpTable()));
    InMemoryLookupTable restoredTable = (InMemoryLookupTable) vec2.lookupTable();
    /*
        logger.info("Restored word 1: " + restoredTable.getVocab().wordFor(restoredTable.getVocab().wordAtIndex(1)));
        logger.info("Restored word 'it': " + restoredTable.getVocab().wordFor("it"));
        logger.info("Original word 1: " + cache.wordFor(cache.wordAtIndex(1)));
        logger.info("Original word 'i': " + cache.wordFor("i"));
        logger.info("Original word 0: " + cache.wordFor(cache.wordAtIndex(0)));
        logger.info("Restored word 0: " + restoredTable.getVocab().wordFor(restoredTable.getVocab().wordAtIndex(0)));
        */
    assertEquals(cache.wordAtIndex(1), restoredTable.getVocab().wordAtIndex(1));
    assertEquals(cache.wordAtIndex(7), restoredTable.getVocab().wordAtIndex(7));
    assertEquals(cache.wordAtIndex(15), restoredTable.getVocab().wordAtIndex(15));
    /*
            these tests needed only to make sure INDArray equality is working properly
         */
    double[] array1 = new double[] { 0.323232325, 0.65756575, 0.12315, 0.12312315, 0.1232135, 0.12312315, 0.4343423425, 0.15 };
    double[] array2 = new double[] { 0.423232325, 0.25756575, 0.12375, 0.12311315, 0.1232035, 0.12318315, 0.4343493425, 0.25 };
    assertNotEquals(Nd4j.create(array1), Nd4j.create(array2));
    assertEquals(Nd4j.create(array1), Nd4j.create(array1));
    INDArray rSyn0_1 = restoredTable.getSyn0().slice(1);
    INDArray oSyn0_1 = ((InMemoryLookupTable) table).getSyn0().slice(1);
    //logger.info("Restored syn0: " + rSyn0_1);
    //logger.info("Original syn0: " + oSyn0_1);
    assertEquals(oSyn0_1, rSyn0_1);
    // just checking $^###! syn0/syn1 order
    int cnt = 0;
    for (VocabWord word : cache.vocabWords()) {
        INDArray rSyn0 = restoredTable.getSyn0().slice(word.getIndex());
        INDArray oSyn0 = ((InMemoryLookupTable) table).getSyn0().slice(word.getIndex());
        assertEquals(rSyn0, oSyn0);
        assertEquals(1.0, arraysSimilarity(rSyn0, oSyn0), 0.001);
        INDArray rSyn1 = restoredTable.getSyn1().slice(word.getIndex());
        INDArray oSyn1 = ((InMemoryLookupTable) table).getSyn1().slice(word.getIndex());
        assertEquals(rSyn1, oSyn1);
        if (arraysSimilarity(rSyn1, oSyn1) < 0.98) {
        //   logger.info("Restored syn1: " + rSyn1);
        //   logger.info("Original  syn1: " + oSyn1);
        }
        // we exclude word 222 since it has syn1 full of zeroes
        if (cnt != 222)
            assertEquals(1.0, arraysSimilarity(rSyn1, oSyn1), 0.001);
        if (((InMemoryLookupTable) table).getSyn1Neg() != null) {
            INDArray rSyn1Neg = restoredTable.getSyn1Neg().slice(word.getIndex());
            INDArray oSyn1Neg = ((InMemoryLookupTable) table).getSyn1Neg().slice(word.getIndex());
            assertEquals(rSyn1Neg, oSyn1Neg);
        //                assertEquals(1.0, arraysSimilarity(rSyn1Neg, oSyn1Neg), 0.001);
        }
        assertEquals(word.getHistoricalGradient(), restoredTable.getVocab().wordFor(word.getWord()).getHistoricalGradient());
        cnt++;
    }
    // at this moment we can assume that whole model is transferred, and we can call fit over new model
    //        iter.reset();
    iter = UimaSentenceIterator.createWithPath(inputFile.getAbsolutePath());
    vec2.setTokenizerFactory(t);
    vec2.setSentenceIterator(iter);
    vec2.fit();
    INDArray day1 = vec.getWordVectorMatrix("day");
    INDArray day2 = vec2.getWordVectorMatrix("day");
    INDArray night1 = vec.getWordVectorMatrix("night");
    INDArray night2 = vec2.getWordVectorMatrix("night");
    double simD = arraysSimilarity(day1, day2);
    double simN = arraysSimilarity(night1, night2);
    logger.info("Vec1 day: " + day1);
    logger.info("Vec2 day: " + day2);
    logger.info("Vec1 night: " + night1);
    logger.info("Vec2 night: " + night2);
    logger.info("Day/day cross-model similarity: " + simD);
    logger.info("Night/night cross-model similarity: " + simN);
    logger.info("Vec1 day/night similiraty: " + vec.similarity("day", "night"));
    logger.info("Vec2 day/night similiraty: " + vec2.similarity("day", "night"));
    // check if cross-model values are not the same
    assertNotEquals(1.0, simD, 0.001);
    assertNotEquals(1.0, simN, 0.001);
    // check if cross-model values are still close to each other
    assertTrue(simD > 0.70);
    assertTrue(simN > 0.70);
    modelFile.delete();
}
Also used : TokenizerFactory(org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory) DefaultTokenizerFactory(org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFactory) VocabWord(org.deeplearning4j.models.word2vec.VocabWord) ClassPathResource(org.datavec.api.util.ClassPathResource) SentenceIterator(org.deeplearning4j.text.sentenceiterator.SentenceIterator) UimaSentenceIterator(org.deeplearning4j.text.sentenceiterator.UimaSentenceIterator) InMemoryLookupCache(org.deeplearning4j.models.word2vec.wordstore.inmemory.InMemoryLookupCache) DefaultTokenizerFactory(org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFactory) CommonPreprocessor(org.deeplearning4j.text.tokenization.tokenizer.preprocessor.CommonPreprocessor) InMemoryLookupTable(org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable) INDArray(org.nd4j.linalg.api.ndarray.INDArray) Word2Vec(org.deeplearning4j.models.word2vec.Word2Vec) WeightLookupTable(org.deeplearning4j.models.embeddings.WeightLookupTable) File(java.io.File) Test(org.junit.Test)

Example 14 with Word2Vec

use of org.deeplearning4j.models.word2vec.Word2Vec in project deeplearning4j by deeplearning4j.

the class WordVectorSerializerTest method testOutputStream.

@Test
public void testOutputStream() throws Exception {
    File file = File.createTempFile("tmp_ser", "ssa");
    file.deleteOnExit();
    File inputFile = new ClassPathResource("/big/raw_sentences.txt").getFile();
    SentenceIterator iter = new BasicLineIterator(inputFile);
    // Split on white spaces in the line to get words
    TokenizerFactory t = new DefaultTokenizerFactory();
    t.setTokenPreProcessor(new CommonPreprocessor());
    InMemoryLookupCache cache = new InMemoryLookupCache(false);
    WeightLookupTable table = new InMemoryLookupTable.Builder().vectorLength(100).useAdaGrad(false).negative(5.0).cache(cache).lr(0.025f).build();
    Word2Vec vec = new Word2Vec.Builder().minWordFrequency(5).iterations(1).epochs(1).layerSize(100).lookupTable(table).stopWords(new ArrayList<String>()).useAdaGrad(false).negativeSample(5).vocabCache(cache).seed(42).windowSize(5).iterate(iter).tokenizerFactory(t).build();
    assertEquals(new ArrayList<String>(), vec.getStopWords());
    vec.fit();
    INDArray day1 = vec.getWordVectorMatrix("day");
    WordVectorSerializer.writeWordVectors(vec, new FileOutputStream(file));
    WordVectors vec2 = WordVectorSerializer.loadTxtVectors(file);
    INDArray day2 = vec2.getWordVectorMatrix("day");
    assertEquals(day1, day2);
    File tempFile = File.createTempFile("tetsts", "Fdfs");
    tempFile.deleteOnExit();
    WordVectorSerializer.writeWord2VecModel(vec, tempFile);
    Word2Vec vec3 = WordVectorSerializer.readWord2VecModel(tempFile);
}
Also used : BasicLineIterator(org.deeplearning4j.text.sentenceiterator.BasicLineIterator) TokenizerFactory(org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory) DefaultTokenizerFactory(org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFactory) ClassPathResource(org.datavec.api.util.ClassPathResource) SentenceIterator(org.deeplearning4j.text.sentenceiterator.SentenceIterator) UimaSentenceIterator(org.deeplearning4j.text.sentenceiterator.UimaSentenceIterator) InMemoryLookupCache(org.deeplearning4j.models.word2vec.wordstore.inmemory.InMemoryLookupCache) DefaultTokenizerFactory(org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFactory) CommonPreprocessor(org.deeplearning4j.text.tokenization.tokenizer.preprocessor.CommonPreprocessor) INDArray(org.nd4j.linalg.api.ndarray.INDArray) Word2Vec(org.deeplearning4j.models.word2vec.Word2Vec) FileOutputStream(java.io.FileOutputStream) WeightLookupTable(org.deeplearning4j.models.embeddings.WeightLookupTable) WordVectors(org.deeplearning4j.models.embeddings.wordvectors.WordVectors) File(java.io.File) Test(org.junit.Test)

Example 15 with Word2Vec

use of org.deeplearning4j.models.word2vec.Word2Vec in project deeplearning4j by deeplearning4j.

the class WordVectorSerializer method fromPair.

/**
     * Load word vectors from the given pair
     *
     * @param pair
     *            the given pair
     * @return a read only word vectors impl based on the given lookup table and vocab
     */
public static Word2Vec fromPair(Pair<InMemoryLookupTable, VocabCache> pair) {
    Word2Vec vectors = new Word2Vec();
    vectors.setLookupTable(pair.getFirst());
    vectors.setVocab(pair.getSecond());
    vectors.setModelUtils(new BasicModelUtils());
    return vectors;
}
Also used : BasicModelUtils(org.deeplearning4j.models.embeddings.reader.impl.BasicModelUtils) StaticWord2Vec(org.deeplearning4j.models.word2vec.StaticWord2Vec) Word2Vec(org.deeplearning4j.models.word2vec.Word2Vec)

Aggregations

Word2Vec (org.deeplearning4j.models.word2vec.Word2Vec)19 INDArray (org.nd4j.linalg.api.ndarray.INDArray)13 TokenizerFactory (org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory)12 VocabWord (org.deeplearning4j.models.word2vec.VocabWord)11 Test (org.junit.Test)11 SentenceIterator (org.deeplearning4j.text.sentenceiterator.SentenceIterator)10 DefaultTokenizerFactory (org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFactory)10 CommonPreprocessor (org.deeplearning4j.text.tokenization.tokenizer.preprocessor.CommonPreprocessor)9 File (java.io.File)8 ClassPathResource (org.datavec.api.util.ClassPathResource)8 StaticWord2Vec (org.deeplearning4j.models.word2vec.StaticWord2Vec)8 ArrayList (java.util.ArrayList)7 BasicLineIterator (org.deeplearning4j.text.sentenceiterator.BasicLineIterator)7 InMemoryLookupTable (org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable)6 GZIPInputStream (java.util.zip.GZIPInputStream)5 UimaSentenceIterator (org.deeplearning4j.text.sentenceiterator.UimaSentenceIterator)5 ZipFile (java.util.zip.ZipFile)4 BasicModelUtils (org.deeplearning4j.models.embeddings.reader.impl.BasicModelUtils)4 AtomicInteger (java.util.concurrent.atomic.AtomicInteger)3 ZipEntry (java.util.zip.ZipEntry)3