Search in sources :

Example 1 with LabelAwareFileSentenceIterator

use of org.deeplearning4j.text.sentenceiterator.labelaware.LabelAwareFileSentenceIterator in project deeplearning4j by deeplearning4j.

the class Word2VecIteratorTest method testLabeledExample.

@Test
public void testLabeledExample() throws Exception {
    INDArray unk = vec.getWordVectorMatrix(Word2Vec.DEFAULT_UNK);
    assertNotEquals(null, unk);
    unk = vec.getWordVectorMatrix("2131241sdasdas");
    assertNotEquals(null, unk);
    Word2VecDataSetIterator iter = new Word2VecDataSetIterator(vec, new LabelAwareFileSentenceIterator(null, new ClassPathResource("labeled/").getFile()), Arrays.asList("negative", "positive", "neutral"));
    DataSet next = iter.next();
}
Also used : INDArray(org.nd4j.linalg.api.ndarray.INDArray) DataSet(org.nd4j.linalg.dataset.DataSet) ClassPathResource(org.datavec.api.util.ClassPathResource) LabelAwareFileSentenceIterator(org.deeplearning4j.text.sentenceiterator.labelaware.LabelAwareFileSentenceIterator) Test(org.junit.Test)

Example 2 with LabelAwareFileSentenceIterator

use of org.deeplearning4j.text.sentenceiterator.labelaware.LabelAwareFileSentenceIterator in project deeplearning4j by deeplearning4j.

the class BagOfWordsVectorizerTest method testBagOfWordsVectorizer.

