Search in sources :

Example 6 with ShallowSequenceElement

use of org.deeplearning4j.models.sequencevectors.sequence.ShallowSequenceElement in project deeplearning4j by deeplearning4j.

the class SparkDM method frameSequence.

@Override
public Frame<? extends TrainingMessage> frameSequence(Sequence<ShallowSequenceElement> sequence, AtomicLong nextRandom, double learningRate) {
    if (vectorsConfiguration.getSampling() > 0)
        sequence = BaseSparkLearningAlgorithm.applySubsampling(sequence, nextRandom, 10L, vectorsConfiguration.getSampling());
    int currentWindow = vectorsConfiguration.getWindow();
    if (vectorsConfiguration.getVariableWindows() != null && vectorsConfiguration.getVariableWindows().length != 0) {
        currentWindow = vectorsConfiguration.getVariableWindows()[RandomUtils.nextInt(vectorsConfiguration.getVariableWindows().length)];
    }
    if (frame == null)
        synchronized (this) {
            if (frame == null)
                frame = new ThreadLocal<>();
        }
    if (frame.get() == null)
        frame.set(new Frame<CbowRequestMessage>(BasicSequenceProvider.getInstance().getNextValue()));
    for (int i = 0; i < sequence.getElements().size(); i++) {
        nextRandom.set(Math.abs(nextRandom.get() * 25214903917L + 11));
        int b = (int) nextRandom.get() % currentWindow;
        int end = currentWindow * 2 + 1 - b;
        ShallowSequenceElement currentWord = sequence.getElementByIndex(i);
        List<Integer> intsList = new ArrayList<>();
        for (int a = b; a < end; a++) {
            if (a != currentWindow) {
                int c = i - currentWindow + a;
                if (c >= 0 && c < sequence.size()) {
                    ShallowSequenceElement lastWord = sequence.getElementByIndex(c);
                    intsList.add(lastWord.getIndex());
                }
            }
        }
        // basically it's the same as CBOW, we just add labels here
        if (sequence.getSequenceLabels() != null) {
            for (ShallowSequenceElement label : sequence.getSequenceLabels()) {
                intsList.add(label.getIndex());
            }
        } else
            // FIXME: we probably should throw this exception earlier?
            throw new DL4JInvalidInputException("Sequence passed via RDD has no labels within, nothing to learn here");
        // just converting values to int
        int[] windowWords = new int[intsList.size()];
        for (int x = 0; x < windowWords.length; x++) {
            windowWords[x] = intsList.get(x);
        }
        if (windowWords.length < 1)
            continue;
        iterateSample(currentWord, windowWords, nextRandom, learningRate, false, 0, true, null);
    }
    Frame<CbowRequestMessage> currentFrame = frame.get();
    frame.set(new Frame<CbowRequestMessage>(BasicSequenceProvider.getInstance().getNextValue()));
    return currentFrame;
}
Also used : Frame(org.nd4j.parameterserver.distributed.messages.Frame) ShallowSequenceElement(org.deeplearning4j.models.sequencevectors.sequence.ShallowSequenceElement) ArrayList(java.util.ArrayList) DL4JInvalidInputException(org.deeplearning4j.exception.DL4JInvalidInputException) CbowRequestMessage(org.nd4j.parameterserver.distributed.messages.requests.CbowRequestMessage)

Example 7 with ShallowSequenceElement

use of org.deeplearning4j.models.sequencevectors.sequence.ShallowSequenceElement in project deeplearning4j by deeplearning4j.

the class SparkSequenceVectorsTest method testFrequenciesCount.

@Test
public void testFrequenciesCount() throws Exception {
    JavaRDD<Sequence<VocabWord>> sequences = sc.parallelize(sequencesCyclic);
    SparkSequenceVectors<VocabWord> seqVec = new SparkSequenceVectors<>();
    seqVec.fitSequences(sequences);
    Counter<Long> counter = seqVec.getCounter();
    // element "0" should have frequency of 20
    assertEquals(20, counter.getCount(0L), 1e-5);
    // elements 1 - 9 should have frequencies of 10
    for (int e = 1; e < sequencesCyclic.get(0).getElements().size() - 1; e++) {
        assertEquals(10, counter.getCount(sequencesCyclic.get(0).getElementByIndex(e).getStorageId()), 1e-5);
    }
    VocabCache<ShallowSequenceElement> shallowVocab = seqVec.getShallowVocabCache();
    assertEquals(10, shallowVocab.numWords());
    ShallowSequenceElement zero = shallowVocab.tokenFor(0L);
    ShallowSequenceElement first = shallowVocab.tokenFor(1L);
    assertNotEquals(null, zero);
    assertEquals(20.0, zero.getElementFrequency(), 1e-5);
    assertEquals(0, zero.getIndex());
    assertEquals(10.0, first.getElementFrequency(), 1e-5);
}
Also used : ShallowSequenceElement(org.deeplearning4j.models.sequencevectors.sequence.ShallowSequenceElement) VocabWord(org.deeplearning4j.models.word2vec.VocabWord) Sequence(org.deeplearning4j.models.sequencevectors.sequence.Sequence) Test(org.junit.Test)

