Search in sources :

Example 1 with VocabularyWord

use of org.deeplearning4j.models.word2vec.wordstore.VocabularyWord in project deeplearning4j by deeplearning4j.

the class WordVectorSerializer method loadFullModel.

/**
     * This method loads full w2v model, previously saved with writeFullMethod call
     *
     * Deprecation note: Please, consider using readWord2VecModel() or loadStaticModel() method instead
     *
     * @param path - path to previously stored w2v json model
     * @return - Word2Vec instance
     */
@Deprecated
public static Word2Vec loadFullModel(@NonNull String path) throws FileNotFoundException {
    /*
            // TODO: implementation is in process
            We need to restore:
                     1. WeightLookupTable, including syn0 and syn1 matrices
                     2. VocabCache + mark it as SPECIAL, to avoid accidental word removals
         */
    BasicLineIterator iterator = new BasicLineIterator(new File(path));
    // first 3 lines should be processed separately
    String confJson = iterator.nextSentence();
    log.info("Word2Vec conf. JSON: " + confJson);
    VectorsConfiguration configuration = VectorsConfiguration.fromJson(confJson);
    // actually we dont need expTable, since it produces exact results on subsequent runs untill you dont modify expTable size :)
    String eTable = iterator.nextSentence();
    double[] expTable;
    String nTable = iterator.nextSentence();
    if (configuration.getNegative() > 0) {
    // TODO: we probably should parse negTable, but it's not required until vocab changes are introduced. Since on the predefined vocab it will produce exact nTable, the same goes for expTable btw.
    }
    /*
                Since we're restoring vocab from previously serialized model, we can expect minWordFrequency appliance in its vocabulary, so it should NOT be truncated.
                That's why i'm setting minWordFrequency to configuration value, but applying SPECIAL to each word, to avoid truncation
         */
    VocabularyHolder holder = new VocabularyHolder.Builder().minWordFrequency(configuration.getMinWordFrequency()).hugeModelExpected(configuration.isHugeModelExpected()).scavengerActivationThreshold(configuration.getScavengerActivationThreshold()).scavengerRetentionDelay(configuration.getScavengerRetentionDelay()).build();
    AtomicInteger counter = new AtomicInteger(0);
    AbstractCache<VocabWord> vocabCache = new AbstractCache.Builder<VocabWord>().build();
    while (iterator.hasNext()) {
        //    log.info("got line: " + iterator.nextSentence());
        String wordJson = iterator.nextSentence();
        VocabularyWord word = VocabularyWord.fromJson(wordJson);
        word.setSpecial(true);
        VocabWord vw = new VocabWord(word.getCount(), word.getWord());
        vw.setIndex(counter.getAndIncrement());
        vw.setIndex(word.getHuffmanNode().getIdx());
        vw.setCodeLength(word.getHuffmanNode().getLength());
        vw.setPoints(arrayToList(word.getHuffmanNode().getPoint(), word.getHuffmanNode().getLength()));
        vw.setCodes(arrayToList(word.getHuffmanNode().getCode(), word.getHuffmanNode().getLength()));
        vocabCache.addToken(vw);
        vocabCache.addWordToIndex(vw.getIndex(), vw.getLabel());
        vocabCache.putVocabWord(vw.getWord());
    }
    // at this moment vocab is restored, and it's time to rebuild Huffman tree
    // since word counters are equal, huffman tree will be equal too
    //holder.updateHuffmanCodes();
    // we definitely don't need UNK word in this scenarion
    //        holder.transferBackToVocabCache(vocabCache, false);
    // now, it's time to transfer syn0/syn1/syn1 neg values
    InMemoryLookupTable lookupTable = (InMemoryLookupTable) new InMemoryLookupTable.Builder().negative(configuration.getNegative()).useAdaGrad(configuration.isUseAdaGrad()).lr(configuration.getLearningRate()).cache(vocabCache).vectorLength(configuration.getLayersSize()).build();
    // we create all arrays
    lookupTable.resetWeights(true);
    iterator.reset();
    // we should skip 3 lines from file
    iterator.nextSentence();
    iterator.nextSentence();
    iterator.nextSentence();
    // now, for each word from vocabHolder we'll just transfer actual values
    while (iterator.hasNext()) {
        String wordJson = iterator.nextSentence();
        VocabularyWord word = VocabularyWord.fromJson(wordJson);
        // syn0 transfer
        INDArray syn0 = lookupTable.getSyn0().getRow(vocabCache.indexOf(word.getWord()));
        syn0.assign(Nd4j.create(word.getSyn0()));
        // syn1 transfer
        // syn1 values are being accessed via tree points, but since our goal is just deserialization - we can just push it row by row
        INDArray syn1 = lookupTable.getSyn1().getRow(vocabCache.indexOf(word.getWord()));
        syn1.assign(Nd4j.create(word.getSyn1()));
        // syn1Neg transfer
        if (configuration.getNegative() > 0) {
            INDArray syn1Neg = lookupTable.getSyn1Neg().getRow(vocabCache.indexOf(word.getWord()));
            syn1Neg.assign(Nd4j.create(word.getSyn1Neg()));
        }
    }
    Word2Vec vec = new Word2Vec.Builder(configuration).vocabCache(vocabCache).lookupTable(lookupTable).resetModel(false).build();
    vec.setModelUtils(new BasicModelUtils());
    return vec;
}
Also used : BasicLineIterator(org.deeplearning4j.text.sentenceiterator.BasicLineIterator) VocabWord(org.deeplearning4j.models.word2vec.VocabWord) AbstractCache(org.deeplearning4j.models.word2vec.wordstore.inmemory.AbstractCache) VocabularyHolder(org.deeplearning4j.models.word2vec.wordstore.VocabularyHolder) InMemoryLookupTable(org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable) INDArray(org.nd4j.linalg.api.ndarray.INDArray) BasicModelUtils(org.deeplearning4j.models.embeddings.reader.impl.BasicModelUtils) AtomicInteger(java.util.concurrent.atomic.AtomicInteger) VocabularyWord(org.deeplearning4j.models.word2vec.wordstore.VocabularyWord) StaticWord2Vec(org.deeplearning4j.models.word2vec.StaticWord2Vec) Word2Vec(org.deeplearning4j.models.word2vec.Word2Vec) ZipFile(java.util.zip.ZipFile)

