Search in sources :

Example 6 with BasicModelUtils

use of org.deeplearning4j.models.embeddings.reader.impl.BasicModelUtils in project deeplearning4j by deeplearning4j.

the class Word2VecTests method testRunWord2Vec.

@Test
public void testRunWord2Vec() throws Exception {
    // Strip white space before and after for each line
    SentenceIterator iter = new BasicLineIterator(inputFile.getAbsolutePath());
    // Split on white spaces in the line to get words
    TokenizerFactory t = new DefaultTokenizerFactory();
    t.setTokenPreProcessor(new CommonPreprocessor());
    Word2Vec vec = new Word2Vec.Builder().minWordFrequency(1).iterations(3).batchSize(64).layerSize(100).stopWords(new ArrayList<String>()).seed(42).learningRate(0.025).minLearningRate(0.001).sampling(0).elementsLearningAlgorithm(new SkipGram<VocabWord>()).epochs(1).windowSize(5).allowParallelTokenization(true).modelUtils(new BasicModelUtils<VocabWord>()).iterate(iter).tokenizerFactory(t).build();
    assertEquals(new ArrayList<String>(), vec.getStopWords());
    vec.fit();
    File tempFile = File.createTempFile("temp", "temp");
    tempFile.deleteOnExit();
    WordVectorSerializer.writeFullModel(vec, tempFile.getAbsolutePath());
    Collection<String> lst = vec.wordsNearest("day", 10);
    //log.info(Arrays.toString(lst.toArray()));
    printWords("day", lst, vec);
    assertEquals(10, lst.size());
    double sim = vec.similarity("day", "night");
    log.info("Day/night similarity: " + sim);
    assertTrue(sim < 1.0);
    assertTrue(sim > 0.4);
    assertTrue(lst.contains("week"));
    assertTrue(lst.contains("night"));
    assertTrue(lst.contains("year"));
    assertFalse(lst.contains(null));
    lst = vec.wordsNearest("day", 10);
    //log.info(Arrays.toString(lst.toArray()));
    printWords("day", lst, vec);
    assertTrue(lst.contains("week"));
    assertTrue(lst.contains("night"));
    assertTrue(lst.contains("year"));
    new File("cache.ser").delete();
    ArrayList<String> labels = new ArrayList<>();
    labels.add("day");
    labels.add("night");
    labels.add("week");
    INDArray matrix = vec.getWordVectors(labels);
    assertEquals(matrix.getRow(0), vec.getWordVectorMatrix("day"));
    assertEquals(matrix.getRow(1), vec.getWordVectorMatrix("night"));
    assertEquals(matrix.getRow(2), vec.getWordVectorMatrix("week"));
    WordVectorSerializer.writeWordVectors(vec, pathToWriteto);
}
Also used : BasicLineIterator(org.deeplearning4j.text.sentenceiterator.BasicLineIterator) TokenizerFactory(org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory) DefaultTokenizerFactory(org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFactory) SkipGram(org.deeplearning4j.models.embeddings.learning.impl.elements.SkipGram) ArrayList(java.util.ArrayList) SentenceIterator(org.deeplearning4j.text.sentenceiterator.SentenceIterator) UimaSentenceIterator(org.deeplearning4j.text.sentenceiterator.UimaSentenceIterator) DefaultTokenizerFactory(org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFactory) CommonPreprocessor(org.deeplearning4j.text.tokenization.tokenizer.preprocessor.CommonPreprocessor) BasicModelUtils(org.deeplearning4j.models.embeddings.reader.impl.BasicModelUtils) INDArray(org.nd4j.linalg.api.ndarray.INDArray) File(java.io.File) Test(org.junit.Test)

Example 7 with BasicModelUtils

use of org.deeplearning4j.models.embeddings.reader.impl.BasicModelUtils in project deeplearning4j by deeplearning4j.

the class WordVectorSerializer method fromPair.

/**
     * Load word vectors from the given pair
     *
     * @param pair
     *            the given pair
     * @return a read only word vectors impl based on the given lookup table and vocab
     */
public static Word2Vec fromPair(Pair<InMemoryLookupTable, VocabCache> pair) {
    Word2Vec vectors = new Word2Vec();
    vectors.setLookupTable(pair.getFirst());
    vectors.setVocab(pair.getSecond());
    vectors.setModelUtils(new BasicModelUtils());
    return vectors;
}
Also used : BasicModelUtils(org.deeplearning4j.models.embeddings.reader.impl.BasicModelUtils) StaticWord2Vec(org.deeplearning4j.models.word2vec.StaticWord2Vec) Word2Vec(org.deeplearning4j.models.word2vec.Word2Vec)

Example 8 with BasicModelUtils

use of org.deeplearning4j.models.embeddings.reader.impl.BasicModelUtils in project deeplearning4j by deeplearning4j.

the class WordVectorSerializer method readParagraphVectorsFromText.

/**
     * Restores previously serialized ParagraphVectors model
     *
     * Deprecation note: Please, consider using readParagraphVectors() method instead
     *
     * @param stream InputStream that contains previously serialized model
     * @return
     */
