Search in sources :

Example 1 with Huffman

use of org.deeplearning4j.models.word2vec.Huffman in project deeplearning4j by deeplearning4j.

the class GraphTransformer method initialize.

/**
     * This method handles required initialization for GraphTransformer
     */
protected void initialize() {
    log.info("Building Huffman tree for source graph...");
    int nVertices = sourceGraph.numVertices();
    //int[] degrees = new int[nVertices];
    //for( int i=0; i<nVertices; i++ )
    // degrees[i] = sourceGraph.getVertexDegree(i);
    /*
        for (int y = 0; y < nVertices; y+= 20) {
            int[] copy = Arrays.copyOfRange(degrees, y, y+20);
            System.out.println("D: " + Arrays.toString(copy));
        }
        */
    //        GraphHuffman huffman = new GraphHuffman(nVertices);
    //        huffman.buildTree(degrees);
    log.info("Transferring Huffman tree info to nodes...");
    for (int i = 0; i < nVertices; i++) {
        T element = sourceGraph.getVertex(i).getValue();
        element.setElementFrequency(sourceGraph.getConnectedVertices(i).size());
        if (vocabCache != null)
            vocabCache.addToken(element);
    }
    if (vocabCache != null) {
        Huffman huffman = new Huffman(vocabCache.vocabWords());
        huffman.build();
        huffman.applyIndexes(vocabCache);
    }
}
Also used : Huffman(org.deeplearning4j.models.word2vec.Huffman)

Example 2 with Huffman

use of org.deeplearning4j.models.word2vec.Huffman in project deeplearning4j by deeplearning4j.

the class VocabConstructor method buildJointVocabulary.

/**
     * This method scans all sources passed through builder, and returns all words as vocab.
     * If TargetVocabCache was set during instance creation, it'll be filled too.
     *
     *
     * @return
     */
