Search in sources :

Example 51 with INDArray

use of org.nd4j.linalg.api.ndarray.INDArray in project deeplearning4j by deeplearning4j.

the class WordVectorSerializer method readWord2VecModel.

/**
     * This method
     * 1) Binary model, either compressed or not. Like well-known Google Model
     * 2) Popular CSV word2vec text format
     * 3) DL4j compressed format
     *
     * Please note: if extended data isn't available, only weights will be loaded instead.
     *
     * @param file
     * @param extendedModel if TRUE, we'll try to load HS states & Huffman tree info, if FALSE, only weights will be loaded
     * @return
     */
public static Word2Vec readWord2VecModel(@NonNull File file, boolean extendedModel) {
    InMemoryLookupTable<VocabWord> lookupTable = new InMemoryLookupTable<>();
    AbstractCache<VocabWord> vocabCache = new AbstractCache<>();
    Word2Vec vec;
    INDArray syn0 = null;
    VectorsConfiguration configuration = new VectorsConfiguration();
    if (!file.exists() || !file.isFile())
        throw new ND4JIllegalStateException("File [" + file.getAbsolutePath() + "] doesn't exist");
    int originalFreq = Nd4j.getMemoryManager().getOccasionalGcFrequency();
    boolean originalPeriodic = Nd4j.getMemoryManager().isPeriodicGcActive();
    if (originalPeriodic)
        Nd4j.getMemoryManager().togglePeriodicGc(false);
    Nd4j.getMemoryManager().setOccasionalGcFrequency(50000);
    // try to load zip format
    try {
        if (extendedModel) {
            log.debug("Trying full model restoration...");
            if (originalPeriodic)
                Nd4j.getMemoryManager().togglePeriodicGc(true);
            Nd4j.getMemoryManager().setOccasionalGcFrequency(originalFreq);
            return readWord2Vec(file);
        } else {
            log.debug("Trying simplified model restoration...");
            File tmpFileSyn0 = File.createTempFile("word2vec", "syn");
            File tmpFileConfig = File.createTempFile("word2vec", "config");
            // we don't need full model, so we go directly to syn0 file
            ZipFile zipFile = new ZipFile(file);
            ZipEntry syn = zipFile.getEntry("syn0.txt");
            InputStream stream = zipFile.getInputStream(syn);
            Files.copy(stream, Paths.get(tmpFileSyn0.getAbsolutePath()), StandardCopyOption.REPLACE_EXISTING);
            // now we're restoring configuration saved earlier
            ZipEntry config = zipFile.getEntry("config.json");
            if (config != null) {
                stream = zipFile.getInputStream(config);
                StringBuilder builder = new StringBuilder();
                try (BufferedReader reader = new BufferedReader(new InputStreamReader(stream))) {
                    String line;
                    while ((line = reader.readLine()) != null) {
                        builder.append(line);
                    }
                }
                configuration = VectorsConfiguration.fromJson(builder.toString().trim());
            }
            ZipEntry ve = zipFile.getEntry("frequencies.txt");
            if (ve != null) {
                stream = zipFile.getInputStream(ve);
                AtomicInteger cnt = new AtomicInteger(0);
                try (BufferedReader reader = new BufferedReader(new InputStreamReader(stream))) {
                    String line;
                    while ((line = reader.readLine()) != null) {
                        String[] split = line.split(" ");
                        VocabWord word = new VocabWord(Double.valueOf(split[1]), decodeB64(split[0]));
                        word.setIndex(cnt.getAndIncrement());
                        word.incrementSequencesCount(Long.valueOf(split[2]));
                        vocabCache.addToken(word);
                        vocabCache.addWordToIndex(word.getIndex(), word.getLabel());
                        Nd4j.getMemoryManager().invokeGcOccasionally();
                    }
                }
            }
            List<INDArray> rows = new ArrayList<>();
            // basically read up everything, call vstacl and then return model
            try (Reader reader = new CSVReader(tmpFileSyn0)) {
                AtomicInteger cnt = new AtomicInteger(0);
                while (reader.hasNext()) {
                    Pair<VocabWord, float[]> pair = reader.next();
                    VocabWord word = pair.getFirst();
                    INDArray vector = Nd4j.create(pair.getSecond());
                    if (ve != null) {
                        if (syn0 == null)
                            syn0 = Nd4j.create(vocabCache.numWords(), vector.length());
                        syn0.getRow(cnt.getAndIncrement()).assign(vector);
                    } else {
                        rows.add(vector);
                        vocabCache.addToken(word);
                        vocabCache.addWordToIndex(word.getIndex(), word.getLabel());
                    }
                    Nd4j.getMemoryManager().invokeGcOccasionally();
                }
            } catch (Exception e) {
                throw new RuntimeException(e);
            } finally {
                if (originalPeriodic)
                    Nd4j.getMemoryManager().togglePeriodicGc(true);
                Nd4j.getMemoryManager().setOccasionalGcFrequency(originalFreq);
            }
            if (syn0 == null && vocabCache.numWords() > 0)
                syn0 = Nd4j.vstack(rows);
            if (syn0 == null) {
                log.error("Can't build syn0 table");
                throw new DL4JInvalidInputException("Can't build syn0 table");
            }
            lookupTable = new InMemoryLookupTable.Builder<VocabWord>().cache(vocabCache).vectorLength(syn0.columns()).useHierarchicSoftmax(false).useAdaGrad(false).build();
            lookupTable.setSyn0(syn0);
            try {
                tmpFileSyn0.delete();
                tmpFileConfig.delete();
            } catch (Exception e) {
            //
            }
        }
    } catch (Exception e) {
        // let's try to load this file as csv file
        try {
            log.debug("Trying CSV model restoration...");
            Pair<InMemoryLookupTable, VocabCache> pair = loadTxt(file);
            lookupTable = pair.getFirst();
            vocabCache = (AbstractCache<VocabWord>) pair.getSecond();
        } catch (Exception ex) {
            // we fallback to trying binary model instead
            try {
                log.debug("Trying binary model restoration...");
                if (originalPeriodic)
                    Nd4j.getMemoryManager().togglePeriodicGc(true);
                Nd4j.getMemoryManager().setOccasionalGcFrequency(originalFreq);
                vec = loadGoogleModel(file, true, true);
                return vec;
            } catch (Exception ey) {
                // try to load without linebreaks
                try {
                    if (originalPeriodic)
                        Nd4j.getMemoryManager().togglePeriodicGc(true);
                    Nd4j.getMemoryManager().setOccasionalGcFrequency(originalFreq);
                    vec = loadGoogleModel(file, true, false);
                    return vec;
                } catch (Exception ez) {
                    throw new RuntimeException("Unable to guess input file format. Please use corresponding loader directly");
                }
            }
        }
    }
    Word2Vec.Builder builder = new Word2Vec.Builder(configuration).lookupTable(lookupTable).useAdaGrad(false).vocabCache(vocabCache).layerSize(lookupTable.layerSize()).useHierarchicSoftmax(false).resetModel(false);
    /*
            Trying to restore TokenizerFactory & TokenPreProcessor
         */
    TokenizerFactory factory = getTokenizerFactory(configuration);
    if (factory != null)
        builder.tokenizerFactory(factory);
    vec = builder.build();
    return vec;
}
Also used : ZipEntry(java.util.zip.ZipEntry) ArrayList(java.util.ArrayList) VocabWord(org.deeplearning4j.models.word2vec.VocabWord) AbstractCache(org.deeplearning4j.models.word2vec.wordstore.inmemory.AbstractCache) InMemoryLookupTable(org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable) StaticWord2Vec(org.deeplearning4j.models.word2vec.StaticWord2Vec) Word2Vec(org.deeplearning4j.models.word2vec.Word2Vec) Pair(org.deeplearning4j.berkeley.Pair) TokenizerFactory(org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory) GZIPInputStream(java.util.zip.GZIPInputStream) DL4JInvalidInputException(org.deeplearning4j.exception.DL4JInvalidInputException) ND4JIllegalStateException(org.nd4j.linalg.exception.ND4JIllegalStateException) INDArray(org.nd4j.linalg.api.ndarray.INDArray) ZipFile(java.util.zip.ZipFile) AtomicInteger(java.util.concurrent.atomic.AtomicInteger) ND4JIllegalStateException(org.nd4j.linalg.exception.ND4JIllegalStateException) ZipFile(java.util.zip.ZipFile) DL4JInvalidInputException(org.deeplearning4j.exception.DL4JInvalidInputException)

