Search in sources :

Example 1 with Pair

use of org.deeplearning4j.berkeley.Pair in project deeplearning4j by deeplearning4j.

the class WordVectorSerializer method loadTxt.

/**
     * Loads an in memory cache from the given path (sets syn0 and the vocab)
     *
     * @param vectorsFile the path of the file to load
     * @return a Pair holding the lookup table and the vocab cache.
     * @throws FileNotFoundException if the input file does not exist
     */
public static Pair<InMemoryLookupTable, VocabCache> loadTxt(File vectorsFile) throws FileNotFoundException, UnsupportedEncodingException {
    BufferedReader reader = new BufferedReader(new InputStreamReader(new FileInputStream(vectorsFile), "UTF-8"));
    AbstractCache cache = new AbstractCache<>();
    LineIterator iter = IOUtils.lineIterator(reader);
    String line = null;
    boolean hasHeader = false;
    if (iter.hasNext()) {
        // skip header line
        line = iter.nextLine();
        //look for spaces
        if (!line.contains(" ")) {
            log.debug("Skipping first line");
            hasHeader = true;
        } else {
            // we should check for something that looks like proper word vectors here. i.e: 1 word at the 0 position, and bunch of floats further
            String[] split = line.split(" ");
            try {
                long[] header = new long[split.length];
                for (int x = 0; x < split.length; x++) {
                    header[x] = Long.parseLong(split[x]);
                }
                if (split.length < 4)
                    hasHeader = true;
                // [2] - number of documents <-- DL4j-only value
                if (split.length == 3)
                    cache.incrementTotalDocCount(header[2]);
                printOutProjectedMemoryUse(header[0], (int) header[1], 1);
                hasHeader = true;
                try {
                    reader.close();
                } catch (Exception ex) {
                }
            } catch (Exception e) {
                // if any conversion exception hits - that'll be considered header
                hasHeader = false;
            }
        }
    }
    //reposition buffer to be one line ahead
    if (hasHeader) {
        line = "";
        iter.close();
        reader = new BufferedReader(new FileReader(vectorsFile));
        iter = IOUtils.lineIterator(reader);
        iter.nextLine();
    }
    List<INDArray> arrays = new ArrayList<>();
    while (iter.hasNext()) {
        if (line.isEmpty())
            line = iter.nextLine();
        String[] split = line.split(" ");
        //split[0].replaceAll(whitespaceReplacement, " ");
        String word = decodeB64(split[0]);
        VocabWord word1 = new VocabWord(1.0, word);
        word1.setIndex(cache.numWords());
        cache.addToken(word1);
        cache.addWordToIndex(word1.getIndex(), word);
        cache.putVocabWord(word);
        float[] vector = new float[split.length - 1];
        for (int i = 1; i < split.length; i++) {
            vector[i - 1] = Float.parseFloat(split[i]);
        }
        INDArray row = Nd4j.create(vector);
        arrays.add(row);
        // workaround for skipped first row
        line = "";
    }
    INDArray syn = Nd4j.vstack(arrays);
    InMemoryLookupTable lookupTable = (InMemoryLookupTable) new InMemoryLookupTable.Builder().vectorLength(arrays.get(0).columns()).useAdaGrad(false).cache(cache).useHierarchicSoftmax(false).build();
    if (Nd4j.ENFORCE_NUMERICAL_STABILITY)
        Nd4j.clearNans(syn);
    lookupTable.setSyn0(syn);
    iter.close();
    try {
        reader.close();
    } catch (Exception e) {
    }
    return new Pair<>(lookupTable, (VocabCache) cache);
}
Also used : ArrayList(java.util.ArrayList) VocabWord(org.deeplearning4j.models.word2vec.VocabWord) AbstractCache(org.deeplearning4j.models.word2vec.wordstore.inmemory.AbstractCache) LineIterator(org.apache.commons.io.LineIterator) BasicLineIterator(org.deeplearning4j.text.sentenceiterator.BasicLineIterator) DL4JInvalidInputException(org.deeplearning4j.exception.DL4JInvalidInputException) ND4JIllegalStateException(org.nd4j.linalg.exception.ND4JIllegalStateException) InMemoryLookupTable(org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable) INDArray(org.nd4j.linalg.api.ndarray.INDArray) Pair(org.deeplearning4j.berkeley.Pair)

Example 2 with Pair

