Search in sources :

Example 1 with StaticWord2Vec

use of org.deeplearning4j.models.word2vec.StaticWord2Vec in project deeplearning4j by deeplearning4j.

the class WordVectorSerializer method loadStaticModel.

/**
     * This method restores previously saved w2v model. File can be in one of the following formats:
     * 1) Binary model, either compressed or not. Like well-known Google Model
     * 2) Popular CSV word2vec text format
     * 3) DL4j compressed format
     *
     * In return you get StaticWord2Vec model, which might be used as lookup table only in multi-gpu environment.
     *
     * @param file File should point to previously saved w2v model
     * @return
     */
// TODO: this method needs better name :)
public static WordVectors loadStaticModel(File file) {
    if (!file.exists() || file.isDirectory())
        throw new RuntimeException(new FileNotFoundException("File [" + file.getAbsolutePath() + "] was not found"));
    int originalFreq = Nd4j.getMemoryManager().getOccasionalGcFrequency();
    boolean originalPeriodic = Nd4j.getMemoryManager().isPeriodicGcActive();
    if (originalPeriodic)
        Nd4j.getMemoryManager().togglePeriodicGc(false);
    Nd4j.getMemoryManager().setOccasionalGcFrequency(50000);
    CompressedRamStorage<Integer> storage = new CompressedRamStorage.Builder<Integer>().useInplaceCompression(false).setCompressor(new NoOp()).emulateIsAbsent(false).build();
    VocabCache<VocabWord> vocabCache = new AbstractCache.Builder<VocabWord>().build();
    // if zip - that's dl4j format
    try {
        log.debug("Trying DL4j format...");
        File tmpFileSyn0 = File.createTempFile("word2vec", "syn");
        ZipFile zipFile = new ZipFile(file);
        ZipEntry syn0 = zipFile.getEntry("syn0.txt");
        InputStream stream = zipFile.getInputStream(syn0);
        Files.copy(stream, Paths.get(tmpFileSyn0.getAbsolutePath()), StandardCopyOption.REPLACE_EXISTING);
        storage.clear();
        try (Reader reader = new CSVReader(tmpFileSyn0)) {
            while (reader.hasNext()) {
                Pair<VocabWord, float[]> pair = reader.next();
                VocabWord word = pair.getFirst();
                storage.store(word.getIndex(), pair.getSecond());
                vocabCache.addToken(word);
                vocabCache.addWordToIndex(word.getIndex(), word.getLabel());
                Nd4j.getMemoryManager().invokeGcOccasionally();
            }
        } catch (Exception e) {
            throw new RuntimeException(e);
        } finally {
            if (originalPeriodic)
                Nd4j.getMemoryManager().togglePeriodicGc(true);
            Nd4j.getMemoryManager().setOccasionalGcFrequency(originalFreq);
        }
    } catch (Exception e) {
        //
        try {
            // try to load file as text csv
            vocabCache = new AbstractCache.Builder<VocabWord>().build();
            storage.clear();
            log.debug("Trying CSVReader...");
            try (Reader reader = new CSVReader(file)) {
                while (reader.hasNext()) {
                    Pair<VocabWord, float[]> pair = reader.next();
                    VocabWord word = pair.getFirst();
                    storage.store(word.getIndex(), pair.getSecond());
                    vocabCache.addToken(word);
                    vocabCache.addWordToIndex(word.getIndex(), word.getLabel());
                    Nd4j.getMemoryManager().invokeGcOccasionally();
                }
            } catch (Exception ef) {
                // we throw away this exception, and trying to load data as binary model
                throw new RuntimeException(ef);
            } finally {
                if (originalPeriodic)
                    Nd4j.getMemoryManager().togglePeriodicGc(true);
                Nd4j.getMemoryManager().setOccasionalGcFrequency(originalFreq);
            }
        } catch (Exception ex) {
            // otherwise it's probably google model. which might be compressed or not
            log.debug("Trying BinaryReader...");
            vocabCache = new AbstractCache.Builder<VocabWord>().build();
            storage.clear();
            try (Reader reader = new BinaryReader(file)) {
                while (reader.hasNext()) {
                    Pair<VocabWord, float[]> pair = reader.next();
                    VocabWord word = pair.getFirst();
                    storage.store(word.getIndex(), pair.getSecond());
                    vocabCache.addToken(word);
                    vocabCache.addWordToIndex(word.getIndex(), word.getLabel());
                    Nd4j.getMemoryManager().invokeGcOccasionally();
                }
            } catch (Exception ez) {
                throw new RuntimeException("Unable to guess input file format");
            } finally {
                if (originalPeriodic)
                    Nd4j.getMemoryManager().togglePeriodicGc(true);
                Nd4j.getMemoryManager().setOccasionalGcFrequency(originalFreq);
            }
        } finally {
            if (originalPeriodic)
                Nd4j.getMemoryManager().togglePeriodicGc(true);
            Nd4j.getMemoryManager().setOccasionalGcFrequency(originalFreq);
        }
    }
    StaticWord2Vec word2Vec = new StaticWord2Vec.Builder(storage, vocabCache).build();
    return word2Vec;
}
Also used : GZIPInputStream(java.util.zip.GZIPInputStream) NoOp(org.nd4j.compression.impl.NoOp) ZipEntry(java.util.zip.ZipEntry) VocabWord(org.deeplearning4j.models.word2vec.VocabWord) AbstractCache(org.deeplearning4j.models.word2vec.wordstore.inmemory.AbstractCache) DL4JInvalidInputException(org.deeplearning4j.exception.DL4JInvalidInputException) ND4JIllegalStateException(org.nd4j.linalg.exception.ND4JIllegalStateException) AtomicInteger(java.util.concurrent.atomic.AtomicInteger) StaticWord2Vec(org.deeplearning4j.models.word2vec.StaticWord2Vec) ZipFile(java.util.zip.ZipFile) ZipFile(java.util.zip.ZipFile) Pair(org.deeplearning4j.berkeley.Pair)

Aggregations

AtomicInteger (java.util.concurrent.atomic.AtomicInteger)1 GZIPInputStream (java.util.zip.GZIPInputStream)1 ZipEntry (java.util.zip.ZipEntry)1 ZipFile (java.util.zip.ZipFile)1 Pair (org.deeplearning4j.berkeley.Pair)1 DL4JInvalidInputException (org.deeplearning4j.exception.DL4JInvalidInputException)1 StaticWord2Vec (org.deeplearning4j.models.word2vec.StaticWord2Vec)1 VocabWord (org.deeplearning4j.models.word2vec.VocabWord)1 AbstractCache (org.deeplearning4j.models.word2vec.wordstore.inmemory.AbstractCache)1 NoOp (org.nd4j.compression.impl.NoOp)1 ND4JIllegalStateException (org.nd4j.linalg.exception.ND4JIllegalStateException)1