Search in sources :

Example 26 with InMemoryLookupTable

use of org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable in project deeplearning4j by deeplearning4j.

the class SkipGram method iterateSample.

public double iterateSample(T w1, T lastWord, AtomicLong nextRandom, double alpha, boolean isInference, INDArray inferenceVector) {
    if (w1 == null || lastWord == null || (lastWord.getIndex() < 0 && !isInference) || w1.getIndex() == lastWord.getIndex() || w1.getLabel().equals("STOP") || lastWord.getLabel().equals("STOP") || w1.getLabel().equals("UNK") || lastWord.getLabel().equals("UNK")) {
        return 0.0;
    }
    double score = 0.0;
    int[] idxSyn1 = null;
    int[] codes = null;
    if (configuration.isUseHierarchicSoftmax()) {
        idxSyn1 = new int[w1.getCodeLength()];
        codes = new int[w1.getCodeLength()];
        for (int i = 0; i < w1.getCodeLength(); i++) {
            int code = w1.getCodes().get(i);
            int point = w1.getPoints().get(i);
            if (point >= vocabCache.numWords() || point < 0)
                continue;
            codes[i] = code;
            idxSyn1[i] = point;
        }
    } else {
        idxSyn1 = new int[0];
        codes = new int[0];
    }
    int target = w1.getIndex();
    //negative sampling
    if (negative > 0) {
        if (syn1Neg == null) {
            ((InMemoryLookupTable<T>) lookupTable).initNegative();
            syn1Neg = new DeviceLocalNDArray(((InMemoryLookupTable<T>) lookupTable).getSyn1Neg());
        }
    }
    if (batches.get() == null) {
        batches.set(new ArrayList<Aggregate>());
    }
    //log.info("VocabWords: {}; lastWordIndex: {}; syn1neg: {}", vocabCache.numWords(), lastWord.getIndex(), syn1Neg.get().rows());
    AggregateSkipGram sg = new AggregateSkipGram(syn0.get(), syn1.get(), syn1Neg.get(), expTable.get(), table.get(), lastWord.getIndex(), idxSyn1, codes, (int) negative, target, vectorLength, alpha, nextRandom.get(), vocabCache.numWords(), inferenceVector);
    nextRandom.set(Math.abs(nextRandom.get() * 25214903917L + 11));
    if (!isInference)
        batches.get().add(sg);
    else
        Nd4j.getExecutioner().exec(sg);
    return score;
}
Also used : InMemoryLookupTable(org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable) DeviceLocalNDArray(org.nd4j.linalg.util.DeviceLocalNDArray) AggregateSkipGram(org.nd4j.linalg.api.ops.aggregates.impl.AggregateSkipGram) Aggregate(org.nd4j.linalg.api.ops.aggregates.Aggregate)

Example 27 with InMemoryLookupTable

use of org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable in project deeplearning4j by deeplearning4j.

the class BasicModelUtils method wordsNearestSum.

/**
     * Words nearest based on positive and negative words
     * * @param top the top n words
     * @return the words nearest the mean of the words
     */
@Override
public Collection<String> wordsNearestSum(INDArray words, int top) {
    if (lookupTable instanceof InMemoryLookupTable) {
        InMemoryLookupTable l = (InMemoryLookupTable) lookupTable;
        INDArray syn0 = l.getSyn0();
        INDArray weights = syn0.norm2(0).rdivi(1).muli(words);
        INDArray distances = syn0.mulRowVector(weights).sum(1);
        INDArray[] sorted = Nd4j.sortWithIndices(distances, 0, false);
        INDArray sort = sorted[0];
        List<String> ret = new ArrayList<>();
        if (top > sort.length())
            top = sort.length();
        //there will be a redundant word
        int end = top;
        for (int i = 0; i < end; i++) {
            String add = vocabCache.wordAtIndex(sort.getInt(i));
            if (add == null || add.equals("UNK") || add.equals("STOP")) {
                end++;
                if (end >= sort.length())
                    break;
                continue;
            }
            ret.add(vocabCache.wordAtIndex(sort.getInt(i)));
        }
        return ret;
    }
    Counter<String> distances = new Counter<>();
    for (String s : vocabCache.words()) {
        INDArray otherVec = lookupTable.vector(s);
        double sim = Transforms.cosineSim(words, otherVec);
        distances.incrementCount(s, sim);
    }
    distances.keepTopNKeys(top);
    return distances.keySet();
}
Also used : InMemoryLookupTable(org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable) Counter(org.deeplearning4j.berkeley.Counter) INDArray(org.nd4j.linalg.api.ndarray.INDArray)

