use of org.deeplearning4j.berkeley.Pair in project deeplearning4j by deeplearning4j.
the class WordVectorSerializer method loadTxt.
/**
* Loads an in memory cache from the given path (sets syn0 and the vocab)
*
* @param vectorsFile the path of the file to load
* @return a Pair holding the lookup table and the vocab cache.
* @throws FileNotFoundException if the input file does not exist
*/
public static Pair<InMemoryLookupTable, VocabCache> loadTxt(File vectorsFile) throws FileNotFoundException, UnsupportedEncodingException {
BufferedReader reader = new BufferedReader(new InputStreamReader(new FileInputStream(vectorsFile), "UTF-8"));
AbstractCache cache = new AbstractCache<>();
LineIterator iter = IOUtils.lineIterator(reader);
String line = null;
boolean hasHeader = false;
if (iter.hasNext()) {
// skip header line
line = iter.nextLine();
//look for spaces
if (!line.contains(" ")) {
log.debug("Skipping first line");
hasHeader = true;
} else {
// we should check for something that looks like proper word vectors here. i.e: 1 word at the 0 position, and bunch of floats further
String[] split = line.split(" ");
try {
long[] header = new long[split.length];
for (int x = 0; x < split.length; x++) {
header[x] = Long.parseLong(split[x]);
}
if (split.length < 4)
hasHeader = true;
// [2] - number of documents <-- DL4j-only value
if (split.length == 3)
cache.incrementTotalDocCount(header[2]);
printOutProjectedMemoryUse(header[0], (int) header[1], 1);
hasHeader = true;
try {
reader.close();
} catch (Exception ex) {
}
} catch (Exception e) {
// if any conversion exception hits - that'll be considered header
hasHeader = false;
}
}
}
//reposition buffer to be one line ahead
if (hasHeader) {
line = "";
iter.close();
reader = new BufferedReader(new FileReader(vectorsFile));
iter = IOUtils.lineIterator(reader);
iter.nextLine();
}
List<INDArray> arrays = new ArrayList<>();
while (iter.hasNext()) {
if (line.isEmpty())
line = iter.nextLine();
String[] split = line.split(" ");
//split[0].replaceAll(whitespaceReplacement, " ");
String word = decodeB64(split[0]);
VocabWord word1 = new VocabWord(1.0, word);
word1.setIndex(cache.numWords());
cache.addToken(word1);
cache.addWordToIndex(word1.getIndex(), word);
cache.putVocabWord(word);
float[] vector = new float[split.length - 1];
for (int i = 1; i < split.length; i++) {
vector[i - 1] = Float.parseFloat(split[i]);
}
INDArray row = Nd4j.create(vector);
arrays.add(row);
// workaround for skipped first row
line = "";
}
INDArray syn = Nd4j.vstack(arrays);
InMemoryLookupTable lookupTable = (InMemoryLookupTable) new InMemoryLookupTable.Builder().vectorLength(arrays.get(0).columns()).useAdaGrad(false).cache(cache).useHierarchicSoftmax(false).build();
if (Nd4j.ENFORCE_NUMERICAL_STABILITY)
Nd4j.clearNans(syn);
lookupTable.setSyn0(syn);
iter.close();
try {
reader.close();
} catch (Exception e) {
}
return new Pair<>(lookupTable, (VocabCache) cache);
}
use of org.deeplearning4j.berkeley.Pair in project deeplearning4j by deeplearning4j.
the class WordVectorSerializer method readWord2VecModel.
/**
* This method
* 1) Binary model, either compressed or not. Like well-known Google Model
* 2) Popular CSV word2vec text format
* 3) DL4j compressed format
*
* Please note: if extended data isn't available, only weights will be loaded instead.
*
* @param file
* @param extendedModel if TRUE, we'll try to load HS states & Huffman tree info, if FALSE, only weights will be loaded
* @return
*/
public static Word2Vec readWord2VecModel(@NonNull File file, boolean extendedModel) {
InMemoryLookupTable<VocabWord> lookupTable = new InMemoryLookupTable<>();
AbstractCache<VocabWord> vocabCache = new AbstractCache<>();
Word2Vec vec;
INDArray syn0 = null;
VectorsConfiguration configuration = new VectorsConfiguration();
if (!file.exists() || !file.isFile())
throw new ND4JIllegalStateException("File [" + file.getAbsolutePath() + "] doesn't exist");
int originalFreq = Nd4j.getMemoryManager().getOccasionalGcFrequency();
boolean originalPeriodic = Nd4j.getMemoryManager().isPeriodicGcActive();
if (originalPeriodic)
Nd4j.getMemoryManager().togglePeriodicGc(false);
Nd4j.getMemoryManager().setOccasionalGcFrequency(50000);
// try to load zip format
try {
if (extendedModel) {
log.debug("Trying full model restoration...");
if (originalPeriodic)
Nd4j.getMemoryManager().togglePeriodicGc(true);
Nd4j.getMemoryManager().setOccasionalGcFrequency(originalFreq);
return readWord2Vec(file);
} else {
log.debug("Trying simplified model restoration...");
File tmpFileSyn0 = File.createTempFile("word2vec", "syn");
File tmpFileConfig = File.createTempFile("word2vec", "config");
// we don't need full model, so we go directly to syn0 file
ZipFile zipFile = new ZipFile(file);
ZipEntry syn = zipFile.getEntry("syn0.txt");
InputStream stream = zipFile.getInputStream(syn);
Files.copy(stream, Paths.get(tmpFileSyn0.getAbsolutePath()), StandardCopyOption.REPLACE_EXISTING);
// now we're restoring configuration saved earlier
ZipEntry config = zipFile.getEntry("config.json");
if (config != null) {
stream = zipFile.getInputStream(config);
StringBuilder builder = new StringBuilder();
try (BufferedReader reader = new BufferedReader(new InputStreamReader(stream))) {
String line;
while ((line = reader.readLine()) != null) {
builder.append(line);
}
}
configuration = VectorsConfiguration.fromJson(builder.toString().trim());
}
ZipEntry ve = zipFile.getEntry("frequencies.txt");
if (ve != null) {
stream = zipFile.getInputStream(ve);
AtomicInteger cnt = new AtomicInteger(0);
try (BufferedReader reader = new BufferedReader(new InputStreamReader(stream))) {
String line;
while ((line = reader.readLine()) != null) {
String[] split = line.split(" ");
VocabWord word = new VocabWord(Double.valueOf(split[1]), decodeB64(split[0]));
word.setIndex(cnt.getAndIncrement());
word.incrementSequencesCount(Long.valueOf(split[2]));
vocabCache.addToken(word);
vocabCache.addWordToIndex(word.getIndex(), word.getLabel());
Nd4j.getMemoryManager().invokeGcOccasionally();
}
}
}
List<INDArray> rows = new ArrayList<>();
// basically read up everything, call vstacl and then return model
try (Reader reader = new CSVReader(tmpFileSyn0)) {
AtomicInteger cnt = new AtomicInteger(0);
while (reader.hasNext()) {
Pair<VocabWord, float[]> pair = reader.next();
VocabWord word = pair.getFirst();
INDArray vector = Nd4j.create(pair.getSecond());
if (ve != null) {
if (syn0 == null)
syn0 = Nd4j.create(vocabCache.numWords(), vector.length());
syn0.getRow(cnt.getAndIncrement()).assign(vector);
} else {
rows.add(vector);
vocabCache.addToken(word);
vocabCache.addWordToIndex(word.getIndex(), word.getLabel());
}
Nd4j.getMemoryManager().invokeGcOccasionally();
}
} catch (Exception e) {
throw new RuntimeException(e);
} finally {
if (originalPeriodic)
Nd4j.getMemoryManager().togglePeriodicGc(true);
Nd4j.getMemoryManager().setOccasionalGcFrequency(originalFreq);
}
if (syn0 == null && vocabCache.numWords() > 0)
syn0 = Nd4j.vstack(rows);
if (syn0 == null) {
log.error("Can't build syn0 table");
throw new DL4JInvalidInputException("Can't build syn0 table");
}
lookupTable = new InMemoryLookupTable.Builder<VocabWord>().cache(vocabCache).vectorLength(syn0.columns()).useHierarchicSoftmax(false).useAdaGrad(false).build();
lookupTable.setSyn0(syn0);
try {
tmpFileSyn0.delete();
tmpFileConfig.delete();
} catch (Exception e) {
//
}
}
} catch (Exception e) {
// let's try to load this file as csv file
try {
log.debug("Trying CSV model restoration...");
Pair<InMemoryLookupTable, VocabCache> pair = loadTxt(file);
lookupTable = pair.getFirst();
vocabCache = (AbstractCache<VocabWord>) pair.getSecond();
} catch (Exception ex) {
// we fallback to trying binary model instead
try {
log.debug("Trying binary model restoration...");
if (originalPeriodic)
Nd4j.getMemoryManager().togglePeriodicGc(true);
Nd4j.getMemoryManager().setOccasionalGcFrequency(originalFreq);
vec = loadGoogleModel(file, true, true);
return vec;
} catch (Exception ey) {
// try to load without linebreaks
try {
if (originalPeriodic)
Nd4j.getMemoryManager().togglePeriodicGc(true);
Nd4j.getMemoryManager().setOccasionalGcFrequency(originalFreq);
vec = loadGoogleModel(file, true, false);
return vec;
} catch (Exception ez) {
throw new RuntimeException("Unable to guess input file format. Please use corresponding loader directly");
}
}
}
}
Word2Vec.Builder builder = new Word2Vec.Builder(configuration).lookupTable(lookupTable).useAdaGrad(false).vocabCache(vocabCache).layerSize(lookupTable.layerSize()).useHierarchicSoftmax(false).resetModel(false);
/*
Trying to restore TokenizerFactory & TokenPreProcessor
*/
TokenizerFactory factory = getTokenizerFactory(configuration);
if (factory != null)
builder.tokenizerFactory(factory);
vec = builder.build();
return vec;
}
use of org.deeplearning4j.berkeley.Pair in project deeplearning4j by deeplearning4j.
the class WordVectorSerializer method loadStaticModel.
/**
* This method restores previously saved w2v model. File can be in one of the following formats:
* 1) Binary model, either compressed or not. Like well-known Google Model
* 2) Popular CSV word2vec text format
* 3) DL4j compressed format
*
* In return you get StaticWord2Vec model, which might be used as lookup table only in multi-gpu environment.
*
* @param file File should point to previously saved w2v model
* @return
*/
// TODO: this method needs better name :)
public static WordVectors loadStaticModel(File file) {
if (!file.exists() || file.isDirectory())
throw new RuntimeException(new FileNotFoundException("File [" + file.getAbsolutePath() + "] was not found"));
int originalFreq = Nd4j.getMemoryManager().getOccasionalGcFrequency();
boolean originalPeriodic = Nd4j.getMemoryManager().isPeriodicGcActive();
if (originalPeriodic)
Nd4j.getMemoryManager().togglePeriodicGc(false);
Nd4j.getMemoryManager().setOccasionalGcFrequency(50000);
CompressedRamStorage<Integer> storage = new CompressedRamStorage.Builder<Integer>().useInplaceCompression(false).setCompressor(new NoOp()).emulateIsAbsent(false).build();
VocabCache<VocabWord> vocabCache = new AbstractCache.Builder<VocabWord>().build();
// if zip - that's dl4j format
try {
log.debug("Trying DL4j format...");
File tmpFileSyn0 = File.createTempFile("word2vec", "syn");
ZipFile zipFile = new ZipFile(file);
ZipEntry syn0 = zipFile.getEntry("syn0.txt");
InputStream stream = zipFile.getInputStream(syn0);
Files.copy(stream, Paths.get(tmpFileSyn0.getAbsolutePath()), StandardCopyOption.REPLACE_EXISTING);
storage.clear();
try (Reader reader = new CSVReader(tmpFileSyn0)) {
while (reader.hasNext()) {
Pair<VocabWord, float[]> pair = reader.next();
VocabWord word = pair.getFirst();
storage.store(word.getIndex(), pair.getSecond());
vocabCache.addToken(word);
vocabCache.addWordToIndex(word.getIndex(), word.getLabel());
Nd4j.getMemoryManager().invokeGcOccasionally();
}
} catch (Exception e) {
throw new RuntimeException(e);
} finally {
if (originalPeriodic)
Nd4j.getMemoryManager().togglePeriodicGc(true);
Nd4j.getMemoryManager().setOccasionalGcFrequency(originalFreq);
}
} catch (Exception e) {
//
try {
// try to load file as text csv
vocabCache = new AbstractCache.Builder<VocabWord>().build();
storage.clear();
log.debug("Trying CSVReader...");
try (Reader reader = new CSVReader(file)) {
while (reader.hasNext()) {
Pair<VocabWord, float[]> pair = reader.next();
VocabWord word = pair.getFirst();
storage.store(word.getIndex(), pair.getSecond());
vocabCache.addToken(word);
vocabCache.addWordToIndex(word.getIndex(), word.getLabel());
Nd4j.getMemoryManager().invokeGcOccasionally();
}
} catch (Exception ef) {
// we throw away this exception, and trying to load data as binary model
throw new RuntimeException(ef);
} finally {
if (originalPeriodic)
Nd4j.getMemoryManager().togglePeriodicGc(true);
Nd4j.getMemoryManager().setOccasionalGcFrequency(originalFreq);
}
} catch (Exception ex) {
// otherwise it's probably google model. which might be compressed or not
log.debug("Trying BinaryReader...");
vocabCache = new AbstractCache.Builder<VocabWord>().build();
storage.clear();
try (Reader reader = new BinaryReader(file)) {
while (reader.hasNext()) {
Pair<VocabWord, float[]> pair = reader.next();
VocabWord word = pair.getFirst();
storage.store(word.getIndex(), pair.getSecond());
vocabCache.addToken(word);
vocabCache.addWordToIndex(word.getIndex(), word.getLabel());
Nd4j.getMemoryManager().invokeGcOccasionally();
}
} catch (Exception ez) {
throw new RuntimeException("Unable to guess input file format");
} finally {
if (originalPeriodic)
Nd4j.getMemoryManager().togglePeriodicGc(true);
Nd4j.getMemoryManager().setOccasionalGcFrequency(originalFreq);
}
} finally {
if (originalPeriodic)
Nd4j.getMemoryManager().togglePeriodicGc(true);
Nd4j.getMemoryManager().setOccasionalGcFrequency(originalFreq);
}
}
StaticWord2Vec word2Vec = new StaticWord2Vec.Builder(storage, vocabCache).build();
return word2Vec;
}
use of org.deeplearning4j.berkeley.Pair in project deeplearning4j by deeplearning4j.
the class ParameterAveragingTrainingWorker method processMinibatchWithStats.
@Override
public Pair<ParameterAveragingTrainingResult, SparkTrainingStats> processMinibatchWithStats(DataSet dataSet, MultiLayerNetwork network, boolean isLast) {
ParameterAveragingTrainingResult result = processMinibatch(dataSet, network, isLast);
if (result == null)
return null;
SparkTrainingStats statsToReturn = (stats != null ? stats.build() : null);
return new Pair<>(result, statsToReturn);
}
use of org.deeplearning4j.berkeley.Pair in project deeplearning4j by deeplearning4j.
the class ParameterAveragingTrainingWorker method getFinalResultWithStats.
@Override
public Pair<ParameterAveragingTrainingResult, SparkTrainingStats> getFinalResultWithStats(ComputationGraph graph) {
ParameterAveragingTrainingResult result = getFinalResult(graph);
if (result == null)
return null;
SparkTrainingStats statsToReturn = (stats != null ? stats.build() : null);
return new Pair<>(result, statsToReturn);
}
Aggregations