Search in sources :

Example 1 with CountCumSum

use of org.deeplearning4j.spark.text.functions.CountCumSum in project deeplearning4j by deeplearning4j.

the class Word2Vec method train.

/**
     *  Training word2vec model on a given text corpus
     *
     * @param corpusRDD training corpus
     * @throws Exception
     */
public void train(JavaRDD<String> corpusRDD) throws Exception {
    log.info("Start training ...");
    if (workers > 0)
        corpusRDD.repartition(workers);
    // SparkContext
    final JavaSparkContext sc = new JavaSparkContext(corpusRDD.context());
    // Pre-defined variables
    Map<String, Object> tokenizerVarMap = getTokenizerVarMap();
    Map<String, Object> word2vecVarMap = getWord2vecVarMap();
    // Variables to fill in train
    final JavaRDD<AtomicLong> sentenceWordsCountRDD;
    final JavaRDD<List<VocabWord>> vocabWordListRDD;
    final JavaPairRDD<List<VocabWord>, Long> vocabWordListSentenceCumSumRDD;
    final VocabCache<VocabWord> vocabCache;
    final JavaRDD<Long> sentenceCumSumCountRDD;
    int maxRep = 1;
    // Start Training //
    //////////////////////////////////////
    log.info("Tokenization and building VocabCache ...");
    // Processing every sentence and make a VocabCache which gets fed into a LookupCache
    Broadcast<Map<String, Object>> broadcastTokenizerVarMap = sc.broadcast(tokenizerVarMap);
    TextPipeline pipeline = new TextPipeline(corpusRDD, broadcastTokenizerVarMap);
    pipeline.buildVocabCache();
    pipeline.buildVocabWordListRDD();
    // Get total word count and put into word2vec variable map
    word2vecVarMap.put("totalWordCount", pipeline.getTotalWordCount());
    // 2 RDDs: (vocab words list) and (sentence Count).Already cached
    sentenceWordsCountRDD = pipeline.getSentenceCountRDD();
    vocabWordListRDD = pipeline.getVocabWordListRDD();
    // Get vocabCache and broad-casted vocabCache
    Broadcast<VocabCache<VocabWord>> vocabCacheBroadcast = pipeline.getBroadCastVocabCache();
    vocabCache = vocabCacheBroadcast.getValue();
    log.info("Vocab size: {}", vocabCache.numWords());
    //////////////////////////////////////
    log.info("Building Huffman Tree ...");
    // Building Huffman Tree would update the code and point in each of the vocabWord in vocabCache
    /*
        We don't need to build tree here, since it was built earlier, at TextPipeline.buildVocabCache() call.
        
        Huffman huffman = new Huffman(vocabCache.vocabWords());
        huffman.build();
        huffman.applyIndexes(vocabCache);
        */
    //////////////////////////////////////
    log.info("Calculating cumulative sum of sentence counts ...");
    sentenceCumSumCountRDD = new CountCumSum(sentenceWordsCountRDD).buildCumSum();
    //////////////////////////////////////
    log.info("Mapping to RDD(vocabWordList, cumulative sentence count) ...");
    vocabWordListSentenceCumSumRDD = vocabWordListRDD.zip(sentenceCumSumCountRDD).setName("vocabWordListSentenceCumSumRDD");
    /////////////////////////////////////
    log.info("Broadcasting word2vec variables to workers ...");
    Broadcast<Map<String, Object>> word2vecVarMapBroadcast = sc.broadcast(word2vecVarMap);
    Broadcast<double[]> expTableBroadcast = sc.broadcast(expTable);
    /////////////////////////////////////
    log.info("Training word2vec sentences ...");
    FlatMapFunction firstIterFunc = new FirstIterationFunction(word2vecVarMapBroadcast, expTableBroadcast, vocabCacheBroadcast);
    @SuppressWarnings("unchecked") JavaRDD<Pair<VocabWord, INDArray>> indexSyn0UpdateEntryRDD = vocabWordListSentenceCumSumRDD.mapPartitions(firstIterFunc).map(new MapToPairFunction());
    // Get all the syn0 updates into a list in driver
    List<Pair<VocabWord, INDArray>> syn0UpdateEntries = indexSyn0UpdateEntryRDD.collect();
    // Instantiate syn0
    INDArray syn0 = Nd4j.zeros(vocabCache.numWords(), layerSize);
    // Updating syn0 first pass: just add vectors obtained from different nodes
    log.info("Averaging results...");
    Map<VocabWord, AtomicInteger> updates = new HashMap<>();
    Map<Long, Long> updaters = new HashMap<>();
    for (Pair<VocabWord, INDArray> syn0UpdateEntry : syn0UpdateEntries) {
        syn0.getRow(syn0UpdateEntry.getFirst().getIndex()).addi(syn0UpdateEntry.getSecond());
        // for proper averaging we need to divide resulting sums later, by the number of additions
        if (updates.containsKey(syn0UpdateEntry.getFirst())) {
            updates.get(syn0UpdateEntry.getFirst()).incrementAndGet();
        } else
            updates.put(syn0UpdateEntry.getFirst(), new AtomicInteger(1));
        if (!updaters.containsKey(syn0UpdateEntry.getFirst().getVocabId())) {
            updaters.put(syn0UpdateEntry.getFirst().getVocabId(), syn0UpdateEntry.getFirst().getAffinityId());
        }
    }
    // Updating syn0 second pass: average obtained vectors
    for (Map.Entry<VocabWord, AtomicInteger> entry : updates.entrySet()) {
        if (entry.getValue().get() > 1) {
            if (entry.getValue().get() > maxRep)
                maxRep = entry.getValue().get();
            syn0.getRow(entry.getKey().getIndex()).divi(entry.getValue().get());
        }
    }
    long totals = 0;
    log.info("Finished calculations...");
    vocab = vocabCache;
    InMemoryLookupTable<VocabWord> inMemoryLookupTable = new InMemoryLookupTable<VocabWord>();
    Environment env = EnvironmentUtils.buildEnvironment();
    env.setNumCores(maxRep);
    env.setAvailableMemory(totals);
    update(env, Event.SPARK);
    inMemoryLookupTable.setVocab(vocabCache);
    inMemoryLookupTable.setVectorLength(layerSize);
    inMemoryLookupTable.setSyn0(syn0);
    lookupTable = inMemoryLookupTable;
    modelUtils.init(lookupTable);
}
Also used : HashMap(java.util.HashMap) VocabWord(org.deeplearning4j.models.word2vec.VocabWord) InMemoryLookupTable(org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable) FlatMapFunction(org.apache.spark.api.java.function.FlatMapFunction) ArrayList(java.util.ArrayList) List(java.util.List) CountCumSum(org.deeplearning4j.spark.text.functions.CountCumSum) JavaSparkContext(org.apache.spark.api.java.JavaSparkContext) Pair(org.deeplearning4j.berkeley.Pair) TextPipeline(org.deeplearning4j.spark.text.functions.TextPipeline) AtomicLong(java.util.concurrent.atomic.AtomicLong) INDArray(org.nd4j.linalg.api.ndarray.INDArray) AtomicInteger(java.util.concurrent.atomic.AtomicInteger) VocabCache(org.deeplearning4j.models.word2vec.wordstore.VocabCache) AtomicLong(java.util.concurrent.atomic.AtomicLong) Environment(org.nd4j.linalg.heartbeat.reports.Environment) HashMap(java.util.HashMap) Map(java.util.Map)

