use of org.deeplearning4j.spark.text.functions.CountCumSum in project deeplearning4j by deeplearning4j.
the class Word2Vec method train.
/**
* Training word2vec model on a given text corpus
*
* @param corpusRDD training corpus
* @throws Exception
*/
public void train(JavaRDD<String> corpusRDD) throws Exception {
log.info("Start training ...");
if (workers > 0)
corpusRDD.repartition(workers);
// SparkContext
final JavaSparkContext sc = new JavaSparkContext(corpusRDD.context());
// Pre-defined variables
Map<String, Object> tokenizerVarMap = getTokenizerVarMap();
Map<String, Object> word2vecVarMap = getWord2vecVarMap();
// Variables to fill in train
final JavaRDD<AtomicLong> sentenceWordsCountRDD;
final JavaRDD<List<VocabWord>> vocabWordListRDD;
final JavaPairRDD<List<VocabWord>, Long> vocabWordListSentenceCumSumRDD;
final VocabCache<VocabWord> vocabCache;
final JavaRDD<Long> sentenceCumSumCountRDD;
int maxRep = 1;
// Start Training //
//////////////////////////////////////
log.info("Tokenization and building VocabCache ...");
// Processing every sentence and make a VocabCache which gets fed into a LookupCache
Broadcast<Map<String, Object>> broadcastTokenizerVarMap = sc.broadcast(tokenizerVarMap);
TextPipeline pipeline = new TextPipeline(corpusRDD, broadcastTokenizerVarMap);
pipeline.buildVocabCache();
pipeline.buildVocabWordListRDD();
// Get total word count and put into word2vec variable map
word2vecVarMap.put("totalWordCount", pipeline.getTotalWordCount());
// 2 RDDs: (vocab words list) and (sentence Count).Already cached
sentenceWordsCountRDD = pipeline.getSentenceCountRDD();
vocabWordListRDD = pipeline.getVocabWordListRDD();
// Get vocabCache and broad-casted vocabCache
Broadcast<VocabCache<VocabWord>> vocabCacheBroadcast = pipeline.getBroadCastVocabCache();
vocabCache = vocabCacheBroadcast.getValue();
log.info("Vocab size: {}", vocabCache.numWords());
//////////////////////////////////////
log.info("Building Huffman Tree ...");
// Building Huffman Tree would update the code and point in each of the vocabWord in vocabCache
/*
We don't need to build tree here, since it was built earlier, at TextPipeline.buildVocabCache() call.
Huffman huffman = new Huffman(vocabCache.vocabWords());
huffman.build();
huffman.applyIndexes(vocabCache);
*/
//////////////////////////////////////
log.info("Calculating cumulative sum of sentence counts ...");
sentenceCumSumCountRDD = new CountCumSum(sentenceWordsCountRDD).buildCumSum();
//////////////////////////////////////
log.info("Mapping to RDD(vocabWordList, cumulative sentence count) ...");
vocabWordListSentenceCumSumRDD = vocabWordListRDD.zip(sentenceCumSumCountRDD).setName("vocabWordListSentenceCumSumRDD");
/////////////////////////////////////
log.info("Broadcasting word2vec variables to workers ...");
Broadcast<Map<String, Object>> word2vecVarMapBroadcast = sc.broadcast(word2vecVarMap);
Broadcast<double[]> expTableBroadcast = sc.broadcast(expTable);
/////////////////////////////////////
log.info("Training word2vec sentences ...");
FlatMapFunction firstIterFunc = new FirstIterationFunction(word2vecVarMapBroadcast, expTableBroadcast, vocabCacheBroadcast);
@SuppressWarnings("unchecked") JavaRDD<Pair<VocabWord, INDArray>> indexSyn0UpdateEntryRDD = vocabWordListSentenceCumSumRDD.mapPartitions(firstIterFunc).map(new MapToPairFunction());
// Get all the syn0 updates into a list in driver
List<Pair<VocabWord, INDArray>> syn0UpdateEntries = indexSyn0UpdateEntryRDD.collect();
// Instantiate syn0
INDArray syn0 = Nd4j.zeros(vocabCache.numWords(), layerSize);
// Updating syn0 first pass: just add vectors obtained from different nodes
log.info("Averaging results...");
Map<VocabWord, AtomicInteger> updates = new HashMap<>();
Map<Long, Long> updaters = new HashMap<>();
for (Pair<VocabWord, INDArray> syn0UpdateEntry : syn0UpdateEntries) {
syn0.getRow(syn0UpdateEntry.getFirst().getIndex()).addi(syn0UpdateEntry.getSecond());
// for proper averaging we need to divide resulting sums later, by the number of additions
if (updates.containsKey(syn0UpdateEntry.getFirst())) {
updates.get(syn0UpdateEntry.getFirst()).incrementAndGet();
} else
updates.put(syn0UpdateEntry.getFirst(), new AtomicInteger(1));
if (!updaters.containsKey(syn0UpdateEntry.getFirst().getVocabId())) {
updaters.put(syn0UpdateEntry.getFirst().getVocabId(), syn0UpdateEntry.getFirst().getAffinityId());
}
}
// Updating syn0 second pass: average obtained vectors
for (Map.Entry<VocabWord, AtomicInteger> entry : updates.entrySet()) {
if (entry.getValue().get() > 1) {
if (entry.getValue().get() > maxRep)
maxRep = entry.getValue().get();
syn0.getRow(entry.getKey().getIndex()).divi(entry.getValue().get());
}
}
long totals = 0;
log.info("Finished calculations...");
vocab = vocabCache;
InMemoryLookupTable<VocabWord> inMemoryLookupTable = new InMemoryLookupTable<VocabWord>();
Environment env = EnvironmentUtils.buildEnvironment();
env.setNumCores(maxRep);
env.setAvailableMemory(totals);
update(env, Event.SPARK);
inMemoryLookupTable.setVocab(vocabCache);
inMemoryLookupTable.setVectorLength(layerSize);
inMemoryLookupTable.setSyn0(syn0);
lookupTable = inMemoryLookupTable;
modelUtils.init(lookupTable);
}
use of org.deeplearning4j.spark.text.functions.CountCumSum in project deeplearning4j by deeplearning4j.
the class TextPipelineTest method testZipFunction1.
/**
* This test checked generations retrieved using stopWords
*
* @throws Exception
*/
@Test
public void testZipFunction1() throws Exception {
JavaSparkContext sc = getContext();
JavaRDD<String> corpusRDD = getCorpusRDD(sc);
// word2vec.setRemoveStop(false);
Broadcast<Map<String, Object>> broadcastTokenizerVarMap = sc.broadcast(word2vec.getTokenizerVarMap());
TextPipeline pipeline = new TextPipeline(corpusRDD, broadcastTokenizerVarMap);
pipeline.buildVocabCache();
pipeline.buildVocabWordListRDD();
JavaRDD<AtomicLong> sentenceCountRDD = pipeline.getSentenceCountRDD();
JavaRDD<List<VocabWord>> vocabWordListRDD = pipeline.getVocabWordListRDD();
CountCumSum countCumSum = new CountCumSum(sentenceCountRDD);
JavaRDD<Long> sentenceCountCumSumRDD = countCumSum.buildCumSum();
JavaPairRDD<List<VocabWord>, Long> vocabWordListSentenceCumSumRDD = vocabWordListRDD.zip(sentenceCountCumSumRDD);
List<Tuple2<List<VocabWord>, Long>> lst = vocabWordListSentenceCumSumRDD.collect();
List<VocabWord> vocabWordsList1 = lst.get(0)._1();
Long cumSumSize1 = lst.get(0)._2();
assertEquals(3, vocabWordsList1.size());
assertEquals(vocabWordsList1.get(0).getWord(), "strange");
assertEquals(vocabWordsList1.get(1).getWord(), "strange");
assertEquals(vocabWordsList1.get(2).getWord(), "world");
assertEquals(cumSumSize1, 6L, 0);
List<VocabWord> vocabWordsList2 = lst.get(1)._1();
Long cumSumSize2 = lst.get(1)._2();
assertEquals(2, vocabWordsList2.size());
assertEquals(vocabWordsList2.get(0).getWord(), "flowers");
assertEquals(vocabWordsList2.get(1).getWord(), "red");
assertEquals(cumSumSize2, 9L, 0);
sc.stop();
}
use of org.deeplearning4j.spark.text.functions.CountCumSum in project deeplearning4j by deeplearning4j.
the class TextPipelineTest method testZipFunction2.
@Test
public void testZipFunction2() throws Exception {
JavaSparkContext sc = getContext();
JavaRDD<String> corpusRDD = getCorpusRDD(sc);
// word2vec.setRemoveStop(false);
Broadcast<Map<String, Object>> broadcastTokenizerVarMap = sc.broadcast(word2vecNoStop.getTokenizerVarMap());
TextPipeline pipeline = new TextPipeline(corpusRDD, broadcastTokenizerVarMap);
pipeline.buildVocabCache();
pipeline.buildVocabWordListRDD();
JavaRDD<AtomicLong> sentenceCountRDD = pipeline.getSentenceCountRDD();
JavaRDD<List<VocabWord>> vocabWordListRDD = pipeline.getVocabWordListRDD();
CountCumSum countCumSum = new CountCumSum(sentenceCountRDD);
JavaRDD<Long> sentenceCountCumSumRDD = countCumSum.buildCumSum();
JavaPairRDD<List<VocabWord>, Long> vocabWordListSentenceCumSumRDD = vocabWordListRDD.zip(sentenceCountCumSumRDD);
List<Tuple2<List<VocabWord>, Long>> lst = vocabWordListSentenceCumSumRDD.collect();
List<VocabWord> vocabWordsList1 = lst.get(0)._1();
Long cumSumSize1 = lst.get(0)._2();
assertEquals(6, vocabWordsList1.size());
assertEquals(vocabWordsList1.get(0).getWord(), "this");
assertEquals(vocabWordsList1.get(1).getWord(), "is");
assertEquals(vocabWordsList1.get(2).getWord(), "a");
assertEquals(vocabWordsList1.get(3).getWord(), "strange");
assertEquals(vocabWordsList1.get(4).getWord(), "strange");
assertEquals(vocabWordsList1.get(5).getWord(), "world");
assertEquals(cumSumSize1, 6L, 0);
List<VocabWord> vocabWordsList2 = lst.get(1)._1();
Long cumSumSize2 = lst.get(1)._2();
assertEquals(vocabWordsList2.size(), 3);
assertEquals(vocabWordsList2.get(0).getWord(), "flowers");
assertEquals(vocabWordsList2.get(1).getWord(), "are");
assertEquals(vocabWordsList2.get(2).getWord(), "red");
assertEquals(cumSumSize2, 9L, 0);
sc.stop();
}
use of org.deeplearning4j.spark.text.functions.CountCumSum in project deeplearning4j by deeplearning4j.
the class TextPipelineTest method testCountCumSum.
@Test
public void testCountCumSum() throws Exception {
JavaSparkContext sc = getContext();
JavaRDD<String> corpusRDD = getCorpusRDD(sc);
Broadcast<Map<String, Object>> broadcastTokenizerVarMap = sc.broadcast(word2vec.getTokenizerVarMap());
TextPipeline pipeline = new TextPipeline(corpusRDD, broadcastTokenizerVarMap);
pipeline.buildVocabCache();
pipeline.buildVocabWordListRDD();
JavaRDD<AtomicLong> sentenceCountRDD = pipeline.getSentenceCountRDD();
CountCumSum countCumSum = new CountCumSum(sentenceCountRDD);
JavaRDD<Long> sentenceCountCumSumRDD = countCumSum.buildCumSum();
List<Long> sentenceCountCumSumList = sentenceCountCumSumRDD.collect();
assertTrue(sentenceCountCumSumList.get(0) == 6L);
assertTrue(sentenceCountCumSumList.get(1) == 9L);
sc.stop();
}
use of org.deeplearning4j.spark.text.functions.CountCumSum in project deeplearning4j by deeplearning4j.
the class TextPipelineTest method testSyn0AfterFirstIteration.
@Test
public void testSyn0AfterFirstIteration() throws Exception {
JavaSparkContext sc = getContext();
JavaRDD<String> corpusRDD = getCorpusRDD(sc);
// word2vec.setRemoveStop(false);
Broadcast<Map<String, Object>> broadcastTokenizerVarMap = sc.broadcast(word2vec.getTokenizerVarMap());
TextPipeline pipeline = new TextPipeline(corpusRDD, broadcastTokenizerVarMap);
pipeline.buildVocabCache();
pipeline.buildVocabWordListRDD();
VocabCache<VocabWord> vocabCache = pipeline.getVocabCache();
Huffman huffman = new Huffman(vocabCache.vocabWords());
huffman.build();
// Get total word count and put into word2vec variable map
Map<String, Object> word2vecVarMap = word2vec.getWord2vecVarMap();
word2vecVarMap.put("totalWordCount", pipeline.getTotalWordCount());
double[] expTable = word2vec.getExpTable();
JavaRDD<AtomicLong> sentenceCountRDD = pipeline.getSentenceCountRDD();
JavaRDD<List<VocabWord>> vocabWordListRDD = pipeline.getVocabWordListRDD();
CountCumSum countCumSum = new CountCumSum(sentenceCountRDD);
JavaRDD<Long> sentenceCountCumSumRDD = countCumSum.buildCumSum();
JavaPairRDD<List<VocabWord>, Long> vocabWordListSentenceCumSumRDD = vocabWordListRDD.zip(sentenceCountCumSumRDD);
Broadcast<Map<String, Object>> word2vecVarMapBroadcast = sc.broadcast(word2vecVarMap);
Broadcast<double[]> expTableBroadcast = sc.broadcast(expTable);
FirstIterationFunction firstIterationFunction = new FirstIterationFunction(word2vecVarMapBroadcast, expTableBroadcast, pipeline.getBroadCastVocabCache());
JavaRDD<Pair<VocabWord, INDArray>> pointSyn0Vec = vocabWordListSentenceCumSumRDD.mapPartitions(firstIterationFunction).map(new MapToPairFunction());
}
Aggregations