use of org.deeplearning4j.models.word2vec.Huffman in project deeplearning4j by deeplearning4j.
the class GraphTransformer method initialize.
/**
* This method handles required initialization for GraphTransformer
*/
protected void initialize() {
log.info("Building Huffman tree for source graph...");
int nVertices = sourceGraph.numVertices();
//int[] degrees = new int[nVertices];
//for( int i=0; i<nVertices; i++ )
// degrees[i] = sourceGraph.getVertexDegree(i);
/*
for (int y = 0; y < nVertices; y+= 20) {
int[] copy = Arrays.copyOfRange(degrees, y, y+20);
System.out.println("D: " + Arrays.toString(copy));
}
*/
// GraphHuffman huffman = new GraphHuffman(nVertices);
// huffman.buildTree(degrees);
log.info("Transferring Huffman tree info to nodes...");
for (int i = 0; i < nVertices; i++) {
T element = sourceGraph.getVertex(i).getValue();
element.setElementFrequency(sourceGraph.getConnectedVertices(i).size());
if (vocabCache != null)
vocabCache.addToken(element);
}
if (vocabCache != null) {
Huffman huffman = new Huffman(vocabCache.vocabWords());
huffman.build();
huffman.applyIndexes(vocabCache);
}
}
use of org.deeplearning4j.models.word2vec.Huffman in project deeplearning4j by deeplearning4j.
the class VocabConstructor method buildJointVocabulary.
/**
* This method scans all sources passed through builder, and returns all words as vocab.
* If TargetVocabCache was set during instance creation, it'll be filled too.
*
*
* @return
*/
public VocabCache<T> buildJointVocabulary(boolean resetCounters, boolean buildHuffmanTree) {
long lastTime = System.currentTimeMillis();
long lastSequences = 0;
long lastElements = 0;
long startTime = lastTime;
long startWords = 0;
AtomicLong parsedCount = new AtomicLong(0);
if (resetCounters && buildHuffmanTree)
throw new IllegalStateException("You can't reset counters and build Huffman tree at the same time!");
if (cache == null)
cache = new AbstractCache.Builder<T>().build();
log.debug("Target vocab size before building: [" + cache.numWords() + "]");
final AtomicLong loopCounter = new AtomicLong(0);
AbstractCache<T> topHolder = new AbstractCache.Builder<T>().minElementFrequency(0).build();
int cnt = 0;
int numProc = Runtime.getRuntime().availableProcessors();
int numThreads = Math.max(numProc / 2, 2);
ExecutorService executorService = new ThreadPoolExecutor(numThreads, numThreads, 0L, TimeUnit.MILLISECONDS, new LinkedTransferQueue<Runnable>());
final AtomicLong execCounter = new AtomicLong(0);
final AtomicLong finCounter = new AtomicLong(0);
for (VocabSource<T> source : sources) {
SequenceIterator<T> iterator = source.getIterator();
iterator.reset();
log.debug("Trying source iterator: [" + cnt + "]");
log.debug("Target vocab size before building: [" + cache.numWords() + "]");
cnt++;
AbstractCache<T> tempHolder = new AbstractCache.Builder<T>().build();
List<Long> timesHasNext = new ArrayList<>();
List<Long> timesNext = new ArrayList<>();
int sequences = 0;
long time3 = 0;
while (iterator.hasMoreSequences()) {
Sequence<T> document = iterator.nextSequence();
seqCount.incrementAndGet();
parsedCount.addAndGet(document.size());
tempHolder.incrementTotalDocCount();
execCounter.incrementAndGet();
VocabRunnable runnable = new VocabRunnable(tempHolder, document, finCounter, loopCounter);
executorService.execute(runnable);
// if we're not in parallel mode - wait till this runnable finishes
if (!allowParallelBuilder) {
while (execCounter.get() != finCounter.get()) LockSupport.parkNanos(1000);
}
while (execCounter.get() - finCounter.get() > numProc) {
try {
Thread.sleep(1);
} catch (Exception e) {
}
}
sequences++;
if (seqCount.get() % 100000 == 0) {
long currentTime = System.currentTimeMillis();
long currentSequences = seqCount.get();
long currentElements = parsedCount.get();
double seconds = (currentTime - lastTime) / (double) 1000;
// Collections.sort(timesHasNext);
// Collections.sort(timesNext);
double seqPerSec = (currentSequences - lastSequences) / seconds;
double elPerSec = (currentElements - lastElements) / seconds;
// log.info("Document time: {} us; hasNext time: {} us", timesNext.get(timesNext.size() / 2), timesHasNext.get(timesHasNext.size() / 2));
log.info("Sequences checked: [{}]; Current vocabulary size: [{}]; Sequences/sec: {}; Words/sec: {};", seqCount.get(), tempHolder.numWords(), String.format("%.2f", seqPerSec), String.format("%.2f", elPerSec));
lastTime = currentTime;
lastElements = currentElements;
lastSequences = currentSequences;
// timesHasNext.clear();
// timesNext.clear();
}
/**
* Firing scavenger loop
*/
if (enableScavenger && loopCounter.get() >= 2000000 && tempHolder.numWords() > 10000000) {
log.info("Starting scavenger...");
while (execCounter.get() != finCounter.get()) {
try {
Thread.sleep(2);
} catch (Exception e) {
}
}
filterVocab(tempHolder, Math.max(1, source.getMinWordFrequency() / 2));
loopCounter.set(0);
}
// timesNext.add((time2 - time1) / 1000L);
// timesHasNext.add((time1 - time3) / 1000L);
// time3 = System.nanoTime();
}
// block untill all threads are finished
log.debug("Wating till all processes stop...");
while (execCounter.get() != finCounter.get()) {
try {
Thread.sleep(2);
} catch (Exception e) {
}
}
// apply minWordFrequency set for this source
log.debug("Vocab size before truncation: [" + tempHolder.numWords() + "], NumWords: [" + tempHolder.totalWordOccurrences() + "], sequences parsed: [" + seqCount.get() + "], counter: [" + parsedCount.get() + "]");
if (source.getMinWordFrequency() > 0) {
filterVocab(tempHolder, source.getMinWordFrequency());
}
log.debug("Vocab size after truncation: [" + tempHolder.numWords() + "], NumWords: [" + tempHolder.totalWordOccurrences() + "], sequences parsed: [" + seqCount.get() + "], counter: [" + parsedCount.get() + "]");
// at this moment we're ready to transfer
topHolder.importVocabulary(tempHolder);
}
// at this moment, we have vocabulary full of words, and we have to reset counters before transfer everything back to VocabCache
//topHolder.resetWordCounters();
System.gc();
System.gc();
try {
Thread.sleep(1000);
} catch (Exception e) {
//
}
cache.importVocabulary(topHolder);
// adding UNK word
if (unk != null) {
log.info("Adding UNK element to vocab...");
unk.setSpecial(true);
cache.addToken(unk);
}
if (resetCounters) {
for (T element : cache.vocabWords()) {
element.setElementFrequency(0);
}
cache.updateWordsOccurencies();
}
if (buildHuffmanTree) {
Huffman huffman = new Huffman(cache.vocabWords());
huffman.build();
huffman.applyIndexes(cache);
if (limit > 0) {
LinkedBlockingQueue<String> labelsToRemove = new LinkedBlockingQueue<>();
for (T element : cache.vocabWords()) {
if (element.getIndex() > limit && !element.isSpecial() && !element.isLabel())
labelsToRemove.add(element.getLabel());
}
for (String label : labelsToRemove) {
cache.removeElement(label);
}
}
}
executorService.shutdown();
System.gc();
System.gc();
try {
Thread.sleep(1000);
} catch (Exception e) {
//
}
long endSequences = seqCount.get();
long endTime = System.currentTimeMillis();
double seconds = (endTime - startTime) / (double) 1000;
double seqPerSec = endSequences / seconds;
log.info("Sequences checked: [{}], Current vocabulary size: [{}]; Sequences/sec: [{}];", seqCount.get(), cache.numWords(), String.format("%.2f", seqPerSec));
return cache;
}
use of org.deeplearning4j.models.word2vec.Huffman in project deeplearning4j by deeplearning4j.
the class BinaryCoOccurrenceReaderTest method testHasMoreObjects2.
@Test
public void testHasMoreObjects2() throws Exception {
File tempFile = File.createTempFile("tmp", "tmp");
tempFile.deleteOnExit();
VocabCache<VocabWord> vocabCache = new AbstractCache.Builder<VocabWord>().build();
VocabWord word1 = new VocabWord(1.0, "human");
VocabWord word2 = new VocabWord(2.0, "animal");
VocabWord word3 = new VocabWord(3.0, "unknown");
vocabCache.addToken(word1);
vocabCache.addToken(word2);
vocabCache.addToken(word3);
Huffman huffman = new Huffman(vocabCache.vocabWords());
huffman.build();
huffman.applyIndexes(vocabCache);
BinaryCoOccurrenceWriter<VocabWord> writer = new BinaryCoOccurrenceWriter<>(tempFile);
CoOccurrenceWeight<VocabWord> object1 = new CoOccurrenceWeight<>();
object1.setElement1(word1);
object1.setElement2(word2);
object1.setWeight(3.14159265);
writer.writeObject(object1);
CoOccurrenceWeight<VocabWord> object2 = new CoOccurrenceWeight<>();
object2.setElement1(word2);
object2.setElement2(word3);
object2.setWeight(0.197);
writer.writeObject(object2);
CoOccurrenceWeight<VocabWord> object3 = new CoOccurrenceWeight<>();
object3.setElement1(word1);
object3.setElement2(word3);
object3.setWeight(0.001);
writer.writeObject(object3);
writer.finish();
BinaryCoOccurrenceReader<VocabWord> reader = new BinaryCoOccurrenceReader<>(tempFile, vocabCache, null);
CoOccurrenceWeight<VocabWord> r1 = reader.nextObject();
log.info("Object received: " + r1);
assertNotEquals(null, r1);
r1 = reader.nextObject();
log.info("Object received: " + r1);
assertNotEquals(null, r1);
r1 = reader.nextObject();
log.info("Object received: " + r1);
assertNotEquals(null, r1);
}
use of org.deeplearning4j.models.word2vec.Huffman in project deeplearning4j by deeplearning4j.
the class SparkSequenceVectors method buildShallowVocabCache.
/**
* This method builds shadow vocabulary and huffman tree
*
* @param counter
* @return
*/
protected VocabCache<ShallowSequenceElement> buildShallowVocabCache(Counter<Long> counter) {
// TODO: need simplified cache here, that will operate on Long instead of string labels
VocabCache<ShallowSequenceElement> vocabCache = new AbstractCache<>();
for (Long id : counter.keySet()) {
ShallowSequenceElement shallowElement = new ShallowSequenceElement(counter.getCount(id), id);
vocabCache.addToken(shallowElement);
}
// building huffman tree
Huffman huffman = new Huffman(vocabCache.vocabWords());
huffman.build();
huffman.applyIndexes(vocabCache);
return vocabCache;
}
use of org.deeplearning4j.models.word2vec.Huffman in project deeplearning4j by deeplearning4j.
the class TextPipelineTest method testHuffman.
@Test
public void testHuffman() throws Exception {
JavaSparkContext sc = getContext();
JavaRDD<String> corpusRDD = getCorpusRDD(sc);
Broadcast<Map<String, Object>> broadcastTokenizerVarMap = sc.broadcast(word2vec.getTokenizerVarMap());
TextPipeline pipeline = new TextPipeline(corpusRDD, broadcastTokenizerVarMap);
pipeline.buildVocabCache();
VocabCache<VocabWord> vocabCache = pipeline.getVocabCache();
Huffman huffman = new Huffman(vocabCache.vocabWords());
huffman.build();
huffman.applyIndexes(vocabCache);
Collection<VocabWord> vocabWords = vocabCache.vocabWords();
System.out.println("Huffman Test:");
for (VocabWord vocabWord : vocabWords) {
System.out.println("Word: " + vocabWord);
System.out.println(vocabWord.getCodes());
System.out.println(vocabWord.getPoints());
}
sc.stop();
}
Aggregations