Example 28 with InMemoryLookupTable

use of org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable in project deeplearning4j by deeplearning4j.

the class ParagraphVectorsTest method testParagraphVectorsModelling1.

/**
     * This test doesn't really cares about actual results. We only care about equality between live model & restored models
     *
     * @throws Exception
     */
@Test
public void testParagraphVectorsModelling1() throws Exception {
    ClassPathResource resource = new ClassPathResource("/big/raw_sentences.txt");
    File file = resource.getFile();
    SentenceIterator iter = new BasicLineIterator(file);
    TokenizerFactory t = new DefaultTokenizerFactory();
    t.setTokenPreProcessor(new CommonPreprocessor());
    LabelsSource source = new LabelsSource("DOC_");
    ParagraphVectors vec = new ParagraphVectors.Builder().minWordFrequency(1).iterations(5).seed(119).epochs(1).layerSize(150).learningRate(0.025).labelsSource(source).windowSize(5).sequenceLearningAlgorithm(new DM<VocabWord>()).iterate(iter).trainWordVectors(true).tokenizerFactory(t).workers(4).sampling(0).build();
    vec.fit();
    VocabCache<VocabWord> cache = vec.getVocab();
    File fullFile = File.createTempFile("paravec", "tests");
    fullFile.deleteOnExit();
    INDArray originalSyn1_17 = ((InMemoryLookupTable) vec.getLookupTable()).getSyn1().getRow(17).dup();
    WordVectorSerializer.writeParagraphVectors(vec, fullFile);
    int cnt1 = cache.wordFrequency("day");
    int cnt2 = cache.wordFrequency("me");
    assertNotEquals(1, cnt1);
    assertNotEquals(1, cnt2);
    assertNotEquals(cnt1, cnt2);
    assertEquals(97406, cache.numWords());
    assertTrue(vec.hasWord("DOC_16392"));
    assertTrue(vec.hasWord("DOC_3720"));
    List<String> result = new ArrayList<>(vec.nearestLabels(vec.getWordVectorMatrix("DOC_16392"), 10));
    System.out.println("nearest labels: " + result);
    for (String label : result) {
        System.out.println(label + "/DOC_16392: " + vec.similarity(label, "DOC_16392"));
    }
    assertTrue(result.contains("DOC_16392"));
    //assertTrue(result.contains("DOC_21383"));
    /*
            We have few lines that contain pretty close words invloved.
            These sentences should be pretty close to each other in vector space
         */
    // line 3721: This is my way .
    // line 6348: This is my case .
    // line 9836: This is my house .
    // line 12493: This is my world .
    // line 16393: This is my work .
    // this is special sentence, that has nothing common with previous sentences
    // line 9853: We now have one .
    double similarityD = vec.similarity("day", "night");
    log.info("day/night similarity: " + similarityD);
    if (similarityD < 0.0) {
        log.info("Day: " + Arrays.toString(vec.getWordVectorMatrix("day").dup().data().asDouble()));
        log.info("Night: " + Arrays.toString(vec.getWordVectorMatrix("night").dup().data().asDouble()));
    }
    List<String> labelsOriginal = vec.labelsSource.getLabels();
    double similarityW = vec.similarity("way", "work");
    log.info("way/work similarity: " + similarityW);
    double similarityH = vec.similarity("house", "world");
    log.info("house/world similarity: " + similarityH);
    double similarityC = vec.similarity("case", "way");
    log.info("case/way similarity: " + similarityC);
    double similarity1 = vec.similarity("DOC_9835", "DOC_12492");
    log.info("9835/12492 similarity: " + similarity1);
    //        assertTrue(similarity1 > 0.7d);
    double similarity2 = vec.similarity("DOC_3720", "DOC_16392");
    log.info("3720/16392 similarity: " + similarity2);
    //        assertTrue(similarity2 > 0.7d);
    double similarity3 = vec.similarity("DOC_6347", "DOC_3720");
    log.info("6347/3720 similarity: " + similarity3);
    //        assertTrue(similarity2 > 0.7d);
    // likelihood in this case should be significantly lower
    double similarityX = vec.similarity("DOC_3720", "DOC_9852");
    log.info("3720/9852 similarity: " + similarityX);
    assertTrue(similarityX < 0.5d);
    File tempFile = File.createTempFile("paravec", "ser");
    tempFile.deleteOnExit();
    INDArray day = vec.getWordVectorMatrix("day").dup();
    /*
            Testing txt serialization
         */
    File tempFile2 = File.createTempFile("paravec", "ser");
    tempFile2.deleteOnExit();
    WordVectorSerializer.writeWordVectors(vec, tempFile2);
    ParagraphVectors vec3 = WordVectorSerializer.readParagraphVectorsFromText(tempFile2);
    INDArray day3 = vec3.getWordVectorMatrix("day").dup();
    List<String> labelsRestored = vec3.labelsSource.getLabels();
    assertEquals(day, day3);
    assertEquals(labelsOriginal.size(), labelsRestored.size());
    /*
         Testing binary serialization
        */
    SerializationUtils.saveObject(vec, tempFile);
    ParagraphVectors vec2 = (ParagraphVectors) SerializationUtils.readObject(tempFile);
    INDArray day2 = vec2.getWordVectorMatrix("day").dup();
    List<String> labelsBinary = vec2.labelsSource.getLabels();
    assertEquals(day, day2);
    tempFile.delete();
    assertEquals(labelsOriginal.size(), labelsBinary.size());
    INDArray original = vec.getWordVectorMatrix("DOC_16392").dup();
    INDArray originalPreserved = original.dup();
    INDArray inferredA1 = vec.inferVector("This is my work .");
    INDArray inferredB1 = vec.inferVector("This is my work .");
    double cosAO1 = Transforms.cosineSim(inferredA1.dup(), original.dup());
    double cosAB1 = Transforms.cosineSim(inferredA1.dup(), inferredB1.dup());
    log.info("Cos O/A: {}", cosAO1);
    log.info("Cos A/B: {}", cosAB1);
    //        assertTrue(cosAO1 > 0.45);
    assertTrue(cosAB1 > 0.95);
    //assertArrayEquals(inferredA.data().asDouble(), inferredB.data().asDouble(), 0.01);
    ParagraphVectors restoredVectors = WordVectorSerializer.readParagraphVectors(fullFile);
    restoredVectors.setTokenizerFactory(t);
    INDArray restoredSyn1_17 = ((InMemoryLookupTable) restoredVectors.getLookupTable()).getSyn1().getRow(17).dup();
    assertEquals(originalSyn1_17, restoredSyn1_17);
    INDArray originalRestored = vec.getWordVectorMatrix("DOC_16392").dup();
    assertEquals(originalPreserved, originalRestored);
    INDArray inferredA2 = restoredVectors.inferVector("This is my work .");
    INDArray inferredB2 = restoredVectors.inferVector("This is my work .");
    INDArray inferredC2 = restoredVectors.inferVector("world way case .");
    double cosAO2 = Transforms.cosineSim(inferredA2.dup(), original.dup());
    double cosAB2 = Transforms.cosineSim(inferredA2.dup(), inferredB2.dup());
    double cosAAX = Transforms.cosineSim(inferredA1.dup(), inferredA2.dup());
    double cosAC2 = Transforms.cosineSim(inferredC2.dup(), inferredA2.dup());
    log.info("Cos A2/B2: {}", cosAB2);
    log.info("Cos A1/A2: {}", cosAAX);
    log.info("Cos O/A2: {}", cosAO2);
    log.info("Cos C2/A2: {}", cosAC2);
    log.info("Vector: {}", Arrays.toString(inferredA1.data().asFloat()));
    log.info("cosAO2: {}", cosAO2);
    //  assertTrue(cosAO2 > 0.45);
    assertTrue(cosAB2 > 0.95);
    assertTrue(cosAAX > 0.95);
}
Also used : BasicLineIterator(org.deeplearning4j.text.sentenceiterator.BasicLineIterator) TokenizerFactory(org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory) DefaultTokenizerFactory(org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFactory) ArrayList(java.util.ArrayList) VocabWord(org.deeplearning4j.models.word2vec.VocabWord) ClassPathResource(org.datavec.api.util.ClassPathResource) SentenceIterator(org.deeplearning4j.text.sentenceiterator.SentenceIterator) FileSentenceIterator(org.deeplearning4j.text.sentenceiterator.FileSentenceIterator) AggregatingSentenceIterator(org.deeplearning4j.text.sentenceiterator.AggregatingSentenceIterator) DefaultTokenizerFactory(org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFactory) CommonPreprocessor(org.deeplearning4j.text.tokenization.tokenizer.preprocessor.CommonPreprocessor) InMemoryLookupTable(org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable) INDArray(org.nd4j.linalg.api.ndarray.INDArray) LabelsSource(org.deeplearning4j.text.documentiterator.LabelsSource) File(java.io.File) Test(org.junit.Test)

