use of org.deeplearning4j.datasets.iterator.AsyncDataSetIterator in project deeplearning4j by deeplearning4j.
the class ParallelWrapper method fit.
/**
* This method takes DataSetIterator, and starts training over it by scheduling DataSets to different executors
*
* @param source
*/
public synchronized void fit(@NonNull DataSetIterator source) {
stopFit.set(false);
if (zoo == null) {
zoo = new Trainer[workers];
for (int cnt = 0; cnt < workers; cnt++) {
zoo[cnt] = new Trainer(cnt, model, Nd4j.getAffinityManager().getDeviceForCurrentThread());
// if if we're using MQ here - we'd like
if (isMQ)
Nd4j.getAffinityManager().attachThreadToDevice(zoo[cnt], cnt % Nd4j.getAffinityManager().getNumberOfDevices());
zoo[cnt].setUncaughtExceptionHandler(handler);
zoo[cnt].start();
}
}
source.reset();
DataSetIterator iterator;
if (prefetchSize > 0 && source.asyncSupported()) {
if (isMQ) {
if (workers % Nd4j.getAffinityManager().getNumberOfDevices() != 0)
log.warn("Number of workers [{}] isn't optimal for available devices [{}]", workers, Nd4j.getAffinityManager().getNumberOfDevices());
MagicQueue queue = new MagicQueue.Builder().setCapacityPerFlow(8).setMode(MagicQueue.Mode.SEQUENTIAL).setNumberOfBuckets(Nd4j.getAffinityManager().getNumberOfDevices()).build();
iterator = new AsyncDataSetIterator(source, prefetchSize, queue);
} else
iterator = new AsyncDataSetIterator(source, prefetchSize);
} else
iterator = source;
AtomicInteger locker = new AtomicInteger(0);
int whiles = 0;
while (iterator.hasNext() && !stopFit.get()) {
whiles++;
DataSet dataSet = iterator.next();
if (dataSet == null)
throw new ND4JIllegalStateException("You can't have NULL as DataSet");
/*
now dataSet should be dispatched to next free workers, until all workers are busy. And then we should block till all finished.
*/
int pos = locker.getAndIncrement();
if (zoo == null)
throw new IllegalStateException("ParallelWrapper.shutdown() has been called too early and will fail from this point forward.");
zoo[pos].feedDataSet(dataSet);
/*
if all workers are dispatched now, join till all are finished
*/
if (pos + 1 == workers || !iterator.hasNext()) {
iterationsCounter.incrementAndGet();
for (int cnt = 0; cnt < workers && cnt < locker.get(); cnt++) {
try {
zoo[cnt].waitTillRunning();
} catch (Exception e) {
throw new RuntimeException(e);
}
}
Nd4j.getMemoryManager().invokeGcOccasionally();
/*
average model, and propagate it to whole
*/
if (iterationsCounter.get() % averagingFrequency == 0 && pos + 1 == workers) {
double score = getScore(locker);
// averaging updaters state
if (model instanceof MultiLayerNetwork) {
if (averageUpdaters) {
Updater updater = ((MultiLayerNetwork) model).getUpdater();
int batchSize = 0;
if (updater != null && updater.getStateViewArray() != null) {
if (!legacyAveraging || Nd4j.getAffinityManager().getNumberOfDevices() == 1) {
List<INDArray> updaters = new ArrayList<>();
for (int cnt = 0; cnt < workers && cnt < locker.get(); cnt++) {
MultiLayerNetwork workerModel = (MultiLayerNetwork) zoo[cnt].getModel();
updaters.add(workerModel.getUpdater().getStateViewArray());
batchSize += workerModel.batchSize();
}
Nd4j.averageAndPropagate(updater.getStateViewArray(), updaters);
} else {
INDArray state = Nd4j.zeros(updater.getStateViewArray().shape());
int cnt = 0;
for (; cnt < workers && cnt < locker.get(); cnt++) {
MultiLayerNetwork workerModel = (MultiLayerNetwork) zoo[cnt].getModel();
state.addi(workerModel.getUpdater().getStateViewArray().dup());
batchSize += workerModel.batchSize();
}
state.divi(cnt);
updater.setStateViewArray((MultiLayerNetwork) model, state, false);
}
}
}
((MultiLayerNetwork) model).setScore(score);
} else if (model instanceof ComputationGraph) {
averageUpdatersState(locker, score);
}
if (legacyAveraging && Nd4j.getAffinityManager().getNumberOfDevices() > 1) {
for (int cnt = 0; cnt < workers; cnt++) {
zoo[cnt].updateModel(model);
}
}
}
locker.set(0);
}
}
// sanity checks, or the dataset may never average
if (!wasAveraged)
log.warn("Parameters were never averaged on current fit(). Ratios of batch size, num workers, and averaging frequency may be responsible.");
// throw new IllegalStateException("Parameters were never averaged. Please check batch size ratios, number of workers, and your averaging frequency.");
log.debug("Iterations passed: {}", iterationsCounter.get());
// iterationsCounter.set(0);
}
use of org.deeplearning4j.datasets.iterator.AsyncDataSetIterator in project deeplearning4j by deeplearning4j.
the class ExecuteWorkerFlatMapAdapter method call.
@Override
public Iterable<R> call(Iterator<DataSet> dataSetIterator) throws Exception {
WorkerConfiguration dataConfig = worker.getDataConfiguration();
final boolean isGraph = dataConfig.isGraphNetwork();
boolean stats = dataConfig.isCollectTrainingStats();
StatsCalculationHelper s = (stats ? new StatsCalculationHelper() : null);
if (stats)
s.logMethodStartTime();
if (!dataSetIterator.hasNext()) {
if (stats) {
s.logReturnTime();
Pair<R, SparkTrainingStats> pair = worker.getFinalResultNoDataWithStats();
pair.getFirst().setStats(s.build(pair.getSecond()));
return Collections.singletonList(pair.getFirst());
} else {
return Collections.singletonList(worker.getFinalResultNoData());
}
}
int batchSize = dataConfig.getBatchSizePerWorker();
final int prefetchCount = dataConfig.getPrefetchNumBatches();
DataSetIterator batchedIterator = new IteratorDataSetIterator(dataSetIterator, batchSize);
if (prefetchCount > 0) {
batchedIterator = new AsyncDataSetIterator(batchedIterator, prefetchCount);
}
try {
MultiLayerNetwork net = null;
ComputationGraph graph = null;
if (stats)
s.logInitialModelBefore();
if (isGraph)
graph = worker.getInitialModelGraph();
else
net = worker.getInitialModel();
if (stats)
s.logInitialModelAfter();
int miniBatchCount = 0;
int maxMinibatches = (dataConfig.getMaxBatchesPerWorker() > 0 ? dataConfig.getMaxBatchesPerWorker() : Integer.MAX_VALUE);
while (batchedIterator.hasNext() && miniBatchCount++ < maxMinibatches) {
if (stats)
s.logNextDataSetBefore();
DataSet next = batchedIterator.next();
if (stats)
s.logNextDataSetAfter(next.numExamples());
if (stats) {
s.logProcessMinibatchBefore();
Pair<R, SparkTrainingStats> result;
if (isGraph)
result = worker.processMinibatchWithStats(next, graph, !batchedIterator.hasNext());
else
result = worker.processMinibatchWithStats(next, net, !batchedIterator.hasNext());
s.logProcessMinibatchAfter();
if (result != null) {
//Terminate training immediately
s.logReturnTime();
SparkTrainingStats workerStats = result.getSecond();
SparkTrainingStats returnStats = s.build(workerStats);
result.getFirst().setStats(returnStats);
return Collections.singletonList(result.getFirst());
}
} else {
R result;
if (isGraph)
result = worker.processMinibatch(next, graph, !batchedIterator.hasNext());
else
result = worker.processMinibatch(next, net, !batchedIterator.hasNext());
if (result != null) {
//Terminate training immediately
return Collections.singletonList(result);
}
}
}
//For some reason, we didn't return already. Normally this shouldn't happen
if (stats) {
s.logReturnTime();
Pair<R, SparkTrainingStats> pair;
if (isGraph)
pair = worker.getFinalResultWithStats(graph);
else
pair = worker.getFinalResultWithStats(net);
pair.getFirst().setStats(s.build(pair.getSecond()));
return Collections.singletonList(pair.getFirst());
} else {
if (isGraph)
return Collections.singletonList(worker.getFinalResult(graph));
else
return Collections.singletonList(worker.getFinalResult(net));
}
} finally {
//Make sure we shut down the async thread properly...
if (batchedIterator instanceof AsyncDataSetIterator) {
((AsyncDataSetIterator) batchedIterator).shutdown();
}
if (Nd4j.getExecutioner() instanceof GridExecutioner)
((GridExecutioner) Nd4j.getExecutioner()).flushQueueBlocking();
}
}
use of org.deeplearning4j.datasets.iterator.AsyncDataSetIterator in project deeplearning4j by deeplearning4j.
the class MagicQueueTest method testSequentialIterable.
@Test
public void testSequentialIterable() throws Exception {
List<DataSet> list = new ArrayList<>();
for (int i = 0; i < 1024; i++) list.add(new DataSet(Nd4j.create(new float[] { 1f, 2f, 3f }), Nd4j.create(new float[] { 1f, 2f, 3f })));
int numDevices = Nd4j.getAffinityManager().getNumberOfDevices();
ExistingDataSetIterator edsi = new ExistingDataSetIterator(list);
MagicQueue queue = new MagicQueue.Builder().setMode(MagicQueue.Mode.SEQUENTIAL).setCapacityPerFlow(32).build();
AsyncDataSetIterator adsi = new AsyncDataSetIterator(edsi, 10, queue);
int cnt = 0;
while (adsi.hasNext()) {
DataSet ds = adsi.next();
// making sure dataset isn't null
assertNotEquals("Failed on round " + cnt, null, ds);
// making sure device for this array is a "next one"
assertEquals(cnt % numDevices, Nd4j.getAffinityManager().getDeviceForArray(ds.getFeatures()).intValue());
assertEquals(cnt % numDevices, Nd4j.getAffinityManager().getDeviceForArray(ds.getLabels()).intValue());
cnt++;
}
assertEquals(list.size(), cnt);
}
use of org.deeplearning4j.datasets.iterator.AsyncDataSetIterator in project deeplearning4j by deeplearning4j.
the class MultiLayerNetwork method fit.
@Override
public void fit(DataSetIterator iterator) {
DataSetIterator iter;
// we're wrapping all iterators into AsyncDataSetIterator to provide background prefetch - where appropriate
if (iterator.asyncSupported()) {
iter = new AsyncDataSetIterator(iterator, 2);
} else {
iter = iterator;
}
if (trainingListeners.size() > 0) {
for (TrainingListener tl : trainingListeners) {
tl.onEpochStart(this);
}
}
if (layerWiseConfigurations.isPretrain()) {
pretrain(iter);
if (iter.resetSupported()) {
iter.reset();
}
// while (iter.hasNext()) {
// DataSet next = iter.next();
// if (next.getFeatureMatrix() == null || next.getLabels() == null)
// break;
// setInput(next.getFeatureMatrix());
// setLabels(next.getLabels());
// finetune();
// }
}
if (layerWiseConfigurations.isBackprop()) {
update(TaskUtils.buildTask(iter));
if (!iter.hasNext() && iter.resetSupported()) {
iter.reset();
}
while (iter.hasNext()) {
DataSet next = iter.next();
if (next.getFeatureMatrix() == null || next.getLabels() == null)
break;
boolean hasMaskArrays = next.hasMaskArrays();
if (layerWiseConfigurations.getBackpropType() == BackpropType.TruncatedBPTT) {
doTruncatedBPTT(next.getFeatureMatrix(), next.getLabels(), next.getFeaturesMaskArray(), next.getLabelsMaskArray());
} else {
if (hasMaskArrays)
setLayerMaskArrays(next.getFeaturesMaskArray(), next.getLabelsMaskArray());
setInput(next.getFeatureMatrix());
setLabels(next.getLabels());
if (solver == null) {
solver = new Solver.Builder().configure(conf()).listeners(getListeners()).model(this).build();
}
solver.optimize();
}
if (hasMaskArrays)
clearLayerMaskArrays();
Nd4j.getMemoryManager().invokeGcOccasionally();
}
} else if (layerWiseConfigurations.isPretrain()) {
log.warn("Warning: finetune is not applied.");
}
if (trainingListeners.size() > 0) {
for (TrainingListener tl : trainingListeners) {
tl.onEpochEnd(this);
}
}
}
use of org.deeplearning4j.datasets.iterator.AsyncDataSetIterator in project deeplearning4j by deeplearning4j.
the class ComputationGraph method fit.
/**
* Fit the ComputationGraph using a DataSetIterator.
* Note that this method can only be used with ComputationGraphs with 1 input and 1 output
*/
public void fit(DataSetIterator iterator) {
if (flattenedGradients == null)
initGradientsView();
if (numInputArrays != 1 || numOutputArrays != 1)
throw new UnsupportedOperationException("Cannot train ComputationGraph network with " + " multiple inputs or outputs using a DataSetIterator");
DataSetIterator dataSetIterator;
// we're wrapping all iterators into AsyncDataSetIterator to provide background prefetch - where appropriate
if (iterator.asyncSupported()) {
dataSetIterator = new AsyncDataSetIterator(iterator, 2);
} else
dataSetIterator = iterator;
if (trainingListeners.size() > 0) {
for (TrainingListener tl : trainingListeners) {
tl.onEpochStart(this);
}
}
if (configuration.isPretrain()) {
pretrain(dataSetIterator);
}
if (configuration.isBackprop()) {
update(TaskUtils.buildTask(dataSetIterator));
while (dataSetIterator.hasNext()) {
DataSet next = dataSetIterator.next();
if (next.getFeatures() == null || next.getLabels() == null)
break;
boolean hasMaskArrays = next.hasMaskArrays();
if (hasMaskArrays) {
INDArray[] fMask = (next.getFeaturesMaskArray() != null ? new INDArray[] { next.getFeaturesMaskArray() } : null);
INDArray[] lMask = (next.getLabelsMaskArray() != null ? new INDArray[] { next.getLabelsMaskArray() } : null);
setLayerMaskArrays(fMask, lMask);
}
if (configuration.getBackpropType() == BackpropType.TruncatedBPTT) {
doTruncatedBPTT(new INDArray[] { next.getFeatures() }, new INDArray[] { next.getLabels() }, (hasMaskArrays ? new INDArray[] { next.getFeaturesMaskArray() } : null), (hasMaskArrays ? new INDArray[] { next.getLabelsMaskArray() } : null));
} else {
setInput(0, next.getFeatures());
setLabel(0, next.getLabels());
if (solver == null) {
solver = //TODO; don't like this
new Solver.Builder().configure(defaultConfiguration).listeners(listeners).model(this).build();
}
solver.optimize();
}
if (hasMaskArrays) {
clearLayerMaskArrays();
}
Nd4j.getMemoryManager().invokeGcOccasionally();
}
}
if (trainingListeners.size() > 0) {
for (TrainingListener tl : trainingListeners) {
tl.onEpochEnd(this);
}
}
}
Aggregations