Search in sources :

Example 1 with VocabCache

use of org.deeplearning4j.models.word2vec.wordstore.VocabCache in project deeplearning4j by deeplearning4j.

the class WordVectorSerializer method readTextModel.

/**
     * @param modelFile
     * @return
     * @throws FileNotFoundException
     * @throws IOException
     * @throws NumberFormatException
     */
private static Word2Vec readTextModel(File modelFile) throws IOException, NumberFormatException {
    InMemoryLookupTable lookupTable;
    VocabCache cache;
    INDArray syn0;
    Word2Vec ret = new Word2Vec();
    try (BufferedReader reader = new BufferedReader(new InputStreamReader(GzipUtils.isCompressedFilename(modelFile.getName()) ? new GZIPInputStream(new FileInputStream(modelFile)) : new FileInputStream(modelFile), "UTF-8"))) {
        String line = reader.readLine();
        String[] initial = line.split(" ");
        int words = Integer.parseInt(initial[0]);
        int layerSize = Integer.parseInt(initial[1]);
        syn0 = Nd4j.create(words, layerSize);
        cache = new InMemoryLookupCache(false);
        int currLine = 0;
        while ((line = reader.readLine()) != null) {
            String[] split = line.split(" ");
            assert split.length == layerSize + 1;
            String word = split[0].replaceAll(whitespaceReplacement, " ");
            float[] vector = new float[split.length - 1];
            for (int i = 1; i < split.length; i++) {
                vector[i - 1] = Float.parseFloat(split[i]);
            }
            syn0.putRow(currLine, Nd4j.create(vector));
            cache.addWordToIndex(cache.numWords(), word);
            cache.addToken(new VocabWord(1, word));
            cache.putVocabWord(word);
            currLine++;
        }
        lookupTable = (InMemoryLookupTable) new InMemoryLookupTable.Builder().cache(cache).vectorLength(layerSize).build();
        lookupTable.setSyn0(syn0);
        ret.setVocab(cache);
        ret.setLookupTable(lookupTable);
    }
    return ret;
}
Also used : VocabWord(org.deeplearning4j.models.word2vec.VocabWord) InMemoryLookupCache(org.deeplearning4j.models.word2vec.wordstore.inmemory.InMemoryLookupCache) GZIPInputStream(java.util.zip.GZIPInputStream) InMemoryLookupTable(org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable) INDArray(org.nd4j.linalg.api.ndarray.INDArray) VocabCache(org.deeplearning4j.models.word2vec.wordstore.VocabCache) StaticWord2Vec(org.deeplearning4j.models.word2vec.StaticWord2Vec) Word2Vec(org.deeplearning4j.models.word2vec.Word2Vec)

Example 2 with VocabCache

use of org.deeplearning4j.models.word2vec.wordstore.VocabCache in project deeplearning4j by deeplearning4j.

the class WordVectorSerializer method readWord2VecFromText.

/**
     * This method allows you to read ParagraphVectors from externaly originated vectors and syn1.
     * So, technically this method is compatible with any other w2v implementation
     *
     * @param vectors   text file with words and their wieghts, aka Syn0
     * @param hs    text file HS layers, aka Syn1
     * @param h_codes   text file with Huffman tree codes
     * @param h_points  text file with Huffman tree points
     * @return
     */
