use of org.deeplearning4j.models.word2vec.wordstore.inmemory.AbstractCache in project deeplearning4j by deeplearning4j.
the class VocabCacheExporter method export.
@Override
public void export(JavaRDD<ExportContainer<VocabWord>> rdd) {
// beware, generally that's VERY bad idea, but will work fine for testing purposes
List<ExportContainer<VocabWord>> list = rdd.collect();
if (vocabCache == null)
vocabCache = new AbstractCache<>();
INDArray syn0 = null;
// just roll through list
for (ExportContainer<VocabWord> element : list) {
VocabWord word = element.getElement();
INDArray weights = element.getArray();
if (syn0 == null)
syn0 = Nd4j.create(list.size(), weights.length());
vocabCache.addToken(word);
vocabCache.addWordToIndex(word.getIndex(), word.getLabel());
syn0.getRow(word.getIndex()).assign(weights);
}
if (lookupTable == null)
lookupTable = new InMemoryLookupTable.Builder<VocabWord>().cache(vocabCache).vectorLength(syn0.columns()).build();
lookupTable.setSyn0(syn0);
// this is bad & dirty, but we don't really need anything else for testing :)
word2Vec = WordVectorSerializer.fromPair(Pair.<InMemoryLookupTable, VocabCache>makePair(lookupTable, vocabCache));
}
Aggregations