use of org.apache.spark.api.java.JavaSparkContext in project deeplearning4j by deeplearning4j.
the class Glove method train.
/**
* Train on the corpus
* @param rdd the rdd to train
* @return the vocab and weights
*/
public Pair<VocabCache<VocabWord>, GloveWeightLookupTable> train(JavaRDD<String> rdd) throws Exception {
// Each `train()` can use different parameters
final JavaSparkContext sc = new JavaSparkContext(rdd.context());
final SparkConf conf = sc.getConf();
final int vectorLength = assignVar(VECTOR_LENGTH, conf, Integer.class);
final boolean useAdaGrad = assignVar(ADAGRAD, conf, Boolean.class);
final double negative = assignVar(NEGATIVE, conf, Double.class);
final int numWords = assignVar(NUM_WORDS, conf, Integer.class);
final int window = assignVar(WINDOW, conf, Integer.class);
final double alpha = assignVar(ALPHA, conf, Double.class);
final double minAlpha = assignVar(MIN_ALPHA, conf, Double.class);
final int iterations = assignVar(ITERATIONS, conf, Integer.class);
final int nGrams = assignVar(N_GRAMS, conf, Integer.class);
final String tokenizer = assignVar(TOKENIZER, conf, String.class);
final String tokenPreprocessor = assignVar(TOKEN_PREPROCESSOR, conf, String.class);
final boolean removeStop = assignVar(REMOVE_STOPWORDS, conf, Boolean.class);
Map<String, Object> tokenizerVarMap = new HashMap<String, Object>() {
{
put("numWords", numWords);
put("nGrams", nGrams);
put("tokenizer", tokenizer);
put("tokenPreprocessor", tokenPreprocessor);
put("removeStop", removeStop);
}
};
Broadcast<Map<String, Object>> broadcastTokenizerVarMap = sc.broadcast(tokenizerVarMap);
TextPipeline pipeline = new TextPipeline(rdd, broadcastTokenizerVarMap);
pipeline.buildVocabCache();
pipeline.buildVocabWordListRDD();
// Get total word count
Long totalWordCount = pipeline.getTotalWordCount();
VocabCache<VocabWord> vocabCache = pipeline.getVocabCache();
JavaRDD<Pair<List<String>, AtomicLong>> sentenceWordsCountRDD = pipeline.getSentenceWordsCountRDD();
final Pair<VocabCache<VocabWord>, Long> vocabAndNumWords = new Pair<>(vocabCache, totalWordCount);
vocabCacheBroadcast = sc.broadcast(vocabAndNumWords.getFirst());
final GloveWeightLookupTable gloveWeightLookupTable = new GloveWeightLookupTable.Builder().cache(vocabAndNumWords.getFirst()).lr(conf.getDouble(GlovePerformer.ALPHA, 0.01)).maxCount(conf.getDouble(GlovePerformer.MAX_COUNT, 100)).vectorLength(conf.getInt(GlovePerformer.VECTOR_LENGTH, 300)).xMax(conf.getDouble(GlovePerformer.X_MAX, 0.75)).build();
gloveWeightLookupTable.resetWeights();
gloveWeightLookupTable.getBiasAdaGrad().historicalGradient = Nd4j.ones(gloveWeightLookupTable.getSyn0().rows());
gloveWeightLookupTable.getWeightAdaGrad().historicalGradient = Nd4j.ones(gloveWeightLookupTable.getSyn0().shape());
log.info("Created lookup table of size " + Arrays.toString(gloveWeightLookupTable.getSyn0().shape()));
CounterMap<String, String> coOccurrenceCounts = sentenceWordsCountRDD.map(new CoOccurrenceCalculator(symmetric, vocabCacheBroadcast, windowSize)).fold(new CounterMap<String, String>(), new CoOccurrenceCounts());
Iterator<Pair<String, String>> pair2 = coOccurrenceCounts.getPairIterator();
List<Triple<String, String, Double>> counts = new ArrayList<>();
while (pair2.hasNext()) {
Pair<String, String> next = pair2.next();
if (coOccurrenceCounts.getCount(next.getFirst(), next.getSecond()) > gloveWeightLookupTable.getMaxCount()) {
coOccurrenceCounts.setCount(next.getFirst(), next.getSecond(), gloveWeightLookupTable.getMaxCount());
}
counts.add(new Triple<>(next.getFirst(), next.getSecond(), coOccurrenceCounts.getCount(next.getFirst(), next.getSecond())));
}
log.info("Calculated co occurrences");
JavaRDD<Triple<String, String, Double>> parallel = sc.parallelize(counts);
JavaPairRDD<String, Tuple2<String, Double>> pairs = parallel.mapToPair(new PairFunction<Triple<String, String, Double>, String, Tuple2<String, Double>>() {
@Override
public Tuple2<String, Tuple2<String, Double>> call(Triple<String, String, Double> stringStringDoubleTriple) throws Exception {
return new Tuple2<>(stringStringDoubleTriple.getFirst(), new Tuple2<>(stringStringDoubleTriple.getSecond(), stringStringDoubleTriple.getThird()));
}
});
JavaPairRDD<VocabWord, Tuple2<VocabWord, Double>> pairsVocab = pairs.mapToPair(new PairFunction<Tuple2<String, Tuple2<String, Double>>, VocabWord, Tuple2<VocabWord, Double>>() {
@Override
public Tuple2<VocabWord, Tuple2<VocabWord, Double>> call(Tuple2<String, Tuple2<String, Double>> stringTuple2Tuple2) throws Exception {
VocabWord w1 = vocabCacheBroadcast.getValue().wordFor(stringTuple2Tuple2._1());
VocabWord w2 = vocabCacheBroadcast.getValue().wordFor(stringTuple2Tuple2._2()._1());
return new Tuple2<>(w1, new Tuple2<>(w2, stringTuple2Tuple2._2()._2()));
}
});
for (int i = 0; i < iterations; i++) {
JavaRDD<GloveChange> change = pairsVocab.map(new Function<Tuple2<VocabWord, Tuple2<VocabWord, Double>>, GloveChange>() {
@Override
public GloveChange call(Tuple2<VocabWord, Tuple2<VocabWord, Double>> vocabWordTuple2Tuple2) throws Exception {
VocabWord w1 = vocabWordTuple2Tuple2._1();
VocabWord w2 = vocabWordTuple2Tuple2._2()._1();
INDArray w1Vector = gloveWeightLookupTable.getSyn0().slice(w1.getIndex());
INDArray w2Vector = gloveWeightLookupTable.getSyn0().slice(w2.getIndex());
INDArray bias = gloveWeightLookupTable.getBias();
double score = vocabWordTuple2Tuple2._2()._2();
double xMax = gloveWeightLookupTable.getxMax();
double maxCount = gloveWeightLookupTable.getMaxCount();
//w1 * w2 + bias
double prediction = Nd4j.getBlasWrapper().dot(w1Vector, w2Vector);
prediction += bias.getDouble(w1.getIndex()) + bias.getDouble(w2.getIndex());
double weight = FastMath.pow(Math.min(1.0, (score / maxCount)), xMax);
double fDiff = score > xMax ? prediction : weight * (prediction - Math.log(score));
if (Double.isNaN(fDiff))
fDiff = Nd4j.EPS_THRESHOLD;
//amount of change
double gradient = fDiff;
Pair<INDArray, Double> w1Update = update(gloveWeightLookupTable.getWeightAdaGrad(), gloveWeightLookupTable.getBiasAdaGrad(), gloveWeightLookupTable.getSyn0(), gloveWeightLookupTable.getBias(), w1, w1Vector, w2Vector, gradient);
Pair<INDArray, Double> w2Update = update(gloveWeightLookupTable.getWeightAdaGrad(), gloveWeightLookupTable.getBiasAdaGrad(), gloveWeightLookupTable.getSyn0(), gloveWeightLookupTable.getBias(), w2, w2Vector, w1Vector, gradient);
return new GloveChange(w1, w2, w1Update.getFirst(), w2Update.getFirst(), w1Update.getSecond(), w2Update.getSecond(), fDiff, gloveWeightLookupTable.getWeightAdaGrad().getHistoricalGradient().slice(w1.getIndex()), gloveWeightLookupTable.getWeightAdaGrad().getHistoricalGradient().slice(w2.getIndex()), gloveWeightLookupTable.getBiasAdaGrad().getHistoricalGradient().getDouble(w2.getIndex()), gloveWeightLookupTable.getBiasAdaGrad().getHistoricalGradient().getDouble(w1.getIndex()));
}
});
List<GloveChange> gloveChanges = change.collect();
double error = 0.0;
for (GloveChange change2 : gloveChanges) {
change2.apply(gloveWeightLookupTable);
error += change2.getError();
}
List l = pairsVocab.collect();
Collections.shuffle(l);
pairsVocab = sc.parallelizePairs(l);
log.info("Error at iteration " + i + " was " + error);
}
return new Pair<>(vocabAndNumWords.getFirst(), gloveWeightLookupTable);
}
use of org.apache.spark.api.java.JavaSparkContext in project deeplearning4j by deeplearning4j.
the class SparkParagraphVectors method fitLabelledDocuments.
/**
* This method builds ParagraphVectors model, expecting JavaRDD<LabelledDocument>.
* It can be either non-tokenized documents, or tokenized.
*
* @param documentsRdd
*/
public void fitLabelledDocuments(JavaRDD<LabelledDocument> documentsRdd) {
validateConfiguration();
broadcastEnvironment(new JavaSparkContext(documentsRdd.context()));
JavaRDD<Sequence<VocabWord>> sequenceRDD = documentsRdd.map(new DocumentSequenceConvertFunction(configurationBroadcast));
super.fitSequences(sequenceRDD);
}
use of org.apache.spark.api.java.JavaSparkContext in project deeplearning4j by deeplearning4j.
the class SparkSequenceVectors method fitSequences.
/**
* Base training entry point
*
* @param corpus
*/
public void fitSequences(JavaRDD<Sequence<T>> corpus) {
/**
* Basically all we want for base implementation here is 3 things:
* a) build vocabulary
* b) build huffman tree
* c) do training
*
* in this case all classes extending SeqVec, like deepwalk or word2vec will be just building their RDD<Sequence<T>>,
* and calling this method for training, instead implementing own routines
*/
validateConfiguration();
if (ela == null) {
try {
ela = (SparkElementsLearningAlgorithm) Class.forName(configuration.getElementsLearningAlgorithm()).newInstance();
} catch (Exception e) {
throw new RuntimeException(e);
}
}
if (workers > 1) {
log.info("Repartitioning corpus to {} parts...", workers);
corpus.repartition(workers);
}
if (storageLevel != null)
corpus.persist(storageLevel);
final JavaSparkContext sc = new JavaSparkContext(corpus.context());
// this will have any effect only if wasn't called before, in extension classes
broadcastEnvironment(sc);
Counter<Long> finalCounter;
long numberOfSequences = 0;
/**
* Here we s
*/
if (paramServerConfiguration == null)
paramServerConfiguration = VoidConfiguration.builder().faultToleranceStrategy(FaultToleranceStrategy.NONE).numberOfShards(2).unicastPort(40123).multicastPort(40124).build();
isAutoDiscoveryMode = paramServerConfiguration.getShardAddresses() != null && !paramServerConfiguration.getShardAddresses().isEmpty() ? false : true;
Broadcast<VoidConfiguration> paramServerConfigurationBroadcast = null;
if (isAutoDiscoveryMode) {
log.info("Trying auto discovery mode...");
elementsFreqAccumExtra = corpus.context().accumulator(new ExtraCounter<Long>(), new ExtraElementsFrequenciesAccumulator());
ExtraCountFunction<T> elementsCounter = new ExtraCountFunction<>(elementsFreqAccumExtra, configuration.isTrainSequenceVectors());
JavaRDD<Pair<Sequence<T>, Long>> countedCorpus = corpus.map(elementsCounter);
// just to trigger map function, since we need huffman tree before proceeding
numberOfSequences = countedCorpus.count();
finalCounter = elementsFreqAccumExtra.value();
ExtraCounter<Long> spareReference = (ExtraCounter<Long>) finalCounter;
// getting list of available hosts
Set<NetworkInformation> availableHosts = spareReference.getNetworkInformation();
log.info("availableHosts: {}", availableHosts);
if (availableHosts.size() > 1) {
// now we have to pick N shards and optionally N backup nodes, and pass them within configuration bean
NetworkOrganizer organizer = new NetworkOrganizer(availableHosts, paramServerConfiguration.getNetworkMask());
paramServerConfiguration.setShardAddresses(organizer.getSubset(paramServerConfiguration.getNumberOfShards()));
// backup shards are optional
if (paramServerConfiguration.getFaultToleranceStrategy() != FaultToleranceStrategy.NONE) {
paramServerConfiguration.setBackupAddresses(organizer.getSubset(paramServerConfiguration.getNumberOfShards(), paramServerConfiguration.getShardAddresses()));
}
} else {
// for single host (aka driver-only, aka spark-local) just run on loopback interface
paramServerConfiguration.setShardAddresses(Arrays.asList("127.0.0.1:" + paramServerConfiguration.getUnicastPort()));
paramServerConfiguration.setFaultToleranceStrategy(FaultToleranceStrategy.NONE);
}
log.info("Got Shards so far: {}", paramServerConfiguration.getShardAddresses());
// update ps configuration with real values where required
paramServerConfiguration.setNumberOfShards(paramServerConfiguration.getShardAddresses().size());
paramServerConfiguration.setUseHS(configuration.isUseHierarchicSoftmax());
paramServerConfiguration.setUseNS(configuration.getNegative() > 0);
paramServerConfigurationBroadcast = sc.broadcast(paramServerConfiguration);
} else {
// update ps configuration with real values where required
paramServerConfiguration.setNumberOfShards(paramServerConfiguration.getShardAddresses().size());
paramServerConfiguration.setUseHS(configuration.isUseHierarchicSoftmax());
paramServerConfiguration.setUseNS(configuration.getNegative() > 0);
paramServerConfigurationBroadcast = sc.broadcast(paramServerConfiguration);
// set up freqs accumulator
elementsFreqAccum = corpus.context().accumulator(new Counter<Long>(), new ElementsFrequenciesAccumulator());
CountFunction<T> elementsCounter = new CountFunction<>(configurationBroadcast, paramServerConfigurationBroadcast, elementsFreqAccum, configuration.isTrainSequenceVectors());
// count all sequence elements and their sum
JavaRDD<Pair<Sequence<T>, Long>> countedCorpus = corpus.map(elementsCounter);
// just to trigger map function, since we need huffman tree before proceeding
numberOfSequences = countedCorpus.count();
// now we grab counter, which contains frequencies for all SequenceElements in corpus
finalCounter = elementsFreqAccum.value();
}
long numberOfElements = (long) finalCounter.totalCount();
long numberOfUniqueElements = finalCounter.size();
log.info("Total number of sequences: {}; Total number of elements entries: {}; Total number of unique elements: {}", numberOfSequences, numberOfElements, numberOfUniqueElements);
/*
build RDD of reduced SequenceElements, just get rid of labels temporary, stick to some numerical values,
like index or hashcode. So we could reduce driver memory footprint
*/
// build huffman tree, and update original RDD with huffman encoding info
shallowVocabCache = buildShallowVocabCache(finalCounter);
shallowVocabCacheBroadcast = sc.broadcast(shallowVocabCache);
// FIXME: probably we need to reconsider this approach
JavaRDD<T> vocabRDD = corpus.flatMap(new VocabRddFunctionFlat<T>(configurationBroadcast, paramServerConfigurationBroadcast)).distinct();
vocabRDD.count();
/**
* now we initialize Shards with values. That call should be started from driver which is either Client or Shard in standalone mode.
*/
VoidParameterServer.getInstance().init(paramServerConfiguration, new RoutedTransport(), ela.getTrainingDriver());
VoidParameterServer.getInstance().initializeSeqVec(configuration.getLayersSize(), (int) numberOfUniqueElements, 119, configuration.getLayersSize() / paramServerConfiguration.getNumberOfShards(), paramServerConfiguration.isUseHS(), paramServerConfiguration.isUseNS());
// proceed to training
// also, training function is the place where we invoke ParameterServer
TrainingFunction<T> trainer = new TrainingFunction<>(shallowVocabCacheBroadcast, configurationBroadcast, paramServerConfigurationBroadcast);
PartitionTrainingFunction<T> partitionTrainer = new PartitionTrainingFunction<>(shallowVocabCacheBroadcast, configurationBroadcast, paramServerConfigurationBroadcast);
if (configuration != null)
for (int e = 0; e < configuration.getEpochs(); e++) corpus.foreachPartition(partitionTrainer);
//corpus.foreach(trainer);
// we're transferring vectors to ExportContainer
JavaRDD<ExportContainer<T>> exportRdd = vocabRDD.map(new DistributedFunction<T>(paramServerConfigurationBroadcast, configurationBroadcast, shallowVocabCacheBroadcast));
// at this particular moment training should be pretty much done, and we're good to go for export
if (exporter != null)
exporter.export(exportRdd);
// unpersist, if we've persisten corpus after all
if (storageLevel != null)
corpus.unpersist();
log.info("Training finish, starting cleanup...");
VoidParameterServer.getInstance().shutdown();
}
use of org.apache.spark.api.java.JavaSparkContext in project deeplearning4j by deeplearning4j.
the class Word2Vec method train.
/**
* Training word2vec model on a given text corpus
*
* @param corpusRDD training corpus
* @throws Exception
*/
public void train(JavaRDD<String> corpusRDD) throws Exception {
log.info("Start training ...");
if (workers > 0)
corpusRDD.repartition(workers);
// SparkContext
final JavaSparkContext sc = new JavaSparkContext(corpusRDD.context());
// Pre-defined variables
Map<String, Object> tokenizerVarMap = getTokenizerVarMap();
Map<String, Object> word2vecVarMap = getWord2vecVarMap();
// Variables to fill in train
final JavaRDD<AtomicLong> sentenceWordsCountRDD;
final JavaRDD<List<VocabWord>> vocabWordListRDD;
final JavaPairRDD<List<VocabWord>, Long> vocabWordListSentenceCumSumRDD;
final VocabCache<VocabWord> vocabCache;
final JavaRDD<Long> sentenceCumSumCountRDD;
int maxRep = 1;
// Start Training //
//////////////////////////////////////
log.info("Tokenization and building VocabCache ...");
// Processing every sentence and make a VocabCache which gets fed into a LookupCache
Broadcast<Map<String, Object>> broadcastTokenizerVarMap = sc.broadcast(tokenizerVarMap);
TextPipeline pipeline = new TextPipeline(corpusRDD, broadcastTokenizerVarMap);
pipeline.buildVocabCache();
pipeline.buildVocabWordListRDD();
// Get total word count and put into word2vec variable map
word2vecVarMap.put("totalWordCount", pipeline.getTotalWordCount());
// 2 RDDs: (vocab words list) and (sentence Count).Already cached
sentenceWordsCountRDD = pipeline.getSentenceCountRDD();
vocabWordListRDD = pipeline.getVocabWordListRDD();
// Get vocabCache and broad-casted vocabCache
Broadcast<VocabCache<VocabWord>> vocabCacheBroadcast = pipeline.getBroadCastVocabCache();
vocabCache = vocabCacheBroadcast.getValue();
log.info("Vocab size: {}", vocabCache.numWords());
//////////////////////////////////////
log.info("Building Huffman Tree ...");
// Building Huffman Tree would update the code and point in each of the vocabWord in vocabCache
/*
We don't need to build tree here, since it was built earlier, at TextPipeline.buildVocabCache() call.
Huffman huffman = new Huffman(vocabCache.vocabWords());
huffman.build();
huffman.applyIndexes(vocabCache);
*/
//////////////////////////////////////
log.info("Calculating cumulative sum of sentence counts ...");
sentenceCumSumCountRDD = new CountCumSum(sentenceWordsCountRDD).buildCumSum();
//////////////////////////////////////
log.info("Mapping to RDD(vocabWordList, cumulative sentence count) ...");
vocabWordListSentenceCumSumRDD = vocabWordListRDD.zip(sentenceCumSumCountRDD).setName("vocabWordListSentenceCumSumRDD");
/////////////////////////////////////
log.info("Broadcasting word2vec variables to workers ...");
Broadcast<Map<String, Object>> word2vecVarMapBroadcast = sc.broadcast(word2vecVarMap);
Broadcast<double[]> expTableBroadcast = sc.broadcast(expTable);
/////////////////////////////////////
log.info("Training word2vec sentences ...");
FlatMapFunction firstIterFunc = new FirstIterationFunction(word2vecVarMapBroadcast, expTableBroadcast, vocabCacheBroadcast);
@SuppressWarnings("unchecked") JavaRDD<Pair<VocabWord, INDArray>> indexSyn0UpdateEntryRDD = vocabWordListSentenceCumSumRDD.mapPartitions(firstIterFunc).map(new MapToPairFunction());
// Get all the syn0 updates into a list in driver
List<Pair<VocabWord, INDArray>> syn0UpdateEntries = indexSyn0UpdateEntryRDD.collect();
// Instantiate syn0
INDArray syn0 = Nd4j.zeros(vocabCache.numWords(), layerSize);
// Updating syn0 first pass: just add vectors obtained from different nodes
log.info("Averaging results...");
Map<VocabWord, AtomicInteger> updates = new HashMap<>();
Map<Long, Long> updaters = new HashMap<>();
for (Pair<VocabWord, INDArray> syn0UpdateEntry : syn0UpdateEntries) {
syn0.getRow(syn0UpdateEntry.getFirst().getIndex()).addi(syn0UpdateEntry.getSecond());
// for proper averaging we need to divide resulting sums later, by the number of additions
if (updates.containsKey(syn0UpdateEntry.getFirst())) {
updates.get(syn0UpdateEntry.getFirst()).incrementAndGet();
} else
updates.put(syn0UpdateEntry.getFirst(), new AtomicInteger(1));
if (!updaters.containsKey(syn0UpdateEntry.getFirst().getVocabId())) {
updaters.put(syn0UpdateEntry.getFirst().getVocabId(), syn0UpdateEntry.getFirst().getAffinityId());
}
}
// Updating syn0 second pass: average obtained vectors
for (Map.Entry<VocabWord, AtomicInteger> entry : updates.entrySet()) {
if (entry.getValue().get() > 1) {
if (entry.getValue().get() > maxRep)
maxRep = entry.getValue().get();
syn0.getRow(entry.getKey().getIndex()).divi(entry.getValue().get());
}
}
long totals = 0;
log.info("Finished calculations...");
vocab = vocabCache;
InMemoryLookupTable<VocabWord> inMemoryLookupTable = new InMemoryLookupTable<VocabWord>();
Environment env = EnvironmentUtils.buildEnvironment();
env.setNumCores(maxRep);
env.setAvailableMemory(totals);
update(env, Event.SPARK);
inMemoryLookupTable.setVocab(vocabCache);
inMemoryLookupTable.setVectorLength(layerSize);
inMemoryLookupTable.setSyn0(syn0);
lookupTable = inMemoryLookupTable;
modelUtils.init(lookupTable);
}
use of org.apache.spark.api.java.JavaSparkContext in project deeplearning4j by deeplearning4j.
the class TextPipeline method setup.
private void setup() {
// Set up accumulators and broadcast stopwords
this.sc = new JavaSparkContext(corpusRDD.context());
this.wordFreqAcc = sc.accumulator(new Counter<String>(), new WordFreqAccumulator());
this.stopWordBroadCast = sc.broadcast(stopWords);
}
Aggregations