public static Word2Vec readWord2VecFromText(@NonNull File vectors, @NonNull File hs, @NonNull File h_codes, @NonNull File h_points, @NonNull VectorsConfiguration configuration) throws IOException {
    // first we load syn0
    Pair<InMemoryLookupTable, VocabCache> pair = loadTxt(vectors);
    InMemoryLookupTable lookupTable = pair.getFirst();
    lookupTable.setNegative(configuration.getNegative());
    if (configuration.getNegative() > 0)
        lookupTable.initNegative();
    VocabCache<VocabWord> vocab = (VocabCache<VocabWord>) pair.getSecond();
    // now we load syn1
    BufferedReader reader = new BufferedReader(new FileReader(hs));
    String line = null;
    List<INDArray> rows = new ArrayList<>();
    while ((line = reader.readLine()) != null) {
        String[] split = line.split(" ");
        double[] array = new double[split.length];
        for (int i = 0; i < split.length; i++) {
            array[i] = Double.parseDouble(split[i]);
        }
        rows.add(Nd4j.create(array));
    }
    reader.close();
    // it's possible to have full model without syn1
    if (rows.size() > 0) {
        INDArray syn1 = Nd4j.vstack(rows);
        lookupTable.setSyn1(syn1);
    }
    // now we transform mappings into huffman tree points
    reader = new BufferedReader(new FileReader(h_points));
    while ((line = reader.readLine()) != null) {
        String[] split = line.split(" ");
        VocabWord word = vocab.wordFor(decodeB64(split[0]));
        List<Integer> points = new ArrayList<>();
        for (int i = 1; i < split.length; i++) {
            points.add(Integer.parseInt(split[i]));
        }
        word.setPoints(points);
    }
    reader.close();
    // now we transform mappings into huffman tree codes
    reader = new BufferedReader(new FileReader(h_codes));
    while ((line = reader.readLine()) != null) {
        String[] split = line.split(" ");
        VocabWord word = vocab.wordFor(decodeB64(split[0]));
        List<Byte> codes = new ArrayList<>();
        for (int i = 1; i < split.length; i++) {
            codes.add(Byte.parseByte(split[i]));
        }
        word.setCodes(codes);
        word.setCodeLength((short) codes.size());
    }
    reader.close();
    Word2Vec.Builder builder = new Word2Vec.Builder(configuration).vocabCache(vocab).lookupTable(lookupTable).resetModel(false);
    TokenizerFactory factory = getTokenizerFactory(configuration);
    if (factory != null)
        builder.tokenizerFactory(factory);
    Word2Vec w2v = builder.build();
    return w2v;
}
Also used : TokenizerFactory(org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory) ArrayList(java.util.ArrayList) VocabWord(org.deeplearning4j.models.word2vec.VocabWord) AtomicInteger(java.util.concurrent.atomic.AtomicInteger) InMemoryLookupTable(org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable) INDArray(org.nd4j.linalg.api.ndarray.INDArray) VocabCache(org.deeplearning4j.models.word2vec.wordstore.VocabCache) StaticWord2Vec(org.deeplearning4j.models.word2vec.StaticWord2Vec) Word2Vec(org.deeplearning4j.models.word2vec.Word2Vec)

Example 3 with VocabCache

use of org.deeplearning4j.models.word2vec.wordstore.VocabCache in project deeplearning4j by deeplearning4j.

the class WordVectorSerializerTest method testIndexPersistence.

