Search in sources :

Example 1 with CbowRequestMessage

use of org.nd4j.parameterserver.distributed.messages.requests.CbowRequestMessage in project deeplearning4j by deeplearning4j.

the class SparkCBOW method frameSequence.

@Override
public Frame<? extends TrainingMessage> frameSequence(Sequence<ShallowSequenceElement> sequence, AtomicLong nextRandom, double learningRate) {
    // FIXME: totalElementsCount should have real value
    if (vectorsConfiguration.getSampling() > 0)
        sequence = BaseSparkLearningAlgorithm.applySubsampling(sequence, nextRandom, 10L, vectorsConfiguration.getSampling());
    int currentWindow = vectorsConfiguration.getWindow();
    if (vectorsConfiguration.getVariableWindows() != null && vectorsConfiguration.getVariableWindows().length != 0) {
        currentWindow = vectorsConfiguration.getVariableWindows()[RandomUtils.nextInt(vectorsConfiguration.getVariableWindows().length)];
    }
    if (frame == null)
        synchronized (this) {
            if (frame == null)
                frame = new ThreadLocal<>();
        }
    if (frame.get() == null)
        frame.set(new Frame<CbowRequestMessage>(BasicSequenceProvider.getInstance().getNextValue()));
    for (int i = 0; i < sequence.getElements().size(); i++) {
        nextRandom.set(Math.abs(nextRandom.get() * 25214903917L + 11));
        int b = (int) nextRandom.get() % currentWindow;
        int end = currentWindow * 2 + 1 - b;
        ShallowSequenceElement currentWord = sequence.getElementByIndex(i);
        List<Integer> intsList = new ArrayList<>();
        for (int a = b; a < end; a++) {
            if (a != currentWindow) {
                int c = i - currentWindow + a;
                if (c >= 0 && c < sequence.size()) {
                    ShallowSequenceElement lastWord = sequence.getElementByIndex(c);
                    intsList.add(lastWord.getIndex());
                }
            }
        }
        // just converting values to int
        int[] windowWords = new int[intsList.size()];
        for (int x = 0; x < windowWords.length; x++) {
            windowWords[x] = intsList.get(x);
        }
        if (windowWords.length < 1)
            continue;
        iterateSample(currentWord, windowWords, nextRandom, learningRate, false, 0, true, null);
    }
    Frame<CbowRequestMessage> currentFrame = frame.get();
    frame.set(new Frame<CbowRequestMessage>(BasicSequenceProvider.getInstance().getNextValue()));
    return currentFrame;
}
Also used : Frame(org.nd4j.parameterserver.distributed.messages.Frame) ShallowSequenceElement(org.deeplearning4j.models.sequencevectors.sequence.ShallowSequenceElement) ArrayList(java.util.ArrayList) CbowRequestMessage(org.nd4j.parameterserver.distributed.messages.requests.CbowRequestMessage)

Example 2 with CbowRequestMessage

use of org.nd4j.parameterserver.distributed.messages.requests.CbowRequestMessage in project deeplearning4j by deeplearning4j.

the class SparkDM method frameSequence.

