use of org.deeplearning4j.models.embeddings.reader.impl.BasicModelUtils in project deeplearning4j by deeplearning4j.
the class Word2VecTests method testRunWord2Vec.
@Test
public void testRunWord2Vec() throws Exception {
// Strip white space before and after for each line
SentenceIterator iter = new BasicLineIterator(inputFile.getAbsolutePath());
// Split on white spaces in the line to get words
TokenizerFactory t = new DefaultTokenizerFactory();
t.setTokenPreProcessor(new CommonPreprocessor());
Word2Vec vec = new Word2Vec.Builder().minWordFrequency(1).iterations(3).batchSize(64).layerSize(100).stopWords(new ArrayList<String>()).seed(42).learningRate(0.025).minLearningRate(0.001).sampling(0).elementsLearningAlgorithm(new SkipGram<VocabWord>()).epochs(1).windowSize(5).allowParallelTokenization(true).modelUtils(new BasicModelUtils<VocabWord>()).iterate(iter).tokenizerFactory(t).build();
assertEquals(new ArrayList<String>(), vec.getStopWords());
vec.fit();
File tempFile = File.createTempFile("temp", "temp");
tempFile.deleteOnExit();
WordVectorSerializer.writeFullModel(vec, tempFile.getAbsolutePath());
Collection<String> lst = vec.wordsNearest("day", 10);
//log.info(Arrays.toString(lst.toArray()));
printWords("day", lst, vec);
assertEquals(10, lst.size());
double sim = vec.similarity("day", "night");
log.info("Day/night similarity: " + sim);
assertTrue(sim < 1.0);
assertTrue(sim > 0.4);
assertTrue(lst.contains("week"));
assertTrue(lst.contains("night"));
assertTrue(lst.contains("year"));
assertFalse(lst.contains(null));
lst = vec.wordsNearest("day", 10);
//log.info(Arrays.toString(lst.toArray()));
printWords("day", lst, vec);
assertTrue(lst.contains("week"));
assertTrue(lst.contains("night"));
assertTrue(lst.contains("year"));
new File("cache.ser").delete();
ArrayList<String> labels = new ArrayList<>();
labels.add("day");
labels.add("night");
labels.add("week");
INDArray matrix = vec.getWordVectors(labels);
assertEquals(matrix.getRow(0), vec.getWordVectorMatrix("day"));
assertEquals(matrix.getRow(1), vec.getWordVectorMatrix("night"));
assertEquals(matrix.getRow(2), vec.getWordVectorMatrix("week"));
WordVectorSerializer.writeWordVectors(vec, pathToWriteto);
}
use of org.deeplearning4j.models.embeddings.reader.impl.BasicModelUtils in project deeplearning4j by deeplearning4j.
the class WordVectorSerializer method fromPair.
/**
* Load word vectors from the given pair
*
* @param pair
* the given pair
* @return a read only word vectors impl based on the given lookup table and vocab
*/
public static Word2Vec fromPair(Pair<InMemoryLookupTable, VocabCache> pair) {
Word2Vec vectors = new Word2Vec();
vectors.setLookupTable(pair.getFirst());
vectors.setVocab(pair.getSecond());
vectors.setModelUtils(new BasicModelUtils());
return vectors;
}
use of org.deeplearning4j.models.embeddings.reader.impl.BasicModelUtils in project deeplearning4j by deeplearning4j.
the class WordVectorSerializer method readParagraphVectorsFromText.
/**
* Restores previously serialized ParagraphVectors model
*
* Deprecation note: Please, consider using readParagraphVectors() method instead
*
* @param stream InputStream that contains previously serialized model
* @return
*/
@Deprecated
public static ParagraphVectors readParagraphVectorsFromText(@NonNull InputStream stream) {
try {
BufferedReader reader = new BufferedReader(new InputStreamReader(stream, "UTF-8"));
ArrayList<String> labels = new ArrayList<>();
ArrayList<INDArray> arrays = new ArrayList<>();
VocabCache<VocabWord> vocabCache = new AbstractCache.Builder<VocabWord>().build();
String line = "";
while ((line = reader.readLine()) != null) {
String[] split = line.split(" ");
split[1] = split[1].replaceAll(whitespaceReplacement, " ");
VocabWord word = new VocabWord(1.0, split[1]);
if (split[0].equals("L")) {
// we have label element here
word.setSpecial(true);
word.markAsLabel(true);
labels.add(word.getLabel());
} else if (split[0].equals("E")) {
// we have usual element, aka word here
word.setSpecial(false);
word.markAsLabel(false);
} else
throw new IllegalStateException("Source stream doesn't looks like ParagraphVectors serialized model");
// this particular line is just for backward compatibility with InMemoryLookupCache
word.setIndex(vocabCache.numWords());
vocabCache.addToken(word);
vocabCache.addWordToIndex(word.getIndex(), word.getLabel());
// backward compatibility code
vocabCache.putVocabWord(word.getLabel());
float[] vector = new float[split.length - 2];
for (int i = 2; i < split.length; i++) {
vector[i - 2] = Float.parseFloat(split[i]);
}
INDArray row = Nd4j.create(vector);
arrays.add(row);
}
// now we create syn0 matrix, using previously fetched rows
/*INDArray syn = Nd4j.create(new int[]{arrays.size(), arrays.get(0).columns()});
for (int i = 0; i < syn.rows(); i++) {
syn.putRow(i, arrays.get(i));
}*/
INDArray syn = Nd4j.vstack(arrays);
InMemoryLookupTable<VocabWord> lookupTable = (InMemoryLookupTable<VocabWord>) new InMemoryLookupTable.Builder<VocabWord>().vectorLength(arrays.get(0).columns()).useAdaGrad(false).cache(vocabCache).build();
Nd4j.clearNans(syn);
lookupTable.setSyn0(syn);
LabelsSource source = new LabelsSource(labels);
ParagraphVectors vectors = new ParagraphVectors.Builder().labelsSource(source).vocabCache(vocabCache).lookupTable(lookupTable).modelUtils(new BasicModelUtils<VocabWord>()).build();
try {
reader.close();
} catch (Exception e) {
}
vectors.extractLabels();
return vectors;
} catch (Exception e) {
throw new RuntimeException(e);
}
}
Aggregations