use of org.deeplearning4j.models.word2vec.VocabWord in project deeplearning4j by deeplearning4j.
the class ParagraphVectors method predict.
/**
* This method takes raw text, applies tokenizer, and returns most probable label
*
* @param rawText
* @return
*/
@Deprecated
public String predict(String rawText) {
if (tokenizerFactory == null)
throw new IllegalStateException("TokenizerFactory should be defined, prior to predict() call");
List<String> tokens = tokenizerFactory.create(rawText).getTokens();
List<VocabWord> document = new ArrayList<>();
for (String token : tokens) {
if (vocab.containsWord(token)) {
document.add(vocab.wordFor(token));
}
}
return predict(document);
}
use of org.deeplearning4j.models.word2vec.VocabWord in project deeplearning4j by deeplearning4j.
the class ParagraphVectors method extractLabels.
public void extractLabels() {
Collection<VocabWord> vocabWordCollection = vocab.vocabWords();
List<VocabWord> vocabWordList = new ArrayList<>();
int[] indexArray;
//Check if word has label and build a list out of the collection
for (VocabWord vWord : vocabWordCollection) {
if (vWord.isLabel()) {
vocabWordList.add(vWord);
}
}
//Build array of indexes in the order of the vocablist
indexArray = new int[vocabWordList.size()];
int i = 0;
for (VocabWord vWord : vocabWordList) {
indexArray[i] = vWord.getIndex();
i++;
}
//pull the label rows and create new matrix
if (i > 0) {
labelsMatrix = Nd4j.pullRows(lookupTable.getWeights(), 1, indexArray);
labelsList = vocabWordList;
}
}
use of org.deeplearning4j.models.word2vec.VocabWord in project deeplearning4j by deeplearning4j.
the class ParagraphVectors method similarityToLabel.
/**
* This method returns similarity of the document to specific label, based on mean value
*
* @param rawText
* @param label
* @return
*/
@Deprecated
public double similarityToLabel(String rawText, String label) {
if (tokenizerFactory == null)
throw new IllegalStateException("TokenizerFactory should be defined, prior to predict() call");
List<String> tokens = tokenizerFactory.create(rawText).getTokens();
List<VocabWord> document = new ArrayList<>();
for (String token : tokens) {
if (vocab.containsWord(token)) {
document.add(vocab.wordFor(token));
}
}
return similarityToLabel(document, label);
}
use of org.deeplearning4j.models.word2vec.VocabWord in project deeplearning4j by deeplearning4j.
the class WordVectorSerializerTest method testIndexPersistence.
@Test
public void testIndexPersistence() throws Exception {
File inputFile = new ClassPathResource("/big/raw_sentences.txt").getFile();
SentenceIterator iter = UimaSentenceIterator.createWithPath(inputFile.getAbsolutePath());
// Split on white spaces in the line to get words
TokenizerFactory t = new DefaultTokenizerFactory();
t.setTokenPreProcessor(new CommonPreprocessor());
Word2Vec vec = new Word2Vec.Builder().minWordFrequency(5).iterations(1).epochs(1).layerSize(100).stopWords(new ArrayList<String>()).useAdaGrad(false).negativeSample(5).seed(42).windowSize(5).iterate(iter).tokenizerFactory(t).build();
vec.fit();
VocabCache orig = vec.getVocab();
File tempFile = File.createTempFile("temp", "w2v");
tempFile.deleteOnExit();
WordVectorSerializer.writeWordVectors(vec, tempFile);
WordVectors vec2 = WordVectorSerializer.loadTxtVectors(tempFile);
VocabCache rest = vec2.vocab();
assertEquals(orig.totalNumberOfDocs(), rest.totalNumberOfDocs());
for (VocabWord word : vec.getVocab().vocabWords()) {
INDArray array1 = vec.getWordVectorMatrix(word.getLabel());
INDArray array2 = vec2.getWordVectorMatrix(word.getLabel());
assertEquals(array1, array2);
}
}
use of org.deeplearning4j.models.word2vec.VocabWord in project deeplearning4j by deeplearning4j.
the class WordVectorSerializerTest method testMalformedLabels1.
@Test
public void testMalformedLabels1() throws Exception {
List<String> words = new ArrayList<>();
words.add("test A");
words.add("test B");
words.add("test\nC");
words.add("test`D");
words.add("test_E");
words.add("test 5");
AbstractCache<VocabWord> vocabCache = new AbstractCache<>();
int cnt = 0;
for (String word : words) {
vocabCache.addToken(new VocabWord(1.0, word));
vocabCache.addWordToIndex(cnt, word);
cnt++;
}
vocabCache.elementAtIndex(1).markAsLabel(true);
InMemoryLookupTable<VocabWord> lookupTable = new InMemoryLookupTable<>(vocabCache, 10, false, 0.01, Nd4j.getRandom(), 0.0);
lookupTable.resetWeights(true);
assertNotEquals(null, lookupTable.getSyn0());
assertNotEquals(null, lookupTable.getSyn1());
assertNotEquals(null, lookupTable.getExpTable());
assertEquals(null, lookupTable.getSyn1Neg());
ParagraphVectors vec = new ParagraphVectors.Builder().lookupTable(lookupTable).vocabCache(vocabCache).build();
File tempFile = File.createTempFile("temp", "w2v");
tempFile.deleteOnExit();
WordVectorSerializer.writeParagraphVectors(vec, tempFile);
ParagraphVectors restoredVec = WordVectorSerializer.readParagraphVectors(tempFile);
for (String word : words) {
assertEquals(true, restoredVec.hasWord(word));
}
assertTrue(restoredVec.getVocab().elementAtIndex(1).isLabel());
}
Aggregations