Example 52 with INDArray

use of org.nd4j.linalg.api.ndarray.INDArray in project deeplearning4j by deeplearning4j.

the class BasicModelUtils method wordsNearest.

/**
     * Words nearest based on positive and negative words
     * * @param top the top n words
     * @return the words nearest the mean of the words
     */
@Override
public Collection<String> wordsNearest(INDArray words, int top) {
    if (lookupTable instanceof InMemoryLookupTable) {
        InMemoryLookupTable l = (InMemoryLookupTable) lookupTable;
        INDArray syn0 = l.getSyn0();
        if (!normalized) {
            synchronized (this) {
                if (!normalized) {
                    syn0.diviColumnVector(syn0.norm2(1));
                    normalized = true;
                }
            }
        }
        INDArray similarity = Transforms.unitVec(words).mmul(syn0.transpose());
        List<Double> highToLowSimList = getTopN(similarity, top + 20);
        List<WordSimilarity> result = new ArrayList<>();
        for (int i = 0; i < highToLowSimList.size(); i++) {
            String word = vocabCache.wordAtIndex(highToLowSimList.get(i).intValue());
            if (word != null && !word.equals("UNK") && !word.equals("STOP")) {
                INDArray otherVec = lookupTable.vector(word);
                double sim = Transforms.cosineSim(words, otherVec);
                result.add(new WordSimilarity(word, sim));
            }
        }
        Collections.sort(result, new SimilarityComparator());
        return getLabels(result, top);
    }
    Counter<String> distances = new Counter<>();
    for (String s : vocabCache.words()) {
        INDArray otherVec = lookupTable.vector(s);
        double sim = Transforms.cosineSim(words, otherVec);
        distances.incrementCount(s, sim);
    }
    distances.keepTopNKeys(top);
    return distances.keySet();
}
Also used : InMemoryLookupTable(org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable) Counter(org.deeplearning4j.berkeley.Counter) INDArray(org.nd4j.linalg.api.ndarray.INDArray)

