Search in sources :

Example 1 with SkipGram

use of org.deeplearning4j.models.embeddings.learning.impl.elements.SkipGram in project deeplearning4j by deeplearning4j.

the class Word2VecTests method testW2VnegativeOnRestore.

@Test
public void testW2VnegativeOnRestore() throws Exception {
    // Strip white space before and after for each line
    SentenceIterator iter = new BasicLineIterator(inputFile.getAbsolutePath());
    // Split on white spaces in the line to get words
    TokenizerFactory t = new DefaultTokenizerFactory();
    t.setTokenPreProcessor(new CommonPreprocessor());
    Word2Vec vec = new Word2Vec.Builder().minWordFrequency(1).iterations(3).batchSize(64).layerSize(100).stopWords(new ArrayList<String>()).seed(42).learningRate(0.025).minLearningRate(0.001).sampling(0).elementsLearningAlgorithm(new SkipGram<VocabWord>()).negativeSample(10).epochs(1).windowSize(5).useHierarchicSoftmax(false).allowParallelTokenization(true).modelUtils(new FlatModelUtils<VocabWord>()).iterate(iter).tokenizerFactory(t).build();
    assertEquals(false, vec.getConfiguration().isUseHierarchicSoftmax());
    log.info("Fit 1");
    vec.fit();
    File tmpFile = File.createTempFile("temp", "file");
    tmpFile.deleteOnExit();
    WordVectorSerializer.writeWord2VecModel(vec, tmpFile);
    iter.reset();
    Word2Vec restoredVec = WordVectorSerializer.readWord2VecModel(tmpFile, true);
    restoredVec.setTokenizerFactory(t);
    restoredVec.setSentenceIterator(iter);
    assertEquals(false, restoredVec.getConfiguration().isUseHierarchicSoftmax());
    assertTrue(restoredVec.getModelUtils() instanceof FlatModelUtils);
    assertTrue(restoredVec.getConfiguration().isAllowParallelTokenization());
    log.info("Fit 2");
    restoredVec.fit();
    iter.reset();
    restoredVec = WordVectorSerializer.readWord2VecModel(tmpFile, false);
    restoredVec.setTokenizerFactory(t);
    restoredVec.setSentenceIterator(iter);
    assertEquals(false, restoredVec.getConfiguration().isUseHierarchicSoftmax());
    assertTrue(restoredVec.getModelUtils() instanceof BasicModelUtils);
    log.info("Fit 3");
    restoredVec.fit();
}
Also used : BasicLineIterator(org.deeplearning4j.text.sentenceiterator.BasicLineIterator) TokenizerFactory(org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory) DefaultTokenizerFactory(org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFactory) SkipGram(org.deeplearning4j.models.embeddings.learning.impl.elements.SkipGram) SentenceIterator(org.deeplearning4j.text.sentenceiterator.SentenceIterator) UimaSentenceIterator(org.deeplearning4j.text.sentenceiterator.UimaSentenceIterator) DefaultTokenizerFactory(org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFactory) CommonPreprocessor(org.deeplearning4j.text.tokenization.tokenizer.preprocessor.CommonPreprocessor) BasicModelUtils(org.deeplearning4j.models.embeddings.reader.impl.BasicModelUtils) FlatModelUtils(org.deeplearning4j.models.embeddings.reader.impl.FlatModelUtils) File(java.io.File) Test(org.junit.Test)

Example 2 with SkipGram

use of org.deeplearning4j.models.embeddings.learning.impl.elements.SkipGram in project deeplearning4j by deeplearning4j.

the class Word2VecTests method testRunWord2Vec.

