Search in sources :

Example 1 with MultiLayerUpdater

use of org.deeplearning4j.nn.updater.MultiLayerUpdater in project deeplearning4j by deeplearning4j.

the class ParameterAveragingTrainingWorker method getInitialModel.

@Override
public MultiLayerNetwork getInitialModel() {
    if (configuration.isCollectTrainingStats())
        stats = new ParameterAveragingTrainingWorkerStats.ParameterAveragingTrainingWorkerStatsHelper();
    if (configuration.isCollectTrainingStats())
        stats.logBroadcastGetValueStart();
    NetBroadcastTuple tuple = broadcast.getValue();
    if (configuration.isCollectTrainingStats())
        stats.logBroadcastGetValueEnd();
    //Don't want to have shared configuration object: each may update its iteration count (for LR schedule etc) individually
    MultiLayerNetwork net = new MultiLayerNetwork(tuple.getConfiguration().clone());
    //Can't have shared parameter array across executors for parameter averaging, hence the 'true' for clone parameters array arg
    net.init(tuple.getParameters().unsafeDuplication(), false);
    if (tuple.getUpdaterState() != null) {
        //Can't have shared updater state
        net.setUpdater(new MultiLayerUpdater(net, tuple.getUpdaterState().unsafeDuplication()));
    }
    if (Nd4j.getExecutioner() instanceof GridExecutioner)
        ((GridExecutioner) Nd4j.getExecutioner()).flushQueueBlocking();
    configureListeners(net, tuple.getCounter().getAndIncrement());
    if (configuration.isCollectTrainingStats())
        stats.logInitEnd();
    return net;
}
Also used : MultiLayerUpdater(org.deeplearning4j.nn.updater.MultiLayerUpdater) GridExecutioner(org.nd4j.linalg.api.ops.executioner.GridExecutioner) NetBroadcastTuple(org.deeplearning4j.spark.api.worker.NetBroadcastTuple) MultiLayerNetwork(org.deeplearning4j.nn.multilayer.MultiLayerNetwork)

Example 2 with MultiLayerUpdater

use of org.deeplearning4j.nn.updater.MultiLayerUpdater in project deeplearning4j by deeplearning4j.

the class MultiLayerNetwork method update.

/**
     * Assigns the parameters of this model to the ones specified by this
     * network. This is used in loading from input streams, factory methods, etc
     *
     * @param network the network to getFromOrigin parameters from
     */
public void update(MultiLayerNetwork network) {
    this.defaultConfiguration = (network.defaultConfiguration != null ? network.defaultConfiguration.clone() : null);
    if (network.input != null)
        //Dup in case of dropout etc
        setInput(network.input.dup());
    this.labels = network.labels;
    if (network.layers != null) {
        layers = new Layer[network.layers.length];
        for (int i = 0; i < layers.length; i++) {
            layers[i] = network.layers[i].clone();
        }
    } else {
        this.layers = null;
    }
    if (network.solver != null) {
        //Network updater state: should be cloned over also
        INDArray updaterView = network.getUpdater().getStateViewArray();
        if (updaterView != null) {
            Updater newUpdater = new MultiLayerUpdater(this, updaterView.dup());
            this.setUpdater(newUpdater);
        }
    } else {
        this.solver = null;
    }
}
Also used : MultiLayerUpdater(org.deeplearning4j.nn.updater.MultiLayerUpdater) INDArray(org.nd4j.linalg.api.ndarray.INDArray) MultiLayerUpdater(org.deeplearning4j.nn.updater.MultiLayerUpdater)

Aggregations

MultiLayerUpdater (org.deeplearning4j.nn.updater.MultiLayerUpdater)2 MultiLayerNetwork (org.deeplearning4j.nn.multilayer.MultiLayerNetwork)1 NetBroadcastTuple (org.deeplearning4j.spark.api.worker.NetBroadcastTuple)1 INDArray (org.nd4j.linalg.api.ndarray.INDArray)1 GridExecutioner (org.nd4j.linalg.api.ops.executioner.GridExecutioner)1