Search in sources :

Example 16 with InMemoryLookupTable

use of org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable in project deeplearning4j by deeplearning4j.

the class Word2VecTest method testConcepts.

@Test
public void testConcepts() throws Exception {
    // These are all default values for word2vec
    SparkConf sparkConf = new SparkConf().setMaster("local[8]").setAppName("sparktest");
    // Set SparkContext
    JavaSparkContext sc = new JavaSparkContext(sparkConf);
    // Path of data part-00000
    String dataPath = new ClassPathResource("raw_sentences.txt").getFile().getAbsolutePath();
    //        dataPath = "/ext/Temp/part-00000";
    //        String dataPath = new ClassPathResource("spark_word2vec_test.txt").getFile().getAbsolutePath();
    // Read in data
    JavaRDD<String> corpus = sc.textFile(dataPath);
    TokenizerFactory t = new DefaultTokenizerFactory();
    t.setTokenPreProcessor(new CommonPreprocessor());
    Word2Vec word2Vec = new Word2Vec.Builder().setNGrams(1).tokenizerFactory(t).seed(42L).negative(10).useAdaGrad(false).layerSize(150).windowSize(5).learningRate(0.025).minLearningRate(0.0001).iterations(1).batchSize(100).minWordFrequency(5).stopWords(Arrays.asList("three")).useUnknown(true).build();
    word2Vec.train(corpus);
    //word2Vec.setModelUtils(new FlatModelUtils());
    System.out.println("UNK: " + word2Vec.getWordVectorMatrix("UNK"));
    InMemoryLookupTable<VocabWord> table = (InMemoryLookupTable<VocabWord>) word2Vec.lookupTable();
    double sim = word2Vec.similarity("day", "night");
    System.out.println("day/night similarity: " + sim);
    /*
        System.out.println("Hornjo: " + word2Vec.getWordVectorMatrix("hornjoserbsce"));
        System.out.println("carro: " + word2Vec.getWordVectorMatrix("carro"));
        
        Collection<String> portu = word2Vec.wordsNearest("carro", 10);
        printWords("carro", portu, word2Vec);
        
        portu = word2Vec.wordsNearest("davi", 10);
        printWords("davi", portu, word2Vec);
        
        System.out.println("---------------------------------------");
        */
    Collection<String> words = word2Vec.wordsNearest("day", 10);
    printWords("day", words, word2Vec);
    assertTrue(words.contains("night"));
    assertTrue(words.contains("week"));
    assertTrue(words.contains("year"));
    sim = word2Vec.similarity("two", "four");
    System.out.println("two/four similarity: " + sim);
    words = word2Vec.wordsNearest("two", 10);
    printWords("two", words, word2Vec);
    // three should be absent due to stopWords
    assertFalse(words.contains("three"));
    assertTrue(words.contains("five"));
    assertTrue(words.contains("four"));
    sc.stop();
    // test serialization
    File tempFile = File.createTempFile("temp", "tmp");
    tempFile.deleteOnExit();
    int idx1 = word2Vec.vocab().wordFor("day").getIndex();
    INDArray array1 = word2Vec.getWordVectorMatrix("day").dup();
    VocabWord word1 = word2Vec.vocab().elementAtIndex(0);
    WordVectorSerializer.writeWordVectors(word2Vec.getLookupTable(), tempFile);
    WordVectors vectors = WordVectorSerializer.loadTxtVectors(tempFile);
    VocabWord word2 = ((VocabCache<VocabWord>) vectors.vocab()).elementAtIndex(0);
    VocabWord wordIT = ((VocabCache<VocabWord>) vectors.vocab()).wordFor("it");
    int idx2 = vectors.vocab().wordFor("day").getIndex();
    INDArray array2 = vectors.getWordVectorMatrix("day").dup();
    System.out.println("word 'i': " + word2);
    System.out.println("word 'it': " + wordIT);
    assertEquals(idx1, idx2);
    assertEquals(word1, word2);
    assertEquals(array1, array2);
}
Also used : TokenizerFactory(org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory) DefaultTokenizerFactory(org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFactory) VocabWord(org.deeplearning4j.models.word2vec.VocabWord) ClassPathResource(org.datavec.api.util.ClassPathResource) DefaultTokenizerFactory(org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFactory) CommonPreprocessor(org.deeplearning4j.text.tokenization.tokenizer.preprocessor.CommonPreprocessor) InMemoryLookupTable(org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable) INDArray(org.nd4j.linalg.api.ndarray.INDArray) VocabCache(org.deeplearning4j.models.word2vec.wordstore.VocabCache) JavaSparkContext(org.apache.spark.api.java.JavaSparkContext) WordVectors(org.deeplearning4j.models.embeddings.wordvectors.WordVectors) SparkConf(org.apache.spark.SparkConf) File(java.io.File) Test(org.junit.Test)