Example 2 with VocabularyWord

use of org.deeplearning4j.models.word2vec.wordstore.VocabularyWord in project deeplearning4j by deeplearning4j.

the class WordVectorSerializer method writeFullModel.

/**
     * Saves full Word2Vec model in the way, that allows model updates without being rebuilt from scratches
     *
     * Deprecation note: Please, consider using writeWord2VecModel() method instead
     *
     * @param vec - The Word2Vec instance to be saved
     * @param path - the path for json to be saved
     */
@Deprecated
public static void writeFullModel(@NonNull Word2Vec vec, @NonNull String path) {
    /*
            Basically we need to save:
                    1. WeightLookupTable, especially syn0 and syn1 matrices
                    2. VocabCache, including only WordCounts
                    3. Settings from Word2Vect model: workers, layers, etc.
         */
    PrintWriter printWriter = null;
    try {
        printWriter = new PrintWriter(new OutputStreamWriter(new FileOutputStream(path), "UTF-8"));
    } catch (Exception e) {
        throw new RuntimeException(e);
    }
    WeightLookupTable<VocabWord> lookupTable = vec.getLookupTable();
    // ((InMemoryLookupTable) lookupTable).getVocab(); //vec.getVocab();
    VocabCache<VocabWord> vocabCache = vec.getVocab();
    if (!(lookupTable instanceof InMemoryLookupTable))
        throw new IllegalStateException("At this moment only InMemoryLookupTable is supported.");
    VectorsConfiguration conf = vec.getConfiguration();
    conf.setVocabSize(vocabCache.numWords());
    printWriter.println(conf.toJson());
    //log.info("Word2Vec conf. JSON: " + conf.toJson());
    /*
            We have the following map:
            Line 0 - VectorsConfiguration JSON string
            Line 1 - expTable
            Line 2 - table

            All following lines are vocab/weight lookup table saved line by line as VocabularyWord JSON representation
         */
    // actually we don't need expTable, since it produces exact results on subsequent runs untill you dont modify expTable size :)
    // saving ExpTable just for "special case in future"
    StringBuilder builder = new StringBuilder();
    for (int x = 0; x < ((InMemoryLookupTable) lookupTable).getExpTable().length; x++) {
        builder.append(((InMemoryLookupTable) lookupTable).getExpTable()[x]).append(" ");
    }
    printWriter.println(builder.toString().trim());
    // saving table, available only if negative sampling is used
    if (conf.getNegative() > 0 && ((InMemoryLookupTable) lookupTable).getTable() != null) {
        builder = new StringBuilder();
        for (int x = 0; x < ((InMemoryLookupTable) lookupTable).getTable().columns(); x++) {
            builder.append(((InMemoryLookupTable) lookupTable).getTable().getDouble(x)).append(" ");
        }
        printWriter.println(builder.toString().trim());
    } else
        printWriter.println("");
    List<VocabWord> words = new ArrayList<>(vocabCache.vocabWords());
    for (SequenceElement word : words) {
        VocabularyWord vw = new VocabularyWord(word.getLabel());
        vw.setCount(vocabCache.wordFrequency(word.getLabel()));
        vw.setHuffmanNode(VocabularyHolder.buildNode(word.getCodes(), word.getPoints(), word.getCodeLength(), word.getIndex()));
        // writing down syn0
        INDArray syn0 = ((InMemoryLookupTable) lookupTable).getSyn0().getRow(vocabCache.indexOf(word.getLabel()));
        double[] dsyn0 = new double[syn0.columns()];
        for (int x = 0; x < conf.getLayersSize(); x++) {
            dsyn0[x] = syn0.getDouble(x);
        }
        vw.setSyn0(dsyn0);
        // writing down syn1
        INDArray syn1 = ((InMemoryLookupTable) lookupTable).getSyn1().getRow(vocabCache.indexOf(word.getLabel()));
        double[] dsyn1 = new double[syn1.columns()];
        for (int x = 0; x < syn1.columns(); x++) {
            dsyn1[x] = syn1.getDouble(x);
        }
        vw.setSyn1(dsyn1);
        // writing down syn1Neg, if negative sampling is used
        if (conf.getNegative() > 0 && ((InMemoryLookupTable) lookupTable).getSyn1Neg() != null) {
            INDArray syn1Neg = ((InMemoryLookupTable) lookupTable).getSyn1Neg().getRow(vocabCache.indexOf(word.getLabel()));
            double[] dsyn1Neg = new double[syn1Neg.columns()];
            for (int x = 0; x < syn1Neg.columns(); x++) {
                dsyn1Neg[x] = syn1Neg.getDouble(x);
            }
            vw.setSyn1Neg(dsyn1Neg);
        }
        // in case of UseAdaGrad == true - we should save gradients for each word in vocab
        if (conf.isUseAdaGrad() && ((InMemoryLookupTable) lookupTable).isUseAdaGrad()) {
            INDArray gradient = word.getHistoricalGradient();
            if (gradient == null)
                gradient = Nd4j.zeros(word.getCodes().size());
            double[] ada = new double[gradient.columns()];
            for (int x = 0; x < gradient.columns(); x++) {
                ada[x] = gradient.getDouble(x);
            }
            vw.setHistoricalGradient(ada);
        }
        printWriter.println(vw.toJson());
    }
    // at this moment we have whole vocab serialized
    printWriter.flush();
    printWriter.close();
}
Also used : ND4JIllegalStateException(org.nd4j.linalg.exception.ND4JIllegalStateException) SequenceElement(org.deeplearning4j.models.sequencevectors.sequence.SequenceElement) ArrayList(java.util.ArrayList) VocabWord(org.deeplearning4j.models.word2vec.VocabWord) DL4JInvalidInputException(org.deeplearning4j.exception.DL4JInvalidInputException) ND4JIllegalStateException(org.nd4j.linalg.exception.ND4JIllegalStateException) InMemoryLookupTable(org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable) INDArray(org.nd4j.linalg.api.ndarray.INDArray) VocabularyWord(org.deeplearning4j.models.word2vec.wordstore.VocabularyWord)

