Search in sources :

Example 6 with VocabCache

use of org.deeplearning4j.models.word2vec.wordstore.VocabCache in project deeplearning4j by deeplearning4j.

the class Word2Vec method train.

/**
     *  Training word2vec model on a given text corpus
     *
     * @param corpusRDD training corpus
     * @throws Exception
     */
public void train(JavaRDD<String> corpusRDD) throws Exception {
    log.info("Start training ...");
    if (workers > 0)
        corpusRDD.repartition(workers);
    // SparkContext
    final JavaSparkContext sc = new JavaSparkContext(corpusRDD.context());
    // Pre-defined variables
    Map<String, Object> tokenizerVarMap = getTokenizerVarMap();
    Map<String, Object> word2vecVarMap = getWord2vecVarMap();
    // Variables to fill in train
    final JavaRDD<AtomicLong> sentenceWordsCountRDD;
    final JavaRDD<List<VocabWord>> vocabWordListRDD;
    final JavaPairRDD<List<VocabWord>, Long> vocabWordListSentenceCumSumRDD;
    final VocabCache<VocabWord> vocabCache;
    final JavaRDD<Long> sentenceCumSumCountRDD;
    int maxRep = 1;
    // Start Training //
    //////////////////////////////////////
    log.info("Tokenization and building VocabCache ...");
    // Processing every sentence and make a VocabCache which gets fed into a LookupCache
    Broadcast<Map<String, Object>> broadcastTokenizerVarMap = sc.broadcast(tokenizerVarMap);
    TextPipeline pipeline = new TextPipeline(corpusRDD, broadcastTokenizerVarMap);
    pipeline.buildVocabCache();
    pipeline.buildVocabWordListRDD();
    // Get total word count and put into word2vec variable map
    word2vecVarMap.put("totalWordCount", pipeline.getTotalWordCount());
    // 2 RDDs: (vocab words list) and (sentence Count).Already cached
    sentenceWordsCountRDD = pipeline.getSentenceCountRDD();
    vocabWordListRDD = pipeline.getVocabWordListRDD();
    // Get vocabCache and broad-casted vocabCache
    Broadcast<VocabCache<VocabWord>> vocabCacheBroadcast = pipeline.getBroadCastVocabCache();
    vocabCache = vocabCacheBroadcast.getValue();
    log.info("Vocab size: {}", vocabCache.numWords());
    //////////////////////////////////////
    log.info("Building Huffman Tree ...");
    // Building Huffman Tree would update the code and point in each of the vocabWord in vocabCache
    /*
        We don't need to build tree here, since it was built earlier, at TextPipeline.buildVocabCache() call.
        
        Huffman huffman = new Huffman(vocabCache.vocabWords());
        huffman.build();
        huffman.applyIndexes(vocabCache);
        */
    //////////////////////////////////////
    log.info("Calculating cumulative sum of sentence counts ...");
    sentenceCumSumCountRDD = new CountCumSum(sentenceWordsCountRDD).buildCumSum();
    //////////////////////////////////////
    log.info("Mapping to RDD(vocabWordList, cumulative sentence count) ...");
    vocabWordListSentenceCumSumRDD = vocabWordListRDD.zip(sentenceCumSumCountRDD).setName("vocabWordListSentenceCumSumRDD");
    /////////////////////////////////////
    log.info("Broadcasting word2vec variables to workers ...");
    Broadcast<Map<String, Object>> word2vecVarMapBroadcast = sc.broadcast(word2vecVarMap);
    Broadcast<double[]> expTableBroadcast = sc.broadcast(expTable);
    /////////////////////////////////////
    log.info("Training word2vec sentences ...");
    FlatMapFunction firstIterFunc = new FirstIterationFunction(word2vecVarMapBroadcast, expTableBroadcast, vocabCacheBroadcast);
    @SuppressWarnings("unchecked") JavaRDD<Pair<VocabWord, INDArray>> indexSyn0UpdateEntryRDD = vocabWordListSentenceCumSumRDD.mapPartitions(firstIterFunc).map(new MapToPairFunction());
    // Get all the syn0 updates into a list in driver
    List<Pair<VocabWord, INDArray>> syn0UpdateEntries = indexSyn0UpdateEntryRDD.collect();
    // Instantiate syn0
    INDArray syn0 = Nd4j.zeros(vocabCache.numWords(), layerSize);
    // Updating syn0 first pass: just add vectors obtained from different nodes
    log.info("Averaging results...");
    Map<VocabWord, AtomicInteger> updates = new HashMap<>();
    Map<Long, Long> updaters = new HashMap<>();
    for (Pair<VocabWord, INDArray> syn0UpdateEntry : syn0UpdateEntries) {
        syn0.getRow(syn0UpdateEntry.getFirst().getIndex()).addi(syn0UpdateEntry.getSecond());
        // for proper averaging we need to divide resulting sums later, by the number of additions
        if (updates.containsKey(syn0UpdateEntry.getFirst())) {
            updates.get(syn0UpdateEntry.getFirst()).incrementAndGet();
        } else
            updates.put(syn0UpdateEntry.getFirst(), new AtomicInteger(1));
        if (!updaters.containsKey(syn0UpdateEntry.getFirst().getVocabId())) {
            updaters.put(syn0UpdateEntry.getFirst().getVocabId(), syn0UpdateEntry.getFirst().getAffinityId());
        }
    }
    // Updating syn0 second pass: average obtained vectors
    for (Map.Entry<VocabWord, AtomicInteger> entry : updates.entrySet()) {
        if (entry.getValue().get() > 1) {
            if (entry.getValue().get() > maxRep)
                maxRep = entry.getValue().get();
            syn0.getRow(entry.getKey().getIndex()).divi(entry.getValue().get());
        }
    }
    long totals = 0;
    log.info("Finished calculations...");
    vocab = vocabCache;
    InMemoryLookupTable<VocabWord> inMemoryLookupTable = new InMemoryLookupTable<VocabWord>();
    Environment env = EnvironmentUtils.buildEnvironment();
    env.setNumCores(maxRep);
    env.setAvailableMemory(totals);
    update(env, Event.SPARK);
    inMemoryLookupTable.setVocab(vocabCache);
    inMemoryLookupTable.setVectorLength(layerSize);
    inMemoryLookupTable.setSyn0(syn0);
    lookupTable = inMemoryLookupTable;
    modelUtils.init(lookupTable);
}
Also used : HashMap(java.util.HashMap) VocabWord(org.deeplearning4j.models.word2vec.VocabWord) InMemoryLookupTable(org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable) FlatMapFunction(org.apache.spark.api.java.function.FlatMapFunction) ArrayList(java.util.ArrayList) List(java.util.List) CountCumSum(org.deeplearning4j.spark.text.functions.CountCumSum) JavaSparkContext(org.apache.spark.api.java.JavaSparkContext) Pair(org.deeplearning4j.berkeley.Pair) TextPipeline(org.deeplearning4j.spark.text.functions.TextPipeline) AtomicLong(java.util.concurrent.atomic.AtomicLong) INDArray(org.nd4j.linalg.api.ndarray.INDArray) AtomicInteger(java.util.concurrent.atomic.AtomicInteger) VocabCache(org.deeplearning4j.models.word2vec.wordstore.VocabCache) AtomicLong(java.util.concurrent.atomic.AtomicLong) Environment(org.nd4j.linalg.heartbeat.reports.Environment) HashMap(java.util.HashMap) Map(java.util.Map)

