Search in sources :

Example 16 with Sequence

use of org.deeplearning4j.models.sequencevectors.sequence.Sequence in project deeplearning4j by deeplearning4j.

the class PartitionTrainingFunction method call.

@SuppressWarnings("unchecked")
@Override
public void call(Iterator<Sequence<T>> sequenceIterator) throws Exception {
    /**
         * first we initialize
         */
    if (vectorsConfiguration == null)
        vectorsConfiguration = configurationBroadcast.getValue();
    if (paramServer == null) {
        paramServer = VoidParameterServer.getInstance();
        if (elementsLearningAlgorithm == null) {
            try {
                elementsLearningAlgorithm = (SparkElementsLearningAlgorithm) Class.forName(vectorsConfiguration.getElementsLearningAlgorithm()).newInstance();
            } catch (Exception e) {
                throw new RuntimeException(e);
            }
        }
        driver = elementsLearningAlgorithm.getTrainingDriver();
        // FIXME: init line should probably be removed, basically init happens in VocabRddFunction
        paramServer.init(paramServerConfigurationBroadcast.getValue(), new RoutedTransport(), driver);
    }
    if (shallowVocabCache == null)
        shallowVocabCache = vocabCacheBroadcast.getValue();
    if (elementsLearningAlgorithm == null && vectorsConfiguration.getElementsLearningAlgorithm() != null) {
        // TODO: do ELA initialization
        try {
            elementsLearningAlgorithm = (SparkElementsLearningAlgorithm) Class.forName(vectorsConfiguration.getElementsLearningAlgorithm()).newInstance();
        } catch (Exception e) {
            throw new RuntimeException(e);
        }
    }
    if (elementsLearningAlgorithm != null)
        elementsLearningAlgorithm.configure(shallowVocabCache, null, vectorsConfiguration);
    if (sequenceLearningAlgorithm == null && vectorsConfiguration.getSequenceLearningAlgorithm() != null) {
        // TODO: do SLA initialization
        try {
            sequenceLearningAlgorithm = (SparkSequenceLearningAlgorithm) Class.forName(vectorsConfiguration.getSequenceLearningAlgorithm()).newInstance();
            sequenceLearningAlgorithm.configure(shallowVocabCache, null, vectorsConfiguration);
        } catch (Exception e) {
            throw new RuntimeException(e);
        }
    }
    if (sequenceLearningAlgorithm != null)
        sequenceLearningAlgorithm.configure(shallowVocabCache, null, vectorsConfiguration);
    if (elementsLearningAlgorithm == null && sequenceLearningAlgorithm == null) {
        throw new ND4JIllegalStateException("No LearningAlgorithms specified!");
    }
    List<Sequence<ShallowSequenceElement>> sequences = new ArrayList<>();
    // now we roll throw Sequences and prepare/convert/"learn" them
    while (sequenceIterator.hasNext()) {
        Sequence<T> sequence = sequenceIterator.next();
        Sequence<ShallowSequenceElement> mergedSequence = new Sequence<>();
        for (T element : sequence.getElements()) {
            // it's possible to get null here, i.e. if frequency for this element is below minWordFrequency threshold
            ShallowSequenceElement reduced = shallowVocabCache.tokenFor(element.getStorageId());
            if (reduced != null)
                mergedSequence.addElement(reduced);
        }
        // do the same with labels, transfer them, if any
        if (sequenceLearningAlgorithm != null && vectorsConfiguration.isTrainSequenceVectors()) {
            for (T label : sequence.getSequenceLabels()) {
                ShallowSequenceElement reduced = shallowVocabCache.tokenFor(label.getStorageId());
                if (reduced != null)
                    mergedSequence.addSequenceLabel(reduced);
            }
        }
        sequences.add(mergedSequence);
        if (sequences.size() >= 8) {
            trainAllAtOnce(sequences);
            sequences.clear();
        }
    }
    if (sequences.size() > 0) {
        // finishing training round, to make sure we don't have trails
        trainAllAtOnce(sequences);
        sequences.clear();
    }
}
Also used : ShallowSequenceElement(org.deeplearning4j.models.sequencevectors.sequence.ShallowSequenceElement) ArrayList(java.util.ArrayList) ND4JIllegalStateException(org.nd4j.linalg.exception.ND4JIllegalStateException) Sequence(org.deeplearning4j.models.sequencevectors.sequence.Sequence) RoutedTransport(org.nd4j.parameterserver.distributed.transport.RoutedTransport) ND4JIllegalStateException(org.nd4j.linalg.exception.ND4JIllegalStateException)