@Test
public void testIndexPersistence() throws Exception {
    File inputFile = new ClassPathResource("/big/raw_sentences.txt").getFile();
    SentenceIterator iter = UimaSentenceIterator.createWithPath(inputFile.getAbsolutePath());
    // Split on white spaces in the line to get words
    TokenizerFactory t = new DefaultTokenizerFactory();
    t.setTokenPreProcessor(new CommonPreprocessor());
    Word2Vec vec = new Word2Vec.Builder().minWordFrequency(5).iterations(1).epochs(1).layerSize(100).stopWords(new ArrayList<String>()).useAdaGrad(false).negativeSample(5).seed(42).windowSize(5).iterate(iter).tokenizerFactory(t).build();
    vec.fit();
    VocabCache orig = vec.getVocab();
    File tempFile = File.createTempFile("temp", "w2v");
    tempFile.deleteOnExit();
    WordVectorSerializer.writeWordVectors(vec, tempFile);
    WordVectors vec2 = WordVectorSerializer.loadTxtVectors(tempFile);
    VocabCache rest = vec2.vocab();
    assertEquals(orig.totalNumberOfDocs(), rest.totalNumberOfDocs());
    for (VocabWord word : vec.getVocab().vocabWords()) {
        INDArray array1 = vec.getWordVectorMatrix(word.getLabel());
        INDArray array2 = vec2.getWordVectorMatrix(word.getLabel());
        assertEquals(array1, array2);
    }
}
Also used : TokenizerFactory(org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory) DefaultTokenizerFactory(org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFactory) ArrayList(java.util.ArrayList) VocabWord(org.deeplearning4j.models.word2vec.VocabWord) ClassPathResource(org.datavec.api.util.ClassPathResource) SentenceIterator(org.deeplearning4j.text.sentenceiterator.SentenceIterator) UimaSentenceIterator(org.deeplearning4j.text.sentenceiterator.UimaSentenceIterator) DefaultTokenizerFactory(org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFactory) CommonPreprocessor(org.deeplearning4j.text.tokenization.tokenizer.preprocessor.CommonPreprocessor) INDArray(org.nd4j.linalg.api.ndarray.INDArray) VocabCache(org.deeplearning4j.models.word2vec.wordstore.VocabCache) Word2Vec(org.deeplearning4j.models.word2vec.Word2Vec) WordVectors(org.deeplearning4j.models.embeddings.wordvectors.WordVectors) File(java.io.File) Test(org.junit.Test)

Example 4 with VocabCache

use of org.deeplearning4j.models.word2vec.wordstore.VocabCache in project deeplearning4j by deeplearning4j.

the class Glove method train.

/**
     * Train on the corpus
     * @param rdd the rdd to train
     * @return the vocab and weights
     */