Example 2 with CountCumSum

use of org.deeplearning4j.spark.text.functions.CountCumSum in project deeplearning4j by deeplearning4j.

the class TextPipelineTest method testZipFunction1.

/**
     * This test checked generations retrieved using stopWords
     *
     * @throws Exception
     */
@Test
public void testZipFunction1() throws Exception {
    JavaSparkContext sc = getContext();
    JavaRDD<String> corpusRDD = getCorpusRDD(sc);
    //  word2vec.setRemoveStop(false);
    Broadcast<Map<String, Object>> broadcastTokenizerVarMap = sc.broadcast(word2vec.getTokenizerVarMap());
    TextPipeline pipeline = new TextPipeline(corpusRDD, broadcastTokenizerVarMap);
    pipeline.buildVocabCache();
    pipeline.buildVocabWordListRDD();
    JavaRDD<AtomicLong> sentenceCountRDD = pipeline.getSentenceCountRDD();
    JavaRDD<List<VocabWord>> vocabWordListRDD = pipeline.getVocabWordListRDD();
    CountCumSum countCumSum = new CountCumSum(sentenceCountRDD);
    JavaRDD<Long> sentenceCountCumSumRDD = countCumSum.buildCumSum();
    JavaPairRDD<List<VocabWord>, Long> vocabWordListSentenceCumSumRDD = vocabWordListRDD.zip(sentenceCountCumSumRDD);
    List<Tuple2<List<VocabWord>, Long>> lst = vocabWordListSentenceCumSumRDD.collect();
    List<VocabWord> vocabWordsList1 = lst.get(0)._1();
    Long cumSumSize1 = lst.get(0)._2();
    assertEquals(3, vocabWordsList1.size());
    assertEquals(vocabWordsList1.get(0).getWord(), "strange");
    assertEquals(vocabWordsList1.get(1).getWord(), "strange");
    assertEquals(vocabWordsList1.get(2).getWord(), "world");
    assertEquals(cumSumSize1, 6L, 0);
    List<VocabWord> vocabWordsList2 = lst.get(1)._1();
    Long cumSumSize2 = lst.get(1)._2();
    assertEquals(2, vocabWordsList2.size());
    assertEquals(vocabWordsList2.get(0).getWord(), "flowers");
    assertEquals(vocabWordsList2.get(1).getWord(), "red");
    assertEquals(cumSumSize2, 9L, 0);
    sc.stop();
}
Also used : VocabWord(org.deeplearning4j.models.word2vec.VocabWord) TextPipeline(org.deeplearning4j.spark.text.functions.TextPipeline) AtomicLong(java.util.concurrent.atomic.AtomicLong) Tuple2(scala.Tuple2) AtomicLong(java.util.concurrent.atomic.AtomicLong) CountCumSum(org.deeplearning4j.spark.text.functions.CountCumSum) JavaSparkContext(org.apache.spark.api.java.JavaSparkContext) Test(org.junit.Test)

