use of org.deeplearning4j.models.sequencevectors.transformers.impl.SentenceTransformer in project deeplearning4j by deeplearning4j.
the class BaseTextVectorizer method buildVocab.
public void buildVocab() {
if (vocabCache == null)
vocabCache = new AbstractCache.Builder<VocabWord>().build();
SentenceTransformer transformer = new SentenceTransformer.Builder().iterator(this.iterator).tokenizerFactory(tokenizerFactory).build();
AbstractSequenceIterator<VocabWord> iterator = new AbstractSequenceIterator.Builder<>(transformer).build();
VocabConstructor<VocabWord> constructor = new VocabConstructor.Builder<VocabWord>().addSource(iterator, minWordFrequency).setTargetVocabCache(vocabCache).setStopWords(stopWords).allowParallelTokenization(isParallel).build();
constructor.buildJointVocabulary(false, true);
}
use of org.deeplearning4j.models.sequencevectors.transformers.impl.SentenceTransformer in project deeplearning4j by deeplearning4j.
the class Word2Vec method setTokenizerFactory.
/**
* This method defines TokenizerFactory instance to be using during model building
*
* @param tokenizerFactory TokenizerFactory instance
*/
public void setTokenizerFactory(@NonNull TokenizerFactory tokenizerFactory) {
this.tokenizerFactory = tokenizerFactory;
if (sentenceIter != null) {
SentenceTransformer transformer = new SentenceTransformer.Builder().iterator(sentenceIter).tokenizerFactory(this.tokenizerFactory).build();
this.iterator = new AbstractSequenceIterator.Builder<>(transformer).build();
}
}
use of org.deeplearning4j.models.sequencevectors.transformers.impl.SentenceTransformer in project deeplearning4j by deeplearning4j.
the class InMemoryLookupTableTest method testConsumeOnNonEqualVocabs.
@Test
public void testConsumeOnNonEqualVocabs() throws Exception {
TokenizerFactory t = new DefaultTokenizerFactory();
t.setTokenPreProcessor(new CommonPreprocessor());
AbstractCache<VocabWord> cacheSource = new AbstractCache.Builder<VocabWord>().build();
ClassPathResource resource = new ClassPathResource("big/raw_sentences.txt");
BasicLineIterator underlyingIterator = new BasicLineIterator(resource.getFile());
SentenceTransformer transformer = new SentenceTransformer.Builder().iterator(underlyingIterator).tokenizerFactory(t).build();
AbstractSequenceIterator<VocabWord> sequenceIterator = new AbstractSequenceIterator.Builder<>(transformer).build();
VocabConstructor<VocabWord> vocabConstructor = new VocabConstructor.Builder<VocabWord>().addSource(sequenceIterator, 1).setTargetVocabCache(cacheSource).build();
vocabConstructor.buildJointVocabulary(false, true);
assertEquals(244, cacheSource.numWords());
InMemoryLookupTable<VocabWord> mem1 = (InMemoryLookupTable<VocabWord>) new InMemoryLookupTable.Builder<VocabWord>().vectorLength(100).cache(cacheSource).build();
mem1.resetWeights(true);
AbstractCache<VocabWord> cacheTarget = new AbstractCache.Builder<VocabWord>().build();
FileLabelAwareIterator labelAwareIterator = new FileLabelAwareIterator.Builder().addSourceFolder(new ClassPathResource("/paravec/labeled").getFile()).build();
transformer = new SentenceTransformer.Builder().iterator(labelAwareIterator).tokenizerFactory(t).build();
sequenceIterator = new AbstractSequenceIterator.Builder<>(transformer).build();
VocabConstructor<VocabWord> vocabTransfer = new VocabConstructor.Builder<VocabWord>().addSource(sequenceIterator, 1).setTargetVocabCache(cacheTarget).build();
vocabTransfer.buildMergedVocabulary(cacheSource, true);
// those +3 go for 3 additional entries in target VocabCache: labels
assertEquals(cacheSource.numWords() + 3, cacheTarget.numWords());
InMemoryLookupTable<VocabWord> mem2 = (InMemoryLookupTable<VocabWord>) new InMemoryLookupTable.Builder<VocabWord>().vectorLength(100).cache(cacheTarget).seed(18).build();
mem2.resetWeights(true);
assertNotEquals(mem1.vector("day"), mem2.vector("day"));
mem2.consume(mem1);
assertEquals(mem1.vector("day"), mem2.vector("day"));
assertTrue(mem1.syn0.rows() < mem2.syn0.rows());
assertEquals(mem1.syn0.rows() + 3, mem2.syn0.rows());
}
use of org.deeplearning4j.models.sequencevectors.transformers.impl.SentenceTransformer in project deeplearning4j by deeplearning4j.
the class InMemoryLookupTableTest method testConsumeOnEqualVocabs.
@Test
public void testConsumeOnEqualVocabs() throws Exception {
TokenizerFactory t = new DefaultTokenizerFactory();
t.setTokenPreProcessor(new CommonPreprocessor());
AbstractCache<VocabWord> cacheSource = new AbstractCache.Builder<VocabWord>().build();
ClassPathResource resource = new ClassPathResource("big/raw_sentences.txt");
BasicLineIterator underlyingIterator = new BasicLineIterator(resource.getFile());
SentenceTransformer transformer = new SentenceTransformer.Builder().iterator(underlyingIterator).tokenizerFactory(t).build();
AbstractSequenceIterator<VocabWord> sequenceIterator = new AbstractSequenceIterator.Builder<>(transformer).build();
VocabConstructor<VocabWord> vocabConstructor = new VocabConstructor.Builder<VocabWord>().addSource(sequenceIterator, 1).setTargetVocabCache(cacheSource).build();
vocabConstructor.buildJointVocabulary(false, true);
assertEquals(244, cacheSource.numWords());
InMemoryLookupTable<VocabWord> mem1 = (InMemoryLookupTable<VocabWord>) new InMemoryLookupTable.Builder<VocabWord>().vectorLength(100).cache(cacheSource).seed(17).build();
mem1.resetWeights(true);
InMemoryLookupTable<VocabWord> mem2 = (InMemoryLookupTable<VocabWord>) new InMemoryLookupTable.Builder<VocabWord>().vectorLength(100).cache(cacheSource).seed(15).build();
mem2.resetWeights(true);
assertNotEquals(mem1.vector("day"), mem2.vector("day"));
mem2.consume(mem1);
assertEquals(mem1.vector("day"), mem2.vector("day"));
}
use of org.deeplearning4j.models.sequencevectors.transformers.impl.SentenceTransformer in project deeplearning4j by deeplearning4j.
the class AbstractCoOccurrencesTest method testFit1.
@Test
public void testFit1() throws Exception {
ClassPathResource resource = new ClassPathResource("other/oneline.txt");
File file = resource.getFile();
AbstractCache<VocabWord> vocabCache = new AbstractCache.Builder<VocabWord>().build();
BasicLineIterator underlyingIterator = new BasicLineIterator(file);
TokenizerFactory t = new DefaultTokenizerFactory();
t.setTokenPreProcessor(new CommonPreprocessor());
SentenceTransformer transformer = new SentenceTransformer.Builder().iterator(underlyingIterator).tokenizerFactory(t).build();
AbstractSequenceIterator<VocabWord> sequenceIterator = new AbstractSequenceIterator.Builder<>(transformer).build();
VocabConstructor<VocabWord> constructor = new VocabConstructor.Builder<VocabWord>().addSource(sequenceIterator, 1).setTargetVocabCache(vocabCache).build();
constructor.buildJointVocabulary(false, true);
AbstractCoOccurrences<VocabWord> coOccurrences = new AbstractCoOccurrences.Builder<VocabWord>().iterate(sequenceIterator).vocabCache(vocabCache).symmetric(false).windowSize(15).build();
coOccurrences.fit();
//List<Pair<VocabWord, VocabWord>> list = coOccurrences.i();
Iterator<Pair<Pair<VocabWord, VocabWord>, Double>> iterator = coOccurrences.iterator();
assertNotEquals(null, iterator);
int cnt = 0;
List<Pair<VocabWord, VocabWord>> list = new ArrayList<>();
while (iterator.hasNext()) {
Pair<Pair<VocabWord, VocabWord>, Double> pair = iterator.next();
list.add(pair.getFirst());
cnt++;
}
log.info("CoOccurrences: " + list);
assertEquals(16, list.size());
assertEquals(16, cnt);
}
Aggregations