Search in sources :

Example 26 with VocabWord

use of org.deeplearning4j.models.word2vec.VocabWord in project deeplearning4j by deeplearning4j.

the class SequenceVectorsTest method testInternalVocabConstruction.

@Test
public void testInternalVocabConstruction() throws Exception {
    ClassPathResource resource = new ClassPathResource("big/raw_sentences.txt");
    File file = resource.getFile();
    BasicLineIterator underlyingIterator = new BasicLineIterator(file);
    TokenizerFactory t = new DefaultTokenizerFactory();
    t.setTokenPreProcessor(new CommonPreprocessor());
    SentenceTransformer transformer = new SentenceTransformer.Builder().iterator(underlyingIterator).tokenizerFactory(t).build();
    AbstractSequenceIterator<VocabWord> sequenceIterator = new AbstractSequenceIterator.Builder<>(transformer).build();
    SequenceVectors<VocabWord> vectors = new SequenceVectors.Builder<VocabWord>(new VectorsConfiguration()).minWordFrequency(5).iterate(sequenceIterator).batchSize(250).iterations(1).epochs(1).resetModel(false).trainElementsRepresentation(true).build();
    logger.info("Fitting model...");
    vectors.fit();
    logger.info("Model ready...");
    double sim = vectors.similarity("day", "night");
    logger.info("Day/night similarity: " + sim);
    assertTrue(sim > 0.6d);
    Collection<String> labels = vectors.wordsNearest("day", 10);
    logger.info("Nearest labels to 'day': " + labels);
}
Also used : BasicLineIterator(org.deeplearning4j.text.sentenceiterator.BasicLineIterator) TokenizerFactory(org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory) DefaultTokenizerFactory(org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFactory) VectorsConfiguration(org.deeplearning4j.models.embeddings.loader.VectorsConfiguration) VocabWord(org.deeplearning4j.models.word2vec.VocabWord) SentenceTransformer(org.deeplearning4j.models.sequencevectors.transformers.impl.SentenceTransformer) ClassPathResource(org.datavec.api.util.ClassPathResource) DefaultTokenizerFactory(org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFactory) CommonPreprocessor(org.deeplearning4j.text.tokenization.tokenizer.preprocessor.CommonPreprocessor) AbstractSequenceIterator(org.deeplearning4j.models.sequencevectors.iterators.AbstractSequenceIterator) File(java.io.File) Test(org.junit.Test)

Example 27 with VocabWord

use of org.deeplearning4j.models.word2vec.VocabWord in project deeplearning4j by deeplearning4j.

the class PopularityWalkerTest method testPopularityWalker3.

@Test
public void testPopularityWalker3() throws Exception {
    GraphWalker<VocabWord> walker = new PopularityWalker.Builder<>(graph).setWalkDirection(WalkDirection.FORWARD_ONLY).setNoEdgeHandling(NoEdgeHandling.CUTOFF_ON_DISCONNECTED).setWalkLength(10).setPopularityMode(PopularityMode.MAXIMUM).setPopularitySpread(3).setSpreadSpectrum(SpreadSpectrum.PROPORTIONAL).build();
    System.out.println("Connected [3] size: " + graph.getConnectedVertices(3).size());
    System.out.println("Connected [4] size: " + graph.getConnectedVertices(4).size());
    AtomicBoolean got4 = new AtomicBoolean(false);
    AtomicBoolean got7 = new AtomicBoolean(false);
    AtomicBoolean got9 = new AtomicBoolean(false);
    for (int i = 0; i < 50; i++) {
        Sequence<VocabWord> sequence = walker.next();
        assertEquals("0", sequence.getElements().get(0).getLabel());
        System.out.println("Position at 1: [" + sequence.getElements().get(1).getLabel() + "]");
        got4.compareAndSet(false, sequence.getElements().get(1).getLabel().equals("4"));
        got7.compareAndSet(false, sequence.getElements().get(1).getLabel().equals("7"));
        got9.compareAndSet(false, sequence.getElements().get(1).getLabel().equals("9"));
        assertTrue(sequence.getElements().get(1).getLabel().equals("4") || sequence.getElements().get(1).getLabel().equals("7") || sequence.getElements().get(1).getLabel().equals("9"));
        walker.reset(false);
    }
    assertTrue(got4.get());
    assertTrue(got7.get());
    assertTrue(got9.get());
}
Also used : AtomicBoolean(java.util.concurrent.atomic.AtomicBoolean) VocabWord(org.deeplearning4j.models.word2vec.VocabWord) Test(org.junit.Test)

Example 28 with VocabWord

use of org.deeplearning4j.models.word2vec.VocabWord in project deeplearning4j by deeplearning4j.

the class RandomWalkerTest method testGraphTraverseRandom2.

