use of org.deeplearning4j.text.documentiterator.FileLabelAwareIterator in project deeplearning4j by deeplearning4j.
the class ParagraphVectorsTest method testParagraphVectorsOverExistingWordVectorsModel.
/*
In this test we'll build w2v model, and will use it's vocab and weights for ParagraphVectors.
there's no need in this test within travis, use it manually only for problems detection
*/
@Test
public void testParagraphVectorsOverExistingWordVectorsModel() throws Exception {
// we build w2v from multiple sources, to cover everything
ClassPathResource resource_sentences = new ClassPathResource("/big/raw_sentences.txt");
ClassPathResource resource_mixed = new ClassPathResource("/paravec");
SentenceIterator iter = new AggregatingSentenceIterator.Builder().addSentenceIterator(new BasicLineIterator(resource_sentences.getFile())).addSentenceIterator(new FileSentenceIterator(resource_mixed.getFile())).build();
TokenizerFactory t = new DefaultTokenizerFactory();
t.setTokenPreProcessor(new CommonPreprocessor());
Word2Vec wordVectors = new Word2Vec.Builder().minWordFrequency(1).batchSize(250).iterations(1).epochs(3).learningRate(0.025).layerSize(150).minLearningRate(0.001).elementsLearningAlgorithm(new SkipGram<VocabWord>()).useHierarchicSoftmax(true).windowSize(5).iterate(iter).tokenizerFactory(t).build();
wordVectors.fit();
VocabWord day_A = wordVectors.getVocab().tokenFor("day");
INDArray vector_day1 = wordVectors.getWordVectorMatrix("day").dup();
// At this moment we have ready w2v model. It's time to use it for ParagraphVectors
FileLabelAwareIterator labelAwareIterator = new FileLabelAwareIterator.Builder().addSourceFolder(new ClassPathResource("/paravec/labeled").getFile()).build();
// documents from this iterator will be used for classification
FileLabelAwareIterator unlabeledIterator = new FileLabelAwareIterator.Builder().addSourceFolder(new ClassPathResource("/paravec/unlabeled").getFile()).build();
// we're building classifier now, with pre-built w2v model passed in
ParagraphVectors paragraphVectors = new ParagraphVectors.Builder().iterate(labelAwareIterator).learningRate(0.025).minLearningRate(0.001).iterations(5).epochs(1).layerSize(150).tokenizerFactory(t).sequenceLearningAlgorithm(new DBOW<VocabWord>()).useHierarchicSoftmax(true).trainWordVectors(false).useExistingWordVectors(wordVectors).build();
paragraphVectors.fit();
VocabWord day_B = paragraphVectors.getVocab().tokenFor("day");
assertEquals(day_A.getIndex(), day_B.getIndex());
/*
double similarityD = wordVectors.similarity("day", "night");
log.info("day/night similarity: " + similarityD);
assertTrue(similarityD > 0.5d);
*/
INDArray vector_day2 = paragraphVectors.getWordVectorMatrix("day").dup();
double crossDay = arraysSimilarity(vector_day1, vector_day2);
log.info("Day1: " + vector_day1);
log.info("Day2: " + vector_day2);
log.info("Cross-Day similarity: " + crossDay);
log.info("Cross-Day similiarity 2: " + Transforms.cosineSim(vector_day1, vector_day2));
assertTrue(crossDay > 0.9d);
/**
*
* Here we're checking cross-vocabulary equality
*
*/
/*
Random rnd = new Random();
VocabCache<VocabWord> cacheP = paragraphVectors.getVocab();
VocabCache<VocabWord> cacheW = wordVectors.getVocab();
for (int x = 0; x < 1000; x++) {
int idx = rnd.nextInt(cacheW.numWords());
String wordW = cacheW.wordAtIndex(idx);
String wordP = cacheP.wordAtIndex(idx);
assertEquals(wordW, wordP);
INDArray arrayW = wordVectors.getWordVectorMatrix(wordW);
INDArray arrayP = paragraphVectors.getWordVectorMatrix(wordP);
double simWP = Transforms.cosineSim(arrayW, arrayP);
assertTrue(simWP >= 0.9);
}
*/
log.info("Zfinance: " + paragraphVectors.getWordVectorMatrix("Zfinance"));
log.info("Zhealth: " + paragraphVectors.getWordVectorMatrix("Zhealth"));
log.info("Zscience: " + paragraphVectors.getWordVectorMatrix("Zscience"));
LabelledDocument document = unlabeledIterator.nextDocument();
log.info("Results for document '" + document.getLabel() + "'");
List<String> results = new ArrayList<>(paragraphVectors.predictSeveral(document, 3));
for (String result : results) {
double sim = paragraphVectors.similarityToLabel(document, result);
log.info("Similarity to [" + result + "] is [" + sim + "]");
}
String topPrediction = paragraphVectors.predict(document);
assertEquals("Zfinance", topPrediction);
}
use of org.deeplearning4j.text.documentiterator.FileLabelAwareIterator in project deeplearning4j by deeplearning4j.
the class InMemoryLookupTableTest method testConsumeOnNonEqualVocabs.
@Test
public void testConsumeOnNonEqualVocabs() throws Exception {
TokenizerFactory t = new DefaultTokenizerFactory();
t.setTokenPreProcessor(new CommonPreprocessor());
AbstractCache<VocabWord> cacheSource = new AbstractCache.Builder<VocabWord>().build();
ClassPathResource resource = new ClassPathResource("big/raw_sentences.txt");
BasicLineIterator underlyingIterator = new BasicLineIterator(resource.getFile());
SentenceTransformer transformer = new SentenceTransformer.Builder().iterator(underlyingIterator).tokenizerFactory(t).build();
AbstractSequenceIterator<VocabWord> sequenceIterator = new AbstractSequenceIterator.Builder<>(transformer).build();
VocabConstructor<VocabWord> vocabConstructor = new VocabConstructor.Builder<VocabWord>().addSource(sequenceIterator, 1).setTargetVocabCache(cacheSource).build();
vocabConstructor.buildJointVocabulary(false, true);
assertEquals(244, cacheSource.numWords());
InMemoryLookupTable<VocabWord> mem1 = (InMemoryLookupTable<VocabWord>) new InMemoryLookupTable.Builder<VocabWord>().vectorLength(100).cache(cacheSource).build();
mem1.resetWeights(true);
AbstractCache<VocabWord> cacheTarget = new AbstractCache.Builder<VocabWord>().build();
FileLabelAwareIterator labelAwareIterator = new FileLabelAwareIterator.Builder().addSourceFolder(new ClassPathResource("/paravec/labeled").getFile()).build();
transformer = new SentenceTransformer.Builder().iterator(labelAwareIterator).tokenizerFactory(t).build();
sequenceIterator = new AbstractSequenceIterator.Builder<>(transformer).build();
VocabConstructor<VocabWord> vocabTransfer = new VocabConstructor.Builder<VocabWord>().addSource(sequenceIterator, 1).setTargetVocabCache(cacheTarget).build();
vocabTransfer.buildMergedVocabulary(cacheSource, true);
// those +3 go for 3 additional entries in target VocabCache: labels
assertEquals(cacheSource.numWords() + 3, cacheTarget.numWords());
InMemoryLookupTable<VocabWord> mem2 = (InMemoryLookupTable<VocabWord>) new InMemoryLookupTable.Builder<VocabWord>().vectorLength(100).cache(cacheTarget).seed(18).build();
mem2.resetWeights(true);
assertNotEquals(mem1.vector("day"), mem2.vector("day"));
mem2.consume(mem1);
assertEquals(mem1.vector("day"), mem2.vector("day"));
assertTrue(mem1.syn0.rows() < mem2.syn0.rows());
assertEquals(mem1.syn0.rows() + 3, mem2.syn0.rows());
}
use of org.deeplearning4j.text.documentiterator.FileLabelAwareIterator in project deeplearning4j by deeplearning4j.
the class VocabConstructorTest method testMergedVocabWithLabels1.
@Test
public void testMergedVocabWithLabels1() throws Exception {
AbstractCache<VocabWord> cacheSource = new AbstractCache.Builder<VocabWord>().build();
AbstractCache<VocabWord> cacheTarget = new AbstractCache.Builder<VocabWord>().build();
ClassPathResource resource = new ClassPathResource("big/raw_sentences.txt");
BasicLineIterator underlyingIterator = new BasicLineIterator(resource.getFile());
SentenceTransformer transformer = new SentenceTransformer.Builder().iterator(underlyingIterator).tokenizerFactory(t).build();
AbstractSequenceIterator<VocabWord> sequenceIterator = new AbstractSequenceIterator.Builder<>(transformer).build();
VocabConstructor<VocabWord> vocabConstructor = new VocabConstructor.Builder<VocabWord>().addSource(sequenceIterator, 1).setTargetVocabCache(cacheSource).build();
vocabConstructor.buildJointVocabulary(false, true);
int sourceSize = cacheSource.numWords();
log.info("Source Vocab size: " + sourceSize);
FileLabelAwareIterator labelAwareIterator = new FileLabelAwareIterator.Builder().addSourceFolder(new ClassPathResource("/paravec/labeled").getFile()).build();
transformer = new SentenceTransformer.Builder().iterator(labelAwareIterator).tokenizerFactory(t).build();
sequenceIterator = new AbstractSequenceIterator.Builder<>(transformer).build();
VocabConstructor<VocabWord> vocabTransfer = new VocabConstructor.Builder<VocabWord>().addSource(sequenceIterator, 1).setTargetVocabCache(cacheTarget).build();
vocabTransfer.buildMergedVocabulary(cacheSource, true);
// those +3 go for 3 additional entries in target VocabCache: labels
assertEquals(sourceSize + 3, cacheTarget.numWords());
// now we check index equality for transferred elements
assertEquals(cacheSource.wordAtIndex(17), cacheTarget.wordAtIndex(17));
assertEquals(cacheSource.wordAtIndex(45), cacheTarget.wordAtIndex(45));
assertEquals(cacheSource.wordAtIndex(89), cacheTarget.wordAtIndex(89));
// we check that newly added labels have indexes beyond the VocabCache index space
// please note, we need >= since the indexes are zero-based, and sourceSize is not
assertTrue(cacheTarget.indexOf("Zfinance") > sourceSize - 1);
assertTrue(cacheTarget.indexOf("Zscience") > sourceSize - 1);
assertTrue(cacheTarget.indexOf("Zhealth") > sourceSize - 1);
}
Aggregations