use of org.deeplearning4j.text.documentiterator.LabelAwareIterator in project deeplearning4j by deeplearning4j.
the class ParagraphVectorsTest method testParagraphVectorsReducedLabels1.
/**
* This test is not indicative.
* there's no need in this test within travis, use it manually only for problems detection
*
* @throws Exception
*/
@Test
@Ignore
public void testParagraphVectorsReducedLabels1() throws Exception {
ClassPathResource resource = new ClassPathResource("/labeled");
File file = resource.getFile();
LabelAwareIterator iter = new FileLabelAwareIterator.Builder().addSourceFolder(file).build();
TokenizerFactory t = new DefaultTokenizerFactory();
/**
* Please note: text corpus is REALLY small, and some kind of "results" could be received with HIGH epochs number, like 30.
* But there's no reason to keep at that high
*/
ParagraphVectors vec = new ParagraphVectors.Builder().minWordFrequency(1).epochs(3).layerSize(100).stopWords(new ArrayList<String>()).windowSize(5).iterate(iter).tokenizerFactory(t).build();
vec.fit();
//WordVectorSerializer.writeWordVectors(vec, "vectors.txt");
INDArray w1 = vec.lookupTable().vector("I");
INDArray w2 = vec.lookupTable().vector("am");
INDArray w3 = vec.lookupTable().vector("sad.");
INDArray words = Nd4j.create(3, vec.lookupTable().layerSize());
words.putRow(0, w1);
words.putRow(1, w2);
words.putRow(2, w3);
INDArray mean = words.isMatrix() ? words.mean(0) : words;
log.info("Mean" + Arrays.toString(mean.dup().data().asDouble()));
log.info("Array" + Arrays.toString(vec.lookupTable().vector("negative").dup().data().asDouble()));
double simN = Transforms.cosineSim(mean, vec.lookupTable().vector("negative"));
log.info("Similarity negative: " + simN);
double simP = Transforms.cosineSim(mean, vec.lookupTable().vector("neutral"));
log.info("Similarity neutral: " + simP);
double simV = Transforms.cosineSim(mean, vec.lookupTable().vector("positive"));
log.info("Similarity positive: " + simV);
}
use of org.deeplearning4j.text.documentiterator.LabelAwareIterator in project deeplearning4j by deeplearning4j.
the class ParallelTransformerIteratorTest method testSpeedComparison1.
@Test
public void testSpeedComparison1() throws Exception {
SentenceIterator iterator = new MutipleEpochsSentenceIterator(new BasicLineIterator(new ClassPathResource("/big/raw_sentences.txt").getFile()), 25);
SentenceTransformer transformer = new SentenceTransformer.Builder().iterator(iterator).allowMultithreading(false).tokenizerFactory(factory).build();
Iterator<Sequence<VocabWord>> iter = transformer.iterator();
int cnt = 0;
long time1 = System.currentTimeMillis();
while (iter.hasNext()) {
Sequence<VocabWord> sequence = iter.next();
assertNotEquals("Failed on [" + cnt + "] iteration", null, sequence);
assertNotEquals("Failed on [" + cnt + "] iteration", 0, sequence.size());
cnt++;
}
long time2 = System.currentTimeMillis();
log.info("Single-threaded time: {} ms", time2 - time1);
iterator.reset();
transformer = new SentenceTransformer.Builder().iterator(iterator).allowMultithreading(true).tokenizerFactory(factory).build();
iter = transformer.iterator();
time1 = System.currentTimeMillis();
while (iter.hasNext()) {
Sequence<VocabWord> sequence = iter.next();
assertNotEquals("Failed on [" + cnt + "] iteration", null, sequence);
assertNotEquals("Failed on [" + cnt + "] iteration", 0, sequence.size());
cnt++;
}
time2 = System.currentTimeMillis();
log.info("Multi-threaded time: {} ms", time2 - time1);
SentenceIterator baseIterator = iterator;
baseIterator.reset();
LabelAwareIterator lai = new BasicLabelAwareIterator.Builder(new MutipleEpochsSentenceIterator(new BasicLineIterator(new ClassPathResource("/big/raw_sentences.txt").getFile()), 25)).build();
transformer = new SentenceTransformer.Builder().iterator(lai).allowMultithreading(false).tokenizerFactory(factory).build();
iter = transformer.iterator();
time1 = System.currentTimeMillis();
while (iter.hasNext()) {
Sequence<VocabWord> sequence = iter.next();
assertNotEquals("Failed on [" + cnt + "] iteration", null, sequence);
assertNotEquals("Failed on [" + cnt + "] iteration", 0, sequence.size());
cnt++;
}
time2 = System.currentTimeMillis();
log.info("Prefetched Single-threaded time: {} ms", time2 - time1);
lai.reset();
transformer = new SentenceTransformer.Builder().iterator(lai).allowMultithreading(true).tokenizerFactory(factory).build();
iter = transformer.iterator();
time1 = System.currentTimeMillis();
while (iter.hasNext()) {
Sequence<VocabWord> sequence = iter.next();
assertNotEquals("Failed on [" + cnt + "] iteration", null, sequence);
assertNotEquals("Failed on [" + cnt + "] iteration", 0, sequence.size());
cnt++;
}
time2 = System.currentTimeMillis();
log.info("Prefetched Multi-threaded time: {} ms", time2 - time1);
}
Aggregations