Search in sources :

Example 46 with INDArray

use of org.nd4j.linalg.api.ndarray.INDArray in project deeplearning4j by deeplearning4j.

the class WordVectorSerializer method readTextModel.

/**
     * @param modelFile
     * @return
     * @throws FileNotFoundException
     * @throws IOException
     * @throws NumberFormatException
     */
private static Word2Vec readTextModel(File modelFile) throws IOException, NumberFormatException {
    InMemoryLookupTable lookupTable;
    VocabCache cache;
    INDArray syn0;
    Word2Vec ret = new Word2Vec();
    try (BufferedReader reader = new BufferedReader(new InputStreamReader(GzipUtils.isCompressedFilename(modelFile.getName()) ? new GZIPInputStream(new FileInputStream(modelFile)) : new FileInputStream(modelFile), "UTF-8"))) {
        String line = reader.readLine();
        String[] initial = line.split(" ");
        int words = Integer.parseInt(initial[0]);
        int layerSize = Integer.parseInt(initial[1]);
        syn0 = Nd4j.create(words, layerSize);
        cache = new InMemoryLookupCache(false);
        int currLine = 0;
        while ((line = reader.readLine()) != null) {
            String[] split = line.split(" ");
            assert split.length == layerSize + 1;
            String word = split[0].replaceAll(whitespaceReplacement, " ");
            float[] vector = new float[split.length - 1];
            for (int i = 1; i < split.length; i++) {
                vector[i - 1] = Float.parseFloat(split[i]);
            }
            syn0.putRow(currLine, Nd4j.create(vector));
            cache.addWordToIndex(cache.numWords(), word);
            cache.addToken(new VocabWord(1, word));
            cache.putVocabWord(word);
            currLine++;
        }
        lookupTable = (InMemoryLookupTable) new InMemoryLookupTable.Builder().cache(cache).vectorLength(layerSize).build();
        lookupTable.setSyn0(syn0);
        ret.setVocab(cache);
        ret.setLookupTable(lookupTable);
    }
    return ret;
}
Also used : VocabWord(org.deeplearning4j.models.word2vec.VocabWord) InMemoryLookupCache(org.deeplearning4j.models.word2vec.wordstore.inmemory.InMemoryLookupCache) GZIPInputStream(java.util.zip.GZIPInputStream) InMemoryLookupTable(org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable) INDArray(org.nd4j.linalg.api.ndarray.INDArray) VocabCache(org.deeplearning4j.models.word2vec.wordstore.VocabCache) StaticWord2Vec(org.deeplearning4j.models.word2vec.StaticWord2Vec) Word2Vec(org.deeplearning4j.models.word2vec.Word2Vec)

Example 47 with INDArray

use of org.nd4j.linalg.api.ndarray.INDArray in project deeplearning4j by deeplearning4j.

the class WordVectorSerializer method readWord2VecFromText.

/**
     * This method allows you to read ParagraphVectors from externaly originated vectors and syn1.
     * So, technically this method is compatible with any other w2v implementation
     *
     * @param vectors   text file with words and their wieghts, aka Syn0
     * @param hs    text file HS layers, aka Syn1
     * @param h_codes   text file with Huffman tree codes
     * @param h_points  text file with Huffman tree points
     * @return
     */
