Search in sources :

Example 1 with LabelAwareIterator

use of org.deeplearning4j.text.documentiterator.LabelAwareIterator in project deeplearning4j by deeplearning4j.

the class ParagraphVectorsTest method testParagraphVectorsReducedLabels1.

/**
     * This test is not indicative.
     * there's no need in this test within travis, use it manually only for problems detection
     *
     * @throws Exception
     */
@Test
@Ignore
public void testParagraphVectorsReducedLabels1() throws Exception {
    ClassPathResource resource = new ClassPathResource("/labeled");
    File file = resource.getFile();
    LabelAwareIterator iter = new FileLabelAwareIterator.Builder().addSourceFolder(file).build();
    TokenizerFactory t = new DefaultTokenizerFactory();
    /**
         * Please note: text corpus is REALLY small, and some kind of "results" could be received with HIGH epochs number, like 30.
         * But there's no reason to keep at that high
         */
    ParagraphVectors vec = new ParagraphVectors.Builder().minWordFrequency(1).epochs(3).layerSize(100).stopWords(new ArrayList<String>()).windowSize(5).iterate(iter).tokenizerFactory(t).build();
    vec.fit();
    //WordVectorSerializer.writeWordVectors(vec, "vectors.txt");
    INDArray w1 = vec.lookupTable().vector("I");
    INDArray w2 = vec.lookupTable().vector("am");
    INDArray w3 = vec.lookupTable().vector("sad.");
    INDArray words = Nd4j.create(3, vec.lookupTable().layerSize());
    words.putRow(0, w1);
    words.putRow(1, w2);
    words.putRow(2, w3);
    INDArray mean = words.isMatrix() ? words.mean(0) : words;
    log.info("Mean" + Arrays.toString(mean.dup().data().asDouble()));
    log.info("Array" + Arrays.toString(vec.lookupTable().vector("negative").dup().data().asDouble()));
    double simN = Transforms.cosineSim(mean, vec.lookupTable().vector("negative"));
    log.info("Similarity negative: " + simN);
    double simP = Transforms.cosineSim(mean, vec.lookupTable().vector("neutral"));
    log.info("Similarity neutral: " + simP);
    double simV = Transforms.cosineSim(mean, vec.lookupTable().vector("positive"));
    log.info("Similarity positive: " + simV);
}
Also used : DefaultTokenizerFactory(org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFactory) TokenizerFactory(org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory) DefaultTokenizerFactory(org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFactory) INDArray(org.nd4j.linalg.api.ndarray.INDArray) LabelAwareIterator(org.deeplearning4j.text.documentiterator.LabelAwareIterator) FileLabelAwareIterator(org.deeplearning4j.text.documentiterator.FileLabelAwareIterator) File(java.io.File) ClassPathResource(org.datavec.api.util.ClassPathResource) Ignore(org.junit.Ignore) Test(org.junit.Test)

Example 2 with LabelAwareIterator

use of org.deeplearning4j.text.documentiterator.LabelAwareIterator in project deeplearning4j by deeplearning4j.

the class ParallelTransformerIteratorTest method testSpeedComparison1.