public VocabCache<T> buildJointVocabulary(boolean resetCounters, boolean buildHuffmanTree) {
    long lastTime = System.currentTimeMillis();
    long lastSequences = 0;
    long lastElements = 0;
    long startTime = lastTime;
    long startWords = 0;
    AtomicLong parsedCount = new AtomicLong(0);
    if (resetCounters && buildHuffmanTree)
        throw new IllegalStateException("You can't reset counters and build Huffman tree at the same time!");
    if (cache == null)
        cache = new AbstractCache.Builder<T>().build();
    log.debug("Target vocab size before building: [" + cache.numWords() + "]");
    final AtomicLong loopCounter = new AtomicLong(0);
    AbstractCache<T> topHolder = new AbstractCache.Builder<T>().minElementFrequency(0).build();
    int cnt = 0;
    int numProc = Runtime.getRuntime().availableProcessors();
    int numThreads = Math.max(numProc / 2, 2);
    ExecutorService executorService = new ThreadPoolExecutor(numThreads, numThreads, 0L, TimeUnit.MILLISECONDS, new LinkedTransferQueue<Runnable>());
    final AtomicLong execCounter = new AtomicLong(0);
    final AtomicLong finCounter = new AtomicLong(0);
    for (VocabSource<T> source : sources) {
        SequenceIterator<T> iterator = source.getIterator();
        iterator.reset();
        log.debug("Trying source iterator: [" + cnt + "]");
        log.debug("Target vocab size before building: [" + cache.numWords() + "]");
        cnt++;
        AbstractCache<T> tempHolder = new AbstractCache.Builder<T>().build();
        List<Long> timesHasNext = new ArrayList<>();
        List<Long> timesNext = new ArrayList<>();
        int sequences = 0;
        long time3 = 0;
        while (iterator.hasMoreSequences()) {
            Sequence<T> document = iterator.nextSequence();
            seqCount.incrementAndGet();
            parsedCount.addAndGet(document.size());
            tempHolder.incrementTotalDocCount();
            execCounter.incrementAndGet();
            VocabRunnable runnable = new VocabRunnable(tempHolder, document, finCounter, loopCounter);
            executorService.execute(runnable);
            // if we're not in parallel mode - wait till this runnable finishes
            if (!allowParallelBuilder) {
                while (execCounter.get() != finCounter.get()) LockSupport.parkNanos(1000);
            }
            while (execCounter.get() - finCounter.get() > numProc) {
                try {
                    Thread.sleep(1);
                } catch (Exception e) {
                }
            }
            sequences++;
            if (seqCount.get() % 100000 == 0) {
                long currentTime = System.currentTimeMillis();
                long currentSequences = seqCount.get();
                long currentElements = parsedCount.get();
                double seconds = (currentTime - lastTime) / (double) 1000;
                //                    Collections.sort(timesHasNext);
                //                    Collections.sort(timesNext);
                double seqPerSec = (currentSequences - lastSequences) / seconds;
                double elPerSec = (currentElements - lastElements) / seconds;
                //                    log.info("Document time: {} us; hasNext time: {} us", timesNext.get(timesNext.size() / 2), timesHasNext.get(timesHasNext.size() / 2));
                log.info("Sequences checked: [{}]; Current vocabulary size: [{}]; Sequences/sec: {}; Words/sec: {};", seqCount.get(), tempHolder.numWords(), String.format("%.2f", seqPerSec), String.format("%.2f", elPerSec));
                lastTime = currentTime;
                lastElements = currentElements;
                lastSequences = currentSequences;
            //                    timesHasNext.clear();
            //                    timesNext.clear();
            }
            /**
                 * Firing scavenger loop
                 */
            if (enableScavenger && loopCounter.get() >= 2000000 && tempHolder.numWords() > 10000000) {
                log.info("Starting scavenger...");
                while (execCounter.get() != finCounter.get()) {
                    try {
                        Thread.sleep(2);
                    } catch (Exception e) {
                    }
                }
                filterVocab(tempHolder, Math.max(1, source.getMinWordFrequency() / 2));
                loopCounter.set(0);
            }
        //                timesNext.add((time2 - time1) / 1000L);
        //                timesHasNext.add((time1 - time3) / 1000L);
        //                time3 = System.nanoTime();
        }
        // block untill all threads are finished
        log.debug("Wating till all processes stop...");
        while (execCounter.get() != finCounter.get()) {
            try {
                Thread.sleep(2);
            } catch (Exception e) {
            }
        }
        // apply minWordFrequency set for this source
        log.debug("Vocab size before truncation: [" + tempHolder.numWords() + "],  NumWords: [" + tempHolder.totalWordOccurrences() + "], sequences parsed: [" + seqCount.get() + "], counter: [" + parsedCount.get() + "]");
        if (source.getMinWordFrequency() > 0) {
            filterVocab(tempHolder, source.getMinWordFrequency());
        }
        log.debug("Vocab size after truncation: [" + tempHolder.numWords() + "],  NumWords: [" + tempHolder.totalWordOccurrences() + "], sequences parsed: [" + seqCount.get() + "], counter: [" + parsedCount.get() + "]");
        // at this moment we're ready to transfer
        topHolder.importVocabulary(tempHolder);
    }
    // at this moment, we have vocabulary full of words, and we have to reset counters before transfer everything back to VocabCache
    //topHolder.resetWordCounters();
    System.gc();
    System.gc();
    try {
        Thread.sleep(1000);
    } catch (Exception e) {
    //
    }
    cache.importVocabulary(topHolder);
    // adding UNK word
    if (unk != null) {
        log.info("Adding UNK element to vocab...");
        unk.setSpecial(true);
        cache.addToken(unk);
    }
    if (resetCounters) {
        for (T element : cache.vocabWords()) {
            element.setElementFrequency(0);
        }
        cache.updateWordsOccurencies();
    }
    if (buildHuffmanTree) {
        Huffman huffman = new Huffman(cache.vocabWords());
        huffman.build();
        huffman.applyIndexes(cache);
        if (limit > 0) {
            LinkedBlockingQueue<String> labelsToRemove = new LinkedBlockingQueue<>();
            for (T element : cache.vocabWords()) {
                if (element.getIndex() > limit && !element.isSpecial() && !element.isLabel())
                    labelsToRemove.add(element.getLabel());
            }
            for (String label : labelsToRemove) {
                cache.removeElement(label);
            }
        }
    }
    executorService.shutdown();
    System.gc();
    System.gc();
    try {
        Thread.sleep(1000);
    } catch (Exception e) {
    //
    }
    long endSequences = seqCount.get();
    long endTime = System.currentTimeMillis();
    double seconds = (endTime - startTime) / (double) 1000;
    double seqPerSec = endSequences / seconds;
    log.info("Sequences checked: [{}], Current vocabulary size: [{}]; Sequences/sec: [{}];", seqCount.get(), cache.numWords(), String.format("%.2f", seqPerSec));
    return cache;
}
Also used : AbstractCache(org.deeplearning4j.models.word2vec.wordstore.inmemory.AbstractCache) AtomicLong(java.util.concurrent.atomic.AtomicLong) AtomicLong(java.util.concurrent.atomic.AtomicLong) Huffman(org.deeplearning4j.models.word2vec.Huffman)