Example 29 with InMemoryLookupTable

use of org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable in project deeplearning4j by deeplearning4j.

the class VocabCacheExporter method export.

@Override
public void export(JavaRDD<ExportContainer<VocabWord>> rdd) {
    // beware, generally that's VERY bad idea, but will work fine for testing purposes
    List<ExportContainer<VocabWord>> list = rdd.collect();
    if (vocabCache == null)
        vocabCache = new AbstractCache<>();
    INDArray syn0 = null;
    // just roll through list
    for (ExportContainer<VocabWord> element : list) {
        VocabWord word = element.getElement();
        INDArray weights = element.getArray();
        if (syn0 == null)
            syn0 = Nd4j.create(list.size(), weights.length());
        vocabCache.addToken(word);
        vocabCache.addWordToIndex(word.getIndex(), word.getLabel());
        syn0.getRow(word.getIndex()).assign(weights);
    }
    if (lookupTable == null)
        lookupTable = new InMemoryLookupTable.Builder<VocabWord>().cache(vocabCache).vectorLength(syn0.columns()).build();
    lookupTable.setSyn0(syn0);
    // this is bad & dirty, but we don't really need anything else for testing :)
    word2Vec = WordVectorSerializer.fromPair(Pair.<InMemoryLookupTable, VocabCache>makePair(lookupTable, vocabCache));
}
Also used : InMemoryLookupTable(org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable) INDArray(org.nd4j.linalg.api.ndarray.INDArray) VocabCache(org.deeplearning4j.models.word2vec.wordstore.VocabCache) ExportContainer(org.deeplearning4j.spark.models.sequencevectors.export.ExportContainer) VocabWord(org.deeplearning4j.models.word2vec.VocabWord) AbstractCache(org.deeplearning4j.models.word2vec.wordstore.inmemory.AbstractCache)

