Search in sources :

Example 21 with IterationListener

use of org.deeplearning4j.optimize.api.IterationListener in project deeplearning4j by deeplearning4j.

the class MultiLayerNetwork method setListeners.

@Override
public void setListeners(IterationListener... listeners) {
    Collection<IterationListener> cListeners = new ArrayList<>();
    //This results in an IterationListener[1] with a single null value -> results in a NPE later
    if (listeners != null && listeners.length > 0) {
        for (IterationListener i : listeners) {
            if (i != null)
                cListeners.add(i);
        }
    }
    setListeners(cListeners);
}
Also used : IterationListener(org.deeplearning4j.optimize.api.IterationListener)

Example 22 with IterationListener

use of org.deeplearning4j.optimize.api.IterationListener in project deeplearning4j by deeplearning4j.

the class StochasticGradientDescent method optimize.

@Override
public boolean optimize() {
    for (int i = 0; i < conf.getNumIterations(); i++) {
        Pair<Gradient, Double> pair = gradientAndScore();
        Gradient gradient = pair.getFirst();
        INDArray params = model.params();
        stepFunction.step(params, gradient.gradient());
        //Note: model.params() is always in-place for MultiLayerNetwork and ComputationGraph, hence no setParams is necessary there
        //However: for pretrain layers, params are NOT a view. Thus a setParams call is necessary
        //But setParams should be a no-op for MLN and CG
        model.setParams(params);
        int iterationCount = BaseOptimizer.getIterationCount(model);
        for (IterationListener listener : iterationListeners) listener.iterationDone(model, iterationCount);
        checkTerminalConditions(pair.getFirst().gradient(), oldScore, score, i);
        BaseOptimizer.incrementIterationCount(model, 1);
    }
    return true;
}
Also used : Gradient(org.deeplearning4j.nn.gradient.Gradient) INDArray(org.nd4j.linalg.api.ndarray.INDArray) IterationListener(org.deeplearning4j.optimize.api.IterationListener)

Example 23 with IterationListener

use of org.deeplearning4j.optimize.api.IterationListener in project deeplearning4j by deeplearning4j.

the class SparkListenable method setListeners.

/**
     * Set the listeners, along with a StatsStorageRouter that the results will be shuffled to (in the
     * case of any listeners that implement the {@link RoutingIterationListener} interface)
     *
     * @param statsStorage Stats storage router to place the results into
     * @param listeners Listeners to set
     */
public void setListeners(StatsStorageRouter statsStorage, Collection<? extends IterationListener> listeners) {
    //Check if we have any RoutingIterationListener instances that need a StatsStorage implementation...
    StatsStorageRouterProvider routerProvider = null;
    if (listeners != null) {
        for (IterationListener l : listeners) {
            if (l instanceof RoutingIterationListener) {
                RoutingIterationListener rl = (RoutingIterationListener) l;
                if (statsStorage == null && rl.getStorageRouter() == null) {
                    log.warn("RoutingIterationListener provided without providing any StatsStorage instance. Iterator may not function without one. Listener: {}", l);
                } else if (rl.getStorageRouter() != null && !(rl.getStorageRouter() instanceof Serializable)) {
                    //Spark would throw a (probably cryptic) serialization exception later anyway...
                    throw new IllegalStateException("RoutingIterationListener provided with non-serializable storage router " + "\nRoutingIterationListener class: " + rl.getClass().getName() + "\nStatsStorageRouter class: " + rl.getStorageRouter().getClass().getName());
                }
                //Need to give workers a router provider...
                if (routerProvider == null) {
                    routerProvider = new VanillaStatsStorageRouterProvider();
                }
            }
        }
    }
    this.listeners.clear();
    if (listeners != null) {
        this.listeners.addAll(listeners);
        if (trainingMaster != null)
            trainingMaster.setListeners(statsStorage, this.listeners);
    }
}
Also used : VanillaStatsStorageRouterProvider(org.deeplearning4j.spark.impl.listeners.VanillaStatsStorageRouterProvider) Serializable(java.io.Serializable) RoutingIterationListener(org.deeplearning4j.api.storage.listener.RoutingIterationListener) RoutingIterationListener(org.deeplearning4j.api.storage.listener.RoutingIterationListener) IterationListener(org.deeplearning4j.optimize.api.IterationListener) StatsStorageRouterProvider(org.deeplearning4j.api.storage.StatsStorageRouterProvider) VanillaStatsStorageRouterProvider(org.deeplearning4j.spark.impl.listeners.VanillaStatsStorageRouterProvider)

Example 24 with IterationListener

use of org.deeplearning4j.optimize.api.IterationListener in project deeplearning4j by deeplearning4j.

the class ParameterAveragingTrainingWorker method configureListeners.

private void configureListeners(Model m, int counter) {
    if (iterationListeners != null) {
        List<IterationListener> list = new ArrayList<>(iterationListeners.size());
        for (IterationListener l : iterationListeners) {
            if (listenerRouterProvider != null && l instanceof RoutingIterationListener) {
                RoutingIterationListener rl = (RoutingIterationListener) l;
                rl.setStorageRouter(listenerRouterProvider.getRouter());
                String workerID = UIDProvider.getJVMUID() + "_" + counter;
                rl.setWorkerID(workerID);
            }
            //Don't need to clone listeners: not from broadcast, so deserialization handles
            list.add(l);
        }
        if (m instanceof MultiLayerNetwork)
            ((MultiLayerNetwork) m).setListeners(list);
        else
            ((ComputationGraph) m).setListeners(list);
    }
}
Also used : RoutingIterationListener(org.deeplearning4j.api.storage.listener.RoutingIterationListener) RoutingIterationListener(org.deeplearning4j.api.storage.listener.RoutingIterationListener) IterationListener(org.deeplearning4j.optimize.api.IterationListener) ArrayList(java.util.ArrayList) MultiLayerNetwork(org.deeplearning4j.nn.multilayer.MultiLayerNetwork)

Aggregations

IterationListener (org.deeplearning4j.optimize.api.IterationListener)24 ScoreIterationListener (org.deeplearning4j.optimize.listeners.ScoreIterationListener)15 Test (org.junit.Test)15 DataSet (org.nd4j.linalg.dataset.DataSet)12 MultiLayerNetwork (org.deeplearning4j.nn.multilayer.MultiLayerNetwork)11 INDArray (org.nd4j.linalg.api.ndarray.INDArray)9 NeuralNetConfiguration (org.deeplearning4j.nn.conf.NeuralNetConfiguration)7 IrisDataSetIterator (org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator)6 MultiLayerConfiguration (org.deeplearning4j.nn.conf.MultiLayerConfiguration)5 DataSetIterator (org.nd4j.linalg.dataset.api.iterator.DataSetIterator)5 RoutingIterationListener (org.deeplearning4j.api.storage.listener.RoutingIterationListener)4 Evaluation (org.deeplearning4j.eval.Evaluation)4 OptimizationAlgorithm (org.deeplearning4j.nn.api.OptimizationAlgorithm)4 Layer (org.deeplearning4j.nn.api.Layer)3 RnnOutputLayer (org.deeplearning4j.nn.layers.recurrent.RnnOutputLayer)3 SplitTestAndTrain (org.nd4j.linalg.dataset.SplitTestAndTrain)3 Serializable (java.io.Serializable)2 StatsStorageRouterProvider (org.deeplearning4j.api.storage.StatsStorageRouterProvider)2 IOutputLayer (org.deeplearning4j.nn.api.layers.IOutputLayer)2 RecurrentLayer (org.deeplearning4j.nn.api.layers.RecurrentLayer)2