Search in sources :

Example 1 with NetBroadcastTuple

use of org.deeplearning4j.spark.api.worker.NetBroadcastTuple in project deeplearning4j by deeplearning4j.

the class ParameterAveragingTrainingWorker method getInitialModel.

@Override
public MultiLayerNetwork getInitialModel() {
    if (configuration.isCollectTrainingStats())
        stats = new ParameterAveragingTrainingWorkerStats.ParameterAveragingTrainingWorkerStatsHelper();
    if (configuration.isCollectTrainingStats())
        stats.logBroadcastGetValueStart();
    NetBroadcastTuple tuple = broadcast.getValue();
    if (configuration.isCollectTrainingStats())
        stats.logBroadcastGetValueEnd();
    //Don't want to have shared configuration object: each may update its iteration count (for LR schedule etc) individually
    MultiLayerNetwork net = new MultiLayerNetwork(tuple.getConfiguration().clone());
    //Can't have shared parameter array across executors for parameter averaging, hence the 'true' for clone parameters array arg
    net.init(tuple.getParameters().unsafeDuplication(), false);
    if (tuple.getUpdaterState() != null) {
        //Can't have shared updater state
        net.setUpdater(new MultiLayerUpdater(net, tuple.getUpdaterState().unsafeDuplication()));
    }
    if (Nd4j.getExecutioner() instanceof GridExecutioner)
        ((GridExecutioner) Nd4j.getExecutioner()).flushQueueBlocking();
    configureListeners(net, tuple.getCounter().getAndIncrement());
    if (configuration.isCollectTrainingStats())
        stats.logInitEnd();
    return net;
}
Also used : MultiLayerUpdater(org.deeplearning4j.nn.updater.MultiLayerUpdater) GridExecutioner(org.nd4j.linalg.api.ops.executioner.GridExecutioner) NetBroadcastTuple(org.deeplearning4j.spark.api.worker.NetBroadcastTuple) MultiLayerNetwork(org.deeplearning4j.nn.multilayer.MultiLayerNetwork)

Example 2 with NetBroadcastTuple

use of org.deeplearning4j.spark.api.worker.NetBroadcastTuple in project deeplearning4j by deeplearning4j.

the class ParameterAveragingTrainingWorker method getInitialModelGraph.

@Override
public ComputationGraph getInitialModelGraph() {
    if (configuration.isCollectTrainingStats())
        stats = new ParameterAveragingTrainingWorkerStats.ParameterAveragingTrainingWorkerStatsHelper();
    if (configuration.isCollectTrainingStats())
        stats.logBroadcastGetValueStart();
    NetBroadcastTuple tuple = broadcast.getValue();
    if (configuration.isCollectTrainingStats())
        stats.logBroadcastGetValueEnd();
    //Don't want to have shared configuration object: each may update its iteration count (for LR schedule etc) individually
    ComputationGraph net = new ComputationGraph(tuple.getGraphConfiguration().clone());
    //Can't have shared parameter array across executors for parameter averaging, hence the 'true' for clone parameters array arg
    net.init(tuple.getParameters().unsafeDuplication(), false);
    if (tuple.getUpdaterState() != null) {
        //Again: can't have shared updater state
        net.setUpdater(new ComputationGraphUpdater(net, tuple.getUpdaterState().unsafeDuplication()));
    }
    if (Nd4j.getExecutioner() instanceof GridExecutioner)
        ((GridExecutioner) Nd4j.getExecutioner()).flushQueueBlocking();
    configureListeners(net, tuple.getCounter().getAndIncrement());
    if (configuration.isCollectTrainingStats())
        stats.logInitEnd();
    return net;
}
Also used : GridExecutioner(org.nd4j.linalg.api.ops.executioner.GridExecutioner) ComputationGraphUpdater(org.deeplearning4j.nn.updater.graph.ComputationGraphUpdater) NetBroadcastTuple(org.deeplearning4j.spark.api.worker.NetBroadcastTuple) ComputationGraph(org.deeplearning4j.nn.graph.ComputationGraph)

Aggregations

NetBroadcastTuple (org.deeplearning4j.spark.api.worker.NetBroadcastTuple)2 GridExecutioner (org.nd4j.linalg.api.ops.executioner.GridExecutioner)2 ComputationGraph (org.deeplearning4j.nn.graph.ComputationGraph)1 MultiLayerNetwork (org.deeplearning4j.nn.multilayer.MultiLayerNetwork)1 MultiLayerUpdater (org.deeplearning4j.nn.updater.MultiLayerUpdater)1 ComputationGraphUpdater (org.deeplearning4j.nn.updater.graph.ComputationGraphUpdater)1