@Test
public void testSpeedComparison1() throws Exception {
    SentenceIterator iterator = new MutipleEpochsSentenceIterator(new BasicLineIterator(new ClassPathResource("/big/raw_sentences.txt").getFile()), 25);
    SentenceTransformer transformer = new SentenceTransformer.Builder().iterator(iterator).allowMultithreading(false).tokenizerFactory(factory).build();
    Iterator<Sequence<VocabWord>> iter = transformer.iterator();
    int cnt = 0;
    long time1 = System.currentTimeMillis();
    while (iter.hasNext()) {
        Sequence<VocabWord> sequence = iter.next();
        assertNotEquals("Failed on [" + cnt + "] iteration", null, sequence);
        assertNotEquals("Failed on [" + cnt + "] iteration", 0, sequence.size());
        cnt++;
    }
    long time2 = System.currentTimeMillis();
    log.info("Single-threaded time: {} ms", time2 - time1);
    iterator.reset();
    transformer = new SentenceTransformer.Builder().iterator(iterator).allowMultithreading(true).tokenizerFactory(factory).build();
    iter = transformer.iterator();
    time1 = System.currentTimeMillis();
    while (iter.hasNext()) {
        Sequence<VocabWord> sequence = iter.next();
        assertNotEquals("Failed on [" + cnt + "] iteration", null, sequence);
        assertNotEquals("Failed on [" + cnt + "] iteration", 0, sequence.size());
        cnt++;
    }
    time2 = System.currentTimeMillis();
    log.info("Multi-threaded time: {} ms", time2 - time1);
    SentenceIterator baseIterator = iterator;
    baseIterator.reset();
    LabelAwareIterator lai = new BasicLabelAwareIterator.Builder(new MutipleEpochsSentenceIterator(new BasicLineIterator(new ClassPathResource("/big/raw_sentences.txt").getFile()), 25)).build();
    transformer = new SentenceTransformer.Builder().iterator(lai).allowMultithreading(false).tokenizerFactory(factory).build();
    iter = transformer.iterator();
    time1 = System.currentTimeMillis();
    while (iter.hasNext()) {
        Sequence<VocabWord> sequence = iter.next();
        assertNotEquals("Failed on [" + cnt + "] iteration", null, sequence);
        assertNotEquals("Failed on [" + cnt + "] iteration", 0, sequence.size());
        cnt++;
    }
    time2 = System.currentTimeMillis();
    log.info("Prefetched Single-threaded time: {} ms", time2 - time1);
    lai.reset();
    transformer = new SentenceTransformer.Builder().iterator(lai).allowMultithreading(true).tokenizerFactory(factory).build();
    iter = transformer.iterator();
    time1 = System.currentTimeMillis();
    while (iter.hasNext()) {
        Sequence<VocabWord> sequence = iter.next();
        assertNotEquals("Failed on [" + cnt + "] iteration", null, sequence);
        assertNotEquals("Failed on [" + cnt + "] iteration", 0, sequence.size());
        cnt++;
    }
    time2 = System.currentTimeMillis();
    log.info("Prefetched Multi-threaded time: {} ms", time2 - time1);
}
Also used : BasicLineIterator(org.deeplearning4j.text.sentenceiterator.BasicLineIterator) MutipleEpochsSentenceIterator(org.deeplearning4j.text.sentenceiterator.MutipleEpochsSentenceIterator) BasicLabelAwareIterator(org.deeplearning4j.text.documentiterator.BasicLabelAwareIterator) AsyncLabelAwareIterator(org.deeplearning4j.text.documentiterator.AsyncLabelAwareIterator) BasicLabelAwareIterator(org.deeplearning4j.text.documentiterator.BasicLabelAwareIterator) LabelAwareIterator(org.deeplearning4j.text.documentiterator.LabelAwareIterator) VocabWord(org.deeplearning4j.models.word2vec.VocabWord) SentenceTransformer(org.deeplearning4j.models.sequencevectors.transformers.impl.SentenceTransformer) Sequence(org.deeplearning4j.models.sequencevectors.sequence.Sequence) PrefetchingSentenceIterator(org.deeplearning4j.text.sentenceiterator.PrefetchingSentenceIterator) SentenceIterator(org.deeplearning4j.text.sentenceiterator.SentenceIterator) MutipleEpochsSentenceIterator(org.deeplearning4j.text.sentenceiterator.MutipleEpochsSentenceIterator) ClassPathResource(org.datavec.api.util.ClassPathResource) Test(org.junit.Test)

Aggregations

ClassPathResource (org.datavec.api.util.ClassPathResource)2 LabelAwareIterator (org.deeplearning4j.text.documentiterator.LabelAwareIterator)2 Test (org.junit.Test)2 File (java.io.File)1 Sequence (org.deeplearning4j.models.sequencevectors.sequence.Sequence)1 SentenceTransformer (org.deeplearning4j.models.sequencevectors.transformers.impl.SentenceTransformer)1 VocabWord (org.deeplearning4j.models.word2vec.VocabWord)1 AsyncLabelAwareIterator (org.deeplearning4j.text.documentiterator.AsyncLabelAwareIterator)1 BasicLabelAwareIterator (org.deeplearning4j.text.documentiterator.BasicLabelAwareIterator)1 FileLabelAwareIterator (org.deeplearning4j.text.documentiterator.FileLabelAwareIterator)1 BasicLineIterator (org.deeplearning4j.text.sentenceiterator.BasicLineIterator)1 MutipleEpochsSentenceIterator (org.deeplearning4j.text.sentenceiterator.MutipleEpochsSentenceIterator)1 PrefetchingSentenceIterator (org.deeplearning4j.text.sentenceiterator.PrefetchingSentenceIterator)1 SentenceIterator (org.deeplearning4j.text.sentenceiterator.SentenceIterator)1 DefaultTokenizerFactory (org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFactory)1 TokenizerFactory (org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory)1 Ignore (org.junit.Ignore)1 INDArray (org.nd4j.linalg.api.ndarray.INDArray)1