use of org.deeplearning4j.text.sentenceiterator.MutipleEpochsSentenceIterator in project deeplearning4j by deeplearning4j.
the class ParallelTransformerIteratorTest method testSpeedComparison1.
@Test
public void testSpeedComparison1() throws Exception {
SentenceIterator iterator = new MutipleEpochsSentenceIterator(new BasicLineIterator(new ClassPathResource("/big/raw_sentences.txt").getFile()), 25);
SentenceTransformer transformer = new SentenceTransformer.Builder().iterator(iterator).allowMultithreading(false).tokenizerFactory(factory).build();
Iterator<Sequence<VocabWord>> iter = transformer.iterator();
int cnt = 0;
long time1 = System.currentTimeMillis();
while (iter.hasNext()) {
Sequence<VocabWord> sequence = iter.next();
assertNotEquals("Failed on [" + cnt + "] iteration", null, sequence);
assertNotEquals("Failed on [" + cnt + "] iteration", 0, sequence.size());
cnt++;
}
long time2 = System.currentTimeMillis();
log.info("Single-threaded time: {} ms", time2 - time1);
iterator.reset();
transformer = new SentenceTransformer.Builder().iterator(iterator).allowMultithreading(true).tokenizerFactory(factory).build();
iter = transformer.iterator();
time1 = System.currentTimeMillis();
while (iter.hasNext()) {
Sequence<VocabWord> sequence = iter.next();
assertNotEquals("Failed on [" + cnt + "] iteration", null, sequence);
assertNotEquals("Failed on [" + cnt + "] iteration", 0, sequence.size());
cnt++;
}
time2 = System.currentTimeMillis();
log.info("Multi-threaded time: {} ms", time2 - time1);
SentenceIterator baseIterator = iterator;
baseIterator.reset();
LabelAwareIterator lai = new BasicLabelAwareIterator.Builder(new MutipleEpochsSentenceIterator(new BasicLineIterator(new ClassPathResource("/big/raw_sentences.txt").getFile()), 25)).build();
transformer = new SentenceTransformer.Builder().iterator(lai).allowMultithreading(false).tokenizerFactory(factory).build();
iter = transformer.iterator();
time1 = System.currentTimeMillis();
while (iter.hasNext()) {
Sequence<VocabWord> sequence = iter.next();
assertNotEquals("Failed on [" + cnt + "] iteration", null, sequence);
assertNotEquals("Failed on [" + cnt + "] iteration", 0, sequence.size());
cnt++;
}
time2 = System.currentTimeMillis();
log.info("Prefetched Single-threaded time: {} ms", time2 - time1);
lai.reset();
transformer = new SentenceTransformer.Builder().iterator(lai).allowMultithreading(true).tokenizerFactory(factory).build();
iter = transformer.iterator();
time1 = System.currentTimeMillis();
while (iter.hasNext()) {
Sequence<VocabWord> sequence = iter.next();
assertNotEquals("Failed on [" + cnt + "] iteration", null, sequence);
assertNotEquals("Failed on [" + cnt + "] iteration", 0, sequence.size());
cnt++;
}
time2 = System.currentTimeMillis();
log.info("Prefetched Multi-threaded time: {} ms", time2 - time1);
}
Aggregations