use of org.deeplearning4j.models.word2vec.VocabWord in project deeplearning4j by deeplearning4j.
the class ParallelTransformerIteratorTest method hasNext.
@Test
public void hasNext() throws Exception {
SentenceIterator iterator = new BasicLineIterator(new ClassPathResource("/big/raw_sentences.txt").getFile());
SentenceTransformer transformer = new SentenceTransformer.Builder().iterator(iterator).allowMultithreading(true).tokenizerFactory(factory).build();
Iterator<Sequence<VocabWord>> iter = transformer.iterator();
int cnt = 0;
Sequence<VocabWord> sequence = null;
while (iter.hasNext()) {
sequence = iter.next();
assertNotEquals("Failed on [" + cnt + "] iteration", null, sequence);
assertNotEquals("Failed on [" + cnt + "] iteration", 0, sequence.size());
cnt++;
}
// log.info("Last element: {}", sequence.asLabels());
assertEquals(97162, cnt);
}
use of org.deeplearning4j.models.word2vec.VocabWord in project deeplearning4j by deeplearning4j.
the class Word2VecDataSetIteratorTest method testIterator1.
/**
* Basically all we want from this test - being able to finish without exceptions.
*/
@Test
public void testIterator1() throws Exception {
File inputFile = new ClassPathResource("/big/raw_sentences.txt").getFile();
SentenceIterator iter = new BasicLineIterator(inputFile.getAbsolutePath());
TokenizerFactory t = new DefaultTokenizerFactory();
t.setTokenPreProcessor(new CommonPreprocessor());
Word2Vec vec = // we make sure we'll have some missing words
new Word2Vec.Builder().minWordFrequency(10).iterations(1).learningRate(0.025).layerSize(150).seed(42).sampling(0).negativeSample(0).useHierarchicSoftmax(true).windowSize(5).modelUtils(new BasicModelUtils<VocabWord>()).useAdaGrad(false).iterate(iter).workers(8).tokenizerFactory(t).elementsLearningAlgorithm(new CBOW<VocabWord>()).build();
vec.fit();
List<String> labels = new ArrayList<>();
labels.add("positive");
labels.add("negative");
Word2VecDataSetIterator iterator = new Word2VecDataSetIterator(vec, getLASI(iter, labels), labels, 1);
INDArray array = iterator.next().getFeatures();
while (iterator.hasNext()) {
DataSet ds = iterator.next();
assertArrayEquals(array.shape(), ds.getFeatureMatrix().shape());
}
}
use of org.deeplearning4j.models.word2vec.VocabWord in project deeplearning4j by deeplearning4j.
the class VocabConstructorTest method testCounter1.
@Test
public void testCounter1() throws Exception {
VocabCache<VocabWord> vocabCache = new AbstractCache.Builder<VocabWord>().build();
final List<VocabWord> words = new ArrayList<>();
words.add(new VocabWord(1, "word"));
words.add(new VocabWord(2, "test"));
words.add(new VocabWord(1, "here"));
Iterable<Sequence<VocabWord>> iterable = new Iterable<Sequence<VocabWord>>() {
@Override
public Iterator<Sequence<VocabWord>> iterator() {
return new Iterator<Sequence<VocabWord>>() {
private AtomicBoolean switcher = new AtomicBoolean(true);
@Override
public boolean hasNext() {
return switcher.getAndSet(false);
}
@Override
public Sequence<VocabWord> next() {
Sequence<VocabWord> sequence = new Sequence<>(words);
return sequence;
}
@Override
public void remove() {
throw new UnsupportedOperationException();
}
};
}
};
SequenceIterator<VocabWord> sequenceIterator = new AbstractSequenceIterator.Builder<>(iterable).build();
VocabConstructor<VocabWord> constructor = new VocabConstructor.Builder<VocabWord>().addSource(sequenceIterator, 0).useAdaGrad(false).setTargetVocabCache(vocabCache).build();
constructor.buildJointVocabulary(false, true);
assertEquals(3, vocabCache.numWords());
assertEquals(1, vocabCache.wordFrequency("test"));
}
use of org.deeplearning4j.models.word2vec.VocabWord in project deeplearning4j by deeplearning4j.
the class VocabConstructorTest method testCounter2.
@Test
public void testCounter2() throws Exception {
VocabCache<VocabWord> vocabCache = new AbstractCache.Builder<VocabWord>().build();
final List<VocabWord> words = new ArrayList<>();
words.add(new VocabWord(1, "word"));
words.add(new VocabWord(0, "test"));
words.add(new VocabWord(1, "here"));
Iterable<Sequence<VocabWord>> iterable = new Iterable<Sequence<VocabWord>>() {
@Override
public Iterator<Sequence<VocabWord>> iterator() {
return new Iterator<Sequence<VocabWord>>() {
private AtomicBoolean switcher = new AtomicBoolean(true);
@Override
public boolean hasNext() {
return switcher.getAndSet(false);
}
@Override
public Sequence<VocabWord> next() {
Sequence<VocabWord> sequence = new Sequence<>(words);
return sequence;
}
@Override
public void remove() {
throw new UnsupportedOperationException();
}
};
}
};
SequenceIterator<VocabWord> sequenceIterator = new AbstractSequenceIterator.Builder<>(iterable).build();
VocabConstructor<VocabWord> constructor = new VocabConstructor.Builder<VocabWord>().addSource(sequenceIterator, 0).useAdaGrad(false).setTargetVocabCache(vocabCache).build();
constructor.buildJointVocabulary(false, true);
assertEquals(3, vocabCache.numWords());
assertEquals(1, vocabCache.wordFrequency("test"));
}
use of org.deeplearning4j.models.word2vec.VocabWord in project deeplearning4j by deeplearning4j.
the class VocabConstructorTest method testMergedVocab1.
/**
* Here we test basic vocab transfer, done WITHOUT labels
* @throws Exception
*/
@Test
public void testMergedVocab1() throws Exception {
AbstractCache<VocabWord> cacheSource = new AbstractCache.Builder<VocabWord>().build();
AbstractCache<VocabWord> cacheTarget = new AbstractCache.Builder<VocabWord>().build();
ClassPathResource resource = new ClassPathResource("big/raw_sentences.txt");
BasicLineIterator underlyingIterator = new BasicLineIterator(resource.getFile());
SentenceTransformer transformer = new SentenceTransformer.Builder().iterator(underlyingIterator).tokenizerFactory(t).build();
AbstractSequenceIterator<VocabWord> sequenceIterator = new AbstractSequenceIterator.Builder<>(transformer).build();
VocabConstructor<VocabWord> vocabConstructor = new VocabConstructor.Builder<VocabWord>().addSource(sequenceIterator, 1).setTargetVocabCache(cacheSource).build();
vocabConstructor.buildJointVocabulary(false, true);
int sourceSize = cacheSource.numWords();
log.info("Source Vocab size: " + sourceSize);
VocabConstructor<VocabWord> vocabTransfer = new VocabConstructor.Builder<VocabWord>().addSource(sequenceIterator, 1).setTargetVocabCache(cacheTarget).build();
vocabTransfer.buildMergedVocabulary(cacheSource, false);
assertEquals(sourceSize, cacheTarget.numWords());
}
Aggregations