use of org.apache.spark.SparkConf in project deeplearning4j by deeplearning4j.
the class TestKryoWarning method testKryoMessageMLNIncorrectConfig.
@Test
@Ignore
public void testKryoMessageMLNIncorrectConfig() {
//Should print warning message
SparkConf sparkConf = new SparkConf().setMaster("local[*]").setAppName("sparktest").set("spark.serializer", "org.apache.spark.serializer.KryoSerializer");
doTestMLN(sparkConf);
}
use of org.apache.spark.SparkConf in project deeplearning4j by deeplearning4j.
the class Glove method train.
/**
* Train on the corpus
* @param rdd the rdd to train
* @return the vocab and weights
*/
public Pair<VocabCache<VocabWord>, GloveWeightLookupTable> train(JavaRDD<String> rdd) throws Exception {
// Each `train()` can use different parameters
final JavaSparkContext sc = new JavaSparkContext(rdd.context());
final SparkConf conf = sc.getConf();
final int vectorLength = assignVar(VECTOR_LENGTH, conf, Integer.class);
final boolean useAdaGrad = assignVar(ADAGRAD, conf, Boolean.class);
final double negative = assignVar(NEGATIVE, conf, Double.class);
final int numWords = assignVar(NUM_WORDS, conf, Integer.class);
final int window = assignVar(WINDOW, conf, Integer.class);
final double alpha = assignVar(ALPHA, conf, Double.class);
final double minAlpha = assignVar(MIN_ALPHA, conf, Double.class);
final int iterations = assignVar(ITERATIONS, conf, Integer.class);
final int nGrams = assignVar(N_GRAMS, conf, Integer.class);
final String tokenizer = assignVar(TOKENIZER, conf, String.class);
final String tokenPreprocessor = assignVar(TOKEN_PREPROCESSOR, conf, String.class);
final boolean removeStop = assignVar(REMOVE_STOPWORDS, conf, Boolean.class);
Map<String, Object> tokenizerVarMap = new HashMap<String, Object>() {
{
put("numWords", numWords);
put("nGrams", nGrams);
put("tokenizer", tokenizer);
put("tokenPreprocessor", tokenPreprocessor);
put("removeStop", removeStop);
}
};
Broadcast<Map<String, Object>> broadcastTokenizerVarMap = sc.broadcast(tokenizerVarMap);
TextPipeline pipeline = new TextPipeline(rdd, broadcastTokenizerVarMap);
pipeline.buildVocabCache();
pipeline.buildVocabWordListRDD();
// Get total word count
Long totalWordCount = pipeline.getTotalWordCount();
VocabCache<VocabWord> vocabCache = pipeline.getVocabCache();
JavaRDD<Pair<List<String>, AtomicLong>> sentenceWordsCountRDD = pipeline.getSentenceWordsCountRDD();
final Pair<VocabCache<VocabWord>, Long> vocabAndNumWords = new Pair<>(vocabCache, totalWordCount);
vocabCacheBroadcast = sc.broadcast(vocabAndNumWords.getFirst());
final GloveWeightLookupTable gloveWeightLookupTable = new GloveWeightLookupTable.Builder().cache(vocabAndNumWords.getFirst()).lr(conf.getDouble(GlovePerformer.ALPHA, 0.01)).maxCount(conf.getDouble(GlovePerformer.MAX_COUNT, 100)).vectorLength(conf.getInt(GlovePerformer.VECTOR_LENGTH, 300)).xMax(conf.getDouble(GlovePerformer.X_MAX, 0.75)).build();
gloveWeightLookupTable.resetWeights();
gloveWeightLookupTable.getBiasAdaGrad().historicalGradient = Nd4j.ones(gloveWeightLookupTable.getSyn0().rows());
gloveWeightLookupTable.getWeightAdaGrad().historicalGradient = Nd4j.ones(gloveWeightLookupTable.getSyn0().shape());
log.info("Created lookup table of size " + Arrays.toString(gloveWeightLookupTable.getSyn0().shape()));
CounterMap<String, String> coOccurrenceCounts = sentenceWordsCountRDD.map(new CoOccurrenceCalculator(symmetric, vocabCacheBroadcast, windowSize)).fold(new CounterMap<String, String>(), new CoOccurrenceCounts());
Iterator<Pair<String, String>> pair2 = coOccurrenceCounts.getPairIterator();
List<Triple<String, String, Double>> counts = new ArrayList<>();
while (pair2.hasNext()) {
Pair<String, String> next = pair2.next();
if (coOccurrenceCounts.getCount(next.getFirst(), next.getSecond()) > gloveWeightLookupTable.getMaxCount()) {
coOccurrenceCounts.setCount(next.getFirst(), next.getSecond(), gloveWeightLookupTable.getMaxCount());
}
counts.add(new Triple<>(next.getFirst(), next.getSecond(), coOccurrenceCounts.getCount(next.getFirst(), next.getSecond())));
}
log.info("Calculated co occurrences");
JavaRDD<Triple<String, String, Double>> parallel = sc.parallelize(counts);
JavaPairRDD<String, Tuple2<String, Double>> pairs = parallel.mapToPair(new PairFunction<Triple<String, String, Double>, String, Tuple2<String, Double>>() {
@Override
public Tuple2<String, Tuple2<String, Double>> call(Triple<String, String, Double> stringStringDoubleTriple) throws Exception {
return new Tuple2<>(stringStringDoubleTriple.getFirst(), new Tuple2<>(stringStringDoubleTriple.getSecond(), stringStringDoubleTriple.getThird()));
}
});
JavaPairRDD<VocabWord, Tuple2<VocabWord, Double>> pairsVocab = pairs.mapToPair(new PairFunction<Tuple2<String, Tuple2<String, Double>>, VocabWord, Tuple2<VocabWord, Double>>() {
@Override
public Tuple2<VocabWord, Tuple2<VocabWord, Double>> call(Tuple2<String, Tuple2<String, Double>> stringTuple2Tuple2) throws Exception {
VocabWord w1 = vocabCacheBroadcast.getValue().wordFor(stringTuple2Tuple2._1());
VocabWord w2 = vocabCacheBroadcast.getValue().wordFor(stringTuple2Tuple2._2()._1());
return new Tuple2<>(w1, new Tuple2<>(w2, stringTuple2Tuple2._2()._2()));
}
});
for (int i = 0; i < iterations; i++) {
JavaRDD<GloveChange> change = pairsVocab.map(new Function<Tuple2<VocabWord, Tuple2<VocabWord, Double>>, GloveChange>() {
@Override
public GloveChange call(Tuple2<VocabWord, Tuple2<VocabWord, Double>> vocabWordTuple2Tuple2) throws Exception {
VocabWord w1 = vocabWordTuple2Tuple2._1();
VocabWord w2 = vocabWordTuple2Tuple2._2()._1();
INDArray w1Vector = gloveWeightLookupTable.getSyn0().slice(w1.getIndex());
INDArray w2Vector = gloveWeightLookupTable.getSyn0().slice(w2.getIndex());
INDArray bias = gloveWeightLookupTable.getBias();
double score = vocabWordTuple2Tuple2._2()._2();
double xMax = gloveWeightLookupTable.getxMax();
double maxCount = gloveWeightLookupTable.getMaxCount();
//w1 * w2 + bias
double prediction = Nd4j.getBlasWrapper().dot(w1Vector, w2Vector);
prediction += bias.getDouble(w1.getIndex()) + bias.getDouble(w2.getIndex());
double weight = FastMath.pow(Math.min(1.0, (score / maxCount)), xMax);
double fDiff = score > xMax ? prediction : weight * (prediction - Math.log(score));
if (Double.isNaN(fDiff))
fDiff = Nd4j.EPS_THRESHOLD;
//amount of change
double gradient = fDiff;
Pair<INDArray, Double> w1Update = update(gloveWeightLookupTable.getWeightAdaGrad(), gloveWeightLookupTable.getBiasAdaGrad(), gloveWeightLookupTable.getSyn0(), gloveWeightLookupTable.getBias(), w1, w1Vector, w2Vector, gradient);
Pair<INDArray, Double> w2Update = update(gloveWeightLookupTable.getWeightAdaGrad(), gloveWeightLookupTable.getBiasAdaGrad(), gloveWeightLookupTable.getSyn0(), gloveWeightLookupTable.getBias(), w2, w2Vector, w1Vector, gradient);
return new GloveChange(w1, w2, w1Update.getFirst(), w2Update.getFirst(), w1Update.getSecond(), w2Update.getSecond(), fDiff, gloveWeightLookupTable.getWeightAdaGrad().getHistoricalGradient().slice(w1.getIndex()), gloveWeightLookupTable.getWeightAdaGrad().getHistoricalGradient().slice(w2.getIndex()), gloveWeightLookupTable.getBiasAdaGrad().getHistoricalGradient().getDouble(w2.getIndex()), gloveWeightLookupTable.getBiasAdaGrad().getHistoricalGradient().getDouble(w1.getIndex()));
}
});
List<GloveChange> gloveChanges = change.collect();
double error = 0.0;
for (GloveChange change2 : gloveChanges) {
change2.apply(gloveWeightLookupTable);
error += change2.getError();
}
List l = pairsVocab.collect();
Collections.shuffle(l);
pairsVocab = sc.parallelizePairs(l);
log.info("Error at iteration " + i + " was " + error);
}
return new Pair<>(vocabAndNumWords.getFirst(), gloveWeightLookupTable);
}
use of org.apache.spark.SparkConf in project deeplearning4j by deeplearning4j.
the class Word2VecTest method testSparkW2VonBiggerCorpus.
@Ignore
@Test
public void testSparkW2VonBiggerCorpus() throws Exception {
SparkConf sparkConf = new SparkConf().setMaster("local[8]").setAppName("sparktest").set("spark.driver.maxResultSize", "4g").set("spark.driver.memory", "8g").set("spark.executor.memory", "8g");
// Set SparkContext
JavaSparkContext sc = new JavaSparkContext(sparkConf);
// Path of data part-00000
//String dataPath = new ClassPathResource("/big/raw_sentences.txt").getFile().getAbsolutePath();
// String dataPath = "/ext/Temp/SampleRussianCorpus.txt";
String dataPath = new ClassPathResource("spark_word2vec_test.txt").getFile().getAbsolutePath();
// Read in data
JavaRDD<String> corpus = sc.textFile(dataPath);
TokenizerFactory t = new DefaultTokenizerFactory();
t.setTokenPreProcessor(new LowCasePreProcessor());
Word2Vec word2Vec = new Word2Vec.Builder().setNGrams(1).tokenizerFactory(t).seed(42L).negative(3).useAdaGrad(false).layerSize(100).windowSize(5).learningRate(0.025).minLearningRate(0.0001).iterations(1).batchSize(100).minWordFrequency(5).useUnknown(true).build();
word2Vec.train(corpus);
sc.stop();
WordVectorSerializer.writeWordVectors(word2Vec.getLookupTable(), "/ext/Temp/sparkRuModel.txt");
}
use of org.apache.spark.SparkConf in project deeplearning4j by deeplearning4j.
the class Word2VecTest method testConcepts.
@Test
public void testConcepts() throws Exception {
// These are all default values for word2vec
SparkConf sparkConf = new SparkConf().setMaster("local[8]").setAppName("sparktest");
// Set SparkContext
JavaSparkContext sc = new JavaSparkContext(sparkConf);
// Path of data part-00000
String dataPath = new ClassPathResource("raw_sentences.txt").getFile().getAbsolutePath();
// dataPath = "/ext/Temp/part-00000";
// String dataPath = new ClassPathResource("spark_word2vec_test.txt").getFile().getAbsolutePath();
// Read in data
JavaRDD<String> corpus = sc.textFile(dataPath);
TokenizerFactory t = new DefaultTokenizerFactory();
t.setTokenPreProcessor(new CommonPreprocessor());
Word2Vec word2Vec = new Word2Vec.Builder().setNGrams(1).tokenizerFactory(t).seed(42L).negative(10).useAdaGrad(false).layerSize(150).windowSize(5).learningRate(0.025).minLearningRate(0.0001).iterations(1).batchSize(100).minWordFrequency(5).stopWords(Arrays.asList("three")).useUnknown(true).build();
word2Vec.train(corpus);
//word2Vec.setModelUtils(new FlatModelUtils());
System.out.println("UNK: " + word2Vec.getWordVectorMatrix("UNK"));
InMemoryLookupTable<VocabWord> table = (InMemoryLookupTable<VocabWord>) word2Vec.lookupTable();
double sim = word2Vec.similarity("day", "night");
System.out.println("day/night similarity: " + sim);
/*
System.out.println("Hornjo: " + word2Vec.getWordVectorMatrix("hornjoserbsce"));
System.out.println("carro: " + word2Vec.getWordVectorMatrix("carro"));
Collection<String> portu = word2Vec.wordsNearest("carro", 10);
printWords("carro", portu, word2Vec);
portu = word2Vec.wordsNearest("davi", 10);
printWords("davi", portu, word2Vec);
System.out.println("---------------------------------------");
*/
Collection<String> words = word2Vec.wordsNearest("day", 10);
printWords("day", words, word2Vec);
assertTrue(words.contains("night"));
assertTrue(words.contains("week"));
assertTrue(words.contains("year"));
sim = word2Vec.similarity("two", "four");
System.out.println("two/four similarity: " + sim);
words = word2Vec.wordsNearest("two", 10);
printWords("two", words, word2Vec);
// three should be absent due to stopWords
assertFalse(words.contains("three"));
assertTrue(words.contains("five"));
assertTrue(words.contains("four"));
sc.stop();
// test serialization
File tempFile = File.createTempFile("temp", "tmp");
tempFile.deleteOnExit();
int idx1 = word2Vec.vocab().wordFor("day").getIndex();
INDArray array1 = word2Vec.getWordVectorMatrix("day").dup();
VocabWord word1 = word2Vec.vocab().elementAtIndex(0);
WordVectorSerializer.writeWordVectors(word2Vec.getLookupTable(), tempFile);
WordVectors vectors = WordVectorSerializer.loadTxtVectors(tempFile);
VocabWord word2 = ((VocabCache<VocabWord>) vectors.vocab()).elementAtIndex(0);
VocabWord wordIT = ((VocabCache<VocabWord>) vectors.vocab()).wordFor("it");
int idx2 = vectors.vocab().wordFor("day").getIndex();
INDArray array2 = vectors.getWordVectorMatrix("day").dup();
System.out.println("word 'i': " + word2);
System.out.println("word 'it': " + wordIT);
assertEquals(idx1, idx2);
assertEquals(word1, word2);
assertEquals(array1, array2);
}
use of org.apache.spark.SparkConf in project deeplearning4j by deeplearning4j.
the class SparkSequenceVectorsTest method setUp.
@Before
public void setUp() throws Exception {
if (sequencesCyclic == null) {
sequencesCyclic = new ArrayList<>();
// 10 sequences in total
for (int с = 0; с < 10; с++) {
Sequence<VocabWord> sequence = new Sequence<>();
for (int e = 0; e < 10; e++) {
// we will have 9 equal elements, with total frequency of 10
sequence.addElement(new VocabWord(1.0, "" + e, (long) e));
}
// and 1 element with frequency of 20
sequence.addElement(new VocabWord(1.0, "0", 0L));
sequencesCyclic.add(sequence);
}
}
SparkConf sparkConf = new SparkConf().setMaster("local[8]").setAppName("SeqVecTests");
sc = new JavaSparkContext(sparkConf);
}
Aggregations