Example 17 with InMemoryLookupTable

use of org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable in project deeplearning4j by deeplearning4j.

the class WordVectorSerializerTest method testWriteWordVectors.

@Test
@Ignore
public void testWriteWordVectors() throws IOException {
    WordVectors vec = WordVectorSerializer.loadGoogleModel(binaryFile, true);
    InMemoryLookupTable lookupTable = (InMemoryLookupTable) vec.lookupTable();
    InMemoryLookupCache lookupCache = (InMemoryLookupCache) vec.vocab();
    WordVectorSerializer.writeWordVectors(lookupTable, lookupCache, pathToWriteto);
    WordVectors wordVectors = WordVectorSerializer.loadTxtVectors(new File(pathToWriteto));
    double[] wordVector1 = wordVectors.getWordVector("Morgan_Freeman");
    double[] wordVector2 = wordVectors.getWordVector("JA_Montalbano");
    assertTrue(wordVector1.length == 300);
    assertTrue(wordVector2.length == 300);
    assertEquals(Doubles.asList(wordVector1).get(0), 0.044423, 1e-3);
    assertEquals(Doubles.asList(wordVector2).get(0), 0.051964, 1e-3);
}
Also used : InMemoryLookupTable(org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable) WordVectors(org.deeplearning4j.models.embeddings.wordvectors.WordVectors) File(java.io.File) InMemoryLookupCache(org.deeplearning4j.models.word2vec.wordstore.inmemory.InMemoryLookupCache) Ignore(org.junit.Ignore) Test(org.junit.Test)

Example 18 with InMemoryLookupTable

use of org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable in project deeplearning4j by deeplearning4j.

the class WordVectorSerializerTest method testParaVecSerialization1.

@Test
public void testParaVecSerialization1() throws Exception {
    VectorsConfiguration configuration = new VectorsConfiguration();
    configuration.setIterations(14123);
    configuration.setLayersSize(156);
    INDArray syn0 = Nd4j.rand(100, configuration.getLayersSize());
    INDArray syn1 = Nd4j.rand(100, configuration.getLayersSize());
    AbstractCache<VocabWord> cache = new AbstractCache.Builder<VocabWord>().build();
    for (int i = 0; i < 100; i++) {
        VocabWord word = new VocabWord((float) i, "word_" + i);
        List<Integer> points = new ArrayList<>();
        List<Byte> codes = new ArrayList<>();
        int num = org.apache.commons.lang3.RandomUtils.nextInt(1, 20);
        for (int x = 0; x < num; x++) {
            points.add(org.apache.commons.lang3.RandomUtils.nextInt(1, 100000));
            codes.add(org.apache.commons.lang3.RandomUtils.nextBytes(10)[0]);
        }
        if (RandomUtils.nextInt(10) < 3) {
            word.markAsLabel(true);
        }
        word.setIndex(i);
        word.setPoints(points);
        word.setCodes(codes);
        cache.addToken(word);
        cache.addWordToIndex(i, word.getLabel());
    }
    InMemoryLookupTable<VocabWord> lookupTable = (InMemoryLookupTable<VocabWord>) new InMemoryLookupTable.Builder<VocabWord>().vectorLength(configuration.getLayersSize()).cache(cache).build();
    lookupTable.setSyn0(syn0);
    lookupTable.setSyn1(syn1);
    ParagraphVectors originalVectors = new ParagraphVectors.Builder(configuration).vocabCache(cache).lookupTable(lookupTable).build();
    File tempFile = File.createTempFile("paravec", "tests");
    tempFile.deleteOnExit();
    WordVectorSerializer.writeParagraphVectors(originalVectors, tempFile);
    ParagraphVectors restoredVectors = WordVectorSerializer.readParagraphVectors(tempFile);
    InMemoryLookupTable<VocabWord> restoredLookupTable = (InMemoryLookupTable<VocabWord>) restoredVectors.getLookupTable();
    AbstractCache<VocabWord> restoredVocab = (AbstractCache<VocabWord>) restoredVectors.getVocab();
    assertEquals(restoredLookupTable.getSyn0(), lookupTable.getSyn0());
    assertEquals(restoredLookupTable.getSyn1(), lookupTable.getSyn1());
    for (int i = 0; i < cache.numWords(); i++) {
        assertEquals(cache.elementAtIndex(i).isLabel(), restoredVocab.elementAtIndex(i).isLabel());
        assertEquals(cache.wordAtIndex(i), restoredVocab.wordAtIndex(i));
        assertEquals(cache.elementAtIndex(i).getElementFrequency(), restoredVocab.elementAtIndex(i).getElementFrequency(), 0.1f);
        List<Integer> originalPoints = cache.elementAtIndex(i).getPoints();
        List<Integer> restoredPoints = restoredVocab.elementAtIndex(i).getPoints();
        assertEquals(originalPoints.size(), restoredPoints.size());
        for (int x = 0; x < originalPoints.size(); x++) {
            assertEquals(originalPoints.get(x), restoredPoints.get(x));
        }
        List<Byte> originalCodes = cache.elementAtIndex(i).getCodes();
        List<Byte> restoredCodes = restoredVocab.elementAtIndex(i).getCodes();
        assertEquals(originalCodes.size(), restoredCodes.size());
        for (int x = 0; x < originalCodes.size(); x++) {
            assertEquals(originalCodes.get(x), restoredCodes.get(x));
        }
    }
}
Also used : VectorsConfiguration(org.deeplearning4j.models.embeddings.loader.VectorsConfiguration) ArrayList(java.util.ArrayList) VocabWord(org.deeplearning4j.models.word2vec.VocabWord) AbstractCache(org.deeplearning4j.models.word2vec.wordstore.inmemory.AbstractCache) ParagraphVectors(org.deeplearning4j.models.paragraphvectors.ParagraphVectors) InMemoryLookupTable(org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable) INDArray(org.nd4j.linalg.api.ndarray.INDArray) File(java.io.File) Test(org.junit.Test)

