Search in sources :

Example 1 with AbstractCache

use of org.deeplearning4j.models.word2vec.wordstore.inmemory.AbstractCache in project deeplearning4j by deeplearning4j.

the class VocabConstructor method buildJointVocabulary.

/**
     * This method scans all sources passed through builder, and returns all words as vocab.
     * If TargetVocabCache was set during instance creation, it'll be filled too.
     *
     *
     * @return
     */
public VocabCache<T> buildJointVocabulary(boolean resetCounters, boolean buildHuffmanTree) {
    long lastTime = System.currentTimeMillis();
    long lastSequences = 0;
    long lastElements = 0;
    long startTime = lastTime;
    long startWords = 0;
    AtomicLong parsedCount = new AtomicLong(0);
    if (resetCounters && buildHuffmanTree)
        throw new IllegalStateException("You can't reset counters and build Huffman tree at the same time!");
    if (cache == null)
        cache = new AbstractCache.Builder<T>().build();
    log.debug("Target vocab size before building: [" + cache.numWords() + "]");
    final AtomicLong loopCounter = new AtomicLong(0);
    AbstractCache<T> topHolder = new AbstractCache.Builder<T>().minElementFrequency(0).build();
    int cnt = 0;
    int numProc = Runtime.getRuntime().availableProcessors();
    int numThreads = Math.max(numProc / 2, 2);
    ExecutorService executorService = new ThreadPoolExecutor(numThreads, numThreads, 0L, TimeUnit.MILLISECONDS, new LinkedTransferQueue<Runnable>());
    final AtomicLong execCounter = new AtomicLong(0);
    final AtomicLong finCounter = new AtomicLong(0);
    for (VocabSource<T> source : sources) {
        SequenceIterator<T> iterator = source.getIterator();
        iterator.reset();
        log.debug("Trying source iterator: [" + cnt + "]");
        log.debug("Target vocab size before building: [" + cache.numWords() + "]");
        cnt++;
        AbstractCache<T> tempHolder = new AbstractCache.Builder<T>().build();
        List<Long> timesHasNext = new ArrayList<>();
        List<Long> timesNext = new ArrayList<>();
        int sequences = 0;
        long time3 = 0;
        while (iterator.hasMoreSequences()) {
            Sequence<T> document = iterator.nextSequence();
            seqCount.incrementAndGet();
            parsedCount.addAndGet(document.size());
            tempHolder.incrementTotalDocCount();
            execCounter.incrementAndGet();
            VocabRunnable runnable = new VocabRunnable(tempHolder, document, finCounter, loopCounter);
            executorService.execute(runnable);
            // if we're not in parallel mode - wait till this runnable finishes
            if (!allowParallelBuilder) {
                while (execCounter.get() != finCounter.get()) LockSupport.parkNanos(1000);
            }
            while (execCounter.get() - finCounter.get() > numProc) {
                try {
                    Thread.sleep(1);
                } catch (Exception e) {
                }
            }
            sequences++;
            if (seqCount.get() % 100000 == 0) {
                long currentTime = System.currentTimeMillis();
                long currentSequences = seqCount.get();
                long currentElements = parsedCount.get();
                double seconds = (currentTime - lastTime) / (double) 1000;
                //                    Collections.sort(timesHasNext);
                //                    Collections.sort(timesNext);
                double seqPerSec = (currentSequences - lastSequences) / seconds;
                double elPerSec = (currentElements - lastElements) / seconds;
                //                    log.info("Document time: {} us; hasNext time: {} us", timesNext.get(timesNext.size() / 2), timesHasNext.get(timesHasNext.size() / 2));
                log.info("Sequences checked: [{}]; Current vocabulary size: [{}]; Sequences/sec: {}; Words/sec: {};", seqCount.get(), tempHolder.numWords(), String.format("%.2f", seqPerSec), String.format("%.2f", elPerSec));
                lastTime = currentTime;
                lastElements = currentElements;
                lastSequences = currentSequences;
            //                    timesHasNext.clear();
            //                    timesNext.clear();
            }
            /**
                 * Firing scavenger loop
                 */
            if (enableScavenger && loopCounter.get() >= 2000000 && tempHolder.numWords() > 10000000) {
                log.info("Starting scavenger...");
                while (execCounter.get() != finCounter.get()) {
                    try {
                        Thread.sleep(2);
                    } catch (Exception e) {
                    }
                }
                filterVocab(tempHolder, Math.max(1, source.getMinWordFrequency() / 2));
                loopCounter.set(0);
            }
        //                timesNext.add((time2 - time1) / 1000L);
        //                timesHasNext.add((time1 - time3) / 1000L);
        //                time3 = System.nanoTime();
        }
        // block untill all threads are finished
        log.debug("Wating till all processes stop...");
        while (execCounter.get() != finCounter.get()) {
            try {
                Thread.sleep(2);
            } catch (Exception e) {
            }
        }
        // apply minWordFrequency set for this source
        log.debug("Vocab size before truncation: [" + tempHolder.numWords() + "],  NumWords: [" + tempHolder.totalWordOccurrences() + "], sequences parsed: [" + seqCount.get() + "], counter: [" + parsedCount.get() + "]");
        if (source.getMinWordFrequency() > 0) {
            filterVocab(tempHolder, source.getMinWordFrequency());
        }
        log.debug("Vocab size after truncation: [" + tempHolder.numWords() + "],  NumWords: [" + tempHolder.totalWordOccurrences() + "], sequences parsed: [" + seqCount.get() + "], counter: [" + parsedCount.get() + "]");
        // at this moment we're ready to transfer
        topHolder.importVocabulary(tempHolder);
    }
    // at this moment, we have vocabulary full of words, and we have to reset counters before transfer everything back to VocabCache
    //topHolder.resetWordCounters();
    System.gc();
    System.gc();
    try {
        Thread.sleep(1000);
    } catch (Exception e) {
    //
    }
    cache.importVocabulary(topHolder);
    // adding UNK word
    if (unk != null) {
        log.info("Adding UNK element to vocab...");
        unk.setSpecial(true);
        cache.addToken(unk);
    }
    if (resetCounters) {
        for (T element : cache.vocabWords()) {
            element.setElementFrequency(0);
        }
        cache.updateWordsOccurencies();
    }
    if (buildHuffmanTree) {
        Huffman huffman = new Huffman(cache.vocabWords());
        huffman.build();
        huffman.applyIndexes(cache);
        if (limit > 0) {
            LinkedBlockingQueue<String> labelsToRemove = new LinkedBlockingQueue<>();
            for (T element : cache.vocabWords()) {
                if (element.getIndex() > limit && !element.isSpecial() && !element.isLabel())
                    labelsToRemove.add(element.getLabel());
            }
            for (String label : labelsToRemove) {
                cache.removeElement(label);
            }
        }
    }
    executorService.shutdown();
    System.gc();
    System.gc();
    try {
        Thread.sleep(1000);
    } catch (Exception e) {
    //
    }
    long endSequences = seqCount.get();
    long endTime = System.currentTimeMillis();
    double seconds = (endTime - startTime) / (double) 1000;
    double seqPerSec = endSequences / seconds;
    log.info("Sequences checked: [{}], Current vocabulary size: [{}]; Sequences/sec: [{}];", seqCount.get(), cache.numWords(), String.format("%.2f", seqPerSec));
    return cache;
}
Also used : AbstractCache(org.deeplearning4j.models.word2vec.wordstore.inmemory.AbstractCache) AtomicLong(java.util.concurrent.atomic.AtomicLong) AtomicLong(java.util.concurrent.atomic.AtomicLong) Huffman(org.deeplearning4j.models.word2vec.Huffman)

