Search in sources :

Example 21 with InMemoryLookupTable

use of org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable in project deeplearning4j by deeplearning4j.

the class WordVectorSerializer method writeParagraphVectors.

/**
     * This method saves ParagraphVectors model into compressed zip file and sends it to output stream
     */
public static void writeParagraphVectors(ParagraphVectors vectors, OutputStream stream) throws IOException {
    ZipOutputStream zipfile = new ZipOutputStream(new BufferedOutputStream(new CloseShieldOutputStream(stream)));
    ZipEntry syn0 = new ZipEntry("syn0.txt");
    zipfile.putNextEntry(syn0);
    // writing out syn0
    File tempFileSyn0 = File.createTempFile("paravec", "0");
    tempFileSyn0.deleteOnExit();
    writeWordVectors(vectors.lookupTable(), tempFileSyn0);
    BufferedInputStream fis = new BufferedInputStream(new FileInputStream(tempFileSyn0));
    writeEntry(fis, zipfile);
    fis.close();
    // writing out syn1
    File tempFileSyn1 = File.createTempFile("paravec", "1");
    tempFileSyn1.deleteOnExit();
    INDArray syn1 = ((InMemoryLookupTable<VocabWord>) vectors.getLookupTable()).getSyn1();
    if (syn1 != null)
        try (PrintWriter writer = new PrintWriter(new FileWriter(tempFileSyn1))) {
            for (int x = 0; x < syn1.rows(); x++) {
                INDArray row = syn1.getRow(x);
                StringBuilder builder = new StringBuilder();
                for (int i = 0; i < row.length(); i++) {
                    builder.append(row.getDouble(i)).append(" ");
                }
                writer.println(builder.toString().trim());
            }
        }
    ZipEntry zSyn1 = new ZipEntry("syn1.txt");
    zipfile.putNextEntry(zSyn1);
    fis = new BufferedInputStream(new FileInputStream(tempFileSyn1));
    writeEntry(fis, zipfile);
    fis.close();
    File tempFileCodes = File.createTempFile("paravec", "h");
    tempFileCodes.deleteOnExit();
    ZipEntry hC = new ZipEntry("codes.txt");
    zipfile.putNextEntry(hC);
    // writing out huffman tree
    try (PrintWriter writer = new PrintWriter(new FileWriter(tempFileCodes))) {
        for (int i = 0; i < vectors.getVocab().numWords(); i++) {
            VocabWord word = vectors.getVocab().elementAtIndex(i);
            StringBuilder builder = new StringBuilder(encodeB64(word.getLabel())).append(" ");
            for (int code : word.getCodes()) {
                builder.append(code).append(" ");
            }
            writer.println(builder.toString().trim());
        }
    }
    fis = new BufferedInputStream(new FileInputStream(tempFileCodes));
    writeEntry(fis, zipfile);
    fis.close();
    File tempFileHuffman = File.createTempFile("paravec", "h");
    tempFileHuffman.deleteOnExit();
    ZipEntry hP = new ZipEntry("huffman.txt");
    zipfile.putNextEntry(hP);
    // writing out huffman tree
    try (PrintWriter writer = new PrintWriter(new FileWriter(tempFileHuffman))) {
        for (int i = 0; i < vectors.getVocab().numWords(); i++) {
            VocabWord word = vectors.getVocab().elementAtIndex(i);
            StringBuilder builder = new StringBuilder(encodeB64(word.getLabel())).append(" ");
            for (int point : word.getPoints()) {
                builder.append(point).append(" ");
            }
            writer.println(builder.toString().trim());
        }
    }
    fis = new BufferedInputStream(new FileInputStream(tempFileHuffman));
    writeEntry(fis, zipfile);
    fis.close();
    ZipEntry config = new ZipEntry("config.json");
    zipfile.putNextEntry(config);
    writeEntry(new ByteArrayInputStream(vectors.getConfiguration().toJson().getBytes()), zipfile);
    ZipEntry labels = new ZipEntry("labels.txt");
    zipfile.putNextEntry(labels);
    StringBuilder builder = new StringBuilder();
    for (VocabWord word : vectors.getVocab().tokens()) {
        if (word.isLabel())
            builder.append(encodeB64(word.getLabel())).append("\n");
    }
    writeEntry(new ByteArrayInputStream(builder.toString().trim().getBytes()), zipfile);
    ZipEntry hF = new ZipEntry("frequencies.txt");
    zipfile.putNextEntry(hF);
    File tempFileFreqs = File.createTempFile("paravec", "h");
    tempFileFreqs.deleteOnExit();
    // writing out word frequencies
    try (PrintWriter writer = new PrintWriter(new FileWriter(tempFileFreqs))) {
        for (int i = 0; i < vectors.getVocab().numWords(); i++) {
            VocabWord word = vectors.getVocab().elementAtIndex(i);
            builder = new StringBuilder(encodeB64(word.getLabel())).append(" ").append(word.getElementFrequency()).append(" ").append(vectors.getVocab().docAppearedIn(word.getLabel()));
            writer.println(builder.toString().trim());
        }
    }
    fis = new BufferedInputStream(new FileInputStream(tempFileFreqs));
    writeEntry(fis, zipfile);
    fis.close();
    zipfile.flush();
    zipfile.close();
}
Also used : ZipEntry(java.util.zip.ZipEntry) VocabWord(org.deeplearning4j.models.word2vec.VocabWord) InMemoryLookupTable(org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable) INDArray(org.nd4j.linalg.api.ndarray.INDArray) ZipOutputStream(java.util.zip.ZipOutputStream) ZipFile(java.util.zip.ZipFile)

