use of org.deeplearning4j.spark.api.stats.SparkTrainingStats in project deeplearning4j by deeplearning4j.
the class ParameterAveragingTrainingWorker method processMinibatchWithStats.
@Override
public Pair<ParameterAveragingTrainingResult, SparkTrainingStats> processMinibatchWithStats(DataSet dataSet, MultiLayerNetwork network, boolean isLast) {
ParameterAveragingTrainingResult result = processMinibatch(dataSet, network, isLast);
if (result == null)
return null;
SparkTrainingStats statsToReturn = (stats != null ? stats.build() : null);
return new Pair<>(result, statsToReturn);
}
use of org.deeplearning4j.spark.api.stats.SparkTrainingStats in project deeplearning4j by deeplearning4j.
the class ParameterAveragingTrainingWorker method getFinalResultWithStats.
@Override
public Pair<ParameterAveragingTrainingResult, SparkTrainingStats> getFinalResultWithStats(ComputationGraph graph) {
ParameterAveragingTrainingResult result = getFinalResult(graph);
if (result == null)
return null;
SparkTrainingStats statsToReturn = (stats != null ? stats.build() : null);
return new Pair<>(result, statsToReturn);
}
use of org.deeplearning4j.spark.api.stats.SparkTrainingStats in project deeplearning4j by deeplearning4j.
the class ParameterAveragingElementCombineFunction method call.
@Override
public ParameterAveragingAggregationTuple call(ParameterAveragingAggregationTuple v1, ParameterAveragingAggregationTuple v2) throws Exception {
if (v1 == null)
return v2;
else if (v2 == null)
return v1;
//Handle edge case of less data than executors: in this case, one (or both) of v1 and v2 might not have any contents...
if (v1.getParametersSum() == null)
return v2;
else if (v2.getParametersSum() == null)
return v1;
INDArray newParams = v1.getParametersSum().addi(v2.getParametersSum());
INDArray updaterStateSum;
if (v1.getUpdaterStateSum() == null) {
updaterStateSum = v2.getUpdaterStateSum();
} else {
updaterStateSum = v1.getUpdaterStateSum();
if (v2.getUpdaterStateSum() != null)
updaterStateSum.addi(v2.getUpdaterStateSum());
}
double scoreSum = v1.getScoreSum() + v2.getScoreSum();
int aggregationCount = v1.getAggregationsCount() + v2.getAggregationsCount();
SparkTrainingStats stats = v1.getSparkTrainingStats();
if (v2.getSparkTrainingStats() != null) {
if (stats == null)
stats = v2.getSparkTrainingStats();
else
stats.addOtherTrainingStats(v2.getSparkTrainingStats());
}
if (Nd4j.getExecutioner() instanceof GridExecutioner)
((GridExecutioner) Nd4j.getExecutioner()).flushQueueBlocking();
Collection<StorageMetaData> listenerMetaData = v1.getListenerMetaData();
if (listenerMetaData == null)
listenerMetaData = v2.getListenerMetaData();
else {
Collection<StorageMetaData> newMeta = v2.getListenerMetaData();
if (newMeta != null)
listenerMetaData.addAll(newMeta);
}
Collection<Persistable> listenerStaticInfo = v1.getListenerStaticInfo();
if (listenerStaticInfo == null)
listenerStaticInfo = v2.getListenerStaticInfo();
else {
Collection<Persistable> newStatic = v2.getListenerStaticInfo();
if (newStatic != null)
listenerStaticInfo.addAll(newStatic);
}
Collection<Persistable> listenerUpdates = v1.getListenerUpdates();
if (listenerUpdates == null)
listenerUpdates = v2.getListenerUpdates();
else {
Collection<Persistable> listenerUpdates2 = v2.getListenerUpdates();
if (listenerUpdates2 != null)
listenerUpdates.addAll(listenerUpdates2);
}
return new ParameterAveragingAggregationTuple(newParams, updaterStateSum, scoreSum, aggregationCount, stats, listenerMetaData, listenerStaticInfo, listenerUpdates);
}
use of org.deeplearning4j.spark.api.stats.SparkTrainingStats in project deeplearning4j by deeplearning4j.
the class ExecuteWorkerFlatMapAdapter method call.
@Override
public Iterable<R> call(Iterator<DataSet> dataSetIterator) throws Exception {
WorkerConfiguration dataConfig = worker.getDataConfiguration();
final boolean isGraph = dataConfig.isGraphNetwork();
boolean stats = dataConfig.isCollectTrainingStats();
StatsCalculationHelper s = (stats ? new StatsCalculationHelper() : null);
if (stats)
s.logMethodStartTime();
if (!dataSetIterator.hasNext()) {
if (stats) {
s.logReturnTime();
Pair<R, SparkTrainingStats> pair = worker.getFinalResultNoDataWithStats();
pair.getFirst().setStats(s.build(pair.getSecond()));
return Collections.singletonList(pair.getFirst());
} else {
return Collections.singletonList(worker.getFinalResultNoData());
}
}
int batchSize = dataConfig.getBatchSizePerWorker();
final int prefetchCount = dataConfig.getPrefetchNumBatches();
DataSetIterator batchedIterator = new IteratorDataSetIterator(dataSetIterator, batchSize);
if (prefetchCount > 0) {
batchedIterator = new AsyncDataSetIterator(batchedIterator, prefetchCount);
}
try {
MultiLayerNetwork net = null;
ComputationGraph graph = null;
if (stats)
s.logInitialModelBefore();
if (isGraph)
graph = worker.getInitialModelGraph();
else
net = worker.getInitialModel();
if (stats)
s.logInitialModelAfter();
int miniBatchCount = 0;
int maxMinibatches = (dataConfig.getMaxBatchesPerWorker() > 0 ? dataConfig.getMaxBatchesPerWorker() : Integer.MAX_VALUE);
while (batchedIterator.hasNext() && miniBatchCount++ < maxMinibatches) {
if (stats)
s.logNextDataSetBefore();
DataSet next = batchedIterator.next();
if (stats)
s.logNextDataSetAfter(next.numExamples());
if (stats) {
s.logProcessMinibatchBefore();
Pair<R, SparkTrainingStats> result;
if (isGraph)
result = worker.processMinibatchWithStats(next, graph, !batchedIterator.hasNext());
else
result = worker.processMinibatchWithStats(next, net, !batchedIterator.hasNext());
s.logProcessMinibatchAfter();
if (result != null) {
//Terminate training immediately
s.logReturnTime();
SparkTrainingStats workerStats = result.getSecond();
SparkTrainingStats returnStats = s.build(workerStats);
result.getFirst().setStats(returnStats);
return Collections.singletonList(result.getFirst());
}
} else {
R result;
if (isGraph)
result = worker.processMinibatch(next, graph, !batchedIterator.hasNext());
else
result = worker.processMinibatch(next, net, !batchedIterator.hasNext());
if (result != null) {
//Terminate training immediately
return Collections.singletonList(result);
}
}
}
//For some reason, we didn't return already. Normally this shouldn't happen
if (stats) {
s.logReturnTime();
Pair<R, SparkTrainingStats> pair;
if (isGraph)
pair = worker.getFinalResultWithStats(graph);
else
pair = worker.getFinalResultWithStats(net);
pair.getFirst().setStats(s.build(pair.getSecond()));
return Collections.singletonList(pair.getFirst());
} else {
if (isGraph)
return Collections.singletonList(worker.getFinalResult(graph));
else
return Collections.singletonList(worker.getFinalResult(net));
}
} finally {
//Make sure we shut down the async thread properly...
if (batchedIterator instanceof AsyncDataSetIterator) {
((AsyncDataSetIterator) batchedIterator).shutdown();
}
if (Nd4j.getExecutioner() instanceof GridExecutioner)
((GridExecutioner) Nd4j.getExecutioner()).flushQueueBlocking();
}
}
use of org.deeplearning4j.spark.api.stats.SparkTrainingStats in project deeplearning4j by deeplearning4j.
the class ExecuteWorkerMultiDataSetFlatMapAdapter method call.
@Override
public Iterable<R> call(Iterator<MultiDataSet> dataSetIterator) throws Exception {
WorkerConfiguration dataConfig = worker.getDataConfiguration();
boolean stats = dataConfig.isCollectTrainingStats();
StatsCalculationHelper s = (stats ? new StatsCalculationHelper() : null);
if (stats)
s.logMethodStartTime();
if (!dataSetIterator.hasNext()) {
if (stats)
s.logReturnTime();
//Sometimes: no data
return Collections.emptyList();
}
int batchSize = dataConfig.getBatchSizePerWorker();
final int prefetchCount = dataConfig.getPrefetchNumBatches();
MultiDataSetIterator batchedIterator = new IteratorMultiDataSetIterator(dataSetIterator, batchSize);
if (prefetchCount > 0) {
batchedIterator = new AsyncMultiDataSetIterator(batchedIterator, prefetchCount);
}
try {
if (stats)
s.logInitialModelBefore();
ComputationGraph net = worker.getInitialModelGraph();
if (stats)
s.logInitialModelAfter();
int miniBatchCount = 0;
int maxMinibatches = (dataConfig.getMaxBatchesPerWorker() > 0 ? dataConfig.getMaxBatchesPerWorker() : Integer.MAX_VALUE);
while (batchedIterator.hasNext() && miniBatchCount++ < maxMinibatches) {
if (stats)
s.logNextDataSetBefore();
MultiDataSet next = batchedIterator.next();
if (stats)
s.logNextDataSetAfter(next.getFeatures(0).size(0));
if (stats) {
s.logProcessMinibatchBefore();
Pair<R, SparkTrainingStats> result = worker.processMinibatchWithStats(next, net, !batchedIterator.hasNext());
s.logProcessMinibatchAfter();
if (result != null) {
//Terminate training immediately
s.logReturnTime();
SparkTrainingStats workerStats = result.getSecond();
SparkTrainingStats returnStats = s.build(workerStats);
result.getFirst().setStats(returnStats);
return Collections.singletonList(result.getFirst());
}
} else {
R result = worker.processMinibatch(next, net, !batchedIterator.hasNext());
if (result != null) {
//Terminate training immediately
return Collections.singletonList(result);
}
}
}
//For some reason, we didn't return already. Normally this shouldn't happen
if (stats) {
s.logReturnTime();
Pair<R, SparkTrainingStats> pair = worker.getFinalResultWithStats(net);
pair.getFirst().setStats(s.build(pair.getSecond()));
return Collections.singletonList(pair.getFirst());
} else {
return Collections.singletonList(worker.getFinalResult(net));
}
} finally {
//Make sure we shut down the async thread properly...
if (batchedIterator instanceof AsyncMultiDataSetIterator) {
((AsyncMultiDataSetIterator) batchedIterator).shutdown();
}
if (Nd4j.getExecutioner() instanceof GridExecutioner)
((GridExecutioner) Nd4j.getExecutioner()).flushQueueBlocking();
}
}
Aggregations