Example 3 with Huffman

use of org.deeplearning4j.models.word2vec.Huffman in project deeplearning4j by deeplearning4j.

the class BinaryCoOccurrenceReaderTest method testHasMoreObjects2.

@Test
public void testHasMoreObjects2() throws Exception {
    File tempFile = File.createTempFile("tmp", "tmp");
    tempFile.deleteOnExit();
    VocabCache<VocabWord> vocabCache = new AbstractCache.Builder<VocabWord>().build();
    VocabWord word1 = new VocabWord(1.0, "human");
    VocabWord word2 = new VocabWord(2.0, "animal");
    VocabWord word3 = new VocabWord(3.0, "unknown");
    vocabCache.addToken(word1);
    vocabCache.addToken(word2);
    vocabCache.addToken(word3);
    Huffman huffman = new Huffman(vocabCache.vocabWords());
    huffman.build();
    huffman.applyIndexes(vocabCache);
    BinaryCoOccurrenceWriter<VocabWord> writer = new BinaryCoOccurrenceWriter<>(tempFile);
    CoOccurrenceWeight<VocabWord> object1 = new CoOccurrenceWeight<>();
    object1.setElement1(word1);
    object1.setElement2(word2);
    object1.setWeight(3.14159265);
    writer.writeObject(object1);
    CoOccurrenceWeight<VocabWord> object2 = new CoOccurrenceWeight<>();
    object2.setElement1(word2);
    object2.setElement2(word3);
    object2.setWeight(0.197);
    writer.writeObject(object2);
    CoOccurrenceWeight<VocabWord> object3 = new CoOccurrenceWeight<>();
    object3.setElement1(word1);
    object3.setElement2(word3);
    object3.setWeight(0.001);
    writer.writeObject(object3);
    writer.finish();
    BinaryCoOccurrenceReader<VocabWord> reader = new BinaryCoOccurrenceReader<>(tempFile, vocabCache, null);
    CoOccurrenceWeight<VocabWord> r1 = reader.nextObject();
    log.info("Object received: " + r1);
    assertNotEquals(null, r1);
    r1 = reader.nextObject();
    log.info("Object received: " + r1);
    assertNotEquals(null, r1);
    r1 = reader.nextObject();
    log.info("Object received: " + r1);
    assertNotEquals(null, r1);
}
Also used : Huffman(org.deeplearning4j.models.word2vec.Huffman) VocabWord(org.deeplearning4j.models.word2vec.VocabWord) File(java.io.File) AbstractCache(org.deeplearning4j.models.word2vec.wordstore.inmemory.AbstractCache) Test(org.junit.Test)