Example 2 with AbstractCache

use of org.deeplearning4j.models.word2vec.wordstore.inmemory.AbstractCache in project deeplearning4j by deeplearning4j.

the class ParagraphVectorsTest method testParagraphVectorsDM.

@Test
public void testParagraphVectorsDM() throws Exception {
    ClassPathResource resource = new ClassPathResource("/big/raw_sentences.txt");
    File file = resource.getFile();
    SentenceIterator iter = new BasicLineIterator(file);
    AbstractCache<VocabWord> cache = new AbstractCache.Builder<VocabWord>().build();
    TokenizerFactory t = new DefaultTokenizerFactory();
    t.setTokenPreProcessor(new CommonPreprocessor());
    LabelsSource source = new LabelsSource("DOC_");
    ParagraphVectors vec = new ParagraphVectors.Builder().minWordFrequency(1).iterations(2).seed(119).epochs(3).layerSize(100).learningRate(0.025).labelsSource(source).windowSize(5).iterate(iter).trainWordVectors(true).vocabCache(cache).tokenizerFactory(t).negativeSample(0).useHierarchicSoftmax(true).sampling(0).workers(1).usePreciseWeightInit(true).sequenceLearningAlgorithm(new DM<VocabWord>()).build();
    vec.fit();
    int cnt1 = cache.wordFrequency("day");
    int cnt2 = cache.wordFrequency("me");
    assertNotEquals(1, cnt1);
    assertNotEquals(1, cnt2);
    assertNotEquals(cnt1, cnt2);
    double simDN = vec.similarity("day", "night");
    log.info("day/night similariry: {}", simDN);
    double similarity1 = vec.similarity("DOC_9835", "DOC_12492");
    log.info("9835/12492 similarity: " + similarity1);
    //        assertTrue(similarity1 > 0.2d);
    double similarity2 = vec.similarity("DOC_3720", "DOC_16392");
    log.info("3720/16392 similarity: " + similarity2);
    //      assertTrue(similarity2 > 0.2d);
    double similarity3 = vec.similarity("DOC_6347", "DOC_3720");
    log.info("6347/3720 similarity: " + similarity3);
    //        assertTrue(similarity3 > 0.6d);
    double similarityX = vec.similarity("DOC_3720", "DOC_9852");
    log.info("3720/9852 similarity: " + similarityX);
    assertTrue(similarityX < 0.5d);
    // testing DM inference now
    INDArray original = vec.getWordVectorMatrix("DOC_16392").dup();
    INDArray inferredA1 = vec.inferVector("This is my work");
    INDArray inferredB1 = vec.inferVector("This is my work .");
    double cosAO1 = Transforms.cosineSim(inferredA1.dup(), original.dup());
    double cosAB1 = Transforms.cosineSim(inferredA1.dup(), inferredB1.dup());
    log.info("Cos O/A: {}", cosAO1);
    log.info("Cos A/B: {}", cosAB1);
}
Also used : BasicLineIterator(org.deeplearning4j.text.sentenceiterator.BasicLineIterator) TokenizerFactory(org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory) DefaultTokenizerFactory(org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFactory) VocabWord(org.deeplearning4j.models.word2vec.VocabWord) DM(org.deeplearning4j.models.embeddings.learning.impl.sequence.DM) AbstractCache(org.deeplearning4j.models.word2vec.wordstore.inmemory.AbstractCache) ClassPathResource(org.datavec.api.util.ClassPathResource) SentenceIterator(org.deeplearning4j.text.sentenceiterator.SentenceIterator) FileSentenceIterator(org.deeplearning4j.text.sentenceiterator.FileSentenceIterator) AggregatingSentenceIterator(org.deeplearning4j.text.sentenceiterator.AggregatingSentenceIterator) DefaultTokenizerFactory(org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFactory) CommonPreprocessor(org.deeplearning4j.text.tokenization.tokenizer.preprocessor.CommonPreprocessor) INDArray(org.nd4j.linalg.api.ndarray.INDArray) LabelsSource(org.deeplearning4j.text.documentiterator.LabelsSource) File(java.io.File) Test(org.junit.Test)

