use of org.apache.spark.api.java.JavaSparkContext in project deeplearning4j by deeplearning4j.
the class TextPipelineTest method testBuildVocabWordListRDD.
@Test
public void testBuildVocabWordListRDD() throws Exception {
JavaSparkContext sc = getContext();
JavaRDD<String> corpusRDD = getCorpusRDD(sc);
Broadcast<Map<String, Object>> broadcastTokenizerVarMap = sc.broadcast(word2vec.getTokenizerVarMap());
TextPipeline pipeline = new TextPipeline(corpusRDD, broadcastTokenizerVarMap);
pipeline.buildVocabCache();
pipeline.buildVocabWordListRDD();
JavaRDD<AtomicLong> sentenceCountRDD = pipeline.getSentenceCountRDD();
JavaRDD<List<VocabWord>> vocabWordListRDD = pipeline.getVocabWordListRDD();
List<List<VocabWord>> vocabWordList = vocabWordListRDD.collect();
List<VocabWord> firstSentenceVocabList = vocabWordList.get(0);
List<VocabWord> secondSentenceVocabList = vocabWordList.get(1);
System.out.println(Arrays.deepToString(firstSentenceVocabList.toArray()));
List<String> firstSentenceTokenList = new ArrayList<>();
List<String> secondSentenceTokenList = new ArrayList<>();
for (VocabWord v : firstSentenceVocabList) {
if (v != null) {
firstSentenceTokenList.add(v.getWord());
}
}
for (VocabWord v : secondSentenceVocabList) {
if (v != null) {
secondSentenceTokenList.add(v.getWord());
}
}
assertEquals(pipeline.getTotalWordCount(), 9, 0);
assertEquals(sentenceCountRDD.collect().get(0).get(), 6);
assertEquals(sentenceCountRDD.collect().get(1).get(), 3);
assertTrue(firstSentenceTokenList.containsAll(Arrays.asList("strange", "strange", "world")));
assertTrue(secondSentenceTokenList.containsAll(Arrays.asList("flowers", "red")));
sc.stop();
}
use of org.apache.spark.api.java.JavaSparkContext in project deeplearning4j by deeplearning4j.
the class TextPipelineTest method testZipFunction1.
/**
* This test checked generations retrieved using stopWords
*
* @throws Exception
*/
@Test
public void testZipFunction1() throws Exception {
JavaSparkContext sc = getContext();
JavaRDD<String> corpusRDD = getCorpusRDD(sc);
// word2vec.setRemoveStop(false);
Broadcast<Map<String, Object>> broadcastTokenizerVarMap = sc.broadcast(word2vec.getTokenizerVarMap());
TextPipeline pipeline = new TextPipeline(corpusRDD, broadcastTokenizerVarMap);
pipeline.buildVocabCache();
pipeline.buildVocabWordListRDD();
JavaRDD<AtomicLong> sentenceCountRDD = pipeline.getSentenceCountRDD();
JavaRDD<List<VocabWord>> vocabWordListRDD = pipeline.getVocabWordListRDD();
CountCumSum countCumSum = new CountCumSum(sentenceCountRDD);
JavaRDD<Long> sentenceCountCumSumRDD = countCumSum.buildCumSum();
JavaPairRDD<List<VocabWord>, Long> vocabWordListSentenceCumSumRDD = vocabWordListRDD.zip(sentenceCountCumSumRDD);
List<Tuple2<List<VocabWord>, Long>> lst = vocabWordListSentenceCumSumRDD.collect();
List<VocabWord> vocabWordsList1 = lst.get(0)._1();
Long cumSumSize1 = lst.get(0)._2();
assertEquals(3, vocabWordsList1.size());
assertEquals(vocabWordsList1.get(0).getWord(), "strange");
assertEquals(vocabWordsList1.get(1).getWord(), "strange");
assertEquals(vocabWordsList1.get(2).getWord(), "world");
assertEquals(cumSumSize1, 6L, 0);
List<VocabWord> vocabWordsList2 = lst.get(1)._1();
Long cumSumSize2 = lst.get(1)._2();
assertEquals(2, vocabWordsList2.size());
assertEquals(vocabWordsList2.get(0).getWord(), "flowers");
assertEquals(vocabWordsList2.get(1).getWord(), "red");
assertEquals(cumSumSize2, 9L, 0);
sc.stop();
}
use of org.apache.spark.api.java.JavaSparkContext in project deeplearning4j by deeplearning4j.
the class TextPipelineTest method testZipFunction2.
@Test
public void testZipFunction2() throws Exception {
JavaSparkContext sc = getContext();
JavaRDD<String> corpusRDD = getCorpusRDD(sc);
// word2vec.setRemoveStop(false);
Broadcast<Map<String, Object>> broadcastTokenizerVarMap = sc.broadcast(word2vecNoStop.getTokenizerVarMap());
TextPipeline pipeline = new TextPipeline(corpusRDD, broadcastTokenizerVarMap);
pipeline.buildVocabCache();
pipeline.buildVocabWordListRDD();
JavaRDD<AtomicLong> sentenceCountRDD = pipeline.getSentenceCountRDD();
JavaRDD<List<VocabWord>> vocabWordListRDD = pipeline.getVocabWordListRDD();
CountCumSum countCumSum = new CountCumSum(sentenceCountRDD);
JavaRDD<Long> sentenceCountCumSumRDD = countCumSum.buildCumSum();
JavaPairRDD<List<VocabWord>, Long> vocabWordListSentenceCumSumRDD = vocabWordListRDD.zip(sentenceCountCumSumRDD);
List<Tuple2<List<VocabWord>, Long>> lst = vocabWordListSentenceCumSumRDD.collect();
List<VocabWord> vocabWordsList1 = lst.get(0)._1();
Long cumSumSize1 = lst.get(0)._2();
assertEquals(6, vocabWordsList1.size());
assertEquals(vocabWordsList1.get(0).getWord(), "this");
assertEquals(vocabWordsList1.get(1).getWord(), "is");
assertEquals(vocabWordsList1.get(2).getWord(), "a");
assertEquals(vocabWordsList1.get(3).getWord(), "strange");
assertEquals(vocabWordsList1.get(4).getWord(), "strange");
assertEquals(vocabWordsList1.get(5).getWord(), "world");
assertEquals(cumSumSize1, 6L, 0);
List<VocabWord> vocabWordsList2 = lst.get(1)._1();
Long cumSumSize2 = lst.get(1)._2();
assertEquals(vocabWordsList2.size(), 3);
assertEquals(vocabWordsList2.get(0).getWord(), "flowers");
assertEquals(vocabWordsList2.get(1).getWord(), "are");
assertEquals(vocabWordsList2.get(2).getWord(), "red");
assertEquals(cumSumSize2, 9L, 0);
sc.stop();
}
use of org.apache.spark.api.java.JavaSparkContext in project deeplearning4j by deeplearning4j.
the class TextPipelineTest method testCountCumSum.
@Test
public void testCountCumSum() throws Exception {
JavaSparkContext sc = getContext();
JavaRDD<String> corpusRDD = getCorpusRDD(sc);
Broadcast<Map<String, Object>> broadcastTokenizerVarMap = sc.broadcast(word2vec.getTokenizerVarMap());
TextPipeline pipeline = new TextPipeline(corpusRDD, broadcastTokenizerVarMap);
pipeline.buildVocabCache();
pipeline.buildVocabWordListRDD();
JavaRDD<AtomicLong> sentenceCountRDD = pipeline.getSentenceCountRDD();
CountCumSum countCumSum = new CountCumSum(sentenceCountRDD);
JavaRDD<Long> sentenceCountCumSumRDD = countCumSum.buildCumSum();
List<Long> sentenceCountCumSumList = sentenceCountCumSumRDD.collect();
assertTrue(sentenceCountCumSumList.get(0) == 6L);
assertTrue(sentenceCountCumSumList.get(1) == 9L);
sc.stop();
}
use of org.apache.spark.api.java.JavaSparkContext in project deeplearning4j by deeplearning4j.
the class TextPipelineTest method testHuffman.
@Test
public void testHuffman() throws Exception {
JavaSparkContext sc = getContext();
JavaRDD<String> corpusRDD = getCorpusRDD(sc);
Broadcast<Map<String, Object>> broadcastTokenizerVarMap = sc.broadcast(word2vec.getTokenizerVarMap());
TextPipeline pipeline = new TextPipeline(corpusRDD, broadcastTokenizerVarMap);
pipeline.buildVocabCache();
VocabCache<VocabWord> vocabCache = pipeline.getVocabCache();
Huffman huffman = new Huffman(vocabCache.vocabWords());
huffman.build();
huffman.applyIndexes(vocabCache);
Collection<VocabWord> vocabWords = vocabCache.vocabWords();
System.out.println("Huffman Test:");
for (VocabWord vocabWord : vocabWords) {
System.out.println("Word: " + vocabWord);
System.out.println(vocabWord.getCodes());
System.out.println(vocabWord.getPoints());
}
sc.stop();
}
Aggregations