Search in sources :

Example 86 with VocabWord

use of org.deeplearning4j.models.word2vec.VocabWord in project deeplearning4j by deeplearning4j.

the class ParallelTransformerIteratorTest method hasNext.

@Test
public void hasNext() throws Exception {
    SentenceIterator iterator = new BasicLineIterator(new ClassPathResource("/big/raw_sentences.txt").getFile());
    SentenceTransformer transformer = new SentenceTransformer.Builder().iterator(iterator).allowMultithreading(true).tokenizerFactory(factory).build();
    Iterator<Sequence<VocabWord>> iter = transformer.iterator();
    int cnt = 0;
    Sequence<VocabWord> sequence = null;
    while (iter.hasNext()) {
        sequence = iter.next();
        assertNotEquals("Failed on [" + cnt + "] iteration", null, sequence);
        assertNotEquals("Failed on [" + cnt + "] iteration", 0, sequence.size());
        cnt++;
    }
    //   log.info("Last element: {}", sequence.asLabels());
    assertEquals(97162, cnt);
}
Also used : BasicLineIterator(org.deeplearning4j.text.sentenceiterator.BasicLineIterator) VocabWord(org.deeplearning4j.models.word2vec.VocabWord) SentenceTransformer(org.deeplearning4j.models.sequencevectors.transformers.impl.SentenceTransformer) Sequence(org.deeplearning4j.models.sequencevectors.sequence.Sequence) PrefetchingSentenceIterator(org.deeplearning4j.text.sentenceiterator.PrefetchingSentenceIterator) SentenceIterator(org.deeplearning4j.text.sentenceiterator.SentenceIterator) MutipleEpochsSentenceIterator(org.deeplearning4j.text.sentenceiterator.MutipleEpochsSentenceIterator) ClassPathResource(org.datavec.api.util.ClassPathResource) Test(org.junit.Test)

Example 87 with VocabWord

use of org.deeplearning4j.models.word2vec.VocabWord in project deeplearning4j by deeplearning4j.

the class Word2VecDataSetIteratorTest method testIterator1.

/**
     * Basically all we want from this test - being able to finish without exceptions.
     */
@Test
public void testIterator1() throws Exception {
    File inputFile = new ClassPathResource("/big/raw_sentences.txt").getFile();
    SentenceIterator iter = new BasicLineIterator(inputFile.getAbsolutePath());
    TokenizerFactory t = new DefaultTokenizerFactory();
    t.setTokenPreProcessor(new CommonPreprocessor());
    Word2Vec vec = // we make sure we'll have some missing words
    new Word2Vec.Builder().minWordFrequency(10).iterations(1).learningRate(0.025).layerSize(150).seed(42).sampling(0).negativeSample(0).useHierarchicSoftmax(true).windowSize(5).modelUtils(new BasicModelUtils<VocabWord>()).useAdaGrad(false).iterate(iter).workers(8).tokenizerFactory(t).elementsLearningAlgorithm(new CBOW<VocabWord>()).build();
    vec.fit();
    List<String> labels = new ArrayList<>();
    labels.add("positive");
    labels.add("negative");
    Word2VecDataSetIterator iterator = new Word2VecDataSetIterator(vec, getLASI(iter, labels), labels, 1);
    INDArray array = iterator.next().getFeatures();
    while (iterator.hasNext()) {
        DataSet ds = iterator.next();
        assertArrayEquals(array.shape(), ds.getFeatureMatrix().shape());
    }
}
Also used : BasicLineIterator(org.deeplearning4j.text.sentenceiterator.BasicLineIterator) TokenizerFactory(org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory) DefaultTokenizerFactory(org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFactory) DataSet(org.nd4j.linalg.dataset.DataSet) ArrayList(java.util.ArrayList) VocabWord(org.deeplearning4j.models.word2vec.VocabWord) ClassPathResource(org.datavec.api.util.ClassPathResource) SentenceIterator(org.deeplearning4j.text.sentenceiterator.SentenceIterator) LabelAwareSentenceIterator(org.deeplearning4j.text.sentenceiterator.labelaware.LabelAwareSentenceIterator) DefaultTokenizerFactory(org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFactory) CommonPreprocessor(org.deeplearning4j.text.tokenization.tokenizer.preprocessor.CommonPreprocessor) INDArray(org.nd4j.linalg.api.ndarray.INDArray) Word2Vec(org.deeplearning4j.models.word2vec.Word2Vec) CBOW(org.deeplearning4j.models.embeddings.learning.impl.elements.CBOW) File(java.io.File) Test(org.junit.Test)

Example 88 with VocabWord