Example 3 with AbstractCache

use of org.deeplearning4j.models.word2vec.wordstore.inmemory.AbstractCache in project deeplearning4j by deeplearning4j.

the class ParagraphVectorsTest method testParagraphVectorsWithWordVectorsModelling1.

@Test
public void testParagraphVectorsWithWordVectorsModelling1() throws Exception {
    ClassPathResource resource = new ClassPathResource("/big/raw_sentences.txt");
    File file = resource.getFile();
    SentenceIterator iter = new BasicLineIterator(file);
    //        InMemoryLookupCache cache = new InMemoryLookupCache(false);
    AbstractCache<VocabWord> cache = new AbstractCache.Builder<VocabWord>().build();
    TokenizerFactory t = new DefaultTokenizerFactory();
    t.setTokenPreProcessor(new CommonPreprocessor());
    LabelsSource source = new LabelsSource("DOC_");
    ParagraphVectors vec = new ParagraphVectors.Builder().minWordFrequency(1).iterations(3).epochs(1).layerSize(100).learningRate(0.025).labelsSource(source).windowSize(5).iterate(iter).trainWordVectors(true).vocabCache(cache).tokenizerFactory(t).sampling(0).build();
    vec.fit();
    int cnt1 = cache.wordFrequency("day");
    int cnt2 = cache.wordFrequency("me");
    assertNotEquals(1, cnt1);
    assertNotEquals(1, cnt2);
    assertNotEquals(cnt1, cnt2);
    /*
            We have few lines that contain pretty close words invloved.
            These sentences should be pretty close to each other in vector space
         */
    // line 3721: This is my way .
    // line 6348: This is my case .
    // line 9836: This is my house .
    // line 12493: This is my world .
    // line 16393: This is my work .
    // this is special sentence, that has nothing common with previous sentences
    // line 9853: We now have one .
    assertTrue(vec.hasWord("DOC_3720"));
    double similarityD = vec.similarity("day", "night");
    log.info("day/night similarity: " + similarityD);
    double similarityW = vec.similarity("way", "work");
    log.info("way/work similarity: " + similarityW);
    double similarityH = vec.similarity("house", "world");
    log.info("house/world similarity: " + similarityH);
    double similarityC = vec.similarity("case", "way");
    log.info("case/way similarity: " + similarityC);
    double similarity1 = vec.similarity("DOC_9835", "DOC_12492");
    log.info("9835/12492 similarity: " + similarity1);
    //        assertTrue(similarity1 > 0.7d);
    double similarity2 = vec.similarity("DOC_3720", "DOC_16392");
    log.info("3720/16392 similarity: " + similarity2);
    //        assertTrue(similarity2 > 0.7d);
    double similarity3 = vec.similarity("DOC_6347", "DOC_3720");
    log.info("6347/3720 similarity: " + similarity3);
    //        assertTrue(similarity2 > 0.7d);
    // likelihood in this case should be significantly lower
    // however, since corpus is small, and weight initialization is random-based, sometimes this test CAN fail
    double similarityX = vec.similarity("DOC_3720", "DOC_9852");
    log.info("3720/9852 similarity: " + similarityX);
    assertTrue(similarityX < 0.5d);
    double sim119 = vec.similarityToLabel("This is my case .", "DOC_6347");
    double sim120 = vec.similarityToLabel("This is my case .", "DOC_3720");
    log.info("1/2: " + sim119 + "/" + sim120);
//assertEquals(similarity3, sim119, 0.001);
}
Also used : BasicLineIterator(org.deeplearning4j.text.sentenceiterator.BasicLineIterator) TokenizerFactory(org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory) DefaultTokenizerFactory(org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFactory) VocabWord(org.deeplearning4j.models.word2vec.VocabWord) AbstractCache(org.deeplearning4j.models.word2vec.wordstore.inmemory.AbstractCache) ClassPathResource(org.datavec.api.util.ClassPathResource) SentenceIterator(org.deeplearning4j.text.sentenceiterator.SentenceIterator) FileSentenceIterator(org.deeplearning4j.text.sentenceiterator.FileSentenceIterator) AggregatingSentenceIterator(org.deeplearning4j.text.sentenceiterator.AggregatingSentenceIterator) DefaultTokenizerFactory(org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFactory) CommonPreprocessor(org.deeplearning4j.text.tokenization.tokenizer.preprocessor.CommonPreprocessor) LabelsSource(org.deeplearning4j.text.documentiterator.LabelsSource) File(java.io.File) Test(org.junit.Test)

