Search in sources :

Example 1 with LabelAwareSentenceIterator

use of org.deeplearning4j.text.sentenceiterator.labelaware.LabelAwareSentenceIterator in project deeplearning4j by deeplearning4j.

the class SentenceIteratorConverter method nextDocument.

@Override
public LabelledDocument nextDocument() {
    LabelledDocument document = new LabelledDocument();
    document.setContent(backendIterator.nextSentence());
    if (backendIterator instanceof LabelAwareSentenceIterator) {
        List<String> labels = ((LabelAwareSentenceIterator) backendIterator).currentLabels();
        if (labels != null) {
            for (String label : labels) {
                document.addLabel(label);
                generator.storeLabel(label);
            }
        } else {
            String label = ((LabelAwareSentenceIterator) backendIterator).currentLabel();
            if (labels != null) {
                document.addLabel(label);
                generator.storeLabel(label);
            }
        }
    } else if (generator != null)
        document.addLabel(generator.nextLabel());
    return document;
}
Also used : LabelledDocument(org.deeplearning4j.text.documentiterator.LabelledDocument) LabelAwareSentenceIterator(org.deeplearning4j.text.sentenceiterator.labelaware.LabelAwareSentenceIterator)

Example 2 with LabelAwareSentenceIterator

use of org.deeplearning4j.text.sentenceiterator.labelaware.LabelAwareSentenceIterator in project deeplearning4j by deeplearning4j.

the class SentenceIteratorTest method testLabelAware.

@Test
public void testLabelAware() throws Exception {
    String s = "1; hello";
    ByteArrayInputStream bis = new ByteArrayInputStream(s.getBytes());
    LabelAwareSentenceIterator labelAwareSentenceIterator = new LabelAwareListSentenceIterator(bis, ";", 0, 1);
    assertTrue(labelAwareSentenceIterator.hasNext());
    labelAwareSentenceIterator.nextSentence();
    assertEquals("1", labelAwareSentenceIterator.currentLabel());
    InputStream is2 = new ClassPathResource("labelawaresentenceiterator.txt").getInputStream();
    LabelAwareSentenceIterator labelAwareSentenceIterator2 = new LabelAwareListSentenceIterator(is2, ";", 0, 1);
    int count = 0;
    Map<Integer, String> labels = new HashMap<>();
    while (labelAwareSentenceIterator2.hasNext()) {
        String sentence = labelAwareSentenceIterator2.nextSentence();
        labels.put(count, labelAwareSentenceIterator2.currentLabel());
        count++;
    }
    assertEquals("SENT37", labels.get(0));
    assertEquals("SENT38", labels.get(1));
    assertEquals("SENT39", labels.get(2));
    assertEquals("SENT42", labels.get(3));
    assertEquals(4, count);
}
Also used : ByteArrayInputStream(java.io.ByteArrayInputStream) HashMap(java.util.HashMap) ByteArrayInputStream(java.io.ByteArrayInputStream) InputStream(java.io.InputStream) LabelAwareSentenceIterator(org.deeplearning4j.text.sentenceiterator.labelaware.LabelAwareSentenceIterator) LabelAwareListSentenceIterator(org.deeplearning4j.text.sentenceiterator.labelaware.LabelAwareListSentenceIterator) ClassPathResource(org.datavec.api.util.ClassPathResource) Test(org.junit.Test)

Example 3 with LabelAwareSentenceIterator

use of org.deeplearning4j.text.sentenceiterator.labelaware.LabelAwareSentenceIterator in project deeplearning4j by deeplearning4j.

the class BagOfWordsVectorizerTest method testBagOfWordsVectorizer.

