use of org.deeplearning4j.spark.text.functions.TextPipeline in project deeplearning4j by deeplearning4j.
the class Glove method train.
/**
* Train on the corpus
* @param rdd the rdd to train
* @return the vocab and weights
*/
public Pair<VocabCache<VocabWord>, GloveWeightLookupTable> train(JavaRDD<String> rdd) throws Exception {
// Each `train()` can use different parameters
final JavaSparkContext sc = new JavaSparkContext(rdd.context());
final SparkConf conf = sc.getConf();
final int vectorLength = assignVar(VECTOR_LENGTH, conf, Integer.class);
final boolean useAdaGrad = assignVar(ADAGRAD, conf, Boolean.class);
final double negative = assignVar(NEGATIVE, conf, Double.class);
final int numWords = assignVar(NUM_WORDS, conf, Integer.class);
final int window = assignVar(WINDOW, conf, Integer.class);
final double alpha = assignVar(ALPHA, conf, Double.class);
final double minAlpha = assignVar(MIN_ALPHA, conf, Double.class);
final int iterations = assignVar(ITERATIONS, conf, Integer.class);
final int nGrams = assignVar(N_GRAMS, conf, Integer.class);
final String tokenizer = assignVar(TOKENIZER, conf, String.class);
final String tokenPreprocessor = assignVar(TOKEN_PREPROCESSOR, conf, String.class);
final boolean removeStop = assignVar(REMOVE_STOPWORDS, conf, Boolean.class);
Map<String, Object> tokenizerVarMap = new HashMap<String, Object>() {
{
put("numWords", numWords);
put("nGrams", nGrams);
put("tokenizer", tokenizer);
put("tokenPreprocessor", tokenPreprocessor);
put("removeStop", removeStop);
}
};
Broadcast<Map<String, Object>> broadcastTokenizerVarMap = sc.broadcast(tokenizerVarMap);
TextPipeline pipeline = new TextPipeline(rdd, broadcastTokenizerVarMap);
pipeline.buildVocabCache();
pipeline.buildVocabWordListRDD();
// Get total word count
Long totalWordCount = pipeline.getTotalWordCount();
VocabCache<VocabWord> vocabCache = pipeline.getVocabCache();
JavaRDD<Pair<List<String>, AtomicLong>> sentenceWordsCountRDD = pipeline.getSentenceWordsCountRDD();
final Pair<VocabCache<VocabWord>, Long> vocabAndNumWords = new Pair<>(vocabCache, totalWordCount);
vocabCacheBroadcast = sc.broadcast(vocabAndNumWords.getFirst());
final GloveWeightLookupTable gloveWeightLookupTable = new GloveWeightLookupTable.Builder().cache(vocabAndNumWords.getFirst()).lr(conf.getDouble(GlovePerformer.ALPHA, 0.01)).maxCount(conf.getDouble(GlovePerformer.MAX_COUNT, 100)).vectorLength(conf.getInt(GlovePerformer.VECTOR_LENGTH, 300)).xMax(conf.getDouble(GlovePerformer.X_MAX, 0.75)).build();
gloveWeightLookupTable.resetWeights();
gloveWeightLookupTable.getBiasAdaGrad().historicalGradient = Nd4j.ones(gloveWeightLookupTable.getSyn0().rows());
gloveWeightLookupTable.getWeightAdaGrad().historicalGradient = Nd4j.ones(gloveWeightLookupTable.getSyn0().shape());
log.info("Created lookup table of size " + Arrays.toString(gloveWeightLookupTable.getSyn0().shape()));
CounterMap<String, String> coOccurrenceCounts = sentenceWordsCountRDD.map(new CoOccurrenceCalculator(symmetric, vocabCacheBroadcast, windowSize)).fold(new CounterMap<String, String>(), new CoOccurrenceCounts());
Iterator<Pair<String, String>> pair2 = coOccurrenceCounts.getPairIterator();
List<Triple<String, String, Double>> counts = new ArrayList<>();
while (pair2.hasNext()) {
Pair<String, String> next = pair2.next();
if (coOccurrenceCounts.getCount(next.getFirst(), next.getSecond()) > gloveWeightLookupTable.getMaxCount()) {
coOccurrenceCounts.setCount(next.getFirst(), next.getSecond(), gloveWeightLookupTable.getMaxCount());
}
counts.add(new Triple<>(next.getFirst(), next.getSecond(), coOccurrenceCounts.getCount(next.getFirst(), next.getSecond())));
}
log.info("Calculated co occurrences");
JavaRDD<Triple<String, String, Double>> parallel = sc.parallelize(counts);
JavaPairRDD<String, Tuple2<String, Double>> pairs = parallel.mapToPair(new PairFunction<Triple<String, String, Double>, String, Tuple2<String, Double>>() {
@Override
public Tuple2<String, Tuple2<String, Double>> call(Triple<String, String, Double> stringStringDoubleTriple) throws Exception {
return new Tuple2<>(stringStringDoubleTriple.getFirst(), new Tuple2<>(stringStringDoubleTriple.getSecond(), stringStringDoubleTriple.getThird()));
}
});
JavaPairRDD<VocabWord, Tuple2<VocabWord, Double>> pairsVocab = pairs.mapToPair(new PairFunction<Tuple2<String, Tuple2<String, Double>>, VocabWord, Tuple2<VocabWord, Double>>() {
@Override
public Tuple2<VocabWord, Tuple2<VocabWord, Double>> call(Tuple2<String, Tuple2<String, Double>> stringTuple2Tuple2) throws Exception {
VocabWord w1 = vocabCacheBroadcast.getValue().wordFor(stringTuple2Tuple2._1());
VocabWord w2 = vocabCacheBroadcast.getValue().wordFor(stringTuple2Tuple2._2()._1());
return new Tuple2<>(w1, new Tuple2<>(w2, stringTuple2Tuple2._2()._2()));
}
});
for (int i = 0; i < iterations; i++) {
JavaRDD<GloveChange> change = pairsVocab.map(new Function<Tuple2<VocabWord, Tuple2<VocabWord, Double>>, GloveChange>() {
@Override
public GloveChange call(Tuple2<VocabWord, Tuple2<VocabWord, Double>> vocabWordTuple2Tuple2) throws Exception {
VocabWord w1 = vocabWordTuple2Tuple2._1();
VocabWord w2 = vocabWordTuple2Tuple2._2()._1();
INDArray w1Vector = gloveWeightLookupTable.getSyn0().slice(w1.getIndex());
INDArray w2Vector = gloveWeightLookupTable.getSyn0().slice(w2.getIndex());
INDArray bias = gloveWeightLookupTable.getBias();
double score = vocabWordTuple2Tuple2._2()._2();
double xMax = gloveWeightLookupTable.getxMax();
double maxCount = gloveWeightLookupTable.getMaxCount();
//w1 * w2 + bias
double prediction = Nd4j.getBlasWrapper().dot(w1Vector, w2Vector);
prediction += bias.getDouble(w1.getIndex()) + bias.getDouble(w2.getIndex());
double weight = FastMath.pow(Math.min(1.0, (score / maxCount)), xMax);
double fDiff = score > xMax ? prediction : weight * (prediction - Math.log(score));
if (Double.isNaN(fDiff))
fDiff = Nd4j.EPS_THRESHOLD;
//amount of change
double gradient = fDiff;
Pair<INDArray, Double> w1Update = update(gloveWeightLookupTable.getWeightAdaGrad(), gloveWeightLookupTable.getBiasAdaGrad(), gloveWeightLookupTable.getSyn0(), gloveWeightLookupTable.getBias(), w1, w1Vector, w2Vector, gradient);
Pair<INDArray, Double> w2Update = update(gloveWeightLookupTable.getWeightAdaGrad(), gloveWeightLookupTable.getBiasAdaGrad(), gloveWeightLookupTable.getSyn0(), gloveWeightLookupTable.getBias(), w2, w2Vector, w1Vector, gradient);
return new GloveChange(w1, w2, w1Update.getFirst(), w2Update.getFirst(), w1Update.getSecond(), w2Update.getSecond(), fDiff, gloveWeightLookupTable.getWeightAdaGrad().getHistoricalGradient().slice(w1.getIndex()), gloveWeightLookupTable.getWeightAdaGrad().getHistoricalGradient().slice(w2.getIndex()), gloveWeightLookupTable.getBiasAdaGrad().getHistoricalGradient().getDouble(w2.getIndex()), gloveWeightLookupTable.getBiasAdaGrad().getHistoricalGradient().getDouble(w1.getIndex()));
}
});
List<GloveChange> gloveChanges = change.collect();
double error = 0.0;
for (GloveChange change2 : gloveChanges) {
change2.apply(gloveWeightLookupTable);
error += change2.getError();
}
List l = pairsVocab.collect();
Collections.shuffle(l);
pairsVocab = sc.parallelizePairs(l);
log.info("Error at iteration " + i + " was " + error);
}
return new Pair<>(vocabAndNumWords.getFirst(), gloveWeightLookupTable);
}
use of org.deeplearning4j.spark.text.functions.TextPipeline in project deeplearning4j by deeplearning4j.
the class Word2Vec method train.
/**
* Training word2vec model on a given text corpus
*
* @param corpusRDD training corpus
* @throws Exception
*/
public void train(JavaRDD<String> corpusRDD) throws Exception {
log.info("Start training ...");
if (workers > 0)
corpusRDD.repartition(workers);
// SparkContext
final JavaSparkContext sc = new JavaSparkContext(corpusRDD.context());
// Pre-defined variables
Map<String, Object> tokenizerVarMap = getTokenizerVarMap();
Map<String, Object> word2vecVarMap = getWord2vecVarMap();
// Variables to fill in train
final JavaRDD<AtomicLong> sentenceWordsCountRDD;
final JavaRDD<List<VocabWord>> vocabWordListRDD;
final JavaPairRDD<List<VocabWord>, Long> vocabWordListSentenceCumSumRDD;
final VocabCache<VocabWord> vocabCache;
final JavaRDD<Long> sentenceCumSumCountRDD;
int maxRep = 1;
// Start Training //
//////////////////////////////////////
log.info("Tokenization and building VocabCache ...");
// Processing every sentence and make a VocabCache which gets fed into a LookupCache
Broadcast<Map<String, Object>> broadcastTokenizerVarMap = sc.broadcast(tokenizerVarMap);
TextPipeline pipeline = new TextPipeline(corpusRDD, broadcastTokenizerVarMap);
pipeline.buildVocabCache();
pipeline.buildVocabWordListRDD();
// Get total word count and put into word2vec variable map
word2vecVarMap.put("totalWordCount", pipeline.getTotalWordCount());
// 2 RDDs: (vocab words list) and (sentence Count).Already cached
sentenceWordsCountRDD = pipeline.getSentenceCountRDD();
vocabWordListRDD = pipeline.getVocabWordListRDD();
// Get vocabCache and broad-casted vocabCache
Broadcast<VocabCache<VocabWord>> vocabCacheBroadcast = pipeline.getBroadCastVocabCache();
vocabCache = vocabCacheBroadcast.getValue();
log.info("Vocab size: {}", vocabCache.numWords());
//////////////////////////////////////
log.info("Building Huffman Tree ...");
// Building Huffman Tree would update the code and point in each of the vocabWord in vocabCache
/*
We don't need to build tree here, since it was built earlier, at TextPipeline.buildVocabCache() call.
Huffman huffman = new Huffman(vocabCache.vocabWords());
huffman.build();
huffman.applyIndexes(vocabCache);
*/
//////////////////////////////////////
log.info("Calculating cumulative sum of sentence counts ...");
sentenceCumSumCountRDD = new CountCumSum(sentenceWordsCountRDD).buildCumSum();
//////////////////////////////////////
log.info("Mapping to RDD(vocabWordList, cumulative sentence count) ...");
vocabWordListSentenceCumSumRDD = vocabWordListRDD.zip(sentenceCumSumCountRDD).setName("vocabWordListSentenceCumSumRDD");
/////////////////////////////////////
log.info("Broadcasting word2vec variables to workers ...");
Broadcast<Map<String, Object>> word2vecVarMapBroadcast = sc.broadcast(word2vecVarMap);
Broadcast<double[]> expTableBroadcast = sc.broadcast(expTable);
/////////////////////////////////////
log.info("Training word2vec sentences ...");
FlatMapFunction firstIterFunc = new FirstIterationFunction(word2vecVarMapBroadcast, expTableBroadcast, vocabCacheBroadcast);
@SuppressWarnings("unchecked") JavaRDD<Pair<VocabWord, INDArray>> indexSyn0UpdateEntryRDD = vocabWordListSentenceCumSumRDD.mapPartitions(firstIterFunc).map(new MapToPairFunction());
// Get all the syn0 updates into a list in driver
List<Pair<VocabWord, INDArray>> syn0UpdateEntries = indexSyn0UpdateEntryRDD.collect();
// Instantiate syn0
INDArray syn0 = Nd4j.zeros(vocabCache.numWords(), layerSize);
// Updating syn0 first pass: just add vectors obtained from different nodes
log.info("Averaging results...");
Map<VocabWord, AtomicInteger> updates = new HashMap<>();
Map<Long, Long> updaters = new HashMap<>();
for (Pair<VocabWord, INDArray> syn0UpdateEntry : syn0UpdateEntries) {
syn0.getRow(syn0UpdateEntry.getFirst().getIndex()).addi(syn0UpdateEntry.getSecond());
// for proper averaging we need to divide resulting sums later, by the number of additions
if (updates.containsKey(syn0UpdateEntry.getFirst())) {
updates.get(syn0UpdateEntry.getFirst()).incrementAndGet();
} else
updates.put(syn0UpdateEntry.getFirst(), new AtomicInteger(1));
if (!updaters.containsKey(syn0UpdateEntry.getFirst().getVocabId())) {
updaters.put(syn0UpdateEntry.getFirst().getVocabId(), syn0UpdateEntry.getFirst().getAffinityId());
}
}
// Updating syn0 second pass: average obtained vectors
for (Map.Entry<VocabWord, AtomicInteger> entry : updates.entrySet()) {
if (entry.getValue().get() > 1) {
if (entry.getValue().get() > maxRep)
maxRep = entry.getValue().get();
syn0.getRow(entry.getKey().getIndex()).divi(entry.getValue().get());
}
}
long totals = 0;
log.info("Finished calculations...");
vocab = vocabCache;
InMemoryLookupTable<VocabWord> inMemoryLookupTable = new InMemoryLookupTable<VocabWord>();
Environment env = EnvironmentUtils.buildEnvironment();
env.setNumCores(maxRep);
env.setAvailableMemory(totals);
update(env, Event.SPARK);
inMemoryLookupTable.setVocab(vocabCache);
inMemoryLookupTable.setVectorLength(layerSize);
inMemoryLookupTable.setSyn0(syn0);
lookupTable = inMemoryLookupTable;
modelUtils.init(lookupTable);
}
use of org.deeplearning4j.spark.text.functions.TextPipeline in project deeplearning4j by deeplearning4j.
the class TextPipelineTest method testBuildVocabWordListRDD.
@Test
public void testBuildVocabWordListRDD() throws Exception {
JavaSparkContext sc = getContext();
JavaRDD<String> corpusRDD = getCorpusRDD(sc);
Broadcast<Map<String, Object>> broadcastTokenizerVarMap = sc.broadcast(word2vec.getTokenizerVarMap());
TextPipeline pipeline = new TextPipeline(corpusRDD, broadcastTokenizerVarMap);
pipeline.buildVocabCache();
pipeline.buildVocabWordListRDD();
JavaRDD<AtomicLong> sentenceCountRDD = pipeline.getSentenceCountRDD();
JavaRDD<List<VocabWord>> vocabWordListRDD = pipeline.getVocabWordListRDD();
List<List<VocabWord>> vocabWordList = vocabWordListRDD.collect();
List<VocabWord> firstSentenceVocabList = vocabWordList.get(0);
List<VocabWord> secondSentenceVocabList = vocabWordList.get(1);
System.out.println(Arrays.deepToString(firstSentenceVocabList.toArray()));
List<String> firstSentenceTokenList = new ArrayList<>();
List<String> secondSentenceTokenList = new ArrayList<>();
for (VocabWord v : firstSentenceVocabList) {
if (v != null) {
firstSentenceTokenList.add(v.getWord());
}
}
for (VocabWord v : secondSentenceVocabList) {
if (v != null) {
secondSentenceTokenList.add(v.getWord());
}
}
assertEquals(pipeline.getTotalWordCount(), 9, 0);
assertEquals(sentenceCountRDD.collect().get(0).get(), 6);
assertEquals(sentenceCountRDD.collect().get(1).get(), 3);
assertTrue(firstSentenceTokenList.containsAll(Arrays.asList("strange", "strange", "world")));
assertTrue(secondSentenceTokenList.containsAll(Arrays.asList("flowers", "red")));
sc.stop();
}
use of org.deeplearning4j.spark.text.functions.TextPipeline in project deeplearning4j by deeplearning4j.
the class TextPipelineTest method testZipFunction1.
/**
* This test checked generations retrieved using stopWords
*
* @throws Exception
*/
@Test
public void testZipFunction1() throws Exception {
JavaSparkContext sc = getContext();
JavaRDD<String> corpusRDD = getCorpusRDD(sc);
// word2vec.setRemoveStop(false);
Broadcast<Map<String, Object>> broadcastTokenizerVarMap = sc.broadcast(word2vec.getTokenizerVarMap());
TextPipeline pipeline = new TextPipeline(corpusRDD, broadcastTokenizerVarMap);
pipeline.buildVocabCache();
pipeline.buildVocabWordListRDD();
JavaRDD<AtomicLong> sentenceCountRDD = pipeline.getSentenceCountRDD();
JavaRDD<List<VocabWord>> vocabWordListRDD = pipeline.getVocabWordListRDD();
CountCumSum countCumSum = new CountCumSum(sentenceCountRDD);
JavaRDD<Long> sentenceCountCumSumRDD = countCumSum.buildCumSum();
JavaPairRDD<List<VocabWord>, Long> vocabWordListSentenceCumSumRDD = vocabWordListRDD.zip(sentenceCountCumSumRDD);
List<Tuple2<List<VocabWord>, Long>> lst = vocabWordListSentenceCumSumRDD.collect();
List<VocabWord> vocabWordsList1 = lst.get(0)._1();
Long cumSumSize1 = lst.get(0)._2();
assertEquals(3, vocabWordsList1.size());
assertEquals(vocabWordsList1.get(0).getWord(), "strange");
assertEquals(vocabWordsList1.get(1).getWord(), "strange");
assertEquals(vocabWordsList1.get(2).getWord(), "world");
assertEquals(cumSumSize1, 6L, 0);
List<VocabWord> vocabWordsList2 = lst.get(1)._1();
Long cumSumSize2 = lst.get(1)._2();
assertEquals(2, vocabWordsList2.size());
assertEquals(vocabWordsList2.get(0).getWord(), "flowers");
assertEquals(vocabWordsList2.get(1).getWord(), "red");
assertEquals(cumSumSize2, 9L, 0);
sc.stop();
}
use of org.deeplearning4j.spark.text.functions.TextPipeline in project deeplearning4j by deeplearning4j.
the class TextPipelineTest method testZipFunction2.
@Test
public void testZipFunction2() throws Exception {
JavaSparkContext sc = getContext();
JavaRDD<String> corpusRDD = getCorpusRDD(sc);
// word2vec.setRemoveStop(false);
Broadcast<Map<String, Object>> broadcastTokenizerVarMap = sc.broadcast(word2vecNoStop.getTokenizerVarMap());
TextPipeline pipeline = new TextPipeline(corpusRDD, broadcastTokenizerVarMap);
pipeline.buildVocabCache();
pipeline.buildVocabWordListRDD();
JavaRDD<AtomicLong> sentenceCountRDD = pipeline.getSentenceCountRDD();
JavaRDD<List<VocabWord>> vocabWordListRDD = pipeline.getVocabWordListRDD();
CountCumSum countCumSum = new CountCumSum(sentenceCountRDD);
JavaRDD<Long> sentenceCountCumSumRDD = countCumSum.buildCumSum();
JavaPairRDD<List<VocabWord>, Long> vocabWordListSentenceCumSumRDD = vocabWordListRDD.zip(sentenceCountCumSumRDD);
List<Tuple2<List<VocabWord>, Long>> lst = vocabWordListSentenceCumSumRDD.collect();
List<VocabWord> vocabWordsList1 = lst.get(0)._1();
Long cumSumSize1 = lst.get(0)._2();
assertEquals(6, vocabWordsList1.size());
assertEquals(vocabWordsList1.get(0).getWord(), "this");
assertEquals(vocabWordsList1.get(1).getWord(), "is");
assertEquals(vocabWordsList1.get(2).getWord(), "a");
assertEquals(vocabWordsList1.get(3).getWord(), "strange");
assertEquals(vocabWordsList1.get(4).getWord(), "strange");
assertEquals(vocabWordsList1.get(5).getWord(), "world");
assertEquals(cumSumSize1, 6L, 0);
List<VocabWord> vocabWordsList2 = lst.get(1)._1();
Long cumSumSize2 = lst.get(1)._2();
assertEquals(vocabWordsList2.size(), 3);
assertEquals(vocabWordsList2.get(0).getWord(), "flowers");
assertEquals(vocabWordsList2.get(1).getWord(), "are");
assertEquals(vocabWordsList2.get(2).getWord(), "red");
assertEquals(cumSumSize2, 9L, 0);
sc.stop();
}
Aggregations