use of org.deeplearning4j.exception.InvalidStepException in project deeplearning4j by deeplearning4j.
the class BaseOptimizer method optimize.
/**
* Optimize call. This runs the optimizer.
* @return whether it converged or not
*/
// TODO add flag to allow retaining state between mini batches and when to apply updates
@Override
public boolean optimize() {
//validate the input before training
INDArray gradient;
INDArray searchDirection;
INDArray parameters;
Pair<Gradient, Double> pair = gradientAndScore();
if (searchState.isEmpty()) {
searchState.put(GRADIENT_KEY, pair.getFirst().gradient());
//Only do this once
setupSearchState(pair);
} else {
searchState.put(GRADIENT_KEY, pair.getFirst().gradient());
}
/*
* Commented out for now; this has been problematic for testing/debugging
* Revisit & re-enable later. */
for (TerminationCondition condition : terminationConditions) {
if (condition.terminate(0.0, 0.0, new Object[] { pair.getFirst().gradient() })) {
log.info("Hit termination condition " + condition.getClass().getName());
return true;
}
}
//calculate initial search direction
preProcessLine();
for (int i = 0; i < conf.getNumIterations(); i++) {
gradient = (INDArray) searchState.get(GRADIENT_KEY);
searchDirection = (INDArray) searchState.get(SEARCH_DIR);
parameters = (INDArray) searchState.get(PARAMS_KEY);
//perform one line search optimization
try {
step = lineMaximizer.optimize(parameters, gradient, searchDirection);
} catch (InvalidStepException e) {
log.warn("Invalid step...continuing another iteration: {}", e.getMessage());
step = 0.0;
}
//Update parameters based on final/best step size returned by line search:
if (step != 0.0) {
//Calculate params. given step size
stepFunction.step(parameters, searchDirection, step);
model.setParams(parameters);
} else {
log.debug("Step size returned by line search is 0.0.");
}
pair = gradientAndScore();
//updates searchDirection
postStep(pair.getFirst().gradient());
//invoke listeners
int iterationCount = BaseOptimizer.getIterationCount(model);
for (IterationListener listener : iterationListeners) listener.iterationDone(model, iterationCount);
//check for termination conditions based on absolute change in score
checkTerminalConditions(pair.getFirst().gradient(), oldScore, score, i);
incrementIterationCount(model, 1);
}
return true;
}
Aggregations