Search in sources :

Example 46 with VocabWord

use of org.deeplearning4j.models.word2vec.VocabWord in project deeplearning4j by deeplearning4j.

the class SentenceBatch method trainSentence.

/**
     * Train on a list of vocab words
     * @param sentence the list of vocab words to train on
     */
public void trainSentence(Word2VecParam param, final List<VocabWord> sentence, double alpha, List<Triple<Integer, Integer, Integer>> changed) {
    if (sentence != null && !sentence.isEmpty()) {
        for (int i = 0; i < sentence.size(); i++) {
            VocabWord vocabWord = sentence.get(i);
            if (vocabWord != null && vocabWord.getWord().endsWith("STOP")) {
                nextRandom.set(nextRandom.get() * 25214903917L + 11);
                skipGram(param, i, sentence, (int) nextRandom.get() % param.getWindow(), alpha, changed);
            }
        }
    }
}
Also used : VocabWord(org.deeplearning4j.models.word2vec.VocabWord)

Example 47 with VocabWord

use of org.deeplearning4j.models.word2vec.VocabWord in project deeplearning4j by deeplearning4j.

the class Word2Vec method train.

/**
     *  Training word2vec model on a given text corpus
     *
     * @param corpusRDD training corpus
     * @throws Exception
     */
public void train(JavaRDD<String> corpusRDD) throws Exception {
    log.info("Start training ...");
    if (workers > 0)
        corpusRDD.repartition(workers);
    // SparkContext
    final JavaSparkContext sc = new JavaSparkContext(corpusRDD.context());
    // Pre-defined variables
    Map<String, Object> tokenizerVarMap = getTokenizerVarMap();
    Map<String, Object> word2vecVarMap = getWord2vecVarMap();
    // Variables to fill in train
    final JavaRDD<AtomicLong> sentenceWordsCountRDD;
    final JavaRDD<List<VocabWord>> vocabWordListRDD;
    final JavaPairRDD<List<VocabWord>, Long> vocabWordListSentenceCumSumRDD;
    final VocabCache<VocabWord> vocabCache;
    final JavaRDD<Long> sentenceCumSumCountRDD;
    int maxRep = 1;
    // Start Training //
    //////////////////////////////////////
    log.info("Tokenization and building VocabCache ...");
    // Processing every sentence and make a VocabCache which gets fed into a LookupCache
    Broadcast<Map<String, Object>> broadcastTokenizerVarMap = sc.broadcast(tokenizerVarMap);
    TextPipeline pipeline = new TextPipeline(corpusRDD, broadcastTokenizerVarMap);
    pipeline.buildVocabCache();
    pipeline.buildVocabWordListRDD();
    // Get total word count and put into word2vec variable map
    word2vecVarMap.put("totalWordCount", pipeline.getTotalWordCount());
    // 2 RDDs: (vocab words list) and (sentence Count).Already cached
    sentenceWordsCountRDD = pipeline.getSentenceCountRDD();
    vocabWordListRDD = pipeline.getVocabWordListRDD();
    // Get vocabCache and broad-casted vocabCache
    Broadcast<VocabCache<VocabWord>> vocabCacheBroadcast = pipeline.getBroadCastVocabCache();
    vocabCache = vocabCacheBroadcast.getValue();
    log.info("Vocab size: {}", vocabCache.numWords());
    //////////////////////////////////////
    log.info("Building Huffman Tree ...");
    // Building Huffman Tree would update the code and point in each of the vocabWord in vocabCache
    /*
        We don't need to build tree here, since it was built earlier, at TextPipeline.buildVocabCache() call.
        
        Huffman huffman = new Huffman(vocabCache.vocabWords());
        huffman.build();
        huffman.applyIndexes(vocabCache);
        */
    //////////////////////////////////////
    log.info("Calculating cumulative sum of sentence counts ...");
    sentenceCumSumCountRDD = new CountCumSum(sentenceWordsCountRDD).buildCumSum();
    //////////////////////////////////////
    log.info("Mapping to RDD(vocabWordList, cumulative sentence count) ...");
    vocabWordListSentenceCumSumRDD = vocabWordListRDD.zip(sentenceCumSumCountRDD).setName("vocabWordListSentenceCumSumRDD");
    /////////////////////////////////////
    log.info("Broadcasting word2vec variables to workers ...");
    Broadcast<Map<String, Object>> word2vecVarMapBroadcast = sc.broadcast(word2vecVarMap);
    Broadcast<double[]> expTableBroadcast = sc.broadcast(expTable);
    /////////////////////////////////////
    log.info("Training word2vec sentences ...");
    FlatMapFunction firstIterFunc = new FirstIterationFunction(word2vecVarMapBroadcast, expTableBroadcast, vocabCacheBroadcast);
    @SuppressWarnings("unchecked") JavaRDD<Pair<VocabWord, INDArray>> indexSyn0UpdateEntryRDD = vocabWordListSentenceCumSumRDD.mapPartitions(firstIterFunc).map(new MapToPairFunction());
    // Get all the syn0 updates into a list in driver
    List<Pair<VocabWord, INDArray>> syn0UpdateEntries = indexSyn0UpdateEntryRDD.collect();
    // Instantiate syn0
    INDArray syn0 = Nd4j.zeros(vocabCache.numWords(), layerSize);
    // Updating syn0 first pass: just add vectors obtained from different nodes
    log.info("Averaging results...");
    Map<VocabWord, AtomicInteger> updates = new HashMap<>();
    Map<Long, Long> updaters = new HashMap<>();
    for (Pair<VocabWord, INDArray> syn0UpdateEntry : syn0UpdateEntries) {
        syn0.getRow(syn0UpdateEntry.getFirst().getIndex()).addi(syn0UpdateEntry.getSecond());
        // for proper averaging we need to divide resulting sums later, by the number of additions
        if (updates.containsKey(syn0UpdateEntry.getFirst())) {
            updates.get(syn0UpdateEntry.getFirst()).incrementAndGet();
        } else
            updates.put(syn0UpdateEntry.getFirst(), new AtomicInteger(1));
        if (!updaters.containsKey(syn0UpdateEntry.getFirst().getVocabId())) {
            updaters.put(syn0UpdateEntry.getFirst().getVocabId(), syn0UpdateEntry.getFirst().getAffinityId());
        }
    }
    // Updating syn0 second pass: average obtained vectors
    for (Map.Entry<VocabWord, AtomicInteger> entry : updates.entrySet()) {
        if (entry.getValue().get() > 1) {
            if (entry.getValue().get() > maxRep)
                maxRep = entry.getValue().get();
            syn0.getRow(entry.getKey().getIndex()).divi(entry.getValue().get());
        }
    }
    long totals = 0;
    log.info("Finished calculations...");
    vocab = vocabCache;
    InMemoryLookupTable<VocabWord> inMemoryLookupTable = new InMemoryLookupTable<VocabWord>();
    Environment env = EnvironmentUtils.buildEnvironment();
    env.setNumCores(maxRep);
    env.setAvailableMemory(totals);
    update(env, Event.SPARK);
    inMemoryLookupTable.setVocab(vocabCache);
    inMemoryLookupTable.setVectorLength(layerSize);
    inMemoryLookupTable.setSyn0(syn0);
    lookupTable = inMemoryLookupTable;
    modelUtils.init(lookupTable);
}
Also used : HashMap(java.util.HashMap) VocabWord(org.deeplearning4j.models.word2vec.VocabWord) InMemoryLookupTable(org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable) FlatMapFunction(org.apache.spark.api.java.function.FlatMapFunction) ArrayList(java.util.ArrayList) List(java.util.List) CountCumSum(org.deeplearning4j.spark.text.functions.CountCumSum) JavaSparkContext(org.apache.spark.api.java.JavaSparkContext) Pair(org.deeplearning4j.berkeley.Pair) TextPipeline(org.deeplearning4j.spark.text.functions.TextPipeline) AtomicLong(java.util.concurrent.atomic.AtomicLong) INDArray(org.nd4j.linalg.api.ndarray.INDArray) AtomicInteger(java.util.concurrent.atomic.AtomicInteger) VocabCache(org.deeplearning4j.models.word2vec.wordstore.VocabCache) AtomicLong(java.util.concurrent.atomic.AtomicLong) Environment(org.nd4j.linalg.heartbeat.reports.Environment) HashMap(java.util.HashMap) Map(java.util.Map)