@Override
public Frame<? extends TrainingMessage> frameSequence(Sequence<ShallowSequenceElement> sequence, AtomicLong nextRandom, double learningRate) {
    if (vectorsConfiguration.getSampling() > 0)
        sequence = BaseSparkLearningAlgorithm.applySubsampling(sequence, nextRandom, 10L, vectorsConfiguration.getSampling());
    int currentWindow = vectorsConfiguration.getWindow();
    if (vectorsConfiguration.getVariableWindows() != null && vectorsConfiguration.getVariableWindows().length != 0) {
        currentWindow = vectorsConfiguration.getVariableWindows()[RandomUtils.nextInt(vectorsConfiguration.getVariableWindows().length)];
    }
    if (frame == null)
        synchronized (this) {
            if (frame == null)
                frame = new ThreadLocal<>();
        }
    if (frame.get() == null)
        frame.set(new Frame<CbowRequestMessage>(BasicSequenceProvider.getInstance().getNextValue()));
    for (int i = 0; i < sequence.getElements().size(); i++) {
        nextRandom.set(Math.abs(nextRandom.get() * 25214903917L + 11));
        int b = (int) nextRandom.get() % currentWindow;
        int end = currentWindow * 2 + 1 - b;
        ShallowSequenceElement currentWord = sequence.getElementByIndex(i);
        List<Integer> intsList = new ArrayList<>();
        for (int a = b; a < end; a++) {
            if (a != currentWindow) {
                int c = i - currentWindow + a;
                if (c >= 0 && c < sequence.size()) {
                    ShallowSequenceElement lastWord = sequence.getElementByIndex(c);
                    intsList.add(lastWord.getIndex());
                }
            }
        }
        // basically it's the same as CBOW, we just add labels here
        if (sequence.getSequenceLabels() != null) {
            for (ShallowSequenceElement label : sequence.getSequenceLabels()) {
                intsList.add(label.getIndex());
            }
        } else
            // FIXME: we probably should throw this exception earlier?
            throw new DL4JInvalidInputException("Sequence passed via RDD has no labels within, nothing to learn here");
        // just converting values to int
        int[] windowWords = new int[intsList.size()];
        for (int x = 0; x < windowWords.length; x++) {
            windowWords[x] = intsList.get(x);
        }
        if (windowWords.length < 1)
            continue;
        iterateSample(currentWord, windowWords, nextRandom, learningRate, false, 0, true, null);
    }
    Frame<CbowRequestMessage> currentFrame = frame.get();
    frame.set(new Frame<CbowRequestMessage>(BasicSequenceProvider.getInstance().getNextValue()));
    return currentFrame;
}
Also used : Frame(org.nd4j.parameterserver.distributed.messages.Frame) ShallowSequenceElement(org.deeplearning4j.models.sequencevectors.sequence.ShallowSequenceElement) ArrayList(java.util.ArrayList) DL4JInvalidInputException(org.deeplearning4j.exception.DL4JInvalidInputException) CbowRequestMessage(org.nd4j.parameterserver.distributed.messages.requests.CbowRequestMessage)

Example 3 with CbowRequestMessage

use of org.nd4j.parameterserver.distributed.messages.requests.CbowRequestMessage in project deeplearning4j by deeplearning4j.

the class SparkCBOW method iterateSample.

protected void iterateSample(ShallowSequenceElement currentWord, int[] windowWords, AtomicLong nextRandom, double alpha, boolean isInference, int numLabels, boolean trainWords, INDArray inferenceVector) {
    int[] idxSyn1 = null;
    byte[] codes = null;
    if (vectorsConfiguration.isUseHierarchicSoftmax()) {
        idxSyn1 = new int[currentWord.getCodeLength()];
        codes = new byte[currentWord.getCodeLength()];
        for (int p = 0; p < currentWord.getCodeLength(); p++) {
            if (currentWord.getPoints().get(p) < 0)
                continue;
            codes[p] = currentWord.getCodes().get(p);
            idxSyn1[p] = currentWord.getPoints().get(p);
        }
    } else {
        idxSyn1 = new int[0];
        codes = new byte[0];
    }
    CbowRequestMessage cbrm = new CbowRequestMessage(windowWords, idxSyn1, currentWord.getIndex(), codes, (int) vectorsConfiguration.getNegative(), alpha, nextRandom.get());
    frame.get().stackMessage(cbrm);
}
Also used : CbowRequestMessage(org.nd4j.parameterserver.distributed.messages.requests.CbowRequestMessage)

Aggregations

CbowRequestMessage (org.nd4j.parameterserver.distributed.messages.requests.CbowRequestMessage)3 ArrayList (java.util.ArrayList)2 ShallowSequenceElement (org.deeplearning4j.models.sequencevectors.sequence.ShallowSequenceElement)2 Frame (org.nd4j.parameterserver.distributed.messages.Frame)2 DL4JInvalidInputException (org.deeplearning4j.exception.DL4JInvalidInputException)1