use of org.deeplearning4j.models.word2vec.wordstore.VocabCache in project deeplearning4j by deeplearning4j.
the class Word2Vec method train.
/**
* Training word2vec model on a given text corpus
*
* @param corpusRDD training corpus
* @throws Exception
*/
public void train(JavaRDD<String> corpusRDD) throws Exception {
log.info("Start training ...");
if (workers > 0)
corpusRDD.repartition(workers);
// SparkContext
final JavaSparkContext sc = new JavaSparkContext(corpusRDD.context());
// Pre-defined variables
Map<String, Object> tokenizerVarMap = getTokenizerVarMap();
Map<String, Object> word2vecVarMap = getWord2vecVarMap();
// Variables to fill in train
final JavaRDD<AtomicLong> sentenceWordsCountRDD;
final JavaRDD<List<VocabWord>> vocabWordListRDD;
final JavaPairRDD<List<VocabWord>, Long> vocabWordListSentenceCumSumRDD;
final VocabCache<VocabWord> vocabCache;
final JavaRDD<Long> sentenceCumSumCountRDD;
int maxRep = 1;
// Start Training //
//////////////////////////////////////
log.info("Tokenization and building VocabCache ...");
// Processing every sentence and make a VocabCache which gets fed into a LookupCache
Broadcast<Map<String, Object>> broadcastTokenizerVarMap = sc.broadcast(tokenizerVarMap);
TextPipeline pipeline = new TextPipeline(corpusRDD, broadcastTokenizerVarMap);
pipeline.buildVocabCache();
pipeline.buildVocabWordListRDD();
// Get total word count and put into word2vec variable map
word2vecVarMap.put("totalWordCount", pipeline.getTotalWordCount());
// 2 RDDs: (vocab words list) and (sentence Count).Already cached
sentenceWordsCountRDD = pipeline.getSentenceCountRDD();
vocabWordListRDD = pipeline.getVocabWordListRDD();
// Get vocabCache and broad-casted vocabCache
Broadcast<VocabCache<VocabWord>> vocabCacheBroadcast = pipeline.getBroadCastVocabCache();
vocabCache = vocabCacheBroadcast.getValue();
log.info("Vocab size: {}", vocabCache.numWords());
//////////////////////////////////////
log.info("Building Huffman Tree ...");
// Building Huffman Tree would update the code and point in each of the vocabWord in vocabCache
/*
We don't need to build tree here, since it was built earlier, at TextPipeline.buildVocabCache() call.
Huffman huffman = new Huffman(vocabCache.vocabWords());
huffman.build();
huffman.applyIndexes(vocabCache);
*/
//////////////////////////////////////
log.info("Calculating cumulative sum of sentence counts ...");
sentenceCumSumCountRDD = new CountCumSum(sentenceWordsCountRDD).buildCumSum();
//////////////////////////////////////
log.info("Mapping to RDD(vocabWordList, cumulative sentence count) ...");
vocabWordListSentenceCumSumRDD = vocabWordListRDD.zip(sentenceCumSumCountRDD).setName("vocabWordListSentenceCumSumRDD");
/////////////////////////////////////
log.info("Broadcasting word2vec variables to workers ...");
Broadcast<Map<String, Object>> word2vecVarMapBroadcast = sc.broadcast(word2vecVarMap);
Broadcast<double[]> expTableBroadcast = sc.broadcast(expTable);
/////////////////////////////////////
log.info("Training word2vec sentences ...");
FlatMapFunction firstIterFunc = new FirstIterationFunction(word2vecVarMapBroadcast, expTableBroadcast, vocabCacheBroadcast);
@SuppressWarnings("unchecked") JavaRDD<Pair<VocabWord, INDArray>> indexSyn0UpdateEntryRDD = vocabWordListSentenceCumSumRDD.mapPartitions(firstIterFunc).map(new MapToPairFunction());
// Get all the syn0 updates into a list in driver
List<Pair<VocabWord, INDArray>> syn0UpdateEntries = indexSyn0UpdateEntryRDD.collect();
// Instantiate syn0
INDArray syn0 = Nd4j.zeros(vocabCache.numWords(), layerSize);
// Updating syn0 first pass: just add vectors obtained from different nodes
log.info("Averaging results...");
Map<VocabWord, AtomicInteger> updates = new HashMap<>();
Map<Long, Long> updaters = new HashMap<>();
for (Pair<VocabWord, INDArray> syn0UpdateEntry : syn0UpdateEntries) {
syn0.getRow(syn0UpdateEntry.getFirst().getIndex()).addi(syn0UpdateEntry.getSecond());
// for proper averaging we need to divide resulting sums later, by the number of additions
if (updates.containsKey(syn0UpdateEntry.getFirst())) {
updates.get(syn0UpdateEntry.getFirst()).incrementAndGet();
} else
updates.put(syn0UpdateEntry.getFirst(), new AtomicInteger(1));
if (!updaters.containsKey(syn0UpdateEntry.getFirst().getVocabId())) {
updaters.put(syn0UpdateEntry.getFirst().getVocabId(), syn0UpdateEntry.getFirst().getAffinityId());
}
}
// Updating syn0 second pass: average obtained vectors
for (Map.Entry<VocabWord, AtomicInteger> entry : updates.entrySet()) {
if (entry.getValue().get() > 1) {
if (entry.getValue().get() > maxRep)
maxRep = entry.getValue().get();
syn0.getRow(entry.getKey().getIndex()).divi(entry.getValue().get());
}
}
long totals = 0;
log.info("Finished calculations...");
vocab = vocabCache;
InMemoryLookupTable<VocabWord> inMemoryLookupTable = new InMemoryLookupTable<VocabWord>();
Environment env = EnvironmentUtils.buildEnvironment();
env.setNumCores(maxRep);
env.setAvailableMemory(totals);
update(env, Event.SPARK);
inMemoryLookupTable.setVocab(vocabCache);
inMemoryLookupTable.setVectorLength(layerSize);
inMemoryLookupTable.setSyn0(syn0);
lookupTable = inMemoryLookupTable;
modelUtils.init(lookupTable);
}
use of org.deeplearning4j.models.word2vec.wordstore.VocabCache in project deeplearning4j by deeplearning4j.
the class GloveTest method testGlove.
@Test
public void testGlove() throws Exception {
Glove glove = new Glove(true, 5, 100);
JavaRDD<String> corpus = sc.textFile(new ClassPathResource("raw_sentences.txt").getFile().getAbsolutePath()).map(new Function<String, String>() {
@Override
public String call(String s) throws Exception {
return s.toLowerCase();
}
});
Pair<VocabCache<VocabWord>, GloveWeightLookupTable> table = glove.train(corpus);
WordVectors vectors = WordVectorSerializer.fromPair(new Pair<>((InMemoryLookupTable) table.getSecond(), (VocabCache) table.getFirst()));
Collection<String> words = vectors.wordsNearest("day", 20);
assertTrue(words.contains("week"));
}
use of org.deeplearning4j.models.word2vec.wordstore.VocabCache in project deeplearning4j by deeplearning4j.
the class Word2VecTest method testConcepts.
@Test
public void testConcepts() throws Exception {
// These are all default values for word2vec
SparkConf sparkConf = new SparkConf().setMaster("local[8]").setAppName("sparktest");
// Set SparkContext
JavaSparkContext sc = new JavaSparkContext(sparkConf);
// Path of data part-00000
String dataPath = new ClassPathResource("raw_sentences.txt").getFile().getAbsolutePath();
// dataPath = "/ext/Temp/part-00000";
// String dataPath = new ClassPathResource("spark_word2vec_test.txt").getFile().getAbsolutePath();
// Read in data
JavaRDD<String> corpus = sc.textFile(dataPath);
TokenizerFactory t = new DefaultTokenizerFactory();
t.setTokenPreProcessor(new CommonPreprocessor());
Word2Vec word2Vec = new Word2Vec.Builder().setNGrams(1).tokenizerFactory(t).seed(42L).negative(10).useAdaGrad(false).layerSize(150).windowSize(5).learningRate(0.025).minLearningRate(0.0001).iterations(1).batchSize(100).minWordFrequency(5).stopWords(Arrays.asList("three")).useUnknown(true).build();
word2Vec.train(corpus);
//word2Vec.setModelUtils(new FlatModelUtils());
System.out.println("UNK: " + word2Vec.getWordVectorMatrix("UNK"));
InMemoryLookupTable<VocabWord> table = (InMemoryLookupTable<VocabWord>) word2Vec.lookupTable();
double sim = word2Vec.similarity("day", "night");
System.out.println("day/night similarity: " + sim);
/*
System.out.println("Hornjo: " + word2Vec.getWordVectorMatrix("hornjoserbsce"));
System.out.println("carro: " + word2Vec.getWordVectorMatrix("carro"));
Collection<String> portu = word2Vec.wordsNearest("carro", 10);
printWords("carro", portu, word2Vec);
portu = word2Vec.wordsNearest("davi", 10);
printWords("davi", portu, word2Vec);
System.out.println("---------------------------------------");
*/
Collection<String> words = word2Vec.wordsNearest("day", 10);
printWords("day", words, word2Vec);
assertTrue(words.contains("night"));
assertTrue(words.contains("week"));
assertTrue(words.contains("year"));
sim = word2Vec.similarity("two", "four");
System.out.println("two/four similarity: " + sim);
words = word2Vec.wordsNearest("two", 10);
printWords("two", words, word2Vec);
// three should be absent due to stopWords
assertFalse(words.contains("three"));
assertTrue(words.contains("five"));
assertTrue(words.contains("four"));
sc.stop();
// test serialization
File tempFile = File.createTempFile("temp", "tmp");
tempFile.deleteOnExit();
int idx1 = word2Vec.vocab().wordFor("day").getIndex();
INDArray array1 = word2Vec.getWordVectorMatrix("day").dup();
VocabWord word1 = word2Vec.vocab().elementAtIndex(0);
WordVectorSerializer.writeWordVectors(word2Vec.getLookupTable(), tempFile);
WordVectors vectors = WordVectorSerializer.loadTxtVectors(tempFile);
VocabWord word2 = ((VocabCache<VocabWord>) vectors.vocab()).elementAtIndex(0);
VocabWord wordIT = ((VocabCache<VocabWord>) vectors.vocab()).wordFor("it");
int idx2 = vectors.vocab().wordFor("day").getIndex();
INDArray array2 = vectors.getWordVectorMatrix("day").dup();
System.out.println("word 'i': " + word2);
System.out.println("word 'it': " + wordIT);
assertEquals(idx1, idx2);
assertEquals(word1, word2);
assertEquals(array1, array2);
}
use of org.deeplearning4j.models.word2vec.wordstore.VocabCache in project deeplearning4j by deeplearning4j.
the class WordVectorSerializer method loadTxtVectors.
/**
* This method can be used to load previously saved model from InputStream (like a HDFS-stream)
*
* Deprecation note: Please, consider using readWord2VecModel() or loadStaticModel() method instead
*
* @param stream InputStream that contains previously serialized model
* @param skipFirstLine Set this TRUE if first line contains csv header, FALSE otherwise
* @return
* @throws IOException
*/
@Deprecated
public static WordVectors loadTxtVectors(@NonNull InputStream stream, boolean skipFirstLine) throws IOException {
AbstractCache<VocabWord> cache = new AbstractCache.Builder<VocabWord>().build();
BufferedReader reader = new BufferedReader(new InputStreamReader(stream));
String line = "";
List<INDArray> arrays = new ArrayList<>();
if (skipFirstLine)
reader.readLine();
while ((line = reader.readLine()) != null) {
String[] split = line.split(" ");
String word = split[0].replaceAll(whitespaceReplacement, " ");
VocabWord word1 = new VocabWord(1.0, word);
word1.setIndex(cache.numWords());
cache.addToken(word1);
cache.addWordToIndex(word1.getIndex(), word);
cache.putVocabWord(word);
float[] vector = new float[split.length - 1];
for (int i = 1; i < split.length; i++) {
vector[i - 1] = Float.parseFloat(split[i]);
}
INDArray row = Nd4j.create(vector);
arrays.add(row);
}
InMemoryLookupTable<VocabWord> lookupTable = (InMemoryLookupTable<VocabWord>) new InMemoryLookupTable.Builder<VocabWord>().vectorLength(arrays.get(0).columns()).cache(cache).build();
INDArray syn = Nd4j.vstack(arrays);
Nd4j.clearNans(syn);
lookupTable.setSyn0(syn);
return fromPair(Pair.makePair((InMemoryLookupTable) lookupTable, (VocabCache) cache));
}
use of org.deeplearning4j.models.word2vec.wordstore.VocabCache in project deeplearning4j by deeplearning4j.
the class VocabCacheExporter method export.
@Override
public void export(JavaRDD<ExportContainer<VocabWord>> rdd) {
// beware, generally that's VERY bad idea, but will work fine for testing purposes
List<ExportContainer<VocabWord>> list = rdd.collect();
if (vocabCache == null)
vocabCache = new AbstractCache<>();
INDArray syn0 = null;
// just roll through list
for (ExportContainer<VocabWord> element : list) {
VocabWord word = element.getElement();
INDArray weights = element.getArray();
if (syn0 == null)
syn0 = Nd4j.create(list.size(), weights.length());
vocabCache.addToken(word);
vocabCache.addWordToIndex(word.getIndex(), word.getLabel());
syn0.getRow(word.getIndex()).assign(weights);
}
if (lookupTable == null)
lookupTable = new InMemoryLookupTable.Builder<VocabWord>().cache(vocabCache).vectorLength(syn0.columns()).build();
lookupTable.setSyn0(syn0);
// this is bad & dirty, but we don't really need anything else for testing :)
word2Vec = WordVectorSerializer.fromPair(Pair.<InMemoryLookupTable, VocabCache>makePair(lookupTable, vocabCache));
}
Aggregations