public static Word2Vec readWord2VecFromText(@NonNull File vectors, @NonNull File hs, @NonNull File h_codes, @NonNull File h_points, @NonNull VectorsConfiguration configuration) throws IOException {
    // first we load syn0
    Pair<InMemoryLookupTable, VocabCache> pair = loadTxt(vectors);
    InMemoryLookupTable lookupTable = pair.getFirst();
    lookupTable.setNegative(configuration.getNegative());
    if (configuration.getNegative() > 0)
        lookupTable.initNegative();
    VocabCache<VocabWord> vocab = (VocabCache<VocabWord>) pair.getSecond();
    // now we load syn1
    BufferedReader reader = new BufferedReader(new FileReader(hs));
    String line = null;
    List<INDArray> rows = new ArrayList<>();
    while ((line = reader.readLine()) != null) {
        String[] split = line.split(" ");
        double[] array = new double[split.length];
        for (int i = 0; i < split.length; i++) {
            array[i] = Double.parseDouble(split[i]);
        }
        rows.add(Nd4j.create(array));
    }
    reader.close();
    // it's possible to have full model without syn1
    if (rows.size() > 0) {
        INDArray syn1 = Nd4j.vstack(rows);
        lookupTable.setSyn1(syn1);
    }
    // now we transform mappings into huffman tree points
    reader = new BufferedReader(new FileReader(h_points));
    while ((line = reader.readLine()) != null) {
        String[] split = line.split(" ");
        VocabWord word = vocab.wordFor(decodeB64(split[0]));
        List<Integer> points = new ArrayList<>();
        for (int i = 1; i < split.length; i++) {
            points.add(Integer.parseInt(split[i]));
        }
        word.setPoints(points);
    }
    reader.close();
    // now we transform mappings into huffman tree codes
    reader = new BufferedReader(new FileReader(h_codes));
    while ((line = reader.readLine()) != null) {
        String[] split = line.split(" ");
        VocabWord word = vocab.wordFor(decodeB64(split[0]));
        List<Byte> codes = new ArrayList<>();
        for (int i = 1; i < split.length; i++) {
            codes.add(Byte.parseByte(split[i]));
        }
        word.setCodes(codes);
        word.setCodeLength((short) codes.size());
    }
    reader.close();
    Word2Vec.Builder builder = new Word2Vec.Builder(configuration).vocabCache(vocab).lookupTable(lookupTable).resetModel(false);
    TokenizerFactory factory = getTokenizerFactory(configuration);
    if (factory != null)
        builder.tokenizerFactory(factory);
    Word2Vec w2v = builder.build();
    return w2v;
}
Also used : TokenizerFactory(org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory) ArrayList(java.util.ArrayList) VocabWord(org.deeplearning4j.models.word2vec.VocabWord) AtomicInteger(java.util.concurrent.atomic.AtomicInteger) InMemoryLookupTable(org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable) INDArray(org.nd4j.linalg.api.ndarray.INDArray) VocabCache(org.deeplearning4j.models.word2vec.wordstore.VocabCache) StaticWord2Vec(org.deeplearning4j.models.word2vec.StaticWord2Vec) Word2Vec(org.deeplearning4j.models.word2vec.Word2Vec)

Example 48 with INDArray

use of org.nd4j.linalg.api.ndarray.INDArray in project deeplearning4j by deeplearning4j.

the class WordVectorSerializer method readWord2Vec.

/**
     * This method restores Word2Vec model previously saved with writeWord2VecModel
     *
     * PLEASE NOTE: This method loads FULL model, so don't use it if you're only going to use weights.
     *
     * @param file
     * @return
     * @throws IOException
     */
