use of org.deeplearning4j.models.word2vec.StaticWord2Vec in project deeplearning4j by deeplearning4j.
the class WordVectorSerializer method loadStaticModel.
/**
* This method restores previously saved w2v model. File can be in one of the following formats:
* 1) Binary model, either compressed or not. Like well-known Google Model
* 2) Popular CSV word2vec text format
* 3) DL4j compressed format
*
* In return you get StaticWord2Vec model, which might be used as lookup table only in multi-gpu environment.
*
* @param file File should point to previously saved w2v model
* @return
*/
// TODO: this method needs better name :)
public static WordVectors loadStaticModel(File file) {
if (!file.exists() || file.isDirectory())
throw new RuntimeException(new FileNotFoundException("File [" + file.getAbsolutePath() + "] was not found"));
int originalFreq = Nd4j.getMemoryManager().getOccasionalGcFrequency();
boolean originalPeriodic = Nd4j.getMemoryManager().isPeriodicGcActive();
if (originalPeriodic)
Nd4j.getMemoryManager().togglePeriodicGc(false);
Nd4j.getMemoryManager().setOccasionalGcFrequency(50000);
CompressedRamStorage<Integer> storage = new CompressedRamStorage.Builder<Integer>().useInplaceCompression(false).setCompressor(new NoOp()).emulateIsAbsent(false).build();
VocabCache<VocabWord> vocabCache = new AbstractCache.Builder<VocabWord>().build();
// if zip - that's dl4j format
try {
log.debug("Trying DL4j format...");
File tmpFileSyn0 = File.createTempFile("word2vec", "syn");
ZipFile zipFile = new ZipFile(file);
ZipEntry syn0 = zipFile.getEntry("syn0.txt");
InputStream stream = zipFile.getInputStream(syn0);
Files.copy(stream, Paths.get(tmpFileSyn0.getAbsolutePath()), StandardCopyOption.REPLACE_EXISTING);
storage.clear();
try (Reader reader = new CSVReader(tmpFileSyn0)) {
while (reader.hasNext()) {
Pair<VocabWord, float[]> pair = reader.next();
VocabWord word = pair.getFirst();
storage.store(word.getIndex(), pair.getSecond());
vocabCache.addToken(word);
vocabCache.addWordToIndex(word.getIndex(), word.getLabel());
Nd4j.getMemoryManager().invokeGcOccasionally();
}
} catch (Exception e) {
throw new RuntimeException(e);
} finally {
if (originalPeriodic)
Nd4j.getMemoryManager().togglePeriodicGc(true);
Nd4j.getMemoryManager().setOccasionalGcFrequency(originalFreq);
}
} catch (Exception e) {
//
try {
// try to load file as text csv
vocabCache = new AbstractCache.Builder<VocabWord>().build();
storage.clear();
log.debug("Trying CSVReader...");
try (Reader reader = new CSVReader(file)) {
while (reader.hasNext()) {
Pair<VocabWord, float[]> pair = reader.next();
VocabWord word = pair.getFirst();
storage.store(word.getIndex(), pair.getSecond());
vocabCache.addToken(word);
vocabCache.addWordToIndex(word.getIndex(), word.getLabel());
Nd4j.getMemoryManager().invokeGcOccasionally();
}
} catch (Exception ef) {
// we throw away this exception, and trying to load data as binary model
throw new RuntimeException(ef);
} finally {
if (originalPeriodic)
Nd4j.getMemoryManager().togglePeriodicGc(true);
Nd4j.getMemoryManager().setOccasionalGcFrequency(originalFreq);
}
} catch (Exception ex) {
// otherwise it's probably google model. which might be compressed or not
log.debug("Trying BinaryReader...");
vocabCache = new AbstractCache.Builder<VocabWord>().build();
storage.clear();
try (Reader reader = new BinaryReader(file)) {
while (reader.hasNext()) {
Pair<VocabWord, float[]> pair = reader.next();
VocabWord word = pair.getFirst();
storage.store(word.getIndex(), pair.getSecond());
vocabCache.addToken(word);
vocabCache.addWordToIndex(word.getIndex(), word.getLabel());
Nd4j.getMemoryManager().invokeGcOccasionally();
}
} catch (Exception ez) {
throw new RuntimeException("Unable to guess input file format");
} finally {
if (originalPeriodic)
Nd4j.getMemoryManager().togglePeriodicGc(true);
Nd4j.getMemoryManager().setOccasionalGcFrequency(originalFreq);
}
} finally {
if (originalPeriodic)
Nd4j.getMemoryManager().togglePeriodicGc(true);
Nd4j.getMemoryManager().setOccasionalGcFrequency(originalFreq);
}
}
StaticWord2Vec word2Vec = new StaticWord2Vec.Builder(storage, vocabCache).build();
return word2Vec;
}
Aggregations