Example 19 with InMemoryLookupTable

use of org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable in project deeplearning4j by deeplearning4j.

the class WordVectorSerializerTest method testFullModelSerialization.

@Test
public void testFullModelSerialization() throws Exception {
    File inputFile = new ClassPathResource("/big/raw_sentences.txt").getFile();
    SentenceIterator iter = UimaSentenceIterator.createWithPath(inputFile.getAbsolutePath());
    // Split on white spaces in the line to get words
    TokenizerFactory t = new DefaultTokenizerFactory();
    t.setTokenPreProcessor(new CommonPreprocessor());
    InMemoryLookupCache cache = new InMemoryLookupCache(false);
    WeightLookupTable table = new InMemoryLookupTable.Builder().vectorLength(100).useAdaGrad(false).negative(5.0).cache(cache).lr(0.025f).build();
    Word2Vec vec = new Word2Vec.Builder().minWordFrequency(5).iterations(1).epochs(1).layerSize(100).lookupTable(table).stopWords(new ArrayList<String>()).useAdaGrad(false).negativeSample(5).vocabCache(cache).seed(42).windowSize(5).iterate(iter).tokenizerFactory(t).build();
    assertEquals(new ArrayList<String>(), vec.getStopWords());
    vec.fit();
    //logger.info("Original word 0: " + cache.wordFor(cache.wordAtIndex(0)));
    //logger.info("Closest Words:");
    Collection<String> lst = vec.wordsNearest("day", 10);
    System.out.println(lst);
    WordVectorSerializer.writeFullModel(vec, "tempModel.txt");
    File modelFile = new File("tempModel.txt");
    modelFile.deleteOnExit();
    assertTrue(modelFile.exists());
    assertTrue(modelFile.length() > 0);
    Word2Vec vec2 = WordVectorSerializer.loadFullModel("tempModel.txt");
    assertNotEquals(null, vec2);
    assertEquals(vec.getConfiguration(), vec2.getConfiguration());
    //logger.info("Source ExpTable: " + ArrayUtils.toString(((InMemoryLookupTable) table).getExpTable()));
    //logger.info("Dest  ExpTable: " + ArrayUtils.toString(((InMemoryLookupTable)  vec2.getLookupTable()).getExpTable()));
    assertTrue(ArrayUtils.isEquals(((InMemoryLookupTable) table).getExpTable(), ((InMemoryLookupTable) vec2.getLookupTable()).getExpTable()));
    InMemoryLookupTable restoredTable = (InMemoryLookupTable) vec2.lookupTable();
    /*
        logger.info("Restored word 1: " + restoredTable.getVocab().wordFor(restoredTable.getVocab().wordAtIndex(1)));
        logger.info("Restored word 'it': " + restoredTable.getVocab().wordFor("it"));
        logger.info("Original word 1: " + cache.wordFor(cache.wordAtIndex(1)));
        logger.info("Original word 'i': " + cache.wordFor("i"));
        logger.info("Original word 0: " + cache.wordFor(cache.wordAtIndex(0)));
        logger.info("Restored word 0: " + restoredTable.getVocab().wordFor(restoredTable.getVocab().wordAtIndex(0)));
        */
    assertEquals(cache.wordAtIndex(1), restoredTable.getVocab().wordAtIndex(1));
    assertEquals(cache.wordAtIndex(7), restoredTable.getVocab().wordAtIndex(7));
    assertEquals(cache.wordAtIndex(15), restoredTable.getVocab().wordAtIndex(15));
    /*
            these tests needed only to make sure INDArray equality is working properly
         */
    double[] array1 = new double[] { 0.323232325, 0.65756575, 0.12315, 0.12312315, 0.1232135, 0.12312315, 0.4343423425, 0.15 };
    double[] array2 = new double[] { 0.423232325, 0.25756575, 0.12375, 0.12311315, 0.1232035, 0.12318315, 0.4343493425, 0.25 };
    assertNotEquals(Nd4j.create(array1), Nd4j.create(array2));
    assertEquals(Nd4j.create(array1), Nd4j.create(array1));
    INDArray rSyn0_1 = restoredTable.getSyn0().slice(1);
    INDArray oSyn0_1 = ((InMemoryLookupTable) table).getSyn0().slice(1);
    //logger.info("Restored syn0: " + rSyn0_1);
    //logger.info("Original syn0: " + oSyn0_1);
    assertEquals(oSyn0_1, rSyn0_1);
    // just checking $^###! syn0/syn1 order
    int cnt = 0;
    for (VocabWord word : cache.vocabWords()) {
        INDArray rSyn0 = restoredTable.getSyn0().slice(word.getIndex());
        INDArray oSyn0 = ((InMemoryLookupTable) table).getSyn0().slice(word.getIndex());
        assertEquals(rSyn0, oSyn0);
        assertEquals(1.0, arraysSimilarity(rSyn0, oSyn0), 0.001);
        INDArray rSyn1 = restoredTable.getSyn1().slice(word.getIndex());
        INDArray oSyn1 = ((InMemoryLookupTable) table).getSyn1().slice(word.getIndex());
        assertEquals(rSyn1, oSyn1);
        if (arraysSimilarity(rSyn1, oSyn1) < 0.98) {
        //   logger.info("Restored syn1: " + rSyn1);
        //   logger.info("Original  syn1: " + oSyn1);
        }
        // we exclude word 222 since it has syn1 full of zeroes
        if (cnt != 222)
            assertEquals(1.0, arraysSimilarity(rSyn1, oSyn1), 0.001);
        if (((InMemoryLookupTable) table).getSyn1Neg() != null) {
            INDArray rSyn1Neg = restoredTable.getSyn1Neg().slice(word.getIndex());
            INDArray oSyn1Neg = ((InMemoryLookupTable) table).getSyn1Neg().slice(word.getIndex());
            assertEquals(rSyn1Neg, oSyn1Neg);
        //                assertEquals(1.0, arraysSimilarity(rSyn1Neg, oSyn1Neg), 0.001);
        }
        assertEquals(word.getHistoricalGradient(), restoredTable.getVocab().wordFor(word.getWord()).getHistoricalGradient());
        cnt++;
    }
    // at this moment we can assume that whole model is transferred, and we can call fit over new model
    //        iter.reset();
    iter = UimaSentenceIterator.createWithPath(inputFile.getAbsolutePath());
    vec2.setTokenizerFactory(t);
    vec2.setSentenceIterator(iter);
    vec2.fit();
    INDArray day1 = vec.getWordVectorMatrix("day");
    INDArray day2 = vec2.getWordVectorMatrix("day");
    INDArray night1 = vec.getWordVectorMatrix("night");
    INDArray night2 = vec2.getWordVectorMatrix("night");
    double simD = arraysSimilarity(day1, day2);
    double simN = arraysSimilarity(night1, night2);
    logger.info("Vec1 day: " + day1);
    logger.info("Vec2 day: " + day2);
    logger.info("Vec1 night: " + night1);
    logger.info("Vec2 night: " + night2);
    logger.info("Day/day cross-model similarity: " + simD);
    logger.info("Night/night cross-model similarity: " + simN);
    logger.info("Vec1 day/night similiraty: " + vec.similarity("day", "night"));
    logger.info("Vec2 day/night similiraty: " + vec2.similarity("day", "night"));
    // check if cross-model values are not the same
    assertNotEquals(1.0, simD, 0.001);
    assertNotEquals(1.0, simN, 0.001);
    // check if cross-model values are still close to each other
    assertTrue(simD > 0.70);
    assertTrue(simN > 0.70);
    modelFile.delete();
}
Also used : TokenizerFactory(org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory) DefaultTokenizerFactory(org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFactory) VocabWord(org.deeplearning4j.models.word2vec.VocabWord) ClassPathResource(org.datavec.api.util.ClassPathResource) SentenceIterator(org.deeplearning4j.text.sentenceiterator.SentenceIterator) UimaSentenceIterator(org.deeplearning4j.text.sentenceiterator.UimaSentenceIterator) InMemoryLookupCache(org.deeplearning4j.models.word2vec.wordstore.inmemory.InMemoryLookupCache) DefaultTokenizerFactory(org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFactory) CommonPreprocessor(org.deeplearning4j.text.tokenization.tokenizer.preprocessor.CommonPreprocessor) InMemoryLookupTable(org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable) INDArray(org.nd4j.linalg.api.ndarray.INDArray) Word2Vec(org.deeplearning4j.models.word2vec.Word2Vec) WeightLookupTable(org.deeplearning4j.models.embeddings.WeightLookupTable) File(java.io.File) Test(org.junit.Test)

