Search in sources :

Example 41 with VocabWord

use of org.deeplearning4j.models.word2vec.VocabWord in project deeplearning4j by deeplearning4j.

the class FirstIterationFunctionAdapter method call.

@Override
public Iterable<Map.Entry<VocabWord, INDArray>> call(Iterator<Tuple2<List<VocabWord>, Long>> pairIter) {
    while (pairIter.hasNext()) {
        List<Pair<List<VocabWord>, Long>> batch = new ArrayList<>();
        while (pairIter.hasNext() && batch.size() < batchSize) {
            Tuple2<List<VocabWord>, Long> pair = pairIter.next();
            List<VocabWord> vocabWordsList = pair._1();
            Long sentenceCumSumCount = pair._2();
            batch.add(Pair.of(vocabWordsList, sentenceCumSumCount));
        }
        for (int i = 0; i < iterations; i++) {
            //System.out.println("Training sentence: " + vocabWordsList);
            for (Pair<List<VocabWord>, Long> pair : batch) {
                List<VocabWord> vocabWordsList = pair.getKey();
                Long sentenceCumSumCount = pair.getValue();
                double currentSentenceAlpha = Math.max(minAlpha, alpha - (alpha - minAlpha) * (sentenceCumSumCount / (double) totalWordCount));
                trainSentence(vocabWordsList, currentSentenceAlpha);
            }
        }
    }
    return indexSyn0VecMap.entrySet();
}
Also used : AtomicLong(java.util.concurrent.atomic.AtomicLong) VocabWord(org.deeplearning4j.models.word2vec.VocabWord) Pair(org.apache.commons.lang3.tuple.Pair)

Example 42 with VocabWord

use of org.deeplearning4j.models.word2vec.VocabWord in project deeplearning4j by deeplearning4j.

the class SecondIterationFunctionAdapter method skipGram.

public void skipGram(int ithWordInSentence, List<VocabWord> vocabWordsList, int b, double currentSentenceAlpha) {
    VocabWord currentWord = vocabWordsList.get(ithWordInSentence);
    if (currentWord != null && !vocabWordsList.isEmpty()) {
        int end = window * 2 + 1 - b;
        for (int a = b; a < end; a++) {
            if (a != window) {
                int c = ithWordInSentence - window + a;
                if (c >= 0 && c < vocabWordsList.size()) {
                    VocabWord lastWord = vocabWordsList.get(c);
                    iterateSample(currentWord, lastWord, currentSentenceAlpha);
                }
            }
        }
    }
}
Also used : VocabWord(org.deeplearning4j.models.word2vec.VocabWord)

Example 43 with VocabWord

use of org.deeplearning4j.models.word2vec.VocabWord in project deeplearning4j by deeplearning4j.

the class KeySequenceConvertFunction method call.

@Override
public Sequence<VocabWord> call(Tuple2<String, String> pair) throws Exception {
    Sequence<VocabWord> sequence = new Sequence<>();
    sequence.addSequenceLabel(new VocabWord(1.0, pair._1()));
    if (tokenizerFactory == null)
        instantiateTokenizerFactory();
    List<String> tokens = tokenizerFactory.create(pair._2()).getTokens();
    for (String token : tokens) {
        if (token == null || token.isEmpty())
            continue;
        VocabWord word = new VocabWord(1.0, token);
        sequence.addElement(word);
    }
    return sequence;
}
Also used : VocabWord(org.deeplearning4j.models.word2vec.VocabWord) Sequence(org.deeplearning4j.models.sequencevectors.sequence.Sequence)

Example 44 with VocabWord

use of org.deeplearning4j.models.word2vec.VocabWord in project deeplearning4j by deeplearning4j.

the class SecondIterationFunctionAdapter method call.

