Search in sources :

Example 1 with ParagraphVectors

use of org.deeplearning4j.models.paragraphvectors.ParagraphVectors in project deeplearning4j by deeplearning4j.

the class WordVectorSerializer method readParagraphVectors.

/**
     * This method restores ParagraphVectors model previously saved with writeParagraphVectors()
     *
     * @return
     */
public static ParagraphVectors readParagraphVectors(File file) throws IOException {
    File tmpFileL = File.createTempFile("paravec", "l");
    tmpFileL.deleteOnExit();
    Word2Vec w2v = readWord2Vec(file);
    // and "convert" it to ParaVec model + optionally trying to restore labels information
    ParagraphVectors vectors = new ParagraphVectors.Builder(w2v.getConfiguration()).vocabCache(w2v.getVocab()).lookupTable(w2v.getLookupTable()).resetModel(false).build();
    ZipFile zipFile = new ZipFile(file);
    // now we try to restore labels information
    ZipEntry labels = zipFile.getEntry("labels.txt");
    if (labels != null) {
        InputStream stream = zipFile.getInputStream(labels);
        Files.copy(stream, Paths.get(tmpFileL.getAbsolutePath()), StandardCopyOption.REPLACE_EXISTING);
        try (BufferedReader reader = new BufferedReader(new FileReader(tmpFileL))) {
            String line;
            while ((line = reader.readLine()) != null) {
                VocabWord word = vectors.getVocab().tokenFor(decodeB64(line.trim()));
                if (word != null) {
                    word.markAsLabel(true);
                }
            }
        }
    }
    vectors.extractLabels();
    return vectors;
}
Also used : ZipFile(java.util.zip.ZipFile) GZIPInputStream(java.util.zip.GZIPInputStream) StaticWord2Vec(org.deeplearning4j.models.word2vec.StaticWord2Vec) Word2Vec(org.deeplearning4j.models.word2vec.Word2Vec) ZipEntry(java.util.zip.ZipEntry) VocabWord(org.deeplearning4j.models.word2vec.VocabWord) ZipFile(java.util.zip.ZipFile) ParagraphVectors(org.deeplearning4j.models.paragraphvectors.ParagraphVectors)

Example 2 with ParagraphVectors

use of org.deeplearning4j.models.paragraphvectors.ParagraphVectors in project deeplearning4j by deeplearning4j.

the class WordVectorSerializerTest method testMalformedLabels1.

@Test
public void testMalformedLabels1() throws Exception {
    List<String> words = new ArrayList<>();
    words.add("test A");
    words.add("test B");
    words.add("test\nC");
    words.add("test`D");
    words.add("test_E");
    words.add("test 5");
    AbstractCache<VocabWord> vocabCache = new AbstractCache<>();
    int cnt = 0;
    for (String word : words) {
        vocabCache.addToken(new VocabWord(1.0, word));
        vocabCache.addWordToIndex(cnt, word);
        cnt++;
    }
    vocabCache.elementAtIndex(1).markAsLabel(true);
    InMemoryLookupTable<VocabWord> lookupTable = new InMemoryLookupTable<>(vocabCache, 10, false, 0.01, Nd4j.getRandom(), 0.0);
    lookupTable.resetWeights(true);
    assertNotEquals(null, lookupTable.getSyn0());
    assertNotEquals(null, lookupTable.getSyn1());
    assertNotEquals(null, lookupTable.getExpTable());
    assertEquals(null, lookupTable.getSyn1Neg());
    ParagraphVectors vec = new ParagraphVectors.Builder().lookupTable(lookupTable).vocabCache(vocabCache).build();
    File tempFile = File.createTempFile("temp", "w2v");
    tempFile.deleteOnExit();
    WordVectorSerializer.writeParagraphVectors(vec, tempFile);
    ParagraphVectors restoredVec = WordVectorSerializer.readParagraphVectors(tempFile);
    for (String word : words) {
        assertEquals(true, restoredVec.hasWord(word));
    }
    assertTrue(restoredVec.getVocab().elementAtIndex(1).isLabel());
}
Also used : InMemoryLookupTable(org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable) ArrayList(java.util.ArrayList) VocabWord(org.deeplearning4j.models.word2vec.VocabWord) AbstractCache(org.deeplearning4j.models.word2vec.wordstore.inmemory.AbstractCache) File(java.io.File) ParagraphVectors(org.deeplearning4j.models.paragraphvectors.ParagraphVectors) Test(org.junit.Test)

Example 3 with ParagraphVectors

use of org.deeplearning4j.models.paragraphvectors.ParagraphVectors in project deeplearning4j by deeplearning4j.

the class WordVectorSerializerTest method testParaVecSerialization1.