@Test
public void testGraphTraverseRandom2() throws Exception {
    RandomWalker<VocabWord> walker = (RandomWalker<VocabWord>) new RandomWalker.Builder<>(graph).setNoEdgeHandling(NoEdgeHandling.EXCEPTION_ON_DISCONNECTED).setWalkLength(20).setWalkDirection(WalkDirection.FORWARD_UNIQUE).setNoEdgeHandling(NoEdgeHandling.CUTOFF_ON_DISCONNECTED).build();
    int cnt = 0;
    while (walker.hasNext()) {
        Sequence<VocabWord> sequence = walker.next();
        assertTrue(sequence.getElements().size() <= 10);
        assertNotEquals(null, sequence);
        for (VocabWord word : sequence.getElements()) {
            assertNotEquals(null, word);
        }
        cnt++;
    }
    assertEquals(10, cnt);
}
Also used : VocabWord(org.deeplearning4j.models.word2vec.VocabWord) Test(org.junit.Test)

Example 29 with VocabWord

use of org.deeplearning4j.models.word2vec.VocabWord in project deeplearning4j by deeplearning4j.

the class RandomWalkerTest method setUp.

@Before
public void setUp() throws Exception {
    if (graph == null) {
        graph = new Graph<>(10, false, new AbstractVertexFactory<VocabWord>());
        for (int i = 0; i < 10; i++) {
            graph.getVertex(i).setValue(new VocabWord(i, String.valueOf(i)));
            int x = i + 3;
            if (x >= 10)
                x = 0;
            graph.addEdge(i, x, 1.0, false);
        }
        graphDirected = new Graph<>(10, false, new AbstractVertexFactory<VocabWord>());
        for (int i = 0; i < 10; i++) {
            graphDirected.getVertex(i).setValue(new VocabWord(i, String.valueOf(i)));
            int x = i + 3;
            if (x >= 10)
                x = 0;
            graphDirected.addEdge(i, x, 1.0, true);
        }
        graphBig = new Graph<>(1000, false, new AbstractVertexFactory<VocabWord>());
        for (int i = 0; i < 1000; i++) {
            graphBig.getVertex(i).setValue(new VocabWord(i, String.valueOf(i)));
            int x = i + 3;
            if (x >= 1000)
                x = 0;
            graphBig.addEdge(i, x, 1.0, false);
        }
    }
}
Also used : AbstractVertexFactory(org.deeplearning4j.models.sequencevectors.graph.vertex.AbstractVertexFactory) VocabWord(org.deeplearning4j.models.word2vec.VocabWord) Before(org.junit.Before)

Example 30 with VocabWord

use of org.deeplearning4j.models.word2vec.VocabWord in project deeplearning4j by deeplearning4j.

the class AbstractElementFactoryTest method testSerialize.

@Test
public void testSerialize() throws Exception {
    VocabWord word = new VocabWord(1, "word");
    AbstractElementFactory<VocabWord> factory = new AbstractElementFactory<>(VocabWord.class);
    System.out.println("VocabWord JSON: " + factory.serialize(word));
    VocabWord word2 = factory.deserialize(factory.serialize(word));
    assertEquals(word, word2);
}
Also used : VocabWord(org.deeplearning4j.models.word2vec.VocabWord) Test(org.junit.Test)

Aggregations

VocabWord (org.deeplearning4j.models.word2vec.VocabWord)110 Test (org.junit.Test)54 INDArray (org.nd4j.linalg.api.ndarray.INDArray)31 AbstractCache (org.deeplearning4j.models.word2vec.wordstore.inmemory.AbstractCache)26 ClassPathResource (org.datavec.api.util.ClassPathResource)23 BasicLineIterator (org.deeplearning4j.text.sentenceiterator.BasicLineIterator)22 File (java.io.File)20 InMemoryLookupTable (org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable)19 TokenizerFactory (org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory)19 ArrayList (java.util.ArrayList)17 DefaultTokenizerFactory (org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFactory)17 CommonPreprocessor (org.deeplearning4j.text.tokenization.tokenizer.preprocessor.CommonPreprocessor)15 SentenceIterator (org.deeplearning4j.text.sentenceiterator.SentenceIterator)14 AbstractSequenceIterator (org.deeplearning4j.models.sequencevectors.iterators.AbstractSequenceIterator)13 SentenceTransformer (org.deeplearning4j.models.sequencevectors.transformers.impl.SentenceTransformer)13 JavaSparkContext (org.apache.spark.api.java.JavaSparkContext)12 ND4JIllegalStateException (org.nd4j.linalg.exception.ND4JIllegalStateException)12 Sequence (org.deeplearning4j.models.sequencevectors.sequence.Sequence)11 Word2Vec (org.deeplearning4j.models.word2vec.Word2Vec)11 TextPipeline (org.deeplearning4j.spark.text.functions.TextPipeline)10