Example 20 with InMemoryLookupTable

use of org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable in project deeplearning4j by deeplearning4j.

the class CBOW method configure.

@Override
public void configure(@NonNull VocabCache<T> vocabCache, @NonNull WeightLookupTable<T> lookupTable, @NonNull VectorsConfiguration configuration) {
    this.vocabCache = vocabCache;
    this.lookupTable = lookupTable;
    this.configuration = configuration;
    this.window = configuration.getWindow();
    this.useAdaGrad = configuration.isUseAdaGrad();
    this.negative = configuration.getNegative();
    this.sampling = configuration.getSampling();
    if (configuration.getNegative() > 0) {
        if (((InMemoryLookupTable<T>) lookupTable).getSyn1Neg() == null) {
            logger.info("Initializing syn1Neg...");
            ((InMemoryLookupTable<T>) lookupTable).setUseHS(configuration.isUseHierarchicSoftmax());
            ((InMemoryLookupTable<T>) lookupTable).setNegative(configuration.getNegative());
            ((InMemoryLookupTable<T>) lookupTable).resetWeights(false);
        }
    }
    this.syn0 = new DeviceLocalNDArray(((InMemoryLookupTable<T>) lookupTable).getSyn0());
    this.syn1 = new DeviceLocalNDArray(((InMemoryLookupTable<T>) lookupTable).getSyn1());
    this.syn1Neg = new DeviceLocalNDArray(((InMemoryLookupTable<T>) lookupTable).getSyn1Neg());
    this.expTable = new DeviceLocalNDArray(Nd4j.create(((InMemoryLookupTable<T>) lookupTable).getExpTable()));
    this.table = new DeviceLocalNDArray(((InMemoryLookupTable<T>) lookupTable).getTable());
    this.variableWindows = configuration.getVariableWindows();
}
Also used : InMemoryLookupTable(org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable) DeviceLocalNDArray(org.nd4j.linalg.util.DeviceLocalNDArray)

