Search in sources :

Example 11 with WordVectors

use of org.deeplearning4j.models.embeddings.wordvectors.WordVectors in project deeplearning4j by deeplearning4j.

the class WordVectorSerializerTest method testUnifiedLoaderArchive1.

@Test
public void testUnifiedLoaderArchive1() throws Exception {
    logger.info("Executor name: {}", Nd4j.getExecutioner().getClass().getSimpleName());
    File w2v = new ClassPathResource("word2vec.dl4j/file.w2v").getFile();
    WordVectors vectorsLive = WordVectorSerializer.readWord2Vec(w2v);
    WordVectors vectorsUnified = WordVectorSerializer.readWord2VecModel(w2v, false);
    INDArray arrayLive = vectorsLive.getWordVectorMatrix("night");
    INDArray arrayStatic = vectorsUnified.getWordVectorMatrix("night");
    assertNotEquals(null, arrayLive);
    assertEquals(arrayLive, arrayStatic);
    assertEquals(null, ((InMemoryLookupTable) vectorsUnified.lookupTable()).getSyn1());
    assertEquals(null, ((InMemoryLookupTable) vectorsUnified.lookupTable()).getSyn1Neg());
}
Also used : INDArray(org.nd4j.linalg.api.ndarray.INDArray) WordVectors(org.deeplearning4j.models.embeddings.wordvectors.WordVectors) File(java.io.File) ClassPathResource(org.datavec.api.util.ClassPathResource) Test(org.junit.Test)

Example 12 with WordVectors

use of org.deeplearning4j.models.embeddings.wordvectors.WordVectors in project deeplearning4j by deeplearning4j.

the class ParagraphVectorsTest method testGoogleModelForInference.

@Ignore
@Test
public void testGoogleModelForInference() throws Exception {
    WordVectors googleVectors = WordVectorSerializer.loadGoogleModelNonNormalized(new File("/ext/GoogleNews-vectors-negative300.bin.gz"), true, false);
    TokenizerFactory t = new DefaultTokenizerFactory();
    t.setTokenPreProcessor(new CommonPreprocessor());
    ParagraphVectors pv = new ParagraphVectors.Builder().tokenizerFactory(t).iterations(10).useHierarchicSoftmax(false).trainWordVectors(false).iterations(10).useExistingWordVectors(googleVectors).negativeSample(10).sequenceLearningAlgorithm(new DM<VocabWord>()).build();
    INDArray vec1 = pv.inferVector("This text is pretty awesome");
    INDArray vec2 = pv.inferVector("Fantastic process of crazy things happening inside just for history purposes");
    log.info("vec1/vec2: {}", Transforms.cosineSim(vec1, vec2));
}
Also used : DefaultTokenizerFactory(org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFactory) CommonPreprocessor(org.deeplearning4j.text.tokenization.tokenizer.preprocessor.CommonPreprocessor) TokenizerFactory(org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory) DefaultTokenizerFactory(org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFactory) INDArray(org.nd4j.linalg.api.ndarray.INDArray) DM(org.deeplearning4j.models.embeddings.learning.impl.sequence.DM) WordVectors(org.deeplearning4j.models.embeddings.wordvectors.WordVectors) File(java.io.File) Ignore(org.junit.Ignore) Test(org.junit.Test)

Example 13 with WordVectors

use of org.deeplearning4j.models.embeddings.wordvectors.WordVectors in project deeplearning4j by deeplearning4j.

the class TestCnnSentenceDataSetIterator method testSentenceIterator.

