use of org.deeplearning4j.optimize.api.IterationListener in project deeplearning4j by deeplearning4j.
the class MultiLayerNetwork method setListeners.
@Override
public void setListeners(IterationListener... listeners) {
Collection<IterationListener> cListeners = new ArrayList<>();
//This results in an IterationListener[1] with a single null value -> results in a NPE later
if (listeners != null && listeners.length > 0) {
for (IterationListener i : listeners) {
if (i != null)
cListeners.add(i);
}
}
setListeners(cListeners);
}
use of org.deeplearning4j.optimize.api.IterationListener in project deeplearning4j by deeplearning4j.
the class StochasticGradientDescent method optimize.
@Override
public boolean optimize() {
for (int i = 0; i < conf.getNumIterations(); i++) {
Pair<Gradient, Double> pair = gradientAndScore();
Gradient gradient = pair.getFirst();
INDArray params = model.params();
stepFunction.step(params, gradient.gradient());
//Note: model.params() is always in-place for MultiLayerNetwork and ComputationGraph, hence no setParams is necessary there
//However: for pretrain layers, params are NOT a view. Thus a setParams call is necessary
//But setParams should be a no-op for MLN and CG
model.setParams(params);
int iterationCount = BaseOptimizer.getIterationCount(model);
for (IterationListener listener : iterationListeners) listener.iterationDone(model, iterationCount);
checkTerminalConditions(pair.getFirst().gradient(), oldScore, score, i);
BaseOptimizer.incrementIterationCount(model, 1);
}
return true;
}
use of org.deeplearning4j.optimize.api.IterationListener in project deeplearning4j by deeplearning4j.
the class SparkListenable method setListeners.
/**
* Set the listeners, along with a StatsStorageRouter that the results will be shuffled to (in the
* case of any listeners that implement the {@link RoutingIterationListener} interface)
*
* @param statsStorage Stats storage router to place the results into
* @param listeners Listeners to set
*/
public void setListeners(StatsStorageRouter statsStorage, Collection<? extends IterationListener> listeners) {
//Check if we have any RoutingIterationListener instances that need a StatsStorage implementation...
StatsStorageRouterProvider routerProvider = null;
if (listeners != null) {
for (IterationListener l : listeners) {
if (l instanceof RoutingIterationListener) {
RoutingIterationListener rl = (RoutingIterationListener) l;
if (statsStorage == null && rl.getStorageRouter() == null) {
log.warn("RoutingIterationListener provided without providing any StatsStorage instance. Iterator may not function without one. Listener: {}", l);
} else if (rl.getStorageRouter() != null && !(rl.getStorageRouter() instanceof Serializable)) {
//Spark would throw a (probably cryptic) serialization exception later anyway...
throw new IllegalStateException("RoutingIterationListener provided with non-serializable storage router " + "\nRoutingIterationListener class: " + rl.getClass().getName() + "\nStatsStorageRouter class: " + rl.getStorageRouter().getClass().getName());
}
//Need to give workers a router provider...
if (routerProvider == null) {
routerProvider = new VanillaStatsStorageRouterProvider();
}
}
}
}
this.listeners.clear();
if (listeners != null) {
this.listeners.addAll(listeners);
if (trainingMaster != null)
trainingMaster.setListeners(statsStorage, this.listeners);
}
}
use of org.deeplearning4j.optimize.api.IterationListener in project deeplearning4j by deeplearning4j.
the class ParameterAveragingTrainingWorker method configureListeners.
private void configureListeners(Model m, int counter) {
if (iterationListeners != null) {
List<IterationListener> list = new ArrayList<>(iterationListeners.size());
for (IterationListener l : iterationListeners) {
if (listenerRouterProvider != null && l instanceof RoutingIterationListener) {
RoutingIterationListener rl = (RoutingIterationListener) l;
rl.setStorageRouter(listenerRouterProvider.getRouter());
String workerID = UIDProvider.getJVMUID() + "_" + counter;
rl.setWorkerID(workerID);
}
//Don't need to clone listeners: not from broadcast, so deserialization handles
list.add(l);
}
if (m instanceof MultiLayerNetwork)
((MultiLayerNetwork) m).setListeners(list);
else
((ComputationGraph) m).setListeners(list);
}
}
Aggregations