Search in sources :

Example 6 with AbstractCache

use of org.deeplearning4j.models.word2vec.wordstore.inmemory.AbstractCache in project deeplearning4j by deeplearning4j.

the class VocabConstructor method buildJointVocabulary.

/**
     * This method scans all sources passed through builder, and returns all words as vocab.
     * If TargetVocabCache was set during instance creation, it'll be filled too.
     *
     *
     * @return
     */
public VocabCache<T> buildJointVocabulary(boolean resetCounters, boolean buildHuffmanTree) {
    long lastTime = System.currentTimeMillis();
    long lastSequences = 0;
    long lastElements = 0;
    long startTime = lastTime;
    long startWords = 0;
    AtomicLong parsedCount = new AtomicLong(0);
    if (resetCounters && buildHuffmanTree)
        throw new IllegalStateException("You can't reset counters and build Huffman tree at the same time!");
    if (cache == null)
        cache = new AbstractCache.Builder<T>().build();
    log.debug("Target vocab size before building: [" + cache.numWords() + "]");
    final AtomicLong loopCounter = new AtomicLong(0);
    AbstractCache<T> topHolder = new AbstractCache.Builder<T>().minElementFrequency(0).build();
    int cnt = 0;
    int numProc = Runtime.getRuntime().availableProcessors();
    int numThreads = Math.max(numProc / 2, 2);
    ExecutorService executorService = new ThreadPoolExecutor(numThreads, numThreads, 0L, TimeUnit.MILLISECONDS, new LinkedTransferQueue<Runnable>());
    final AtomicLong execCounter = new AtomicLong(0);
    final AtomicLong finCounter = new AtomicLong(0);
    for (VocabSource<T> source : sources) {
        SequenceIterator<T> iterator = source.getIterator();
        iterator.reset();
        log.debug("Trying source iterator: [" + cnt + "]");
        log.debug("Target vocab size before building: [" + cache.numWords() + "]");
        cnt++;
        AbstractCache<T> tempHolder = new AbstractCache.Builder<T>().build();
        List<Long> timesHasNext = new ArrayList<>();
        List<Long> timesNext = new ArrayList<>();
        int sequences = 0;
        long time3 = 0;
        while (iterator.hasMoreSequences()) {
            Sequence<T> document = iterator.nextSequence();
            seqCount.incrementAndGet();
            parsedCount.addAndGet(document.size());
            tempHolder.incrementTotalDocCount();
            execCounter.incrementAndGet();
            VocabRunnable runnable = new VocabRunnable(tempHolder, document, finCounter, loopCounter);
            executorService.execute(runnable);
            // if we're not in parallel mode - wait till this runnable finishes
            if (!allowParallelBuilder) {
                while (execCounter.get() != finCounter.get()) LockSupport.parkNanos(1000);
            }
            while (execCounter.get() - finCounter.get() > numProc) {
                try {
                    Thread.sleep(1);
                } catch (Exception e) {
                }
            }
            sequences++;
            if (seqCount.get() % 100000 == 0) {
                long currentTime = System.currentTimeMillis();
                long currentSequences = seqCount.get();
                long currentElements = parsedCount.get();
                double seconds = (currentTime - lastTime) / (double) 1000;
                //                    Collections.sort(timesHasNext);
                //                    Collections.sort(timesNext);
                double seqPerSec = (currentSequences - lastSequences) / seconds;
                double elPerSec = (currentElements - lastElements) / seconds;
                //                    log.info("Document time: {} us; hasNext time: {} us", timesNext.get(timesNext.size() / 2), timesHasNext.get(timesHasNext.size() / 2));
                log.info("Sequences checked: [{}]; Current vocabulary size: [{}]; Sequences/sec: {}; Words/sec: {};", seqCount.get(), tempHolder.numWords(), String.format("%.2f", seqPerSec), String.format("%.2f", elPerSec));
                lastTime = currentTime;
                lastElements = currentElements;
                lastSequences = currentSequences;
            //                    timesHasNext.clear();
            //                    timesNext.clear();
            }
            /**
                 * Firing scavenger loop
                 */
            if (enableScavenger && loopCounter.get() >= 2000000 && tempHolder.numWords() > 10000000) {
                log.info("Starting scavenger...");
                while (execCounter.get() != finCounter.get()) {
                    try {
                        Thread.sleep(2);
                    } catch (Exception e) {
                    }
                }
                filterVocab(tempHolder, Math.max(1, source.getMinWordFrequency() / 2));
                loopCounter.set(0);
            }
        //                timesNext.add((time2 - time1) / 1000L);
        //                timesHasNext.add((time1 - time3) / 1000L);
        //                time3 = System.nanoTime();
        }
        // block untill all threads are finished
        log.debug("Wating till all processes stop...");
        while (execCounter.get() != finCounter.get()) {
            try {
                Thread.sleep(2);
            } catch (Exception e) {
            }
        }
        // apply minWordFrequency set for this source
        log.debug("Vocab size before truncation: [" + tempHolder.numWords() + "],  NumWords: [" + tempHolder.totalWordOccurrences() + "], sequences parsed: [" + seqCount.get() + "], counter: [" + parsedCount.get() + "]");
        if (source.getMinWordFrequency() > 0) {
            filterVocab(tempHolder, source.getMinWordFrequency());
        }
        log.debug("Vocab size after truncation: [" + tempHolder.numWords() + "],  NumWords: [" + tempHolder.totalWordOccurrences() + "], sequences parsed: [" + seqCount.get() + "], counter: [" + parsedCount.get() + "]");
        // at this moment we're ready to transfer
        topHolder.importVocabulary(tempHolder);
    }
    // at this moment, we have vocabulary full of words, and we have to reset counters before transfer everything back to VocabCache
    //topHolder.resetWordCounters();
    System.gc();
    System.gc();
    try {
        Thread.sleep(1000);
    } catch (Exception e) {
    //
    }
    cache.importVocabulary(topHolder);
    // adding UNK word
    if (unk != null) {
        log.info("Adding UNK element to vocab...");
        unk.setSpecial(true);
        cache.addToken(unk);
    }
    if (resetCounters) {
        for (T element : cache.vocabWords()) {
            element.setElementFrequency(0);
        }
        cache.updateWordsOccurencies();
    }
    if (buildHuffmanTree) {
        Huffman huffman = new Huffman(cache.vocabWords());
        huffman.build();
        huffman.applyIndexes(cache);
        if (limit > 0) {
            LinkedBlockingQueue<String> labelsToRemove = new LinkedBlockingQueue<>();
            for (T element : cache.vocabWords()) {
                if (element.getIndex() > limit && !element.isSpecial() && !element.isLabel())
                    labelsToRemove.add(element.getLabel());
            }
            for (String label : labelsToRemove) {
                cache.removeElement(label);
            }
        }
    }
    executorService.shutdown();
    System.gc();
    System.gc();
    try {
        Thread.sleep(1000);
    } catch (Exception e) {
    //
    }
    long endSequences = seqCount.get();
    long endTime = System.currentTimeMillis();
    double seconds = (endTime - startTime) / (double) 1000;
    double seqPerSec = endSequences / seconds;
    log.info("Sequences checked: [{}], Current vocabulary size: [{}]; Sequences/sec: [{}];", seqCount.get(), cache.numWords(), String.format("%.2f", seqPerSec));
    return cache;
}
Also used : AbstractCache(org.deeplearning4j.models.word2vec.wordstore.inmemory.AbstractCache) AtomicLong(java.util.concurrent.atomic.AtomicLong) AtomicLong(java.util.concurrent.atomic.AtomicLong) Huffman(org.deeplearning4j.models.word2vec.Huffman)