use of org.deeplearning4j.berkeley.Pair in project deeplearning4j by deeplearning4j.

the class WordVectorSerializer method readWord2VecModel.

/**
     * This method
     * 1) Binary model, either compressed or not. Like well-known Google Model
     * 2) Popular CSV word2vec text format
     * 3) DL4j compressed format
     *
     * Please note: if extended data isn't available, only weights will be loaded instead.
     *
     * @param file
     * @param extendedModel if TRUE, we'll try to load HS states & Huffman tree info, if FALSE, only weights will be loaded
     * @return
     */
public static Word2Vec readWord2VecModel(@NonNull File file, boolean extendedModel) {
    InMemoryLookupTable<VocabWord> lookupTable = new InMemoryLookupTable<>();
    AbstractCache<VocabWord> vocabCache = new AbstractCache<>();
    Word2Vec vec;
    INDArray syn0 = null;
    VectorsConfiguration configuration = new VectorsConfiguration();
    if (!file.exists() || !file.isFile())
        throw new ND4JIllegalStateException("File [" + file.getAbsolutePath() + "] doesn't exist");
    int originalFreq = Nd4j.getMemoryManager().getOccasionalGcFrequency();
    boolean originalPeriodic = Nd4j.getMemoryManager().isPeriodicGcActive();
    if (originalPeriodic)
        Nd4j.getMemoryManager().togglePeriodicGc(false);
    Nd4j.getMemoryManager().setOccasionalGcFrequency(50000);
    // try to load zip format
    try {
        if (extendedModel) {
            log.debug("Trying full model restoration...");
            if (originalPeriodic)
                Nd4j.getMemoryManager().togglePeriodicGc(true);
            Nd4j.getMemoryManager().setOccasionalGcFrequency(originalFreq);
            return readWord2Vec(file);
        } else {
            log.debug("Trying simplified model restoration...");
            File tmpFileSyn0 = File.createTempFile("word2vec", "syn");
            File tmpFileConfig = File.createTempFile("word2vec", "config");
            // we don't need full model, so we go directly to syn0 file
            ZipFile zipFile = new ZipFile(file);
            ZipEntry syn = zipFile.getEntry("syn0.txt");
            InputStream stream = zipFile.getInputStream(syn);
            Files.copy(stream, Paths.get(tmpFileSyn0.getAbsolutePath()), StandardCopyOption.REPLACE_EXISTING);
            // now we're restoring configuration saved earlier
            ZipEntry config = zipFile.getEntry("config.json");
            if (config != null) {
                stream = zipFile.getInputStream(config);
                StringBuilder builder = new StringBuilder();
                try (BufferedReader reader = new BufferedReader(new InputStreamReader(stream))) {
                    String line;
                    while ((line = reader.readLine()) != null) {
                        builder.append(line);
                    }
                }
                configuration = VectorsConfiguration.fromJson(builder.toString().trim());
            }
            ZipEntry ve = zipFile.getEntry("frequencies.txt");
            if (ve != null) {
                stream = zipFile.getInputStream(ve);
                AtomicInteger cnt = new AtomicInteger(0);
                try (BufferedReader reader = new BufferedReader(new InputStreamReader(stream))) {
                    String line;
                    while ((line = reader.readLine()) != null) {
                        String[] split = line.split(" ");
                        VocabWord word = new VocabWord(Double.valueOf(split[1]), decodeB64(split[0]));
                        word.setIndex(cnt.getAndIncrement());
                        word.incrementSequencesCount(Long.valueOf(split[2]));
                        vocabCache.addToken(word);
                        vocabCache.addWordToIndex(word.getIndex(), word.getLabel());
                        Nd4j.getMemoryManager().invokeGcOccasionally();
                    }
                }
            }
            List<INDArray> rows = new ArrayList<>();
            // basically read up everything, call vstacl and then return model
            try (Reader reader = new CSVReader(tmpFileSyn0)) {
                AtomicInteger cnt = new AtomicInteger(0);
                while (reader.hasNext()) {
                    Pair<VocabWord, float[]> pair = reader.next();
                    VocabWord word = pair.getFirst();
                    INDArray vector = Nd4j.create(pair.getSecond());
                    if (ve != null) {
                        if (syn0 == null)
                            syn0 = Nd4j.create(vocabCache.numWords(), vector.length());
                        syn0.getRow(cnt.getAndIncrement()).assign(vector);
                    } else {
                        rows.add(vector);
                        vocabCache.addToken(word);
                        vocabCache.addWordToIndex(word.getIndex(), word.getLabel());
                    }
                    Nd4j.getMemoryManager().invokeGcOccasionally();
                }
            } catch (Exception e) {
                throw new RuntimeException(e);
            } finally {
                if (originalPeriodic)
                    Nd4j.getMemoryManager().togglePeriodicGc(true);
                Nd4j.getMemoryManager().setOccasionalGcFrequency(originalFreq);
            }
            if (syn0 == null && vocabCache.numWords() > 0)
                syn0 = Nd4j.vstack(rows);
            if (syn0 == null) {
                log.error("Can't build syn0 table");
                throw new DL4JInvalidInputException("Can't build syn0 table");
            }
            lookupTable = new InMemoryLookupTable.Builder<VocabWord>().cache(vocabCache).vectorLength(syn0.columns()).useHierarchicSoftmax(false).useAdaGrad(false).build();
            lookupTable.setSyn0(syn0);
            try {
                tmpFileSyn0.delete();
                tmpFileConfig.delete();
            } catch (Exception e) {
            //
            }
        }
    } catch (Exception e) {
        // let's try to load this file as csv file
        try {
            log.debug("Trying CSV model restoration...");
            Pair<InMemoryLookupTable, VocabCache> pair = loadTxt(file);
            lookupTable = pair.getFirst();
            vocabCache = (AbstractCache<VocabWord>) pair.getSecond();
        } catch (Exception ex) {
            // we fallback to trying binary model instead
            try {
                log.debug("Trying binary model restoration...");
                if (originalPeriodic)
                    Nd4j.getMemoryManager().togglePeriodicGc(true);
                Nd4j.getMemoryManager().setOccasionalGcFrequency(originalFreq);
                vec = loadGoogleModel(file, true, true);
                return vec;
            } catch (Exception ey) {
                // try to load without linebreaks
                try {
                    if (originalPeriodic)
                        Nd4j.getMemoryManager().togglePeriodicGc(true);
                    Nd4j.getMemoryManager().setOccasionalGcFrequency(originalFreq);
                    vec = loadGoogleModel(file, true, false);
                    return vec;
                } catch (Exception ez) {
                    throw new RuntimeException("Unable to guess input file format. Please use corresponding loader directly");
                }
            }
        }
    }
    Word2Vec.Builder builder = new Word2Vec.Builder(configuration).lookupTable(lookupTable).useAdaGrad(false).vocabCache(vocabCache).layerSize(lookupTable.layerSize()).useHierarchicSoftmax(false).resetModel(false);
    /*
            Trying to restore TokenizerFactory & TokenPreProcessor
         */
    TokenizerFactory factory = getTokenizerFactory(configuration);
    if (factory != null)
        builder.tokenizerFactory(factory);
    vec = builder.build();
    return vec;
}
Also used : ZipEntry(java.util.zip.ZipEntry) ArrayList(java.util.ArrayList) VocabWord(org.deeplearning4j.models.word2vec.VocabWord) AbstractCache(org.deeplearning4j.models.word2vec.wordstore.inmemory.AbstractCache) InMemoryLookupTable(org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable) StaticWord2Vec(org.deeplearning4j.models.word2vec.StaticWord2Vec) Word2Vec(org.deeplearning4j.models.word2vec.Word2Vec) Pair(org.deeplearning4j.berkeley.Pair) TokenizerFactory(org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory) GZIPInputStream(java.util.zip.GZIPInputStream) DL4JInvalidInputException(org.deeplearning4j.exception.DL4JInvalidInputException) ND4JIllegalStateException(org.nd4j.linalg.exception.ND4JIllegalStateException) INDArray(org.nd4j.linalg.api.ndarray.INDArray) ZipFile(java.util.zip.ZipFile) AtomicInteger(java.util.concurrent.atomic.AtomicInteger) ND4JIllegalStateException(org.nd4j.linalg.exception.ND4JIllegalStateException) ZipFile(java.util.zip.ZipFile) DL4JInvalidInputException(org.deeplearning4j.exception.DL4JInvalidInputException)