Example 4 with AbstractCache

use of org.deeplearning4j.models.word2vec.wordstore.inmemory.AbstractCache in project deeplearning4j by deeplearning4j.

the class SequenceVectorsTest method testDeepWalk.

@Test
@Ignore
public void testDeepWalk() throws Exception {
    Heartbeat.getInstance().disableHeartbeat();
    AbstractCache<Blogger> vocabCache = new AbstractCache.Builder<Blogger>().build();
    Graph<Blogger, Double> graph = buildGraph();
    GraphWalker<Blogger> walker = new PopularityWalker.Builder<>(graph).setNoEdgeHandling(NoEdgeHandling.RESTART_ON_DISCONNECTED).setWalkLength(40).setWalkDirection(WalkDirection.FORWARD_UNIQUE).setRestartProbability(0.05).setPopularitySpread(10).setPopularityMode(PopularityMode.MAXIMUM).setSpreadSpectrum(SpreadSpectrum.PROPORTIONAL).build();
    /*
        GraphWalker<Blogger> walker = new RandomWalker.Builder<Blogger>(graph)
                .setNoEdgeHandling(NoEdgeHandling.RESTART_ON_DISCONNECTED)
                .setWalkLength(40)
                .setWalkDirection(WalkDirection.RANDOM)
                .setRestartProbability(0.05)
                .build();
        */
    GraphTransformer<Blogger> graphTransformer = new GraphTransformer.Builder<>(graph).setGraphWalker(walker).shuffleOnReset(true).setVocabCache(vocabCache).build();
    Blogger blogger = graph.getVertex(0).getValue();
    assertEquals(119, blogger.getElementFrequency(), 0.001);
    logger.info("Blogger: " + blogger);
    AbstractSequenceIterator<Blogger> sequenceIterator = new AbstractSequenceIterator.Builder<>(graphTransformer).build();
    WeightLookupTable<Blogger> lookupTable = new InMemoryLookupTable.Builder<Blogger>().lr(0.025).vectorLength(150).useAdaGrad(false).cache(vocabCache).seed(42).build();
    lookupTable.resetWeights(true);
    SequenceVectors<Blogger> vectors = new SequenceVectors.Builder<Blogger>(new VectorsConfiguration()).lookupTable(lookupTable).iterate(sequenceIterator).vocabCache(vocabCache).batchSize(1000).iterations(1).epochs(10).resetModel(false).trainElementsRepresentation(true).trainSequencesRepresentation(false).elementsLearningAlgorithm(new SkipGram<Blogger>()).learningRate(0.025).layerSize(150).sampling(0).negativeSample(0).windowSize(4).workers(6).seed(42).build();
    vectors.fit();
    vectors.setModelUtils(new FlatModelUtils());
    //     logger.info("12: " + Arrays.toString(vectors.getWordVector("12")));
    double sim = vectors.similarity("12", "72");
    Collection<String> list = vectors.wordsNearest("12", 20);
    logger.info("12->72: " + sim);
    printWords("12", list, vectors);
    assertTrue(sim > 0.10);
    assertFalse(Double.isNaN(sim));
}
Also used : VectorsConfiguration(org.deeplearning4j.models.embeddings.loader.VectorsConfiguration) AbstractCache(org.deeplearning4j.models.word2vec.wordstore.inmemory.AbstractCache) AbstractSequenceIterator(org.deeplearning4j.models.sequencevectors.iterators.AbstractSequenceIterator) FlatModelUtils(org.deeplearning4j.models.embeddings.reader.impl.FlatModelUtils) Ignore(org.junit.Ignore) Test(org.junit.Test)

