use of org.nd4j.linalg.exception.ND4JIllegalStateException in project nd4j by deeplearning4j.
the class TrainerProvider method scanClasspath.
protected void scanClasspath() {
// TODO: reflection stuff to fill trainers
Reflections reflections = new Reflections("org");
Set<Class<? extends TrainingDriver>> classes = reflections.getSubTypesOf(TrainingDriver.class);
for (Class clazz : classes) {
if (clazz.isInterface() || Modifier.isAbstract(clazz.getModifiers()))
continue;
try {
TrainingDriver driver = (TrainingDriver) clazz.newInstance();
trainers.put(driver.targetMessageClass(), driver);
} catch (Exception e) {
throw new RuntimeException(e);
}
}
if (trainers.size() < 1)
throw new ND4JIllegalStateException("No TrainingDrivers were found");
}
use of org.nd4j.linalg.exception.ND4JIllegalStateException in project nd4j by deeplearning4j.
the class BaseTransport method launch.
/**
* This method starts transport mechanisms.
*
* PLEASE NOTE: init() method should be called prior to launch() call
*/
@Override
public void launch(@NonNull ThreadingModel threading) {
this.threadingModel = threading;
switch(threading) {
case SINGLE_THREAD:
{
log.warn("SINGLE_THREAD model is used, performance will be significantly reduced");
// single thread for all queues. shouldn't be used in real world
threadA = new Thread(() -> {
while (runner.get()) {
if (subscriptionForShards != null)
subscriptionForShards.poll(messageHandlerForShards, 512);
idler.idle(subscriptionForClients.poll(messageHandlerForClients, 512));
}
});
threadA.start();
}
break;
case DEDICATED_THREADS:
{
// we start separate thread for each handler
/**
* We definitely might use less conditional code here, BUT i'll keep it as is,
* only because we want code to be obvious for people
*/
final AtomicBoolean localRunner = new AtomicBoolean(false);
if (nodeRole == NodeRole.NONE) {
throw new ND4JIllegalStateException("No role is set for current node!");
} else if (nodeRole == NodeRole.SHARD || nodeRole == NodeRole.BACKUP || nodeRole == NodeRole.MASTER) {
// setting up thread for shard->client communication listener
if (messageHandlerForShards != null)
threadB = new Thread(() -> {
while (runner.get()) idler.idle(subscriptionForShards.poll(messageHandlerForShards, 512));
});
// setting up thread for inter-shard communication listener
threadA = new Thread(() -> {
localRunner.set(true);
while (runner.get()) idler.idle(subscriptionForClients.poll(messageHandlerForClients, 512));
});
if (threadB != null) {
Nd4j.getAffinityManager().attachThreadToDevice(threadB, Nd4j.getAffinityManager().getDeviceForCurrentThread());
threadB.setDaemon(true);
threadB.setName("VoidParamServer subscription threadB [" + nodeRole + "]");
threadB.start();
}
} else {
// setting up thread for shard->client communication listener
threadA = new Thread(() -> {
localRunner.set(true);
while (runner.get()) idler.idle(subscriptionForClients.poll(messageHandlerForClients, 512));
});
}
// all roles have threadA anyway
Nd4j.getAffinityManager().attachThreadToDevice(threadA, Nd4j.getAffinityManager().getDeviceForCurrentThread());
threadA.setDaemon(true);
threadA.setName("VoidParamServer subscription threadA [" + nodeRole + "]");
threadA.start();
while (!localRunner.get()) try {
Thread.sleep(50);
} catch (Exception e) {
}
}
break;
case SAME_THREAD:
{
// no additional threads at all, we do poll within takeMessage loop
log.warn("SAME_THREAD model is used, performance will be dramatically reduced");
}
break;
default:
throw new IllegalStateException("Unknown thread model: [" + threading.toString() + "]");
}
}
use of org.nd4j.linalg.exception.ND4JIllegalStateException in project deeplearning4j by deeplearning4j.
the class InMemoryLookupTable method consume.
/**
* This method consumes weights of a given InMemoryLookupTable
*
* PLEASE NOTE: this method explicitly resets current weights
*
* @param srcTable
*/
public void consume(InMemoryLookupTable<T> srcTable) {
if (srcTable.vectorLength != this.vectorLength)
throw new IllegalStateException("You can't consume lookupTable with different vector lengths");
if (srcTable.syn0 == null)
throw new IllegalStateException("Source lookupTable Syn0 is NULL");
this.resetWeights(true);
AtomicInteger cntHs = new AtomicInteger(0);
AtomicInteger cntNg = new AtomicInteger(0);
if (srcTable.syn0.rows() > this.syn0.rows())
throw new IllegalStateException("You can't consume lookupTable with built for larger vocabulary without updating your vocabulary first");
for (int x = 0; x < srcTable.syn0.rows(); x++) {
this.syn0.putRow(x, srcTable.syn0.getRow(x));
if (this.syn1 != null && srcTable.syn1 != null)
this.syn1.putRow(x, srcTable.syn1.getRow(x));
else if (cntHs.incrementAndGet() == 1)
log.info("Skipping syn1 merge");
if (this.syn1Neg != null && srcTable.syn1Neg != null) {
this.syn1Neg.putRow(x, srcTable.syn1Neg.getRow(x));
} else if (cntNg.incrementAndGet() == 1)
log.info("Skipping syn1Neg merge");
if (cntHs.get() > 0 && cntNg.get() > 0)
throw new ND4JIllegalStateException("srcTable has no syn1/syn1neg");
}
}
use of org.nd4j.linalg.exception.ND4JIllegalStateException in project deeplearning4j by deeplearning4j.
the class ParagraphVectors method inferVector.
/**
* This method calculates inferred vector for given document
*
* @param document
* @return
*/
public INDArray inferVector(@NonNull List<VocabWord> document, double learningRate, double minLearningRate, int iterations) {
SequenceLearningAlgorithm<VocabWord> learner = sequenceLearningAlgorithm;
if (learner == null) {
synchronized (this) {
if (sequenceLearningAlgorithm == null) {
log.info("Creating new PV-DM learner...");
learner = new DM<VocabWord>();
learner.configure(vocab, lookupTable, configuration);
sequenceLearningAlgorithm = learner;
} else {
learner = sequenceLearningAlgorithm;
}
}
}
learner = sequenceLearningAlgorithm;
if (document.isEmpty())
throw new ND4JIllegalStateException("Impossible to apply inference to empty list of words");
Sequence<VocabWord> sequence = new Sequence<>();
sequence.addElements(document);
sequence.setSequenceLabel(new VocabWord(1.0, String.valueOf(new Random().nextInt())));
initLearners();
INDArray inf = learner.inferSequence(sequence, seed, learningRate, minLearningRate, iterations);
return inf;
}
use of org.nd4j.linalg.exception.ND4JIllegalStateException in project deeplearning4j by deeplearning4j.
the class ParallelWrapper method fit.
/**
*
* @param source
*/
public synchronized void fit(@NonNull MultiDataSetIterator source) {
stopFit.set(false);
if (zoo == null) {
zoo = new Trainer[workers];
for (int cnt = 0; cnt < workers; cnt++) {
// we pass true here, to tell Trainer to use MultiDataSet queue for training
zoo[cnt] = new Trainer(cnt, model, Nd4j.getAffinityManager().getDeviceForCurrentThread(), true);
zoo[cnt].setUncaughtExceptionHandler(handler);
zoo[cnt].start();
}
} else {
for (int cnt = 0; cnt < workers; cnt++) {
zoo[cnt].useMDS = true;
}
}
source.reset();
MultiDataSetIterator iterator;
if (prefetchSize > 0 && source.asyncSupported()) {
iterator = new AsyncMultiDataSetIterator(source, prefetchSize);
} else
iterator = source;
AtomicInteger locker = new AtomicInteger(0);
while (iterator.hasNext() && !stopFit.get()) {
MultiDataSet dataSet = iterator.next();
if (dataSet == null)
throw new ND4JIllegalStateException("You can't have NULL as MultiDataSet");
/*
now dataSet should be dispatched to next free workers, until all workers are busy. And then we should block till all finished.
*/
int pos = locker.getAndIncrement();
zoo[pos].feedMultiDataSet(dataSet);
/*
if all workers are dispatched now, join till all are finished
*/
if (pos + 1 == workers || !iterator.hasNext()) {
iterationsCounter.incrementAndGet();
for (int cnt = 0; cnt < workers && cnt < locker.get(); cnt++) {
try {
zoo[cnt].waitTillRunning();
} catch (Exception e) {
throw new RuntimeException(e);
}
}
Nd4j.getMemoryManager().invokeGcOccasionally();
/*
average model, and propagate it to whole
*/
if (iterationsCounter.get() % averagingFrequency == 0 && pos + 1 == workers) {
double score = getScore(locker);
// averaging updaters state
if (model instanceof ComputationGraph) {
averageUpdatersState(locker, score);
} else
throw new RuntimeException("MultiDataSet must only be used with ComputationGraph model");
if (legacyAveraging && Nd4j.getAffinityManager().getNumberOfDevices() > 1) {
for (int cnt = 0; cnt < workers; cnt++) {
zoo[cnt].updateModel(model);
}
}
}
locker.set(0);
}
}
// sanity checks, or the dataset may never average
if (!wasAveraged)
log.warn("Parameters were never averaged on current fit(). Ratios of batch size, num workers, and averaging frequency may be responsible.");
// throw new IllegalStateException("Parameters were never averaged. Please check batch size ratios, number of workers, and your averaging frequency.");
log.debug("Iterations passed: {}", iterationsCounter.get());
// iterationsCounter.set(0);
}
Aggregations