Example 3 with Pair

use of org.deeplearning4j.berkeley.Pair in project deeplearning4j by deeplearning4j.

the class WordVectorSerializer method loadStaticModel.

/**
     * This method restores previously saved w2v model. File can be in one of the following formats:
     * 1) Binary model, either compressed or not. Like well-known Google Model
     * 2) Popular CSV word2vec text format
     * 3) DL4j compressed format
     *
     * In return you get StaticWord2Vec model, which might be used as lookup table only in multi-gpu environment.
     *
     * @param file File should point to previously saved w2v model
     * @return
     */
// TODO: this method needs better name :)
public static WordVectors loadStaticModel(File file) {
    if (!file.exists() || file.isDirectory())
        throw new RuntimeException(new FileNotFoundException("File [" + file.getAbsolutePath() + "] was not found"));
    int originalFreq = Nd4j.getMemoryManager().getOccasionalGcFrequency();
    boolean originalPeriodic = Nd4j.getMemoryManager().isPeriodicGcActive();
    if (originalPeriodic)
        Nd4j.getMemoryManager().togglePeriodicGc(false);
    Nd4j.getMemoryManager().setOccasionalGcFrequency(50000);
    CompressedRamStorage<Integer> storage = new CompressedRamStorage.Builder<Integer>().useInplaceCompression(false).setCompressor(new NoOp()).emulateIsAbsent(false).build();
    VocabCache<VocabWord> vocabCache = new AbstractCache.Builder<VocabWord>().build();
    // if zip - that's dl4j format
    try {
        log.debug("Trying DL4j format...");
        File tmpFileSyn0 = File.createTempFile("word2vec", "syn");
        ZipFile zipFile = new ZipFile(file);
        ZipEntry syn0 = zipFile.getEntry("syn0.txt");
        InputStream stream = zipFile.getInputStream(syn0);
        Files.copy(stream, Paths.get(tmpFileSyn0.getAbsolutePath()), StandardCopyOption.REPLACE_EXISTING);
        storage.clear();
        try (Reader reader = new CSVReader(tmpFileSyn0)) {
            while (reader.hasNext()) {
                Pair<VocabWord, float[]> pair = reader.next();
                VocabWord word = pair.getFirst();
                storage.store(word.getIndex(), pair.getSecond());
                vocabCache.addToken(word);
                vocabCache.addWordToIndex(word.getIndex(), word.getLabel());
                Nd4j.getMemoryManager().invokeGcOccasionally();
            }
        } catch (Exception e) {
            throw new RuntimeException(e);
        } finally {
            if (originalPeriodic)
                Nd4j.getMemoryManager().togglePeriodicGc(true);
            Nd4j.getMemoryManager().setOccasionalGcFrequency(originalFreq);
        }
    } catch (Exception e) {
        //
        try {
            // try to load file as text csv
            vocabCache = new AbstractCache.Builder<VocabWord>().build();
            storage.clear();
            log.debug("Trying CSVReader...");
            try (Reader reader = new CSVReader(file)) {
                while (reader.hasNext()) {
                    Pair<VocabWord, float[]> pair = reader.next();
                    VocabWord word = pair.getFirst();
                    storage.store(word.getIndex(), pair.getSecond());
                    vocabCache.addToken(word);
                    vocabCache.addWordToIndex(word.getIndex(), word.getLabel());
                    Nd4j.getMemoryManager().invokeGcOccasionally();
                }
            } catch (Exception ef) {
                // we throw away this exception, and trying to load data as binary model
                throw new RuntimeException(ef);
            } finally {
                if (originalPeriodic)
                    Nd4j.getMemoryManager().togglePeriodicGc(true);
                Nd4j.getMemoryManager().setOccasionalGcFrequency(originalFreq);
            }
        } catch (Exception ex) {
            // otherwise it's probably google model. which might be compressed or not
            log.debug("Trying BinaryReader...");
            vocabCache = new AbstractCache.Builder<VocabWord>().build();
            storage.clear();
            try (Reader reader = new BinaryReader(file)) {
                while (reader.hasNext()) {
                    Pair<VocabWord, float[]> pair = reader.next();
                    VocabWord word = pair.getFirst();
                    storage.store(word.getIndex(), pair.getSecond());
                    vocabCache.addToken(word);
                    vocabCache.addWordToIndex(word.getIndex(), word.getLabel());
                    Nd4j.getMemoryManager().invokeGcOccasionally();
                }
            } catch (Exception ez) {
                throw new RuntimeException("Unable to guess input file format");
            } finally {
                if (originalPeriodic)
                    Nd4j.getMemoryManager().togglePeriodicGc(true);
                Nd4j.getMemoryManager().setOccasionalGcFrequency(originalFreq);
            }
        } finally {
            if (originalPeriodic)
                Nd4j.getMemoryManager().togglePeriodicGc(true);
            Nd4j.getMemoryManager().setOccasionalGcFrequency(originalFreq);
        }
    }
    StaticWord2Vec word2Vec = new StaticWord2Vec.Builder(storage, vocabCache).build();
    return word2Vec;
}
Also used : GZIPInputStream(java.util.zip.GZIPInputStream) NoOp(org.nd4j.compression.impl.NoOp) ZipEntry(java.util.zip.ZipEntry) VocabWord(org.deeplearning4j.models.word2vec.VocabWord) AbstractCache(org.deeplearning4j.models.word2vec.wordstore.inmemory.AbstractCache) DL4JInvalidInputException(org.deeplearning4j.exception.DL4JInvalidInputException) ND4JIllegalStateException(org.nd4j.linalg.exception.ND4JIllegalStateException) AtomicInteger(java.util.concurrent.atomic.AtomicInteger) StaticWord2Vec(org.deeplearning4j.models.word2vec.StaticWord2Vec) ZipFile(java.util.zip.ZipFile) ZipFile(java.util.zip.ZipFile) Pair(org.deeplearning4j.berkeley.Pair)