Example 22 with InMemoryLookupTable

use of org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable in project deeplearning4j by deeplearning4j.

the class WordVectorSerializer method loadTxtVectors.

/**
     * This method can be used to load previously saved model from InputStream (like a HDFS-stream)
     *
     * Deprecation note: Please, consider using readWord2VecModel() or loadStaticModel() method instead
     *
     * @param stream InputStream that contains previously serialized model
     * @param skipFirstLine Set this TRUE if first line contains csv header, FALSE otherwise
     * @return
     * @throws IOException
     */
@Deprecated
public static WordVectors loadTxtVectors(@NonNull InputStream stream, boolean skipFirstLine) throws IOException {
    AbstractCache<VocabWord> cache = new AbstractCache.Builder<VocabWord>().build();
    BufferedReader reader = new BufferedReader(new InputStreamReader(stream));
    String line = "";
    List<INDArray> arrays = new ArrayList<>();
    if (skipFirstLine)
        reader.readLine();
    while ((line = reader.readLine()) != null) {
        String[] split = line.split(" ");
        String word = split[0].replaceAll(whitespaceReplacement, " ");
        VocabWord word1 = new VocabWord(1.0, word);
        word1.setIndex(cache.numWords());
        cache.addToken(word1);
        cache.addWordToIndex(word1.getIndex(), word);
        cache.putVocabWord(word);
        float[] vector = new float[split.length - 1];
        for (int i = 1; i < split.length; i++) {
            vector[i - 1] = Float.parseFloat(split[i]);
        }
        INDArray row = Nd4j.create(vector);
        arrays.add(row);
    }
    InMemoryLookupTable<VocabWord> lookupTable = (InMemoryLookupTable<VocabWord>) new InMemoryLookupTable.Builder<VocabWord>().vectorLength(arrays.get(0).columns()).cache(cache).build();
    INDArray syn = Nd4j.vstack(arrays);
    Nd4j.clearNans(syn);
    lookupTable.setSyn0(syn);
    return fromPair(Pair.makePair((InMemoryLookupTable) lookupTable, (VocabCache) cache));
}
Also used : ArrayList(java.util.ArrayList) VocabWord(org.deeplearning4j.models.word2vec.VocabWord) AbstractCache(org.deeplearning4j.models.word2vec.wordstore.inmemory.AbstractCache) InMemoryLookupTable(org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable) INDArray(org.nd4j.linalg.api.ndarray.INDArray) VocabCache(org.deeplearning4j.models.word2vec.wordstore.VocabCache)

Example 23 with InMemoryLookupTable

use of org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable in project deeplearning4j by deeplearning4j.

the class WordVectorSerializer method readParagraphVectorsFromText.

/**
     * Restores previously serialized ParagraphVectors model
     *
     * Deprecation note: Please, consider using readParagraphVectors() method instead
     *
     * @param stream InputStream that contains previously serialized model
     * @return
     */
