Search in sources :

Example 96 with VocabWord

use of org.deeplearning4j.models.word2vec.VocabWord in project deeplearning4j by deeplearning4j.

the class PopularityWalkerTest method setUp.

@Before
public void setUp() {
    if (graph == null) {
        graph = new Graph<>(10, false, new AbstractVertexFactory<VocabWord>());
        for (int i = 0; i < 10; i++) {
            graph.getVertex(i).setValue(new VocabWord(i, String.valueOf(i)));
            int x = i + 3;
            if (x >= 10)
                x = 0;
            graph.addEdge(i, x, 1.0, false);
        }
        graph.addEdge(0, 4, 1.0, false);
        graph.addEdge(0, 4, 1.0, false);
        graph.addEdge(0, 4, 1.0, false);
        graph.addEdge(4, 5, 1.0, false);
        graph.addEdge(1, 3, 1.0, false);
        graph.addEdge(9, 7, 1.0, false);
        graph.addEdge(5, 6, 1.0, false);
    }
}
Also used : AbstractVertexFactory(org.deeplearning4j.models.sequencevectors.graph.vertex.AbstractVertexFactory) VocabWord(org.deeplearning4j.models.word2vec.VocabWord) Before(org.junit.Before)

Example 97 with VocabWord

use of org.deeplearning4j.models.word2vec.VocabWord in project deeplearning4j by deeplearning4j.

the class RandomWalkerTest method testGraphTraverseRandom1.

@Test
public void testGraphTraverseRandom1() throws Exception {
    RandomWalker<VocabWord> walker = (RandomWalker<VocabWord>) new RandomWalker.Builder<>(graph).setNoEdgeHandling(NoEdgeHandling.SELF_LOOP_ON_DISCONNECTED).setWalkLength(3).build();
    int cnt = 0;
    while (walker.hasNext()) {
        Sequence<VocabWord> sequence = walker.next();
        assertEquals(3, sequence.getElements().size());
        assertNotEquals(null, sequence);
        for (VocabWord word : sequence.getElements()) {
            assertNotEquals(null, word);
        }
        cnt++;
    }
    assertEquals(10, cnt);
}
Also used : VocabWord(org.deeplearning4j.models.word2vec.VocabWord) Test(org.junit.Test)

Example 98 with VocabWord

use of org.deeplearning4j.models.word2vec.VocabWord in project deeplearning4j by deeplearning4j.

the class WeightedWalkerTest method setUp.

@Before
public void setUp() throws Exception {
    if (basicGraph == null) {
        // we don't really care about this graph, since it's just basic graph for iteration checks
        basicGraph = new Graph<>(10, false, new AbstractVertexFactory<VocabWord>());
        for (int i = 0; i < 10; i++) {
            basicGraph.getVertex(i).setValue(new VocabWord(i, String.valueOf(i)));
            int x = i + 3;
            if (x >= 10)
                x = 0;
            basicGraph.addEdge(i, x, 1, false);
        }
        basicGraph.addEdge(0, 4, 2, false);
        basicGraph.addEdge(0, 4, 4, false);
        basicGraph.addEdge(0, 4, 6, false);
        basicGraph.addEdge(4, 5, 8, false);
        basicGraph.addEdge(1, 3, 6, false);
        basicGraph.addEdge(9, 7, 4, false);
        basicGraph.addEdge(5, 6, 2, false);
    }
}
Also used : AbstractVertexFactory(org.deeplearning4j.models.sequencevectors.graph.vertex.AbstractVertexFactory) VocabWord(org.deeplearning4j.models.word2vec.VocabWord) Before(org.junit.Before)

Example 99 with VocabWord

use of org.deeplearning4j.models.word2vec.VocabWord in project deeplearning4j by deeplearning4j.

the class AbstractElementFactoryTest method testDeserialize.

@Test
public void testDeserialize() throws Exception {
    VocabWord word = new VocabWord(1, "word");
    AbstractElementFactory<VocabWord> factory = new AbstractElementFactory<>(VocabWord.class);
    System.out.println("VocabWord JSON: " + word.toJSON());
    VocabWord word2 = factory.deserialize(word.toJSON());
    assertEquals(word, word2);
}
Also used : VocabWord(org.deeplearning4j.models.word2vec.VocabWord) Test(org.junit.Test)

Example 100 with VocabWord

use of org.deeplearning4j.models.word2vec.VocabWord in project deeplearning4j by deeplearning4j.

the class AbstractCacheTest method testHuffman.

@Test
public void testHuffman() throws Exception {
    AbstractCache<VocabWord> cache = new AbstractCache.Builder<VocabWord>().build();
    cache.addToken(new VocabWord(1.0, "word"));
    cache.addToken(new VocabWord(2.0, "test"));
    cache.addToken(new VocabWord(3.0, "tester"));
    assertEquals(3, cache.numWords());
    Huffman huffman = new Huffman(cache.tokens());
    huffman.build();
    huffman.applyIndexes(cache);
    assertEquals("tester", cache.wordAtIndex(0));
    assertEquals("test", cache.wordAtIndex(1));
    assertEquals("word", cache.wordAtIndex(2));
    VocabWord word = cache.tokenFor("tester");
    assertEquals(0, word.getIndex());
}
Also used : Huffman(org.deeplearning4j.models.word2vec.Huffman) VocabWord(org.deeplearning4j.models.word2vec.VocabWord) Test(org.junit.Test)

Aggregations

VocabWord (org.deeplearning4j.models.word2vec.VocabWord)110 Test (org.junit.Test)54 INDArray (org.nd4j.linalg.api.ndarray.INDArray)31 AbstractCache (org.deeplearning4j.models.word2vec.wordstore.inmemory.AbstractCache)26 ClassPathResource (org.datavec.api.util.ClassPathResource)23 BasicLineIterator (org.deeplearning4j.text.sentenceiterator.BasicLineIterator)22 File (java.io.File)20 InMemoryLookupTable (org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable)19 TokenizerFactory (org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory)19 ArrayList (java.util.ArrayList)17 DefaultTokenizerFactory (org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFactory)17 CommonPreprocessor (org.deeplearning4j.text.tokenization.tokenizer.preprocessor.CommonPreprocessor)15 SentenceIterator (org.deeplearning4j.text.sentenceiterator.SentenceIterator)14 AbstractSequenceIterator (org.deeplearning4j.models.sequencevectors.iterators.AbstractSequenceIterator)13 SentenceTransformer (org.deeplearning4j.models.sequencevectors.transformers.impl.SentenceTransformer)13 JavaSparkContext (org.apache.spark.api.java.JavaSparkContext)12 ND4JIllegalStateException (org.nd4j.linalg.exception.ND4JIllegalStateException)12 Sequence (org.deeplearning4j.models.sequencevectors.sequence.Sequence)11 Word2Vec (org.deeplearning4j.models.word2vec.Word2Vec)11 TextPipeline (org.deeplearning4j.spark.text.functions.TextPipeline)10