Search in sources :

Example 1 with ShallowSequenceElement

use of org.deeplearning4j.models.sequencevectors.sequence.ShallowSequenceElement in project deeplearning4j by deeplearning4j.

the class PartitionTrainingFunction method trainAllAtOnce.

protected void trainAllAtOnce(List<Sequence<ShallowSequenceElement>> sequences) {
    Frame bigFrame = new Frame(BasicSequenceProvider.getInstance().getNextValue());
    for (Sequence<ShallowSequenceElement> sequence : sequences) {
        Frame frame = elementsLearningAlgorithm.frameSequence(sequence, new AtomicLong(119L), 25e-3f);
        bigFrame.stackMessages(frame.getMessages());
    }
    if (bigFrame.size() > 0)
        paramServer.execDistributed(bigFrame);
}
Also used : Frame(org.nd4j.parameterserver.distributed.messages.Frame) AtomicLong(java.util.concurrent.atomic.AtomicLong) ShallowSequenceElement(org.deeplearning4j.models.sequencevectors.sequence.ShallowSequenceElement)

Example 2 with ShallowSequenceElement

use of org.deeplearning4j.models.sequencevectors.sequence.ShallowSequenceElement in project deeplearning4j by deeplearning4j.

the class TrainingFunction method call.

@Override
@SuppressWarnings("unchecked")
public void call(Sequence<T> sequence) throws Exception {
    /**
         * Depending on actual training mode, we'll either go for SkipGram/CBOW/PV-DM/PV-DBOW or whatever
         */
    if (vectorsConfiguration == null)
        vectorsConfiguration = configurationBroadcast.getValue();
    if (paramServer == null) {
        paramServer = VoidParameterServer.getInstance();
        if (elementsLearningAlgorithm == null) {
            try {
                elementsLearningAlgorithm = (SparkElementsLearningAlgorithm) Class.forName(vectorsConfiguration.getElementsLearningAlgorithm()).newInstance();
            } catch (Exception e) {
                throw new RuntimeException(e);
            }
        }
        driver = elementsLearningAlgorithm.getTrainingDriver();
        // FIXME: init line should probably be removed, basically init happens in VocabRddFunction
        paramServer.init(paramServerConfigurationBroadcast.getValue(), new RoutedTransport(), driver);
    }
    if (vectorsConfiguration == null)
        vectorsConfiguration = configurationBroadcast.getValue();
    if (shallowVocabCache == null)
        shallowVocabCache = vocabCacheBroadcast.getValue();
    if (elementsLearningAlgorithm == null && vectorsConfiguration.getElementsLearningAlgorithm() != null) {
        // TODO: do ELA initialization
        try {
            elementsLearningAlgorithm = (SparkElementsLearningAlgorithm) Class.forName(vectorsConfiguration.getElementsLearningAlgorithm()).newInstance();
            elementsLearningAlgorithm.configure(shallowVocabCache, null, vectorsConfiguration);
        } catch (Exception e) {
            throw new RuntimeException(e);
        }
    }
    if (sequenceLearningAlgorithm == null && vectorsConfiguration.getSequenceLearningAlgorithm() != null) {
        // TODO: do SLA initialization
        try {
            sequenceLearningAlgorithm = (SparkSequenceLearningAlgorithm) Class.forName(vectorsConfiguration.getSequenceLearningAlgorithm()).newInstance();
            sequenceLearningAlgorithm.configure(shallowVocabCache, null, vectorsConfiguration);
        } catch (Exception e) {
            throw new RuntimeException(e);
        }
    }
    if (elementsLearningAlgorithm == null && sequenceLearningAlgorithm == null) {
        throw new ND4JIllegalStateException("No LearningAlgorithms specified!");
    }
    /*
         at this moment we should have everything ready for actual initialization
         the only limitation we have - our sequence is detached from actual vocabulary, so we need to merge it back virtually
        */
    Sequence<ShallowSequenceElement> mergedSequence = new Sequence<>();
    for (T element : sequence.getElements()) {
        // it's possible to get null here, i.e. if frequency for this element is below minWordFrequency threshold
        ShallowSequenceElement reduced = shallowVocabCache.tokenFor(element.getStorageId());
        if (reduced != null)
            mergedSequence.addElement(reduced);
    }
    // do the same with labels, transfer them, if any
    if (sequenceLearningAlgorithm != null && vectorsConfiguration.isTrainSequenceVectors()) {
        for (T label : sequence.getSequenceLabels()) {
            ShallowSequenceElement reduced = shallowVocabCache.tokenFor(label.getStorageId());
            if (reduced != null)
                mergedSequence.addSequenceLabel(reduced);
        }
    }
    // FIXME: temporary hook
    if (sequence.size() > 0)
        paramServer.execDistributed(elementsLearningAlgorithm.frameSequence(mergedSequence, new AtomicLong(119), 25e-3));
    else
        log.warn("Skipping empty sequence...");
}
Also used : AtomicLong(java.util.concurrent.atomic.AtomicLong) ShallowSequenceElement(org.deeplearning4j.models.sequencevectors.sequence.ShallowSequenceElement) ND4JIllegalStateException(org.nd4j.linalg.exception.ND4JIllegalStateException) Sequence(org.deeplearning4j.models.sequencevectors.sequence.Sequence) RoutedTransport(org.nd4j.parameterserver.distributed.transport.RoutedTransport) ND4JIllegalStateException(org.nd4j.linalg.exception.ND4JIllegalStateException)

