use of org.nd4j.linalg.exception.ND4JIllegalStateException in project deeplearning4j by deeplearning4j.
the class WordVectorSerializer method readWord2VecModel.
/**
* This method
* 1) Binary model, either compressed or not. Like well-known Google Model
* 2) Popular CSV word2vec text format
* 3) DL4j compressed format
*
* Please note: if extended data isn't available, only weights will be loaded instead.
*
* @param file
* @param extendedModel if TRUE, we'll try to load HS states & Huffman tree info, if FALSE, only weights will be loaded
* @return
*/
public static Word2Vec readWord2VecModel(@NonNull File file, boolean extendedModel) {
InMemoryLookupTable<VocabWord> lookupTable = new InMemoryLookupTable<>();
AbstractCache<VocabWord> vocabCache = new AbstractCache<>();
Word2Vec vec;
INDArray syn0 = null;
VectorsConfiguration configuration = new VectorsConfiguration();
if (!file.exists() || !file.isFile())
throw new ND4JIllegalStateException("File [" + file.getAbsolutePath() + "] doesn't exist");
int originalFreq = Nd4j.getMemoryManager().getOccasionalGcFrequency();
boolean originalPeriodic = Nd4j.getMemoryManager().isPeriodicGcActive();
if (originalPeriodic)
Nd4j.getMemoryManager().togglePeriodicGc(false);
Nd4j.getMemoryManager().setOccasionalGcFrequency(50000);
// try to load zip format
try {
if (extendedModel) {
log.debug("Trying full model restoration...");
if (originalPeriodic)
Nd4j.getMemoryManager().togglePeriodicGc(true);
Nd4j.getMemoryManager().setOccasionalGcFrequency(originalFreq);
return readWord2Vec(file);
} else {
log.debug("Trying simplified model restoration...");
File tmpFileSyn0 = File.createTempFile("word2vec", "syn");
File tmpFileConfig = File.createTempFile("word2vec", "config");
// we don't need full model, so we go directly to syn0 file
ZipFile zipFile = new ZipFile(file);
ZipEntry syn = zipFile.getEntry("syn0.txt");
InputStream stream = zipFile.getInputStream(syn);
Files.copy(stream, Paths.get(tmpFileSyn0.getAbsolutePath()), StandardCopyOption.REPLACE_EXISTING);
// now we're restoring configuration saved earlier
ZipEntry config = zipFile.getEntry("config.json");
if (config != null) {
stream = zipFile.getInputStream(config);
StringBuilder builder = new StringBuilder();
try (BufferedReader reader = new BufferedReader(new InputStreamReader(stream))) {
String line;
while ((line = reader.readLine()) != null) {
builder.append(line);
}
}
configuration = VectorsConfiguration.fromJson(builder.toString().trim());
}
ZipEntry ve = zipFile.getEntry("frequencies.txt");
if (ve != null) {
stream = zipFile.getInputStream(ve);
AtomicInteger cnt = new AtomicInteger(0);
try (BufferedReader reader = new BufferedReader(new InputStreamReader(stream))) {
String line;
while ((line = reader.readLine()) != null) {
String[] split = line.split(" ");
VocabWord word = new VocabWord(Double.valueOf(split[1]), decodeB64(split[0]));
word.setIndex(cnt.getAndIncrement());
word.incrementSequencesCount(Long.valueOf(split[2]));
vocabCache.addToken(word);
vocabCache.addWordToIndex(word.getIndex(), word.getLabel());
Nd4j.getMemoryManager().invokeGcOccasionally();
}
}
}
List<INDArray> rows = new ArrayList<>();
// basically read up everything, call vstacl and then return model
try (Reader reader = new CSVReader(tmpFileSyn0)) {
AtomicInteger cnt = new AtomicInteger(0);
while (reader.hasNext()) {
Pair<VocabWord, float[]> pair = reader.next();
VocabWord word = pair.getFirst();
INDArray vector = Nd4j.create(pair.getSecond());
if (ve != null) {
if (syn0 == null)
syn0 = Nd4j.create(vocabCache.numWords(), vector.length());
syn0.getRow(cnt.getAndIncrement()).assign(vector);
} else {
rows.add(vector);
vocabCache.addToken(word);
vocabCache.addWordToIndex(word.getIndex(), word.getLabel());
}
Nd4j.getMemoryManager().invokeGcOccasionally();
}
} catch (Exception e) {
throw new RuntimeException(e);
} finally {
if (originalPeriodic)
Nd4j.getMemoryManager().togglePeriodicGc(true);
Nd4j.getMemoryManager().setOccasionalGcFrequency(originalFreq);
}
if (syn0 == null && vocabCache.numWords() > 0)
syn0 = Nd4j.vstack(rows);
if (syn0 == null) {
log.error("Can't build syn0 table");
throw new DL4JInvalidInputException("Can't build syn0 table");
}
lookupTable = new InMemoryLookupTable.Builder<VocabWord>().cache(vocabCache).vectorLength(syn0.columns()).useHierarchicSoftmax(false).useAdaGrad(false).build();
lookupTable.setSyn0(syn0);
try {
tmpFileSyn0.delete();
tmpFileConfig.delete();
} catch (Exception e) {
//
}
}
} catch (Exception e) {
// let's try to load this file as csv file
try {
log.debug("Trying CSV model restoration...");
Pair<InMemoryLookupTable, VocabCache> pair = loadTxt(file);
lookupTable = pair.getFirst();
vocabCache = (AbstractCache<VocabWord>) pair.getSecond();
} catch (Exception ex) {
// we fallback to trying binary model instead
try {
log.debug("Trying binary model restoration...");
if (originalPeriodic)
Nd4j.getMemoryManager().togglePeriodicGc(true);
Nd4j.getMemoryManager().setOccasionalGcFrequency(originalFreq);
vec = loadGoogleModel(file, true, true);
return vec;
} catch (Exception ey) {
// try to load without linebreaks
try {
if (originalPeriodic)
Nd4j.getMemoryManager().togglePeriodicGc(true);
Nd4j.getMemoryManager().setOccasionalGcFrequency(originalFreq);
vec = loadGoogleModel(file, true, false);
return vec;
} catch (Exception ez) {
throw new RuntimeException("Unable to guess input file format. Please use corresponding loader directly");
}
}
}
}
Word2Vec.Builder builder = new Word2Vec.Builder(configuration).lookupTable(lookupTable).useAdaGrad(false).vocabCache(vocabCache).layerSize(lookupTable.layerSize()).useHierarchicSoftmax(false).resetModel(false);
/*
Trying to restore TokenizerFactory & TokenPreProcessor
*/
TokenizerFactory factory = getTokenizerFactory(configuration);
if (factory != null)
builder.tokenizerFactory(factory);
vec = builder.build();
return vec;
}
use of org.nd4j.linalg.exception.ND4JIllegalStateException in project deeplearning4j by deeplearning4j.
the class ParallelWrapper method fit.
/**
* This method takes DataSetIterator, and starts training over it by scheduling DataSets to different executors
*
* @param source
*/
public synchronized void fit(@NonNull DataSetIterator source) {
stopFit.set(false);
if (zoo == null) {
zoo = new Trainer[workers];
for (int cnt = 0; cnt < workers; cnt++) {
zoo[cnt] = new Trainer(cnt, model, Nd4j.getAffinityManager().getDeviceForCurrentThread());
// if if we're using MQ here - we'd like
if (isMQ)
Nd4j.getAffinityManager().attachThreadToDevice(zoo[cnt], cnt % Nd4j.getAffinityManager().getNumberOfDevices());
zoo[cnt].setUncaughtExceptionHandler(handler);
zoo[cnt].start();
}
}
source.reset();
DataSetIterator iterator;
if (prefetchSize > 0 && source.asyncSupported()) {
if (isMQ) {
if (workers % Nd4j.getAffinityManager().getNumberOfDevices() != 0)
log.warn("Number of workers [{}] isn't optimal for available devices [{}]", workers, Nd4j.getAffinityManager().getNumberOfDevices());
MagicQueue queue = new MagicQueue.Builder().setCapacityPerFlow(8).setMode(MagicQueue.Mode.SEQUENTIAL).setNumberOfBuckets(Nd4j.getAffinityManager().getNumberOfDevices()).build();
iterator = new AsyncDataSetIterator(source, prefetchSize, queue);
} else
iterator = new AsyncDataSetIterator(source, prefetchSize);
} else
iterator = source;
AtomicInteger locker = new AtomicInteger(0);
int whiles = 0;
while (iterator.hasNext() && !stopFit.get()) {
whiles++;
DataSet dataSet = iterator.next();
if (dataSet == null)
throw new ND4JIllegalStateException("You can't have NULL as DataSet");
/*
now dataSet should be dispatched to next free workers, until all workers are busy. And then we should block till all finished.
*/
int pos = locker.getAndIncrement();
if (zoo == null)
throw new IllegalStateException("ParallelWrapper.shutdown() has been called too early and will fail from this point forward.");
zoo[pos].feedDataSet(dataSet);
/*
if all workers are dispatched now, join till all are finished
*/
if (pos + 1 == workers || !iterator.hasNext()) {
iterationsCounter.incrementAndGet();
for (int cnt = 0; cnt < workers && cnt < locker.get(); cnt++) {
try {
zoo[cnt].waitTillRunning();
} catch (Exception e) {
throw new RuntimeException(e);
}
}
Nd4j.getMemoryManager().invokeGcOccasionally();
/*
average model, and propagate it to whole
*/
if (iterationsCounter.get() % averagingFrequency == 0 && pos + 1 == workers) {
double score = getScore(locker);
// averaging updaters state
if (model instanceof MultiLayerNetwork) {
if (averageUpdaters) {
Updater updater = ((MultiLayerNetwork) model).getUpdater();
int batchSize = 0;
if (updater != null && updater.getStateViewArray() != null) {
if (!legacyAveraging || Nd4j.getAffinityManager().getNumberOfDevices() == 1) {
List<INDArray> updaters = new ArrayList<>();
for (int cnt = 0; cnt < workers && cnt < locker.get(); cnt++) {
MultiLayerNetwork workerModel = (MultiLayerNetwork) zoo[cnt].getModel();
updaters.add(workerModel.getUpdater().getStateViewArray());
batchSize += workerModel.batchSize();
}
Nd4j.averageAndPropagate(updater.getStateViewArray(), updaters);
} else {
INDArray state = Nd4j.zeros(updater.getStateViewArray().shape());
int cnt = 0;
for (; cnt < workers && cnt < locker.get(); cnt++) {
MultiLayerNetwork workerModel = (MultiLayerNetwork) zoo[cnt].getModel();
state.addi(workerModel.getUpdater().getStateViewArray().dup());
batchSize += workerModel.batchSize();
}
state.divi(cnt);
updater.setStateViewArray((MultiLayerNetwork) model, state, false);
}
}
}
((MultiLayerNetwork) model).setScore(score);
} else if (model instanceof ComputationGraph) {
averageUpdatersState(locker, score);
}
if (legacyAveraging && Nd4j.getAffinityManager().getNumberOfDevices() > 1) {
for (int cnt = 0; cnt < workers; cnt++) {
zoo[cnt].updateModel(model);
}
}
}
locker.set(0);
}
}
// sanity checks, or the dataset may never average
if (!wasAveraged)
log.warn("Parameters were never averaged on current fit(). Ratios of batch size, num workers, and averaging frequency may be responsible.");
// throw new IllegalStateException("Parameters were never averaged. Please check batch size ratios, number of workers, and your averaging frequency.");
log.debug("Iterations passed: {}", iterationsCounter.get());
// iterationsCounter.set(0);
}
use of org.nd4j.linalg.exception.ND4JIllegalStateException in project deeplearning4j by deeplearning4j.
the class ParagraphVectors method inferVector.
/**
* This method calculates inferred vector for given text
*
* @param text
* @return
*/
public INDArray inferVector(String text, double learningRate, double minLearningRate, int iterations) {
if (tokenizerFactory == null)
throw new IllegalStateException("TokenizerFactory should be defined, prior to predict() call");
if (this.vocab == null || this.vocab.numWords() == 0)
reassignExistingModel();
List<String> tokens = tokenizerFactory.create(text).getTokens();
List<VocabWord> document = new ArrayList<>();
for (String token : tokens) {
if (vocab.containsWord(token)) {
document.add(vocab.wordFor(token));
}
}
if (document.isEmpty())
throw new ND4JIllegalStateException("Text passed for inference has no matches in model vocabulary.");
return inferVector(document, learningRate, minLearningRate, iterations);
}
use of org.nd4j.linalg.exception.ND4JIllegalStateException in project deeplearning4j by deeplearning4j.
the class TrainingFunction method call.
@Override
@SuppressWarnings("unchecked")
public void call(Sequence<T> sequence) throws Exception {
/**
* Depending on actual training mode, we'll either go for SkipGram/CBOW/PV-DM/PV-DBOW or whatever
*/
if (vectorsConfiguration == null)
vectorsConfiguration = configurationBroadcast.getValue();
if (paramServer == null) {
paramServer = VoidParameterServer.getInstance();
if (elementsLearningAlgorithm == null) {
try {
elementsLearningAlgorithm = (SparkElementsLearningAlgorithm) Class.forName(vectorsConfiguration.getElementsLearningAlgorithm()).newInstance();
} catch (Exception e) {
throw new RuntimeException(e);
}
}
driver = elementsLearningAlgorithm.getTrainingDriver();
// FIXME: init line should probably be removed, basically init happens in VocabRddFunction
paramServer.init(paramServerConfigurationBroadcast.getValue(), new RoutedTransport(), driver);
}
if (vectorsConfiguration == null)
vectorsConfiguration = configurationBroadcast.getValue();
if (shallowVocabCache == null)
shallowVocabCache = vocabCacheBroadcast.getValue();
if (elementsLearningAlgorithm == null && vectorsConfiguration.getElementsLearningAlgorithm() != null) {
// TODO: do ELA initialization
try {
elementsLearningAlgorithm = (SparkElementsLearningAlgorithm) Class.forName(vectorsConfiguration.getElementsLearningAlgorithm()).newInstance();
elementsLearningAlgorithm.configure(shallowVocabCache, null, vectorsConfiguration);
} catch (Exception e) {
throw new RuntimeException(e);
}
}
if (sequenceLearningAlgorithm == null && vectorsConfiguration.getSequenceLearningAlgorithm() != null) {
// TODO: do SLA initialization
try {
sequenceLearningAlgorithm = (SparkSequenceLearningAlgorithm) Class.forName(vectorsConfiguration.getSequenceLearningAlgorithm()).newInstance();
sequenceLearningAlgorithm.configure(shallowVocabCache, null, vectorsConfiguration);
} catch (Exception e) {
throw new RuntimeException(e);
}
}
if (elementsLearningAlgorithm == null && sequenceLearningAlgorithm == null) {
throw new ND4JIllegalStateException("No LearningAlgorithms specified!");
}
/*
at this moment we should have everything ready for actual initialization
the only limitation we have - our sequence is detached from actual vocabulary, so we need to merge it back virtually
*/
Sequence<ShallowSequenceElement> mergedSequence = new Sequence<>();
for (T element : sequence.getElements()) {
// it's possible to get null here, i.e. if frequency for this element is below minWordFrequency threshold
ShallowSequenceElement reduced = shallowVocabCache.tokenFor(element.getStorageId());
if (reduced != null)
mergedSequence.addElement(reduced);
}
// do the same with labels, transfer them, if any
if (sequenceLearningAlgorithm != null && vectorsConfiguration.isTrainSequenceVectors()) {
for (T label : sequence.getSequenceLabels()) {
ShallowSequenceElement reduced = shallowVocabCache.tokenFor(label.getStorageId());
if (reduced != null)
mergedSequence.addSequenceLabel(reduced);
}
}
// FIXME: temporary hook
if (sequence.size() > 0)
paramServer.execDistributed(elementsLearningAlgorithm.frameSequence(mergedSequence, new AtomicLong(119), 25e-3));
else
log.warn("Skipping empty sequence...");
}
use of org.nd4j.linalg.exception.ND4JIllegalStateException in project nd4j by deeplearning4j.
the class Nd4j method create.
public static INDArray create(double[] data, int[] shape, char ordering, long offset) {
shape = getEnsuredShape(shape);
if (shape.length == 1) {
if (shape[0] != data.length)
throw new ND4JIllegalStateException("Shape of the new array " + Arrays.toString(shape) + " doesn't match data length: " + data.length);
}
checkShapeValues(data.length, shape);
INDArray ret = INSTANCE.create(data, shape, getStrides(shape, ordering), offset, ordering);
logCreationIfNecessary(ret);
return ret;
}
Aggregations