public Pair<VocabCache<VocabWord>, GloveWeightLookupTable> train(JavaRDD<String> rdd) throws Exception {
    // Each `train()` can use different parameters
    final JavaSparkContext sc = new JavaSparkContext(rdd.context());
    final SparkConf conf = sc.getConf();
    final int vectorLength = assignVar(VECTOR_LENGTH, conf, Integer.class);
    final boolean useAdaGrad = assignVar(ADAGRAD, conf, Boolean.class);
    final double negative = assignVar(NEGATIVE, conf, Double.class);
    final int numWords = assignVar(NUM_WORDS, conf, Integer.class);
    final int window = assignVar(WINDOW, conf, Integer.class);
    final double alpha = assignVar(ALPHA, conf, Double.class);
    final double minAlpha = assignVar(MIN_ALPHA, conf, Double.class);
    final int iterations = assignVar(ITERATIONS, conf, Integer.class);
    final int nGrams = assignVar(N_GRAMS, conf, Integer.class);
    final String tokenizer = assignVar(TOKENIZER, conf, String.class);
    final String tokenPreprocessor = assignVar(TOKEN_PREPROCESSOR, conf, String.class);
    final boolean removeStop = assignVar(REMOVE_STOPWORDS, conf, Boolean.class);
    Map<String, Object> tokenizerVarMap = new HashMap<String, Object>() {

        {
            put("numWords", numWords);
            put("nGrams", nGrams);
            put("tokenizer", tokenizer);
            put("tokenPreprocessor", tokenPreprocessor);
            put("removeStop", removeStop);
        }
    };
    Broadcast<Map<String, Object>> broadcastTokenizerVarMap = sc.broadcast(tokenizerVarMap);
    TextPipeline pipeline = new TextPipeline(rdd, broadcastTokenizerVarMap);
    pipeline.buildVocabCache();
    pipeline.buildVocabWordListRDD();
    // Get total word count
    Long totalWordCount = pipeline.getTotalWordCount();
    VocabCache<VocabWord> vocabCache = pipeline.getVocabCache();
    JavaRDD<Pair<List<String>, AtomicLong>> sentenceWordsCountRDD = pipeline.getSentenceWordsCountRDD();
    final Pair<VocabCache<VocabWord>, Long> vocabAndNumWords = new Pair<>(vocabCache, totalWordCount);
    vocabCacheBroadcast = sc.broadcast(vocabAndNumWords.getFirst());
    final GloveWeightLookupTable gloveWeightLookupTable = new GloveWeightLookupTable.Builder().cache(vocabAndNumWords.getFirst()).lr(conf.getDouble(GlovePerformer.ALPHA, 0.01)).maxCount(conf.getDouble(GlovePerformer.MAX_COUNT, 100)).vectorLength(conf.getInt(GlovePerformer.VECTOR_LENGTH, 300)).xMax(conf.getDouble(GlovePerformer.X_MAX, 0.75)).build();
    gloveWeightLookupTable.resetWeights();
    gloveWeightLookupTable.getBiasAdaGrad().historicalGradient = Nd4j.ones(gloveWeightLookupTable.getSyn0().rows());
    gloveWeightLookupTable.getWeightAdaGrad().historicalGradient = Nd4j.ones(gloveWeightLookupTable.getSyn0().shape());
    log.info("Created lookup table of size " + Arrays.toString(gloveWeightLookupTable.getSyn0().shape()));
    CounterMap<String, String> coOccurrenceCounts = sentenceWordsCountRDD.map(new CoOccurrenceCalculator(symmetric, vocabCacheBroadcast, windowSize)).fold(new CounterMap<String, String>(), new CoOccurrenceCounts());
    Iterator<Pair<String, String>> pair2 = coOccurrenceCounts.getPairIterator();
    List<Triple<String, String, Double>> counts = new ArrayList<>();
    while (pair2.hasNext()) {
        Pair<String, String> next = pair2.next();
        if (coOccurrenceCounts.getCount(next.getFirst(), next.getSecond()) > gloveWeightLookupTable.getMaxCount()) {
            coOccurrenceCounts.setCount(next.getFirst(), next.getSecond(), gloveWeightLookupTable.getMaxCount());
        }
        counts.add(new Triple<>(next.getFirst(), next.getSecond(), coOccurrenceCounts.getCount(next.getFirst(), next.getSecond())));
    }
    log.info("Calculated co occurrences");
    JavaRDD<Triple<String, String, Double>> parallel = sc.parallelize(counts);
    JavaPairRDD<String, Tuple2<String, Double>> pairs = parallel.mapToPair(new PairFunction<Triple<String, String, Double>, String, Tuple2<String, Double>>() {

        @Override
        public Tuple2<String, Tuple2<String, Double>> call(Triple<String, String, Double> stringStringDoubleTriple) throws Exception {
            return new Tuple2<>(stringStringDoubleTriple.getFirst(), new Tuple2<>(stringStringDoubleTriple.getSecond(), stringStringDoubleTriple.getThird()));
        }
    });
    JavaPairRDD<VocabWord, Tuple2<VocabWord, Double>> pairsVocab = pairs.mapToPair(new PairFunction<Tuple2<String, Tuple2<String, Double>>, VocabWord, Tuple2<VocabWord, Double>>() {

        @Override
        public Tuple2<VocabWord, Tuple2<VocabWord, Double>> call(Tuple2<String, Tuple2<String, Double>> stringTuple2Tuple2) throws Exception {
            VocabWord w1 = vocabCacheBroadcast.getValue().wordFor(stringTuple2Tuple2._1());
            VocabWord w2 = vocabCacheBroadcast.getValue().wordFor(stringTuple2Tuple2._2()._1());
            return new Tuple2<>(w1, new Tuple2<>(w2, stringTuple2Tuple2._2()._2()));
        }
    });
    for (int i = 0; i < iterations; i++) {
        JavaRDD<GloveChange> change = pairsVocab.map(new Function<Tuple2<VocabWord, Tuple2<VocabWord, Double>>, GloveChange>() {

            @Override
            public GloveChange call(Tuple2<VocabWord, Tuple2<VocabWord, Double>> vocabWordTuple2Tuple2) throws Exception {
                VocabWord w1 = vocabWordTuple2Tuple2._1();
                VocabWord w2 = vocabWordTuple2Tuple2._2()._1();
                INDArray w1Vector = gloveWeightLookupTable.getSyn0().slice(w1.getIndex());
                INDArray w2Vector = gloveWeightLookupTable.getSyn0().slice(w2.getIndex());
                INDArray bias = gloveWeightLookupTable.getBias();
                double score = vocabWordTuple2Tuple2._2()._2();
                double xMax = gloveWeightLookupTable.getxMax();
                double maxCount = gloveWeightLookupTable.getMaxCount();
                //w1 * w2 + bias
                double prediction = Nd4j.getBlasWrapper().dot(w1Vector, w2Vector);
                prediction += bias.getDouble(w1.getIndex()) + bias.getDouble(w2.getIndex());
                double weight = FastMath.pow(Math.min(1.0, (score / maxCount)), xMax);
                double fDiff = score > xMax ? prediction : weight * (prediction - Math.log(score));
                if (Double.isNaN(fDiff))
                    fDiff = Nd4j.EPS_THRESHOLD;
                //amount of change
                double gradient = fDiff;
                Pair<INDArray, Double> w1Update = update(gloveWeightLookupTable.getWeightAdaGrad(), gloveWeightLookupTable.getBiasAdaGrad(), gloveWeightLookupTable.getSyn0(), gloveWeightLookupTable.getBias(), w1, w1Vector, w2Vector, gradient);
                Pair<INDArray, Double> w2Update = update(gloveWeightLookupTable.getWeightAdaGrad(), gloveWeightLookupTable.getBiasAdaGrad(), gloveWeightLookupTable.getSyn0(), gloveWeightLookupTable.getBias(), w2, w2Vector, w1Vector, gradient);
                return new GloveChange(w1, w2, w1Update.getFirst(), w2Update.getFirst(), w1Update.getSecond(), w2Update.getSecond(), fDiff, gloveWeightLookupTable.getWeightAdaGrad().getHistoricalGradient().slice(w1.getIndex()), gloveWeightLookupTable.getWeightAdaGrad().getHistoricalGradient().slice(w2.getIndex()), gloveWeightLookupTable.getBiasAdaGrad().getHistoricalGradient().getDouble(w2.getIndex()), gloveWeightLookupTable.getBiasAdaGrad().getHistoricalGradient().getDouble(w1.getIndex()));
            }
        });
        List<GloveChange> gloveChanges = change.collect();
        double error = 0.0;
        for (GloveChange change2 : gloveChanges) {
            change2.apply(gloveWeightLookupTable);
            error += change2.getError();
        }
        List l = pairsVocab.collect();
        Collections.shuffle(l);
        pairsVocab = sc.parallelizePairs(l);
        log.info("Error at iteration " + i + " was " + error);
    }
    return new Pair<>(vocabAndNumWords.getFirst(), gloveWeightLookupTable);
}
Also used : CoOccurrenceCounts(org.deeplearning4j.spark.models.embeddings.glove.cooccurrences.CoOccurrenceCounts) TextPipeline(org.deeplearning4j.spark.text.functions.TextPipeline) Triple(org.deeplearning4j.berkeley.Triple) VocabCache(org.deeplearning4j.models.word2vec.wordstore.VocabCache) AtomicLong(java.util.concurrent.atomic.AtomicLong) CounterMap(org.deeplearning4j.berkeley.CounterMap) VocabWord(org.deeplearning4j.models.word2vec.VocabWord) CoOccurrenceCalculator(org.deeplearning4j.spark.models.embeddings.glove.cooccurrences.CoOccurrenceCalculator) JavaSparkContext(org.apache.spark.api.java.JavaSparkContext) Pair(org.deeplearning4j.berkeley.Pair) GloveWeightLookupTable(org.deeplearning4j.models.glove.GloveWeightLookupTable) INDArray(org.nd4j.linalg.api.ndarray.INDArray) Tuple2(scala.Tuple2) SparkConf(org.apache.spark.SparkConf)

