use of org.deeplearning4j.text.sentenceiterator.labelaware.LabelAwareSentenceIterator in project deeplearning4j by deeplearning4j.
the class SentenceIteratorConverter method nextDocument.
@Override
public LabelledDocument nextDocument() {
LabelledDocument document = new LabelledDocument();
document.setContent(backendIterator.nextSentence());
if (backendIterator instanceof LabelAwareSentenceIterator) {
List<String> labels = ((LabelAwareSentenceIterator) backendIterator).currentLabels();
if (labels != null) {
for (String label : labels) {
document.addLabel(label);
generator.storeLabel(label);
}
} else {
String label = ((LabelAwareSentenceIterator) backendIterator).currentLabel();
if (labels != null) {
document.addLabel(label);
generator.storeLabel(label);
}
}
} else if (generator != null)
document.addLabel(generator.nextLabel());
return document;
}
use of org.deeplearning4j.text.sentenceiterator.labelaware.LabelAwareSentenceIterator in project deeplearning4j by deeplearning4j.
the class SentenceIteratorTest method testLabelAware.
@Test
public void testLabelAware() throws Exception {
String s = "1; hello";
ByteArrayInputStream bis = new ByteArrayInputStream(s.getBytes());
LabelAwareSentenceIterator labelAwareSentenceIterator = new LabelAwareListSentenceIterator(bis, ";", 0, 1);
assertTrue(labelAwareSentenceIterator.hasNext());
labelAwareSentenceIterator.nextSentence();
assertEquals("1", labelAwareSentenceIterator.currentLabel());
InputStream is2 = new ClassPathResource("labelawaresentenceiterator.txt").getInputStream();
LabelAwareSentenceIterator labelAwareSentenceIterator2 = new LabelAwareListSentenceIterator(is2, ";", 0, 1);
int count = 0;
Map<Integer, String> labels = new HashMap<>();
while (labelAwareSentenceIterator2.hasNext()) {
String sentence = labelAwareSentenceIterator2.nextSentence();
labels.put(count, labelAwareSentenceIterator2.currentLabel());
count++;
}
assertEquals("SENT37", labels.get(0));
assertEquals("SENT38", labels.get(1));
assertEquals("SENT39", labels.get(2));
assertEquals("SENT42", labels.get(3));
assertEquals(4, count);
}
use of org.deeplearning4j.text.sentenceiterator.labelaware.LabelAwareSentenceIterator in project deeplearning4j by deeplearning4j.
the class BagOfWordsVectorizerTest method testBagOfWordsVectorizer.
@Test
public void testBagOfWordsVectorizer() throws Exception {
File rootDir = new ClassPathResource("rootdir").getFile();
LabelAwareSentenceIterator iter = new LabelAwareFileSentenceIterator(rootDir);
List<String> labels = Arrays.asList("label1", "label2");
TokenizerFactory tokenizerFactory = new DefaultTokenizerFactory();
BagOfWordsVectorizer vectorizer = new BagOfWordsVectorizer.Builder().setMinWordFrequency(1).setStopWords(new ArrayList<String>()).setTokenizerFactory(tokenizerFactory).setIterator(iter).allowParallelTokenization(false).build();
vectorizer.fit();
VocabWord word = vectorizer.getVocabCache().wordFor("file.");
assumeNotNull(word);
assertEquals(word, vectorizer.getVocabCache().tokenFor("file."));
assertEquals(2, vectorizer.getVocabCache().totalNumberOfDocs());
assertEquals(2, word.getSequencesCount());
assertEquals(2, word.getElementFrequency(), 0.1);
VocabWord word1 = vectorizer.getVocabCache().wordFor("1");
assertEquals(1, word1.getSequencesCount());
assertEquals(1, word1.getElementFrequency(), 0.1);
log.info("Labels used: " + vectorizer.getLabelsSource().getLabels());
assertEquals(2, vectorizer.getLabelsSource().getNumberOfLabelsUsed());
///////////////////
INDArray array = vectorizer.transform("This is 2 file.");
log.info("Transformed array: " + array);
assertEquals(5, array.columns());
VocabCache<VocabWord> vocabCache = vectorizer.getVocabCache();
assertEquals(2, array.getDouble(vocabCache.tokenFor("This").getIndex()), 0.1);
assertEquals(2, array.getDouble(vocabCache.tokenFor("is").getIndex()), 0.1);
assertEquals(2, array.getDouble(vocabCache.tokenFor("file.").getIndex()), 0.1);
assertEquals(0, array.getDouble(vocabCache.tokenFor("1").getIndex()), 0.1);
assertEquals(1, array.getDouble(vocabCache.tokenFor("2").getIndex()), 0.1);
DataSet dataSet = vectorizer.vectorize("This is 2 file.", "label2");
assertEquals(array, dataSet.getFeatureMatrix());
INDArray labelz = dataSet.getLabels();
log.info("Labels array: " + labelz);
int idx2 = ((IndexAccumulation) Nd4j.getExecutioner().exec(new IMax(labelz))).getFinalResult();
// assertEquals(1.0, dataSet.getLabels().getDouble(0), 0.1);
// assertEquals(0.0, dataSet.getLabels().getDouble(1), 0.1);
dataSet = vectorizer.vectorize("This is 1 file.", "label1");
assertEquals(2, dataSet.getFeatureMatrix().getDouble(vocabCache.tokenFor("This").getIndex()), 0.1);
assertEquals(2, dataSet.getFeatureMatrix().getDouble(vocabCache.tokenFor("is").getIndex()), 0.1);
assertEquals(2, dataSet.getFeatureMatrix().getDouble(vocabCache.tokenFor("file.").getIndex()), 0.1);
assertEquals(1, dataSet.getFeatureMatrix().getDouble(vocabCache.tokenFor("1").getIndex()), 0.1);
assertEquals(0, dataSet.getFeatureMatrix().getDouble(vocabCache.tokenFor("2").getIndex()), 0.1);
int idx1 = ((IndexAccumulation) Nd4j.getExecutioner().exec(new IMax(dataSet.getLabels()))).getFinalResult();
//assertEquals(0.0, dataSet.getLabels().getDouble(0), 0.1);
//assertEquals(1.0, dataSet.getLabels().getDouble(1), 0.1);
assertNotEquals(idx2, idx1);
// Serialization check
File tempFile = File.createTempFile("fdsf", "fdfsdf");
tempFile.deleteOnExit();
SerializationUtils.saveObject(vectorizer, tempFile);
BagOfWordsVectorizer vectorizer2 = SerializationUtils.readObject(tempFile);
vectorizer2.setTokenizerFactory(tokenizerFactory);
dataSet = vectorizer2.vectorize("This is 2 file.", "label2");
assertEquals(array, dataSet.getFeatureMatrix());
}
use of org.deeplearning4j.text.sentenceiterator.labelaware.LabelAwareSentenceIterator in project deeplearning4j by deeplearning4j.
the class TfidfVectorizerTest method testTfIdfVectorizer.
@Test
public void testTfIdfVectorizer() throws Exception {
File rootDir = new ClassPathResource("tripledir").getFile();
LabelAwareSentenceIterator iter = new LabelAwareFileSentenceIterator(rootDir);
TokenizerFactory tokenizerFactory = new DefaultTokenizerFactory();
TfidfVectorizer vectorizer = new TfidfVectorizer.Builder().setMinWordFrequency(1).setStopWords(new ArrayList<String>()).setTokenizerFactory(tokenizerFactory).setIterator(iter).allowParallelTokenization(false).build();
vectorizer.fit();
VocabWord word = vectorizer.getVocabCache().wordFor("file.");
assumeNotNull(word);
assertEquals(word, vectorizer.getVocabCache().tokenFor("file."));
assertEquals(3, vectorizer.getVocabCache().totalNumberOfDocs());
assertEquals(3, word.getSequencesCount());
assertEquals(3, word.getElementFrequency(), 0.1);
VocabWord word1 = vectorizer.getVocabCache().wordFor("1");
assertEquals(1, word1.getSequencesCount());
assertEquals(1, word1.getElementFrequency(), 0.1);
log.info("Labels used: " + vectorizer.getLabelsSource().getLabels());
assertEquals(3, vectorizer.getLabelsSource().getNumberOfLabelsUsed());
assertEquals(3, vectorizer.getVocabCache().totalNumberOfDocs());
assertEquals(11, vectorizer.numWordsEncountered());
INDArray vector = vectorizer.transform("This is 3 file.");
log.info("TF-IDF vector: " + Arrays.toString(vector.data().asDouble()));
VocabCache<VocabWord> vocabCache = vectorizer.getVocabCache();
assertEquals(.04402, vector.getDouble(vocabCache.tokenFor("This").getIndex()), 0.001);
assertEquals(.04402, vector.getDouble(vocabCache.tokenFor("is").getIndex()), 0.001);
assertEquals(0.119, vector.getDouble(vocabCache.tokenFor("3").getIndex()), 0.001);
assertEquals(0, vector.getDouble(vocabCache.tokenFor("file.").getIndex()), 0.001);
DataSet dataSet = vectorizer.vectorize("This is 3 file.", "label3");
//assertEquals(0.0, dataSet.getLabels().getDouble(0), 0.1);
//assertEquals(0.0, dataSet.getLabels().getDouble(1), 0.1);
//assertEquals(1.0, dataSet.getLabels().getDouble(2), 0.1);
int cnt = 0;
for (int i = 0; i < 3; i++) {
if (dataSet.getLabels().getDouble(i) > 0.1)
cnt++;
}
assertEquals(1, cnt);
File tempFile = File.createTempFile("somefile", "Dsdas");
tempFile.deleteOnExit();
SerializationUtils.saveObject(vectorizer, tempFile);
TfidfVectorizer vectorizer2 = SerializationUtils.readObject(tempFile);
vectorizer2.setTokenizerFactory(tokenizerFactory);
dataSet = vectorizer2.vectorize("This is 3 file.", "label2");
assertEquals(vector, dataSet.getFeatureMatrix());
}
Aggregations