Example 7 with AbstractCache

use of org.deeplearning4j.models.word2vec.wordstore.inmemory.AbstractCache in project deeplearning4j by deeplearning4j.

the class ParagraphVectorsTest method testParagraphVectorsDM.

@Test
public void testParagraphVectorsDM() throws Exception {
    ClassPathResource resource = new ClassPathResource("/big/raw_sentences.txt");
    File file = resource.getFile();
    SentenceIterator iter = new BasicLineIterator(file);
    AbstractCache<VocabWord> cache = new AbstractCache.Builder<VocabWord>().build();
    TokenizerFactory t = new DefaultTokenizerFactory();
    t.setTokenPreProcessor(new CommonPreprocessor());
    LabelsSource source = new LabelsSource("DOC_");
    ParagraphVectors vec = new ParagraphVectors.Builder().minWordFrequency(1).iterations(2).seed(119).epochs(3).layerSize(100).learningRate(0.025).labelsSource(source).windowSize(5).iterate(iter).trainWordVectors(true).vocabCache(cache).tokenizerFactory(t).negativeSample(0).useHierarchicSoftmax(true).sampling(0).workers(1).usePreciseWeightInit(true).sequenceLearningAlgorithm(new DM<VocabWord>()).build();
    vec.fit();
    int cnt1 = cache.wordFrequency("day");
    int cnt2 = cache.wordFrequency("me");
    assertNotEquals(1, cnt1);
    assertNotEquals(1, cnt2);
    assertNotEquals(cnt1, cnt2);
    double simDN = vec.similarity("day", "night");
    log.info("day/night similariry: {}", simDN);
    double similarity1 = vec.similarity("DOC_9835", "DOC_12492");
    log.info("9835/12492 similarity: " + similarity1);
    //        assertTrue(similarity1 > 0.2d);
    double similarity2 = vec.similarity("DOC_3720", "DOC_16392");
    log.info("3720/16392 similarity: " + similarity2);
    //      assertTrue(similarity2 > 0.2d);
    double similarity3 = vec.similarity("DOC_6347", "DOC_3720");
    log.info("6347/3720 similarity: " + similarity3);
    //        assertTrue(similarity3 > 0.6d);
    double similarityX = vec.similarity("DOC_3720", "DOC_9852");
    log.info("3720/9852 similarity: " + similarityX);
    assertTrue(similarityX < 0.5d);
    // testing DM inference now
    INDArray original = vec.getWordVectorMatrix("DOC_16392").dup();
    INDArray inferredA1 = vec.inferVector("This is my work");
    INDArray inferredB1 = vec.inferVector("This is my work .");
    double cosAO1 = Transforms.cosineSim(inferredA1.dup(), original.dup());
    double cosAB1 = Transforms.cosineSim(inferredA1.dup(), inferredB1.dup());
    log.info("Cos O/A: {}", cosAO1);
    log.info("Cos A/B: {}", cosAB1);
}
Also used : BasicLineIterator(org.deeplearning4j.text.sentenceiterator.BasicLineIterator) TokenizerFactory(org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory) DefaultTokenizerFactory(org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFactory) VocabWord(org.deeplearning4j.models.word2vec.VocabWord) DM(org.deeplearning4j.models.embeddings.learning.impl.sequence.DM) AbstractCache(org.deeplearning4j.models.word2vec.wordstore.inmemory.AbstractCache) ClassPathResource(org.datavec.api.util.ClassPathResource) SentenceIterator(org.deeplearning4j.text.sentenceiterator.SentenceIterator) FileSentenceIterator(org.deeplearning4j.text.sentenceiterator.FileSentenceIterator) AggregatingSentenceIterator(org.deeplearning4j.text.sentenceiterator.AggregatingSentenceIterator) DefaultTokenizerFactory(org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFactory) CommonPreprocessor(org.deeplearning4j.text.tokenization.tokenizer.preprocessor.CommonPreprocessor) INDArray(org.nd4j.linalg.api.ndarray.INDArray) LabelsSource(org.deeplearning4j.text.documentiterator.LabelsSource) File(java.io.File) Test(org.junit.Test)

