use of org.deeplearning4j.spark.impl.paramavg.aggregator.ParameterAveragingAggregationTuple in project deeplearning4j by deeplearning4j.
the class ParameterAveragingTrainingMaster method processResults.
private void processResults(SparkDl4jMultiLayer network, SparkComputationGraph graph, JavaRDD<ParameterAveragingTrainingResult> results, int splitNum, int totalSplits) {
if (collectTrainingStats)
stats.logAggregateStartTime();
ParameterAveragingAggregationTuple tuple = results.aggregate(null, new ParameterAveragingElementAddFunction(), new ParameterAveragingElementCombineFunction());
INDArray params = tuple.getParametersSum();
int aggCount = tuple.getAggregationsCount();
SparkTrainingStats aggregatedStats = tuple.getSparkTrainingStats();
if (collectTrainingStats)
stats.logAggregationEndTime();
if (collectTrainingStats)
stats.logProcessParamsUpdaterStart();
if (params != null) {
params.divi(aggCount);
INDArray updaterState = tuple.getUpdaterStateSum();
if (updaterState != null)
//May be null if all SGD updaters, for example
updaterState.divi(aggCount);
if (network != null) {
MultiLayerNetwork net = network.getNetwork();
net.setParameters(params);
if (updaterState != null)
net.getUpdater().setStateViewArray(null, updaterState, false);
network.setScore(tuple.getScoreSum() / tuple.getAggregationsCount());
} else {
ComputationGraph g = graph.getNetwork();
g.setParams(params);
if (updaterState != null)
g.getUpdater().setStateViewArray(updaterState);
graph.setScore(tuple.getScoreSum() / tuple.getAggregationsCount());
}
} else {
log.info("Skipping imbalanced split with no data for all executors");
}
if (collectTrainingStats) {
stats.logProcessParamsUpdaterEnd();
stats.addWorkerStats(aggregatedStats);
}
if (statsStorage != null) {
Collection<StorageMetaData> meta = tuple.getListenerMetaData();
if (meta != null && meta.size() > 0) {
statsStorage.putStorageMetaData(meta);
}
Collection<Persistable> staticInfo = tuple.getListenerStaticInfo();
if (staticInfo != null && staticInfo.size() > 0) {
statsStorage.putStaticInfo(staticInfo);
}
Collection<Persistable> updates = tuple.getListenerUpdates();
if (updates != null && updates.size() > 0) {
statsStorage.putUpdate(updates);
}
}
if (Nd4j.getExecutioner() instanceof GridExecutioner)
((GridExecutioner) Nd4j.getExecutioner()).flushQueueBlocking();
log.info("Completed training of split {} of {}", splitNum, totalSplits);
if (params != null) {
//Params may be null for edge case (empty RDD)
if (network != null) {
MultiLayerConfiguration conf = network.getNetwork().getLayerWiseConfigurations();
int numUpdates = network.getNetwork().conf().getNumIterations() * averagingFrequency;
conf.setIterationCount(conf.getIterationCount() + numUpdates);
} else {
ComputationGraphConfiguration conf = graph.getNetwork().getConfiguration();
int numUpdates = graph.getNetwork().conf().getNumIterations() * averagingFrequency;
conf.setIterationCount(conf.getIterationCount() + numUpdates);
}
}
}
Aggregations