@Test
public void testBagOfWordsVectorizer() throws Exception {
    File rootDir = new ClassPathResource("rootdir").getFile();
    LabelAwareSentenceIterator iter = new LabelAwareFileSentenceIterator(rootDir);
    List<String> labels = Arrays.asList("label1", "label2");
    TokenizerFactory tokenizerFactory = new DefaultTokenizerFactory();
    BagOfWordsVectorizer vectorizer = new BagOfWordsVectorizer.Builder().setMinWordFrequency(1).setStopWords(new ArrayList<String>()).setTokenizerFactory(tokenizerFactory).setIterator(iter).allowParallelTokenization(false).build();
    vectorizer.fit();
    VocabWord word = vectorizer.getVocabCache().wordFor("file.");
    assumeNotNull(word);
    assertEquals(word, vectorizer.getVocabCache().tokenFor("file."));
    assertEquals(2, vectorizer.getVocabCache().totalNumberOfDocs());
    assertEquals(2, word.getSequencesCount());
    assertEquals(2, word.getElementFrequency(), 0.1);
    VocabWord word1 = vectorizer.getVocabCache().wordFor("1");
    assertEquals(1, word1.getSequencesCount());
    assertEquals(1, word1.getElementFrequency(), 0.1);
    log.info("Labels used: " + vectorizer.getLabelsSource().getLabels());
    assertEquals(2, vectorizer.getLabelsSource().getNumberOfLabelsUsed());
    ///////////////////
    INDArray array = vectorizer.transform("This is 2 file.");
    log.info("Transformed array: " + array);
    assertEquals(5, array.columns());
    VocabCache<VocabWord> vocabCache = vectorizer.getVocabCache();
    assertEquals(2, array.getDouble(vocabCache.tokenFor("This").getIndex()), 0.1);
    assertEquals(2, array.getDouble(vocabCache.tokenFor("is").getIndex()), 0.1);
    assertEquals(2, array.getDouble(vocabCache.tokenFor("file.").getIndex()), 0.1);
    assertEquals(0, array.getDouble(vocabCache.tokenFor("1").getIndex()), 0.1);
    assertEquals(1, array.getDouble(vocabCache.tokenFor("2").getIndex()), 0.1);
    DataSet dataSet = vectorizer.vectorize("This is 2 file.", "label2");
    assertEquals(array, dataSet.getFeatureMatrix());
    INDArray labelz = dataSet.getLabels();
    log.info("Labels array: " + labelz);
    int idx2 = ((IndexAccumulation) Nd4j.getExecutioner().exec(new IMax(labelz))).getFinalResult();
    //        assertEquals(1.0, dataSet.getLabels().getDouble(0), 0.1);
    //        assertEquals(0.0, dataSet.getLabels().getDouble(1), 0.1);
    dataSet = vectorizer.vectorize("This is 1 file.", "label1");
    assertEquals(2, dataSet.getFeatureMatrix().getDouble(vocabCache.tokenFor("This").getIndex()), 0.1);
    assertEquals(2, dataSet.getFeatureMatrix().getDouble(vocabCache.tokenFor("is").getIndex()), 0.1);
    assertEquals(2, dataSet.getFeatureMatrix().getDouble(vocabCache.tokenFor("file.").getIndex()), 0.1);
    assertEquals(1, dataSet.getFeatureMatrix().getDouble(vocabCache.tokenFor("1").getIndex()), 0.1);
    assertEquals(0, dataSet.getFeatureMatrix().getDouble(vocabCache.tokenFor("2").getIndex()), 0.1);
    int idx1 = ((IndexAccumulation) Nd4j.getExecutioner().exec(new IMax(dataSet.getLabels()))).getFinalResult();
    //assertEquals(0.0, dataSet.getLabels().getDouble(0), 0.1);
    //assertEquals(1.0, dataSet.getLabels().getDouble(1), 0.1);
    assertNotEquals(idx2, idx1);
    // Serialization check
    File tempFile = File.createTempFile("fdsf", "fdfsdf");
    tempFile.deleteOnExit();
    SerializationUtils.saveObject(vectorizer, tempFile);
    BagOfWordsVectorizer vectorizer2 = SerializationUtils.readObject(tempFile);
    vectorizer2.setTokenizerFactory(tokenizerFactory);
    dataSet = vectorizer2.vectorize("This is 2 file.", "label2");
    assertEquals(array, dataSet.getFeatureMatrix());
}
Also used : TokenizerFactory(org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory) DefaultTokenizerFactory(org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFactory) DataSet(org.nd4j.linalg.dataset.DataSet) VocabWord(org.deeplearning4j.models.word2vec.VocabWord) ClassPathResource(org.datavec.api.util.ClassPathResource) LabelAwareFileSentenceIterator(org.deeplearning4j.text.sentenceiterator.labelaware.LabelAwareFileSentenceIterator) DefaultTokenizerFactory(org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFactory) INDArray(org.nd4j.linalg.api.ndarray.INDArray) IMax(org.nd4j.linalg.api.ops.impl.indexaccum.IMax) LabelAwareSentenceIterator(org.deeplearning4j.text.sentenceiterator.labelaware.LabelAwareSentenceIterator) File(java.io.File) IndexAccumulation(org.nd4j.linalg.api.ops.IndexAccumulation) Test(org.junit.Test)

Example 3 with LabelAwareFileSentenceIterator

use of org.deeplearning4j.text.sentenceiterator.labelaware.LabelAwareFileSentenceIterator in project deeplearning4j by deeplearning4j.

the class TfidfVectorizerTest method testTfIdfVectorizer.

