Search in sources :

Example 11 with VocabWord

use of org.deeplearning4j.models.word2vec.VocabWord in project deeplearning4j by deeplearning4j.

the class ParagraphVectors method predict.

/**
     * This method takes raw text, applies tokenizer, and returns most probable label
     *
     * @param rawText
     * @return
     */
@Deprecated
public String predict(String rawText) {
    if (tokenizerFactory == null)
        throw new IllegalStateException("TokenizerFactory should be defined, prior to predict() call");
    List<String> tokens = tokenizerFactory.create(rawText).getTokens();
    List<VocabWord> document = new ArrayList<>();
    for (String token : tokens) {
        if (vocab.containsWord(token)) {
            document.add(vocab.wordFor(token));
        }
    }
    return predict(document);
}
Also used : ND4JIllegalStateException(org.nd4j.linalg.exception.ND4JIllegalStateException) VocabWord(org.deeplearning4j.models.word2vec.VocabWord)

Example 12 with VocabWord

use of org.deeplearning4j.models.word2vec.VocabWord in project deeplearning4j by deeplearning4j.

the class ParagraphVectors method extractLabels.

public void extractLabels() {
    Collection<VocabWord> vocabWordCollection = vocab.vocabWords();
    List<VocabWord> vocabWordList = new ArrayList<>();
    int[] indexArray;
    //Check if word has label and build a list out of the collection
    for (VocabWord vWord : vocabWordCollection) {
        if (vWord.isLabel()) {
            vocabWordList.add(vWord);
        }
    }
    //Build array of indexes in the order of the vocablist
    indexArray = new int[vocabWordList.size()];
    int i = 0;
    for (VocabWord vWord : vocabWordList) {
        indexArray[i] = vWord.getIndex();
        i++;
    }
    //pull the label rows and create new matrix
    if (i > 0) {
        labelsMatrix = Nd4j.pullRows(lookupTable.getWeights(), 1, indexArray);
        labelsList = vocabWordList;
    }
}
Also used : VocabWord(org.deeplearning4j.models.word2vec.VocabWord)

Example 13 with VocabWord

use of org.deeplearning4j.models.word2vec.VocabWord in project deeplearning4j by deeplearning4j.

the class ParagraphVectors method similarityToLabel.

/**
     * This method returns similarity of the document to specific label, based on mean value
     *
     * @param rawText
     * @param label
     * @return
     */
@Deprecated
public double similarityToLabel(String rawText, String label) {
    if (tokenizerFactory == null)
        throw new IllegalStateException("TokenizerFactory should be defined, prior to predict() call");
    List<String> tokens = tokenizerFactory.create(rawText).getTokens();
    List<VocabWord> document = new ArrayList<>();
    for (String token : tokens) {
        if (vocab.containsWord(token)) {
            document.add(vocab.wordFor(token));
        }
    }
    return similarityToLabel(document, label);
}
Also used : ND4JIllegalStateException(org.nd4j.linalg.exception.ND4JIllegalStateException) VocabWord(org.deeplearning4j.models.word2vec.VocabWord)

Example 14 with VocabWord

use of org.deeplearning4j.models.word2vec.VocabWord in project deeplearning4j by deeplearning4j.

the class WordVectorSerializerTest method testIndexPersistence.

@Test
public void testIndexPersistence() throws Exception {
    File inputFile = new ClassPathResource("/big/raw_sentences.txt").getFile();
    SentenceIterator iter = UimaSentenceIterator.createWithPath(inputFile.getAbsolutePath());
    // Split on white spaces in the line to get words
    TokenizerFactory t = new DefaultTokenizerFactory();
    t.setTokenPreProcessor(new CommonPreprocessor());
    Word2Vec vec = new Word2Vec.Builder().minWordFrequency(5).iterations(1).epochs(1).layerSize(100).stopWords(new ArrayList<String>()).useAdaGrad(false).negativeSample(5).seed(42).windowSize(5).iterate(iter).tokenizerFactory(t).build();
    vec.fit();
    VocabCache orig = vec.getVocab();
    File tempFile = File.createTempFile("temp", "w2v");
    tempFile.deleteOnExit();
    WordVectorSerializer.writeWordVectors(vec, tempFile);
    WordVectors vec2 = WordVectorSerializer.loadTxtVectors(tempFile);
    VocabCache rest = vec2.vocab();
    assertEquals(orig.totalNumberOfDocs(), rest.totalNumberOfDocs());
    for (VocabWord word : vec.getVocab().vocabWords()) {
        INDArray array1 = vec.getWordVectorMatrix(word.getLabel());
        INDArray array2 = vec2.getWordVectorMatrix(word.getLabel());
        assertEquals(array1, array2);
    }
}
Also used : TokenizerFactory(org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory) DefaultTokenizerFactory(org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFactory) ArrayList(java.util.ArrayList) VocabWord(org.deeplearning4j.models.word2vec.VocabWord) ClassPathResource(org.datavec.api.util.ClassPathResource) SentenceIterator(org.deeplearning4j.text.sentenceiterator.SentenceIterator) UimaSentenceIterator(org.deeplearning4j.text.sentenceiterator.UimaSentenceIterator) DefaultTokenizerFactory(org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFactory) CommonPreprocessor(org.deeplearning4j.text.tokenization.tokenizer.preprocessor.CommonPreprocessor) INDArray(org.nd4j.linalg.api.ndarray.INDArray) VocabCache(org.deeplearning4j.models.word2vec.wordstore.VocabCache) Word2Vec(org.deeplearning4j.models.word2vec.Word2Vec) WordVectors(org.deeplearning4j.models.embeddings.wordvectors.WordVectors) File(java.io.File) Test(org.junit.Test)

