Search in sources :

Example 6 with TextPipeline

use of org.deeplearning4j.spark.text.functions.TextPipeline in project deeplearning4j by deeplearning4j.

the class TextPipelineTest method testCountCumSum.

@Test
public void testCountCumSum() throws Exception {
    JavaSparkContext sc = getContext();
    JavaRDD<String> corpusRDD = getCorpusRDD(sc);
    Broadcast<Map<String, Object>> broadcastTokenizerVarMap = sc.broadcast(word2vec.getTokenizerVarMap());
    TextPipeline pipeline = new TextPipeline(corpusRDD, broadcastTokenizerVarMap);
    pipeline.buildVocabCache();
    pipeline.buildVocabWordListRDD();
    JavaRDD<AtomicLong> sentenceCountRDD = pipeline.getSentenceCountRDD();
    CountCumSum countCumSum = new CountCumSum(sentenceCountRDD);
    JavaRDD<Long> sentenceCountCumSumRDD = countCumSum.buildCumSum();
    List<Long> sentenceCountCumSumList = sentenceCountCumSumRDD.collect();
    assertTrue(sentenceCountCumSumList.get(0) == 6L);
    assertTrue(sentenceCountCumSumList.get(1) == 9L);
    sc.stop();
}
Also used : AtomicLong(java.util.concurrent.atomic.AtomicLong) AtomicLong(java.util.concurrent.atomic.AtomicLong) CountCumSum(org.deeplearning4j.spark.text.functions.CountCumSum) JavaSparkContext(org.apache.spark.api.java.JavaSparkContext) TextPipeline(org.deeplearning4j.spark.text.functions.TextPipeline) Test(org.junit.Test)

Example 7 with TextPipeline

use of org.deeplearning4j.spark.text.functions.TextPipeline in project deeplearning4j by deeplearning4j.

the class TextPipelineTest method testHuffman.

@Test
public void testHuffman() throws Exception {
    JavaSparkContext sc = getContext();
    JavaRDD<String> corpusRDD = getCorpusRDD(sc);
    Broadcast<Map<String, Object>> broadcastTokenizerVarMap = sc.broadcast(word2vec.getTokenizerVarMap());
    TextPipeline pipeline = new TextPipeline(corpusRDD, broadcastTokenizerVarMap);
    pipeline.buildVocabCache();
    VocabCache<VocabWord> vocabCache = pipeline.getVocabCache();
    Huffman huffman = new Huffman(vocabCache.vocabWords());
    huffman.build();
    huffman.applyIndexes(vocabCache);
    Collection<VocabWord> vocabWords = vocabCache.vocabWords();
    System.out.println("Huffman Test:");
    for (VocabWord vocabWord : vocabWords) {
        System.out.println("Word: " + vocabWord);
        System.out.println(vocabWord.getCodes());
        System.out.println(vocabWord.getPoints());
    }
    sc.stop();
}
Also used : Huffman(org.deeplearning4j.models.word2vec.Huffman) VocabWord(org.deeplearning4j.models.word2vec.VocabWord) JavaSparkContext(org.apache.spark.api.java.JavaSparkContext) TextPipeline(org.deeplearning4j.spark.text.functions.TextPipeline) Test(org.junit.Test)

Example 8 with TextPipeline

use of org.deeplearning4j.spark.text.functions.TextPipeline in project deeplearning4j by deeplearning4j.

the class TextPipelineTest method testWordFreqAccIdentifyingStopWords.

@Test
public void testWordFreqAccIdentifyingStopWords() throws Exception {
    JavaSparkContext sc = getContext();
    //  word2vec.setRemoveStop(false);
    JavaRDD<String> corpusRDD = getCorpusRDD(sc);
    Broadcast<Map<String, Object>> broadcastTokenizerVarMap = sc.broadcast(word2vec.getTokenizerVarMap());
    TextPipeline pipeline = new TextPipeline(corpusRDD, broadcastTokenizerVarMap);
    JavaRDD<List<String>> tokenizedRDD = pipeline.tokenize();
    pipeline.updateAndReturnAccumulatorVal(tokenizedRDD);
    Counter<String> wordFreqCounter = pipeline.getWordFreqAcc().value();
    assertEquals(wordFreqCounter.getCount("is"), 0, 0);
    assertEquals(wordFreqCounter.getCount("this"), 0, 0);
    assertEquals(wordFreqCounter.getCount("are"), 0, 0);
    assertEquals(wordFreqCounter.getCount("a"), 0, 0);
    assertEquals(wordFreqCounter.getCount("STOP"), 4, 0);
    assertEquals(wordFreqCounter.getCount("strange"), 2, 0);
    assertEquals(wordFreqCounter.getCount("flowers"), 1, 0);
    assertEquals(wordFreqCounter.getCount("world"), 1, 0);
    assertEquals(wordFreqCounter.getCount("red"), 1, 0);
    sc.stop();
}
Also used : JavaSparkContext(org.apache.spark.api.java.JavaSparkContext) TextPipeline(org.deeplearning4j.spark.text.functions.TextPipeline) Test(org.junit.Test)

Example 9 with TextPipeline

use of org.deeplearning4j.spark.text.functions.TextPipeline in project deeplearning4j by deeplearning4j.

the class TextPipelineTest method testBuildVocabCache.