Example 17 with Sequence

use of org.deeplearning4j.models.sequencevectors.sequence.Sequence in project deeplearning4j by deeplearning4j.

the class BaseSparkLearningAlgorithm method applySubsampling.

public static Sequence<ShallowSequenceElement> applySubsampling(@NonNull Sequence<ShallowSequenceElement> sequence, @NonNull AtomicLong nextRandom, long totalElementsCount, double prob) {
    Sequence<ShallowSequenceElement> result = new Sequence<>();
    // subsampling implementation, if subsampling threshold met, just continue to next element
    if (prob > 0) {
        result.setSequenceId(sequence.getSequenceId());
        if (sequence.getSequenceLabels() != null)
            result.setSequenceLabels(sequence.getSequenceLabels());
        if (sequence.getSequenceLabel() != null)
            result.setSequenceLabel(sequence.getSequenceLabel());
        for (ShallowSequenceElement element : sequence.getElements()) {
            double numWords = (double) totalElementsCount;
            double ran = (Math.sqrt(element.getElementFrequency() / (prob * numWords)) + 1) * (prob * numWords) / element.getElementFrequency();
            nextRandom.set(Math.abs(nextRandom.get() * 25214903917L + 11));
            if (ran < (nextRandom.get() & 0xFFFF) / (double) 65536) {
                continue;
            }
            result.addElement(element);
        }
        return result;
    } else
        return sequence;
}
Also used : ShallowSequenceElement(org.deeplearning4j.models.sequencevectors.sequence.ShallowSequenceElement) Sequence(org.deeplearning4j.models.sequencevectors.sequence.Sequence)

Example 18 with Sequence

use of org.deeplearning4j.models.sequencevectors.sequence.Sequence in project deeplearning4j by deeplearning4j.

the class SparkWord2Vec method fitSentences.

public void fitSentences(JavaRDD<String> sentences) {
    /**
         * Basically all we want here is tokenization, to get JavaRDD<Sequence<VocabWord>> out of Strings, and then we just go  for SeqVec
         */
    validateConfiguration();
    final JavaSparkContext context = new JavaSparkContext(sentences.context());
    broadcastEnvironment(context);
    JavaRDD<Sequence<VocabWord>> seqRdd = sentences.map(new TokenizerFunction(configurationBroadcast));
    // now since we have new rdd - just pass it to SeqVec
    super.fitSequences(seqRdd);
}
Also used : TokenizerFunction(org.deeplearning4j.spark.models.sequencevectors.functions.TokenizerFunction) JavaSparkContext(org.apache.spark.api.java.JavaSparkContext) Sequence(org.deeplearning4j.models.sequencevectors.sequence.Sequence)

Aggregations

Sequence (org.deeplearning4j.models.sequencevectors.sequence.Sequence)18 VocabWord (org.deeplearning4j.models.word2vec.VocabWord)11 Test (org.junit.Test)5 JavaSparkContext (org.apache.spark.api.java.JavaSparkContext)4 ShallowSequenceElement (org.deeplearning4j.models.sequencevectors.sequence.ShallowSequenceElement)4 BasicLineIterator (org.deeplearning4j.text.sentenceiterator.BasicLineIterator)4 SentenceIterator (org.deeplearning4j.text.sentenceiterator.SentenceIterator)4 ND4JIllegalStateException (org.nd4j.linalg.exception.ND4JIllegalStateException)3 ArrayList (java.util.ArrayList)2 AtomicBoolean (java.util.concurrent.atomic.AtomicBoolean)2 ClassPathResource (org.datavec.api.util.ClassPathResource)2 SequenceIterator (org.deeplearning4j.models.sequencevectors.interfaces.SequenceIterator)2 AbstractSequenceIterator (org.deeplearning4j.models.sequencevectors.iterators.AbstractSequenceIterator)2 SentenceTransformer (org.deeplearning4j.models.sequencevectors.transformers.impl.SentenceTransformer)2 AbstractCache (org.deeplearning4j.models.word2vec.wordstore.inmemory.AbstractCache)2 FileLabelAwareIterator (org.deeplearning4j.text.documentiterator.FileLabelAwareIterator)2 MutipleEpochsSentenceIterator (org.deeplearning4j.text.sentenceiterator.MutipleEpochsSentenceIterator)2 PrefetchingSentenceIterator (org.deeplearning4j.text.sentenceiterator.PrefetchingSentenceIterator)2 RoutedTransport (org.nd4j.parameterserver.distributed.transport.RoutedTransport)2 List (java.util.List)1