@Deprecated
public static ParagraphVectors readParagraphVectorsFromText(@NonNull InputStream stream) {
    try {
        BufferedReader reader = new BufferedReader(new InputStreamReader(stream, "UTF-8"));
        ArrayList<String> labels = new ArrayList<>();
        ArrayList<INDArray> arrays = new ArrayList<>();
        VocabCache<VocabWord> vocabCache = new AbstractCache.Builder<VocabWord>().build();
        String line = "";
        while ((line = reader.readLine()) != null) {
            String[] split = line.split(" ");
            split[1] = split[1].replaceAll(whitespaceReplacement, " ");
            VocabWord word = new VocabWord(1.0, split[1]);
            if (split[0].equals("L")) {
                // we have label element here
                word.setSpecial(true);
                word.markAsLabel(true);
                labels.add(word.getLabel());
            } else if (split[0].equals("E")) {
                // we have usual element, aka word here
                word.setSpecial(false);
                word.markAsLabel(false);
            } else
                throw new IllegalStateException("Source stream doesn't looks like ParagraphVectors serialized model");
            // this particular line is just for backward compatibility with InMemoryLookupCache
            word.setIndex(vocabCache.numWords());
            vocabCache.addToken(word);
            vocabCache.addWordToIndex(word.getIndex(), word.getLabel());
            // backward compatibility code
            vocabCache.putVocabWord(word.getLabel());
            float[] vector = new float[split.length - 2];
            for (int i = 2; i < split.length; i++) {
                vector[i - 2] = Float.parseFloat(split[i]);
            }
            INDArray row = Nd4j.create(vector);
            arrays.add(row);
        }
        // now we create syn0 matrix, using previously fetched rows
        /*INDArray syn = Nd4j.create(new int[]{arrays.size(), arrays.get(0).columns()});
            for (int i = 0; i < syn.rows(); i++) {
                syn.putRow(i, arrays.get(i));
            }*/
        INDArray syn = Nd4j.vstack(arrays);
        InMemoryLookupTable<VocabWord> lookupTable = (InMemoryLookupTable<VocabWord>) new InMemoryLookupTable.Builder<VocabWord>().vectorLength(arrays.get(0).columns()).useAdaGrad(false).cache(vocabCache).build();
        Nd4j.clearNans(syn);
        lookupTable.setSyn0(syn);
        LabelsSource source = new LabelsSource(labels);
        ParagraphVectors vectors = new ParagraphVectors.Builder().labelsSource(source).vocabCache(vocabCache).lookupTable(lookupTable).modelUtils(new BasicModelUtils<VocabWord>()).build();
        try {
            reader.close();
        } catch (Exception e) {
        }
        vectors.extractLabels();
        return vectors;
    } catch (Exception e) {
        throw new RuntimeException(e);
    }
}
Also used : ND4JIllegalStateException(org.nd4j.linalg.exception.ND4JIllegalStateException) ArrayList(java.util.ArrayList) VocabWord(org.deeplearning4j.models.word2vec.VocabWord) AbstractCache(org.deeplearning4j.models.word2vec.wordstore.inmemory.AbstractCache) ParagraphVectors(org.deeplearning4j.models.paragraphvectors.ParagraphVectors) DL4JInvalidInputException(org.deeplearning4j.exception.DL4JInvalidInputException) ND4JIllegalStateException(org.nd4j.linalg.exception.ND4JIllegalStateException) InMemoryLookupTable(org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable) INDArray(org.nd4j.linalg.api.ndarray.INDArray) BasicModelUtils(org.deeplearning4j.models.embeddings.reader.impl.BasicModelUtils) LabelsSource(org.deeplearning4j.text.documentiterator.LabelsSource)

Example 24 with InMemoryLookupTable

use of org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable in project deeplearning4j by deeplearning4j.

the class WordVectorSerializer method writeFullModel.

/**
     * Saves full Word2Vec model in the way, that allows model updates without being rebuilt from scratches
     *
     * Deprecation note: Please, consider using writeWord2VecModel() method instead
     *
     * @param vec - The Word2Vec instance to be saved
     * @param path - the path for json to be saved
     */