Example 3 with CountCumSum

use of org.deeplearning4j.spark.text.functions.CountCumSum in project deeplearning4j by deeplearning4j.

the class TextPipelineTest method testZipFunction2.

@Test
public void testZipFunction2() throws Exception {
    JavaSparkContext sc = getContext();
    JavaRDD<String> corpusRDD = getCorpusRDD(sc);
    //  word2vec.setRemoveStop(false);
    Broadcast<Map<String, Object>> broadcastTokenizerVarMap = sc.broadcast(word2vecNoStop.getTokenizerVarMap());
    TextPipeline pipeline = new TextPipeline(corpusRDD, broadcastTokenizerVarMap);
    pipeline.buildVocabCache();
    pipeline.buildVocabWordListRDD();
    JavaRDD<AtomicLong> sentenceCountRDD = pipeline.getSentenceCountRDD();
    JavaRDD<List<VocabWord>> vocabWordListRDD = pipeline.getVocabWordListRDD();
    CountCumSum countCumSum = new CountCumSum(sentenceCountRDD);
    JavaRDD<Long> sentenceCountCumSumRDD = countCumSum.buildCumSum();
    JavaPairRDD<List<VocabWord>, Long> vocabWordListSentenceCumSumRDD = vocabWordListRDD.zip(sentenceCountCumSumRDD);
    List<Tuple2<List<VocabWord>, Long>> lst = vocabWordListSentenceCumSumRDD.collect();
    List<VocabWord> vocabWordsList1 = lst.get(0)._1();
    Long cumSumSize1 = lst.get(0)._2();
    assertEquals(6, vocabWordsList1.size());
    assertEquals(vocabWordsList1.get(0).getWord(), "this");
    assertEquals(vocabWordsList1.get(1).getWord(), "is");
    assertEquals(vocabWordsList1.get(2).getWord(), "a");
    assertEquals(vocabWordsList1.get(3).getWord(), "strange");
    assertEquals(vocabWordsList1.get(4).getWord(), "strange");
    assertEquals(vocabWordsList1.get(5).getWord(), "world");
    assertEquals(cumSumSize1, 6L, 0);
    List<VocabWord> vocabWordsList2 = lst.get(1)._1();
    Long cumSumSize2 = lst.get(1)._2();
    assertEquals(vocabWordsList2.size(), 3);
    assertEquals(vocabWordsList2.get(0).getWord(), "flowers");
    assertEquals(vocabWordsList2.get(1).getWord(), "are");
    assertEquals(vocabWordsList2.get(2).getWord(), "red");
    assertEquals(cumSumSize2, 9L, 0);
    sc.stop();
}
Also used : VocabWord(org.deeplearning4j.models.word2vec.VocabWord) TextPipeline(org.deeplearning4j.spark.text.functions.TextPipeline) AtomicLong(java.util.concurrent.atomic.AtomicLong) Tuple2(scala.Tuple2) AtomicLong(java.util.concurrent.atomic.AtomicLong) CountCumSum(org.deeplearning4j.spark.text.functions.CountCumSum) JavaSparkContext(org.apache.spark.api.java.JavaSparkContext) Test(org.junit.Test)

