use of org.nd4j.linalg.api.ops.executioner.GridExecutioner in project deeplearning4j by deeplearning4j.
the class ParameterAveragingTrainingWorker method getFinalResult.
@Override
public ParameterAveragingTrainingResult getFinalResult(ComputationGraph network) {
INDArray updaterState = null;
if (saveUpdater) {
ComputationGraphUpdater u = network.getUpdater();
if (u != null)
updaterState = u.getStateViewArray();
}
if (Nd4j.getExecutioner() instanceof GridExecutioner)
((GridExecutioner) Nd4j.getExecutioner()).flushQueueBlocking();
Collection<StorageMetaData> storageMetaData = null;
Collection<Persistable> listenerStaticInfo = null;
Collection<Persistable> listenerUpdates = null;
if (listenerRouterProvider != null) {
StatsStorageRouter r = listenerRouterProvider.getRouter();
if (r instanceof VanillaStatsStorageRouter) {
//TODO this is ugly... need to find a better solution
VanillaStatsStorageRouter ssr = (VanillaStatsStorageRouter) r;
storageMetaData = ssr.getStorageMetaData();
listenerStaticInfo = ssr.getStaticInfo();
listenerUpdates = ssr.getUpdates();
}
}
return new ParameterAveragingTrainingResult(network.params(), updaterState, network.score(), storageMetaData, listenerStaticInfo, listenerUpdates);
}
use of org.nd4j.linalg.api.ops.executioner.GridExecutioner in project deeplearning4j by deeplearning4j.
the class ParameterAveragingTrainingWorker method getFinalResult.
@Override
public ParameterAveragingTrainingResult getFinalResult(MultiLayerNetwork network) {
INDArray updaterState = null;
if (saveUpdater) {
Updater u = network.getUpdater();
if (u != null)
updaterState = u.getStateViewArray();
}
if (Nd4j.getExecutioner() instanceof GridExecutioner)
((GridExecutioner) Nd4j.getExecutioner()).flushQueueBlocking();
Collection<StorageMetaData> storageMetaData = null;
Collection<Persistable> listenerStaticInfo = null;
Collection<Persistable> listenerUpdates = null;
if (listenerRouterProvider != null) {
StatsStorageRouter r = listenerRouterProvider.getRouter();
if (r instanceof VanillaStatsStorageRouter) {
//TODO this is ugly... need to find a better solution
VanillaStatsStorageRouter ssr = (VanillaStatsStorageRouter) r;
storageMetaData = ssr.getStorageMetaData();
listenerStaticInfo = ssr.getStaticInfo();
listenerUpdates = ssr.getUpdates();
}
}
return new ParameterAveragingTrainingResult(network.params(), updaterState, network.score(), storageMetaData, listenerStaticInfo, listenerUpdates);
}
use of org.nd4j.linalg.api.ops.executioner.GridExecutioner in project deeplearning4j by deeplearning4j.
the class ParameterAveragingElementAddFunction method call.
@Override
public ParameterAveragingAggregationTuple call(ParameterAveragingAggregationTuple tuple, ParameterAveragingTrainingResult result) throws Exception {
if (tuple == null) {
return ParameterAveragingAggregationTuple.builder().parametersSum(result.getParameters()).updaterStateSum(result.getUpdaterState()).scoreSum(result.getScore()).aggregationsCount(1).sparkTrainingStats(result.getSparkTrainingStats()).listenerMetaData(result.getListenerMetaData()).listenerStaticInfo(result.getListenerStaticInfo()).listenerUpdates(result.getListenerUpdates()).build();
}
INDArray params = tuple.getParametersSum().addi(result.getParameters());
INDArray updaterStateSum;
if (tuple.getUpdaterStateSum() == null) {
updaterStateSum = result.getUpdaterState();
} else {
updaterStateSum = tuple.getUpdaterStateSum();
if (result.getUpdaterState() != null)
updaterStateSum.addi(result.getUpdaterState());
}
double scoreSum = tuple.getScoreSum() + result.getScore();
SparkTrainingStats stats = tuple.getSparkTrainingStats();
if (result.getSparkTrainingStats() != null) {
if (stats == null)
stats = result.getSparkTrainingStats();
else
stats.addOtherTrainingStats(result.getSparkTrainingStats());
}
if (Nd4j.getExecutioner() instanceof GridExecutioner)
((GridExecutioner) Nd4j.getExecutioner()).flushQueueBlocking();
Collection<StorageMetaData> listenerMetaData = tuple.getListenerMetaData();
if (listenerMetaData == null)
listenerMetaData = result.getListenerMetaData();
else {
Collection<StorageMetaData> newMeta = result.getListenerMetaData();
if (newMeta != null)
listenerMetaData.addAll(newMeta);
}
Collection<Persistable> listenerStaticInfo = tuple.getListenerStaticInfo();
if (listenerStaticInfo == null)
listenerStaticInfo = result.getListenerStaticInfo();
else {
Collection<Persistable> newStatic = tuple.getListenerStaticInfo();
if (newStatic != null)
listenerStaticInfo.addAll(newStatic);
}
Collection<Persistable> listenerUpdates = tuple.getListenerUpdates();
if (listenerUpdates == null)
listenerUpdates = result.getListenerUpdates();
else {
Collection<Persistable> newUpdates = result.getListenerUpdates();
if (newUpdates != null)
listenerUpdates.addAll(newUpdates);
}
return new ParameterAveragingAggregationTuple(params, updaterStateSum, scoreSum, tuple.getAggregationsCount() + 1, stats, listenerMetaData, listenerStaticInfo, listenerUpdates);
}
Aggregations