use of org.deeplearning4j.models.word2vec.wordstore.inmemory.AbstractCache in project deeplearning4j by deeplearning4j.
the class VocabConstructor method buildJointVocabulary.
/**
* This method scans all sources passed through builder, and returns all words as vocab.
* If TargetVocabCache was set during instance creation, it'll be filled too.
*
*
* @return
*/
public VocabCache<T> buildJointVocabulary(boolean resetCounters, boolean buildHuffmanTree) {
long lastTime = System.currentTimeMillis();
long lastSequences = 0;
long lastElements = 0;
long startTime = lastTime;
long startWords = 0;
AtomicLong parsedCount = new AtomicLong(0);
if (resetCounters && buildHuffmanTree)
throw new IllegalStateException("You can't reset counters and build Huffman tree at the same time!");
if (cache == null)
cache = new AbstractCache.Builder<T>().build();
log.debug("Target vocab size before building: [" + cache.numWords() + "]");
final AtomicLong loopCounter = new AtomicLong(0);
AbstractCache<T> topHolder = new AbstractCache.Builder<T>().minElementFrequency(0).build();
int cnt = 0;
int numProc = Runtime.getRuntime().availableProcessors();
int numThreads = Math.max(numProc / 2, 2);
ExecutorService executorService = new ThreadPoolExecutor(numThreads, numThreads, 0L, TimeUnit.MILLISECONDS, new LinkedTransferQueue<Runnable>());
final AtomicLong execCounter = new AtomicLong(0);
final AtomicLong finCounter = new AtomicLong(0);
for (VocabSource<T> source : sources) {
SequenceIterator<T> iterator = source.getIterator();
iterator.reset();
log.debug("Trying source iterator: [" + cnt + "]");
log.debug("Target vocab size before building: [" + cache.numWords() + "]");
cnt++;
AbstractCache<T> tempHolder = new AbstractCache.Builder<T>().build();
List<Long> timesHasNext = new ArrayList<>();
List<Long> timesNext = new ArrayList<>();
int sequences = 0;
long time3 = 0;
while (iterator.hasMoreSequences()) {
Sequence<T> document = iterator.nextSequence();
seqCount.incrementAndGet();
parsedCount.addAndGet(document.size());
tempHolder.incrementTotalDocCount();
execCounter.incrementAndGet();
VocabRunnable runnable = new VocabRunnable(tempHolder, document, finCounter, loopCounter);
executorService.execute(runnable);
// if we're not in parallel mode - wait till this runnable finishes
if (!allowParallelBuilder) {
while (execCounter.get() != finCounter.get()) LockSupport.parkNanos(1000);
}
while (execCounter.get() - finCounter.get() > numProc) {
try {
Thread.sleep(1);
} catch (Exception e) {
}
}
sequences++;
if (seqCount.get() % 100000 == 0) {
long currentTime = System.currentTimeMillis();
long currentSequences = seqCount.get();
long currentElements = parsedCount.get();
double seconds = (currentTime - lastTime) / (double) 1000;
// Collections.sort(timesHasNext);
// Collections.sort(timesNext);
double seqPerSec = (currentSequences - lastSequences) / seconds;
double elPerSec = (currentElements - lastElements) / seconds;
// log.info("Document time: {} us; hasNext time: {} us", timesNext.get(timesNext.size() / 2), timesHasNext.get(timesHasNext.size() / 2));
log.info("Sequences checked: [{}]; Current vocabulary size: [{}]; Sequences/sec: {}; Words/sec: {};", seqCount.get(), tempHolder.numWords(), String.format("%.2f", seqPerSec), String.format("%.2f", elPerSec));
lastTime = currentTime;
lastElements = currentElements;
lastSequences = currentSequences;
// timesHasNext.clear();
// timesNext.clear();
}
/**
* Firing scavenger loop
*/
if (enableScavenger && loopCounter.get() >= 2000000 && tempHolder.numWords() > 10000000) {
log.info("Starting scavenger...");
while (execCounter.get() != finCounter.get()) {
try {
Thread.sleep(2);
} catch (Exception e) {
}
}
filterVocab(tempHolder, Math.max(1, source.getMinWordFrequency() / 2));
loopCounter.set(0);
}
// timesNext.add((time2 - time1) / 1000L);
// timesHasNext.add((time1 - time3) / 1000L);
// time3 = System.nanoTime();
}
// block untill all threads are finished
log.debug("Wating till all processes stop...");
while (execCounter.get() != finCounter.get()) {
try {
Thread.sleep(2);
} catch (Exception e) {
}
}
// apply minWordFrequency set for this source
log.debug("Vocab size before truncation: [" + tempHolder.numWords() + "], NumWords: [" + tempHolder.totalWordOccurrences() + "], sequences parsed: [" + seqCount.get() + "], counter: [" + parsedCount.get() + "]");
if (source.getMinWordFrequency() > 0) {
filterVocab(tempHolder, source.getMinWordFrequency());
}
log.debug("Vocab size after truncation: [" + tempHolder.numWords() + "], NumWords: [" + tempHolder.totalWordOccurrences() + "], sequences parsed: [" + seqCount.get() + "], counter: [" + parsedCount.get() + "]");
// at this moment we're ready to transfer
topHolder.importVocabulary(tempHolder);
}
// at this moment, we have vocabulary full of words, and we have to reset counters before transfer everything back to VocabCache
//topHolder.resetWordCounters();
System.gc();
System.gc();
try {
Thread.sleep(1000);
} catch (Exception e) {
//
}
cache.importVocabulary(topHolder);
// adding UNK word
if (unk != null) {
log.info("Adding UNK element to vocab...");
unk.setSpecial(true);
cache.addToken(unk);
}
if (resetCounters) {
for (T element : cache.vocabWords()) {
element.setElementFrequency(0);
}
cache.updateWordsOccurencies();
}
if (buildHuffmanTree) {
Huffman huffman = new Huffman(cache.vocabWords());
huffman.build();
huffman.applyIndexes(cache);
if (limit > 0) {
LinkedBlockingQueue<String> labelsToRemove = new LinkedBlockingQueue<>();
for (T element : cache.vocabWords()) {
if (element.getIndex() > limit && !element.isSpecial() && !element.isLabel())
labelsToRemove.add(element.getLabel());
}
for (String label : labelsToRemove) {
cache.removeElement(label);
}
}
}
executorService.shutdown();
System.gc();
System.gc();
try {
Thread.sleep(1000);
} catch (Exception e) {
//
}
long endSequences = seqCount.get();
long endTime = System.currentTimeMillis();
double seconds = (endTime - startTime) / (double) 1000;
double seqPerSec = endSequences / seconds;
log.info("Sequences checked: [{}], Current vocabulary size: [{}]; Sequences/sec: [{}];", seqCount.get(), cache.numWords(), String.format("%.2f", seqPerSec));
return cache;
}
use of org.deeplearning4j.models.word2vec.wordstore.inmemory.AbstractCache in project deeplearning4j by deeplearning4j.
the class ParagraphVectorsTest method testParagraphVectorsDM.
@Test
public void testParagraphVectorsDM() throws Exception {
ClassPathResource resource = new ClassPathResource("/big/raw_sentences.txt");
File file = resource.getFile();
SentenceIterator iter = new BasicLineIterator(file);
AbstractCache<VocabWord> cache = new AbstractCache.Builder<VocabWord>().build();
TokenizerFactory t = new DefaultTokenizerFactory();
t.setTokenPreProcessor(new CommonPreprocessor());
LabelsSource source = new LabelsSource("DOC_");
ParagraphVectors vec = new ParagraphVectors.Builder().minWordFrequency(1).iterations(2).seed(119).epochs(3).layerSize(100).learningRate(0.025).labelsSource(source).windowSize(5).iterate(iter).trainWordVectors(true).vocabCache(cache).tokenizerFactory(t).negativeSample(0).useHierarchicSoftmax(true).sampling(0).workers(1).usePreciseWeightInit(true).sequenceLearningAlgorithm(new DM<VocabWord>()).build();
vec.fit();
int cnt1 = cache.wordFrequency("day");
int cnt2 = cache.wordFrequency("me");
assertNotEquals(1, cnt1);
assertNotEquals(1, cnt2);
assertNotEquals(cnt1, cnt2);
double simDN = vec.similarity("day", "night");
log.info("day/night similariry: {}", simDN);
double similarity1 = vec.similarity("DOC_9835", "DOC_12492");
log.info("9835/12492 similarity: " + similarity1);
// assertTrue(similarity1 > 0.2d);
double similarity2 = vec.similarity("DOC_3720", "DOC_16392");
log.info("3720/16392 similarity: " + similarity2);
// assertTrue(similarity2 > 0.2d);
double similarity3 = vec.similarity("DOC_6347", "DOC_3720");
log.info("6347/3720 similarity: " + similarity3);
// assertTrue(similarity3 > 0.6d);
double similarityX = vec.similarity("DOC_3720", "DOC_9852");
log.info("3720/9852 similarity: " + similarityX);
assertTrue(similarityX < 0.5d);
// testing DM inference now
INDArray original = vec.getWordVectorMatrix("DOC_16392").dup();
INDArray inferredA1 = vec.inferVector("This is my work");
INDArray inferredB1 = vec.inferVector("This is my work .");
double cosAO1 = Transforms.cosineSim(inferredA1.dup(), original.dup());
double cosAB1 = Transforms.cosineSim(inferredA1.dup(), inferredB1.dup());
log.info("Cos O/A: {}", cosAO1);
log.info("Cos A/B: {}", cosAB1);
}
use of org.deeplearning4j.models.word2vec.wordstore.inmemory.AbstractCache in project deeplearning4j by deeplearning4j.
the class ParagraphVectorsTest method testParagraphVectorsWithWordVectorsModelling1.
@Test
public void testParagraphVectorsWithWordVectorsModelling1() throws Exception {
ClassPathResource resource = new ClassPathResource("/big/raw_sentences.txt");
File file = resource.getFile();
SentenceIterator iter = new BasicLineIterator(file);
// InMemoryLookupCache cache = new InMemoryLookupCache(false);
AbstractCache<VocabWord> cache = new AbstractCache.Builder<VocabWord>().build();
TokenizerFactory t = new DefaultTokenizerFactory();
t.setTokenPreProcessor(new CommonPreprocessor());
LabelsSource source = new LabelsSource("DOC_");
ParagraphVectors vec = new ParagraphVectors.Builder().minWordFrequency(1).iterations(3).epochs(1).layerSize(100).learningRate(0.025).labelsSource(source).windowSize(5).iterate(iter).trainWordVectors(true).vocabCache(cache).tokenizerFactory(t).sampling(0).build();
vec.fit();
int cnt1 = cache.wordFrequency("day");
int cnt2 = cache.wordFrequency("me");
assertNotEquals(1, cnt1);
assertNotEquals(1, cnt2);
assertNotEquals(cnt1, cnt2);
/*
We have few lines that contain pretty close words invloved.
These sentences should be pretty close to each other in vector space
*/
// line 3721: This is my way .
// line 6348: This is my case .
// line 9836: This is my house .
// line 12493: This is my world .
// line 16393: This is my work .
// this is special sentence, that has nothing common with previous sentences
// line 9853: We now have one .
assertTrue(vec.hasWord("DOC_3720"));
double similarityD = vec.similarity("day", "night");
log.info("day/night similarity: " + similarityD);
double similarityW = vec.similarity("way", "work");
log.info("way/work similarity: " + similarityW);
double similarityH = vec.similarity("house", "world");
log.info("house/world similarity: " + similarityH);
double similarityC = vec.similarity("case", "way");
log.info("case/way similarity: " + similarityC);
double similarity1 = vec.similarity("DOC_9835", "DOC_12492");
log.info("9835/12492 similarity: " + similarity1);
// assertTrue(similarity1 > 0.7d);
double similarity2 = vec.similarity("DOC_3720", "DOC_16392");
log.info("3720/16392 similarity: " + similarity2);
// assertTrue(similarity2 > 0.7d);
double similarity3 = vec.similarity("DOC_6347", "DOC_3720");
log.info("6347/3720 similarity: " + similarity3);
// assertTrue(similarity2 > 0.7d);
// likelihood in this case should be significantly lower
// however, since corpus is small, and weight initialization is random-based, sometimes this test CAN fail
double similarityX = vec.similarity("DOC_3720", "DOC_9852");
log.info("3720/9852 similarity: " + similarityX);
assertTrue(similarityX < 0.5d);
double sim119 = vec.similarityToLabel("This is my case .", "DOC_6347");
double sim120 = vec.similarityToLabel("This is my case .", "DOC_3720");
log.info("1/2: " + sim119 + "/" + sim120);
//assertEquals(similarity3, sim119, 0.001);
}
use of org.deeplearning4j.models.word2vec.wordstore.inmemory.AbstractCache in project deeplearning4j by deeplearning4j.
the class SequenceVectorsTest method testDeepWalk.
@Test
@Ignore
public void testDeepWalk() throws Exception {
Heartbeat.getInstance().disableHeartbeat();
AbstractCache<Blogger> vocabCache = new AbstractCache.Builder<Blogger>().build();
Graph<Blogger, Double> graph = buildGraph();
GraphWalker<Blogger> walker = new PopularityWalker.Builder<>(graph).setNoEdgeHandling(NoEdgeHandling.RESTART_ON_DISCONNECTED).setWalkLength(40).setWalkDirection(WalkDirection.FORWARD_UNIQUE).setRestartProbability(0.05).setPopularitySpread(10).setPopularityMode(PopularityMode.MAXIMUM).setSpreadSpectrum(SpreadSpectrum.PROPORTIONAL).build();
/*
GraphWalker<Blogger> walker = new RandomWalker.Builder<Blogger>(graph)
.setNoEdgeHandling(NoEdgeHandling.RESTART_ON_DISCONNECTED)
.setWalkLength(40)
.setWalkDirection(WalkDirection.RANDOM)
.setRestartProbability(0.05)
.build();
*/
GraphTransformer<Blogger> graphTransformer = new GraphTransformer.Builder<>(graph).setGraphWalker(walker).shuffleOnReset(true).setVocabCache(vocabCache).build();
Blogger blogger = graph.getVertex(0).getValue();
assertEquals(119, blogger.getElementFrequency(), 0.001);
logger.info("Blogger: " + blogger);
AbstractSequenceIterator<Blogger> sequenceIterator = new AbstractSequenceIterator.Builder<>(graphTransformer).build();
WeightLookupTable<Blogger> lookupTable = new InMemoryLookupTable.Builder<Blogger>().lr(0.025).vectorLength(150).useAdaGrad(false).cache(vocabCache).seed(42).build();
lookupTable.resetWeights(true);
SequenceVectors<Blogger> vectors = new SequenceVectors.Builder<Blogger>(new VectorsConfiguration()).lookupTable(lookupTable).iterate(sequenceIterator).vocabCache(vocabCache).batchSize(1000).iterations(1).epochs(10).resetModel(false).trainElementsRepresentation(true).trainSequencesRepresentation(false).elementsLearningAlgorithm(new SkipGram<Blogger>()).learningRate(0.025).layerSize(150).sampling(0).negativeSample(0).windowSize(4).workers(6).seed(42).build();
vectors.fit();
vectors.setModelUtils(new FlatModelUtils());
// logger.info("12: " + Arrays.toString(vectors.getWordVector("12")));
double sim = vectors.similarity("12", "72");
Collection<String> list = vectors.wordsNearest("12", 20);
logger.info("12->72: " + sim);
printWords("12", list, vectors);
assertTrue(sim > 0.10);
assertFalse(Double.isNaN(sim));
}
use of org.deeplearning4j.models.word2vec.wordstore.inmemory.AbstractCache in project deeplearning4j by deeplearning4j.
the class SparkSequenceVectors method buildShallowVocabCache.
/**
* This method builds shadow vocabulary and huffman tree
*
* @param counter
* @return
*/
protected VocabCache<ShallowSequenceElement> buildShallowVocabCache(Counter<Long> counter) {
// TODO: need simplified cache here, that will operate on Long instead of string labels
VocabCache<ShallowSequenceElement> vocabCache = new AbstractCache<>();
for (Long id : counter.keySet()) {
ShallowSequenceElement shallowElement = new ShallowSequenceElement(counter.getCount(id), id);
vocabCache.addToken(shallowElement);
}
// building huffman tree
Huffman huffman = new Huffman(vocabCache.vocabWords());
huffman.build();
huffman.applyIndexes(vocabCache);
return vocabCache;
}
Aggregations