Search in sources :

Example 6 with Sequence

use of org.deeplearning4j.models.sequencevectors.sequence.Sequence in project deeplearning4j by deeplearning4j.

the class KeySequenceConvertFunction method call.

@Override
public Sequence<VocabWord> call(Tuple2<String, String> pair) throws Exception {
    Sequence<VocabWord> sequence = new Sequence<>();
    sequence.addSequenceLabel(new VocabWord(1.0, pair._1()));
    if (tokenizerFactory == null)
        instantiateTokenizerFactory();
    List<String> tokens = tokenizerFactory.create(pair._2()).getTokens();
    for (String token : tokens) {
        if (token == null || token.isEmpty())
            continue;
        VocabWord word = new VocabWord(1.0, token);
        sequence.addElement(word);
    }
    return sequence;
}
Also used : VocabWord(org.deeplearning4j.models.word2vec.VocabWord) Sequence(org.deeplearning4j.models.sequencevectors.sequence.Sequence)

Example 7 with Sequence

use of org.deeplearning4j.models.sequencevectors.sequence.Sequence in project deeplearning4j by deeplearning4j.

the class SparkSequenceVectorsTest method testFrequenciesCount.

@Test
public void testFrequenciesCount() throws Exception {
    JavaRDD<Sequence<VocabWord>> sequences = sc.parallelize(sequencesCyclic);
    SparkSequenceVectors<VocabWord> seqVec = new SparkSequenceVectors<>();
    seqVec.fitSequences(sequences);
    Counter<Long> counter = seqVec.getCounter();
    // element "0" should have frequency of 20
    assertEquals(20, counter.getCount(0L), 1e-5);
    // elements 1 - 9 should have frequencies of 10
    for (int e = 1; e < sequencesCyclic.get(0).getElements().size() - 1; e++) {
        assertEquals(10, counter.getCount(sequencesCyclic.get(0).getElementByIndex(e).getStorageId()), 1e-5);
    }
    VocabCache<ShallowSequenceElement> shallowVocab = seqVec.getShallowVocabCache();
    assertEquals(10, shallowVocab.numWords());
    ShallowSequenceElement zero = shallowVocab.tokenFor(0L);
    ShallowSequenceElement first = shallowVocab.tokenFor(1L);
    assertNotEquals(null, zero);
    assertEquals(20.0, zero.getElementFrequency(), 1e-5);
    assertEquals(0, zero.getIndex());
    assertEquals(10.0, first.getElementFrequency(), 1e-5);
}
Also used : ShallowSequenceElement(org.deeplearning4j.models.sequencevectors.sequence.ShallowSequenceElement) VocabWord(org.deeplearning4j.models.word2vec.VocabWord) Sequence(org.deeplearning4j.models.sequencevectors.sequence.Sequence) Test(org.junit.Test)

Example 8 with Sequence

use of org.deeplearning4j.models.sequencevectors.sequence.Sequence in project deeplearning4j by deeplearning4j.

the class SparkSequenceVectorsTest method setUp.

@Before
public void setUp() throws Exception {
    if (sequencesCyclic == null) {
        sequencesCyclic = new ArrayList<>();
        // 10 sequences in total
        for (int с = 0; с < 10; с++) {
            Sequence<VocabWord> sequence = new Sequence<>();
            for (int e = 0; e < 10; e++) {
                // we will have 9 equal elements, with total frequency of 10
                sequence.addElement(new VocabWord(1.0, "" + e, (long) e));
            }
            // and 1 element with frequency of 20
            sequence.addElement(new VocabWord(1.0, "0", 0L));
            sequencesCyclic.add(sequence);
        }
    }
    SparkConf sparkConf = new SparkConf().setMaster("local[8]").setAppName("SeqVecTests");
    sc = new JavaSparkContext(sparkConf);
}
Also used : VocabWord(org.deeplearning4j.models.word2vec.VocabWord) Sequence(org.deeplearning4j.models.sequencevectors.sequence.Sequence) JavaSparkContext(org.apache.spark.api.java.JavaSparkContext) SparkConf(org.apache.spark.SparkConf) Before(org.junit.Before)

Example 9 with Sequence

use of org.deeplearning4j.models.sequencevectors.sequence.Sequence in project deeplearning4j by deeplearning4j.

the class ParagraphVectors method inferVector.