Example 48 with VocabWord

use of org.deeplearning4j.models.word2vec.VocabWord in project deeplearning4j by deeplearning4j.

the class Word2VecPerformer method skipGram.

/**
     * Train via skip gram
     * @param i
     * @param sentence
     */
public void skipGram(int i, List<VocabWord> sentence, int b, double alpha) {
    final VocabWord word = sentence.get(i);
    if (word != null && !sentence.isEmpty()) {
        int end = window * 2 + 1 - b;
        for (int a = b; a < end; a++) {
            if (a != window) {
                int c = i - window + a;
                if (c >= 0 && c < sentence.size()) {
                    VocabWord lastWord = sentence.get(c);
                    iterateSample(word, lastWord, alpha);
                }
            }
        }
    }
}
Also used : VocabWord(org.deeplearning4j.models.word2vec.VocabWord)

Example 49 with VocabWord

use of org.deeplearning4j.models.word2vec.VocabWord in project deeplearning4j by deeplearning4j.

the class Word2VecPerformerVoid method skipGram.

/**
     * Train via skip gram
     * @param i
     * @param sentence
     */
public void skipGram(int i, List<VocabWord> sentence, int b, double alpha) {
    final VocabWord word = sentence.get(i);
    if (word != null && !sentence.isEmpty()) {
        int end = window * 2 + 1 - b;
        for (int a = b; a < end; a++) {
            if (a != window) {
                int c = i - window + a;
                if (c >= 0 && c < sentence.size()) {
                    VocabWord lastWord = sentence.get(c);
                    iterateSample(word, lastWord, alpha);
                }
            }
        }
    }
}
Also used : VocabWord(org.deeplearning4j.models.word2vec.VocabWord)