@Override
public Iterable<Entry<VocabWord, INDArray>> call(Iterator<Tuple2<List<VocabWord>, Long>> pairIter) {
    this.vocabHolder = VocabHolder.getInstance();
    this.vocabHolder.setSeed(seed, vectorLength);
    if (negative > 0) {
        negativeHolder = NegativeHolder.getInstance();
        negativeHolder.initHolder(vocab, expTable, this.vectorLength);
    }
    while (pairIter.hasNext()) {
        List<Pair<List<VocabWord>, Long>> batch = new ArrayList<>();
        while (pairIter.hasNext() && batch.size() < batchSize) {
            Tuple2<List<VocabWord>, Long> pair = pairIter.next();
            List<VocabWord> vocabWordsList = pair._1();
            Long sentenceCumSumCount = pair._2();
            batch.add(Pair.of(vocabWordsList, sentenceCumSumCount));
        }
        for (int i = 0; i < iterations; i++) {
            //System.out.println("Training sentence: " + vocabWordsList);
            for (Pair<List<VocabWord>, Long> pair : batch) {
                List<VocabWord> vocabWordsList = pair.getKey();
                Long sentenceCumSumCount = pair.getValue();
                double currentSentenceAlpha = Math.max(minAlpha, alpha - (alpha - minAlpha) * (sentenceCumSumCount / (double) totalWordCount));
                trainSentence(vocabWordsList, currentSentenceAlpha);
            }
        }
    }
    return vocabHolder.getSplit(vocab);
}
Also used : ArrayList(java.util.ArrayList) AtomicLong(java.util.concurrent.atomic.AtomicLong) VocabWord(org.deeplearning4j.models.word2vec.VocabWord) ArrayList(java.util.ArrayList) List(java.util.List) Pair(org.apache.commons.lang3.tuple.Pair)

Example 45 with VocabWord

use of org.deeplearning4j.models.word2vec.VocabWord in project deeplearning4j by deeplearning4j.

the class SentenceBatch method skipGram.

/**
     * Train via skip gram
     * @param i the current word
     * @param sentence the sentence to train on
     * @param b
     * @param alpha the learning rate
     */
public void skipGram(Word2VecParam param, int i, List<VocabWord> sentence, int b, double alpha, List<Triple<Integer, Integer, Integer>> changed) {
    final VocabWord word = sentence.get(i);
    int window = param.getWindow();
    if (word != null && !sentence.isEmpty()) {
        int end = window * 2 + 1 - b;
        for (int a = b; a < end; a++) {
            if (a != window) {
                int c = i - window + a;
                if (c >= 0 && c < sentence.size()) {
                    VocabWord lastWord = sentence.get(c);
                    iterateSample(param, word, lastWord, alpha, changed);
                }
            }
        }
    }
}
Also used : VocabWord(org.deeplearning4j.models.word2vec.VocabWord)

Aggregations

VocabWord (org.deeplearning4j.models.word2vec.VocabWord)110 Test (org.junit.Test)54 INDArray (org.nd4j.linalg.api.ndarray.INDArray)31 AbstractCache (org.deeplearning4j.models.word2vec.wordstore.inmemory.AbstractCache)26 ClassPathResource (org.datavec.api.util.ClassPathResource)23 BasicLineIterator (org.deeplearning4j.text.sentenceiterator.BasicLineIterator)22 File (java.io.File)20 InMemoryLookupTable (org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable)19 TokenizerFactory (org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory)19 ArrayList (java.util.ArrayList)17 DefaultTokenizerFactory (org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFactory)17 CommonPreprocessor (org.deeplearning4j.text.tokenization.tokenizer.preprocessor.CommonPreprocessor)15 SentenceIterator (org.deeplearning4j.text.sentenceiterator.SentenceIterator)14 AbstractSequenceIterator (org.deeplearning4j.models.sequencevectors.iterators.AbstractSequenceIterator)13 SentenceTransformer (org.deeplearning4j.models.sequencevectors.transformers.impl.SentenceTransformer)13 JavaSparkContext (org.apache.spark.api.java.JavaSparkContext)12 ND4JIllegalStateException (org.nd4j.linalg.exception.ND4JIllegalStateException)12 Sequence (org.deeplearning4j.models.sequencevectors.sequence.Sequence)11 Word2Vec (org.deeplearning4j.models.word2vec.Word2Vec)11 TextPipeline (org.deeplearning4j.spark.text.functions.TextPipeline)10