@Test
public void testSentenceIterator() throws Exception {
    WordVectors w2v = WordVectorSerializer.readWord2VecModel(new ClassPathResource("word2vec/googleload/sample_vec.bin").getFile());
    int vectorSize = w2v.lookupTable().layerSize();
    //        Collection<String> words = w2v.lookupTable().getVocabCache().words();
    //        for(String s : words){
    //            System.out.println(s);
    //        }
    List<String> sentences = new ArrayList<>();
    //First word: all present
    sentences.add("these balance Database model");
    sentences.add("into same THISWORDDOESNTEXIST are");
    int maxLength = 4;
    List<String> s1 = Arrays.asList("these", "balance", "Database", "model");
    List<String> s2 = Arrays.asList("into", "same", "are");
    List<String> labelsForSentences = Arrays.asList("Positive", "Negative");
    //Order of labels: alphabetic. Positive -> [0,1]
    INDArray expLabels = Nd4j.create(new double[][] { { 0, 1 }, { 1, 0 } });
    boolean[] alongHeightVals = new boolean[] { true, false };
    for (boolean alongHeight : alongHeightVals) {
        INDArray expectedFeatures;
        if (alongHeight) {
            expectedFeatures = Nd4j.create(2, 1, maxLength, vectorSize);
        } else {
            expectedFeatures = Nd4j.create(2, 1, vectorSize, maxLength);
        }
        INDArray expectedFeatureMask = Nd4j.create(new double[][] { { 1, 1, 1, 1 }, { 1, 1, 1, 0 } });
        for (int i = 0; i < 4; i++) {
            if (alongHeight) {
                expectedFeatures.get(NDArrayIndex.point(0), NDArrayIndex.point(0), NDArrayIndex.point(i), NDArrayIndex.all()).assign(w2v.getWordVectorMatrix(s1.get(i)));
            } else {
                expectedFeatures.get(NDArrayIndex.point(0), NDArrayIndex.point(0), NDArrayIndex.all(), NDArrayIndex.point(i)).assign(w2v.getWordVectorMatrix(s1.get(i)));
            }
        }
        for (int i = 0; i < 3; i++) {
            if (alongHeight) {
                expectedFeatures.get(NDArrayIndex.point(1), NDArrayIndex.point(0), NDArrayIndex.point(i), NDArrayIndex.all()).assign(w2v.getWordVectorMatrix(s2.get(i)));
            } else {
                expectedFeatures.get(NDArrayIndex.point(1), NDArrayIndex.point(0), NDArrayIndex.all(), NDArrayIndex.point(i)).assign(w2v.getWordVectorMatrix(s2.get(i)));
            }
        }
        LabeledSentenceProvider p = new CollectionLabeledSentenceProvider(sentences, labelsForSentences, null);
        CnnSentenceDataSetIterator dsi = new CnnSentenceDataSetIterator.Builder().sentenceProvider(p).wordVectors(w2v).maxSentenceLength(256).minibatchSize(32).sentencesAlongHeight(alongHeight).build();
        //            System.out.println("alongHeight = " + alongHeight);
        DataSet ds = dsi.next();
        assertArrayEquals(expectedFeatures.shape(), ds.getFeatures().shape());
        assertEquals(expectedFeatures, ds.getFeatures());
        assertEquals(expLabels, ds.getLabels());
        assertEquals(expectedFeatureMask, ds.getFeaturesMaskArray());
        assertNull(ds.getLabelsMaskArray());
        INDArray s1F = dsi.loadSingleSentence(sentences.get(0));
        INDArray s2F = dsi.loadSingleSentence(sentences.get(1));
        INDArray sub1 = ds.getFeatures().get(NDArrayIndex.interval(0, 0, true), NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.all());
        INDArray sub2;
        if (alongHeight) {
            sub2 = ds.getFeatures().get(NDArrayIndex.interval(1, 1, true), NDArrayIndex.all(), NDArrayIndex.interval(0, 3), NDArrayIndex.all());
        } else {
            sub2 = ds.getFeatures().get(NDArrayIndex.interval(1, 1, true), NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.interval(0, 3));
        }
        assertArrayEquals(sub1.shape(), s1F.shape());
        assertArrayEquals(sub2.shape(), s2F.shape());
        assertEquals(sub1, s1F);
        assertEquals(sub2, s2F);
    }
}
Also used : DataSet(org.nd4j.linalg.dataset.api.DataSet) ArrayList(java.util.ArrayList) CollectionLabeledSentenceProvider(org.deeplearning4j.iterator.provider.CollectionLabeledSentenceProvider) ClassPathResource(org.datavec.api.util.ClassPathResource) INDArray(org.nd4j.linalg.api.ndarray.INDArray) CollectionLabeledSentenceProvider(org.deeplearning4j.iterator.provider.CollectionLabeledSentenceProvider) WordVectors(org.deeplearning4j.models.embeddings.wordvectors.WordVectors) Test(org.junit.Test)

