use of org.deeplearning4j.models.word2vec.VocabWord in project deeplearning4j by deeplearning4j.
the class VocabWordFactoryTest method testDeserialize.
@Test
public void testDeserialize() throws Exception {
VocabWord word = new VocabWord(1, "word");
AbstractElementFactory<VocabWord> factory = new AbstractElementFactory<>(VocabWord.class);
System.out.println("VocabWord JSON: " + word.toJSON());
VocabWord word2 = factory.deserialize(word.toJSON());
assertEquals(word, word2);
}
use of org.deeplearning4j.models.word2vec.VocabWord in project deeplearning4j by deeplearning4j.
the class InMemoryLookupCache method incrementWordCount.
/**
* Increment the count for the given word by
* the amount increment
*
* @param word the word to increment the count for
* @param increment the amount to increment by
*/
@Override
public synchronized void incrementWordCount(String word, int increment) {
if (word == null || word.isEmpty())
throw new IllegalArgumentException("Word can't be empty or null");
wordFrequencies.incrementCount(word, increment);
if (hasToken(word)) {
VocabWord token = tokenFor(word);
token.increaseElementFrequency(increment);
}
totalWordOccurrences.set(totalWordOccurrences.get() + increment);
}
use of org.deeplearning4j.models.word2vec.VocabWord in project deeplearning4j by deeplearning4j.
the class FlatModelUtilsTest method testWordsNearestBasic1.
@Test
public void testWordsNearestBasic1() throws Exception {
//WordVectors vec = WordVectorSerializer.loadTxtVectors(new File("/ext/Temp/Models/model.dat_trans"));
vec.setModelUtils(new BasicModelUtils<VocabWord>());
String target = "energy";
INDArray arr1 = vec.getWordVectorMatrix(target).dup();
System.out.println("[-]: " + arr1);
System.out.println("[+]: " + Transforms.unitVec(arr1));
Collection<String> list = vec.wordsNearest(target, 10);
log.info("Transpose model results:");
printWords(target, list, vec);
list = vec.wordsNearest(target, 10);
log.info("Transpose model results 2:");
printWords(target, list, vec);
list = vec.wordsNearest(target, 10);
log.info("Transpose model results 3:");
printWords(target, list, vec);
INDArray arr2 = vec.getWordVectorMatrix(target).dup();
assertEquals(arr1, arr2);
}
use of org.deeplearning4j.models.word2vec.VocabWord in project deeplearning4j by deeplearning4j.
the class AbstractCacheTest method testLabels.
@Test
public void testLabels() throws Exception {
AbstractCache<VocabWord> cache = new AbstractCache.Builder<VocabWord>().build();
cache.addToken(new VocabWord(1.0, "word"));
cache.addToken(new VocabWord(2.0, "test"));
cache.addToken(new VocabWord(3.0, "tester"));
Collection<String> collection = cache.words();
assertEquals(3, collection.size());
assertTrue(collection.contains("word"));
assertTrue(collection.contains("test"));
assertTrue(collection.contains("tester"));
}
use of org.deeplearning4j.models.word2vec.VocabWord in project deeplearning4j by deeplearning4j.
the class AbstractCacheTest method testRemoval.
@Test
public void testRemoval() throws Exception {
AbstractCache<VocabWord> cache = new AbstractCache.Builder<VocabWord>().build();
cache.addToken(new VocabWord(1.0, "word"));
cache.addToken(new VocabWord(2.0, "test"));
cache.addToken(new VocabWord(3.0, "tester"));
assertEquals(3, cache.numWords());
assertEquals(6, cache.totalWordOccurrences());
cache.removeElement("tester");
assertEquals(2, cache.numWords());
assertEquals(3, cache.totalWordOccurrences());
}
Aggregations