@Test
public void testTfIdfVectorizer() throws Exception {
    File rootDir = new ClassPathResource("tripledir").getFile();
    LabelAwareSentenceIterator iter = new LabelAwareFileSentenceIterator(rootDir);
    TokenizerFactory tokenizerFactory = new DefaultTokenizerFactory();
    TfidfVectorizer vectorizer = new TfidfVectorizer.Builder().setMinWordFrequency(1).setStopWords(new ArrayList<String>()).setTokenizerFactory(tokenizerFactory).setIterator(iter).allowParallelTokenization(false).build();
    vectorizer.fit();
    VocabWord word = vectorizer.getVocabCache().wordFor("file.");
    assumeNotNull(word);
    assertEquals(word, vectorizer.getVocabCache().tokenFor("file."));
    assertEquals(3, vectorizer.getVocabCache().totalNumberOfDocs());
    assertEquals(3, word.getSequencesCount());
    assertEquals(3, word.getElementFrequency(), 0.1);
    VocabWord word1 = vectorizer.getVocabCache().wordFor("1");
    assertEquals(1, word1.getSequencesCount());
    assertEquals(1, word1.getElementFrequency(), 0.1);
    log.info("Labels used: " + vectorizer.getLabelsSource().getLabels());
    assertEquals(3, vectorizer.getLabelsSource().getNumberOfLabelsUsed());
    assertEquals(3, vectorizer.getVocabCache().totalNumberOfDocs());
    assertEquals(11, vectorizer.numWordsEncountered());
    INDArray vector = vectorizer.transform("This is 3 file.");
    log.info("TF-IDF vector: " + Arrays.toString(vector.data().asDouble()));
    VocabCache<VocabWord> vocabCache = vectorizer.getVocabCache();
    assertEquals(.04402, vector.getDouble(vocabCache.tokenFor("This").getIndex()), 0.001);
    assertEquals(.04402, vector.getDouble(vocabCache.tokenFor("is").getIndex()), 0.001);
    assertEquals(0.119, vector.getDouble(vocabCache.tokenFor("3").getIndex()), 0.001);
    assertEquals(0, vector.getDouble(vocabCache.tokenFor("file.").getIndex()), 0.001);
    DataSet dataSet = vectorizer.vectorize("This is 3 file.", "label3");
    //assertEquals(0.0, dataSet.getLabels().getDouble(0), 0.1);
    //assertEquals(0.0, dataSet.getLabels().getDouble(1), 0.1);
    //assertEquals(1.0, dataSet.getLabels().getDouble(2), 0.1);
    int cnt = 0;
    for (int i = 0; i < 3; i++) {
        if (dataSet.getLabels().getDouble(i) > 0.1)
            cnt++;
    }
    assertEquals(1, cnt);
    File tempFile = File.createTempFile("somefile", "Dsdas");
    tempFile.deleteOnExit();
    SerializationUtils.saveObject(vectorizer, tempFile);
    TfidfVectorizer vectorizer2 = SerializationUtils.readObject(tempFile);
    vectorizer2.setTokenizerFactory(tokenizerFactory);
    dataSet = vectorizer2.vectorize("This is 3 file.", "label2");
    assertEquals(vector, dataSet.getFeatureMatrix());
}
Also used : TokenizerFactory(org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory) DefaultTokenizerFactory(org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFactory) DataSet(org.nd4j.linalg.dataset.DataSet) VocabWord(org.deeplearning4j.models.word2vec.VocabWord) ClassPathResource(org.datavec.api.util.ClassPathResource) LabelAwareFileSentenceIterator(org.deeplearning4j.text.sentenceiterator.labelaware.LabelAwareFileSentenceIterator) DefaultTokenizerFactory(org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFactory) INDArray(org.nd4j.linalg.api.ndarray.INDArray) LabelAwareSentenceIterator(org.deeplearning4j.text.sentenceiterator.labelaware.LabelAwareSentenceIterator) File(java.io.File) Test(org.junit.Test)

Aggregations

ClassPathResource (org.datavec.api.util.ClassPathResource)3 LabelAwareFileSentenceIterator (org.deeplearning4j.text.sentenceiterator.labelaware.LabelAwareFileSentenceIterator)3 Test (org.junit.Test)3 INDArray (org.nd4j.linalg.api.ndarray.INDArray)3 DataSet (org.nd4j.linalg.dataset.DataSet)3 File (java.io.File)2 VocabWord (org.deeplearning4j.models.word2vec.VocabWord)2 LabelAwareSentenceIterator (org.deeplearning4j.text.sentenceiterator.labelaware.LabelAwareSentenceIterator)2 DefaultTokenizerFactory (org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFactory)2 TokenizerFactory (org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory)2 IndexAccumulation (org.nd4j.linalg.api.ops.IndexAccumulation)1 IMax (org.nd4j.linalg.api.ops.impl.indexaccum.IMax)1