use of org.deeplearning4j.models.word2vec.VocabWord in project deeplearning4j by deeplearning4j.

the class VocabConstructorTest method testCounter1.

@Test
public void testCounter1() throws Exception {
    VocabCache<VocabWord> vocabCache = new AbstractCache.Builder<VocabWord>().build();
    final List<VocabWord> words = new ArrayList<>();
    words.add(new VocabWord(1, "word"));
    words.add(new VocabWord(2, "test"));
    words.add(new VocabWord(1, "here"));
    Iterable<Sequence<VocabWord>> iterable = new Iterable<Sequence<VocabWord>>() {

        @Override
        public Iterator<Sequence<VocabWord>> iterator() {
            return new Iterator<Sequence<VocabWord>>() {

                private AtomicBoolean switcher = new AtomicBoolean(true);

                @Override
                public boolean hasNext() {
                    return switcher.getAndSet(false);
                }

                @Override
                public Sequence<VocabWord> next() {
                    Sequence<VocabWord> sequence = new Sequence<>(words);
                    return sequence;
                }

                @Override
                public void remove() {
                    throw new UnsupportedOperationException();
                }
            };
        }
    };
    SequenceIterator<VocabWord> sequenceIterator = new AbstractSequenceIterator.Builder<>(iterable).build();
    VocabConstructor<VocabWord> constructor = new VocabConstructor.Builder<VocabWord>().addSource(sequenceIterator, 0).useAdaGrad(false).setTargetVocabCache(vocabCache).build();
    constructor.buildJointVocabulary(false, true);
    assertEquals(3, vocabCache.numWords());
    assertEquals(1, vocabCache.wordFrequency("test"));
}
Also used : VocabWord(org.deeplearning4j.models.word2vec.VocabWord) Sequence(org.deeplearning4j.models.sequencevectors.sequence.Sequence) AbstractCache(org.deeplearning4j.models.word2vec.wordstore.inmemory.AbstractCache) AtomicBoolean(java.util.concurrent.atomic.AtomicBoolean) AbstractSequenceIterator(org.deeplearning4j.models.sequencevectors.iterators.AbstractSequenceIterator) FileLabelAwareIterator(org.deeplearning4j.text.documentiterator.FileLabelAwareIterator) SentenceIterator(org.deeplearning4j.text.sentenceiterator.SentenceIterator) AbstractSequenceIterator(org.deeplearning4j.models.sequencevectors.iterators.AbstractSequenceIterator) BasicLineIterator(org.deeplearning4j.text.sentenceiterator.BasicLineIterator) SequenceIterator(org.deeplearning4j.models.sequencevectors.interfaces.SequenceIterator) Test(org.junit.Test)

Example 89 with VocabWord

use of org.deeplearning4j.models.word2vec.VocabWord in project deeplearning4j by deeplearning4j.

the class VocabConstructorTest method testCounter2.

@Test
public void testCounter2() throws Exception {
    VocabCache<VocabWord> vocabCache = new AbstractCache.Builder<VocabWord>().build();
    final List<VocabWord> words = new ArrayList<>();
    words.add(new VocabWord(1, "word"));
    words.add(new VocabWord(0, "test"));
    words.add(new VocabWord(1, "here"));
    Iterable<Sequence<VocabWord>> iterable = new Iterable<Sequence<VocabWord>>() {

        @Override
        public Iterator<Sequence<VocabWord>> iterator() {
            return new Iterator<Sequence<VocabWord>>() {

                private AtomicBoolean switcher = new AtomicBoolean(true);

                @Override
                public boolean hasNext() {
                    return switcher.getAndSet(false);
                }

                @Override
                public Sequence<VocabWord> next() {
                    Sequence<VocabWord> sequence = new Sequence<>(words);
                    return sequence;
                }

                @Override
                public void remove() {
                    throw new UnsupportedOperationException();
                }
            };
        }
    };
    SequenceIterator<VocabWord> sequenceIterator = new AbstractSequenceIterator.Builder<>(iterable).build();
    VocabConstructor<VocabWord> constructor = new VocabConstructor.Builder<VocabWord>().addSource(sequenceIterator, 0).useAdaGrad(false).setTargetVocabCache(vocabCache).build();
    constructor.buildJointVocabulary(false, true);
    assertEquals(3, vocabCache.numWords());
    assertEquals(1, vocabCache.wordFrequency("test"));
}
Also used : VocabWord(org.deeplearning4j.models.word2vec.VocabWord) Sequence(org.deeplearning4j.models.sequencevectors.sequence.Sequence) AbstractCache(org.deeplearning4j.models.word2vec.wordstore.inmemory.AbstractCache) AtomicBoolean(java.util.concurrent.atomic.AtomicBoolean) AbstractSequenceIterator(org.deeplearning4j.models.sequencevectors.iterators.AbstractSequenceIterator) FileLabelAwareIterator(org.deeplearning4j.text.documentiterator.FileLabelAwareIterator) SentenceIterator(org.deeplearning4j.text.sentenceiterator.SentenceIterator) AbstractSequenceIterator(org.deeplearning4j.models.sequencevectors.iterators.AbstractSequenceIterator) BasicLineIterator(org.deeplearning4j.text.sentenceiterator.BasicLineIterator) SequenceIterator(org.deeplearning4j.models.sequencevectors.interfaces.SequenceIterator) Test(org.junit.Test)

