Search in sources :

Example 1 with RoutingIterationListener

use of org.deeplearning4j.api.storage.listener.RoutingIterationListener in project deeplearning4j by deeplearning4j.

the class ParallelWrapper method setListeners.

/**
     * Set the listeners, along with a StatsStorageRouter that the results will be shuffled to (in the case of any listeners
     * that implement the {@link RoutingIterationListener} interface)
     *
     * @param statsStorage Stats storage router to place the results into
     * @param listeners    Listeners to set
     */
public void setListeners(StatsStorageRouter statsStorage, Collection<? extends IterationListener> listeners) {
    //Check if we have any RoutingIterationListener instances that need a StatsStorage implementation...
    StatsStorageRouterProvider routerProvider = null;
    if (listeners != null) {
        for (IterationListener l : listeners) {
            if (l instanceof RoutingIterationListener) {
                RoutingIterationListener rl = (RoutingIterationListener) l;
                if (rl.getStorageRouter() == null) {
                    log.warn("RoutingIterationListener provided without providing any StatsStorage instance. Iterator may not function without one. Listener: {}", l);
                } else if (!(rl.getStorageRouter() instanceof Serializable)) {
                    //Spark would throw a (probably cryptic) serialization exception later anyway...
                    throw new IllegalStateException("RoutingIterationListener provided with non-serializable storage router");
                }
            }
        }
    }
    this.storageRouter = statsStorage;
    this.listeners.addAll(listeners);
}
Also used : ND4JIllegalStateException(org.nd4j.linalg.exception.ND4JIllegalStateException) Serializable(java.io.Serializable) RoutingIterationListener(org.deeplearning4j.api.storage.listener.RoutingIterationListener) RoutingIterationListener(org.deeplearning4j.api.storage.listener.RoutingIterationListener) IterationListener(org.deeplearning4j.optimize.api.IterationListener) StatsStorageRouterProvider(org.deeplearning4j.api.storage.StatsStorageRouterProvider)

Example 2 with RoutingIterationListener

use of org.deeplearning4j.api.storage.listener.RoutingIterationListener in project deeplearning4j by deeplearning4j.

the class SparkListenable method setListeners.

/**
     * Set the listeners, along with a StatsStorageRouter that the results will be shuffled to (in the
     * case of any listeners that implement the {@link RoutingIterationListener} interface)
     *
     * @param statsStorage Stats storage router to place the results into
     * @param listeners Listeners to set
     */
public void setListeners(StatsStorageRouter statsStorage, Collection<? extends IterationListener> listeners) {
    //Check if we have any RoutingIterationListener instances that need a StatsStorage implementation...
    StatsStorageRouterProvider routerProvider = null;
    if (listeners != null) {
        for (IterationListener l : listeners) {
            if (l instanceof RoutingIterationListener) {
                RoutingIterationListener rl = (RoutingIterationListener) l;
                if (statsStorage == null && rl.getStorageRouter() == null) {
                    log.warn("RoutingIterationListener provided without providing any StatsStorage instance. Iterator may not function without one. Listener: {}", l);
                } else if (rl.getStorageRouter() != null && !(rl.getStorageRouter() instanceof Serializable)) {
                    //Spark would throw a (probably cryptic) serialization exception later anyway...
                    throw new IllegalStateException("RoutingIterationListener provided with non-serializable storage router " + "\nRoutingIterationListener class: " + rl.getClass().getName() + "\nStatsStorageRouter class: " + rl.getStorageRouter().getClass().getName());
                }
                //Need to give workers a router provider...
                if (routerProvider == null) {
                    routerProvider = new VanillaStatsStorageRouterProvider();
                }
            }
        }
    }
    this.listeners.clear();
    if (listeners != null) {
        this.listeners.addAll(listeners);
        if (trainingMaster != null)
            trainingMaster.setListeners(statsStorage, this.listeners);
    }
}
Also used : VanillaStatsStorageRouterProvider(org.deeplearning4j.spark.impl.listeners.VanillaStatsStorageRouterProvider) Serializable(java.io.Serializable) RoutingIterationListener(org.deeplearning4j.api.storage.listener.RoutingIterationListener) RoutingIterationListener(org.deeplearning4j.api.storage.listener.RoutingIterationListener) IterationListener(org.deeplearning4j.optimize.api.IterationListener) StatsStorageRouterProvider(org.deeplearning4j.api.storage.StatsStorageRouterProvider) VanillaStatsStorageRouterProvider(org.deeplearning4j.spark.impl.listeners.VanillaStatsStorageRouterProvider)

Example 3 with RoutingIterationListener

use of org.deeplearning4j.api.storage.listener.RoutingIterationListener in project deeplearning4j by deeplearning4j.

the class ParameterAveragingTrainingWorker method configureListeners.

private void configureListeners(Model m, int counter) {
    if (iterationListeners != null) {
        List<IterationListener> list = new ArrayList<>(iterationListeners.size());
        for (IterationListener l : iterationListeners) {
            if (listenerRouterProvider != null && l instanceof RoutingIterationListener) {
                RoutingIterationListener rl = (RoutingIterationListener) l;
                rl.setStorageRouter(listenerRouterProvider.getRouter());
                String workerID = UIDProvider.getJVMUID() + "_" + counter;
                rl.setWorkerID(workerID);
            }
            //Don't need to clone listeners: not from broadcast, so deserialization handles
            list.add(l);
        }
        if (m instanceof MultiLayerNetwork)
            ((MultiLayerNetwork) m).setListeners(list);
        else
            ((ComputationGraph) m).setListeners(list);
    }
}
Also used : RoutingIterationListener(org.deeplearning4j.api.storage.listener.RoutingIterationListener) RoutingIterationListener(org.deeplearning4j.api.storage.listener.RoutingIterationListener) IterationListener(org.deeplearning4j.optimize.api.IterationListener) ArrayList(java.util.ArrayList) MultiLayerNetwork(org.deeplearning4j.nn.multilayer.MultiLayerNetwork)

Aggregations

RoutingIterationListener (org.deeplearning4j.api.storage.listener.RoutingIterationListener)3 IterationListener (org.deeplearning4j.optimize.api.IterationListener)3 Serializable (java.io.Serializable)2 StatsStorageRouterProvider (org.deeplearning4j.api.storage.StatsStorageRouterProvider)2 ArrayList (java.util.ArrayList)1 MultiLayerNetwork (org.deeplearning4j.nn.multilayer.MultiLayerNetwork)1 VanillaStatsStorageRouterProvider (org.deeplearning4j.spark.impl.listeners.VanillaStatsStorageRouterProvider)1 ND4JIllegalStateException (org.nd4j.linalg.exception.ND4JIllegalStateException)1