@Test
public void testParaVecSerialization1() throws Exception {
    VectorsConfiguration configuration = new VectorsConfiguration();
    configuration.setIterations(14123);
    configuration.setLayersSize(156);
    INDArray syn0 = Nd4j.rand(100, configuration.getLayersSize());
    INDArray syn1 = Nd4j.rand(100, configuration.getLayersSize());
    AbstractCache<VocabWord> cache = new AbstractCache.Builder<VocabWord>().build();
    for (int i = 0; i < 100; i++) {
        VocabWord word = new VocabWord((float) i, "word_" + i);
        List<Integer> points = new ArrayList<>();
        List<Byte> codes = new ArrayList<>();
        int num = org.apache.commons.lang3.RandomUtils.nextInt(1, 20);
        for (int x = 0; x < num; x++) {
            points.add(org.apache.commons.lang3.RandomUtils.nextInt(1, 100000));
            codes.add(org.apache.commons.lang3.RandomUtils.nextBytes(10)[0]);
        }
        if (RandomUtils.nextInt(10) < 3) {
            word.markAsLabel(true);
        }
        word.setIndex(i);
        word.setPoints(points);
        word.setCodes(codes);
        cache.addToken(word);
        cache.addWordToIndex(i, word.getLabel());
    }
    InMemoryLookupTable<VocabWord> lookupTable = (InMemoryLookupTable<VocabWord>) new InMemoryLookupTable.Builder<VocabWord>().vectorLength(configuration.getLayersSize()).cache(cache).build();
    lookupTable.setSyn0(syn0);
    lookupTable.setSyn1(syn1);
    ParagraphVectors originalVectors = new ParagraphVectors.Builder(configuration).vocabCache(cache).lookupTable(lookupTable).build();
    File tempFile = File.createTempFile("paravec", "tests");
    tempFile.deleteOnExit();
    WordVectorSerializer.writeParagraphVectors(originalVectors, tempFile);
    ParagraphVectors restoredVectors = WordVectorSerializer.readParagraphVectors(tempFile);
    InMemoryLookupTable<VocabWord> restoredLookupTable = (InMemoryLookupTable<VocabWord>) restoredVectors.getLookupTable();
    AbstractCache<VocabWord> restoredVocab = (AbstractCache<VocabWord>) restoredVectors.getVocab();
    assertEquals(restoredLookupTable.getSyn0(), lookupTable.getSyn0());
    assertEquals(restoredLookupTable.getSyn1(), lookupTable.getSyn1());
    for (int i = 0; i < cache.numWords(); i++) {
        assertEquals(cache.elementAtIndex(i).isLabel(), restoredVocab.elementAtIndex(i).isLabel());
        assertEquals(cache.wordAtIndex(i), restoredVocab.wordAtIndex(i));
        assertEquals(cache.elementAtIndex(i).getElementFrequency(), restoredVocab.elementAtIndex(i).getElementFrequency(), 0.1f);
        List<Integer> originalPoints = cache.elementAtIndex(i).getPoints();
        List<Integer> restoredPoints = restoredVocab.elementAtIndex(i).getPoints();
        assertEquals(originalPoints.size(), restoredPoints.size());
        for (int x = 0; x < originalPoints.size(); x++) {
            assertEquals(originalPoints.get(x), restoredPoints.get(x));
        }
        List<Byte> originalCodes = cache.elementAtIndex(i).getCodes();
        List<Byte> restoredCodes = restoredVocab.elementAtIndex(i).getCodes();
        assertEquals(originalCodes.size(), restoredCodes.size());
        for (int x = 0; x < originalCodes.size(); x++) {
            assertEquals(originalCodes.get(x), restoredCodes.get(x));
        }
    }
}
Also used : VectorsConfiguration(org.deeplearning4j.models.embeddings.loader.VectorsConfiguration) ArrayList(java.util.ArrayList) VocabWord(org.deeplearning4j.models.word2vec.VocabWord) AbstractCache(org.deeplearning4j.models.word2vec.wordstore.inmemory.AbstractCache) ParagraphVectors(org.deeplearning4j.models.paragraphvectors.ParagraphVectors) InMemoryLookupTable(org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable) INDArray(org.nd4j.linalg.api.ndarray.INDArray) File(java.io.File) Test(org.junit.Test)

Example 4 with ParagraphVectors

use of org.deeplearning4j.models.paragraphvectors.ParagraphVectors in project deeplearning4j by deeplearning4j.

the class WordVectorSerializer method readParagraphVectorsFromText.

/**
     * Restores previously serialized ParagraphVectors model
     *
     * Deprecation note: Please, consider using readParagraphVectors() method instead
     *
     * @param stream InputStream that contains previously serialized model
     * @return
     */