@Deprecated
public static void writeFullModel(@NonNull Word2Vec vec, @NonNull String path) {
    /*
            Basically we need to save:
                    1. WeightLookupTable, especially syn0 and syn1 matrices
                    2. VocabCache, including only WordCounts
                    3. Settings from Word2Vect model: workers, layers, etc.
         */
    PrintWriter printWriter = null;
    try {
        printWriter = new PrintWriter(new OutputStreamWriter(new FileOutputStream(path), "UTF-8"));
    } catch (Exception e) {
        throw new RuntimeException(e);
    }
    WeightLookupTable<VocabWord> lookupTable = vec.getLookupTable();
    // ((InMemoryLookupTable) lookupTable).getVocab(); //vec.getVocab();
    VocabCache<VocabWord> vocabCache = vec.getVocab();
    if (!(lookupTable instanceof InMemoryLookupTable))
        throw new IllegalStateException("At this moment only InMemoryLookupTable is supported.");
    VectorsConfiguration conf = vec.getConfiguration();
    conf.setVocabSize(vocabCache.numWords());
    printWriter.println(conf.toJson());
    //log.info("Word2Vec conf. JSON: " + conf.toJson());
    /*
            We have the following map:
            Line 0 - VectorsConfiguration JSON string
            Line 1 - expTable
            Line 2 - table

            All following lines are vocab/weight lookup table saved line by line as VocabularyWord JSON representation
         */
    // actually we don't need expTable, since it produces exact results on subsequent runs untill you dont modify expTable size :)
    // saving ExpTable just for "special case in future"
    StringBuilder builder = new StringBuilder();
    for (int x = 0; x < ((InMemoryLookupTable) lookupTable).getExpTable().length; x++) {
        builder.append(((InMemoryLookupTable) lookupTable).getExpTable()[x]).append(" ");
    }
    printWriter.println(builder.toString().trim());
    // saving table, available only if negative sampling is used
    if (conf.getNegative() > 0 && ((InMemoryLookupTable) lookupTable).getTable() != null) {
        builder = new StringBuilder();
        for (int x = 0; x < ((InMemoryLookupTable) lookupTable).getTable().columns(); x++) {
            builder.append(((InMemoryLookupTable) lookupTable).getTable().getDouble(x)).append(" ");
        }
        printWriter.println(builder.toString().trim());
    } else
        printWriter.println("");
    List<VocabWord> words = new ArrayList<>(vocabCache.vocabWords());
    for (SequenceElement word : words) {
        VocabularyWord vw = new VocabularyWord(word.getLabel());
        vw.setCount(vocabCache.wordFrequency(word.getLabel()));
        vw.setHuffmanNode(VocabularyHolder.buildNode(word.getCodes(), word.getPoints(), word.getCodeLength(), word.getIndex()));
        // writing down syn0
        INDArray syn0 = ((InMemoryLookupTable) lookupTable).getSyn0().getRow(vocabCache.indexOf(word.getLabel()));
        double[] dsyn0 = new double[syn0.columns()];
        for (int x = 0; x < conf.getLayersSize(); x++) {
            dsyn0[x] = syn0.getDouble(x);
        }
        vw.setSyn0(dsyn0);
        // writing down syn1
        INDArray syn1 = ((InMemoryLookupTable) lookupTable).getSyn1().getRow(vocabCache.indexOf(word.getLabel()));
        double[] dsyn1 = new double[syn1.columns()];
        for (int x = 0; x < syn1.columns(); x++) {
            dsyn1[x] = syn1.getDouble(x);
        }
        vw.setSyn1(dsyn1);
        // writing down syn1Neg, if negative sampling is used
        if (conf.getNegative() > 0 && ((InMemoryLookupTable) lookupTable).getSyn1Neg() != null) {
            INDArray syn1Neg = ((InMemoryLookupTable) lookupTable).getSyn1Neg().getRow(vocabCache.indexOf(word.getLabel()));
            double[] dsyn1Neg = new double[syn1Neg.columns()];
            for (int x = 0; x < syn1Neg.columns(); x++) {
                dsyn1Neg[x] = syn1Neg.getDouble(x);
            }
            vw.setSyn1Neg(dsyn1Neg);
        }
        // in case of UseAdaGrad == true - we should save gradients for each word in vocab
        if (conf.isUseAdaGrad() && ((InMemoryLookupTable) lookupTable).isUseAdaGrad()) {
            INDArray gradient = word.getHistoricalGradient();
            if (gradient == null)
                gradient = Nd4j.zeros(word.getCodes().size());
            double[] ada = new double[gradient.columns()];
            for (int x = 0; x < gradient.columns(); x++) {
                ada[x] = gradient.getDouble(x);
            }
            vw.setHistoricalGradient(ada);
        }
        printWriter.println(vw.toJson());
    }
    // at this moment we have whole vocab serialized
    printWriter.flush();
    printWriter.close();
}
Also used : ND4JIllegalStateException(org.nd4j.linalg.exception.ND4JIllegalStateException) SequenceElement(org.deeplearning4j.models.sequencevectors.sequence.SequenceElement) ArrayList(java.util.ArrayList) VocabWord(org.deeplearning4j.models.word2vec.VocabWord) DL4JInvalidInputException(org.deeplearning4j.exception.DL4JInvalidInputException) ND4JIllegalStateException(org.nd4j.linalg.exception.ND4JIllegalStateException) InMemoryLookupTable(org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable) INDArray(org.nd4j.linalg.api.ndarray.INDArray) VocabularyWord(org.deeplearning4j.models.word2vec.wordstore.VocabularyWord)

