Search in sources :

Example 11 with InMemoryLookupTable

use of org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable in project deeplearning4j by deeplearning4j.

the class WordVectorSerializerTest method testFromTableAndVocab.

@Test
@Ignore
public void testFromTableAndVocab() throws IOException {
    WordVectors vec = WordVectorSerializer.loadGoogleModel(textFile, false);
    InMemoryLookupTable lookupTable = (InMemoryLookupTable) vec.lookupTable();
    InMemoryLookupCache lookupCache = (InMemoryLookupCache) vec.vocab();
    WordVectors wordVectors = WordVectorSerializer.fromTableAndVocab(lookupTable, lookupCache);
    double[] wordVector1 = wordVectors.getWordVector("Morgan_Freeman");
    double[] wordVector2 = wordVectors.getWordVector("JA_Montalbano");
    assertTrue(wordVector1.length == 300);
    assertTrue(wordVector2.length == 300);
    assertEquals(Doubles.asList(wordVector1).get(0), 0.044423, 1e-3);
    assertEquals(Doubles.asList(wordVector2).get(0), 0.051964, 1e-3);
}
Also used : InMemoryLookupTable(org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable) WordVectors(org.deeplearning4j.models.embeddings.wordvectors.WordVectors) InMemoryLookupCache(org.deeplearning4j.models.word2vec.wordstore.inmemory.InMemoryLookupCache) Ignore(org.junit.Ignore) Test(org.junit.Test)

Example 12 with InMemoryLookupTable

use of org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable in project deeplearning4j by deeplearning4j.

the class WordVectorSerializerTest method testMalformedLabels1.

@Test
public void testMalformedLabels1() throws Exception {
    List<String> words = new ArrayList<>();
    words.add("test A");
    words.add("test B");
    words.add("test\nC");
    words.add("test`D");
    words.add("test_E");
    words.add("test 5");
    AbstractCache<VocabWord> vocabCache = new AbstractCache<>();
    int cnt = 0;
    for (String word : words) {
        vocabCache.addToken(new VocabWord(1.0, word));
        vocabCache.addWordToIndex(cnt, word);
        cnt++;
    }
    vocabCache.elementAtIndex(1).markAsLabel(true);
    InMemoryLookupTable<VocabWord> lookupTable = new InMemoryLookupTable<>(vocabCache, 10, false, 0.01, Nd4j.getRandom(), 0.0);
    lookupTable.resetWeights(true);
    assertNotEquals(null, lookupTable.getSyn0());
    assertNotEquals(null, lookupTable.getSyn1());
    assertNotEquals(null, lookupTable.getExpTable());
    assertEquals(null, lookupTable.getSyn1Neg());
    ParagraphVectors vec = new ParagraphVectors.Builder().lookupTable(lookupTable).vocabCache(vocabCache).build();
    File tempFile = File.createTempFile("temp", "w2v");
    tempFile.deleteOnExit();
    WordVectorSerializer.writeParagraphVectors(vec, tempFile);
    ParagraphVectors restoredVec = WordVectorSerializer.readParagraphVectors(tempFile);
    for (String word : words) {
        assertEquals(true, restoredVec.hasWord(word));
    }
    assertTrue(restoredVec.getVocab().elementAtIndex(1).isLabel());
}
Also used : InMemoryLookupTable(org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable) ArrayList(java.util.ArrayList) VocabWord(org.deeplearning4j.models.word2vec.VocabWord) AbstractCache(org.deeplearning4j.models.word2vec.wordstore.inmemory.AbstractCache) File(java.io.File) ParagraphVectors(org.deeplearning4j.models.paragraphvectors.ParagraphVectors) Test(org.junit.Test)

Example 13 with InMemoryLookupTable

use of org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable in project deeplearning4j by deeplearning4j.

the class SentenceBatch method iterateSample.

/**
     * Iterate on the given 2 vocab words
     *
     * @param w1 the first word to iterate on
     * @param w2 the second word to iterate on
     */