@Test
public void testBagOfWordsVectorizer() throws Exception {
    File rootDir = new ClassPathResource("rootdir").getFile();
    LabelAwareSentenceIterator iter = new LabelAwareFileSentenceIterator(rootDir);
    List<String> labels = Arrays.asList("label1", "label2");
    TokenizerFactory tokenizerFactory = new DefaultTokenizerFactory();
    BagOfWordsVectorizer vectorizer = new BagOfWordsVectorizer.Builder().setMinWordFrequency(1).setStopWords(new ArrayList<String>()).setTokenizerFactory(tokenizerFactory).setIterator(iter).allowParallelTokenization(false).build();
    vectorizer.fit();
    VocabWord word = vectorizer.getVocabCache().wordFor("file.");
    assumeNotNull(word);
    assertEquals(word, vectorizer.getVocabCache().tokenFor("file."));
    assertEquals(2, vectorizer.getVocabCache().totalNumberOfDocs());
    assertEquals(2, word.getSequencesCount());
    assertEquals(2, word.getElementFrequency(), 0.1);
    VocabWord word1 = vectorizer.getVocabCache().wordFor("1");
    assertEquals(1, word1.getSequencesCount());
    assertEquals(1, word1.getElementFrequency(), 0.1);
    log.info("Labels used: " + vectorizer.getLabelsSource().getLabels());
    assertEquals(2, vectorizer.getLabelsSource().getNumberOfLabelsUsed());
    ///////////////////
    INDArray array = vectorizer.transform("This is 2 file.");
    log.info("Transformed array: " + array);
    assertEquals(5, array.columns());
    VocabCache<VocabWord> vocabCache = vectorizer.getVocabCache();
    assertEquals(2, array.getDouble(vocabCache.tokenFor("This").getIndex()), 0.1);
    assertEquals(2, array.getDouble(vocabCache.tokenFor("is").getIndex()), 0.1);
    assertEquals(2, array.getDouble(vocabCache.tokenFor("file.").getIndex()), 0.1);
    assertEquals(0, array.getDouble(vocabCache.tokenFor("1").getIndex()), 0.1);
    assertEquals(1, array.getDouble(vocabCache.tokenFor("2").getIndex()), 0.1);
    DataSet dataSet = vectorizer.vectorize("This is 2 file.", "label2");
    assertEquals(array, dataSet.getFeatureMatrix());
    INDArray labelz = dataSet.getLabels();
    log.info("Labels array: " + labelz);
    int idx2 = ((IndexAccumulation) Nd4j.getExecutioner().exec(new IMax(labelz))).getFinalResult();
    //        assertEquals(1.0, dataSet.getLabels().getDouble(0), 0.1);
    //        assertEquals(0.0, dataSet.getLabels().getDouble(1), 0.1);
    dataSet = vectorizer.vectorize("This is 1 file.", "label1");
    assertEquals(2, dataSet.getFeatureMatrix().getDouble(vocabCache.tokenFor("This").getIndex()), 0.1);
    assertEquals(2, dataSet.getFeatureMatrix().getDouble(vocabCache.tokenFor("is").getIndex()), 0.1);
    assertEquals(2, dataSet.getFeatureMatrix().getDouble(vocabCache.tokenFor("file.").getIndex()), 0.1);
    assertEquals(1, dataSet.getFeatureMatrix().getDouble(vocabCache.tokenFor("1").getIndex()), 0.1);
    assertEquals(0, dataSet.getFeatureMatrix().getDouble(vocabCache.tokenFor("2").getIndex()), 0.1);
    int idx1 = ((IndexAccumulation) Nd4j.getExecutioner().exec(new IMax(dataSet.getLabels()))).getFinalResult();
    //assertEquals(0.0, dataSet.getLabels().getDouble(0), 0.1);
    //assertEquals(1.0, dataSet.getLabels().getDouble(1), 0.1);
    assertNotEquals(idx2, idx1);
    // Serialization check
    File tempFile = File.createTempFile("fdsf", "fdfsdf");
    tempFile.deleteOnExit();
    SerializationUtils.saveObject(vectorizer, tempFile);
    BagOfWordsVectorizer vectorizer2 = SerializationUtils.readObject(tempFile);
    vectorizer2.setTokenizerFactory(tokenizerFactory);
    dataSet = vectorizer2.vectorize("This is 2 file.", "label2");
    assertEquals(array, dataSet.getFeatureMatrix());
}
Also used : TokenizerFactory(org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory) DefaultTokenizerFactory(org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFactory) DataSet(org.nd4j.linalg.dataset.DataSet) VocabWord(org.deeplearning4j.models.word2vec.VocabWord) ClassPathResource(org.datavec.api.util.ClassPathResource) LabelAwareFileSentenceIterator(org.deeplearning4j.text.sentenceiterator.labelaware.LabelAwareFileSentenceIterator) DefaultTokenizerFactory(org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFactory) INDArray(org.nd4j.linalg.api.ndarray.INDArray) IMax(org.nd4j.linalg.api.ops.impl.indexaccum.IMax) LabelAwareSentenceIterator(org.deeplearning4j.text.sentenceiterator.labelaware.LabelAwareSentenceIterator) File(java.io.File) IndexAccumulation(org.nd4j.linalg.api.ops.IndexAccumulation) Test(org.junit.Test)

Example 4 with LabelAwareSentenceIterator

use of org.deeplearning4j.text.sentenceiterator.labelaware.LabelAwareSentenceIterator in project deeplearning4j by deeplearning4j.

the class TfidfVectorizerTest method testTfIdfVectorizer.