Example 7 with VocabCache

use of org.deeplearning4j.models.word2vec.wordstore.VocabCache in project deeplearning4j by deeplearning4j.

the class GloveTest method testGlove.

@Test
public void testGlove() throws Exception {
    Glove glove = new Glove(true, 5, 100);
    JavaRDD<String> corpus = sc.textFile(new ClassPathResource("raw_sentences.txt").getFile().getAbsolutePath()).map(new Function<String, String>() {

        @Override
        public String call(String s) throws Exception {
            return s.toLowerCase();
        }
    });
    Pair<VocabCache<VocabWord>, GloveWeightLookupTable> table = glove.train(corpus);
    WordVectors vectors = WordVectorSerializer.fromPair(new Pair<>((InMemoryLookupTable) table.getSecond(), (VocabCache) table.getFirst()));
    Collection<String> words = vectors.wordsNearest("day", 20);
    assertTrue(words.contains("week"));
}
Also used : GloveWeightLookupTable(org.deeplearning4j.models.glove.GloveWeightLookupTable) ClassPathResource(org.datavec.api.util.ClassPathResource) InMemoryLookupTable(org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable) VocabCache(org.deeplearning4j.models.word2vec.wordstore.VocabCache) WordVectors(org.deeplearning4j.models.embeddings.wordvectors.WordVectors) Test(org.junit.Test) BaseSparkTest(org.deeplearning4j.spark.text.BaseSparkTest)