@Test
public void testBuildVocabCache() throws Exception {
    JavaSparkContext sc = getContext();
    JavaRDD<String> corpusRDD = getCorpusRDD(sc);
    Broadcast<Map<String, Object>> broadcastTokenizerVarMap = sc.broadcast(word2vec.getTokenizerVarMap());
    TextPipeline pipeline = new TextPipeline(corpusRDD, broadcastTokenizerVarMap);
    pipeline.buildVocabCache();
    VocabCache<VocabWord> vocabCache = pipeline.getVocabCache();
    assertTrue(vocabCache != null);
    log.info("VocabWords: " + vocabCache.words());
    assertEquals(5, vocabCache.numWords());
    VocabWord redVocab = vocabCache.tokenFor("red");
    VocabWord flowerVocab = vocabCache.tokenFor("flowers");
    VocabWord worldVocab = vocabCache.tokenFor("world");
    VocabWord strangeVocab = vocabCache.tokenFor("strange");
    log.info("Red word: " + redVocab);
    log.info("Flower word: " + flowerVocab);
    log.info("World word: " + worldVocab);
    log.info("Strange word: " + strangeVocab);
    assertEquals(redVocab.getWord(), "red");
    assertEquals(redVocab.getElementFrequency(), 1, 0);
    assertEquals(flowerVocab.getWord(), "flowers");
    assertEquals(flowerVocab.getElementFrequency(), 1, 0);
    assertEquals(worldVocab.getWord(), "world");
    assertEquals(worldVocab.getElementFrequency(), 1, 0);
    assertEquals(strangeVocab.getWord(), "strange");
    assertEquals(strangeVocab.getElementFrequency(), 2, 0);
    sc.stop();
}
Also used : VocabWord(org.deeplearning4j.models.word2vec.VocabWord) JavaSparkContext(org.apache.spark.api.java.JavaSparkContext) TextPipeline(org.deeplearning4j.spark.text.functions.TextPipeline) Test(org.junit.Test)

Example 10 with TextPipeline

use of org.deeplearning4j.spark.text.functions.TextPipeline in project deeplearning4j by deeplearning4j.

the class TextPipelineTest method testFilterMinWordAddVocab.

@Test
public void testFilterMinWordAddVocab() throws Exception {
    JavaSparkContext sc = getContext();
    JavaRDD<String> corpusRDD = getCorpusRDD(sc);
    Broadcast<Map<String, Object>> broadcastTokenizerVarMap = sc.broadcast(word2vec.getTokenizerVarMap());
    TextPipeline pipeline = new TextPipeline(corpusRDD, broadcastTokenizerVarMap);
    JavaRDD<List<String>> tokenizedRDD = pipeline.tokenize();
    pipeline.updateAndReturnAccumulatorVal(tokenizedRDD);
    Counter<String> wordFreqCounter = pipeline.getWordFreqAcc().value();
    pipeline.filterMinWordAddVocab(wordFreqCounter);
    VocabCache<VocabWord> vocabCache = pipeline.getVocabCache();
    assertTrue(vocabCache != null);
    VocabWord redVocab = vocabCache.tokenFor("red");
    VocabWord flowerVocab = vocabCache.tokenFor("flowers");
    VocabWord worldVocab = vocabCache.tokenFor("world");
    VocabWord strangeVocab = vocabCache.tokenFor("strange");
    assertEquals(redVocab.getWord(), "red");
    assertEquals(redVocab.getElementFrequency(), 1, 0);
    assertEquals(flowerVocab.getWord(), "flowers");
    assertEquals(flowerVocab.getElementFrequency(), 1, 0);
    assertEquals(worldVocab.getWord(), "world");
    assertEquals(worldVocab.getElementFrequency(), 1, 0);
    assertEquals(strangeVocab.getWord(), "strange");
    assertEquals(strangeVocab.getElementFrequency(), 2, 0);
    sc.stop();
}
Also used : VocabWord(org.deeplearning4j.models.word2vec.VocabWord) JavaSparkContext(org.apache.spark.api.java.JavaSparkContext) TextPipeline(org.deeplearning4j.spark.text.functions.TextPipeline) Test(org.junit.Test)

Aggregations

JavaSparkContext (org.apache.spark.api.java.JavaSparkContext)15 TextPipeline (org.deeplearning4j.spark.text.functions.TextPipeline)15 Test (org.junit.Test)13 VocabWord (org.deeplearning4j.models.word2vec.VocabWord)10 AtomicLong (java.util.concurrent.atomic.AtomicLong)8 CountCumSum (org.deeplearning4j.spark.text.functions.CountCumSum)6 Pair (org.deeplearning4j.berkeley.Pair)4 Tuple2 (scala.Tuple2)4 Huffman (org.deeplearning4j.models.word2vec.Huffman)2 VocabCache (org.deeplearning4j.models.word2vec.wordstore.VocabCache)2 INDArray (org.nd4j.linalg.api.ndarray.INDArray)2 ArrayList (java.util.ArrayList)1 HashMap (java.util.HashMap)1 List (java.util.List)1 Map (java.util.Map)1 AtomicInteger (java.util.concurrent.atomic.AtomicInteger)1 SparkConf (org.apache.spark.SparkConf)1 FlatMapFunction (org.apache.spark.api.java.function.FlatMapFunction)1 CounterMap (org.deeplearning4j.berkeley.CounterMap)1 Triple (org.deeplearning4j.berkeley.Triple)1