Example 25 with InMemoryLookupTable

use of org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable in project deeplearning4j by deeplearning4j.

the class CBOW method iterateSample.

public void iterateSample(T currentWord, int[] windowWords, AtomicLong nextRandom, double alpha, boolean isInference, int numLabels, boolean trainWords, INDArray inferenceVector) {
    int[] idxSyn1 = null;
    int[] codes = null;
    if (configuration.isUseHierarchicSoftmax()) {
        idxSyn1 = new int[currentWord.getCodeLength()];
        codes = new int[currentWord.getCodeLength()];
        for (int p = 0; p < currentWord.getCodeLength(); p++) {
            if (currentWord.getPoints().get(p) < 0)
                continue;
            codes[p] = currentWord.getCodes().get(p);
            idxSyn1[p] = currentWord.getPoints().get(p);
        }
    } else {
        idxSyn1 = new int[0];
        codes = new int[0];
    }
    if (negative > 0) {
        if (syn1Neg == null) {
            ((InMemoryLookupTable<T>) lookupTable).initNegative();
            syn1Neg = new DeviceLocalNDArray(((InMemoryLookupTable<T>) lookupTable).getSyn1Neg());
        }
    }
    if (batches.get() == null)
        batches.set(new ArrayList<Aggregate>());
    AggregateCBOW cbow = new AggregateCBOW(syn0.get(), syn1.get(), syn1Neg.get(), expTable.get(), table.get(), currentWord.getIndex(), windowWords, idxSyn1, codes, (int) negative, currentWord.getIndex(), lookupTable.layerSize(), alpha, nextRandom.get(), vocabCache.numWords(), numLabels, trainWords, inferenceVector);
    nextRandom.set(Math.abs(nextRandom.get() * 25214903917L + 11));
    if (!isInference)
        batches.get().add(cbow);
    else
        Nd4j.getExecutioner().exec(cbow);
}
Also used : InMemoryLookupTable(org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable) DeviceLocalNDArray(org.nd4j.linalg.util.DeviceLocalNDArray) ArrayList(java.util.ArrayList) AggregateCBOW(org.nd4j.linalg.api.ops.aggregates.impl.AggregateCBOW)

Aggregations

InMemoryLookupTable (org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable)29 INDArray (org.nd4j.linalg.api.ndarray.INDArray)21 VocabWord (org.deeplearning4j.models.word2vec.VocabWord)18 ArrayList (java.util.ArrayList)13 AbstractCache (org.deeplearning4j.models.word2vec.wordstore.inmemory.AbstractCache)9 Test (org.junit.Test)8 VocabCache (org.deeplearning4j.models.word2vec.wordstore.VocabCache)7 File (java.io.File)6 Word2Vec (org.deeplearning4j.models.word2vec.Word2Vec)6 ZipFile (java.util.zip.ZipFile)5 DL4JInvalidInputException (org.deeplearning4j.exception.DL4JInvalidInputException)5 StaticWord2Vec (org.deeplearning4j.models.word2vec.StaticWord2Vec)5 TokenizerFactory (org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory)5 ND4JIllegalStateException (org.nd4j.linalg.exception.ND4JIllegalStateException)5 AtomicInteger (java.util.concurrent.atomic.AtomicInteger)4 ZipEntry (java.util.zip.ZipEntry)4 ClassPathResource (org.datavec.api.util.ClassPathResource)4 WordVectors (org.deeplearning4j.models.embeddings.wordvectors.WordVectors)4 InMemoryLookupCache (org.deeplearning4j.models.word2vec.wordstore.inmemory.InMemoryLookupCache)4 GZIPInputStream (java.util.zip.GZIPInputStream)3