Example 8 with VocabCache

use of org.deeplearning4j.models.word2vec.wordstore.VocabCache in project deeplearning4j by deeplearning4j.

the class Word2VecTest method testConcepts.

@Test
public void testConcepts() throws Exception {
    // These are all default values for word2vec
    SparkConf sparkConf = new SparkConf().setMaster("local[8]").setAppName("sparktest");
    // Set SparkContext
    JavaSparkContext sc = new JavaSparkContext(sparkConf);
    // Path of data part-00000
    String dataPath = new ClassPathResource("raw_sentences.txt").getFile().getAbsolutePath();
    //        dataPath = "/ext/Temp/part-00000";
    //        String dataPath = new ClassPathResource("spark_word2vec_test.txt").getFile().getAbsolutePath();
    // Read in data
    JavaRDD<String> corpus = sc.textFile(dataPath);
    TokenizerFactory t = new DefaultTokenizerFactory();
    t.setTokenPreProcessor(new CommonPreprocessor());
    Word2Vec word2Vec = new Word2Vec.Builder().setNGrams(1).tokenizerFactory(t).seed(42L).negative(10).useAdaGrad(false).layerSize(150).windowSize(5).learningRate(0.025).minLearningRate(0.0001).iterations(1).batchSize(100).minWordFrequency(5).stopWords(Arrays.asList("three")).useUnknown(true).build();
    word2Vec.train(corpus);
    //word2Vec.setModelUtils(new FlatModelUtils());
    System.out.println("UNK: " + word2Vec.getWordVectorMatrix("UNK"));
    InMemoryLookupTable<VocabWord> table = (InMemoryLookupTable<VocabWord>) word2Vec.lookupTable();
    double sim = word2Vec.similarity("day", "night");
    System.out.println("day/night similarity: " + sim);
    /*
        System.out.println("Hornjo: " + word2Vec.getWordVectorMatrix("hornjoserbsce"));
        System.out.println("carro: " + word2Vec.getWordVectorMatrix("carro"));
        
        Collection<String> portu = word2Vec.wordsNearest("carro", 10);
        printWords("carro", portu, word2Vec);
        
        portu = word2Vec.wordsNearest("davi", 10);
        printWords("davi", portu, word2Vec);
        
        System.out.println("---------------------------------------");
        */
    Collection<String> words = word2Vec.wordsNearest("day", 10);
    printWords("day", words, word2Vec);
    assertTrue(words.contains("night"));
    assertTrue(words.contains("week"));
    assertTrue(words.contains("year"));
    sim = word2Vec.similarity("two", "four");
    System.out.println("two/four similarity: " + sim);
    words = word2Vec.wordsNearest("two", 10);
    printWords("two", words, word2Vec);
    // three should be absent due to stopWords
    assertFalse(words.contains("three"));
    assertTrue(words.contains("five"));
    assertTrue(words.contains("four"));
    sc.stop();
    // test serialization
    File tempFile = File.createTempFile("temp", "tmp");
    tempFile.deleteOnExit();
    int idx1 = word2Vec.vocab().wordFor("day").getIndex();
    INDArray array1 = word2Vec.getWordVectorMatrix("day").dup();
    VocabWord word1 = word2Vec.vocab().elementAtIndex(0);
    WordVectorSerializer.writeWordVectors(word2Vec.getLookupTable(), tempFile);
    WordVectors vectors = WordVectorSerializer.loadTxtVectors(tempFile);
    VocabWord word2 = ((VocabCache<VocabWord>) vectors.vocab()).elementAtIndex(0);
    VocabWord wordIT = ((VocabCache<VocabWord>) vectors.vocab()).wordFor("it");
    int idx2 = vectors.vocab().wordFor("day").getIndex();
    INDArray array2 = vectors.getWordVectorMatrix("day").dup();
    System.out.println("word 'i': " + word2);
    System.out.println("word 'it': " + wordIT);
    assertEquals(idx1, idx2);
    assertEquals(word1, word2);
    assertEquals(array1, array2);
}
Also used : TokenizerFactory(org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory) DefaultTokenizerFactory(org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFactory) VocabWord(org.deeplearning4j.models.word2vec.VocabWord) ClassPathResource(org.datavec.api.util.ClassPathResource) DefaultTokenizerFactory(org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFactory) CommonPreprocessor(org.deeplearning4j.text.tokenization.tokenizer.preprocessor.CommonPreprocessor) InMemoryLookupTable(org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable) INDArray(org.nd4j.linalg.api.ndarray.INDArray) VocabCache(org.deeplearning4j.models.word2vec.wordstore.VocabCache) JavaSparkContext(org.apache.spark.api.java.JavaSparkContext) WordVectors(org.deeplearning4j.models.embeddings.wordvectors.WordVectors) SparkConf(org.apache.spark.SparkConf) File(java.io.File) Test(org.junit.Test)