@Deprecated
public static ParagraphVectors readParagraphVectorsFromText(@NonNull InputStream stream) {
    try {
        BufferedReader reader = new BufferedReader(new InputStreamReader(stream, "UTF-8"));
        ArrayList<String> labels = new ArrayList<>();
        ArrayList<INDArray> arrays = new ArrayList<>();
        VocabCache<VocabWord> vocabCache = new AbstractCache.Builder<VocabWord>().build();
        String line = "";
        while ((line = reader.readLine()) != null) {
            String[] split = line.split(" ");
            split[1] = split[1].replaceAll(whitespaceReplacement, " ");
            VocabWord word = new VocabWord(1.0, split[1]);
            if (split[0].equals("L")) {
                // we have label element here
                word.setSpecial(true);
                word.markAsLabel(true);
                labels.add(word.getLabel());
            } else if (split[0].equals("E")) {
                // we have usual element, aka word here
                word.setSpecial(false);
                word.markAsLabel(false);
            } else
                throw new IllegalStateException("Source stream doesn't looks like ParagraphVectors serialized model");
            // this particular line is just for backward compatibility with InMemoryLookupCache
            word.setIndex(vocabCache.numWords());
            vocabCache.addToken(word);
            vocabCache.addWordToIndex(word.getIndex(), word.getLabel());
            // backward compatibility code
            vocabCache.putVocabWord(word.getLabel());
            float[] vector = new float[split.length - 2];
            for (int i = 2; i < split.length; i++) {
                vector[i - 2] = Float.parseFloat(split[i]);
            }
            INDArray row = Nd4j.create(vector);
            arrays.add(row);
        }
        // now we create syn0 matrix, using previously fetched rows
        /*INDArray syn = Nd4j.create(new int[]{arrays.size(), arrays.get(0).columns()});
            for (int i = 0; i < syn.rows(); i++) {
                syn.putRow(i, arrays.get(i));
            }*/
        INDArray syn = Nd4j.vstack(arrays);
        InMemoryLookupTable<VocabWord> lookupTable = (InMemoryLookupTable<VocabWord>) new InMemoryLookupTable.Builder<VocabWord>().vectorLength(arrays.get(0).columns()).useAdaGrad(false).cache(vocabCache).build();
        Nd4j.clearNans(syn);
        lookupTable.setSyn0(syn);
        LabelsSource source = new LabelsSource(labels);
        ParagraphVectors vectors = new ParagraphVectors.Builder().labelsSource(source).vocabCache(vocabCache).lookupTable(lookupTable).modelUtils(new BasicModelUtils<VocabWord>()).build();
        try {
            reader.close();
        } catch (Exception e) {
        }
        vectors.extractLabels();
        return vectors;
    } catch (Exception e) {
        throw new RuntimeException(e);
    }
}
Also used : ND4JIllegalStateException(org.nd4j.linalg.exception.ND4JIllegalStateException) ArrayList(java.util.ArrayList) VocabWord(org.deeplearning4j.models.word2vec.VocabWord) AbstractCache(org.deeplearning4j.models.word2vec.wordstore.inmemory.AbstractCache) ParagraphVectors(org.deeplearning4j.models.paragraphvectors.ParagraphVectors) DL4JInvalidInputException(org.deeplearning4j.exception.DL4JInvalidInputException) ND4JIllegalStateException(org.nd4j.linalg.exception.ND4JIllegalStateException) InMemoryLookupTable(org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable) INDArray(org.nd4j.linalg.api.ndarray.INDArray) BasicModelUtils(org.deeplearning4j.models.embeddings.reader.impl.BasicModelUtils) LabelsSource(org.deeplearning4j.text.documentiterator.LabelsSource)

Aggregations

ParagraphVectors (org.deeplearning4j.models.paragraphvectors.ParagraphVectors)4 VocabWord (org.deeplearning4j.models.word2vec.VocabWord)4 ArrayList (java.util.ArrayList)3 InMemoryLookupTable (org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable)3 AbstractCache (org.deeplearning4j.models.word2vec.wordstore.inmemory.AbstractCache)3 File (java.io.File)2 Test (org.junit.Test)2 INDArray (org.nd4j.linalg.api.ndarray.INDArray)2 GZIPInputStream (java.util.zip.GZIPInputStream)1 ZipEntry (java.util.zip.ZipEntry)1 ZipFile (java.util.zip.ZipFile)1 DL4JInvalidInputException (org.deeplearning4j.exception.DL4JInvalidInputException)1 VectorsConfiguration (org.deeplearning4j.models.embeddings.loader.VectorsConfiguration)1 BasicModelUtils (org.deeplearning4j.models.embeddings.reader.impl.BasicModelUtils)1 StaticWord2Vec (org.deeplearning4j.models.word2vec.StaticWord2Vec)1 Word2Vec (org.deeplearning4j.models.word2vec.Word2Vec)1 LabelsSource (org.deeplearning4j.text.documentiterator.LabelsSource)1 ND4JIllegalStateException (org.nd4j.linalg.exception.ND4JIllegalStateException)1