use of org.deeplearning4j.datasets.iterator.AsyncMultiDataSetIterator in project deeplearning4j by deeplearning4j.
the class ComputationGraph method fit.
/**
* Fit the ComputationGraph using a MultiDataSetIterator
*/
public void fit(MultiDataSetIterator multi) {
if (flattenedGradients == null)
initGradientsView();
MultiDataSetIterator multiDataSetIterator;
if (multi.asyncSupported()) {
multiDataSetIterator = new AsyncMultiDataSetIterator(multi, 2);
} else
multiDataSetIterator = multi;
if (configuration.isPretrain()) {
pretrain(multiDataSetIterator);
}
if (configuration.isBackprop()) {
while (multiDataSetIterator.hasNext()) {
MultiDataSet next = multiDataSetIterator.next();
if (next.getFeatures() == null || next.getLabels() == null)
break;
if (configuration.getBackpropType() == BackpropType.TruncatedBPTT) {
doTruncatedBPTT(next.getFeatures(), next.getLabels(), next.getFeaturesMaskArrays(), next.getLabelsMaskArrays());
} else {
boolean hasMaskArrays = next.hasMaskArrays();
if (hasMaskArrays) {
setLayerMaskArrays(next.getFeaturesMaskArrays(), next.getLabelsMaskArrays());
}
setInputs(next.getFeatures());
setLabels(next.getLabels());
if (solver == null) {
solver = new Solver.Builder().configure(defaultConfiguration).listeners(listeners).model(this).build();
}
solver.optimize();
if (hasMaskArrays) {
clearLayerMaskArrays();
}
}
Nd4j.getMemoryManager().invokeGcOccasionally();
}
}
}
use of org.deeplearning4j.datasets.iterator.AsyncMultiDataSetIterator in project deeplearning4j by deeplearning4j.
the class ExecuteWorkerMultiDataSetFlatMapAdapter method call.
@Override
public Iterable<R> call(Iterator<MultiDataSet> dataSetIterator) throws Exception {
WorkerConfiguration dataConfig = worker.getDataConfiguration();
boolean stats = dataConfig.isCollectTrainingStats();
StatsCalculationHelper s = (stats ? new StatsCalculationHelper() : null);
if (stats)
s.logMethodStartTime();
if (!dataSetIterator.hasNext()) {
if (stats)
s.logReturnTime();
//Sometimes: no data
return Collections.emptyList();
}
int batchSize = dataConfig.getBatchSizePerWorker();
final int prefetchCount = dataConfig.getPrefetchNumBatches();
MultiDataSetIterator batchedIterator = new IteratorMultiDataSetIterator(dataSetIterator, batchSize);
if (prefetchCount > 0) {
batchedIterator = new AsyncMultiDataSetIterator(batchedIterator, prefetchCount);
}
try {
if (stats)
s.logInitialModelBefore();
ComputationGraph net = worker.getInitialModelGraph();
if (stats)
s.logInitialModelAfter();
int miniBatchCount = 0;
int maxMinibatches = (dataConfig.getMaxBatchesPerWorker() > 0 ? dataConfig.getMaxBatchesPerWorker() : Integer.MAX_VALUE);
while (batchedIterator.hasNext() && miniBatchCount++ < maxMinibatches) {
if (stats)
s.logNextDataSetBefore();
MultiDataSet next = batchedIterator.next();
if (stats)
s.logNextDataSetAfter(next.getFeatures(0).size(0));
if (stats) {
s.logProcessMinibatchBefore();
Pair<R, SparkTrainingStats> result = worker.processMinibatchWithStats(next, net, !batchedIterator.hasNext());
s.logProcessMinibatchAfter();
if (result != null) {
//Terminate training immediately
s.logReturnTime();
SparkTrainingStats workerStats = result.getSecond();
SparkTrainingStats returnStats = s.build(workerStats);
result.getFirst().setStats(returnStats);
return Collections.singletonList(result.getFirst());
}
} else {
R result = worker.processMinibatch(next, net, !batchedIterator.hasNext());
if (result != null) {
//Terminate training immediately
return Collections.singletonList(result);
}
}
}
//For some reason, we didn't return already. Normally this shouldn't happen
if (stats) {
s.logReturnTime();
Pair<R, SparkTrainingStats> pair = worker.getFinalResultWithStats(net);
pair.getFirst().setStats(s.build(pair.getSecond()));
return Collections.singletonList(pair.getFirst());
} else {
return Collections.singletonList(worker.getFinalResult(net));
}
} finally {
//Make sure we shut down the async thread properly...
if (batchedIterator instanceof AsyncMultiDataSetIterator) {
((AsyncMultiDataSetIterator) batchedIterator).shutdown();
}
if (Nd4j.getExecutioner() instanceof GridExecutioner)
((GridExecutioner) Nd4j.getExecutioner()).flushQueueBlocking();
}
}
use of org.deeplearning4j.datasets.iterator.AsyncMultiDataSetIterator in project deeplearning4j by deeplearning4j.
the class ParallelWrapper method fit.
/**
*
* @param source
*/
public synchronized void fit(@NonNull MultiDataSetIterator source) {
stopFit.set(false);
if (zoo == null) {
zoo = new Trainer[workers];
for (int cnt = 0; cnt < workers; cnt++) {
// we pass true here, to tell Trainer to use MultiDataSet queue for training
zoo[cnt] = new Trainer(cnt, model, Nd4j.getAffinityManager().getDeviceForCurrentThread(), true);
zoo[cnt].setUncaughtExceptionHandler(handler);
zoo[cnt].start();
}
} else {
for (int cnt = 0; cnt < workers; cnt++) {
zoo[cnt].useMDS = true;
}
}
source.reset();
MultiDataSetIterator iterator;
if (prefetchSize > 0 && source.asyncSupported()) {
iterator = new AsyncMultiDataSetIterator(source, prefetchSize);
} else
iterator = source;
AtomicInteger locker = new AtomicInteger(0);
while (iterator.hasNext() && !stopFit.get()) {
MultiDataSet dataSet = iterator.next();
if (dataSet == null)
throw new ND4JIllegalStateException("You can't have NULL as MultiDataSet");
/*
now dataSet should be dispatched to next free workers, until all workers are busy. And then we should block till all finished.
*/
int pos = locker.getAndIncrement();
zoo[pos].feedMultiDataSet(dataSet);
/*
if all workers are dispatched now, join till all are finished
*/
if (pos + 1 == workers || !iterator.hasNext()) {
iterationsCounter.incrementAndGet();
for (int cnt = 0; cnt < workers && cnt < locker.get(); cnt++) {
try {
zoo[cnt].waitTillRunning();
} catch (Exception e) {
throw new RuntimeException(e);
}
}
Nd4j.getMemoryManager().invokeGcOccasionally();
/*
average model, and propagate it to whole
*/
if (iterationsCounter.get() % averagingFrequency == 0 && pos + 1 == workers) {
double score = getScore(locker);
// averaging updaters state
if (model instanceof ComputationGraph) {
averageUpdatersState(locker, score);
} else
throw new RuntimeException("MultiDataSet must only be used with ComputationGraph model");
if (legacyAveraging && Nd4j.getAffinityManager().getNumberOfDevices() > 1) {
for (int cnt = 0; cnt < workers; cnt++) {
zoo[cnt].updateModel(model);
}
}
}
locker.set(0);
}
}
// sanity checks, or the dataset may never average
if (!wasAveraged)
log.warn("Parameters were never averaged on current fit(). Ratios of batch size, num workers, and averaging frequency may be responsible.");
// throw new IllegalStateException("Parameters were never averaged. Please check batch size ratios, number of workers, and your averaging frequency.");
log.debug("Iterations passed: {}", iterationsCounter.get());
// iterationsCounter.set(0);
}
use of org.deeplearning4j.datasets.iterator.AsyncMultiDataSetIterator in project deeplearning4j by deeplearning4j.
the class ParameterServerParallelWrapper method fit.
public void fit(MultiDataSetIterator multiDataSetIterator) {
if (!init)
init(multiDataSetIterator);
MultiDataSetIterator iterator = null;
if (preFetchSize > 0 && multiDataSetIterator.asyncSupported()) {
iterator = new AsyncMultiDataSetIterator(multiDataSetIterator, preFetchSize);
} else
iterator = multiDataSetIterator;
while (iterator.hasNext()) {
org.nd4j.linalg.dataset.api.MultiDataSet next = iterator.next();
addObject(next);
}
}
Aggregations