use of org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable in project deeplearning4j by deeplearning4j.
the class WordVectorSerializer method readWord2Vec.
/**
* This method restores Word2Vec model previously saved with writeWord2VecModel
*
* PLEASE NOTE: This method loads FULL model, so don't use it if you're only going to use weights.
*
* @param file
* @return
* @throws IOException
*/
@Deprecated
public static Word2Vec readWord2Vec(File file) throws IOException {
File tmpFileSyn0 = File.createTempFile("word2vec", "0");
File tmpFileSyn1 = File.createTempFile("word2vec", "1");
File tmpFileC = File.createTempFile("word2vec", "c");
File tmpFileH = File.createTempFile("word2vec", "h");
File tmpFileF = File.createTempFile("word2vec", "f");
tmpFileSyn0.deleteOnExit();
tmpFileSyn1.deleteOnExit();
tmpFileH.deleteOnExit();
tmpFileC.deleteOnExit();
tmpFileF.deleteOnExit();
int originalFreq = Nd4j.getMemoryManager().getOccasionalGcFrequency();
boolean originalPeriodic = Nd4j.getMemoryManager().isPeriodicGcActive();
if (originalPeriodic)
Nd4j.getMemoryManager().togglePeriodicGc(false);
Nd4j.getMemoryManager().setOccasionalGcFrequency(50000);
try {
ZipFile zipFile = new ZipFile(file);
ZipEntry syn0 = zipFile.getEntry("syn0.txt");
InputStream stream = zipFile.getInputStream(syn0);
Files.copy(stream, Paths.get(tmpFileSyn0.getAbsolutePath()), StandardCopyOption.REPLACE_EXISTING);
ZipEntry syn1 = zipFile.getEntry("syn1.txt");
stream = zipFile.getInputStream(syn1);
Files.copy(stream, Paths.get(tmpFileSyn1.getAbsolutePath()), StandardCopyOption.REPLACE_EXISTING);
ZipEntry codes = zipFile.getEntry("codes.txt");
stream = zipFile.getInputStream(codes);
Files.copy(stream, Paths.get(tmpFileC.getAbsolutePath()), StandardCopyOption.REPLACE_EXISTING);
ZipEntry huffman = zipFile.getEntry("huffman.txt");
stream = zipFile.getInputStream(huffman);
Files.copy(stream, Paths.get(tmpFileH.getAbsolutePath()), StandardCopyOption.REPLACE_EXISTING);
ZipEntry config = zipFile.getEntry("config.json");
stream = zipFile.getInputStream(config);
StringBuilder builder = new StringBuilder();
try (BufferedReader reader = new BufferedReader(new InputStreamReader(stream))) {
String line;
while ((line = reader.readLine()) != null) {
builder.append(line);
}
}
VectorsConfiguration configuration = VectorsConfiguration.fromJson(builder.toString().trim());
// we read first 4 files as w2v model
Word2Vec w2v = readWord2VecFromText(tmpFileSyn0, tmpFileSyn1, tmpFileC, tmpFileH, configuration);
// we read frequencies from frequencies.txt, however it's possible that we might not have this file
ZipEntry frequencies = zipFile.getEntry("frequencies.txt");
if (frequencies != null) {
stream = zipFile.getInputStream(frequencies);
try (BufferedReader reader = new BufferedReader(new InputStreamReader(stream))) {
String line;
while ((line = reader.readLine()) != null) {
String[] split = line.split(" ");
VocabWord word = w2v.getVocab().tokenFor(decodeB64(split[0]));
word.setElementFrequency((long) Double.parseDouble(split[1]));
word.setSequencesCount((long) Double.parseDouble(split[2]));
}
}
}
ZipEntry zsyn1Neg = zipFile.getEntry("syn1Neg.txt");
if (zsyn1Neg != null) {
stream = zipFile.getInputStream(zsyn1Neg);
try (InputStreamReader isr = new InputStreamReader(stream);
BufferedReader reader = new BufferedReader(isr)) {
String line = null;
List<INDArray> rows = new ArrayList<>();
while ((line = reader.readLine()) != null) {
String[] split = line.split(" ");
double[] array = new double[split.length];
for (int i = 0; i < split.length; i++) {
array[i] = Double.parseDouble(split[i]);
}
rows.add(Nd4j.create(array));
}
// it's possible to have full model without syn1Neg
if (rows.size() > 0) {
INDArray syn1Neg = Nd4j.vstack(rows);
((InMemoryLookupTable) w2v.getLookupTable()).setSyn1Neg(syn1Neg);
}
}
}
return w2v;
} finally {
if (originalPeriodic)
Nd4j.getMemoryManager().togglePeriodicGc(true);
Nd4j.getMemoryManager().setOccasionalGcFrequency(originalFreq);
}
}
use of org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable in project deeplearning4j by deeplearning4j.
the class WordVectorSerializer method readSequenceVectors.
/**
* This method loads previously saved SequenceVectors model from InputStream
*
* @param factory
* @param stream
* @param <T>
* @return
*/
public static <T extends SequenceElement> SequenceVectors<T> readSequenceVectors(@NonNull SequenceElementFactory<T> factory, @NonNull InputStream stream) throws IOException {
BufferedReader reader = new BufferedReader(new InputStreamReader(stream, "UTF-8"));
// at first we load vectors configuration
String line = reader.readLine();
VectorsConfiguration configuration = VectorsConfiguration.fromJson(new String(Base64.decodeBase64(line), "UTF-8"));
AbstractCache<T> vocabCache = new AbstractCache.Builder<T>().build();
List<INDArray> rows = new ArrayList<>();
while ((line = reader.readLine()) != null) {
ElementPair pair = ElementPair.fromEncodedJson(line);
T element = factory.deserialize(pair.getObject());
rows.add(Nd4j.create(pair.getVector()));
vocabCache.addToken(element);
vocabCache.addWordToIndex(element.getIndex(), element.getLabel());
}
reader.close();
InMemoryLookupTable<T> lookupTable = (InMemoryLookupTable<T>) new InMemoryLookupTable.Builder<T>().vectorLength(rows.get(0).columns()).build();
/*
INDArray syn0 = Nd4j.create(rows.size(), rows.get(0).columns());
for (int x = 0; x < rows.size(); x++) {
syn0.putRow(x, rows.get(x));
}
*/
INDArray syn0 = Nd4j.vstack(rows);
lookupTable.setSyn0(syn0);
SequenceVectors<T> vectors = new SequenceVectors.Builder<T>(configuration).vocabCache(vocabCache).lookupTable(lookupTable).resetModel(false).build();
return vectors;
}
use of org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable in project deeplearning4j by deeplearning4j.
the class WordVectorSerializer method readWord2VecModel.
/**
* This method
* 1) Binary model, either compressed or not. Like well-known Google Model
* 2) Popular CSV word2vec text format
* 3) DL4j compressed format
*
* Please note: if extended data isn't available, only weights will be loaded instead.
*
* @param file
* @param extendedModel if TRUE, we'll try to load HS states & Huffman tree info, if FALSE, only weights will be loaded
* @return
*/
public static Word2Vec readWord2VecModel(@NonNull File file, boolean extendedModel) {
InMemoryLookupTable<VocabWord> lookupTable = new InMemoryLookupTable<>();
AbstractCache<VocabWord> vocabCache = new AbstractCache<>();
Word2Vec vec;
INDArray syn0 = null;
VectorsConfiguration configuration = new VectorsConfiguration();
if (!file.exists() || !file.isFile())
throw new ND4JIllegalStateException("File [" + file.getAbsolutePath() + "] doesn't exist");
int originalFreq = Nd4j.getMemoryManager().getOccasionalGcFrequency();
boolean originalPeriodic = Nd4j.getMemoryManager().isPeriodicGcActive();
if (originalPeriodic)
Nd4j.getMemoryManager().togglePeriodicGc(false);
Nd4j.getMemoryManager().setOccasionalGcFrequency(50000);
// try to load zip format
try {
if (extendedModel) {
log.debug("Trying full model restoration...");
if (originalPeriodic)
Nd4j.getMemoryManager().togglePeriodicGc(true);
Nd4j.getMemoryManager().setOccasionalGcFrequency(originalFreq);
return readWord2Vec(file);
} else {
log.debug("Trying simplified model restoration...");
File tmpFileSyn0 = File.createTempFile("word2vec", "syn");
File tmpFileConfig = File.createTempFile("word2vec", "config");
// we don't need full model, so we go directly to syn0 file
ZipFile zipFile = new ZipFile(file);
ZipEntry syn = zipFile.getEntry("syn0.txt");
InputStream stream = zipFile.getInputStream(syn);
Files.copy(stream, Paths.get(tmpFileSyn0.getAbsolutePath()), StandardCopyOption.REPLACE_EXISTING);
// now we're restoring configuration saved earlier
ZipEntry config = zipFile.getEntry("config.json");
if (config != null) {
stream = zipFile.getInputStream(config);
StringBuilder builder = new StringBuilder();
try (BufferedReader reader = new BufferedReader(new InputStreamReader(stream))) {
String line;
while ((line = reader.readLine()) != null) {
builder.append(line);
}
}
configuration = VectorsConfiguration.fromJson(builder.toString().trim());
}
ZipEntry ve = zipFile.getEntry("frequencies.txt");
if (ve != null) {
stream = zipFile.getInputStream(ve);
AtomicInteger cnt = new AtomicInteger(0);
try (BufferedReader reader = new BufferedReader(new InputStreamReader(stream))) {
String line;
while ((line = reader.readLine()) != null) {
String[] split = line.split(" ");
VocabWord word = new VocabWord(Double.valueOf(split[1]), decodeB64(split[0]));
word.setIndex(cnt.getAndIncrement());
word.incrementSequencesCount(Long.valueOf(split[2]));
vocabCache.addToken(word);
vocabCache.addWordToIndex(word.getIndex(), word.getLabel());
Nd4j.getMemoryManager().invokeGcOccasionally();
}
}
}
List<INDArray> rows = new ArrayList<>();
// basically read up everything, call vstacl and then return model
try (Reader reader = new CSVReader(tmpFileSyn0)) {
AtomicInteger cnt = new AtomicInteger(0);
while (reader.hasNext()) {
Pair<VocabWord, float[]> pair = reader.next();
VocabWord word = pair.getFirst();
INDArray vector = Nd4j.create(pair.getSecond());
if (ve != null) {
if (syn0 == null)
syn0 = Nd4j.create(vocabCache.numWords(), vector.length());
syn0.getRow(cnt.getAndIncrement()).assign(vector);
} else {
rows.add(vector);
vocabCache.addToken(word);
vocabCache.addWordToIndex(word.getIndex(), word.getLabel());
}
Nd4j.getMemoryManager().invokeGcOccasionally();
}
} catch (Exception e) {
throw new RuntimeException(e);
} finally {
if (originalPeriodic)
Nd4j.getMemoryManager().togglePeriodicGc(true);
Nd4j.getMemoryManager().setOccasionalGcFrequency(originalFreq);
}
if (syn0 == null && vocabCache.numWords() > 0)
syn0 = Nd4j.vstack(rows);
if (syn0 == null) {
log.error("Can't build syn0 table");
throw new DL4JInvalidInputException("Can't build syn0 table");
}
lookupTable = new InMemoryLookupTable.Builder<VocabWord>().cache(vocabCache).vectorLength(syn0.columns()).useHierarchicSoftmax(false).useAdaGrad(false).build();
lookupTable.setSyn0(syn0);
try {
tmpFileSyn0.delete();
tmpFileConfig.delete();
} catch (Exception e) {
//
}
}
} catch (Exception e) {
// let's try to load this file as csv file
try {
log.debug("Trying CSV model restoration...");
Pair<InMemoryLookupTable, VocabCache> pair = loadTxt(file);
lookupTable = pair.getFirst();
vocabCache = (AbstractCache<VocabWord>) pair.getSecond();
} catch (Exception ex) {
// we fallback to trying binary model instead
try {
log.debug("Trying binary model restoration...");
if (originalPeriodic)
Nd4j.getMemoryManager().togglePeriodicGc(true);
Nd4j.getMemoryManager().setOccasionalGcFrequency(originalFreq);
vec = loadGoogleModel(file, true, true);
return vec;
} catch (Exception ey) {
// try to load without linebreaks
try {
if (originalPeriodic)
Nd4j.getMemoryManager().togglePeriodicGc(true);
Nd4j.getMemoryManager().setOccasionalGcFrequency(originalFreq);
vec = loadGoogleModel(file, true, false);
return vec;
} catch (Exception ez) {
throw new RuntimeException("Unable to guess input file format. Please use corresponding loader directly");
}
}
}
}
Word2Vec.Builder builder = new Word2Vec.Builder(configuration).lookupTable(lookupTable).useAdaGrad(false).vocabCache(vocabCache).layerSize(lookupTable.layerSize()).useHierarchicSoftmax(false).resetModel(false);
/*
Trying to restore TokenizerFactory & TokenPreProcessor
*/
TokenizerFactory factory = getTokenizerFactory(configuration);
if (factory != null)
builder.tokenizerFactory(factory);
vec = builder.build();
return vec;
}
use of org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable in project deeplearning4j by deeplearning4j.
the class BasicModelUtils method wordsNearest.
/**
* Words nearest based on positive and negative words
* * @param top the top n words
* @return the words nearest the mean of the words
*/
@Override
public Collection<String> wordsNearest(INDArray words, int top) {
if (lookupTable instanceof InMemoryLookupTable) {
InMemoryLookupTable l = (InMemoryLookupTable) lookupTable;
INDArray syn0 = l.getSyn0();
if (!normalized) {
synchronized (this) {
if (!normalized) {
syn0.diviColumnVector(syn0.norm2(1));
normalized = true;
}
}
}
INDArray similarity = Transforms.unitVec(words).mmul(syn0.transpose());
List<Double> highToLowSimList = getTopN(similarity, top + 20);
List<WordSimilarity> result = new ArrayList<>();
for (int i = 0; i < highToLowSimList.size(); i++) {
String word = vocabCache.wordAtIndex(highToLowSimList.get(i).intValue());
if (word != null && !word.equals("UNK") && !word.equals("STOP")) {
INDArray otherVec = lookupTable.vector(word);
double sim = Transforms.cosineSim(words, otherVec);
result.add(new WordSimilarity(word, sim));
}
}
Collections.sort(result, new SimilarityComparator());
return getLabels(result, top);
}
Counter<String> distances = new Counter<>();
for (String s : vocabCache.words()) {
INDArray otherVec = lookupTable.vector(s);
double sim = Transforms.cosineSim(words, otherVec);
distances.incrementCount(s, sim);
}
distances.keepTopNKeys(top);
return distances.keySet();
}
use of org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable in project deeplearning4j by deeplearning4j.
the class SkipGram method configure.
/**
* SkipGram initialization over given vocabulary and WeightLookupTable
*
* @param vocabCache
* @param lookupTable
* @param configuration
*/
@Override
public void configure(@NonNull VocabCache<T> vocabCache, @NonNull WeightLookupTable<T> lookupTable, @NonNull VectorsConfiguration configuration) {
this.vocabCache = vocabCache;
this.lookupTable = lookupTable;
this.configuration = configuration;
if (configuration.getNegative() > 0) {
if (((InMemoryLookupTable<T>) lookupTable).getSyn1Neg() == null) {
log.info("Initializing syn1Neg...");
((InMemoryLookupTable<T>) lookupTable).setUseHS(configuration.isUseHierarchicSoftmax());
((InMemoryLookupTable<T>) lookupTable).setNegative(configuration.getNegative());
((InMemoryLookupTable<T>) lookupTable).resetWeights(false);
}
}
this.expTable = new DeviceLocalNDArray(Nd4j.create(((InMemoryLookupTable<T>) lookupTable).getExpTable()));
this.syn0 = new DeviceLocalNDArray(((InMemoryLookupTable<T>) lookupTable).getSyn0());
this.syn1 = new DeviceLocalNDArray(((InMemoryLookupTable<T>) lookupTable).getSyn1());
this.syn1Neg = new DeviceLocalNDArray(((InMemoryLookupTable<T>) lookupTable).getSyn1Neg());
this.table = new DeviceLocalNDArray(((InMemoryLookupTable<T>) lookupTable).getTable());
this.window = configuration.getWindow();
this.useAdaGrad = configuration.isUseAdaGrad();
this.negative = configuration.getNegative();
this.sampling = configuration.getSampling();
this.variableWindows = configuration.getVariableWindows();
this.vectorLength = configuration.getLayersSize();
}
Aggregations