public void iterateSample(Word2VecParam param, VocabWord w1, VocabWord w2, double alpha, List<Triple<Integer, Integer, Integer>> changed) {
    if (w2 == null || w2.getIndex() < 0 || w1.getIndex() == w2.getIndex() || w1.getWord().equals("STOP") || w2.getWord().equals("STOP") || w1.getWord().equals("UNK") || w2.getWord().equals("UNK"))
        return;
    int vectorLength = param.getVectorLength();
    InMemoryLookupTable weights = param.getWeights();
    boolean useAdaGrad = param.isUseAdaGrad();
    double negative = param.getNegative();
    INDArray table = param.getTable();
    double[] expTable = param.getExpTable().getValue();
    double MAX_EXP = 6;
    int numWords = param.getNumWords();
    //current word vector
    INDArray l1 = weights.vector(w2.getWord());
    //error for current word and context
    INDArray neu1e = Nd4j.create(vectorLength);
    for (int i = 0; i < w1.getCodeLength(); i++) {
        int code = w1.getCodes().get(i);
        int point = w1.getPoints().get(i);
        INDArray syn1 = weights.getSyn1().slice(point);
        double dot = Nd4j.getBlasWrapper().level1().dot(syn1.length(), 1.0, l1, syn1);
        if (dot < -MAX_EXP || dot >= MAX_EXP)
            continue;
        int idx = (int) ((dot + MAX_EXP) * ((double) expTable.length / MAX_EXP / 2.0));
        //score
        double f = expTable[idx];
        //gradient
        double g = (1 - code - f) * (useAdaGrad ? w1.getGradient(i, alpha, alpha) : alpha);
        Nd4j.getBlasWrapper().level1().axpy(syn1.length(), g, syn1, neu1e);
        Nd4j.getBlasWrapper().level1().axpy(syn1.length(), g, l1, syn1);
        changed.add(new Triple<>(point, w1.getIndex(), -1));
    }
    changed.add(new Triple<>(w1.getIndex(), w2.getIndex(), -1));
    //negative sampling
    if (negative > 0) {
        int target = w1.getIndex();
        int label;
        INDArray syn1Neg = weights.getSyn1Neg().slice(target);
        for (int d = 0; d < negative + 1; d++) {
            if (d == 0) {
                label = 1;
            } else {
                nextRandom.set(nextRandom.get() * 25214903917L + 11);
                target = table.getInt((int) (nextRandom.get() >> 16) % table.length());
                if (target == 0)
                    target = (int) nextRandom.get() % (numWords - 1) + 1;
                if (target == w1.getIndex())
                    continue;
                label = 0;
            }
            double f = Nd4j.getBlasWrapper().dot(l1, syn1Neg);
            double g;
            if (f > MAX_EXP)
                g = useAdaGrad ? w1.getGradient(target, (label - 1), alpha) : (label - 1) * alpha;
            else if (f < -MAX_EXP)
                g = label * (useAdaGrad ? w1.getGradient(target, alpha, alpha) : alpha);
            else
                g = useAdaGrad ? w1.getGradient(target, label - expTable[(int) ((f + MAX_EXP) * (expTable.length / MAX_EXP / 2))], alpha) : (label - expTable[(int) ((f + MAX_EXP) * (expTable.length / MAX_EXP / 2))]) * alpha;
            Nd4j.getBlasWrapper().level1().axpy(l1.length(), g, neu1e, l1);
            Nd4j.getBlasWrapper().level1().axpy(l1.length(), g, syn1Neg, l1);
            changed.add(new Triple<>(-1, -1, label));
        }
    }
    Nd4j.getBlasWrapper().level1().axpy(l1.length(), 1.0f, neu1e, l1);
}
Also used : InMemoryLookupTable(org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable) INDArray(org.nd4j.linalg.api.ndarray.INDArray)

Example 14 with InMemoryLookupTable

use of org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable in project deeplearning4j by deeplearning4j.

the class Word2Vec method train.

/**
     *  Training word2vec model on a given text corpus
     *
     * @param corpusRDD training corpus
     * @throws Exception
     */
