use of org.deeplearning4j.models.word2vec.Word2Vec in project deeplearning4j by deeplearning4j.
the class WordVectorSerializer method readWord2VecModel.
/**
* This method
* 1) Binary model, either compressed or not. Like well-known Google Model
* 2) Popular CSV word2vec text format
* 3) DL4j compressed format
*
* Please note: if extended data isn't available, only weights will be loaded instead.
*
* @param file
* @param extendedModel if TRUE, we'll try to load HS states & Huffman tree info, if FALSE, only weights will be loaded
* @return
*/
public static Word2Vec readWord2VecModel(@NonNull File file, boolean extendedModel) {
InMemoryLookupTable<VocabWord> lookupTable = new InMemoryLookupTable<>();
AbstractCache<VocabWord> vocabCache = new AbstractCache<>();
Word2Vec vec;
INDArray syn0 = null;
VectorsConfiguration configuration = new VectorsConfiguration();
if (!file.exists() || !file.isFile())
throw new ND4JIllegalStateException("File [" + file.getAbsolutePath() + "] doesn't exist");
int originalFreq = Nd4j.getMemoryManager().getOccasionalGcFrequency();
boolean originalPeriodic = Nd4j.getMemoryManager().isPeriodicGcActive();
if (originalPeriodic)
Nd4j.getMemoryManager().togglePeriodicGc(false);
Nd4j.getMemoryManager().setOccasionalGcFrequency(50000);
// try to load zip format
try {
if (extendedModel) {
log.debug("Trying full model restoration...");
if (originalPeriodic)
Nd4j.getMemoryManager().togglePeriodicGc(true);
Nd4j.getMemoryManager().setOccasionalGcFrequency(originalFreq);
return readWord2Vec(file);
} else {
log.debug("Trying simplified model restoration...");
File tmpFileSyn0 = File.createTempFile("word2vec", "syn");
File tmpFileConfig = File.createTempFile("word2vec", "config");
// we don't need full model, so we go directly to syn0 file
ZipFile zipFile = new ZipFile(file);
ZipEntry syn = zipFile.getEntry("syn0.txt");
InputStream stream = zipFile.getInputStream(syn);
Files.copy(stream, Paths.get(tmpFileSyn0.getAbsolutePath()), StandardCopyOption.REPLACE_EXISTING);
// now we're restoring configuration saved earlier
ZipEntry config = zipFile.getEntry("config.json");
if (config != null) {
stream = zipFile.getInputStream(config);
StringBuilder builder = new StringBuilder();
try (BufferedReader reader = new BufferedReader(new InputStreamReader(stream))) {
String line;
while ((line = reader.readLine()) != null) {
builder.append(line);
}
}
configuration = VectorsConfiguration.fromJson(builder.toString().trim());
}
ZipEntry ve = zipFile.getEntry("frequencies.txt");
if (ve != null) {
stream = zipFile.getInputStream(ve);
AtomicInteger cnt = new AtomicInteger(0);
try (BufferedReader reader = new BufferedReader(new InputStreamReader(stream))) {
String line;
while ((line = reader.readLine()) != null) {
String[] split = line.split(" ");
VocabWord word = new VocabWord(Double.valueOf(split[1]), decodeB64(split[0]));
word.setIndex(cnt.getAndIncrement());
word.incrementSequencesCount(Long.valueOf(split[2]));
vocabCache.addToken(word);
vocabCache.addWordToIndex(word.getIndex(), word.getLabel());
Nd4j.getMemoryManager().invokeGcOccasionally();
}
}
}
List<INDArray> rows = new ArrayList<>();
// basically read up everything, call vstacl and then return model
try (Reader reader = new CSVReader(tmpFileSyn0)) {
AtomicInteger cnt = new AtomicInteger(0);
while (reader.hasNext()) {
Pair<VocabWord, float[]> pair = reader.next();
VocabWord word = pair.getFirst();
INDArray vector = Nd4j.create(pair.getSecond());
if (ve != null) {
if (syn0 == null)
syn0 = Nd4j.create(vocabCache.numWords(), vector.length());
syn0.getRow(cnt.getAndIncrement()).assign(vector);
} else {
rows.add(vector);
vocabCache.addToken(word);
vocabCache.addWordToIndex(word.getIndex(), word.getLabel());
}
Nd4j.getMemoryManager().invokeGcOccasionally();
}
} catch (Exception e) {
throw new RuntimeException(e);
} finally {
if (originalPeriodic)
Nd4j.getMemoryManager().togglePeriodicGc(true);
Nd4j.getMemoryManager().setOccasionalGcFrequency(originalFreq);
}
if (syn0 == null && vocabCache.numWords() > 0)
syn0 = Nd4j.vstack(rows);
if (syn0 == null) {
log.error("Can't build syn0 table");
throw new DL4JInvalidInputException("Can't build syn0 table");
}
lookupTable = new InMemoryLookupTable.Builder<VocabWord>().cache(vocabCache).vectorLength(syn0.columns()).useHierarchicSoftmax(false).useAdaGrad(false).build();
lookupTable.setSyn0(syn0);
try {
tmpFileSyn0.delete();
tmpFileConfig.delete();
} catch (Exception e) {
//
}
}
} catch (Exception e) {
// let's try to load this file as csv file
try {
log.debug("Trying CSV model restoration...");
Pair<InMemoryLookupTable, VocabCache> pair = loadTxt(file);
lookupTable = pair.getFirst();
vocabCache = (AbstractCache<VocabWord>) pair.getSecond();
} catch (Exception ex) {
// we fallback to trying binary model instead
try {
log.debug("Trying binary model restoration...");
if (originalPeriodic)
Nd4j.getMemoryManager().togglePeriodicGc(true);
Nd4j.getMemoryManager().setOccasionalGcFrequency(originalFreq);
vec = loadGoogleModel(file, true, true);
return vec;
} catch (Exception ey) {
// try to load without linebreaks
try {
if (originalPeriodic)
Nd4j.getMemoryManager().togglePeriodicGc(true);
Nd4j.getMemoryManager().setOccasionalGcFrequency(originalFreq);
vec = loadGoogleModel(file, true, false);
return vec;
} catch (Exception ez) {
throw new RuntimeException("Unable to guess input file format. Please use corresponding loader directly");
}
}
}
}
Word2Vec.Builder builder = new Word2Vec.Builder(configuration).lookupTable(lookupTable).useAdaGrad(false).vocabCache(vocabCache).layerSize(lookupTable.layerSize()).useHierarchicSoftmax(false).resetModel(false);
/*
Trying to restore TokenizerFactory & TokenPreProcessor
*/
TokenizerFactory factory = getTokenizerFactory(configuration);
if (factory != null)
builder.tokenizerFactory(factory);
vec = builder.build();
return vec;
}
use of org.deeplearning4j.models.word2vec.Word2Vec in project deeplearning4j by deeplearning4j.
the class WordVectorSerializer method readParagraphVectors.
/**
* This method restores ParagraphVectors model previously saved with writeParagraphVectors()
*
* @return
*/
public static ParagraphVectors readParagraphVectors(File file) throws IOException {
File tmpFileL = File.createTempFile("paravec", "l");
tmpFileL.deleteOnExit();
Word2Vec w2v = readWord2Vec(file);
// and "convert" it to ParaVec model + optionally trying to restore labels information
ParagraphVectors vectors = new ParagraphVectors.Builder(w2v.getConfiguration()).vocabCache(w2v.getVocab()).lookupTable(w2v.getLookupTable()).resetModel(false).build();
ZipFile zipFile = new ZipFile(file);
// now we try to restore labels information
ZipEntry labels = zipFile.getEntry("labels.txt");
if (labels != null) {
InputStream stream = zipFile.getInputStream(labels);
Files.copy(stream, Paths.get(tmpFileL.getAbsolutePath()), StandardCopyOption.REPLACE_EXISTING);
try (BufferedReader reader = new BufferedReader(new FileReader(tmpFileL))) {
String line;
while ((line = reader.readLine()) != null) {
VocabWord word = vectors.getVocab().tokenFor(decodeB64(line.trim()));
if (word != null) {
word.markAsLabel(true);
}
}
}
}
vectors.extractLabels();
return vectors;
}
use of org.deeplearning4j.models.word2vec.Word2Vec in project deeplearning4j by deeplearning4j.
the class UITest method testPosting.
@Test
public void testPosting() throws Exception {
// File inputFile = new ClassPathResource("/big/raw_sentences.txt").getFile();
File inputFile = new ClassPathResource("/basic/word2vec_advance.txt").getFile();
SentenceIterator iter = UimaSentenceIterator.createWithPath(inputFile.getAbsolutePath());
// Split on white spaces in the line to get words
TokenizerFactory t = new DefaultTokenizerFactory();
t.setTokenPreProcessor(new CommonPreprocessor());
Word2Vec vec = new Word2Vec.Builder().minWordFrequency(1).iterations(1).epochs(1).layerSize(20).stopWords(new ArrayList<String>()).useAdaGrad(false).negativeSample(5).seed(42).windowSize(5).iterate(iter).tokenizerFactory(t).build();
vec.fit();
File tempFile = File.createTempFile("temp", "w2v");
tempFile.deleteOnExit();
WordVectorSerializer.writeWordVectors(vec, tempFile);
WordVectors vectors = WordVectorSerializer.loadTxtVectors(tempFile);
//Initialize
UIServer.getInstance();
UiConnectionInfo uiConnectionInfo = new UiConnectionInfo.Builder().setAddress("localhost").setPort(9000).build();
BarnesHutTsne tsne = new BarnesHutTsne.Builder().normalize(false).setFinalMomentum(0.8f).numDimension(2).setMaxIter(10).build();
vectors.lookupTable().plotVocab(tsne, vectors.lookupTable().getVocabCache().numWords(), uiConnectionInfo);
Thread.sleep(100000);
}
use of org.deeplearning4j.models.word2vec.Word2Vec in project deeplearning4j by deeplearning4j.
the class VectorsConfigurationTest method testFromW2V.
@Test
public void testFromW2V() throws Exception {
VectorsConfiguration configuration = new VectorsConfiguration();
configuration.setHugeModelExpected(true);
configuration.setWindow(5);
configuration.setIterations(3);
configuration.setLayersSize(200);
configuration.setLearningRate(1.4d);
configuration.setSampling(0.0005d);
configuration.setMinLearningRate(0.25d);
configuration.setEpochs(1);
File inputFile = new ClassPathResource("/big/raw_sentences.txt").getFile();
SentenceIterator iter = UimaSentenceIterator.createWithPath(inputFile.getAbsolutePath());
Word2Vec vec = new Word2Vec.Builder(configuration).iterate(iter).build();
VectorsConfiguration configuration2 = vec.getConfiguration();
assertEquals(configuration, configuration2);
}
use of org.deeplearning4j.models.word2vec.Word2Vec in project deeplearning4j by deeplearning4j.
the class WordVectorSerializerTest method testIndexPersistence.
@Test
public void testIndexPersistence() throws Exception {
File inputFile = new ClassPathResource("/big/raw_sentences.txt").getFile();
SentenceIterator iter = UimaSentenceIterator.createWithPath(inputFile.getAbsolutePath());
// Split on white spaces in the line to get words
TokenizerFactory t = new DefaultTokenizerFactory();
t.setTokenPreProcessor(new CommonPreprocessor());
Word2Vec vec = new Word2Vec.Builder().minWordFrequency(5).iterations(1).epochs(1).layerSize(100).stopWords(new ArrayList<String>()).useAdaGrad(false).negativeSample(5).seed(42).windowSize(5).iterate(iter).tokenizerFactory(t).build();
vec.fit();
VocabCache orig = vec.getVocab();
File tempFile = File.createTempFile("temp", "w2v");
tempFile.deleteOnExit();
WordVectorSerializer.writeWordVectors(vec, tempFile);
WordVectors vec2 = WordVectorSerializer.loadTxtVectors(tempFile);
VocabCache rest = vec2.vocab();
assertEquals(orig.totalNumberOfDocs(), rest.totalNumberOfDocs());
for (VocabWord word : vec.getVocab().vocabWords()) {
INDArray array1 = vec.getWordVectorMatrix(word.getLabel());
INDArray array2 = vec2.getWordVectorMatrix(word.getLabel());
assertEquals(array1, array2);
}
}
Aggregations