use of org.nd4j.linalg.api.ndarray.INDArray in project deeplearning4j by deeplearning4j.
the class WordVectorSerializer method readWord2VecModel.
/**
* This method
* 1) Binary model, either compressed or not. Like well-known Google Model
* 2) Popular CSV word2vec text format
* 3) DL4j compressed format
*
* Please note: if extended data isn't available, only weights will be loaded instead.
*
* @param file
* @param extendedModel if TRUE, we'll try to load HS states & Huffman tree info, if FALSE, only weights will be loaded
* @return
*/
public static Word2Vec readWord2VecModel(@NonNull File file, boolean extendedModel) {
InMemoryLookupTable<VocabWord> lookupTable = new InMemoryLookupTable<>();
AbstractCache<VocabWord> vocabCache = new AbstractCache<>();
Word2Vec vec;
INDArray syn0 = null;
VectorsConfiguration configuration = new VectorsConfiguration();
if (!file.exists() || !file.isFile())
throw new ND4JIllegalStateException("File [" + file.getAbsolutePath() + "] doesn't exist");
int originalFreq = Nd4j.getMemoryManager().getOccasionalGcFrequency();
boolean originalPeriodic = Nd4j.getMemoryManager().isPeriodicGcActive();
if (originalPeriodic)
Nd4j.getMemoryManager().togglePeriodicGc(false);
Nd4j.getMemoryManager().setOccasionalGcFrequency(50000);
// try to load zip format
try {
if (extendedModel) {
log.debug("Trying full model restoration...");
if (originalPeriodic)
Nd4j.getMemoryManager().togglePeriodicGc(true);
Nd4j.getMemoryManager().setOccasionalGcFrequency(originalFreq);
return readWord2Vec(file);
} else {
log.debug("Trying simplified model restoration...");
File tmpFileSyn0 = File.createTempFile("word2vec", "syn");
File tmpFileConfig = File.createTempFile("word2vec", "config");
// we don't need full model, so we go directly to syn0 file
ZipFile zipFile = new ZipFile(file);
ZipEntry syn = zipFile.getEntry("syn0.txt");
InputStream stream = zipFile.getInputStream(syn);
Files.copy(stream, Paths.get(tmpFileSyn0.getAbsolutePath()), StandardCopyOption.REPLACE_EXISTING);
// now we're restoring configuration saved earlier
ZipEntry config = zipFile.getEntry("config.json");
if (config != null) {
stream = zipFile.getInputStream(config);
StringBuilder builder = new StringBuilder();
try (BufferedReader reader = new BufferedReader(new InputStreamReader(stream))) {
String line;
while ((line = reader.readLine()) != null) {
builder.append(line);
}
}
configuration = VectorsConfiguration.fromJson(builder.toString().trim());
}
ZipEntry ve = zipFile.getEntry("frequencies.txt");
if (ve != null) {
stream = zipFile.getInputStream(ve);
AtomicInteger cnt = new AtomicInteger(0);
try (BufferedReader reader = new BufferedReader(new InputStreamReader(stream))) {
String line;
while ((line = reader.readLine()) != null) {
String[] split = line.split(" ");
VocabWord word = new VocabWord(Double.valueOf(split[1]), decodeB64(split[0]));
word.setIndex(cnt.getAndIncrement());
word.incrementSequencesCount(Long.valueOf(split[2]));
vocabCache.addToken(word);
vocabCache.addWordToIndex(word.getIndex(), word.getLabel());
Nd4j.getMemoryManager().invokeGcOccasionally();
}
}
}
List<INDArray> rows = new ArrayList<>();
// basically read up everything, call vstacl and then return model
try (Reader reader = new CSVReader(tmpFileSyn0)) {
AtomicInteger cnt = new AtomicInteger(0);
while (reader.hasNext()) {
Pair<VocabWord, float[]> pair = reader.next();
VocabWord word = pair.getFirst();
INDArray vector = Nd4j.create(pair.getSecond());
if (ve != null) {
if (syn0 == null)
syn0 = Nd4j.create(vocabCache.numWords(), vector.length());
syn0.getRow(cnt.getAndIncrement()).assign(vector);
} else {
rows.add(vector);
vocabCache.addToken(word);
vocabCache.addWordToIndex(word.getIndex(), word.getLabel());
}
Nd4j.getMemoryManager().invokeGcOccasionally();
}
} catch (Exception e) {
throw new RuntimeException(e);
} finally {
if (originalPeriodic)
Nd4j.getMemoryManager().togglePeriodicGc(true);
Nd4j.getMemoryManager().setOccasionalGcFrequency(originalFreq);
}
if (syn0 == null && vocabCache.numWords() > 0)
syn0 = Nd4j.vstack(rows);
if (syn0 == null) {
log.error("Can't build syn0 table");
throw new DL4JInvalidInputException("Can't build syn0 table");
}
lookupTable = new InMemoryLookupTable.Builder<VocabWord>().cache(vocabCache).vectorLength(syn0.columns()).useHierarchicSoftmax(false).useAdaGrad(false).build();
lookupTable.setSyn0(syn0);
try {
tmpFileSyn0.delete();
tmpFileConfig.delete();
} catch (Exception e) {
//
}
}
} catch (Exception e) {
// let's try to load this file as csv file
try {
log.debug("Trying CSV model restoration...");
Pair<InMemoryLookupTable, VocabCache> pair = loadTxt(file);
lookupTable = pair.getFirst();
vocabCache = (AbstractCache<VocabWord>) pair.getSecond();
} catch (Exception ex) {
// we fallback to trying binary model instead
try {
log.debug("Trying binary model restoration...");
if (originalPeriodic)
Nd4j.getMemoryManager().togglePeriodicGc(true);
Nd4j.getMemoryManager().setOccasionalGcFrequency(originalFreq);
vec = loadGoogleModel(file, true, true);
return vec;
} catch (Exception ey) {
// try to load without linebreaks
try {
if (originalPeriodic)
Nd4j.getMemoryManager().togglePeriodicGc(true);
Nd4j.getMemoryManager().setOccasionalGcFrequency(originalFreq);
vec = loadGoogleModel(file, true, false);
return vec;
} catch (Exception ez) {
throw new RuntimeException("Unable to guess input file format. Please use corresponding loader directly");
}
}
}
}
Word2Vec.Builder builder = new Word2Vec.Builder(configuration).lookupTable(lookupTable).useAdaGrad(false).vocabCache(vocabCache).layerSize(lookupTable.layerSize()).useHierarchicSoftmax(false).resetModel(false);
/*
Trying to restore TokenizerFactory & TokenPreProcessor
*/
TokenizerFactory factory = getTokenizerFactory(configuration);
if (factory != null)
builder.tokenizerFactory(factory);
vec = builder.build();
return vec;
}
use of org.nd4j.linalg.api.ndarray.INDArray in project deeplearning4j by deeplearning4j.
the class BasicModelUtils method wordsNearest.
/**
* Words nearest based on positive and negative words
* * @param top the top n words
* @return the words nearest the mean of the words
*/
@Override
public Collection<String> wordsNearest(INDArray words, int top) {
if (lookupTable instanceof InMemoryLookupTable) {
InMemoryLookupTable l = (InMemoryLookupTable) lookupTable;
INDArray syn0 = l.getSyn0();
if (!normalized) {
synchronized (this) {
if (!normalized) {
syn0.diviColumnVector(syn0.norm2(1));
normalized = true;
}
}
}
INDArray similarity = Transforms.unitVec(words).mmul(syn0.transpose());
List<Double> highToLowSimList = getTopN(similarity, top + 20);
List<WordSimilarity> result = new ArrayList<>();
for (int i = 0; i < highToLowSimList.size(); i++) {
String word = vocabCache.wordAtIndex(highToLowSimList.get(i).intValue());
if (word != null && !word.equals("UNK") && !word.equals("STOP")) {
INDArray otherVec = lookupTable.vector(word);
double sim = Transforms.cosineSim(words, otherVec);
result.add(new WordSimilarity(word, sim));
}
}
Collections.sort(result, new SimilarityComparator());
return getLabels(result, top);
}
Counter<String> distances = new Counter<>();
for (String s : vocabCache.words()) {
INDArray otherVec = lookupTable.vector(s);
double sim = Transforms.cosineSim(words, otherVec);
distances.incrementCount(s, sim);
}
distances.keepTopNKeys(top);
return distances.keySet();
}
use of org.nd4j.linalg.api.ndarray.INDArray in project deeplearning4j by deeplearning4j.
the class BasicModelUtils method similarity.
/**
* Returns the similarity of 2 words. Result value will be in range [-1,1], where -1.0 is exact opposite similarity, i.e. NO similarity, and 1.0 is total match of two word vectors.
* However, most of time you'll see values in range [0,1], but that's something depends of training corpus.
*
* Returns NaN if any of labels not exists in vocab, or any label is null
*
* @param label1 the first word
* @param label2 the second word
* @return a normalized similarity (cosine similarity)
*/
@Override
public double similarity(@NonNull String label1, @NonNull String label2) {
if (label1 == null || label2 == null) {
log.debug("LABELS: " + label1 + ": " + (label1 == null ? "null" : EXISTS) + ";" + label2 + " vec2:" + (label2 == null ? "null" : EXISTS));
return Double.NaN;
}
if (!vocabCache.hasToken(label1)) {
log.debug("Unknown token 1 requested: [{}]", label1);
return Double.NaN;
}
if (!vocabCache.hasToken(label2)) {
log.debug("Unknown token 2 requested: [{}]", label2);
return Double.NaN;
}
INDArray vec1 = lookupTable.vector(label1).dup();
INDArray vec2 = lookupTable.vector(label2).dup();
if (vec1 == null || vec2 == null) {
log.debug(label1 + ": " + (vec1 == null ? "null" : EXISTS) + ";" + label2 + " vec2:" + (vec2 == null ? "null" : EXISTS));
return Double.NaN;
}
if (label1.equals(label2))
return 1.0;
return Transforms.cosineSim(vec1, vec2);
}
use of org.nd4j.linalg.api.ndarray.INDArray in project deeplearning4j by deeplearning4j.
the class TreeModelUtils method wordsNearest.
@Override
public Collection<String> wordsNearest(Collection<String> positive, Collection<String> negative, int top) {
// Check every word is in the model
for (String p : SetUtils.union(new HashSet<>(positive), new HashSet<>(negative))) {
if (!vocabCache.containsWord(p)) {
return new ArrayList<>();
}
}
INDArray words = Nd4j.create(positive.size() + negative.size(), lookupTable.layerSize());
int row = 0;
for (String s : positive) {
words.putRow(row++, lookupTable.vector(s));
}
for (String s : negative) {
words.putRow(row++, lookupTable.vector(s).mul(-1));
}
INDArray mean = words.isMatrix() ? words.mean(0) : words;
return wordsNearest(mean, top);
}
use of org.nd4j.linalg.api.ndarray.INDArray in project deeplearning4j by deeplearning4j.
the class GloVe method iterateSample.
private double iterateSample(T element1, T element2, double score) {
//prediction: input + bias
if (element1.getIndex() < 0 || element1.getIndex() >= syn0.rows())
throw new IllegalArgumentException("Illegal index for word " + element1.getLabel());
if (element2.getIndex() < 0 || element2.getIndex() >= syn0.rows())
throw new IllegalArgumentException("Illegal index for word " + element2.getLabel());
INDArray w1Vector = syn0.slice(element1.getIndex());
INDArray w2Vector = syn0.slice(element2.getIndex());
//w1 * w2 + bias
double prediction = Nd4j.getBlasWrapper().dot(w1Vector, w2Vector);
prediction += bias.getDouble(element1.getIndex()) + bias.getDouble(element2.getIndex()) - Math.log(score);
// Math.pow(Math.min(1.0,(score / maxCount)),xMax);
double fDiff = (score > xMax) ? prediction : Math.pow(score / xMax, alpha) * prediction;
if (Double.isNaN(fDiff))
fDiff = Nd4j.EPS_THRESHOLD;
//amount of change
double gradient = fDiff * learningRate;
//note the update step here: the gradient is
//the gradient of the OPPOSITE word
//for adagrad we will use the index of the word passed in
//for the gradient calculation we will use the context vector
update(element1, w1Vector, w2Vector, gradient);
update(element2, w2Vector, w1Vector, gradient);
return 0.5 * fDiff * prediction;
}
Aggregations