Example 15 with VocabWord

use of org.deeplearning4j.models.word2vec.VocabWord in project deeplearning4j by deeplearning4j.

the class WordVectorSerializerTest method testMalformedLabels1.

@Test
public void testMalformedLabels1() throws Exception {
    List<String> words = new ArrayList<>();
    words.add("test A");
    words.add("test B");
    words.add("test\nC");
    words.add("test`D");
    words.add("test_E");
    words.add("test 5");
    AbstractCache<VocabWord> vocabCache = new AbstractCache<>();
    int cnt = 0;
    for (String word : words) {
        vocabCache.addToken(new VocabWord(1.0, word));
        vocabCache.addWordToIndex(cnt, word);
        cnt++;
    }
    vocabCache.elementAtIndex(1).markAsLabel(true);
    InMemoryLookupTable<VocabWord> lookupTable = new InMemoryLookupTable<>(vocabCache, 10, false, 0.01, Nd4j.getRandom(), 0.0);
    lookupTable.resetWeights(true);
    assertNotEquals(null, lookupTable.getSyn0());
    assertNotEquals(null, lookupTable.getSyn1());
    assertNotEquals(null, lookupTable.getExpTable());
    assertEquals(null, lookupTable.getSyn1Neg());
    ParagraphVectors vec = new ParagraphVectors.Builder().lookupTable(lookupTable).vocabCache(vocabCache).build();
    File tempFile = File.createTempFile("temp", "w2v");
    tempFile.deleteOnExit();
    WordVectorSerializer.writeParagraphVectors(vec, tempFile);
    ParagraphVectors restoredVec = WordVectorSerializer.readParagraphVectors(tempFile);
    for (String word : words) {
        assertEquals(true, restoredVec.hasWord(word));
    }
    assertTrue(restoredVec.getVocab().elementAtIndex(1).isLabel());
}
Also used : InMemoryLookupTable(org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable) ArrayList(java.util.ArrayList) VocabWord(org.deeplearning4j.models.word2vec.VocabWord) AbstractCache(org.deeplearning4j.models.word2vec.wordstore.inmemory.AbstractCache) File(java.io.File) ParagraphVectors(org.deeplearning4j.models.paragraphvectors.ParagraphVectors) Test(org.junit.Test)

Aggregations

VocabWord (org.deeplearning4j.models.word2vec.VocabWord)110 Test (org.junit.Test)54 INDArray (org.nd4j.linalg.api.ndarray.INDArray)31 AbstractCache (org.deeplearning4j.models.word2vec.wordstore.inmemory.AbstractCache)26 ClassPathResource (org.datavec.api.util.ClassPathResource)23 BasicLineIterator (org.deeplearning4j.text.sentenceiterator.BasicLineIterator)22 File (java.io.File)20 InMemoryLookupTable (org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable)19 TokenizerFactory (org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory)19 ArrayList (java.util.ArrayList)17 DefaultTokenizerFactory (org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFactory)17 CommonPreprocessor (org.deeplearning4j.text.tokenization.tokenizer.preprocessor.CommonPreprocessor)15 SentenceIterator (org.deeplearning4j.text.sentenceiterator.SentenceIterator)14 AbstractSequenceIterator (org.deeplearning4j.models.sequencevectors.iterators.AbstractSequenceIterator)13 SentenceTransformer (org.deeplearning4j.models.sequencevectors.transformers.impl.SentenceTransformer)13 JavaSparkContext (org.apache.spark.api.java.JavaSparkContext)12 ND4JIllegalStateException (org.nd4j.linalg.exception.ND4JIllegalStateException)12 Sequence (org.deeplearning4j.models.sequencevectors.sequence.Sequence)11 Word2Vec (org.deeplearning4j.models.word2vec.Word2Vec)11 TextPipeline (org.deeplearning4j.spark.text.functions.TextPipeline)10