Example 8 with AbstractCache

use of org.deeplearning4j.models.word2vec.wordstore.inmemory.AbstractCache in project deeplearning4j by deeplearning4j.

the class ParagraphVectorsTest method testParagraphVectorsWithWordVectorsModelling1.

@Test
public void testParagraphVectorsWithWordVectorsModelling1() throws Exception {
    ClassPathResource resource = new ClassPathResource("/big/raw_sentences.txt");
    File file = resource.getFile();
    SentenceIterator iter = new BasicLineIterator(file);
    //        InMemoryLookupCache cache = new InMemoryLookupCache(false);
    AbstractCache<VocabWord> cache = new AbstractCache.Builder<VocabWord>().build();
    TokenizerFactory t = new DefaultTokenizerFactory();
    t.setTokenPreProcessor(new CommonPreprocessor());
    LabelsSource source = new LabelsSource("DOC_");
    ParagraphVectors vec = new ParagraphVectors.Builder().minWordFrequency(1).iterations(3).epochs(1).layerSize(100).learningRate(0.025).labelsSource(source).windowSize(5).iterate(iter).trainWordVectors(true).vocabCache(cache).tokenizerFactory(t).sampling(0).build();
    vec.fit();
    int cnt1 = cache.wordFrequency("day");
    int cnt2 = cache.wordFrequency("me");
    assertNotEquals(1, cnt1);
    assertNotEquals(1, cnt2);
    assertNotEquals(cnt1, cnt2);
    /*
            We have few lines that contain pretty close words invloved.
            These sentences should be pretty close to each other in vector space
         */
    // line 3721: This is my way .
    // line 6348: This is my case .
    // line 9836: This is my house .
    // line 12493: This is my world .
    // line 16393: This is my work .
    // this is special sentence, that has nothing common with previous sentences
    // line 9853: We now have one .
    assertTrue(vec.hasWord("DOC_3720"));
    double similarityD = vec.similarity("day", "night");
    log.info("day/night similarity: " + similarityD);
    double similarityW = vec.similarity("way", "work");
    log.info("way/work similarity: " + similarityW);
    double similarityH = vec.similarity("house", "world");
    log.info("house/world similarity: " + similarityH);
    double similarityC = vec.similarity("case", "way");
    log.info("case/way similarity: " + similarityC);
    double similarity1 = vec.similarity("DOC_9835", "DOC_12492");
    log.info("9835/12492 similarity: " + similarity1);
    //        assertTrue(similarity1 > 0.7d);
    double similarity2 = vec.similarity("DOC_3720", "DOC_16392");
    log.info("3720/16392 similarity: " + similarity2);
    //        assertTrue(similarity2 > 0.7d);
    double similarity3 = vec.similarity("DOC_6347", "DOC_3720");
    log.info("6347/3720 similarity: " + similarity3);
    //        assertTrue(similarity2 > 0.7d);
    // likelihood in this case should be significantly lower
    // however, since corpus is small, and weight initialization is random-based, sometimes this test CAN fail
    double similarityX = vec.similarity("DOC_3720", "DOC_9852");
    log.info("3720/9852 similarity: " + similarityX);
    assertTrue(similarityX < 0.5d);
    double sim119 = vec.similarityToLabel("This is my case .", "DOC_6347");
    double sim120 = vec.similarityToLabel("This is my case .", "DOC_3720");
    log.info("1/2: " + sim119 + "/" + sim120);
//assertEquals(similarity3, sim119, 0.001);
}
Also used : BasicLineIterator(org.deeplearning4j.text.sentenceiterator.BasicLineIterator) TokenizerFactory(org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory) DefaultTokenizerFactory(org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFactory) VocabWord(org.deeplearning4j.models.word2vec.VocabWord) AbstractCache(org.deeplearning4j.models.word2vec.wordstore.inmemory.AbstractCache) ClassPathResource(org.datavec.api.util.ClassPathResource) SentenceIterator(org.deeplearning4j.text.sentenceiterator.SentenceIterator) FileSentenceIterator(org.deeplearning4j.text.sentenceiterator.FileSentenceIterator) AggregatingSentenceIterator(org.deeplearning4j.text.sentenceiterator.AggregatingSentenceIterator) DefaultTokenizerFactory(org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFactory) CommonPreprocessor(org.deeplearning4j.text.tokenization.tokenizer.preprocessor.CommonPreprocessor) LabelsSource(org.deeplearning4j.text.documentiterator.LabelsSource) File(java.io.File) Test(org.junit.Test)