Example 8 with ShallowSequenceElement

use of org.deeplearning4j.models.sequencevectors.sequence.ShallowSequenceElement in project deeplearning4j by deeplearning4j.

the class DistributedFunction method call.

@Override
public ExportContainer<T> call(T word) throws Exception {
    if (shallowVocabCache == null)
        shallowVocabCache = shallowVocabBroadcast.getValue();
    ExportContainer<T> container = new ExportContainer<>();
    ShallowSequenceElement reduced = shallowVocabCache.tokenFor(word.getStorageId());
    word.setIndex(reduced.getIndex());
    container.setElement(word);
    container.setArray(VoidParameterServer.getInstance().getVector(reduced.getIndex()));
    return container;
}
Also used : ShallowSequenceElement(org.deeplearning4j.models.sequencevectors.sequence.ShallowSequenceElement) ExportContainer(org.deeplearning4j.spark.models.sequencevectors.export.ExportContainer)

Example 9 with ShallowSequenceElement

use of org.deeplearning4j.models.sequencevectors.sequence.ShallowSequenceElement in project deeplearning4j by deeplearning4j.

the class PartitionTrainingFunction method call.

@SuppressWarnings("unchecked")
@Override
public void call(Iterator<Sequence<T>> sequenceIterator) throws Exception {
    /**
         * first we initialize
         */
    if (vectorsConfiguration == null)
        vectorsConfiguration = configurationBroadcast.getValue();
    if (paramServer == null) {
        paramServer = VoidParameterServer.getInstance();
        if (elementsLearningAlgorithm == null) {
            try {
                elementsLearningAlgorithm = (SparkElementsLearningAlgorithm) Class.forName(vectorsConfiguration.getElementsLearningAlgorithm()).newInstance();
            } catch (Exception e) {
                throw new RuntimeException(e);
            }
        }
        driver = elementsLearningAlgorithm.getTrainingDriver();
        // FIXME: init line should probably be removed, basically init happens in VocabRddFunction
        paramServer.init(paramServerConfigurationBroadcast.getValue(), new RoutedTransport(), driver);
    }
    if (shallowVocabCache == null)
        shallowVocabCache = vocabCacheBroadcast.getValue();
    if (elementsLearningAlgorithm == null && vectorsConfiguration.getElementsLearningAlgorithm() != null) {
        // TODO: do ELA initialization
        try {
            elementsLearningAlgorithm = (SparkElementsLearningAlgorithm) Class.forName(vectorsConfiguration.getElementsLearningAlgorithm()).newInstance();
        } catch (Exception e) {
            throw new RuntimeException(e);
        }
    }
    if (elementsLearningAlgorithm != null)
        elementsLearningAlgorithm.configure(shallowVocabCache, null, vectorsConfiguration);
    if (sequenceLearningAlgorithm == null && vectorsConfiguration.getSequenceLearningAlgorithm() != null) {
        // TODO: do SLA initialization
        try {
            sequenceLearningAlgorithm = (SparkSequenceLearningAlgorithm) Class.forName(vectorsConfiguration.getSequenceLearningAlgorithm()).newInstance();
            sequenceLearningAlgorithm.configure(shallowVocabCache, null, vectorsConfiguration);
        } catch (Exception e) {
            throw new RuntimeException(e);
        }
    }
    if (sequenceLearningAlgorithm != null)
        sequenceLearningAlgorithm.configure(shallowVocabCache, null, vectorsConfiguration);
    if (elementsLearningAlgorithm == null && sequenceLearningAlgorithm == null) {
        throw new ND4JIllegalStateException("No LearningAlgorithms specified!");
    }
    List<Sequence<ShallowSequenceElement>> sequences = new ArrayList<>();
    // now we roll throw Sequences and prepare/convert/"learn" them
    while (sequenceIterator.hasNext()) {
        Sequence<T> sequence = sequenceIterator.next();
        Sequence<ShallowSequenceElement> mergedSequence = new Sequence<>();
        for (T element : sequence.getElements()) {
            // it's possible to get null here, i.e. if frequency for this element is below minWordFrequency threshold
            ShallowSequenceElement reduced = shallowVocabCache.tokenFor(element.getStorageId());
            if (reduced != null)
                mergedSequence.addElement(reduced);
        }
        // do the same with labels, transfer them, if any
        if (sequenceLearningAlgorithm != null && vectorsConfiguration.isTrainSequenceVectors()) {
            for (T label : sequence.getSequenceLabels()) {
                ShallowSequenceElement reduced = shallowVocabCache.tokenFor(label.getStorageId());
                if (reduced != null)
                    mergedSequence.addSequenceLabel(reduced);
            }
        }
        sequences.add(mergedSequence);
        if (sequences.size() >= 8) {
            trainAllAtOnce(sequences);
            sequences.clear();
        }
    }
    if (sequences.size() > 0) {
        // finishing training round, to make sure we don't have trails
        trainAllAtOnce(sequences);
        sequences.clear();
    }
}
Also used : ShallowSequenceElement(org.deeplearning4j.models.sequencevectors.sequence.ShallowSequenceElement) ArrayList(java.util.ArrayList) ND4JIllegalStateException(org.nd4j.linalg.exception.ND4JIllegalStateException) Sequence(org.deeplearning4j.models.sequencevectors.sequence.Sequence) RoutedTransport(org.nd4j.parameterserver.distributed.transport.RoutedTransport) ND4JIllegalStateException(org.nd4j.linalg.exception.ND4JIllegalStateException)