@Test
public void testRunWord2Vec() throws Exception {
    // Strip white space before and after for each line
    SentenceIterator iter = new BasicLineIterator(inputFile.getAbsolutePath());
    // Split on white spaces in the line to get words
    TokenizerFactory t = new DefaultTokenizerFactory();
    t.setTokenPreProcessor(new CommonPreprocessor());
    Word2Vec vec = new Word2Vec.Builder().minWordFrequency(1).iterations(3).batchSize(64).layerSize(100).stopWords(new ArrayList<String>()).seed(42).learningRate(0.025).minLearningRate(0.001).sampling(0).elementsLearningAlgorithm(new SkipGram<VocabWord>()).epochs(1).windowSize(5).allowParallelTokenization(true).modelUtils(new BasicModelUtils<VocabWord>()).iterate(iter).tokenizerFactory(t).build();
    assertEquals(new ArrayList<String>(), vec.getStopWords());
    vec.fit();
    File tempFile = File.createTempFile("temp", "temp");
    tempFile.deleteOnExit();
    WordVectorSerializer.writeFullModel(vec, tempFile.getAbsolutePath());
    Collection<String> lst = vec.wordsNearest("day", 10);
    //log.info(Arrays.toString(lst.toArray()));
    printWords("day", lst, vec);
    assertEquals(10, lst.size());
    double sim = vec.similarity("day", "night");
    log.info("Day/night similarity: " + sim);
    assertTrue(sim < 1.0);
    assertTrue(sim > 0.4);
    assertTrue(lst.contains("week"));
    assertTrue(lst.contains("night"));
    assertTrue(lst.contains("year"));
    assertFalse(lst.contains(null));
    lst = vec.wordsNearest("day", 10);
    //log.info(Arrays.toString(lst.toArray()));
    printWords("day", lst, vec);
    assertTrue(lst.contains("week"));
    assertTrue(lst.contains("night"));
    assertTrue(lst.contains("year"));
    new File("cache.ser").delete();
    ArrayList<String> labels = new ArrayList<>();
    labels.add("day");
    labels.add("night");
    labels.add("week");
    INDArray matrix = vec.getWordVectors(labels);
    assertEquals(matrix.getRow(0), vec.getWordVectorMatrix("day"));
    assertEquals(matrix.getRow(1), vec.getWordVectorMatrix("night"));
    assertEquals(matrix.getRow(2), vec.getWordVectorMatrix("week"));
    WordVectorSerializer.writeWordVectors(vec, pathToWriteto);
}
Also used : BasicLineIterator(org.deeplearning4j.text.sentenceiterator.BasicLineIterator) TokenizerFactory(org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory) DefaultTokenizerFactory(org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFactory) SkipGram(org.deeplearning4j.models.embeddings.learning.impl.elements.SkipGram) ArrayList(java.util.ArrayList) SentenceIterator(org.deeplearning4j.text.sentenceiterator.SentenceIterator) UimaSentenceIterator(org.deeplearning4j.text.sentenceiterator.UimaSentenceIterator) DefaultTokenizerFactory(org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFactory) CommonPreprocessor(org.deeplearning4j.text.tokenization.tokenizer.preprocessor.CommonPreprocessor) BasicModelUtils(org.deeplearning4j.models.embeddings.reader.impl.BasicModelUtils) INDArray(org.nd4j.linalg.api.ndarray.INDArray) File(java.io.File) Test(org.junit.Test)

Example 3 with SkipGram

use of org.deeplearning4j.models.embeddings.learning.impl.elements.SkipGram in project deeplearning4j by deeplearning4j.

the class ParagraphVectorsTest method testDirectInference.

@Test
public void testDirectInference() throws Exception {
    ClassPathResource resource_sentences = new ClassPathResource("/big/raw_sentences.txt");
    ClassPathResource resource_mixed = new ClassPathResource("/paravec");
    SentenceIterator iter = new AggregatingSentenceIterator.Builder().addSentenceIterator(new BasicLineIterator(resource_sentences.getFile())).addSentenceIterator(new FileSentenceIterator(resource_mixed.getFile())).build();
    TokenizerFactory t = new DefaultTokenizerFactory();
    t.setTokenPreProcessor(new CommonPreprocessor());
    Word2Vec wordVectors = new Word2Vec.Builder().minWordFrequency(1).batchSize(250).iterations(1).epochs(3).learningRate(0.025).layerSize(150).minLearningRate(0.001).elementsLearningAlgorithm(new SkipGram<VocabWord>()).useHierarchicSoftmax(true).windowSize(5).iterate(iter).tokenizerFactory(t).build();
    wordVectors.fit();
    ParagraphVectors pv = new ParagraphVectors.Builder().tokenizerFactory(t).iterations(10).useHierarchicSoftmax(true).trainWordVectors(true).useExistingWordVectors(wordVectors).negativeSample(0).sequenceLearningAlgorithm(new DM<VocabWord>()).build();
    INDArray vec1 = pv.inferVector("This text is pretty awesome");
    INDArray vec2 = pv.inferVector("Fantastic process of crazy things happening inside just for history purposes");
    log.info("vec1/vec2: {}", Transforms.cosineSim(vec1, vec2));
}
Also used : BasicLineIterator(org.deeplearning4j.text.sentenceiterator.BasicLineIterator) TokenizerFactory(org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory) DefaultTokenizerFactory(org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFactory) SkipGram(org.deeplearning4j.models.embeddings.learning.impl.elements.SkipGram) DM(org.deeplearning4j.models.embeddings.learning.impl.sequence.DM) ClassPathResource(org.datavec.api.util.ClassPathResource) SentenceIterator(org.deeplearning4j.text.sentenceiterator.SentenceIterator) FileSentenceIterator(org.deeplearning4j.text.sentenceiterator.FileSentenceIterator) AggregatingSentenceIterator(org.deeplearning4j.text.sentenceiterator.AggregatingSentenceIterator) DefaultTokenizerFactory(org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFactory) CommonPreprocessor(org.deeplearning4j.text.tokenization.tokenizer.preprocessor.CommonPreprocessor) AggregatingSentenceIterator(org.deeplearning4j.text.sentenceiterator.AggregatingSentenceIterator) INDArray(org.nd4j.linalg.api.ndarray.INDArray) Word2Vec(org.deeplearning4j.models.word2vec.Word2Vec) FileSentenceIterator(org.deeplearning4j.text.sentenceiterator.FileSentenceIterator) Test(org.junit.Test)