Example 4 with Huffman

use of org.deeplearning4j.models.word2vec.Huffman in project deeplearning4j by deeplearning4j.

the class SparkSequenceVectors method buildShallowVocabCache.

/**
     * This method builds shadow vocabulary and huffman tree
     *
     * @param counter
     * @return
     */
protected VocabCache<ShallowSequenceElement> buildShallowVocabCache(Counter<Long> counter) {
    // TODO: need simplified cache here, that will operate on Long instead of string labels
    VocabCache<ShallowSequenceElement> vocabCache = new AbstractCache<>();
    for (Long id : counter.keySet()) {
        ShallowSequenceElement shallowElement = new ShallowSequenceElement(counter.getCount(id), id);
        vocabCache.addToken(shallowElement);
    }
    // building huffman tree
    Huffman huffman = new Huffman(vocabCache.vocabWords());
    huffman.build();
    huffman.applyIndexes(vocabCache);
    return vocabCache;
}
Also used : ShallowSequenceElement(org.deeplearning4j.models.sequencevectors.sequence.ShallowSequenceElement) Huffman(org.deeplearning4j.models.word2vec.Huffman) AbstractCache(org.deeplearning4j.models.word2vec.wordstore.inmemory.AbstractCache)

Example 5 with Huffman

use of org.deeplearning4j.models.word2vec.Huffman in project deeplearning4j by deeplearning4j.

the class TextPipelineTest method testHuffman.

@Test
public void testHuffman() throws Exception {
    JavaSparkContext sc = getContext();
    JavaRDD<String> corpusRDD = getCorpusRDD(sc);
    Broadcast<Map<String, Object>> broadcastTokenizerVarMap = sc.broadcast(word2vec.getTokenizerVarMap());
    TextPipeline pipeline = new TextPipeline(corpusRDD, broadcastTokenizerVarMap);
    pipeline.buildVocabCache();
    VocabCache<VocabWord> vocabCache = pipeline.getVocabCache();
    Huffman huffman = new Huffman(vocabCache.vocabWords());
    huffman.build();
    huffman.applyIndexes(vocabCache);
    Collection<VocabWord> vocabWords = vocabCache.vocabWords();
    System.out.println("Huffman Test:");
    for (VocabWord vocabWord : vocabWords) {
        System.out.println("Word: " + vocabWord);
        System.out.println(vocabWord.getCodes());
        System.out.println(vocabWord.getPoints());
    }
    sc.stop();
}
Also used : Huffman(org.deeplearning4j.models.word2vec.Huffman) VocabWord(org.deeplearning4j.models.word2vec.VocabWord) JavaSparkContext(org.apache.spark.api.java.JavaSparkContext) TextPipeline(org.deeplearning4j.spark.text.functions.TextPipeline) Test(org.junit.Test)

Aggregations

Huffman (org.deeplearning4j.models.word2vec.Huffman)9 VocabWord (org.deeplearning4j.models.word2vec.VocabWord)5 Test (org.junit.Test)5 AbstractCache (org.deeplearning4j.models.word2vec.wordstore.inmemory.AbstractCache)4 File (java.io.File)2 AtomicLong (java.util.concurrent.atomic.AtomicLong)2 JavaSparkContext (org.apache.spark.api.java.JavaSparkContext)2 TextPipeline (org.deeplearning4j.spark.text.functions.TextPipeline)2 ArrayList (java.util.ArrayList)1 List (java.util.List)1 Pair (org.deeplearning4j.berkeley.Pair)1 ShallowSequenceElement (org.deeplearning4j.models.sequencevectors.sequence.ShallowSequenceElement)1 FirstIterationFunction (org.deeplearning4j.spark.models.embeddings.word2vec.FirstIterationFunction)1 MapToPairFunction (org.deeplearning4j.spark.models.embeddings.word2vec.MapToPairFunction)1 CountCumSum (org.deeplearning4j.spark.text.functions.CountCumSum)1