@Deprecated
public static Word2Vec readWord2Vec(File file) throws IOException {
    File tmpFileSyn0 = File.createTempFile("word2vec", "0");
    File tmpFileSyn1 = File.createTempFile("word2vec", "1");
    File tmpFileC = File.createTempFile("word2vec", "c");
    File tmpFileH = File.createTempFile("word2vec", "h");
    File tmpFileF = File.createTempFile("word2vec", "f");
    tmpFileSyn0.deleteOnExit();
    tmpFileSyn1.deleteOnExit();
    tmpFileH.deleteOnExit();
    tmpFileC.deleteOnExit();
    tmpFileF.deleteOnExit();
    int originalFreq = Nd4j.getMemoryManager().getOccasionalGcFrequency();
    boolean originalPeriodic = Nd4j.getMemoryManager().isPeriodicGcActive();
    if (originalPeriodic)
        Nd4j.getMemoryManager().togglePeriodicGc(false);
    Nd4j.getMemoryManager().setOccasionalGcFrequency(50000);
    try {
        ZipFile zipFile = new ZipFile(file);
        ZipEntry syn0 = zipFile.getEntry("syn0.txt");
        InputStream stream = zipFile.getInputStream(syn0);
        Files.copy(stream, Paths.get(tmpFileSyn0.getAbsolutePath()), StandardCopyOption.REPLACE_EXISTING);
        ZipEntry syn1 = zipFile.getEntry("syn1.txt");
        stream = zipFile.getInputStream(syn1);
        Files.copy(stream, Paths.get(tmpFileSyn1.getAbsolutePath()), StandardCopyOption.REPLACE_EXISTING);
        ZipEntry codes = zipFile.getEntry("codes.txt");
        stream = zipFile.getInputStream(codes);
        Files.copy(stream, Paths.get(tmpFileC.getAbsolutePath()), StandardCopyOption.REPLACE_EXISTING);
        ZipEntry huffman = zipFile.getEntry("huffman.txt");
        stream = zipFile.getInputStream(huffman);
        Files.copy(stream, Paths.get(tmpFileH.getAbsolutePath()), StandardCopyOption.REPLACE_EXISTING);
        ZipEntry config = zipFile.getEntry("config.json");
        stream = zipFile.getInputStream(config);
        StringBuilder builder = new StringBuilder();
        try (BufferedReader reader = new BufferedReader(new InputStreamReader(stream))) {
            String line;
            while ((line = reader.readLine()) != null) {
                builder.append(line);
            }
        }
        VectorsConfiguration configuration = VectorsConfiguration.fromJson(builder.toString().trim());
        // we read first 4 files as w2v model
        Word2Vec w2v = readWord2VecFromText(tmpFileSyn0, tmpFileSyn1, tmpFileC, tmpFileH, configuration);
        // we read frequencies from frequencies.txt, however it's possible that we might not have this file
        ZipEntry frequencies = zipFile.getEntry("frequencies.txt");
        if (frequencies != null) {
            stream = zipFile.getInputStream(frequencies);
            try (BufferedReader reader = new BufferedReader(new InputStreamReader(stream))) {
                String line;
                while ((line = reader.readLine()) != null) {
                    String[] split = line.split(" ");
                    VocabWord word = w2v.getVocab().tokenFor(decodeB64(split[0]));
                    word.setElementFrequency((long) Double.parseDouble(split[1]));
                    word.setSequencesCount((long) Double.parseDouble(split[2]));
                }
            }
        }
        ZipEntry zsyn1Neg = zipFile.getEntry("syn1Neg.txt");
        if (zsyn1Neg != null) {
            stream = zipFile.getInputStream(zsyn1Neg);
            try (InputStreamReader isr = new InputStreamReader(stream);
                BufferedReader reader = new BufferedReader(isr)) {
                String line = null;
                List<INDArray> rows = new ArrayList<>();
                while ((line = reader.readLine()) != null) {
                    String[] split = line.split(" ");
                    double[] array = new double[split.length];
                    for (int i = 0; i < split.length; i++) {
                        array[i] = Double.parseDouble(split[i]);
                    }
                    rows.add(Nd4j.create(array));
                }
                // it's possible to have full model without syn1Neg
                if (rows.size() > 0) {
                    INDArray syn1Neg = Nd4j.vstack(rows);
                    ((InMemoryLookupTable) w2v.getLookupTable()).setSyn1Neg(syn1Neg);
                }
            }
        }
        return w2v;
    } finally {
        if (originalPeriodic)
            Nd4j.getMemoryManager().togglePeriodicGc(true);
        Nd4j.getMemoryManager().setOccasionalGcFrequency(originalFreq);
    }
}
Also used : GZIPInputStream(java.util.zip.GZIPInputStream) ZipEntry(java.util.zip.ZipEntry) ArrayList(java.util.ArrayList) VocabWord(org.deeplearning4j.models.word2vec.VocabWord) InMemoryLookupTable(org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable) ZipFile(java.util.zip.ZipFile) INDArray(org.nd4j.linalg.api.ndarray.INDArray) StaticWord2Vec(org.deeplearning4j.models.word2vec.StaticWord2Vec) Word2Vec(org.deeplearning4j.models.word2vec.Word2Vec) ZipFile(java.util.zip.ZipFile)

Example 49 with INDArray

use of org.nd4j.linalg.api.ndarray.INDArray in project deeplearning4j by deeplearning4j.

the class WordVectorSerializer method readSequenceVectors.

/**
     * This method loads previously saved SequenceVectors model from InputStream
     *
     * @param factory
     * @param stream
     * @param <T>
     * @return
     */
