Search in sources :

Example 1 with SequenceElement

use of org.deeplearning4j.models.sequencevectors.sequence.SequenceElement in project deeplearning4j by deeplearning4j.

the class AbstractCache method removeElement.

@Override
public void removeElement(String label) {
    if (extendedVocabulary.containsKey(label)) {
        SequenceElement element = extendedVocabulary.get(label);
        totalWordCount.getAndAdd((long) element.getElementFrequency() * -1);
        idxMap.remove(element.getIndex());
        extendedVocabulary.remove(label);
        vocabulary.remove(element.getStorageId());
    } else
        throw new IllegalStateException("Can't get label: '" + label + "'");
}
Also used : SequenceElement(org.deeplearning4j.models.sequencevectors.sequence.SequenceElement)

Example 2 with SequenceElement

use of org.deeplearning4j.models.sequencevectors.sequence.SequenceElement in project deeplearning4j by deeplearning4j.

the class WordVectorSerializer method writeFullModel.

/**
     * Saves full Word2Vec model in the way, that allows model updates without being rebuilt from scratches
     *
     * Deprecation note: Please, consider using writeWord2VecModel() method instead
     *
     * @param vec - The Word2Vec instance to be saved
     * @param path - the path for json to be saved
     */
@Deprecated
public static void writeFullModel(@NonNull Word2Vec vec, @NonNull String path) {
    /*
            Basically we need to save:
                    1. WeightLookupTable, especially syn0 and syn1 matrices
                    2. VocabCache, including only WordCounts
                    3. Settings from Word2Vect model: workers, layers, etc.
         */
    PrintWriter printWriter = null;
    try {
        printWriter = new PrintWriter(new OutputStreamWriter(new FileOutputStream(path), "UTF-8"));
    } catch (Exception e) {
        throw new RuntimeException(e);
    }
    WeightLookupTable<VocabWord> lookupTable = vec.getLookupTable();
    // ((InMemoryLookupTable) lookupTable).getVocab(); //vec.getVocab();
    VocabCache<VocabWord> vocabCache = vec.getVocab();
    if (!(lookupTable instanceof InMemoryLookupTable))
        throw new IllegalStateException("At this moment only InMemoryLookupTable is supported.");
    VectorsConfiguration conf = vec.getConfiguration();
    conf.setVocabSize(vocabCache.numWords());
    printWriter.println(conf.toJson());
    //log.info("Word2Vec conf. JSON: " + conf.toJson());
    /*
            We have the following map:
            Line 0 - VectorsConfiguration JSON string
            Line 1 - expTable
            Line 2 - table

            All following lines are vocab/weight lookup table saved line by line as VocabularyWord JSON representation
         */
    // actually we don't need expTable, since it produces exact results on subsequent runs untill you dont modify expTable size :)
    // saving ExpTable just for "special case in future"
    StringBuilder builder = new StringBuilder();
    for (int x = 0; x < ((InMemoryLookupTable) lookupTable).getExpTable().length; x++) {
        builder.append(((InMemoryLookupTable) lookupTable).getExpTable()[x]).append(" ");
    }
    printWriter.println(builder.toString().trim());
    // saving table, available only if negative sampling is used
    if (conf.getNegative() > 0 && ((InMemoryLookupTable) lookupTable).getTable() != null) {
        builder = new StringBuilder();
        for (int x = 0; x < ((InMemoryLookupTable) lookupTable).getTable().columns(); x++) {
            builder.append(((InMemoryLookupTable) lookupTable).getTable().getDouble(x)).append(" ");
        }
        printWriter.println(builder.toString().trim());
    } else
        printWriter.println("");
    List<VocabWord> words = new ArrayList<>(vocabCache.vocabWords());
    for (SequenceElement word : words) {
        VocabularyWord vw = new VocabularyWord(word.getLabel());
        vw.setCount(vocabCache.wordFrequency(word.getLabel()));
        vw.setHuffmanNode(VocabularyHolder.buildNode(word.getCodes(), word.getPoints(), word.getCodeLength(), word.getIndex()));
        // writing down syn0
        INDArray syn0 = ((InMemoryLookupTable) lookupTable).getSyn0().getRow(vocabCache.indexOf(word.getLabel()));
        double[] dsyn0 = new double[syn0.columns()];
        for (int x = 0; x < conf.getLayersSize(); x++) {
            dsyn0[x] = syn0.getDouble(x);
        }
        vw.setSyn0(dsyn0);
        // writing down syn1
        INDArray syn1 = ((InMemoryLookupTable) lookupTable).getSyn1().getRow(vocabCache.indexOf(word.getLabel()));
        double[] dsyn1 = new double[syn1.columns()];
        for (int x = 0; x < syn1.columns(); x++) {
            dsyn1[x] = syn1.getDouble(x);
        }
        vw.setSyn1(dsyn1);
        // writing down syn1Neg, if negative sampling is used
        if (conf.getNegative() > 0 && ((InMemoryLookupTable) lookupTable).getSyn1Neg() != null) {
            INDArray syn1Neg = ((InMemoryLookupTable) lookupTable).getSyn1Neg().getRow(vocabCache.indexOf(word.getLabel()));
            double[] dsyn1Neg = new double[syn1Neg.columns()];
            for (int x = 0; x < syn1Neg.columns(); x++) {
                dsyn1Neg[x] = syn1Neg.getDouble(x);
            }
            vw.setSyn1Neg(dsyn1Neg);
        }
        // in case of UseAdaGrad == true - we should save gradients for each word in vocab
        if (conf.isUseAdaGrad() && ((InMemoryLookupTable) lookupTable).isUseAdaGrad()) {
            INDArray gradient = word.getHistoricalGradient();
            if (gradient == null)
                gradient = Nd4j.zeros(word.getCodes().size());
            double[] ada = new double[gradient.columns()];
            for (int x = 0; x < gradient.columns(); x++) {
                ada[x] = gradient.getDouble(x);
            }
            vw.setHistoricalGradient(ada);
        }
        printWriter.println(vw.toJson());
    }
    // at this moment we have whole vocab serialized
    printWriter.flush();
    printWriter.close();
}
Also used : ND4JIllegalStateException(org.nd4j.linalg.exception.ND4JIllegalStateException) SequenceElement(org.deeplearning4j.models.sequencevectors.sequence.SequenceElement) ArrayList(java.util.ArrayList) VocabWord(org.deeplearning4j.models.word2vec.VocabWord) DL4JInvalidInputException(org.deeplearning4j.exception.DL4JInvalidInputException) ND4JIllegalStateException(org.nd4j.linalg.exception.ND4JIllegalStateException) InMemoryLookupTable(org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable) INDArray(org.nd4j.linalg.api.ndarray.INDArray) VocabularyWord(org.deeplearning4j.models.word2vec.wordstore.VocabularyWord)

Aggregations

SequenceElement (org.deeplearning4j.models.sequencevectors.sequence.SequenceElement)2 ArrayList (java.util.ArrayList)1 DL4JInvalidInputException (org.deeplearning4j.exception.DL4JInvalidInputException)1 InMemoryLookupTable (org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable)1 VocabWord (org.deeplearning4j.models.word2vec.VocabWord)1 VocabularyWord (org.deeplearning4j.models.word2vec.wordstore.VocabularyWord)1 INDArray (org.nd4j.linalg.api.ndarray.INDArray)1 ND4JIllegalStateException (org.nd4j.linalg.exception.ND4JIllegalStateException)1