Example 9 with VocabCache

use of org.deeplearning4j.models.word2vec.wordstore.VocabCache in project deeplearning4j by deeplearning4j.

the class WordVectorSerializer method loadTxtVectors.

/**
     * This method can be used to load previously saved model from InputStream (like a HDFS-stream)
     *
     * Deprecation note: Please, consider using readWord2VecModel() or loadStaticModel() method instead
     *
     * @param stream InputStream that contains previously serialized model
     * @param skipFirstLine Set this TRUE if first line contains csv header, FALSE otherwise
     * @return
     * @throws IOException
     */
@Deprecated
public static WordVectors loadTxtVectors(@NonNull InputStream stream, boolean skipFirstLine) throws IOException {
    AbstractCache<VocabWord> cache = new AbstractCache.Builder<VocabWord>().build();
    BufferedReader reader = new BufferedReader(new InputStreamReader(stream));
    String line = "";
    List<INDArray> arrays = new ArrayList<>();
    if (skipFirstLine)
        reader.readLine();
    while ((line = reader.readLine()) != null) {
        String[] split = line.split(" ");
        String word = split[0].replaceAll(whitespaceReplacement, " ");
        VocabWord word1 = new VocabWord(1.0, word);
        word1.setIndex(cache.numWords());
        cache.addToken(word1);
        cache.addWordToIndex(word1.getIndex(), word);
        cache.putVocabWord(word);
        float[] vector = new float[split.length - 1];
        for (int i = 1; i < split.length; i++) {
            vector[i - 1] = Float.parseFloat(split[i]);
        }
        INDArray row = Nd4j.create(vector);
        arrays.add(row);
    }
    InMemoryLookupTable<VocabWord> lookupTable = (InMemoryLookupTable<VocabWord>) new InMemoryLookupTable.Builder<VocabWord>().vectorLength(arrays.get(0).columns()).cache(cache).build();
    INDArray syn = Nd4j.vstack(arrays);
    Nd4j.clearNans(syn);
    lookupTable.setSyn0(syn);
    return fromPair(Pair.makePair((InMemoryLookupTable) lookupTable, (VocabCache) cache));
}
Also used : ArrayList(java.util.ArrayList) VocabWord(org.deeplearning4j.models.word2vec.VocabWord) AbstractCache(org.deeplearning4j.models.word2vec.wordstore.inmemory.AbstractCache) InMemoryLookupTable(org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable) INDArray(org.nd4j.linalg.api.ndarray.INDArray) VocabCache(org.deeplearning4j.models.word2vec.wordstore.VocabCache)