Example 10 with ShallowSequenceElement

use of org.deeplearning4j.models.sequencevectors.sequence.ShallowSequenceElement in project deeplearning4j by deeplearning4j.

the class BaseSparkLearningAlgorithm method applySubsampling.

public static Sequence<ShallowSequenceElement> applySubsampling(@NonNull Sequence<ShallowSequenceElement> sequence, @NonNull AtomicLong nextRandom, long totalElementsCount, double prob) {
    Sequence<ShallowSequenceElement> result = new Sequence<>();
    // subsampling implementation, if subsampling threshold met, just continue to next element
    if (prob > 0) {
        result.setSequenceId(sequence.getSequenceId());
        if (sequence.getSequenceLabels() != null)
            result.setSequenceLabels(sequence.getSequenceLabels());
        if (sequence.getSequenceLabel() != null)
            result.setSequenceLabel(sequence.getSequenceLabel());
        for (ShallowSequenceElement element : sequence.getElements()) {
            double numWords = (double) totalElementsCount;
            double ran = (Math.sqrt(element.getElementFrequency() / (prob * numWords)) + 1) * (prob * numWords) / element.getElementFrequency();
            nextRandom.set(Math.abs(nextRandom.get() * 25214903917L + 11));
            if (ran < (nextRandom.get() & 0xFFFF) / (double) 65536) {
                continue;
            }
            result.addElement(element);
        }
        return result;
    } else
        return sequence;
}
Also used : ShallowSequenceElement(org.deeplearning4j.models.sequencevectors.sequence.ShallowSequenceElement) Sequence(org.deeplearning4j.models.sequencevectors.sequence.Sequence)

Aggregations

ShallowSequenceElement (org.deeplearning4j.models.sequencevectors.sequence.ShallowSequenceElement)11 Frame (org.nd4j.parameterserver.distributed.messages.Frame)5 Sequence (org.deeplearning4j.models.sequencevectors.sequence.Sequence)4 ArrayList (java.util.ArrayList)3 AtomicLong (java.util.concurrent.atomic.AtomicLong)2 ND4JIllegalStateException (org.nd4j.linalg.exception.ND4JIllegalStateException)2 CbowRequestMessage (org.nd4j.parameterserver.distributed.messages.requests.CbowRequestMessage)2 SkipGramRequestMessage (org.nd4j.parameterserver.distributed.messages.requests.SkipGramRequestMessage)2 RoutedTransport (org.nd4j.parameterserver.distributed.transport.RoutedTransport)2 DL4JInvalidInputException (org.deeplearning4j.exception.DL4JInvalidInputException)1 Huffman (org.deeplearning4j.models.word2vec.Huffman)1 VocabWord (org.deeplearning4j.models.word2vec.VocabWord)1 AbstractCache (org.deeplearning4j.models.word2vec.wordstore.inmemory.AbstractCache)1 ExportContainer (org.deeplearning4j.spark.models.sequencevectors.export.ExportContainer)1 Test (org.junit.Test)1