Example 5 with AbstractCache

use of org.deeplearning4j.models.word2vec.wordstore.inmemory.AbstractCache in project deeplearning4j by deeplearning4j.

the class WordVectorSerializer method loadFullModel.

/**
     * This method loads full w2v model, previously saved with writeFullMethod call
     *
     * Deprecation note: Please, consider using readWord2VecModel() or loadStaticModel() method instead
     *
     * @param path - path to previously stored w2v json model
     * @return - Word2Vec instance
     */
@Deprecated
public static Word2Vec loadFullModel(@NonNull String path) throws FileNotFoundException {
    /*
            // TODO: implementation is in process
            We need to restore:
                     1. WeightLookupTable, including syn0 and syn1 matrices
                     2. VocabCache + mark it as SPECIAL, to avoid accidental word removals
         */
    BasicLineIterator iterator = new BasicLineIterator(new File(path));
    // first 3 lines should be processed separately
    String confJson = iterator.nextSentence();
    log.info("Word2Vec conf. JSON: " + confJson);
    VectorsConfiguration configuration = VectorsConfiguration.fromJson(confJson);
    // actually we dont need expTable, since it produces exact results on subsequent runs untill you dont modify expTable size :)
    String eTable = iterator.nextSentence();
    double[] expTable;
    String nTable = iterator.nextSentence();
    if (configuration.getNegative() > 0) {
    // TODO: we probably should parse negTable, but it's not required until vocab changes are introduced. Since on the predefined vocab it will produce exact nTable, the same goes for expTable btw.
    }
    /*
                Since we're restoring vocab from previously serialized model, we can expect minWordFrequency appliance in its vocabulary, so it should NOT be truncated.
                That's why i'm setting minWordFrequency to configuration value, but applying SPECIAL to each word, to avoid truncation
         */
    VocabularyHolder holder = new VocabularyHolder.Builder().minWordFrequency(configuration.getMinWordFrequency()).hugeModelExpected(configuration.isHugeModelExpected()).scavengerActivationThreshold(configuration.getScavengerActivationThreshold()).scavengerRetentionDelay(configuration.getScavengerRetentionDelay()).build();
    AtomicInteger counter = new AtomicInteger(0);
    AbstractCache<VocabWord> vocabCache = new AbstractCache.Builder<VocabWord>().build();
    while (iterator.hasNext()) {
        //    log.info("got line: " + iterator.nextSentence());
        String wordJson = iterator.nextSentence();
        VocabularyWord word = VocabularyWord.fromJson(wordJson);
        word.setSpecial(true);
        VocabWord vw = new VocabWord(word.getCount(), word.getWord());
        vw.setIndex(counter.getAndIncrement());
        vw.setIndex(word.getHuffmanNode().getIdx());
        vw.setCodeLength(word.getHuffmanNode().getLength());
        vw.setPoints(arrayToList(word.getHuffmanNode().getPoint(), word.getHuffmanNode().getLength()));
        vw.setCodes(arrayToList(word.getHuffmanNode().getCode(), word.getHuffmanNode().getLength()));
        vocabCache.addToken(vw);
        vocabCache.addWordToIndex(vw.getIndex(), vw.getLabel());
        vocabCache.putVocabWord(vw.getWord());
    }
    // at this moment vocab is restored, and it's time to rebuild Huffman tree
    // since word counters are equal, huffman tree will be equal too
    //holder.updateHuffmanCodes();
    // we definitely don't need UNK word in this scenarion
    //        holder.transferBackToVocabCache(vocabCache, false);
    // now, it's time to transfer syn0/syn1/syn1 neg values
    InMemoryLookupTable lookupTable = (InMemoryLookupTable) new InMemoryLookupTable.Builder().negative(configuration.getNegative()).useAdaGrad(configuration.isUseAdaGrad()).lr(configuration.getLearningRate()).cache(vocabCache).vectorLength(configuration.getLayersSize()).build();
    // we create all arrays
    lookupTable.resetWeights(true);
    iterator.reset();
    // we should skip 3 lines from file
    iterator.nextSentence();
    iterator.nextSentence();
    iterator.nextSentence();
    // now, for each word from vocabHolder we'll just transfer actual values
    while (iterator.hasNext()) {
        String wordJson = iterator.nextSentence();
        VocabularyWord word = VocabularyWord.fromJson(wordJson);
        // syn0 transfer
        INDArray syn0 = lookupTable.getSyn0().getRow(vocabCache.indexOf(word.getWord()));
        syn0.assign(Nd4j.create(word.getSyn0()));
        // syn1 transfer
        // syn1 values are being accessed via tree points, but since our goal is just deserialization - we can just push it row by row
        INDArray syn1 = lookupTable.getSyn1().getRow(vocabCache.indexOf(word.getWord()));
        syn1.assign(Nd4j.create(word.getSyn1()));
        // syn1Neg transfer
        if (configuration.getNegative() > 0) {
            INDArray syn1Neg = lookupTable.getSyn1Neg().getRow(vocabCache.indexOf(word.getWord()));
            syn1Neg.assign(Nd4j.create(word.getSyn1Neg()));
        }
    }
    Word2Vec vec = new Word2Vec.Builder(configuration).vocabCache(vocabCache).lookupTable(lookupTable).resetModel(false).build();
    vec.setModelUtils(new BasicModelUtils());
    return vec;
}
Also used : BasicLineIterator(org.deeplearning4j.text.sentenceiterator.BasicLineIterator) VocabWord(org.deeplearning4j.models.word2vec.VocabWord) AbstractCache(org.deeplearning4j.models.word2vec.wordstore.inmemory.AbstractCache) VocabularyHolder(org.deeplearning4j.models.word2vec.wordstore.VocabularyHolder) InMemoryLookupTable(org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable) INDArray(org.nd4j.linalg.api.ndarray.INDArray) BasicModelUtils(org.deeplearning4j.models.embeddings.reader.impl.BasicModelUtils) AtomicInteger(java.util.concurrent.atomic.AtomicInteger) VocabularyWord(org.deeplearning4j.models.word2vec.wordstore.VocabularyWord) StaticWord2Vec(org.deeplearning4j.models.word2vec.StaticWord2Vec) Word2Vec(org.deeplearning4j.models.word2vec.Word2Vec) ZipFile(java.util.zip.ZipFile)