Example 4 with Pair

use of org.deeplearning4j.berkeley.Pair in project deeplearning4j by deeplearning4j.

the class ParameterAveragingTrainingWorker method processMinibatchWithStats.

@Override
public Pair<ParameterAveragingTrainingResult, SparkTrainingStats> processMinibatchWithStats(DataSet dataSet, MultiLayerNetwork network, boolean isLast) {
    ParameterAveragingTrainingResult result = processMinibatch(dataSet, network, isLast);
    if (result == null)
        return null;
    SparkTrainingStats statsToReturn = (stats != null ? stats.build() : null);
    return new Pair<>(result, statsToReturn);
}
Also used : SparkTrainingStats(org.deeplearning4j.spark.api.stats.SparkTrainingStats) Pair(org.deeplearning4j.berkeley.Pair)

Example 5 with Pair

use of org.deeplearning4j.berkeley.Pair in project deeplearning4j by deeplearning4j.

the class ParameterAveragingTrainingWorker method getFinalResultWithStats.

@Override
public Pair<ParameterAveragingTrainingResult, SparkTrainingStats> getFinalResultWithStats(ComputationGraph graph) {
    ParameterAveragingTrainingResult result = getFinalResult(graph);
    if (result == null)
        return null;
    SparkTrainingStats statsToReturn = (stats != null ? stats.build() : null);
    return new Pair<>(result, statsToReturn);
}
Also used : SparkTrainingStats(org.deeplearning4j.spark.api.stats.SparkTrainingStats) Pair(org.deeplearning4j.berkeley.Pair)

