use of org.nd4j.linalg.api.ndarray.INDArray in project deeplearning4j by deeplearning4j.
the class WordVectorSerializer method readTextModel.
/**
* @param modelFile
* @return
* @throws FileNotFoundException
* @throws IOException
* @throws NumberFormatException
*/
private static Word2Vec readTextModel(File modelFile) throws IOException, NumberFormatException {
InMemoryLookupTable lookupTable;
VocabCache cache;
INDArray syn0;
Word2Vec ret = new Word2Vec();
try (BufferedReader reader = new BufferedReader(new InputStreamReader(GzipUtils.isCompressedFilename(modelFile.getName()) ? new GZIPInputStream(new FileInputStream(modelFile)) : new FileInputStream(modelFile), "UTF-8"))) {
String line = reader.readLine();
String[] initial = line.split(" ");
int words = Integer.parseInt(initial[0]);
int layerSize = Integer.parseInt(initial[1]);
syn0 = Nd4j.create(words, layerSize);
cache = new InMemoryLookupCache(false);
int currLine = 0;
while ((line = reader.readLine()) != null) {
String[] split = line.split(" ");
assert split.length == layerSize + 1;
String word = split[0].replaceAll(whitespaceReplacement, " ");
float[] vector = new float[split.length - 1];
for (int i = 1; i < split.length; i++) {
vector[i - 1] = Float.parseFloat(split[i]);
}
syn0.putRow(currLine, Nd4j.create(vector));
cache.addWordToIndex(cache.numWords(), word);
cache.addToken(new VocabWord(1, word));
cache.putVocabWord(word);
currLine++;
}
lookupTable = (InMemoryLookupTable) new InMemoryLookupTable.Builder().cache(cache).vectorLength(layerSize).build();
lookupTable.setSyn0(syn0);
ret.setVocab(cache);
ret.setLookupTable(lookupTable);
}
return ret;
}
use of org.nd4j.linalg.api.ndarray.INDArray in project deeplearning4j by deeplearning4j.
the class WordVectorSerializer method readWord2VecFromText.
/**
* This method allows you to read ParagraphVectors from externaly originated vectors and syn1.
* So, technically this method is compatible with any other w2v implementation
*
* @param vectors text file with words and their wieghts, aka Syn0
* @param hs text file HS layers, aka Syn1
* @param h_codes text file with Huffman tree codes
* @param h_points text file with Huffman tree points
* @return
*/
public static Word2Vec readWord2VecFromText(@NonNull File vectors, @NonNull File hs, @NonNull File h_codes, @NonNull File h_points, @NonNull VectorsConfiguration configuration) throws IOException {
// first we load syn0
Pair<InMemoryLookupTable, VocabCache> pair = loadTxt(vectors);
InMemoryLookupTable lookupTable = pair.getFirst();
lookupTable.setNegative(configuration.getNegative());
if (configuration.getNegative() > 0)
lookupTable.initNegative();
VocabCache<VocabWord> vocab = (VocabCache<VocabWord>) pair.getSecond();
// now we load syn1
BufferedReader reader = new BufferedReader(new FileReader(hs));
String line = null;
List<INDArray> rows = new ArrayList<>();
while ((line = reader.readLine()) != null) {
String[] split = line.split(" ");
double[] array = new double[split.length];
for (int i = 0; i < split.length; i++) {
array[i] = Double.parseDouble(split[i]);
}
rows.add(Nd4j.create(array));
}
reader.close();
// it's possible to have full model without syn1
if (rows.size() > 0) {
INDArray syn1 = Nd4j.vstack(rows);
lookupTable.setSyn1(syn1);
}
// now we transform mappings into huffman tree points
reader = new BufferedReader(new FileReader(h_points));
while ((line = reader.readLine()) != null) {
String[] split = line.split(" ");
VocabWord word = vocab.wordFor(decodeB64(split[0]));
List<Integer> points = new ArrayList<>();
for (int i = 1; i < split.length; i++) {
points.add(Integer.parseInt(split[i]));
}
word.setPoints(points);
}
reader.close();
// now we transform mappings into huffman tree codes
reader = new BufferedReader(new FileReader(h_codes));
while ((line = reader.readLine()) != null) {
String[] split = line.split(" ");
VocabWord word = vocab.wordFor(decodeB64(split[0]));
List<Byte> codes = new ArrayList<>();
for (int i = 1; i < split.length; i++) {
codes.add(Byte.parseByte(split[i]));
}
word.setCodes(codes);
word.setCodeLength((short) codes.size());
}
reader.close();
Word2Vec.Builder builder = new Word2Vec.Builder(configuration).vocabCache(vocab).lookupTable(lookupTable).resetModel(false);
TokenizerFactory factory = getTokenizerFactory(configuration);
if (factory != null)
builder.tokenizerFactory(factory);
Word2Vec w2v = builder.build();
return w2v;
}
use of org.nd4j.linalg.api.ndarray.INDArray in project deeplearning4j by deeplearning4j.
the class WordVectorSerializer method readWord2Vec.
/**
* This method restores Word2Vec model previously saved with writeWord2VecModel
*
* PLEASE NOTE: This method loads FULL model, so don't use it if you're only going to use weights.
*
* @param file
* @return
* @throws IOException
*/
@Deprecated
public static Word2Vec readWord2Vec(File file) throws IOException {
File tmpFileSyn0 = File.createTempFile("word2vec", "0");
File tmpFileSyn1 = File.createTempFile("word2vec", "1");
File tmpFileC = File.createTempFile("word2vec", "c");
File tmpFileH = File.createTempFile("word2vec", "h");
File tmpFileF = File.createTempFile("word2vec", "f");
tmpFileSyn0.deleteOnExit();
tmpFileSyn1.deleteOnExit();
tmpFileH.deleteOnExit();
tmpFileC.deleteOnExit();
tmpFileF.deleteOnExit();
int originalFreq = Nd4j.getMemoryManager().getOccasionalGcFrequency();
boolean originalPeriodic = Nd4j.getMemoryManager().isPeriodicGcActive();
if (originalPeriodic)
Nd4j.getMemoryManager().togglePeriodicGc(false);
Nd4j.getMemoryManager().setOccasionalGcFrequency(50000);
try {
ZipFile zipFile = new ZipFile(file);
ZipEntry syn0 = zipFile.getEntry("syn0.txt");
InputStream stream = zipFile.getInputStream(syn0);
Files.copy(stream, Paths.get(tmpFileSyn0.getAbsolutePath()), StandardCopyOption.REPLACE_EXISTING);
ZipEntry syn1 = zipFile.getEntry("syn1.txt");
stream = zipFile.getInputStream(syn1);
Files.copy(stream, Paths.get(tmpFileSyn1.getAbsolutePath()), StandardCopyOption.REPLACE_EXISTING);
ZipEntry codes = zipFile.getEntry("codes.txt");
stream = zipFile.getInputStream(codes);
Files.copy(stream, Paths.get(tmpFileC.getAbsolutePath()), StandardCopyOption.REPLACE_EXISTING);
ZipEntry huffman = zipFile.getEntry("huffman.txt");
stream = zipFile.getInputStream(huffman);
Files.copy(stream, Paths.get(tmpFileH.getAbsolutePath()), StandardCopyOption.REPLACE_EXISTING);
ZipEntry config = zipFile.getEntry("config.json");
stream = zipFile.getInputStream(config);
StringBuilder builder = new StringBuilder();
try (BufferedReader reader = new BufferedReader(new InputStreamReader(stream))) {
String line;
while ((line = reader.readLine()) != null) {
builder.append(line);
}
}
VectorsConfiguration configuration = VectorsConfiguration.fromJson(builder.toString().trim());
// we read first 4 files as w2v model
Word2Vec w2v = readWord2VecFromText(tmpFileSyn0, tmpFileSyn1, tmpFileC, tmpFileH, configuration);
// we read frequencies from frequencies.txt, however it's possible that we might not have this file
ZipEntry frequencies = zipFile.getEntry("frequencies.txt");
if (frequencies != null) {
stream = zipFile.getInputStream(frequencies);
try (BufferedReader reader = new BufferedReader(new InputStreamReader(stream))) {
String line;
while ((line = reader.readLine()) != null) {
String[] split = line.split(" ");
VocabWord word = w2v.getVocab().tokenFor(decodeB64(split[0]));
word.setElementFrequency((long) Double.parseDouble(split[1]));
word.setSequencesCount((long) Double.parseDouble(split[2]));
}
}
}
ZipEntry zsyn1Neg = zipFile.getEntry("syn1Neg.txt");
if (zsyn1Neg != null) {
stream = zipFile.getInputStream(zsyn1Neg);
try (InputStreamReader isr = new InputStreamReader(stream);
BufferedReader reader = new BufferedReader(isr)) {
String line = null;
List<INDArray> rows = new ArrayList<>();
while ((line = reader.readLine()) != null) {
String[] split = line.split(" ");
double[] array = new double[split.length];
for (int i = 0; i < split.length; i++) {
array[i] = Double.parseDouble(split[i]);
}
rows.add(Nd4j.create(array));
}
// it's possible to have full model without syn1Neg
if (rows.size() > 0) {
INDArray syn1Neg = Nd4j.vstack(rows);
((InMemoryLookupTable) w2v.getLookupTable()).setSyn1Neg(syn1Neg);
}
}
}
return w2v;
} finally {
if (originalPeriodic)
Nd4j.getMemoryManager().togglePeriodicGc(true);
Nd4j.getMemoryManager().setOccasionalGcFrequency(originalFreq);
}
}
use of org.nd4j.linalg.api.ndarray.INDArray in project deeplearning4j by deeplearning4j.
the class WordVectorSerializer method readSequenceVectors.
/**
* This method loads previously saved SequenceVectors model from InputStream
*
* @param factory
* @param stream
* @param <T>
* @return
*/
public static <T extends SequenceElement> SequenceVectors<T> readSequenceVectors(@NonNull SequenceElementFactory<T> factory, @NonNull InputStream stream) throws IOException {
BufferedReader reader = new BufferedReader(new InputStreamReader(stream, "UTF-8"));
// at first we load vectors configuration
String line = reader.readLine();
VectorsConfiguration configuration = VectorsConfiguration.fromJson(new String(Base64.decodeBase64(line), "UTF-8"));
AbstractCache<T> vocabCache = new AbstractCache.Builder<T>().build();
List<INDArray> rows = new ArrayList<>();
while ((line = reader.readLine()) != null) {
ElementPair pair = ElementPair.fromEncodedJson(line);
T element = factory.deserialize(pair.getObject());
rows.add(Nd4j.create(pair.getVector()));
vocabCache.addToken(element);
vocabCache.addWordToIndex(element.getIndex(), element.getLabel());
}
reader.close();
InMemoryLookupTable<T> lookupTable = (InMemoryLookupTable<T>) new InMemoryLookupTable.Builder<T>().vectorLength(rows.get(0).columns()).build();
/*
INDArray syn0 = Nd4j.create(rows.size(), rows.get(0).columns());
for (int x = 0; x < rows.size(); x++) {
syn0.putRow(x, rows.get(x));
}
*/
INDArray syn0 = Nd4j.vstack(rows);
lookupTable.setSyn0(syn0);
SequenceVectors<T> vectors = new SequenceVectors.Builder<T>(configuration).vocabCache(vocabCache).lookupTable(lookupTable).resetModel(false).build();
return vectors;
}
use of org.nd4j.linalg.api.ndarray.INDArray in project deeplearning4j by deeplearning4j.
the class WordVectorSerializer method writeTsneFormat.
/**
* Write the tsne format
*
* @param vec
* the word vectors to use for labeling
* @param tsne
* the tsne array to write
* @param csv
* the file to use
* @throws Exception
*/
public static void writeTsneFormat(Word2Vec vec, INDArray tsne, File csv) throws Exception {
BufferedWriter write = new BufferedWriter(new OutputStreamWriter(new FileOutputStream(csv), "UTF-8"));
int words = 0;
InMemoryLookupCache l = (InMemoryLookupCache) vec.vocab();
for (String word : vec.vocab().words()) {
if (word == null) {
continue;
}
StringBuilder sb = new StringBuilder();
INDArray wordVector = tsne.getRow(l.wordFor(word).getIndex());
for (int j = 0; j < wordVector.length(); j++) {
sb.append(wordVector.getDouble(j));
if (j < wordVector.length() - 1) {
sb.append(",");
}
}
sb.append(",");
sb.append(word.replaceAll(" ", whitespaceReplacement));
sb.append(" ");
sb.append("\n");
write.write(sb.toString());
}
log.info("Wrote " + words + " with size of " + vec.lookupTable().layerSize());
write.flush();
write.close();
}
Aggregations