Aggregations

AbstractCache (org.deeplearning4j.models.word2vec.wordstore.inmemory.AbstractCache)21 VocabWord (org.deeplearning4j.models.word2vec.VocabWord)17 Test (org.junit.Test)12 BasicLineIterator (org.deeplearning4j.text.sentenceiterator.BasicLineIterator)11 ClassPathResource (org.datavec.api.util.ClassPathResource)9 InMemoryLookupTable (org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable)9 INDArray (org.nd4j.linalg.api.ndarray.INDArray)9 TokenizerFactory (org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory)8 File (java.io.File)7 ArrayList (java.util.ArrayList)7 AbstractSequenceIterator (org.deeplearning4j.models.sequencevectors.iterators.AbstractSequenceIterator)7 CommonPreprocessor (org.deeplearning4j.text.tokenization.tokenizer.preprocessor.CommonPreprocessor)7 DefaultTokenizerFactory (org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFactory)7 SentenceTransformer (org.deeplearning4j.models.sequencevectors.transformers.impl.SentenceTransformer)6 VocabConstructor (org.deeplearning4j.models.word2vec.wordstore.VocabConstructor)4 Pair (org.deeplearning4j.berkeley.Pair)3 VectorsConfiguration (org.deeplearning4j.models.embeddings.loader.VectorsConfiguration)3 LabelsSource (org.deeplearning4j.text.documentiterator.LabelsSource)3 AggregatingSentenceIterator (org.deeplearning4j.text.sentenceiterator.AggregatingSentenceIterator)3 FileSentenceIterator (org.deeplearning4j.text.sentenceiterator.FileSentenceIterator)3