use of org.deeplearning4j.nn.updater.MultiLayerUpdater in project deeplearning4j by deeplearning4j.
the class ParameterAveragingTrainingWorker method getInitialModel.
@Override
public MultiLayerNetwork getInitialModel() {
if (configuration.isCollectTrainingStats())
stats = new ParameterAveragingTrainingWorkerStats.ParameterAveragingTrainingWorkerStatsHelper();
if (configuration.isCollectTrainingStats())
stats.logBroadcastGetValueStart();
NetBroadcastTuple tuple = broadcast.getValue();
if (configuration.isCollectTrainingStats())
stats.logBroadcastGetValueEnd();
//Don't want to have shared configuration object: each may update its iteration count (for LR schedule etc) individually
MultiLayerNetwork net = new MultiLayerNetwork(tuple.getConfiguration().clone());
//Can't have shared parameter array across executors for parameter averaging, hence the 'true' for clone parameters array arg
net.init(tuple.getParameters().unsafeDuplication(), false);
if (tuple.getUpdaterState() != null) {
//Can't have shared updater state
net.setUpdater(new MultiLayerUpdater(net, tuple.getUpdaterState().unsafeDuplication()));
}
if (Nd4j.getExecutioner() instanceof GridExecutioner)
((GridExecutioner) Nd4j.getExecutioner()).flushQueueBlocking();
configureListeners(net, tuple.getCounter().getAndIncrement());
if (configuration.isCollectTrainingStats())
stats.logInitEnd();
return net;
}
use of org.deeplearning4j.nn.updater.MultiLayerUpdater in project deeplearning4j by deeplearning4j.
the class MultiLayerNetwork method update.
/**
* Assigns the parameters of this model to the ones specified by this
* network. This is used in loading from input streams, factory methods, etc
*
* @param network the network to getFromOrigin parameters from
*/
public void update(MultiLayerNetwork network) {
this.defaultConfiguration = (network.defaultConfiguration != null ? network.defaultConfiguration.clone() : null);
if (network.input != null)
//Dup in case of dropout etc
setInput(network.input.dup());
this.labels = network.labels;
if (network.layers != null) {
layers = new Layer[network.layers.length];
for (int i = 0; i < layers.length; i++) {
layers[i] = network.layers[i].clone();
}
} else {
this.layers = null;
}
if (network.solver != null) {
//Network updater state: should be cloned over also
INDArray updaterView = network.getUpdater().getStateViewArray();
if (updaterView != null) {
Updater newUpdater = new MultiLayerUpdater(this, updaterView.dup());
this.setUpdater(newUpdater);
}
} else {
this.solver = null;
}
}
Aggregations