Example 4 with SkipGram

use of org.deeplearning4j.models.embeddings.learning.impl.elements.SkipGram in project deeplearning4j by deeplearning4j.

the class ParagraphVectorsTest method testParagraphVectorsOverExistingWordVectorsModel.

/*
        In this test we'll build w2v model, and will use it's vocab and weights for ParagraphVectors.
        there's no need in this test within travis, use it manually only for problems detection
    */
@Test
public void testParagraphVectorsOverExistingWordVectorsModel() throws Exception {
    // we build w2v from multiple sources, to cover everything
    ClassPathResource resource_sentences = new ClassPathResource("/big/raw_sentences.txt");
    ClassPathResource resource_mixed = new ClassPathResource("/paravec");
    SentenceIterator iter = new AggregatingSentenceIterator.Builder().addSentenceIterator(new BasicLineIterator(resource_sentences.getFile())).addSentenceIterator(new FileSentenceIterator(resource_mixed.getFile())).build();
    TokenizerFactory t = new DefaultTokenizerFactory();
    t.setTokenPreProcessor(new CommonPreprocessor());
    Word2Vec wordVectors = new Word2Vec.Builder().minWordFrequency(1).batchSize(250).iterations(1).epochs(3).learningRate(0.025).layerSize(150).minLearningRate(0.001).elementsLearningAlgorithm(new SkipGram<VocabWord>()).useHierarchicSoftmax(true).windowSize(5).iterate(iter).tokenizerFactory(t).build();
    wordVectors.fit();
    VocabWord day_A = wordVectors.getVocab().tokenFor("day");
    INDArray vector_day1 = wordVectors.getWordVectorMatrix("day").dup();
    // At this moment we have ready w2v model. It's time to use it for ParagraphVectors
    FileLabelAwareIterator labelAwareIterator = new FileLabelAwareIterator.Builder().addSourceFolder(new ClassPathResource("/paravec/labeled").getFile()).build();
    // documents from this iterator will be used for classification
    FileLabelAwareIterator unlabeledIterator = new FileLabelAwareIterator.Builder().addSourceFolder(new ClassPathResource("/paravec/unlabeled").getFile()).build();
    // we're building classifier now, with pre-built w2v model passed in
    ParagraphVectors paragraphVectors = new ParagraphVectors.Builder().iterate(labelAwareIterator).learningRate(0.025).minLearningRate(0.001).iterations(5).epochs(1).layerSize(150).tokenizerFactory(t).sequenceLearningAlgorithm(new DBOW<VocabWord>()).useHierarchicSoftmax(true).trainWordVectors(false).useExistingWordVectors(wordVectors).build();
    paragraphVectors.fit();
    VocabWord day_B = paragraphVectors.getVocab().tokenFor("day");
    assertEquals(day_A.getIndex(), day_B.getIndex());
    /*
        double similarityD = wordVectors.similarity("day", "night");
        log.info("day/night similarity: " + similarityD);
        assertTrue(similarityD > 0.5d);
        */
    INDArray vector_day2 = paragraphVectors.getWordVectorMatrix("day").dup();
    double crossDay = arraysSimilarity(vector_day1, vector_day2);
    log.info("Day1: " + vector_day1);
    log.info("Day2: " + vector_day2);
    log.info("Cross-Day similarity: " + crossDay);
    log.info("Cross-Day similiarity 2: " + Transforms.cosineSim(vector_day1, vector_day2));
    assertTrue(crossDay > 0.9d);
    /**
         *
         * Here we're checking cross-vocabulary equality
         *
         */
    /*
        Random rnd = new Random();
        VocabCache<VocabWord> cacheP = paragraphVectors.getVocab();
        VocabCache<VocabWord> cacheW = wordVectors.getVocab();
        for (int x = 0; x < 1000; x++) {
            int idx = rnd.nextInt(cacheW.numWords());
        
            String wordW = cacheW.wordAtIndex(idx);
            String wordP = cacheP.wordAtIndex(idx);
        
            assertEquals(wordW, wordP);
        
            INDArray arrayW = wordVectors.getWordVectorMatrix(wordW);
            INDArray arrayP = paragraphVectors.getWordVectorMatrix(wordP);
        
            double simWP = Transforms.cosineSim(arrayW, arrayP);
            assertTrue(simWP >= 0.9);
        }
        */
    log.info("Zfinance: " + paragraphVectors.getWordVectorMatrix("Zfinance"));
    log.info("Zhealth: " + paragraphVectors.getWordVectorMatrix("Zhealth"));
    log.info("Zscience: " + paragraphVectors.getWordVectorMatrix("Zscience"));
    LabelledDocument document = unlabeledIterator.nextDocument();
    log.info("Results for document '" + document.getLabel() + "'");
    List<String> results = new ArrayList<>(paragraphVectors.predictSeveral(document, 3));
    for (String result : results) {
        double sim = paragraphVectors.similarityToLabel(document, result);
        log.info("Similarity to [" + result + "] is [" + sim + "]");
    }
    String topPrediction = paragraphVectors.predict(document);
    assertEquals("Zfinance", topPrediction);
}
Also used : BasicLineIterator(org.deeplearning4j.text.sentenceiterator.BasicLineIterator) TokenizerFactory(org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory) DefaultTokenizerFactory(org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFactory) SkipGram(org.deeplearning4j.models.embeddings.learning.impl.elements.SkipGram) FileLabelAwareIterator(org.deeplearning4j.text.documentiterator.FileLabelAwareIterator) ArrayList(java.util.ArrayList) VocabWord(org.deeplearning4j.models.word2vec.VocabWord) ClassPathResource(org.datavec.api.util.ClassPathResource) SentenceIterator(org.deeplearning4j.text.sentenceiterator.SentenceIterator) FileSentenceIterator(org.deeplearning4j.text.sentenceiterator.FileSentenceIterator) AggregatingSentenceIterator(org.deeplearning4j.text.sentenceiterator.AggregatingSentenceIterator) LabelledDocument(org.deeplearning4j.text.documentiterator.LabelledDocument) DefaultTokenizerFactory(org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFactory) CommonPreprocessor(org.deeplearning4j.text.tokenization.tokenizer.preprocessor.CommonPreprocessor) AggregatingSentenceIterator(org.deeplearning4j.text.sentenceiterator.AggregatingSentenceIterator) INDArray(org.nd4j.linalg.api.ndarray.INDArray) Word2Vec(org.deeplearning4j.models.word2vec.Word2Vec) FileSentenceIterator(org.deeplearning4j.text.sentenceiterator.FileSentenceIterator) Test(org.junit.Test)

