Search in sources :

Example 1 with VanillaStatsStorageRouterProvider

use of org.deeplearning4j.spark.impl.listeners.VanillaStatsStorageRouterProvider in project deeplearning4j by deeplearning4j.

the class SparkListenable method setListeners.

/**
     * Set the listeners, along with a StatsStorageRouter that the results will be shuffled to (in the
     * case of any listeners that implement the {@link RoutingIterationListener} interface)
     *
     * @param statsStorage Stats storage router to place the results into
     * @param listeners Listeners to set
     */
public void setListeners(StatsStorageRouter statsStorage, Collection<? extends IterationListener> listeners) {
    //Check if we have any RoutingIterationListener instances that need a StatsStorage implementation...
    StatsStorageRouterProvider routerProvider = null;
    if (listeners != null) {
        for (IterationListener l : listeners) {
            if (l instanceof RoutingIterationListener) {
                RoutingIterationListener rl = (RoutingIterationListener) l;
                if (statsStorage == null && rl.getStorageRouter() == null) {
                    log.warn("RoutingIterationListener provided without providing any StatsStorage instance. Iterator may not function without one. Listener: {}", l);
                } else if (rl.getStorageRouter() != null && !(rl.getStorageRouter() instanceof Serializable)) {
                    //Spark would throw a (probably cryptic) serialization exception later anyway...
                    throw new IllegalStateException("RoutingIterationListener provided with non-serializable storage router " + "\nRoutingIterationListener class: " + rl.getClass().getName() + "\nStatsStorageRouter class: " + rl.getStorageRouter().getClass().getName());
                }
                //Need to give workers a router provider...
                if (routerProvider == null) {
                    routerProvider = new VanillaStatsStorageRouterProvider();
                }
            }
        }
    }
    this.listeners.clear();
    if (listeners != null) {
        this.listeners.addAll(listeners);
        if (trainingMaster != null)
            trainingMaster.setListeners(statsStorage, this.listeners);
    }
}
Also used : VanillaStatsStorageRouterProvider(org.deeplearning4j.spark.impl.listeners.VanillaStatsStorageRouterProvider) Serializable(java.io.Serializable) RoutingIterationListener(org.deeplearning4j.api.storage.listener.RoutingIterationListener) RoutingIterationListener(org.deeplearning4j.api.storage.listener.RoutingIterationListener) IterationListener(org.deeplearning4j.optimize.api.IterationListener) StatsStorageRouterProvider(org.deeplearning4j.api.storage.StatsStorageRouterProvider) VanillaStatsStorageRouterProvider(org.deeplearning4j.spark.impl.listeners.VanillaStatsStorageRouterProvider)

Aggregations

Serializable (java.io.Serializable)1 StatsStorageRouterProvider (org.deeplearning4j.api.storage.StatsStorageRouterProvider)1 RoutingIterationListener (org.deeplearning4j.api.storage.listener.RoutingIterationListener)1 IterationListener (org.deeplearning4j.optimize.api.IterationListener)1 VanillaStatsStorageRouterProvider (org.deeplearning4j.spark.impl.listeners.VanillaStatsStorageRouterProvider)1