Aggregations

InMemoryLookupTable (org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable)2 VocabWord (org.deeplearning4j.models.word2vec.VocabWord)2 VocabularyWord (org.deeplearning4j.models.word2vec.wordstore.VocabularyWord)2 INDArray (org.nd4j.linalg.api.ndarray.INDArray)2 ArrayList (java.util.ArrayList)1 AtomicInteger (java.util.concurrent.atomic.AtomicInteger)1 ZipFile (java.util.zip.ZipFile)1 DL4JInvalidInputException (org.deeplearning4j.exception.DL4JInvalidInputException)1 BasicModelUtils (org.deeplearning4j.models.embeddings.reader.impl.BasicModelUtils)1 SequenceElement (org.deeplearning4j.models.sequencevectors.sequence.SequenceElement)1 StaticWord2Vec (org.deeplearning4j.models.word2vec.StaticWord2Vec)1 Word2Vec (org.deeplearning4j.models.word2vec.Word2Vec)1 VocabularyHolder (org.deeplearning4j.models.word2vec.wordstore.VocabularyHolder)1 AbstractCache (org.deeplearning4j.models.word2vec.wordstore.inmemory.AbstractCache)1 BasicLineIterator (org.deeplearning4j.text.sentenceiterator.BasicLineIterator)1 ND4JIllegalStateException (org.nd4j.linalg.exception.ND4JIllegalStateException)1