Example 14 with WordVectors

use of org.deeplearning4j.models.embeddings.wordvectors.WordVectors in project deeplearning4j by deeplearning4j.

the class GloveTest method testGlove.

@Test
public void testGlove() throws Exception {
    Glove glove = new Glove(true, 5, 100);
    JavaRDD<String> corpus = sc.textFile(new ClassPathResource("raw_sentences.txt").getFile().getAbsolutePath()).map(new Function<String, String>() {

        @Override
        public String call(String s) throws Exception {
            return s.toLowerCase();
        }
    });
    Pair<VocabCache<VocabWord>, GloveWeightLookupTable> table = glove.train(corpus);
    WordVectors vectors = WordVectorSerializer.fromPair(new Pair<>((InMemoryLookupTable) table.getSecond(), (VocabCache) table.getFirst()));
    Collection<String> words = vectors.wordsNearest("day", 20);
    assertTrue(words.contains("week"));
}
Also used : GloveWeightLookupTable(org.deeplearning4j.models.glove.GloveWeightLookupTable) ClassPathResource(org.datavec.api.util.ClassPathResource) InMemoryLookupTable(org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable) VocabCache(org.deeplearning4j.models.word2vec.wordstore.VocabCache) WordVectors(org.deeplearning4j.models.embeddings.wordvectors.WordVectors) Test(org.junit.Test) BaseSparkTest(org.deeplearning4j.spark.text.BaseSparkTest)

Example 15 with WordVectors

use of org.deeplearning4j.models.embeddings.wordvectors.WordVectors in project deeplearning4j by deeplearning4j.

the class Word2VecTest method testPortugeseW2V.

@Test
@Ignore
public void testPortugeseW2V() throws Exception {
    WordVectors word2Vec = WordVectorSerializer.loadTxtVectors(new File("/ext/Temp/para.txt"));
    word2Vec.setModelUtils(new FlatModelUtils());
    Collection<String> portu = word2Vec.wordsNearest("carro", 10);
    printWords("carro", portu, word2Vec);
    portu = word2Vec.wordsNearest("davi", 10);
    printWords("davi", portu, word2Vec);
}
Also used : FlatModelUtils(org.deeplearning4j.models.embeddings.reader.impl.FlatModelUtils) WordVectors(org.deeplearning4j.models.embeddings.wordvectors.WordVectors) File(java.io.File) Ignore(org.junit.Ignore) Test(org.junit.Test)

Aggregations

WordVectors (org.deeplearning4j.models.embeddings.wordvectors.WordVectors)26 Test (org.junit.Test)26 File (java.io.File)15 INDArray (org.nd4j.linalg.api.ndarray.INDArray)15 ClassPathResource (org.datavec.api.util.ClassPathResource)10 Ignore (org.junit.Ignore)9 CommonPreprocessor (org.deeplearning4j.text.tokenization.tokenizer.preprocessor.CommonPreprocessor)6 DefaultTokenizerFactory (org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFactory)6 TokenizerFactory (org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory)6 InMemoryLookupTable (org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable)4 SentenceIterator (org.deeplearning4j.text.sentenceiterator.SentenceIterator)4 ArrayList (java.util.ArrayList)3 Word2Vec (org.deeplearning4j.models.word2vec.Word2Vec)3 VocabCache (org.deeplearning4j.models.word2vec.wordstore.VocabCache)3 InMemoryLookupCache (org.deeplearning4j.models.word2vec.wordstore.inmemory.InMemoryLookupCache)3 UimaSentenceIterator (org.deeplearning4j.text.sentenceiterator.UimaSentenceIterator)3 VocabWord (org.deeplearning4j.models.word2vec.VocabWord)2 BasicLineIterator (org.deeplearning4j.text.sentenceiterator.BasicLineIterator)2 FileInputStream (java.io.FileInputStream)1 FileOutputStream (java.io.FileOutputStream)1