Example 4 with CountCumSum

use of org.deeplearning4j.spark.text.functions.CountCumSum in project deeplearning4j by deeplearning4j.

the class TextPipelineTest method testCountCumSum.

@Test
public void testCountCumSum() throws Exception {
    JavaSparkContext sc = getContext();
    JavaRDD<String> corpusRDD = getCorpusRDD(sc);
    Broadcast<Map<String, Object>> broadcastTokenizerVarMap = sc.broadcast(word2vec.getTokenizerVarMap());
    TextPipeline pipeline = new TextPipeline(corpusRDD, broadcastTokenizerVarMap);
    pipeline.buildVocabCache();
    pipeline.buildVocabWordListRDD();
    JavaRDD<AtomicLong> sentenceCountRDD = pipeline.getSentenceCountRDD();
    CountCumSum countCumSum = new CountCumSum(sentenceCountRDD);
    JavaRDD<Long> sentenceCountCumSumRDD = countCumSum.buildCumSum();
    List<Long> sentenceCountCumSumList = sentenceCountCumSumRDD.collect();
    assertTrue(sentenceCountCumSumList.get(0) == 6L);
    assertTrue(sentenceCountCumSumList.get(1) == 9L);
    sc.stop();
}
Also used : AtomicLong(java.util.concurrent.atomic.AtomicLong) AtomicLong(java.util.concurrent.atomic.AtomicLong) CountCumSum(org.deeplearning4j.spark.text.functions.CountCumSum) JavaSparkContext(org.apache.spark.api.java.JavaSparkContext) TextPipeline(org.deeplearning4j.spark.text.functions.TextPipeline) Test(org.junit.Test)

Example 5 with CountCumSum

use of org.deeplearning4j.spark.text.functions.CountCumSum in project deeplearning4j by deeplearning4j.

the class TextPipelineTest method testSyn0AfterFirstIteration.