Example 50 with VocabWord

use of org.deeplearning4j.models.word2vec.VocabWord in project deeplearning4j by deeplearning4j.

the class TextPipeline method addTokenToVocabCache.

private void addTokenToVocabCache(String stringToken, Double tokenCount) {
    // Making string token into actual token if not already an actual token (vocabWord)
    VocabWord actualToken;
    if (vocabCache.hasToken(stringToken)) {
        actualToken = vocabCache.tokenFor(stringToken);
        actualToken.increaseElementFrequency(tokenCount.intValue());
    } else {
        actualToken = new VocabWord(tokenCount, stringToken);
    }
    // Set the index of the actual token (vocabWord)
    // Put vocabWord into vocabs in InMemoryVocabCache
    boolean vocabContainsWord = vocabCache.containsWord(stringToken);
    if (!vocabContainsWord) {
        int idx = vocabCache.numWords();
        vocabCache.addToken(actualToken);
        actualToken.setIndex(idx);
        vocabCache.putVocabWord(stringToken);
    }
}
Also used : VocabWord(org.deeplearning4j.models.word2vec.VocabWord)

Aggregations

VocabWord (org.deeplearning4j.models.word2vec.VocabWord)110 Test (org.junit.Test)54 INDArray (org.nd4j.linalg.api.ndarray.INDArray)31 AbstractCache (org.deeplearning4j.models.word2vec.wordstore.inmemory.AbstractCache)26 ClassPathResource (org.datavec.api.util.ClassPathResource)23 BasicLineIterator (org.deeplearning4j.text.sentenceiterator.BasicLineIterator)22 File (java.io.File)20 InMemoryLookupTable (org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable)19 TokenizerFactory (org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory)19 ArrayList (java.util.ArrayList)17 DefaultTokenizerFactory (org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFactory)17 CommonPreprocessor (org.deeplearning4j.text.tokenization.tokenizer.preprocessor.CommonPreprocessor)15 SentenceIterator (org.deeplearning4j.text.sentenceiterator.SentenceIterator)14 AbstractSequenceIterator (org.deeplearning4j.models.sequencevectors.iterators.AbstractSequenceIterator)13 SentenceTransformer (org.deeplearning4j.models.sequencevectors.transformers.impl.SentenceTransformer)13 JavaSparkContext (org.apache.spark.api.java.JavaSparkContext)12 ND4JIllegalStateException (org.nd4j.linalg.exception.ND4JIllegalStateException)12 Sequence (org.deeplearning4j.models.sequencevectors.sequence.Sequence)11 Word2Vec (org.deeplearning4j.models.word2vec.Word2Vec)11 TextPipeline (org.deeplearning4j.spark.text.functions.TextPipeline)10