Example 3 with ShallowSequenceElement

use of org.deeplearning4j.models.sequencevectors.sequence.ShallowSequenceElement in project deeplearning4j by deeplearning4j.

the class SparkCBOW method frameSequence.

@Override
public Frame<? extends TrainingMessage> frameSequence(Sequence<ShallowSequenceElement> sequence, AtomicLong nextRandom, double learningRate) {
    // FIXME: totalElementsCount should have real value
    if (vectorsConfiguration.getSampling() > 0)
        sequence = BaseSparkLearningAlgorithm.applySubsampling(sequence, nextRandom, 10L, vectorsConfiguration.getSampling());
    int currentWindow = vectorsConfiguration.getWindow();
    if (vectorsConfiguration.getVariableWindows() != null && vectorsConfiguration.getVariableWindows().length != 0) {
        currentWindow = vectorsConfiguration.getVariableWindows()[RandomUtils.nextInt(vectorsConfiguration.getVariableWindows().length)];
    }
    if (frame == null)
        synchronized (this) {
            if (frame == null)
                frame = new ThreadLocal<>();
        }
    if (frame.get() == null)
        frame.set(new Frame<CbowRequestMessage>(BasicSequenceProvider.getInstance().getNextValue()));
    for (int i = 0; i < sequence.getElements().size(); i++) {
        nextRandom.set(Math.abs(nextRandom.get() * 25214903917L + 11));
        int b = (int) nextRandom.get() % currentWindow;
        int end = currentWindow * 2 + 1 - b;
        ShallowSequenceElement currentWord = sequence.getElementByIndex(i);
        List<Integer> intsList = new ArrayList<>();
        for (int a = b; a < end; a++) {
            if (a != currentWindow) {
                int c = i - currentWindow + a;
                if (c >= 0 && c < sequence.size()) {
                    ShallowSequenceElement lastWord = sequence.getElementByIndex(c);
                    intsList.add(lastWord.getIndex());
                }
            }
        }
        // just converting values to int
        int[] windowWords = new int[intsList.size()];
        for (int x = 0; x < windowWords.length; x++) {
            windowWords[x] = intsList.get(x);
        }
        if (windowWords.length < 1)
            continue;
        iterateSample(currentWord, windowWords, nextRandom, learningRate, false, 0, true, null);
    }
    Frame<CbowRequestMessage> currentFrame = frame.get();
    frame.set(new Frame<CbowRequestMessage>(BasicSequenceProvider.getInstance().getNextValue()));
    return currentFrame;
}
Also used : Frame(org.nd4j.parameterserver.distributed.messages.Frame) ShallowSequenceElement(org.deeplearning4j.models.sequencevectors.sequence.ShallowSequenceElement) ArrayList(java.util.ArrayList) CbowRequestMessage(org.nd4j.parameterserver.distributed.messages.requests.CbowRequestMessage)

Example 4 with ShallowSequenceElement

use of org.deeplearning4j.models.sequencevectors.sequence.ShallowSequenceElement in project deeplearning4j by deeplearning4j.

the class SparkSkipGram method frameSequence.