@Test
public void testSyn0AfterFirstIteration() throws Exception {
    JavaSparkContext sc = getContext();
    JavaRDD<String> corpusRDD = getCorpusRDD(sc);
    //  word2vec.setRemoveStop(false);
    Broadcast<Map<String, Object>> broadcastTokenizerVarMap = sc.broadcast(word2vec.getTokenizerVarMap());
    TextPipeline pipeline = new TextPipeline(corpusRDD, broadcastTokenizerVarMap);
    pipeline.buildVocabCache();
    pipeline.buildVocabWordListRDD();
    VocabCache<VocabWord> vocabCache = pipeline.getVocabCache();
    Huffman huffman = new Huffman(vocabCache.vocabWords());
    huffman.build();
    // Get total word count and put into word2vec variable map
    Map<String, Object> word2vecVarMap = word2vec.getWord2vecVarMap();
    word2vecVarMap.put("totalWordCount", pipeline.getTotalWordCount());
    double[] expTable = word2vec.getExpTable();
    JavaRDD<AtomicLong> sentenceCountRDD = pipeline.getSentenceCountRDD();
    JavaRDD<List<VocabWord>> vocabWordListRDD = pipeline.getVocabWordListRDD();
    CountCumSum countCumSum = new CountCumSum(sentenceCountRDD);
    JavaRDD<Long> sentenceCountCumSumRDD = countCumSum.buildCumSum();
    JavaPairRDD<List<VocabWord>, Long> vocabWordListSentenceCumSumRDD = vocabWordListRDD.zip(sentenceCountCumSumRDD);
    Broadcast<Map<String, Object>> word2vecVarMapBroadcast = sc.broadcast(word2vecVarMap);
    Broadcast<double[]> expTableBroadcast = sc.broadcast(expTable);
    FirstIterationFunction firstIterationFunction = new FirstIterationFunction(word2vecVarMapBroadcast, expTableBroadcast, pipeline.getBroadCastVocabCache());
    JavaRDD<Pair<VocabWord, INDArray>> pointSyn0Vec = vocabWordListSentenceCumSumRDD.mapPartitions(firstIterationFunction).map(new MapToPairFunction());
}
Also used : VocabWord(org.deeplearning4j.models.word2vec.VocabWord) MapToPairFunction(org.deeplearning4j.spark.models.embeddings.word2vec.MapToPairFunction) FirstIterationFunction(org.deeplearning4j.spark.models.embeddings.word2vec.FirstIterationFunction) CountCumSum(org.deeplearning4j.spark.text.functions.CountCumSum) JavaSparkContext(org.apache.spark.api.java.JavaSparkContext) Pair(org.deeplearning4j.berkeley.Pair) TextPipeline(org.deeplearning4j.spark.text.functions.TextPipeline) AtomicLong(java.util.concurrent.atomic.AtomicLong) Huffman(org.deeplearning4j.models.word2vec.Huffman) AtomicLong(java.util.concurrent.atomic.AtomicLong) Test(org.junit.Test)

Aggregations

AtomicLong (java.util.concurrent.atomic.AtomicLong)6 JavaSparkContext (org.apache.spark.api.java.JavaSparkContext)6 CountCumSum (org.deeplearning4j.spark.text.functions.CountCumSum)6 TextPipeline (org.deeplearning4j.spark.text.functions.TextPipeline)6 VocabWord (org.deeplearning4j.models.word2vec.VocabWord)5 Test (org.junit.Test)5 Tuple2 (scala.Tuple2)3 Pair (org.deeplearning4j.berkeley.Pair)2 ArrayList (java.util.ArrayList)1 HashMap (java.util.HashMap)1 List (java.util.List)1 Map (java.util.Map)1 AtomicInteger (java.util.concurrent.atomic.AtomicInteger)1 FlatMapFunction (org.apache.spark.api.java.function.FlatMapFunction)1 InMemoryLookupTable (org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable)1 Huffman (org.deeplearning4j.models.word2vec.Huffman)1 VocabCache (org.deeplearning4j.models.word2vec.wordstore.VocabCache)1 FirstIterationFunction (org.deeplearning4j.spark.models.embeddings.word2vec.FirstIterationFunction)1 FirstIterationFunctionAdapter (org.deeplearning4j.spark.models.embeddings.word2vec.FirstIterationFunctionAdapter)1 MapToPairFunction (org.deeplearning4j.spark.models.embeddings.word2vec.MapToPairFunction)1