use of org.deeplearning4j.spark.text.functions.TextPipeline in project deeplearning4j by deeplearning4j.
the class TextPipelineTest method testTokenizer.
@Test
public void testTokenizer() throws Exception {
JavaSparkContext sc = getContext();
JavaRDD<String> corpusRDD = getCorpusRDD(sc);
Broadcast<Map<String, Object>> broadcastTokenizerVarMap = sc.broadcast(word2vec.getTokenizerVarMap());
TextPipeline pipeline = new TextPipeline(corpusRDD, broadcastTokenizerVarMap);
JavaRDD<List<String>> tokenizedRDD = pipeline.tokenize();
assertEquals(2, tokenizedRDD.count());
assertEquals(Arrays.asList("this", "is", "a", "strange", "strange", "world"), tokenizedRDD.first());
sc.stop();
}
use of org.deeplearning4j.spark.text.functions.TextPipeline in project deeplearning4j by deeplearning4j.
the class TextPipelineTest method testWordFreqAccIdentifyStopWords.
@Test
public void testWordFreqAccIdentifyStopWords() throws Exception {
JavaSparkContext sc = getContext();
JavaRDD<String> corpusRDD = getCorpusRDD(sc);
Broadcast<Map<String, Object>> broadcastTokenizerVarMap = sc.broadcast(word2vec.getTokenizerVarMap());
TextPipeline pipeline = new TextPipeline(corpusRDD, broadcastTokenizerVarMap);
JavaRDD<List<String>> tokenizedRDD = pipeline.tokenize();
JavaRDD<Pair<List<String>, AtomicLong>> sentenceWordsCountRDD = pipeline.updateAndReturnAccumulatorVal(tokenizedRDD);
Counter<String> wordFreqCounter = pipeline.getWordFreqAcc().value();
assertEquals(wordFreqCounter.getCount("STOP"), 4, 0);
assertEquals(wordFreqCounter.getCount("strange"), 2, 0);
assertEquals(wordFreqCounter.getCount("flowers"), 1, 0);
assertEquals(wordFreqCounter.getCount("world"), 1, 0);
assertEquals(wordFreqCounter.getCount("red"), 1, 0);
List<Pair<List<String>, AtomicLong>> ret = sentenceWordsCountRDD.collect();
assertEquals(ret.get(0).getFirst(), Arrays.asList("this", "is", "a", "strange", "strange", "world"));
assertEquals(ret.get(1).getFirst(), Arrays.asList("flowers", "are", "red"));
assertEquals(ret.get(0).getSecond().get(), 6);
assertEquals(ret.get(1).getSecond().get(), 3);
sc.stop();
}
use of org.deeplearning4j.spark.text.functions.TextPipeline in project deeplearning4j by deeplearning4j.
the class TextPipelineTest method testWordFreqAccNotIdentifyingStopWords.
@Test
public void testWordFreqAccNotIdentifyingStopWords() throws Exception {
JavaSparkContext sc = getContext();
// word2vec.setRemoveStop(false);
JavaRDD<String> corpusRDD = getCorpusRDD(sc);
Broadcast<Map<String, Object>> broadcastTokenizerVarMap = sc.broadcast(word2vecNoStop.getTokenizerVarMap());
TextPipeline pipeline = new TextPipeline(corpusRDD, broadcastTokenizerVarMap);
JavaRDD<List<String>> tokenizedRDD = pipeline.tokenize();
pipeline.updateAndReturnAccumulatorVal(tokenizedRDD);
Counter<String> wordFreqCounter = pipeline.getWordFreqAcc().value();
assertEquals(wordFreqCounter.getCount("is"), 1, 0);
assertEquals(wordFreqCounter.getCount("this"), 1, 0);
assertEquals(wordFreqCounter.getCount("are"), 1, 0);
assertEquals(wordFreqCounter.getCount("a"), 1, 0);
assertEquals(wordFreqCounter.getCount("strange"), 2, 0);
assertEquals(wordFreqCounter.getCount("flowers"), 1, 0);
assertEquals(wordFreqCounter.getCount("world"), 1, 0);
assertEquals(wordFreqCounter.getCount("red"), 1, 0);
sc.stop();
}
use of org.deeplearning4j.spark.text.functions.TextPipeline in project deeplearning4j by deeplearning4j.
the class TextPipelineTest method testSyn0AfterFirstIteration.
@Test
public void testSyn0AfterFirstIteration() throws Exception {
JavaSparkContext sc = getContext();
JavaRDD<String> corpusRDD = getCorpusRDD(sc);
// word2vec.setRemoveStop(false);
Broadcast<Map<String, Object>> broadcastTokenizerVarMap = sc.broadcast(word2vec.getTokenizerVarMap());
TextPipeline pipeline = new TextPipeline(corpusRDD, broadcastTokenizerVarMap);
pipeline.buildVocabCache();
pipeline.buildVocabWordListRDD();
VocabCache<VocabWord> vocabCache = pipeline.getVocabCache();
Huffman huffman = new Huffman(vocabCache.vocabWords());
huffman.build();
// Get total word count and put into word2vec variable map
Map<String, Object> word2vecVarMap = word2vec.getWord2vecVarMap();
word2vecVarMap.put("totalWordCount", pipeline.getTotalWordCount());
double[] expTable = word2vec.getExpTable();
JavaRDD<AtomicLong> sentenceCountRDD = pipeline.getSentenceCountRDD();
JavaRDD<List<VocabWord>> vocabWordListRDD = pipeline.getVocabWordListRDD();
CountCumSum countCumSum = new CountCumSum(sentenceCountRDD);
JavaRDD<Long> sentenceCountCumSumRDD = countCumSum.buildCumSum();
JavaPairRDD<List<VocabWord>, Long> vocabWordListSentenceCumSumRDD = vocabWordListRDD.zip(sentenceCountCumSumRDD);
Broadcast<Map<String, Object>> word2vecVarMapBroadcast = sc.broadcast(word2vecVarMap);
Broadcast<double[]> expTableBroadcast = sc.broadcast(expTable);
FirstIterationFunction firstIterationFunction = new FirstIterationFunction(word2vecVarMapBroadcast, expTableBroadcast, pipeline.getBroadCastVocabCache());
JavaRDD<Pair<VocabWord, INDArray>> pointSyn0Vec = vocabWordListSentenceCumSumRDD.mapPartitions(firstIterationFunction).map(new MapToPairFunction());
}
use of org.deeplearning4j.spark.text.functions.TextPipeline in project deeplearning4j by deeplearning4j.
the class TextPipelineTest method testFirstIteration.
@Test
public void testFirstIteration() throws Exception {
JavaSparkContext sc = getContext();
JavaRDD<String> corpusRDD = getCorpusRDD(sc);
// word2vec.setRemoveStop(false);
Broadcast<Map<String, Object>> broadcastTokenizerVarMap = sc.broadcast(word2vec.getTokenizerVarMap());
TextPipeline pipeline = new TextPipeline(corpusRDD, broadcastTokenizerVarMap);
pipeline.buildVocabCache();
pipeline.buildVocabWordListRDD();
VocabCache<VocabWord> vocabCache = pipeline.getVocabCache();
/* Huffman huffman = new Huffman(vocabCache.vocabWords());
huffman.build();
huffman.applyIndexes(vocabCache);
*/
VocabWord token = vocabCache.tokenFor("strange");
VocabWord word = vocabCache.wordFor("strange");
log.info("Strange token: " + token);
log.info("Strange word: " + word);
// Get total word count and put into word2vec variable map
Map<String, Object> word2vecVarMap = word2vec.getWord2vecVarMap();
word2vecVarMap.put("totalWordCount", pipeline.getTotalWordCount());
double[] expTable = word2vec.getExpTable();
JavaRDD<AtomicLong> sentenceCountRDD = pipeline.getSentenceCountRDD();
JavaRDD<List<VocabWord>> vocabWordListRDD = pipeline.getVocabWordListRDD();
CountCumSum countCumSum = new CountCumSum(sentenceCountRDD);
JavaRDD<Long> sentenceCountCumSumRDD = countCumSum.buildCumSum();
JavaPairRDD<List<VocabWord>, Long> vocabWordListSentenceCumSumRDD = vocabWordListRDD.zip(sentenceCountCumSumRDD);
Broadcast<Map<String, Object>> word2vecVarMapBroadcast = sc.broadcast(word2vecVarMap);
Broadcast<double[]> expTableBroadcast = sc.broadcast(expTable);
Iterator<Tuple2<List<VocabWord>, Long>> iterator = vocabWordListSentenceCumSumRDD.collect().iterator();
FirstIterationFunctionAdapter firstIterationFunction = new FirstIterationFunctionAdapter(word2vecVarMapBroadcast, expTableBroadcast, pipeline.getBroadCastVocabCache());
Iterable<Map.Entry<VocabWord, INDArray>> ret = firstIterationFunction.call(iterator);
assertTrue(ret.iterator().hasNext());
}
Aggregations