Example 5 with VocabCache

use of org.deeplearning4j.models.word2vec.wordstore.VocabCache in project deeplearning4j by deeplearning4j.

the class CoOccurrenceCalculator method call.

@Override
public CounterMap<String, String> call(Pair<List<String>, AtomicLong> pair) throws Exception {
    List<String> sentence = pair.getFirst();
    CounterMap<String, String> coOCurreneCounts = new CounterMap<>();
    VocabCache vocab = this.vocab.value();
    for (int i = 0; i < sentence.size(); i++) {
        int wordIdx = vocab.indexOf(sentence.get(i));
        String w1 = ((VocabWord) vocab.wordFor(sentence.get(i))).getWord();
        if (// || w1.equals(Glove.UNK))
        wordIdx < 0)
            continue;
        int windowStop = Math.min(i + windowSize + 1, sentence.size());
        for (int j = i; j < windowStop; j++) {
            int otherWord = vocab.indexOf(sentence.get(j));
            String w2 = ((VocabWord) vocab.wordFor(sentence.get(j))).getWord();
            if (// || w2.equals(Glove.UNK))
            vocab.indexOf(sentence.get(j)) < 0)
                continue;
            if (otherWord == wordIdx)
                continue;
            if (wordIdx < otherWord) {
                coOCurreneCounts.incrementCount(sentence.get(i), sentence.get(j), 1.0 / (j - i + Nd4j.EPS_THRESHOLD));
                if (symmetric)
                    coOCurreneCounts.incrementCount(sentence.get(j), sentence.get(i), 1.0 / (j - i + Nd4j.EPS_THRESHOLD));
            } else {
                coOCurreneCounts.incrementCount(sentence.get(j), sentence.get(i), 1.0 / (j - i + Nd4j.EPS_THRESHOLD));
                if (symmetric)
                    coOCurreneCounts.incrementCount(sentence.get(i), sentence.get(j), 1.0 / (j - i + Nd4j.EPS_THRESHOLD));
            }
        }
    }
    return coOCurreneCounts;
}
Also used : CounterMap(org.deeplearning4j.berkeley.CounterMap) VocabCache(org.deeplearning4j.models.word2vec.wordstore.VocabCache) VocabWord(org.deeplearning4j.models.word2vec.VocabWord)