Example 53 with INDArray

use of org.nd4j.linalg.api.ndarray.INDArray in project deeplearning4j by deeplearning4j.

the class BasicModelUtils method similarity.

/**
     * Returns the similarity of 2 words. Result value will be in range [-1,1], where -1.0 is exact opposite similarity, i.e. NO similarity, and 1.0 is total match of two word vectors.
     * However, most of time you'll see values in range [0,1], but that's something depends of training corpus.
     *
     * Returns NaN if any of labels not exists in vocab, or any label is null
     *
     * @param label1 the first word
     * @param label2 the second word
     * @return a normalized similarity (cosine similarity)
     */
@Override
public double similarity(@NonNull String label1, @NonNull String label2) {
    if (label1 == null || label2 == null) {
        log.debug("LABELS: " + label1 + ": " + (label1 == null ? "null" : EXISTS) + ";" + label2 + " vec2:" + (label2 == null ? "null" : EXISTS));
        return Double.NaN;
    }
    if (!vocabCache.hasToken(label1)) {
        log.debug("Unknown token 1 requested: [{}]", label1);
        return Double.NaN;
    }
    if (!vocabCache.hasToken(label2)) {
        log.debug("Unknown token 2 requested: [{}]", label2);
        return Double.NaN;
    }
    INDArray vec1 = lookupTable.vector(label1).dup();
    INDArray vec2 = lookupTable.vector(label2).dup();
    if (vec1 == null || vec2 == null) {
        log.debug(label1 + ": " + (vec1 == null ? "null" : EXISTS) + ";" + label2 + " vec2:" + (vec2 == null ? "null" : EXISTS));
        return Double.NaN;
    }
    if (label1.equals(label2))
        return 1.0;
    return Transforms.cosineSim(vec1, vec2);
}
Also used : INDArray(org.nd4j.linalg.api.ndarray.INDArray)

Example 54 with INDArray

use of org.nd4j.linalg.api.ndarray.INDArray in project deeplearning4j by deeplearning4j.

