use of org.deeplearning4j.models.sequencevectors.transformers.impl.SentenceTransformer in project deeplearning4j by deeplearning4j.
the class Word2Vec method setSentenceIterator.
/**
* This method defines SentenceIterator instance, that will be used as training corpus source
*
* @param iterator SentenceIterator instance
*/
public void setSentenceIterator(@NonNull SentenceIterator iterator) {
if (tokenizerFactory != null) {
SentenceTransformer transformer = new SentenceTransformer.Builder().iterator(iterator).tokenizerFactory(tokenizerFactory).allowMultithreading(configuration == null || configuration.isAllowParallelTokenization()).build();
this.iterator = new AbstractSequenceIterator.Builder<>(transformer).build();
} else
log.error("Please call setTokenizerFactory() prior to setSentenceIter() call.");
}
use of org.deeplearning4j.models.sequencevectors.transformers.impl.SentenceTransformer in project deeplearning4j by deeplearning4j.
the class ParallelTransformerIteratorTest method testSpeedComparison1.
@Test
public void testSpeedComparison1() throws Exception {
SentenceIterator iterator = new MutipleEpochsSentenceIterator(new BasicLineIterator(new ClassPathResource("/big/raw_sentences.txt").getFile()), 25);
SentenceTransformer transformer = new SentenceTransformer.Builder().iterator(iterator).allowMultithreading(false).tokenizerFactory(factory).build();
Iterator<Sequence<VocabWord>> iter = transformer.iterator();
int cnt = 0;
long time1 = System.currentTimeMillis();
while (iter.hasNext()) {
Sequence<VocabWord> sequence = iter.next();
assertNotEquals("Failed on [" + cnt + "] iteration", null, sequence);
assertNotEquals("Failed on [" + cnt + "] iteration", 0, sequence.size());
cnt++;
}
long time2 = System.currentTimeMillis();
log.info("Single-threaded time: {} ms", time2 - time1);
iterator.reset();
transformer = new SentenceTransformer.Builder().iterator(iterator).allowMultithreading(true).tokenizerFactory(factory).build();
iter = transformer.iterator();
time1 = System.currentTimeMillis();
while (iter.hasNext()) {
Sequence<VocabWord> sequence = iter.next();
assertNotEquals("Failed on [" + cnt + "] iteration", null, sequence);
assertNotEquals("Failed on [" + cnt + "] iteration", 0, sequence.size());
cnt++;
}
time2 = System.currentTimeMillis();
log.info("Multi-threaded time: {} ms", time2 - time1);
SentenceIterator baseIterator = iterator;
baseIterator.reset();
LabelAwareIterator lai = new BasicLabelAwareIterator.Builder(new MutipleEpochsSentenceIterator(new BasicLineIterator(new ClassPathResource("/big/raw_sentences.txt").getFile()), 25)).build();
transformer = new SentenceTransformer.Builder().iterator(lai).allowMultithreading(false).tokenizerFactory(factory).build();
iter = transformer.iterator();
time1 = System.currentTimeMillis();
while (iter.hasNext()) {
Sequence<VocabWord> sequence = iter.next();
assertNotEquals("Failed on [" + cnt + "] iteration", null, sequence);
assertNotEquals("Failed on [" + cnt + "] iteration", 0, sequence.size());
cnt++;
}
time2 = System.currentTimeMillis();
log.info("Prefetched Single-threaded time: {} ms", time2 - time1);
lai.reset();
transformer = new SentenceTransformer.Builder().iterator(lai).allowMultithreading(true).tokenizerFactory(factory).build();
iter = transformer.iterator();
time1 = System.currentTimeMillis();
while (iter.hasNext()) {
Sequence<VocabWord> sequence = iter.next();
assertNotEquals("Failed on [" + cnt + "] iteration", null, sequence);
assertNotEquals("Failed on [" + cnt + "] iteration", 0, sequence.size());
cnt++;
}
time2 = System.currentTimeMillis();
log.info("Prefetched Multi-threaded time: {} ms", time2 - time1);
}
use of org.deeplearning4j.models.sequencevectors.transformers.impl.SentenceTransformer in project deeplearning4j by deeplearning4j.
the class VocabConstructorTest method testBuildJointVocabulary2.
@Test
public void testBuildJointVocabulary2() throws Exception {
File inputFile = new ClassPathResource("big/raw_sentences.txt").getFile();
SentenceIterator iter = new BasicLineIterator(inputFile);
VocabCache<VocabWord> cache = new AbstractCache.Builder<VocabWord>().build();
SentenceTransformer transformer = new SentenceTransformer.Builder().iterator(iter).tokenizerFactory(t).build();
AbstractSequenceIterator<VocabWord> sequenceIterator = new AbstractSequenceIterator.Builder<>(transformer).build();
VocabConstructor<VocabWord> constructor = new VocabConstructor.Builder<VocabWord>().addSource(sequenceIterator, 5).useAdaGrad(false).setTargetVocabCache(cache).build();
constructor.buildJointVocabulary(false, true);
// assertFalse(cache.hasToken("including"));
assertEquals(242, cache.numWords());
assertEquals("i", cache.wordAtIndex(1));
assertEquals("it", cache.wordAtIndex(0));
assertEquals(634303, cache.totalWordOccurrences());
}
use of org.deeplearning4j.models.sequencevectors.transformers.impl.SentenceTransformer in project deeplearning4j by deeplearning4j.
the class VocabConstructorTest method testBuildJointVocabulary1.
@Test
public void testBuildJointVocabulary1() throws Exception {
File inputFile = new ClassPathResource("big/raw_sentences.txt").getFile();
SentenceIterator iter = new BasicLineIterator(inputFile);
VocabCache<VocabWord> cache = new AbstractCache.Builder<VocabWord>().build();
SentenceTransformer transformer = new SentenceTransformer.Builder().iterator(iter).tokenizerFactory(t).build();
/*
And we pack that transformer into AbstractSequenceIterator
*/
AbstractSequenceIterator<VocabWord> sequenceIterator = new AbstractSequenceIterator.Builder<>(transformer).build();
VocabConstructor<VocabWord> constructor = new VocabConstructor.Builder<VocabWord>().addSource(sequenceIterator, 0).useAdaGrad(false).setTargetVocabCache(cache).build();
constructor.buildJointVocabulary(true, false);
assertEquals(244, cache.numWords());
assertEquals(0, cache.totalWordOccurrences());
}
use of org.deeplearning4j.models.sequencevectors.transformers.impl.SentenceTransformer in project deeplearning4j by deeplearning4j.
the class SequenceVectorsTest method testInternalVocabConstruction.
@Test
public void testInternalVocabConstruction() throws Exception {
ClassPathResource resource = new ClassPathResource("big/raw_sentences.txt");
File file = resource.getFile();
BasicLineIterator underlyingIterator = new BasicLineIterator(file);
TokenizerFactory t = new DefaultTokenizerFactory();
t.setTokenPreProcessor(new CommonPreprocessor());
SentenceTransformer transformer = new SentenceTransformer.Builder().iterator(underlyingIterator).tokenizerFactory(t).build();
AbstractSequenceIterator<VocabWord> sequenceIterator = new AbstractSequenceIterator.Builder<>(transformer).build();
SequenceVectors<VocabWord> vectors = new SequenceVectors.Builder<VocabWord>(new VectorsConfiguration()).minWordFrequency(5).iterate(sequenceIterator).batchSize(250).iterations(1).epochs(1).resetModel(false).trainElementsRepresentation(true).build();
logger.info("Fitting model...");
vectors.fit();
logger.info("Model ready...");
double sim = vectors.similarity("day", "night");
logger.info("Day/night similarity: " + sim);
assertTrue(sim > 0.6d);
Collection<String> labels = vectors.wordsNearest("day", 10);
logger.info("Nearest labels to 'day': " + labels);
}
Aggregations