use of org.deeplearning4j.api.storage.listener.RoutingIterationListener in project deeplearning4j by deeplearning4j.
the class ParallelWrapper method setListeners.
/**
* Set the listeners, along with a StatsStorageRouter that the results will be shuffled to (in the case of any listeners
* that implement the {@link RoutingIterationListener} interface)
*
* @param statsStorage Stats storage router to place the results into
* @param listeners Listeners to set
*/
public void setListeners(StatsStorageRouter statsStorage, Collection<? extends IterationListener> listeners) {
//Check if we have any RoutingIterationListener instances that need a StatsStorage implementation...
StatsStorageRouterProvider routerProvider = null;
if (listeners != null) {
for (IterationListener l : listeners) {
if (l instanceof RoutingIterationListener) {
RoutingIterationListener rl = (RoutingIterationListener) l;
if (rl.getStorageRouter() == null) {
log.warn("RoutingIterationListener provided without providing any StatsStorage instance. Iterator may not function without one. Listener: {}", l);
} else if (!(rl.getStorageRouter() instanceof Serializable)) {
//Spark would throw a (probably cryptic) serialization exception later anyway...
throw new IllegalStateException("RoutingIterationListener provided with non-serializable storage router");
}
}
}
}
this.storageRouter = statsStorage;
this.listeners.addAll(listeners);
}
use of org.deeplearning4j.api.storage.listener.RoutingIterationListener in project deeplearning4j by deeplearning4j.
the class SparkListenable method setListeners.
/**
* Set the listeners, along with a StatsStorageRouter that the results will be shuffled to (in the
* case of any listeners that implement the {@link RoutingIterationListener} interface)
*
* @param statsStorage Stats storage router to place the results into
* @param listeners Listeners to set
*/
public void setListeners(StatsStorageRouter statsStorage, Collection<? extends IterationListener> listeners) {
//Check if we have any RoutingIterationListener instances that need a StatsStorage implementation...
StatsStorageRouterProvider routerProvider = null;
if (listeners != null) {
for (IterationListener l : listeners) {
if (l instanceof RoutingIterationListener) {
RoutingIterationListener rl = (RoutingIterationListener) l;
if (statsStorage == null && rl.getStorageRouter() == null) {
log.warn("RoutingIterationListener provided without providing any StatsStorage instance. Iterator may not function without one. Listener: {}", l);
} else if (rl.getStorageRouter() != null && !(rl.getStorageRouter() instanceof Serializable)) {
//Spark would throw a (probably cryptic) serialization exception later anyway...
throw new IllegalStateException("RoutingIterationListener provided with non-serializable storage router " + "\nRoutingIterationListener class: " + rl.getClass().getName() + "\nStatsStorageRouter class: " + rl.getStorageRouter().getClass().getName());
}
//Need to give workers a router provider...
if (routerProvider == null) {
routerProvider = new VanillaStatsStorageRouterProvider();
}
}
}
}
this.listeners.clear();
if (listeners != null) {
this.listeners.addAll(listeners);
if (trainingMaster != null)
trainingMaster.setListeners(statsStorage, this.listeners);
}
}
use of org.deeplearning4j.api.storage.listener.RoutingIterationListener in project deeplearning4j by deeplearning4j.
the class ParameterAveragingTrainingWorker method configureListeners.
private void configureListeners(Model m, int counter) {
if (iterationListeners != null) {
List<IterationListener> list = new ArrayList<>(iterationListeners.size());
for (IterationListener l : iterationListeners) {
if (listenerRouterProvider != null && l instanceof RoutingIterationListener) {
RoutingIterationListener rl = (RoutingIterationListener) l;
rl.setStorageRouter(listenerRouterProvider.getRouter());
String workerID = UIDProvider.getJVMUID() + "_" + counter;
rl.setWorkerID(workerID);
}
//Don't need to clone listeners: not from broadcast, so deserialization handles
list.add(l);
}
if (m instanceof MultiLayerNetwork)
((MultiLayerNetwork) m).setListeners(list);
else
((ComputationGraph) m).setListeners(list);
}
}
Aggregations