use of org.deeplearning4j.models.sequencevectors.iterators.AbstractSequenceIterator in project deeplearning4j by deeplearning4j.
the class VocabConstructorTest method testBuildJointVocabulary2.
@Test
public void testBuildJointVocabulary2() throws Exception {
File inputFile = new ClassPathResource("big/raw_sentences.txt").getFile();
SentenceIterator iter = new BasicLineIterator(inputFile);
VocabCache<VocabWord> cache = new AbstractCache.Builder<VocabWord>().build();
SentenceTransformer transformer = new SentenceTransformer.Builder().iterator(iter).tokenizerFactory(t).build();
AbstractSequenceIterator<VocabWord> sequenceIterator = new AbstractSequenceIterator.Builder<>(transformer).build();
VocabConstructor<VocabWord> constructor = new VocabConstructor.Builder<VocabWord>().addSource(sequenceIterator, 5).useAdaGrad(false).setTargetVocabCache(cache).build();
constructor.buildJointVocabulary(false, true);
// assertFalse(cache.hasToken("including"));
assertEquals(242, cache.numWords());
assertEquals("i", cache.wordAtIndex(1));
assertEquals("it", cache.wordAtIndex(0));
assertEquals(634303, cache.totalWordOccurrences());
}
use of org.deeplearning4j.models.sequencevectors.iterators.AbstractSequenceIterator in project deeplearning4j by deeplearning4j.
the class VocabConstructorTest method testBuildJointVocabulary1.
@Test
public void testBuildJointVocabulary1() throws Exception {
File inputFile = new ClassPathResource("big/raw_sentences.txt").getFile();
SentenceIterator iter = new BasicLineIterator(inputFile);
VocabCache<VocabWord> cache = new AbstractCache.Builder<VocabWord>().build();
SentenceTransformer transformer = new SentenceTransformer.Builder().iterator(iter).tokenizerFactory(t).build();
/*
And we pack that transformer into AbstractSequenceIterator
*/
AbstractSequenceIterator<VocabWord> sequenceIterator = new AbstractSequenceIterator.Builder<>(transformer).build();
VocabConstructor<VocabWord> constructor = new VocabConstructor.Builder<VocabWord>().addSource(sequenceIterator, 0).useAdaGrad(false).setTargetVocabCache(cache).build();
constructor.buildJointVocabulary(true, false);
assertEquals(244, cache.numWords());
assertEquals(0, cache.totalWordOccurrences());
}
use of org.deeplearning4j.models.sequencevectors.iterators.AbstractSequenceIterator in project deeplearning4j by deeplearning4j.
the class SequenceVectorsTest method testInternalVocabConstruction.
@Test
public void testInternalVocabConstruction() throws Exception {
ClassPathResource resource = new ClassPathResource("big/raw_sentences.txt");
File file = resource.getFile();
BasicLineIterator underlyingIterator = new BasicLineIterator(file);
TokenizerFactory t = new DefaultTokenizerFactory();
t.setTokenPreProcessor(new CommonPreprocessor());
SentenceTransformer transformer = new SentenceTransformer.Builder().iterator(underlyingIterator).tokenizerFactory(t).build();
AbstractSequenceIterator<VocabWord> sequenceIterator = new AbstractSequenceIterator.Builder<>(transformer).build();
SequenceVectors<VocabWord> vectors = new SequenceVectors.Builder<VocabWord>(new VectorsConfiguration()).minWordFrequency(5).iterate(sequenceIterator).batchSize(250).iterations(1).epochs(1).resetModel(false).trainElementsRepresentation(true).build();
logger.info("Fitting model...");
vectors.fit();
logger.info("Model ready...");
double sim = vectors.similarity("day", "night");
logger.info("Day/night similarity: " + sim);
assertTrue(sim > 0.6d);
Collection<String> labels = vectors.wordsNearest("day", 10);
logger.info("Nearest labels to 'day': " + labels);
}
use of org.deeplearning4j.models.sequencevectors.iterators.AbstractSequenceIterator in project deeplearning4j by deeplearning4j.
the class SequenceVectorsTest method testDeepWalk.
@Test
@Ignore
public void testDeepWalk() throws Exception {
Heartbeat.getInstance().disableHeartbeat();
AbstractCache<Blogger> vocabCache = new AbstractCache.Builder<Blogger>().build();
Graph<Blogger, Double> graph = buildGraph();
GraphWalker<Blogger> walker = new PopularityWalker.Builder<>(graph).setNoEdgeHandling(NoEdgeHandling.RESTART_ON_DISCONNECTED).setWalkLength(40).setWalkDirection(WalkDirection.FORWARD_UNIQUE).setRestartProbability(0.05).setPopularitySpread(10).setPopularityMode(PopularityMode.MAXIMUM).setSpreadSpectrum(SpreadSpectrum.PROPORTIONAL).build();
/*
GraphWalker<Blogger> walker = new RandomWalker.Builder<Blogger>(graph)
.setNoEdgeHandling(NoEdgeHandling.RESTART_ON_DISCONNECTED)
.setWalkLength(40)
.setWalkDirection(WalkDirection.RANDOM)
.setRestartProbability(0.05)
.build();
*/
GraphTransformer<Blogger> graphTransformer = new GraphTransformer.Builder<>(graph).setGraphWalker(walker).shuffleOnReset(true).setVocabCache(vocabCache).build();
Blogger blogger = graph.getVertex(0).getValue();
assertEquals(119, blogger.getElementFrequency(), 0.001);
logger.info("Blogger: " + blogger);
AbstractSequenceIterator<Blogger> sequenceIterator = new AbstractSequenceIterator.Builder<>(graphTransformer).build();
WeightLookupTable<Blogger> lookupTable = new InMemoryLookupTable.Builder<Blogger>().lr(0.025).vectorLength(150).useAdaGrad(false).cache(vocabCache).seed(42).build();
lookupTable.resetWeights(true);
SequenceVectors<Blogger> vectors = new SequenceVectors.Builder<Blogger>(new VectorsConfiguration()).lookupTable(lookupTable).iterate(sequenceIterator).vocabCache(vocabCache).batchSize(1000).iterations(1).epochs(10).resetModel(false).trainElementsRepresentation(true).trainSequencesRepresentation(false).elementsLearningAlgorithm(new SkipGram<Blogger>()).learningRate(0.025).layerSize(150).sampling(0).negativeSample(0).windowSize(4).workers(6).seed(42).build();
vectors.fit();
vectors.setModelUtils(new FlatModelUtils());
// logger.info("12: " + Arrays.toString(vectors.getWordVector("12")));
double sim = vectors.similarity("12", "72");
Collection<String> list = vectors.wordsNearest("12", 20);
logger.info("12->72: " + sim);
printWords("12", list, vectors);
assertTrue(sim > 0.10);
assertFalse(Double.isNaN(sim));
}
use of org.deeplearning4j.models.sequencevectors.iterators.AbstractSequenceIterator in project deeplearning4j by deeplearning4j.
the class BaseTextVectorizer method buildVocab.
public void buildVocab() {
if (vocabCache == null)
vocabCache = new AbstractCache.Builder<VocabWord>().build();
SentenceTransformer transformer = new SentenceTransformer.Builder().iterator(this.iterator).tokenizerFactory(tokenizerFactory).build();
AbstractSequenceIterator<VocabWord> iterator = new AbstractSequenceIterator.Builder<>(transformer).build();
VocabConstructor<VocabWord> constructor = new VocabConstructor.Builder<VocabWord>().addSource(iterator, minWordFrequency).setTargetVocabCache(vocabCache).setStopWords(stopWords).allowParallelTokenization(isParallel).build();
constructor.buildJointVocabulary(false, true);
}
Aggregations