Search in sources :

Example 11 with ShallowSequenceElement

use of org.deeplearning4j.models.sequencevectors.sequence.ShallowSequenceElement in project deeplearning4j by deeplearning4j.

the class SparkDBOW method frameSequence.

@Override
public Frame<? extends TrainingMessage> frameSequence(Sequence<ShallowSequenceElement> sequence, AtomicLong nextRandom, double learningRate) {
    if (vectorsConfiguration.getSampling() > 0)
        sequence = BaseSparkLearningAlgorithm.applySubsampling(sequence, nextRandom, 10L, vectorsConfiguration.getSampling());
    int currentWindow = vectorsConfiguration.getWindow();
    if (vectorsConfiguration.getVariableWindows() != null && vectorsConfiguration.getVariableWindows().length != 0) {
        currentWindow = vectorsConfiguration.getVariableWindows()[RandomUtils.nextInt(vectorsConfiguration.getVariableWindows().length)];
    }
    if (frame == null)
        synchronized (this) {
            if (frame == null)
                frame = new ThreadLocal<>();
        }
    if (frame.get() == null)
        frame.set(new Frame<SkipGramRequestMessage>(BasicSequenceProvider.getInstance().getNextValue()));
    for (ShallowSequenceElement lastWord : sequence.getSequenceLabels()) {
        for (ShallowSequenceElement word : sequence.getElements()) {
            iterateSample(word, lastWord, nextRandom, learningRate);
            nextRandom.set(Math.abs(nextRandom.get() * 25214903917L + 11));
        }
    }
    // at this moment we should have something in ThreadLocal Frame, so we'll send it to VoidParameterServer for processing
    Frame<SkipGramRequestMessage> currentFrame = frame.get();
    frame.set(new Frame<SkipGramRequestMessage>(BasicSequenceProvider.getInstance().getNextValue()));
    return currentFrame;
}
Also used : Frame(org.nd4j.parameterserver.distributed.messages.Frame) ShallowSequenceElement(org.deeplearning4j.models.sequencevectors.sequence.ShallowSequenceElement) SkipGramRequestMessage(org.nd4j.parameterserver.distributed.messages.requests.SkipGramRequestMessage)

Aggregations

ShallowSequenceElement (org.deeplearning4j.models.sequencevectors.sequence.ShallowSequenceElement)11 Frame (org.nd4j.parameterserver.distributed.messages.Frame)5 Sequence (org.deeplearning4j.models.sequencevectors.sequence.Sequence)4 ArrayList (java.util.ArrayList)3 AtomicLong (java.util.concurrent.atomic.AtomicLong)2 ND4JIllegalStateException (org.nd4j.linalg.exception.ND4JIllegalStateException)2 CbowRequestMessage (org.nd4j.parameterserver.distributed.messages.requests.CbowRequestMessage)2 SkipGramRequestMessage (org.nd4j.parameterserver.distributed.messages.requests.SkipGramRequestMessage)2 RoutedTransport (org.nd4j.parameterserver.distributed.transport.RoutedTransport)2 DL4JInvalidInputException (org.deeplearning4j.exception.DL4JInvalidInputException)1 Huffman (org.deeplearning4j.models.word2vec.Huffman)1 VocabWord (org.deeplearning4j.models.word2vec.VocabWord)1 AbstractCache (org.deeplearning4j.models.word2vec.wordstore.inmemory.AbstractCache)1 ExportContainer (org.deeplearning4j.spark.models.sequencevectors.export.ExportContainer)1 Test (org.junit.Test)1