Aggregations

VocabCache (org.deeplearning4j.models.word2vec.wordstore.VocabCache)10 VocabWord (org.deeplearning4j.models.word2vec.VocabWord)9 INDArray (org.nd4j.linalg.api.ndarray.INDArray)8 InMemoryLookupTable (org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable)7 ArrayList (java.util.ArrayList)4 JavaSparkContext (org.apache.spark.api.java.JavaSparkContext)3 ClassPathResource (org.datavec.api.util.ClassPathResource)3 WordVectors (org.deeplearning4j.models.embeddings.wordvectors.WordVectors)3 Word2Vec (org.deeplearning4j.models.word2vec.Word2Vec)3 TokenizerFactory (org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory)3 Test (org.junit.Test)3 File (java.io.File)2 AtomicInteger (java.util.concurrent.atomic.AtomicInteger)2 AtomicLong (java.util.concurrent.atomic.AtomicLong)2 SparkConf (org.apache.spark.SparkConf)2 CounterMap (org.deeplearning4j.berkeley.CounterMap)2 Pair (org.deeplearning4j.berkeley.Pair)2 GloveWeightLookupTable (org.deeplearning4j.models.glove.GloveWeightLookupTable)2 StaticWord2Vec (org.deeplearning4j.models.word2vec.StaticWord2Vec)2 AbstractCache (org.deeplearning4j.models.word2vec.wordstore.inmemory.AbstractCache)2