Aggregations

InMemoryLookupTable (org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable)29 INDArray (org.nd4j.linalg.api.ndarray.INDArray)21 VocabWord (org.deeplearning4j.models.word2vec.VocabWord)18 ArrayList (java.util.ArrayList)13 AbstractCache (org.deeplearning4j.models.word2vec.wordstore.inmemory.AbstractCache)9 Test (org.junit.Test)8 VocabCache (org.deeplearning4j.models.word2vec.wordstore.VocabCache)7 File (java.io.File)6 Word2Vec (org.deeplearning4j.models.word2vec.Word2Vec)6 ZipFile (java.util.zip.ZipFile)5 DL4JInvalidInputException (org.deeplearning4j.exception.DL4JInvalidInputException)5 StaticWord2Vec (org.deeplearning4j.models.word2vec.StaticWord2Vec)5 TokenizerFactory (org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory)5 ND4JIllegalStateException (org.nd4j.linalg.exception.ND4JIllegalStateException)5 AtomicInteger (java.util.concurrent.atomic.AtomicInteger)4 ZipEntry (java.util.zip.ZipEntry)4 ClassPathResource (org.datavec.api.util.ClassPathResource)4 WordVectors (org.deeplearning4j.models.embeddings.wordvectors.WordVectors)4 InMemoryLookupCache (org.deeplearning4j.models.word2vec.wordstore.inmemory.InMemoryLookupCache)4 GZIPInputStream (java.util.zip.GZIPInputStream)3