/**
     * This method calculates inferred vector for given document
     *
     * @param document
     * @return
     */
public INDArray inferVector(@NonNull List<VocabWord> document, double learningRate, double minLearningRate, int iterations) {
    SequenceLearningAlgorithm<VocabWord> learner = sequenceLearningAlgorithm;
    if (learner == null) {
        synchronized (this) {
            if (sequenceLearningAlgorithm == null) {
                log.info("Creating new PV-DM learner...");
                learner = new DM<VocabWord>();
                learner.configure(vocab, lookupTable, configuration);
                sequenceLearningAlgorithm = learner;
            } else {
                learner = sequenceLearningAlgorithm;
            }
        }
    }
    learner = sequenceLearningAlgorithm;
    if (document.isEmpty())
        throw new ND4JIllegalStateException("Impossible to apply inference to empty list of words");
    Sequence<VocabWord> sequence = new Sequence<>();
    sequence.addElements(document);
    sequence.setSequenceLabel(new VocabWord(1.0, String.valueOf(new Random().nextInt())));
    initLearners();
    INDArray inf = learner.inferSequence(sequence, seed, learningRate, minLearningRate, iterations);
    return inf;
}
Also used : INDArray(org.nd4j.linalg.api.ndarray.INDArray) VocabWord(org.deeplearning4j.models.word2vec.VocabWord) ND4JIllegalStateException(org.nd4j.linalg.exception.ND4JIllegalStateException) Sequence(org.deeplearning4j.models.sequencevectors.sequence.Sequence)

Example 10 with Sequence

use of org.deeplearning4j.models.sequencevectors.sequence.Sequence in project deeplearning4j by deeplearning4j.

the class SentenceTransformer method transformToSequence.

@Override
public Sequence<VocabWord> transformToSequence(String object) {
    Sequence<VocabWord> sequence = new Sequence<>();
    Tokenizer tokenizer = tokenizerFactory.create(object);
    List<String> list = tokenizer.getTokens();
    for (String token : list) {
        if (token == null || token.isEmpty() || token.trim().isEmpty())
            continue;
        VocabWord word = new VocabWord(1.0, token);
        sequence.addElement(word);
    }
    sequence.setSequenceId(sentenceCounter.getAndIncrement());
    return sequence;
}
Also used : VocabWord(org.deeplearning4j.models.word2vec.VocabWord) Sequence(org.deeplearning4j.models.sequencevectors.sequence.Sequence) Tokenizer(org.deeplearning4j.text.tokenization.tokenizer.Tokenizer)

Aggregations

Sequence (org.deeplearning4j.models.sequencevectors.sequence.Sequence)18 VocabWord (org.deeplearning4j.models.word2vec.VocabWord)11 Test (org.junit.Test)5 JavaSparkContext (org.apache.spark.api.java.JavaSparkContext)4 ShallowSequenceElement (org.deeplearning4j.models.sequencevectors.sequence.ShallowSequenceElement)4 BasicLineIterator (org.deeplearning4j.text.sentenceiterator.BasicLineIterator)4 SentenceIterator (org.deeplearning4j.text.sentenceiterator.SentenceIterator)4 ND4JIllegalStateException (org.nd4j.linalg.exception.ND4JIllegalStateException)3 ArrayList (java.util.ArrayList)2 AtomicBoolean (java.util.concurrent.atomic.AtomicBoolean)2 ClassPathResource (org.datavec.api.util.ClassPathResource)2 SequenceIterator (org.deeplearning4j.models.sequencevectors.interfaces.SequenceIterator)2 AbstractSequenceIterator (org.deeplearning4j.models.sequencevectors.iterators.AbstractSequenceIterator)2 SentenceTransformer (org.deeplearning4j.models.sequencevectors.transformers.impl.SentenceTransformer)2 AbstractCache (org.deeplearning4j.models.word2vec.wordstore.inmemory.AbstractCache)2 FileLabelAwareIterator (org.deeplearning4j.text.documentiterator.FileLabelAwareIterator)2 MutipleEpochsSentenceIterator (org.deeplearning4j.text.sentenceiterator.MutipleEpochsSentenceIterator)2 PrefetchingSentenceIterator (org.deeplearning4j.text.sentenceiterator.PrefetchingSentenceIterator)2 RoutedTransport (org.nd4j.parameterserver.distributed.transport.RoutedTransport)2 List (java.util.List)1