@Override
public Frame<? extends TrainingMessage> frameSequence(Sequence<ShallowSequenceElement> sequence, AtomicLong nextRandom, double learningRate) {
    // FIXME: totalElementsCount should have real value
    if (vectorsConfiguration.getSampling() > 0)
        sequence = BaseSparkLearningAlgorithm.applySubsampling(sequence, nextRandom, 10L, vectorsConfiguration.getSampling());
    int currentWindow = vectorsConfiguration.getWindow();
    if (vectorsConfiguration.getVariableWindows() != null && vectorsConfiguration.getVariableWindows().length != 0) {
        currentWindow = vectorsConfiguration.getVariableWindows()[RandomUtils.nextInt(vectorsConfiguration.getVariableWindows().length)];
    }
    if (frame == null)
        synchronized (this) {
            if (frame == null)
                frame = new ThreadLocal<>();
        }
    if (frame.get() == null)
        frame.set(new Frame<SkipGramRequestMessage>(BasicSequenceProvider.getInstance().getNextValue()));
    for (int i = 0; i < sequence.size(); i++) {
        nextRandom.set(Math.abs(nextRandom.get() * 25214903917L + 11));
        ShallowSequenceElement word = sequence.getElementByIndex(i);
        if (word == null)
            continue;
        int b = (int) (nextRandom.get() % currentWindow);
        int end = currentWindow * 2 + 1 - b;
        for (int a = b; a < end; a++) {
            if (a != currentWindow) {
                int c = i - currentWindow + a;
                if (c >= 0 && c < sequence.size()) {
                    ShallowSequenceElement lastWord = sequence.getElementByIndex(c);
                    iterateSample(word, lastWord, nextRandom, learningRate);
                    nextRandom.set(Math.abs(nextRandom.get() * 25214903917L + 11));
                }
            }
        }
    }
    // at this moment we should have something in ThreadLocal Frame, so we'll send it to VoidParameterServer for processing
    Frame<SkipGramRequestMessage> currentFrame = frame.get();
    frame.set(new Frame<SkipGramRequestMessage>(BasicSequenceProvider.getInstance().getNextValue()));
    return currentFrame;
}
Also used : Frame(org.nd4j.parameterserver.distributed.messages.Frame) ShallowSequenceElement(org.deeplearning4j.models.sequencevectors.sequence.ShallowSequenceElement) SkipGramRequestMessage(org.nd4j.parameterserver.distributed.messages.requests.SkipGramRequestMessage)

Example 5 with ShallowSequenceElement

use of org.deeplearning4j.models.sequencevectors.sequence.ShallowSequenceElement in project deeplearning4j by deeplearning4j.

the class SparkSequenceVectors method buildShallowVocabCache.

/**
     * This method builds shadow vocabulary and huffman tree
     *
     * @param counter
     * @return
     */
protected VocabCache<ShallowSequenceElement> buildShallowVocabCache(Counter<Long> counter) {
    // TODO: need simplified cache here, that will operate on Long instead of string labels
    VocabCache<ShallowSequenceElement> vocabCache = new AbstractCache<>();
    for (Long id : counter.keySet()) {
        ShallowSequenceElement shallowElement = new ShallowSequenceElement(counter.getCount(id), id);
        vocabCache.addToken(shallowElement);
    }
    // building huffman tree
    Huffman huffman = new Huffman(vocabCache.vocabWords());
    huffman.build();
    huffman.applyIndexes(vocabCache);
    return vocabCache;
}
Also used : ShallowSequenceElement(org.deeplearning4j.models.sequencevectors.sequence.ShallowSequenceElement) Huffman(org.deeplearning4j.models.word2vec.Huffman) AbstractCache(org.deeplearning4j.models.word2vec.wordstore.inmemory.AbstractCache)

Aggregations

ShallowSequenceElement (org.deeplearning4j.models.sequencevectors.sequence.ShallowSequenceElement)11 Frame (org.nd4j.parameterserver.distributed.messages.Frame)5 Sequence (org.deeplearning4j.models.sequencevectors.sequence.Sequence)4 ArrayList (java.util.ArrayList)3 AtomicLong (java.util.concurrent.atomic.AtomicLong)2 ND4JIllegalStateException (org.nd4j.linalg.exception.ND4JIllegalStateException)2 CbowRequestMessage (org.nd4j.parameterserver.distributed.messages.requests.CbowRequestMessage)2 SkipGramRequestMessage (org.nd4j.parameterserver.distributed.messages.requests.SkipGramRequestMessage)2 RoutedTransport (org.nd4j.parameterserver.distributed.transport.RoutedTransport)2 DL4JInvalidInputException (org.deeplearning4j.exception.DL4JInvalidInputException)1 Huffman (org.deeplearning4j.models.word2vec.Huffman)1 VocabWord (org.deeplearning4j.models.word2vec.VocabWord)1 AbstractCache (org.deeplearning4j.models.word2vec.wordstore.inmemory.AbstractCache)1 ExportContainer (org.deeplearning4j.spark.models.sequencevectors.export.ExportContainer)1 Test (org.junit.Test)1