Aggregations

Pair (org.deeplearning4j.berkeley.Pair)81 INDArray (org.nd4j.linalg.api.ndarray.INDArray)56 Gradient (org.deeplearning4j.nn.gradient.Gradient)28 DefaultGradient (org.deeplearning4j.nn.gradient.DefaultGradient)25 ArrayList (java.util.ArrayList)8 DL4JInvalidInputException (org.deeplearning4j.exception.DL4JInvalidInputException)7 VocabWord (org.deeplearning4j.models.word2vec.VocabWord)7 AtomicLong (java.util.concurrent.atomic.AtomicLong)5 JavaSparkContext (org.apache.spark.api.java.JavaSparkContext)5 Test (org.junit.Test)5 AtomicInteger (java.util.concurrent.atomic.AtomicInteger)4 SparkTrainingStats (org.deeplearning4j.spark.api.stats.SparkTrainingStats)4 TextPipeline (org.deeplearning4j.spark.text.functions.TextPipeline)4 ND4JIllegalStateException (org.nd4j.linalg.exception.ND4JIllegalStateException)4 DoublePointer (org.bytedeco.javacpp.DoublePointer)3 FloatPointer (org.bytedeco.javacpp.FloatPointer)3 Pointer (org.bytedeco.javacpp.Pointer)3 ShortPointer (org.bytedeco.javacpp.ShortPointer)3 Counter (org.deeplearning4j.berkeley.Counter)3 InMemoryLookupTable (org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable)3