Aggregations

InMemoryLookupTable (org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable)29 INDArray (org.nd4j.linalg.api.ndarray.INDArray)21 VocabWord (org.deeplearning4j.models.word2vec.VocabWord)18 ArrayList (java.util.ArrayList)13 AbstractCache (org.deeplearning4j.models.word2vec.wordstore.inmemory.AbstractCache)9 Test (org.junit.Test)8 VocabCache (org.deeplearning4j.models.word2vec.wordstore.VocabCache)7 File (java.io.File)6 Word2Vec (org.deeplearning4j.models.word2vec.Word2Vec)6 ZipFile (java.util.zip.ZipFile)5 DL4JInvalidInputException (org.deeplearning4j.exception.DL4JInvalidInputException)5 StaticWord2Vec (org.deeplearning4j.models.word2vec.StaticWord2Vec)5 TokenizerFactory (org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory)5 ND4JIllegalStateException (org.nd4j.linalg.exception.ND4JIllegalStateException)5 AtomicInteger (java.util.concurrent.atomic.AtomicInteger)4 ZipEntry (java.util.zip.ZipEntry)4 ClassPathResource (org.datavec.api.util.ClassPathResource)4 WordVectors (org.deeplearning4j.models.embeddings.wordvectors.WordVectors)4 InMemoryLookupCache (org.deeplearning4j.models.word2vec.wordstore.inmemory.InMemoryLookupCache)4 GZIPInputStream (java.util.zip.GZIPInputStream)3