use of org.deeplearning4j.models.word2vec.VocabWord in project deeplearning4j by deeplearning4j.
the class SequenceVectorsTest method testInternalVocabConstruction.
@Test
public void testInternalVocabConstruction() throws Exception {
ClassPathResource resource = new ClassPathResource("big/raw_sentences.txt");
File file = resource.getFile();
BasicLineIterator underlyingIterator = new BasicLineIterator(file);
TokenizerFactory t = new DefaultTokenizerFactory();
t.setTokenPreProcessor(new CommonPreprocessor());
SentenceTransformer transformer = new SentenceTransformer.Builder().iterator(underlyingIterator).tokenizerFactory(t).build();
AbstractSequenceIterator<VocabWord> sequenceIterator = new AbstractSequenceIterator.Builder<>(transformer).build();
SequenceVectors<VocabWord> vectors = new SequenceVectors.Builder<VocabWord>(new VectorsConfiguration()).minWordFrequency(5).iterate(sequenceIterator).batchSize(250).iterations(1).epochs(1).resetModel(false).trainElementsRepresentation(true).build();
logger.info("Fitting model...");
vectors.fit();
logger.info("Model ready...");
double sim = vectors.similarity("day", "night");
logger.info("Day/night similarity: " + sim);
assertTrue(sim > 0.6d);
Collection<String> labels = vectors.wordsNearest("day", 10);
logger.info("Nearest labels to 'day': " + labels);
}
use of org.deeplearning4j.models.word2vec.VocabWord in project deeplearning4j by deeplearning4j.
the class PopularityWalkerTest method testPopularityWalker3.
@Test
public void testPopularityWalker3() throws Exception {
GraphWalker<VocabWord> walker = new PopularityWalker.Builder<>(graph).setWalkDirection(WalkDirection.FORWARD_ONLY).setNoEdgeHandling(NoEdgeHandling.CUTOFF_ON_DISCONNECTED).setWalkLength(10).setPopularityMode(PopularityMode.MAXIMUM).setPopularitySpread(3).setSpreadSpectrum(SpreadSpectrum.PROPORTIONAL).build();
System.out.println("Connected [3] size: " + graph.getConnectedVertices(3).size());
System.out.println("Connected [4] size: " + graph.getConnectedVertices(4).size());
AtomicBoolean got4 = new AtomicBoolean(false);
AtomicBoolean got7 = new AtomicBoolean(false);
AtomicBoolean got9 = new AtomicBoolean(false);
for (int i = 0; i < 50; i++) {
Sequence<VocabWord> sequence = walker.next();
assertEquals("0", sequence.getElements().get(0).getLabel());
System.out.println("Position at 1: [" + sequence.getElements().get(1).getLabel() + "]");
got4.compareAndSet(false, sequence.getElements().get(1).getLabel().equals("4"));
got7.compareAndSet(false, sequence.getElements().get(1).getLabel().equals("7"));
got9.compareAndSet(false, sequence.getElements().get(1).getLabel().equals("9"));
assertTrue(sequence.getElements().get(1).getLabel().equals("4") || sequence.getElements().get(1).getLabel().equals("7") || sequence.getElements().get(1).getLabel().equals("9"));
walker.reset(false);
}
assertTrue(got4.get());
assertTrue(got7.get());
assertTrue(got9.get());
}
use of org.deeplearning4j.models.word2vec.VocabWord in project deeplearning4j by deeplearning4j.
the class RandomWalkerTest method testGraphTraverseRandom2.
@Test
public void testGraphTraverseRandom2() throws Exception {
RandomWalker<VocabWord> walker = (RandomWalker<VocabWord>) new RandomWalker.Builder<>(graph).setNoEdgeHandling(NoEdgeHandling.EXCEPTION_ON_DISCONNECTED).setWalkLength(20).setWalkDirection(WalkDirection.FORWARD_UNIQUE).setNoEdgeHandling(NoEdgeHandling.CUTOFF_ON_DISCONNECTED).build();
int cnt = 0;
while (walker.hasNext()) {
Sequence<VocabWord> sequence = walker.next();
assertTrue(sequence.getElements().size() <= 10);
assertNotEquals(null, sequence);
for (VocabWord word : sequence.getElements()) {
assertNotEquals(null, word);
}
cnt++;
}
assertEquals(10, cnt);
}
use of org.deeplearning4j.models.word2vec.VocabWord in project deeplearning4j by deeplearning4j.
the class RandomWalkerTest method setUp.
@Before
public void setUp() throws Exception {
if (graph == null) {
graph = new Graph<>(10, false, new AbstractVertexFactory<VocabWord>());
for (int i = 0; i < 10; i++) {
graph.getVertex(i).setValue(new VocabWord(i, String.valueOf(i)));
int x = i + 3;
if (x >= 10)
x = 0;
graph.addEdge(i, x, 1.0, false);
}
graphDirected = new Graph<>(10, false, new AbstractVertexFactory<VocabWord>());
for (int i = 0; i < 10; i++) {
graphDirected.getVertex(i).setValue(new VocabWord(i, String.valueOf(i)));
int x = i + 3;
if (x >= 10)
x = 0;
graphDirected.addEdge(i, x, 1.0, true);
}
graphBig = new Graph<>(1000, false, new AbstractVertexFactory<VocabWord>());
for (int i = 0; i < 1000; i++) {
graphBig.getVertex(i).setValue(new VocabWord(i, String.valueOf(i)));
int x = i + 3;
if (x >= 1000)
x = 0;
graphBig.addEdge(i, x, 1.0, false);
}
}
}
use of org.deeplearning4j.models.word2vec.VocabWord in project deeplearning4j by deeplearning4j.
the class AbstractElementFactoryTest method testSerialize.
@Test
public void testSerialize() throws Exception {
VocabWord word = new VocabWord(1, "word");
AbstractElementFactory<VocabWord> factory = new AbstractElementFactory<>(VocabWord.class);
System.out.println("VocabWord JSON: " + factory.serialize(word));
VocabWord word2 = factory.deserialize(factory.serialize(word));
assertEquals(word, word2);
}
Aggregations