Example 90 with VocabWord

use of org.deeplearning4j.models.word2vec.VocabWord in project deeplearning4j by deeplearning4j.

the class VocabConstructorTest method testMergedVocab1.

/**
     * Here we test basic vocab transfer, done WITHOUT labels
     * @throws Exception
     */
@Test
public void testMergedVocab1() throws Exception {
    AbstractCache<VocabWord> cacheSource = new AbstractCache.Builder<VocabWord>().build();
    AbstractCache<VocabWord> cacheTarget = new AbstractCache.Builder<VocabWord>().build();
    ClassPathResource resource = new ClassPathResource("big/raw_sentences.txt");
    BasicLineIterator underlyingIterator = new BasicLineIterator(resource.getFile());
    SentenceTransformer transformer = new SentenceTransformer.Builder().iterator(underlyingIterator).tokenizerFactory(t).build();
    AbstractSequenceIterator<VocabWord> sequenceIterator = new AbstractSequenceIterator.Builder<>(transformer).build();
    VocabConstructor<VocabWord> vocabConstructor = new VocabConstructor.Builder<VocabWord>().addSource(sequenceIterator, 1).setTargetVocabCache(cacheSource).build();
    vocabConstructor.buildJointVocabulary(false, true);
    int sourceSize = cacheSource.numWords();
    log.info("Source Vocab size: " + sourceSize);
    VocabConstructor<VocabWord> vocabTransfer = new VocabConstructor.Builder<VocabWord>().addSource(sequenceIterator, 1).setTargetVocabCache(cacheTarget).build();
    vocabTransfer.buildMergedVocabulary(cacheSource, false);
    assertEquals(sourceSize, cacheTarget.numWords());
}
Also used : BasicLineIterator(org.deeplearning4j.text.sentenceiterator.BasicLineIterator) VocabWord(org.deeplearning4j.models.word2vec.VocabWord) SentenceTransformer(org.deeplearning4j.models.sequencevectors.transformers.impl.SentenceTransformer) AbstractCache(org.deeplearning4j.models.word2vec.wordstore.inmemory.AbstractCache) ClassPathResource(org.datavec.api.util.ClassPathResource) AbstractSequenceIterator(org.deeplearning4j.models.sequencevectors.iterators.AbstractSequenceIterator) Test(org.junit.Test)

Aggregations

VocabWord (org.deeplearning4j.models.word2vec.VocabWord)110 Test (org.junit.Test)54 INDArray (org.nd4j.linalg.api.ndarray.INDArray)31 AbstractCache (org.deeplearning4j.models.word2vec.wordstore.inmemory.AbstractCache)26 ClassPathResource (org.datavec.api.util.ClassPathResource)23 BasicLineIterator (org.deeplearning4j.text.sentenceiterator.BasicLineIterator)22 File (java.io.File)20 InMemoryLookupTable (org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable)19 TokenizerFactory (org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory)19 ArrayList (java.util.ArrayList)17 DefaultTokenizerFactory (org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFactory)17 CommonPreprocessor (org.deeplearning4j.text.tokenization.tokenizer.preprocessor.CommonPreprocessor)15 SentenceIterator (org.deeplearning4j.text.sentenceiterator.SentenceIterator)14 AbstractSequenceIterator (org.deeplearning4j.models.sequencevectors.iterators.AbstractSequenceIterator)13 SentenceTransformer (org.deeplearning4j.models.sequencevectors.transformers.impl.SentenceTransformer)13 JavaSparkContext (org.apache.spark.api.java.JavaSparkContext)12 ND4JIllegalStateException (org.nd4j.linalg.exception.ND4JIllegalStateException)12 Sequence (org.deeplearning4j.models.sequencevectors.sequence.Sequence)11 Word2Vec (org.deeplearning4j.models.word2vec.Word2Vec)11 TextPipeline (org.deeplearning4j.spark.text.functions.TextPipeline)10