Example 10 with VocabCache

use of org.deeplearning4j.models.word2vec.wordstore.VocabCache in project deeplearning4j by deeplearning4j.

the class VocabCacheExporter method export.

@Override
public void export(JavaRDD<ExportContainer<VocabWord>> rdd) {
    // beware, generally that's VERY bad idea, but will work fine for testing purposes
    List<ExportContainer<VocabWord>> list = rdd.collect();
    if (vocabCache == null)
        vocabCache = new AbstractCache<>();
    INDArray syn0 = null;
    // just roll through list
    for (ExportContainer<VocabWord> element : list) {
        VocabWord word = element.getElement();
        INDArray weights = element.getArray();
        if (syn0 == null)
            syn0 = Nd4j.create(list.size(), weights.length());
        vocabCache.addToken(word);
        vocabCache.addWordToIndex(word.getIndex(), word.getLabel());
        syn0.getRow(word.getIndex()).assign(weights);
    }
    if (lookupTable == null)
        lookupTable = new InMemoryLookupTable.Builder<VocabWord>().cache(vocabCache).vectorLength(syn0.columns()).build();
    lookupTable.setSyn0(syn0);
    // this is bad & dirty, but we don't really need anything else for testing :)
    word2Vec = WordVectorSerializer.fromPair(Pair.<InMemoryLookupTable, VocabCache>makePair(lookupTable, vocabCache));
}
Also used : InMemoryLookupTable(org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable) INDArray(org.nd4j.linalg.api.ndarray.INDArray) VocabCache(org.deeplearning4j.models.word2vec.wordstore.VocabCache) ExportContainer(org.deeplearning4j.spark.models.sequencevectors.export.ExportContainer) VocabWord(org.deeplearning4j.models.word2vec.VocabWord) AbstractCache(org.deeplearning4j.models.word2vec.wordstore.inmemory.AbstractCache)

Aggregations

VocabCache (org.deeplearning4j.models.word2vec.wordstore.VocabCache)10 VocabWord (org.deeplearning4j.models.word2vec.VocabWord)9 INDArray (org.nd4j.linalg.api.ndarray.INDArray)8 InMemoryLookupTable (org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable)7 ArrayList (java.util.ArrayList)4 JavaSparkContext (org.apache.spark.api.java.JavaSparkContext)3 ClassPathResource (org.datavec.api.util.ClassPathResource)3 WordVectors (org.deeplearning4j.models.embeddings.wordvectors.WordVectors)3 Word2Vec (org.deeplearning4j.models.word2vec.Word2Vec)3 TokenizerFactory (org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory)3 Test (org.junit.Test)3 File (java.io.File)2 AtomicInteger (java.util.concurrent.atomic.AtomicInteger)2 AtomicLong (java.util.concurrent.atomic.AtomicLong)2 SparkConf (org.apache.spark.SparkConf)2 CounterMap (org.deeplearning4j.berkeley.CounterMap)2 Pair (org.deeplearning4j.berkeley.Pair)2 GloveWeightLookupTable (org.deeplearning4j.models.glove.GloveWeightLookupTable)2 StaticWord2Vec (org.deeplearning4j.models.word2vec.StaticWord2Vec)2 AbstractCache (org.deeplearning4j.models.word2vec.wordstore.inmemory.AbstractCache)2