Example 9 with AbstractCache

use of org.deeplearning4j.models.word2vec.wordstore.inmemory.AbstractCache in project deeplearning4j by deeplearning4j.

the class SequenceVectorsTest method testDeepWalk.

@Test
@Ignore
public void testDeepWalk() throws Exception {
    Heartbeat.getInstance().disableHeartbeat();
    AbstractCache<Blogger> vocabCache = new AbstractCache.Builder<Blogger>().build();
    Graph<Blogger, Double> graph = buildGraph();
    GraphWalker<Blogger> walker = new PopularityWalker.Builder<>(graph).setNoEdgeHandling(NoEdgeHandling.RESTART_ON_DISCONNECTED).setWalkLength(40).setWalkDirection(WalkDirection.FORWARD_UNIQUE).setRestartProbability(0.05).setPopularitySpread(10).setPopularityMode(PopularityMode.MAXIMUM).setSpreadSpectrum(SpreadSpectrum.PROPORTIONAL).build();
    /*
        GraphWalker<Blogger> walker = new RandomWalker.Builder<Blogger>(graph)
                .setNoEdgeHandling(NoEdgeHandling.RESTART_ON_DISCONNECTED)
                .setWalkLength(40)
                .setWalkDirection(WalkDirection.RANDOM)
                .setRestartProbability(0.05)
                .build();
        */
    GraphTransformer<Blogger> graphTransformer = new GraphTransformer.Builder<>(graph).setGraphWalker(walker).shuffleOnReset(true).setVocabCache(vocabCache).build();
    Blogger blogger = graph.getVertex(0).getValue();
    assertEquals(119, blogger.getElementFrequency(), 0.001);
    logger.info("Blogger: " + blogger);
    AbstractSequenceIterator<Blogger> sequenceIterator = new AbstractSequenceIterator.Builder<>(graphTransformer).build();
    WeightLookupTable<Blogger> lookupTable = new InMemoryLookupTable.Builder<Blogger>().lr(0.025).vectorLength(150).useAdaGrad(false).cache(vocabCache).seed(42).build();
    lookupTable.resetWeights(true);
    SequenceVectors<Blogger> vectors = new SequenceVectors.Builder<Blogger>(new VectorsConfiguration()).lookupTable(lookupTable).iterate(sequenceIterator).vocabCache(vocabCache).batchSize(1000).iterations(1).epochs(10).resetModel(false).trainElementsRepresentation(true).trainSequencesRepresentation(false).elementsLearningAlgorithm(new SkipGram<Blogger>()).learningRate(0.025).layerSize(150).sampling(0).negativeSample(0).windowSize(4).workers(6).seed(42).build();
    vectors.fit();
    vectors.setModelUtils(new FlatModelUtils());
    //     logger.info("12: " + Arrays.toString(vectors.getWordVector("12")));
    double sim = vectors.similarity("12", "72");
    Collection<String> list = vectors.wordsNearest("12", 20);
    logger.info("12->72: " + sim);
    printWords("12", list, vectors);
    assertTrue(sim > 0.10);
    assertFalse(Double.isNaN(sim));
}
Also used : VectorsConfiguration(org.deeplearning4j.models.embeddings.loader.VectorsConfiguration) AbstractCache(org.deeplearning4j.models.word2vec.wordstore.inmemory.AbstractCache) AbstractSequenceIterator(org.deeplearning4j.models.sequencevectors.iterators.AbstractSequenceIterator) FlatModelUtils(org.deeplearning4j.models.embeddings.reader.impl.FlatModelUtils) Ignore(org.junit.Ignore) Test(org.junit.Test)