@Deprecated
public static ParagraphVectors readParagraphVectorsFromText(@NonNull InputStream stream) {
    try {
        BufferedReader reader = new BufferedReader(new InputStreamReader(stream, "UTF-8"));
        ArrayList<String> labels = new ArrayList<>();
        ArrayList<INDArray> arrays = new ArrayList<>();
        VocabCache<VocabWord> vocabCache = new AbstractCache.Builder<VocabWord>().build();
        String line = "";
        while ((line = reader.readLine()) != null) {
            String[] split = line.split(" ");
            split[1] = split[1].replaceAll(whitespaceReplacement, " ");
            VocabWord word = new VocabWord(1.0, split[1]);
            if (split[0].equals("L")) {
                // we have label element here
                word.setSpecial(true);
                word.markAsLabel(true);
                labels.add(word.getLabel());
            } else if (split[0].equals("E")) {
                // we have usual element, aka word here
                word.setSpecial(false);
                word.markAsLabel(false);
            } else
                throw new IllegalStateException("Source stream doesn't looks like ParagraphVectors serialized model");
            // this particular line is just for backward compatibility with InMemoryLookupCache
            word.setIndex(vocabCache.numWords());
            vocabCache.addToken(word);
            vocabCache.addWordToIndex(word.getIndex(), word.getLabel());
            // backward compatibility code
            vocabCache.putVocabWord(word.getLabel());
            float[] vector = new float[split.length - 2];
            for (int i = 2; i < split.length; i++) {
                vector[i - 2] = Float.parseFloat(split[i]);
            }
            INDArray row = Nd4j.create(vector);
            arrays.add(row);
        }
        // now we create syn0 matrix, using previously fetched rows
        /*INDArray syn = Nd4j.create(new int[]{arrays.size(), arrays.get(0).columns()});
            for (int i = 0; i < syn.rows(); i++) {
                syn.putRow(i, arrays.get(i));
            }*/
        INDArray syn = Nd4j.vstack(arrays);
        InMemoryLookupTable<VocabWord> lookupTable = (InMemoryLookupTable<VocabWord>) new InMemoryLookupTable.Builder<VocabWord>().vectorLength(arrays.get(0).columns()).useAdaGrad(false).cache(vocabCache).build();
        Nd4j.clearNans(syn);
        lookupTable.setSyn0(syn);
        LabelsSource source = new LabelsSource(labels);
        ParagraphVectors vectors = new ParagraphVectors.Builder().labelsSource(source).vocabCache(vocabCache).lookupTable(lookupTable).modelUtils(new BasicModelUtils<VocabWord>()).build();
        try {
            reader.close();
        } catch (Exception e) {
        }
        vectors.extractLabels();
        return vectors;
    } catch (Exception e) {
        throw new RuntimeException(e);
    }
}
Also used : ND4JIllegalStateException(org.nd4j.linalg.exception.ND4JIllegalStateException) ArrayList(java.util.ArrayList) VocabWord(org.deeplearning4j.models.word2vec.VocabWord) AbstractCache(org.deeplearning4j.models.word2vec.wordstore.inmemory.AbstractCache) ParagraphVectors(org.deeplearning4j.models.paragraphvectors.ParagraphVectors) DL4JInvalidInputException(org.deeplearning4j.exception.DL4JInvalidInputException) ND4JIllegalStateException(org.nd4j.linalg.exception.ND4JIllegalStateException) InMemoryLookupTable(org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable) INDArray(org.nd4j.linalg.api.ndarray.INDArray) BasicModelUtils(org.deeplearning4j.models.embeddings.reader.impl.BasicModelUtils) LabelsSource(org.deeplearning4j.text.documentiterator.LabelsSource)

Aggregations

BasicModelUtils (org.deeplearning4j.models.embeddings.reader.impl.BasicModelUtils)8 BasicLineIterator (org.deeplearning4j.text.sentenceiterator.BasicLineIterator)5 Word2Vec (org.deeplearning4j.models.word2vec.Word2Vec)4 SentenceIterator (org.deeplearning4j.text.sentenceiterator.SentenceIterator)4 DefaultTokenizerFactory (org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFactory)4 TokenizerFactory (org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory)4 Test (org.junit.Test)4 File (java.io.File)3 CommonPreprocessor (org.deeplearning4j.text.tokenization.tokenizer.preprocessor.CommonPreprocessor)3 INDArray (org.nd4j.linalg.api.ndarray.INDArray)3 ArrayList (java.util.ArrayList)2 InMemoryLookupTable (org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable)2 SkipGram (org.deeplearning4j.models.embeddings.learning.impl.elements.SkipGram)2 StaticWord2Vec (org.deeplearning4j.models.word2vec.StaticWord2Vec)2 VocabWord (org.deeplearning4j.models.word2vec.VocabWord)2 AbstractCache (org.deeplearning4j.models.word2vec.wordstore.inmemory.AbstractCache)2 UimaSentenceIterator (org.deeplearning4j.text.sentenceiterator.UimaSentenceIterator)2 AtomicInteger (java.util.concurrent.atomic.AtomicInteger)1 ZipFile (java.util.zip.ZipFile)1 ClassPathResource (org.datavec.api.util.ClassPathResource)1