public static <T extends SequenceElement> SequenceVectors<T> readSequenceVectors(@NonNull SequenceElementFactory<T> factory, @NonNull InputStream stream) throws IOException {
    BufferedReader reader = new BufferedReader(new InputStreamReader(stream, "UTF-8"));
    // at first we load vectors configuration
    String line = reader.readLine();
    VectorsConfiguration configuration = VectorsConfiguration.fromJson(new String(Base64.decodeBase64(line), "UTF-8"));
    AbstractCache<T> vocabCache = new AbstractCache.Builder<T>().build();
    List<INDArray> rows = new ArrayList<>();
    while ((line = reader.readLine()) != null) {
        ElementPair pair = ElementPair.fromEncodedJson(line);
        T element = factory.deserialize(pair.getObject());
        rows.add(Nd4j.create(pair.getVector()));
        vocabCache.addToken(element);
        vocabCache.addWordToIndex(element.getIndex(), element.getLabel());
    }
    reader.close();
    InMemoryLookupTable<T> lookupTable = (InMemoryLookupTable<T>) new InMemoryLookupTable.Builder<T>().vectorLength(rows.get(0).columns()).build();
    /*
        INDArray syn0 = Nd4j.create(rows.size(), rows.get(0).columns());
        for (int x = 0; x < rows.size(); x++) {
            syn0.putRow(x, rows.get(x));
        }
        */
    INDArray syn0 = Nd4j.vstack(rows);
    lookupTable.setSyn0(syn0);
    SequenceVectors<T> vectors = new SequenceVectors.Builder<T>(configuration).vocabCache(vocabCache).lookupTable(lookupTable).resetModel(false).build();
    return vectors;
}
Also used : ArrayList(java.util.ArrayList) AbstractCache(org.deeplearning4j.models.word2vec.wordstore.inmemory.AbstractCache) InMemoryLookupTable(org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable) INDArray(org.nd4j.linalg.api.ndarray.INDArray)

Example 50 with INDArray

use of org.nd4j.linalg.api.ndarray.INDArray in project deeplearning4j by deeplearning4j.

the class WordVectorSerializer method writeTsneFormat.

/**
     * Write the tsne format
     *
     * @param vec
     *            the word vectors to use for labeling
     * @param tsne
     *            the tsne array to write
     * @param csv
     *            the file to use
     * @throws Exception
     */
public static void writeTsneFormat(Word2Vec vec, INDArray tsne, File csv) throws Exception {
    BufferedWriter write = new BufferedWriter(new OutputStreamWriter(new FileOutputStream(csv), "UTF-8"));
    int words = 0;
    InMemoryLookupCache l = (InMemoryLookupCache) vec.vocab();
    for (String word : vec.vocab().words()) {
        if (word == null) {
            continue;
        }
        StringBuilder sb = new StringBuilder();
        INDArray wordVector = tsne.getRow(l.wordFor(word).getIndex());
        for (int j = 0; j < wordVector.length(); j++) {
            sb.append(wordVector.getDouble(j));
            if (j < wordVector.length() - 1) {
                sb.append(",");
            }
        }
        sb.append(",");
        sb.append(word.replaceAll(" ", whitespaceReplacement));
        sb.append(" ");
        sb.append("\n");
        write.write(sb.toString());
    }
    log.info("Wrote " + words + " with size of " + vec.lookupTable().layerSize());
    write.flush();
    write.close();
}
Also used : INDArray(org.nd4j.linalg.api.ndarray.INDArray) InMemoryLookupCache(org.deeplearning4j.models.word2vec.wordstore.inmemory.InMemoryLookupCache)

Aggregations

INDArray (org.nd4j.linalg.api.ndarray.INDArray)1034 Test (org.junit.Test)453 NeuralNetConfiguration (org.deeplearning4j.nn.conf.NeuralNetConfiguration)173 DataSet (org.nd4j.linalg.dataset.DataSet)171 MultiLayerNetwork (org.deeplearning4j.nn.multilayer.MultiLayerNetwork)166 MultiLayerConfiguration (org.deeplearning4j.nn.conf.MultiLayerConfiguration)143 Gradient (org.deeplearning4j.nn.gradient.Gradient)100 Layer (org.deeplearning4j.nn.api.Layer)82 NormalDistribution (org.deeplearning4j.nn.conf.distribution.NormalDistribution)77 OutputLayer (org.deeplearning4j.nn.conf.layers.OutputLayer)69 DefaultGradient (org.deeplearning4j.nn.gradient.DefaultGradient)68 File (java.io.File)67 DenseLayer (org.deeplearning4j.nn.conf.layers.DenseLayer)66 ArrayList (java.util.ArrayList)65 ComputationGraph (org.deeplearning4j.nn.graph.ComputationGraph)62 DataSetIterator (org.nd4j.linalg.dataset.api.iterator.DataSetIterator)62 Pair (org.deeplearning4j.berkeley.Pair)56 Random (java.util.Random)54 ComputationGraphConfiguration (org.deeplearning4j.nn.conf.ComputationGraphConfiguration)53 IrisDataSetIterator (org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator)44