public void train(JavaRDD<String> corpusRDD) throws Exception {
    log.info("Start training ...");
    if (workers > 0)
        corpusRDD.repartition(workers);
    // SparkContext
    final JavaSparkContext sc = new JavaSparkContext(corpusRDD.context());
    // Pre-defined variables
    Map<String, Object> tokenizerVarMap = getTokenizerVarMap();
    Map<String, Object> word2vecVarMap = getWord2vecVarMap();
    // Variables to fill in train
    final JavaRDD<AtomicLong> sentenceWordsCountRDD;
    final JavaRDD<List<VocabWord>> vocabWordListRDD;
    final JavaPairRDD<List<VocabWord>, Long> vocabWordListSentenceCumSumRDD;
    final VocabCache<VocabWord> vocabCache;
    final JavaRDD<Long> sentenceCumSumCountRDD;
    int maxRep = 1;
    // Start Training //
    //////////////////////////////////////
    log.info("Tokenization and building VocabCache ...");
    // Processing every sentence and make a VocabCache which gets fed into a LookupCache
    Broadcast<Map<String, Object>> broadcastTokenizerVarMap = sc.broadcast(tokenizerVarMap);
    TextPipeline pipeline = new TextPipeline(corpusRDD, broadcastTokenizerVarMap);
    pipeline.buildVocabCache();
    pipeline.buildVocabWordListRDD();
    // Get total word count and put into word2vec variable map
    word2vecVarMap.put("totalWordCount", pipeline.getTotalWordCount());
    // 2 RDDs: (vocab words list) and (sentence Count).Already cached
    sentenceWordsCountRDD = pipeline.getSentenceCountRDD();
    vocabWordListRDD = pipeline.getVocabWordListRDD();
    // Get vocabCache and broad-casted vocabCache
    Broadcast<VocabCache<VocabWord>> vocabCacheBroadcast = pipeline.getBroadCastVocabCache();
    vocabCache = vocabCacheBroadcast.getValue();
    log.info("Vocab size: {}", vocabCache.numWords());
    //////////////////////////////////////
    log.info("Building Huffman Tree ...");
    // Building Huffman Tree would update the code and point in each of the vocabWord in vocabCache
    /*
        We don't need to build tree here, since it was built earlier, at TextPipeline.buildVocabCache() call.
        
        Huffman huffman = new Huffman(vocabCache.vocabWords());
        huffman.build();
        huffman.applyIndexes(vocabCache);
        */
    //////////////////////////////////////
    log.info("Calculating cumulative sum of sentence counts ...");
    sentenceCumSumCountRDD = new CountCumSum(sentenceWordsCountRDD).buildCumSum();
    //////////////////////////////////////
    log.info("Mapping to RDD(vocabWordList, cumulative sentence count) ...");
    vocabWordListSentenceCumSumRDD = vocabWordListRDD.zip(sentenceCumSumCountRDD).setName("vocabWordListSentenceCumSumRDD");
    /////////////////////////////////////
    log.info("Broadcasting word2vec variables to workers ...");
    Broadcast<Map<String, Object>> word2vecVarMapBroadcast = sc.broadcast(word2vecVarMap);
    Broadcast<double[]> expTableBroadcast = sc.broadcast(expTable);
    /////////////////////////////////////
    log.info("Training word2vec sentences ...");
    FlatMapFunction firstIterFunc = new FirstIterationFunction(word2vecVarMapBroadcast, expTableBroadcast, vocabCacheBroadcast);
    @SuppressWarnings("unchecked") JavaRDD<Pair<VocabWord, INDArray>> indexSyn0UpdateEntryRDD = vocabWordListSentenceCumSumRDD.mapPartitions(firstIterFunc).map(new MapToPairFunction());
    // Get all the syn0 updates into a list in driver
    List<Pair<VocabWord, INDArray>> syn0UpdateEntries = indexSyn0UpdateEntryRDD.collect();
    // Instantiate syn0
    INDArray syn0 = Nd4j.zeros(vocabCache.numWords(), layerSize);
    // Updating syn0 first pass: just add vectors obtained from different nodes
    log.info("Averaging results...");
    Map<VocabWord, AtomicInteger> updates = new HashMap<>();
    Map<Long, Long> updaters = new HashMap<>();
    for (Pair<VocabWord, INDArray> syn0UpdateEntry : syn0UpdateEntries) {
        syn0.getRow(syn0UpdateEntry.getFirst().getIndex()).addi(syn0UpdateEntry.getSecond());
        // for proper averaging we need to divide resulting sums later, by the number of additions
        if (updates.containsKey(syn0UpdateEntry.getFirst())) {
            updates.get(syn0UpdateEntry.getFirst()).incrementAndGet();
        } else
            updates.put(syn0UpdateEntry.getFirst(), new AtomicInteger(1));
        if (!updaters.containsKey(syn0UpdateEntry.getFirst().getVocabId())) {
            updaters.put(syn0UpdateEntry.getFirst().getVocabId(), syn0UpdateEntry.getFirst().getAffinityId());
        }
    }
    // Updating syn0 second pass: average obtained vectors
    for (Map.Entry<VocabWord, AtomicInteger> entry : updates.entrySet()) {
        if (entry.getValue().get() > 1) {
            if (entry.getValue().get() > maxRep)
                maxRep = entry.getValue().get();
            syn0.getRow(entry.getKey().getIndex()).divi(entry.getValue().get());
        }
    }
    long totals = 0;
    log.info("Finished calculations...");
    vocab = vocabCache;
    InMemoryLookupTable<VocabWord> inMemoryLookupTable = new InMemoryLookupTable<VocabWord>();
    Environment env = EnvironmentUtils.buildEnvironment();
    env.setNumCores(maxRep);
    env.setAvailableMemory(totals);
    update(env, Event.SPARK);
    inMemoryLookupTable.setVocab(vocabCache);
    inMemoryLookupTable.setVectorLength(layerSize);
    inMemoryLookupTable.setSyn0(syn0);
    lookupTable = inMemoryLookupTable;
    modelUtils.init(lookupTable);
}
Also used : HashMap(java.util.HashMap) VocabWord(org.deeplearning4j.models.word2vec.VocabWord) InMemoryLookupTable(org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable) FlatMapFunction(org.apache.spark.api.java.function.FlatMapFunction) ArrayList(java.util.ArrayList) List(java.util.List) CountCumSum(org.deeplearning4j.spark.text.functions.CountCumSum) JavaSparkContext(org.apache.spark.api.java.JavaSparkContext) Pair(org.deeplearning4j.berkeley.Pair) TextPipeline(org.deeplearning4j.spark.text.functions.TextPipeline) AtomicLong(java.util.concurrent.atomic.AtomicLong) INDArray(org.nd4j.linalg.api.ndarray.INDArray) AtomicInteger(java.util.concurrent.atomic.AtomicInteger) VocabCache(org.deeplearning4j.models.word2vec.wordstore.VocabCache) AtomicLong(java.util.concurrent.atomic.AtomicLong) Environment(org.nd4j.linalg.heartbeat.reports.Environment) HashMap(java.util.HashMap) Map(java.util.Map)

