use of org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable in project deeplearning4j by deeplearning4j.
the class WordVectorSerializerTest method testFromTableAndVocab.
@Test
@Ignore
public void testFromTableAndVocab() throws IOException {
WordVectors vec = WordVectorSerializer.loadGoogleModel(textFile, false);
InMemoryLookupTable lookupTable = (InMemoryLookupTable) vec.lookupTable();
InMemoryLookupCache lookupCache = (InMemoryLookupCache) vec.vocab();
WordVectors wordVectors = WordVectorSerializer.fromTableAndVocab(lookupTable, lookupCache);
double[] wordVector1 = wordVectors.getWordVector("Morgan_Freeman");
double[] wordVector2 = wordVectors.getWordVector("JA_Montalbano");
assertTrue(wordVector1.length == 300);
assertTrue(wordVector2.length == 300);
assertEquals(Doubles.asList(wordVector1).get(0), 0.044423, 1e-3);
assertEquals(Doubles.asList(wordVector2).get(0), 0.051964, 1e-3);
}
use of org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable in project deeplearning4j by deeplearning4j.
the class WordVectorSerializerTest method testMalformedLabels1.
@Test
public void testMalformedLabels1() throws Exception {
List<String> words = new ArrayList<>();
words.add("test A");
words.add("test B");
words.add("test\nC");
words.add("test`D");
words.add("test_E");
words.add("test 5");
AbstractCache<VocabWord> vocabCache = new AbstractCache<>();
int cnt = 0;
for (String word : words) {
vocabCache.addToken(new VocabWord(1.0, word));
vocabCache.addWordToIndex(cnt, word);
cnt++;
}
vocabCache.elementAtIndex(1).markAsLabel(true);
InMemoryLookupTable<VocabWord> lookupTable = new InMemoryLookupTable<>(vocabCache, 10, false, 0.01, Nd4j.getRandom(), 0.0);
lookupTable.resetWeights(true);
assertNotEquals(null, lookupTable.getSyn0());
assertNotEquals(null, lookupTable.getSyn1());
assertNotEquals(null, lookupTable.getExpTable());
assertEquals(null, lookupTable.getSyn1Neg());
ParagraphVectors vec = new ParagraphVectors.Builder().lookupTable(lookupTable).vocabCache(vocabCache).build();
File tempFile = File.createTempFile("temp", "w2v");
tempFile.deleteOnExit();
WordVectorSerializer.writeParagraphVectors(vec, tempFile);
ParagraphVectors restoredVec = WordVectorSerializer.readParagraphVectors(tempFile);
for (String word : words) {
assertEquals(true, restoredVec.hasWord(word));
}
assertTrue(restoredVec.getVocab().elementAtIndex(1).isLabel());
}
use of org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable in project deeplearning4j by deeplearning4j.
the class SentenceBatch method iterateSample.
/**
* Iterate on the given 2 vocab words
*
* @param w1 the first word to iterate on
* @param w2 the second word to iterate on
*/
public void iterateSample(Word2VecParam param, VocabWord w1, VocabWord w2, double alpha, List<Triple<Integer, Integer, Integer>> changed) {
if (w2 == null || w2.getIndex() < 0 || w1.getIndex() == w2.getIndex() || w1.getWord().equals("STOP") || w2.getWord().equals("STOP") || w1.getWord().equals("UNK") || w2.getWord().equals("UNK"))
return;
int vectorLength = param.getVectorLength();
InMemoryLookupTable weights = param.getWeights();
boolean useAdaGrad = param.isUseAdaGrad();
double negative = param.getNegative();
INDArray table = param.getTable();
double[] expTable = param.getExpTable().getValue();
double MAX_EXP = 6;
int numWords = param.getNumWords();
//current word vector
INDArray l1 = weights.vector(w2.getWord());
//error for current word and context
INDArray neu1e = Nd4j.create(vectorLength);
for (int i = 0; i < w1.getCodeLength(); i++) {
int code = w1.getCodes().get(i);
int point = w1.getPoints().get(i);
INDArray syn1 = weights.getSyn1().slice(point);
double dot = Nd4j.getBlasWrapper().level1().dot(syn1.length(), 1.0, l1, syn1);
if (dot < -MAX_EXP || dot >= MAX_EXP)
continue;
int idx = (int) ((dot + MAX_EXP) * ((double) expTable.length / MAX_EXP / 2.0));
//score
double f = expTable[idx];
//gradient
double g = (1 - code - f) * (useAdaGrad ? w1.getGradient(i, alpha, alpha) : alpha);
Nd4j.getBlasWrapper().level1().axpy(syn1.length(), g, syn1, neu1e);
Nd4j.getBlasWrapper().level1().axpy(syn1.length(), g, l1, syn1);
changed.add(new Triple<>(point, w1.getIndex(), -1));
}
changed.add(new Triple<>(w1.getIndex(), w2.getIndex(), -1));
//negative sampling
if (negative > 0) {
int target = w1.getIndex();
int label;
INDArray syn1Neg = weights.getSyn1Neg().slice(target);
for (int d = 0; d < negative + 1; d++) {
if (d == 0) {
label = 1;
} else {
nextRandom.set(nextRandom.get() * 25214903917L + 11);
target = table.getInt((int) (nextRandom.get() >> 16) % table.length());
if (target == 0)
target = (int) nextRandom.get() % (numWords - 1) + 1;
if (target == w1.getIndex())
continue;
label = 0;
}
double f = Nd4j.getBlasWrapper().dot(l1, syn1Neg);
double g;
if (f > MAX_EXP)
g = useAdaGrad ? w1.getGradient(target, (label - 1), alpha) : (label - 1) * alpha;
else if (f < -MAX_EXP)
g = label * (useAdaGrad ? w1.getGradient(target, alpha, alpha) : alpha);
else
g = useAdaGrad ? w1.getGradient(target, label - expTable[(int) ((f + MAX_EXP) * (expTable.length / MAX_EXP / 2))], alpha) : (label - expTable[(int) ((f + MAX_EXP) * (expTable.length / MAX_EXP / 2))]) * alpha;
Nd4j.getBlasWrapper().level1().axpy(l1.length(), g, neu1e, l1);
Nd4j.getBlasWrapper().level1().axpy(l1.length(), g, syn1Neg, l1);
changed.add(new Triple<>(-1, -1, label));
}
}
Nd4j.getBlasWrapper().level1().axpy(l1.length(), 1.0f, neu1e, l1);
}
use of org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable in project deeplearning4j by deeplearning4j.
the class Word2Vec method train.
/**
* Training word2vec model on a given text corpus
*
* @param corpusRDD training corpus
* @throws Exception
*/
public void train(JavaRDD<String> corpusRDD) throws Exception {
log.info("Start training ...");
if (workers > 0)
corpusRDD.repartition(workers);
// SparkContext
final JavaSparkContext sc = new JavaSparkContext(corpusRDD.context());
// Pre-defined variables
Map<String, Object> tokenizerVarMap = getTokenizerVarMap();
Map<String, Object> word2vecVarMap = getWord2vecVarMap();
// Variables to fill in train
final JavaRDD<AtomicLong> sentenceWordsCountRDD;
final JavaRDD<List<VocabWord>> vocabWordListRDD;
final JavaPairRDD<List<VocabWord>, Long> vocabWordListSentenceCumSumRDD;
final VocabCache<VocabWord> vocabCache;
final JavaRDD<Long> sentenceCumSumCountRDD;
int maxRep = 1;
// Start Training //
//////////////////////////////////////
log.info("Tokenization and building VocabCache ...");
// Processing every sentence and make a VocabCache which gets fed into a LookupCache
Broadcast<Map<String, Object>> broadcastTokenizerVarMap = sc.broadcast(tokenizerVarMap);
TextPipeline pipeline = new TextPipeline(corpusRDD, broadcastTokenizerVarMap);
pipeline.buildVocabCache();
pipeline.buildVocabWordListRDD();
// Get total word count and put into word2vec variable map
word2vecVarMap.put("totalWordCount", pipeline.getTotalWordCount());
// 2 RDDs: (vocab words list) and (sentence Count).Already cached
sentenceWordsCountRDD = pipeline.getSentenceCountRDD();
vocabWordListRDD = pipeline.getVocabWordListRDD();
// Get vocabCache and broad-casted vocabCache
Broadcast<VocabCache<VocabWord>> vocabCacheBroadcast = pipeline.getBroadCastVocabCache();
vocabCache = vocabCacheBroadcast.getValue();
log.info("Vocab size: {}", vocabCache.numWords());
//////////////////////////////////////
log.info("Building Huffman Tree ...");
// Building Huffman Tree would update the code and point in each of the vocabWord in vocabCache
/*
We don't need to build tree here, since it was built earlier, at TextPipeline.buildVocabCache() call.
Huffman huffman = new Huffman(vocabCache.vocabWords());
huffman.build();
huffman.applyIndexes(vocabCache);
*/
//////////////////////////////////////
log.info("Calculating cumulative sum of sentence counts ...");
sentenceCumSumCountRDD = new CountCumSum(sentenceWordsCountRDD).buildCumSum();
//////////////////////////////////////
log.info("Mapping to RDD(vocabWordList, cumulative sentence count) ...");
vocabWordListSentenceCumSumRDD = vocabWordListRDD.zip(sentenceCumSumCountRDD).setName("vocabWordListSentenceCumSumRDD");
/////////////////////////////////////
log.info("Broadcasting word2vec variables to workers ...");
Broadcast<Map<String, Object>> word2vecVarMapBroadcast = sc.broadcast(word2vecVarMap);
Broadcast<double[]> expTableBroadcast = sc.broadcast(expTable);
/////////////////////////////////////
log.info("Training word2vec sentences ...");
FlatMapFunction firstIterFunc = new FirstIterationFunction(word2vecVarMapBroadcast, expTableBroadcast, vocabCacheBroadcast);
@SuppressWarnings("unchecked") JavaRDD<Pair<VocabWord, INDArray>> indexSyn0UpdateEntryRDD = vocabWordListSentenceCumSumRDD.mapPartitions(firstIterFunc).map(new MapToPairFunction());
// Get all the syn0 updates into a list in driver
List<Pair<VocabWord, INDArray>> syn0UpdateEntries = indexSyn0UpdateEntryRDD.collect();
// Instantiate syn0
INDArray syn0 = Nd4j.zeros(vocabCache.numWords(), layerSize);
// Updating syn0 first pass: just add vectors obtained from different nodes
log.info("Averaging results...");
Map<VocabWord, AtomicInteger> updates = new HashMap<>();
Map<Long, Long> updaters = new HashMap<>();
for (Pair<VocabWord, INDArray> syn0UpdateEntry : syn0UpdateEntries) {
syn0.getRow(syn0UpdateEntry.getFirst().getIndex()).addi(syn0UpdateEntry.getSecond());
// for proper averaging we need to divide resulting sums later, by the number of additions
if (updates.containsKey(syn0UpdateEntry.getFirst())) {
updates.get(syn0UpdateEntry.getFirst()).incrementAndGet();
} else
updates.put(syn0UpdateEntry.getFirst(), new AtomicInteger(1));
if (!updaters.containsKey(syn0UpdateEntry.getFirst().getVocabId())) {
updaters.put(syn0UpdateEntry.getFirst().getVocabId(), syn0UpdateEntry.getFirst().getAffinityId());
}
}
// Updating syn0 second pass: average obtained vectors
for (Map.Entry<VocabWord, AtomicInteger> entry : updates.entrySet()) {
if (entry.getValue().get() > 1) {
if (entry.getValue().get() > maxRep)
maxRep = entry.getValue().get();
syn0.getRow(entry.getKey().getIndex()).divi(entry.getValue().get());
}
}
long totals = 0;
log.info("Finished calculations...");
vocab = vocabCache;
InMemoryLookupTable<VocabWord> inMemoryLookupTable = new InMemoryLookupTable<VocabWord>();
Environment env = EnvironmentUtils.buildEnvironment();
env.setNumCores(maxRep);
env.setAvailableMemory(totals);
update(env, Event.SPARK);
inMemoryLookupTable.setVocab(vocabCache);
inMemoryLookupTable.setVectorLength(layerSize);
inMemoryLookupTable.setSyn0(syn0);
lookupTable = inMemoryLookupTable;
modelUtils.init(lookupTable);
}
use of org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable in project deeplearning4j by deeplearning4j.
the class GloveTest method testGlove.
@Test
public void testGlove() throws Exception {
Glove glove = new Glove(true, 5, 100);
JavaRDD<String> corpus = sc.textFile(new ClassPathResource("raw_sentences.txt").getFile().getAbsolutePath()).map(new Function<String, String>() {
@Override
public String call(String s) throws Exception {
return s.toLowerCase();
}
});
Pair<VocabCache<VocabWord>, GloveWeightLookupTable> table = glove.train(corpus);
WordVectors vectors = WordVectorSerializer.fromPair(new Pair<>((InMemoryLookupTable) table.getSecond(), (VocabCache) table.getFirst()));
Collection<String> words = vectors.wordsNearest("day", 20);
assertTrue(words.contains("week"));
}
Aggregations