@Test
public void testTfIdfVectorizer() throws Exception {
    File rootDir = new ClassPathResource("tripledir").getFile();
    LabelAwareSentenceIterator iter = new LabelAwareFileSentenceIterator(rootDir);
    TokenizerFactory tokenizerFactory = new DefaultTokenizerFactory();
    TfidfVectorizer vectorizer = new TfidfVectorizer.Builder().setMinWordFrequency(1).setStopWords(new ArrayList<String>()).setTokenizerFactory(tokenizerFactory).setIterator(iter).allowParallelTokenization(false).build();
    vectorizer.fit();
    VocabWord word = vectorizer.getVocabCache().wordFor("file.");
    assumeNotNull(word);
    assertEquals(word, vectorizer.getVocabCache().tokenFor("file."));
    assertEquals(3, vectorizer.getVocabCache().totalNumberOfDocs());
    assertEquals(3, word.getSequencesCount());
    assertEquals(3, word.getElementFrequency(), 0.1);
    VocabWord word1 = vectorizer.getVocabCache().wordFor("1");
    assertEquals(1, word1.getSequencesCount());
    assertEquals(1, word1.getElementFrequency(), 0.1);
    log.info("Labels used: " + vectorizer.getLabelsSource().getLabels());
    assertEquals(3, vectorizer.getLabelsSource().getNumberOfLabelsUsed());
    assertEquals(3, vectorizer.getVocabCache().totalNumberOfDocs());
    assertEquals(11, vectorizer.numWordsEncountered());
    INDArray vector = vectorizer.transform("This is 3 file.");
    log.info("TF-IDF vector: " + Arrays.toString(vector.data().asDouble()));
    VocabCache<VocabWord> vocabCache = vectorizer.getVocabCache();
    assertEquals(.04402, vector.getDouble(vocabCache.tokenFor("This").getIndex()), 0.001);
    assertEquals(.04402, vector.getDouble(vocabCache.tokenFor("is").getIndex()), 0.001);
    assertEquals(0.119, vector.getDouble(vocabCache.tokenFor("3").getIndex()), 0.001);
    assertEquals(0, vector.getDouble(vocabCache.tokenFor("file.").getIndex()), 0.001);
    DataSet dataSet = vectorizer.vectorize("This is 3 file.", "label3");
    //assertEquals(0.0, dataSet.getLabels().getDouble(0), 0.1);
    //assertEquals(0.0, dataSet.getLabels().getDouble(1), 0.1);
    //assertEquals(1.0, dataSet.getLabels().getDouble(2), 0.1);
    int cnt = 0;
    for (int i = 0; i < 3; i++) {
        if (dataSet.getLabels().getDouble(i) > 0.1)
            cnt++;
    }
    assertEquals(1, cnt);
    File tempFile = File.createTempFile("somefile", "Dsdas");
    tempFile.deleteOnExit();
    SerializationUtils.saveObject(vectorizer, tempFile);
    TfidfVectorizer vectorizer2 = SerializationUtils.readObject(tempFile);
    vectorizer2.setTokenizerFactory(tokenizerFactory);
    dataSet = vectorizer2.vectorize("This is 3 file.", "label2");
    assertEquals(vector, dataSet.getFeatureMatrix());
}
Also used : TokenizerFactory(org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory) DefaultTokenizerFactory(org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFactory) DataSet(org.nd4j.linalg.dataset.DataSet) VocabWord(org.deeplearning4j.models.word2vec.VocabWord) ClassPathResource(org.datavec.api.util.ClassPathResource) LabelAwareFileSentenceIterator(org.deeplearning4j.text.sentenceiterator.labelaware.LabelAwareFileSentenceIterator) DefaultTokenizerFactory(org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFactory) INDArray(org.nd4j.linalg.api.ndarray.INDArray) LabelAwareSentenceIterator(org.deeplearning4j.text.sentenceiterator.labelaware.LabelAwareSentenceIterator) File(java.io.File) Test(org.junit.Test)

Aggregations

LabelAwareSentenceIterator (org.deeplearning4j.text.sentenceiterator.labelaware.LabelAwareSentenceIterator)4 ClassPathResource (org.datavec.api.util.ClassPathResource)3 Test (org.junit.Test)3 File (java.io.File)2 VocabWord (org.deeplearning4j.models.word2vec.VocabWord)2 LabelAwareFileSentenceIterator (org.deeplearning4j.text.sentenceiterator.labelaware.LabelAwareFileSentenceIterator)2 DefaultTokenizerFactory (org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFactory)2 TokenizerFactory (org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory)2 INDArray (org.nd4j.linalg.api.ndarray.INDArray)2 DataSet (org.nd4j.linalg.dataset.DataSet)2 ByteArrayInputStream (java.io.ByteArrayInputStream)1 InputStream (java.io.InputStream)1 HashMap (java.util.HashMap)1 LabelledDocument (org.deeplearning4j.text.documentiterator.LabelledDocument)1 LabelAwareListSentenceIterator (org.deeplearning4j.text.sentenceiterator.labelaware.LabelAwareListSentenceIterator)1 IndexAccumulation (org.nd4j.linalg.api.ops.IndexAccumulation)1 IMax (org.nd4j.linalg.api.ops.impl.indexaccum.IMax)1