the class TreeModelUtils method wordsNearest.

@Override
public Collection<String> wordsNearest(Collection<String> positive, Collection<String> negative, int top) {
    // Check every word is in the model
    for (String p : SetUtils.union(new HashSet<>(positive), new HashSet<>(negative))) {
        if (!vocabCache.containsWord(p)) {
            return new ArrayList<>();
        }
    }
    INDArray words = Nd4j.create(positive.size() + negative.size(), lookupTable.layerSize());
    int row = 0;
    for (String s : positive) {
        words.putRow(row++, lookupTable.vector(s));
    }
    for (String s : negative) {
        words.putRow(row++, lookupTable.vector(s).mul(-1));
    }
    INDArray mean = words.isMatrix() ? words.mean(0) : words;
    return wordsNearest(mean, top);
}
Also used : INDArray(org.nd4j.linalg.api.ndarray.INDArray) DataPoint(org.deeplearning4j.clustering.sptree.DataPoint)

Example 55 with INDArray

use of org.nd4j.linalg.api.ndarray.INDArray in project deeplearning4j by deeplearning4j.

the class GloVe method iterateSample.

private double iterateSample(T element1, T element2, double score) {
    //prediction: input + bias
    if (element1.getIndex() < 0 || element1.getIndex() >= syn0.rows())
        throw new IllegalArgumentException("Illegal index for word " + element1.getLabel());
    if (element2.getIndex() < 0 || element2.getIndex() >= syn0.rows())
        throw new IllegalArgumentException("Illegal index for word " + element2.getLabel());
    INDArray w1Vector = syn0.slice(element1.getIndex());
    INDArray w2Vector = syn0.slice(element2.getIndex());
    //w1 * w2 + bias
    double prediction = Nd4j.getBlasWrapper().dot(w1Vector, w2Vector);
    prediction += bias.getDouble(element1.getIndex()) + bias.getDouble(element2.getIndex()) - Math.log(score);
    // Math.pow(Math.min(1.0,(score / maxCount)),xMax);
    double fDiff = (score > xMax) ? prediction : Math.pow(score / xMax, alpha) * prediction;
    if (Double.isNaN(fDiff))
        fDiff = Nd4j.EPS_THRESHOLD;
    //amount of change
    double gradient = fDiff * learningRate;
    //note the update step here: the gradient is
    //the gradient of the OPPOSITE word
    //for adagrad we will use the index of the word passed in
    //for the gradient calculation we will use the context vector
    update(element1, w1Vector, w2Vector, gradient);
    update(element2, w2Vector, w1Vector, gradient);
    return 0.5 * fDiff * prediction;
}
Also used : INDArray(org.nd4j.linalg.api.ndarray.INDArray)

Aggregations

INDArray (org.nd4j.linalg.api.ndarray.INDArray)1034 Test (org.junit.Test)453 NeuralNetConfiguration (org.deeplearning4j.nn.conf.NeuralNetConfiguration)173 DataSet (org.nd4j.linalg.dataset.DataSet)171 MultiLayerNetwork (org.deeplearning4j.nn.multilayer.MultiLayerNetwork)166 MultiLayerConfiguration (org.deeplearning4j.nn.conf.MultiLayerConfiguration)143 Gradient (org.deeplearning4j.nn.gradient.Gradient)100 Layer (org.deeplearning4j.nn.api.Layer)82 NormalDistribution (org.deeplearning4j.nn.conf.distribution.NormalDistribution)77 OutputLayer (org.deeplearning4j.nn.conf.layers.OutputLayer)69 DefaultGradient (org.deeplearning4j.nn.gradient.DefaultGradient)68 File (java.io.File)67 DenseLayer (org.deeplearning4j.nn.conf.layers.DenseLayer)66 ArrayList (java.util.ArrayList)65 ComputationGraph (org.deeplearning4j.nn.graph.ComputationGraph)62 DataSetIterator (org.nd4j.linalg.dataset.api.iterator.DataSetIterator)62 Pair (org.deeplearning4j.berkeley.Pair)56 Random (java.util.Random)54 ComputationGraphConfiguration (org.deeplearning4j.nn.conf.ComputationGraphConfiguration)53 IrisDataSetIterator (org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator)44