Aggregations

SkipGram (org.deeplearning4j.models.embeddings.learning.impl.elements.SkipGram)4 BasicLineIterator (org.deeplearning4j.text.sentenceiterator.BasicLineIterator)4 SentenceIterator (org.deeplearning4j.text.sentenceiterator.SentenceIterator)4 CommonPreprocessor (org.deeplearning4j.text.tokenization.tokenizer.preprocessor.CommonPreprocessor)4 DefaultTokenizerFactory (org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFactory)4 TokenizerFactory (org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory)4 Test (org.junit.Test)4 INDArray (org.nd4j.linalg.api.ndarray.INDArray)3 File (java.io.File)2 ArrayList (java.util.ArrayList)2 ClassPathResource (org.datavec.api.util.ClassPathResource)2 BasicModelUtils (org.deeplearning4j.models.embeddings.reader.impl.BasicModelUtils)2 Word2Vec (org.deeplearning4j.models.word2vec.Word2Vec)2 AggregatingSentenceIterator (org.deeplearning4j.text.sentenceiterator.AggregatingSentenceIterator)2 FileSentenceIterator (org.deeplearning4j.text.sentenceiterator.FileSentenceIterator)2 UimaSentenceIterator (org.deeplearning4j.text.sentenceiterator.UimaSentenceIterator)2 DM (org.deeplearning4j.models.embeddings.learning.impl.sequence.DM)1 FlatModelUtils (org.deeplearning4j.models.embeddings.reader.impl.FlatModelUtils)1 VocabWord (org.deeplearning4j.models.word2vec.VocabWord)1 FileLabelAwareIterator (org.deeplearning4j.text.documentiterator.FileLabelAwareIterator)1