use of org.nd4j.linalg.api.ndarray.INDArray in project deeplearning4j by deeplearning4j.
the class ParagraphVectors method predictSeveral.
/**
* Predict several labels based on the document.
* Computes a similarity wrt the mean of the
* representation of words in the document
* @param document the document
* @return possible labels in descending order
*/
@Deprecated
public Collection<String> predictSeveral(List<VocabWord> document, int limit) {
/*
This code was transferred from original ParagraphVectors DL4j implementation, and yet to be tested
*/
if (document.isEmpty())
throw new IllegalStateException("Document has no words inside");
INDArray arr = Nd4j.create(document.size(), this.layerSize);
for (int i = 0; i < document.size(); i++) {
arr.putRow(i, getWordVectorMatrix(document.get(i).getWord()));
}
INDArray docMean = arr.mean(0);
Counter<String> distances = new Counter<>();
for (String s : labelsSource.getLabels()) {
INDArray otherVec = getWordVectorMatrix(s);
double sim = Transforms.cosineSim(docMean, otherVec);
log.debug("Similarity inside: [" + s + "] -> " + sim);
distances.incrementCount(s, sim);
}
return distances.getSortedKeys().subList(0, limit);
}
use of org.nd4j.linalg.api.ndarray.INDArray in project deeplearning4j by deeplearning4j.
the class WordVectorSerializerTest method testIndexPersistence.
@Test
public void testIndexPersistence() throws Exception {
File inputFile = new ClassPathResource("/big/raw_sentences.txt").getFile();
SentenceIterator iter = UimaSentenceIterator.createWithPath(inputFile.getAbsolutePath());
// Split on white spaces in the line to get words
TokenizerFactory t = new DefaultTokenizerFactory();
t.setTokenPreProcessor(new CommonPreprocessor());
Word2Vec vec = new Word2Vec.Builder().minWordFrequency(5).iterations(1).epochs(1).layerSize(100).stopWords(new ArrayList<String>()).useAdaGrad(false).negativeSample(5).seed(42).windowSize(5).iterate(iter).tokenizerFactory(t).build();
vec.fit();
VocabCache orig = vec.getVocab();
File tempFile = File.createTempFile("temp", "w2v");
tempFile.deleteOnExit();
WordVectorSerializer.writeWordVectors(vec, tempFile);
WordVectors vec2 = WordVectorSerializer.loadTxtVectors(tempFile);
VocabCache rest = vec2.vocab();
assertEquals(orig.totalNumberOfDocs(), rest.totalNumberOfDocs());
for (VocabWord word : vec.getVocab().vocabWords()) {
INDArray array1 = vec.getWordVectorMatrix(word.getLabel());
INDArray array2 = vec2.getWordVectorMatrix(word.getLabel());
assertEquals(array1, array2);
}
}
use of org.nd4j.linalg.api.ndarray.INDArray in project deeplearning4j by deeplearning4j.
the class WordVectorSerializerTest method testUnifiedLoaderArchive2.
@Test
public void testUnifiedLoaderArchive2() throws Exception {
logger.info("Executor name: {}", Nd4j.getExecutioner().getClass().getSimpleName());
File w2v = new ClassPathResource("word2vec.dl4j/file.w2v").getFile();
WordVectors vectorsLive = WordVectorSerializer.readWord2Vec(w2v);
WordVectors vectorsUnified = WordVectorSerializer.readWord2VecModel(w2v, true);
INDArray arrayLive = vectorsLive.getWordVectorMatrix("night");
INDArray arrayStatic = vectorsUnified.getWordVectorMatrix("night");
assertNotEquals(null, arrayLive);
assertEquals(arrayLive, arrayStatic);
assertNotEquals(null, ((InMemoryLookupTable) vectorsUnified.lookupTable()).getSyn1());
}
use of org.nd4j.linalg.api.ndarray.INDArray in project deeplearning4j by deeplearning4j.
the class WordVectorSerializerTest method testUnifiedLoaderText.
/**
* This method tests CSV file loading via unified loader
*
* @throws Exception
*/
@Test
public void testUnifiedLoaderText() throws Exception {
logger.info("Executor name: {}", Nd4j.getExecutioner().getClass().getSimpleName());
WordVectors vectorsLive = WordVectorSerializer.loadTxtVectors(textFile);
WordVectors vectorsUnified = WordVectorSerializer.readWord2VecModel(textFile, true);
INDArray arrayLive = vectorsLive.getWordVectorMatrix("Morgan_Freeman");
INDArray arrayStatic = vectorsUnified.getWordVectorMatrix("Morgan_Freeman");
assertNotEquals(null, arrayLive);
assertEquals(arrayLive, arrayStatic);
// we're trying EXTENDED model, but file doesn't have syn1/huffman info, so it should be silently degraded to simplified model
assertEquals(null, ((InMemoryLookupTable) vectorsUnified.lookupTable()).getSyn1());
}
use of org.nd4j.linalg.api.ndarray.INDArray in project deeplearning4j by deeplearning4j.
the class WordVectorSerializerTest method testStaticLoaderText.
/**
* This method tests CSV file loading as static model
*
* @throws Exception
*/
@Test
public void testStaticLoaderText() throws Exception {
logger.info("Executor name: {}", Nd4j.getExecutioner().getClass().getSimpleName());
WordVectors vectorsLive = WordVectorSerializer.loadTxtVectors(textFile);
WordVectors vectorsStatic = WordVectorSerializer.loadStaticModel(textFile);
INDArray arrayLive = vectorsLive.getWordVectorMatrix("Morgan_Freeman");
INDArray arrayStatic = vectorsStatic.getWordVectorMatrix("Morgan_Freeman");
assertNotEquals(null, arrayLive);
assertEquals(arrayLive, arrayStatic);
}
Aggregations