use of org.deeplearning4j.models.sequencevectors.sequence.SequenceElement in project deeplearning4j by deeplearning4j.
the class AbstractCache method removeElement.
@Override
public void removeElement(String label) {
if (extendedVocabulary.containsKey(label)) {
SequenceElement element = extendedVocabulary.get(label);
totalWordCount.getAndAdd((long) element.getElementFrequency() * -1);
idxMap.remove(element.getIndex());
extendedVocabulary.remove(label);
vocabulary.remove(element.getStorageId());
} else
throw new IllegalStateException("Can't get label: '" + label + "'");
}
use of org.deeplearning4j.models.sequencevectors.sequence.SequenceElement in project deeplearning4j by deeplearning4j.
the class WordVectorSerializer method writeFullModel.
/**
* Saves full Word2Vec model in the way, that allows model updates without being rebuilt from scratches
*
* Deprecation note: Please, consider using writeWord2VecModel() method instead
*
* @param vec - The Word2Vec instance to be saved
* @param path - the path for json to be saved
*/
@Deprecated
public static void writeFullModel(@NonNull Word2Vec vec, @NonNull String path) {
/*
Basically we need to save:
1. WeightLookupTable, especially syn0 and syn1 matrices
2. VocabCache, including only WordCounts
3. Settings from Word2Vect model: workers, layers, etc.
*/
PrintWriter printWriter = null;
try {
printWriter = new PrintWriter(new OutputStreamWriter(new FileOutputStream(path), "UTF-8"));
} catch (Exception e) {
throw new RuntimeException(e);
}
WeightLookupTable<VocabWord> lookupTable = vec.getLookupTable();
// ((InMemoryLookupTable) lookupTable).getVocab(); //vec.getVocab();
VocabCache<VocabWord> vocabCache = vec.getVocab();
if (!(lookupTable instanceof InMemoryLookupTable))
throw new IllegalStateException("At this moment only InMemoryLookupTable is supported.");
VectorsConfiguration conf = vec.getConfiguration();
conf.setVocabSize(vocabCache.numWords());
printWriter.println(conf.toJson());
//log.info("Word2Vec conf. JSON: " + conf.toJson());
/*
We have the following map:
Line 0 - VectorsConfiguration JSON string
Line 1 - expTable
Line 2 - table
All following lines are vocab/weight lookup table saved line by line as VocabularyWord JSON representation
*/
// actually we don't need expTable, since it produces exact results on subsequent runs untill you dont modify expTable size :)
// saving ExpTable just for "special case in future"
StringBuilder builder = new StringBuilder();
for (int x = 0; x < ((InMemoryLookupTable) lookupTable).getExpTable().length; x++) {
builder.append(((InMemoryLookupTable) lookupTable).getExpTable()[x]).append(" ");
}
printWriter.println(builder.toString().trim());
// saving table, available only if negative sampling is used
if (conf.getNegative() > 0 && ((InMemoryLookupTable) lookupTable).getTable() != null) {
builder = new StringBuilder();
for (int x = 0; x < ((InMemoryLookupTable) lookupTable).getTable().columns(); x++) {
builder.append(((InMemoryLookupTable) lookupTable).getTable().getDouble(x)).append(" ");
}
printWriter.println(builder.toString().trim());
} else
printWriter.println("");
List<VocabWord> words = new ArrayList<>(vocabCache.vocabWords());
for (SequenceElement word : words) {
VocabularyWord vw = new VocabularyWord(word.getLabel());
vw.setCount(vocabCache.wordFrequency(word.getLabel()));
vw.setHuffmanNode(VocabularyHolder.buildNode(word.getCodes(), word.getPoints(), word.getCodeLength(), word.getIndex()));
// writing down syn0
INDArray syn0 = ((InMemoryLookupTable) lookupTable).getSyn0().getRow(vocabCache.indexOf(word.getLabel()));
double[] dsyn0 = new double[syn0.columns()];
for (int x = 0; x < conf.getLayersSize(); x++) {
dsyn0[x] = syn0.getDouble(x);
}
vw.setSyn0(dsyn0);
// writing down syn1
INDArray syn1 = ((InMemoryLookupTable) lookupTable).getSyn1().getRow(vocabCache.indexOf(word.getLabel()));
double[] dsyn1 = new double[syn1.columns()];
for (int x = 0; x < syn1.columns(); x++) {
dsyn1[x] = syn1.getDouble(x);
}
vw.setSyn1(dsyn1);
// writing down syn1Neg, if negative sampling is used
if (conf.getNegative() > 0 && ((InMemoryLookupTable) lookupTable).getSyn1Neg() != null) {
INDArray syn1Neg = ((InMemoryLookupTable) lookupTable).getSyn1Neg().getRow(vocabCache.indexOf(word.getLabel()));
double[] dsyn1Neg = new double[syn1Neg.columns()];
for (int x = 0; x < syn1Neg.columns(); x++) {
dsyn1Neg[x] = syn1Neg.getDouble(x);
}
vw.setSyn1Neg(dsyn1Neg);
}
// in case of UseAdaGrad == true - we should save gradients for each word in vocab
if (conf.isUseAdaGrad() && ((InMemoryLookupTable) lookupTable).isUseAdaGrad()) {
INDArray gradient = word.getHistoricalGradient();
if (gradient == null)
gradient = Nd4j.zeros(word.getCodes().size());
double[] ada = new double[gradient.columns()];
for (int x = 0; x < gradient.columns(); x++) {
ada[x] = gradient.getDouble(x);
}
vw.setHistoricalGradient(ada);
}
printWriter.println(vw.toJson());
}
// at this moment we have whole vocab serialized
printWriter.flush();
printWriter.close();
}
Aggregations