Example 10 with AbstractCache

use of org.deeplearning4j.models.word2vec.wordstore.inmemory.AbstractCache in project deeplearning4j by deeplearning4j.

the class SparkSequenceVectors method buildShallowVocabCache.

/**
     * This method builds shadow vocabulary and huffman tree
     *
     * @param counter
     * @return
     */
protected VocabCache<ShallowSequenceElement> buildShallowVocabCache(Counter<Long> counter) {
    // TODO: need simplified cache here, that will operate on Long instead of string labels
    VocabCache<ShallowSequenceElement> vocabCache = new AbstractCache<>();
    for (Long id : counter.keySet()) {
        ShallowSequenceElement shallowElement = new ShallowSequenceElement(counter.getCount(id), id);
        vocabCache.addToken(shallowElement);
    }
    // building huffman tree
    Huffman huffman = new Huffman(vocabCache.vocabWords());
    huffman.build();
    huffman.applyIndexes(vocabCache);
    return vocabCache;
}
Also used : ShallowSequenceElement(org.deeplearning4j.models.sequencevectors.sequence.ShallowSequenceElement) Huffman(org.deeplearning4j.models.word2vec.Huffman) AbstractCache(org.deeplearning4j.models.word2vec.wordstore.inmemory.AbstractCache)

Aggregations

AbstractCache (org.deeplearning4j.models.word2vec.wordstore.inmemory.AbstractCache)21 VocabWord (org.deeplearning4j.models.word2vec.VocabWord)17 Test (org.junit.Test)12 BasicLineIterator (org.deeplearning4j.text.sentenceiterator.BasicLineIterator)11 ClassPathResource (org.datavec.api.util.ClassPathResource)9 InMemoryLookupTable (org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable)9 INDArray (org.nd4j.linalg.api.ndarray.INDArray)9 TokenizerFactory (org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory)8 File (java.io.File)7 ArrayList (java.util.ArrayList)7 AbstractSequenceIterator (org.deeplearning4j.models.sequencevectors.iterators.AbstractSequenceIterator)7 CommonPreprocessor (org.deeplearning4j.text.tokenization.tokenizer.preprocessor.CommonPreprocessor)7 DefaultTokenizerFactory (org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFactory)7 SentenceTransformer (org.deeplearning4j.models.sequencevectors.transformers.impl.SentenceTransformer)6 VocabConstructor (org.deeplearning4j.models.word2vec.wordstore.VocabConstructor)4 Pair (org.deeplearning4j.berkeley.Pair)3 VectorsConfiguration (org.deeplearning4j.models.embeddings.loader.VectorsConfiguration)3 LabelsSource (org.deeplearning4j.text.documentiterator.LabelsSource)3 AggregatingSentenceIterator (org.deeplearning4j.text.sentenceiterator.AggregatingSentenceIterator)3 FileSentenceIterator (org.deeplearning4j.text.sentenceiterator.FileSentenceIterator)3