Search in sources :

Example 21 with AbstractCache

use of org.deeplearning4j.models.word2vec.wordstore.inmemory.AbstractCache in project deeplearning4j by deeplearning4j.

the class VocabCacheExporter method export.

@Override
public void export(JavaRDD<ExportContainer<VocabWord>> rdd) {
    // beware, generally that's VERY bad idea, but will work fine for testing purposes
    List<ExportContainer<VocabWord>> list = rdd.collect();
    if (vocabCache == null)
        vocabCache = new AbstractCache<>();
    INDArray syn0 = null;
    // just roll through list
    for (ExportContainer<VocabWord> element : list) {
        VocabWord word = element.getElement();
        INDArray weights = element.getArray();
        if (syn0 == null)
            syn0 = Nd4j.create(list.size(), weights.length());
        vocabCache.addToken(word);
        vocabCache.addWordToIndex(word.getIndex(), word.getLabel());
        syn0.getRow(word.getIndex()).assign(weights);
    }
    if (lookupTable == null)
        lookupTable = new InMemoryLookupTable.Builder<VocabWord>().cache(vocabCache).vectorLength(syn0.columns()).build();
    lookupTable.setSyn0(syn0);
    // this is bad & dirty, but we don't really need anything else for testing :)
    word2Vec = WordVectorSerializer.fromPair(Pair.<InMemoryLookupTable, VocabCache>makePair(lookupTable, vocabCache));
}
Also used : InMemoryLookupTable(org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable) INDArray(org.nd4j.linalg.api.ndarray.INDArray) VocabCache(org.deeplearning4j.models.word2vec.wordstore.VocabCache) ExportContainer(org.deeplearning4j.spark.models.sequencevectors.export.ExportContainer) VocabWord(org.deeplearning4j.models.word2vec.VocabWord) AbstractCache(org.deeplearning4j.models.word2vec.wordstore.inmemory.AbstractCache)

Aggregations

AbstractCache (org.deeplearning4j.models.word2vec.wordstore.inmemory.AbstractCache)21 VocabWord (org.deeplearning4j.models.word2vec.VocabWord)17 Test (org.junit.Test)12 BasicLineIterator (org.deeplearning4j.text.sentenceiterator.BasicLineIterator)11 ClassPathResource (org.datavec.api.util.ClassPathResource)9 InMemoryLookupTable (org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable)9 INDArray (org.nd4j.linalg.api.ndarray.INDArray)9 TokenizerFactory (org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory)8 File (java.io.File)7 ArrayList (java.util.ArrayList)7 AbstractSequenceIterator (org.deeplearning4j.models.sequencevectors.iterators.AbstractSequenceIterator)7 CommonPreprocessor (org.deeplearning4j.text.tokenization.tokenizer.preprocessor.CommonPreprocessor)7 DefaultTokenizerFactory (org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFactory)7 SentenceTransformer (org.deeplearning4j.models.sequencevectors.transformers.impl.SentenceTransformer)6 VocabConstructor (org.deeplearning4j.models.word2vec.wordstore.VocabConstructor)4 Pair (org.deeplearning4j.berkeley.Pair)3 VectorsConfiguration (org.deeplearning4j.models.embeddings.loader.VectorsConfiguration)3 LabelsSource (org.deeplearning4j.text.documentiterator.LabelsSource)3 AggregatingSentenceIterator (org.deeplearning4j.text.sentenceiterator.AggregatingSentenceIterator)3 FileSentenceIterator (org.deeplearning4j.text.sentenceiterator.FileSentenceIterator)3