Example 15 with InMemoryLookupTable

use of org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable in project deeplearning4j by deeplearning4j.

the class GloveTest method testGlove.

@Test
public void testGlove() throws Exception {
    Glove glove = new Glove(true, 5, 100);
    JavaRDD<String> corpus = sc.textFile(new ClassPathResource("raw_sentences.txt").getFile().getAbsolutePath()).map(new Function<String, String>() {

        @Override
        public String call(String s) throws Exception {
            return s.toLowerCase();
        }
    });
    Pair<VocabCache<VocabWord>, GloveWeightLookupTable> table = glove.train(corpus);
    WordVectors vectors = WordVectorSerializer.fromPair(new Pair<>((InMemoryLookupTable) table.getSecond(), (VocabCache) table.getFirst()));
    Collection<String> words = vectors.wordsNearest("day", 20);
    assertTrue(words.contains("week"));
}
Also used : GloveWeightLookupTable(org.deeplearning4j.models.glove.GloveWeightLookupTable) ClassPathResource(org.datavec.api.util.ClassPathResource) InMemoryLookupTable(org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable) VocabCache(org.deeplearning4j.models.word2vec.wordstore.VocabCache) WordVectors(org.deeplearning4j.models.embeddings.wordvectors.WordVectors) Test(org.junit.Test) BaseSparkTest(org.deeplearning4j.spark.text.BaseSparkTest)

Aggregations

InMemoryLookupTable (org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable)29 INDArray (org.nd4j.linalg.api.ndarray.INDArray)21 VocabWord (org.deeplearning4j.models.word2vec.VocabWord)18 ArrayList (java.util.ArrayList)13 AbstractCache (org.deeplearning4j.models.word2vec.wordstore.inmemory.AbstractCache)9 Test (org.junit.Test)8 VocabCache (org.deeplearning4j.models.word2vec.wordstore.VocabCache)7 File (java.io.File)6 Word2Vec (org.deeplearning4j.models.word2vec.Word2Vec)6 ZipFile (java.util.zip.ZipFile)5 DL4JInvalidInputException (org.deeplearning4j.exception.DL4JInvalidInputException)5 StaticWord2Vec (org.deeplearning4j.models.word2vec.StaticWord2Vec)5 TokenizerFactory (org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory)5 ND4JIllegalStateException (org.nd4j.linalg.exception.ND4JIllegalStateException)5 AtomicInteger (java.util.concurrent.atomic.AtomicInteger)4 ZipEntry (java.util.zip.ZipEntry)4 ClassPathResource (org.datavec.api.util.ClassPathResource)4 WordVectors (org.deeplearning4j.models.embeddings.wordvectors.WordVectors)4 InMemoryLookupCache (org.deeplearning4j.models.word2vec.wordstore.inmemory.InMemoryLookupCache)4 GZIPInputStream (java.util.zip.GZIPInputStream)3