Search in sources :

Example 1 with InvalidStepException

use of org.deeplearning4j.exception.InvalidStepException in project deeplearning4j by deeplearning4j.

the class BaseOptimizer method optimize.

/**
     * Optimize call. This runs the optimizer.
     * @return whether it converged or not
     */
// TODO add flag to allow retaining state between mini batches and when to apply updates
@Override
public boolean optimize() {
    //validate the input before training
    INDArray gradient;
    INDArray searchDirection;
    INDArray parameters;
    Pair<Gradient, Double> pair = gradientAndScore();
    if (searchState.isEmpty()) {
        searchState.put(GRADIENT_KEY, pair.getFirst().gradient());
        //Only do this once
        setupSearchState(pair);
    } else {
        searchState.put(GRADIENT_KEY, pair.getFirst().gradient());
    }
    /*
         * Commented out for now; this has been problematic for testing/debugging
         * Revisit & re-enable later. */
    for (TerminationCondition condition : terminationConditions) {
        if (condition.terminate(0.0, 0.0, new Object[] { pair.getFirst().gradient() })) {
            log.info("Hit termination condition " + condition.getClass().getName());
            return true;
        }
    }
    //calculate initial search direction
    preProcessLine();
    for (int i = 0; i < conf.getNumIterations(); i++) {
        gradient = (INDArray) searchState.get(GRADIENT_KEY);
        searchDirection = (INDArray) searchState.get(SEARCH_DIR);
        parameters = (INDArray) searchState.get(PARAMS_KEY);
        //perform one line search optimization
        try {
            step = lineMaximizer.optimize(parameters, gradient, searchDirection);
        } catch (InvalidStepException e) {
            log.warn("Invalid step...continuing another iteration: {}", e.getMessage());
            step = 0.0;
        }
        //Update parameters based on final/best step size returned by line search:
        if (step != 0.0) {
            //Calculate params. given step size
            stepFunction.step(parameters, searchDirection, step);
            model.setParams(parameters);
        } else {
            log.debug("Step size returned by line search is 0.0.");
        }
        pair = gradientAndScore();
        //updates searchDirection
        postStep(pair.getFirst().gradient());
        //invoke listeners
        int iterationCount = BaseOptimizer.getIterationCount(model);
        for (IterationListener listener : iterationListeners) listener.iterationDone(model, iterationCount);
        //check for termination conditions based on absolute change in score
        checkTerminalConditions(pair.getFirst().gradient(), oldScore, score, i);
        incrementIterationCount(model, 1);
    }
    return true;
}
Also used : Gradient(org.deeplearning4j.nn.gradient.Gradient) INDArray(org.nd4j.linalg.api.ndarray.INDArray) InvalidStepException(org.deeplearning4j.exception.InvalidStepException)

Aggregations

InvalidStepException (org.deeplearning4j.exception.InvalidStepException)1 Gradient (org.deeplearning4j.nn.gradient.Gradient)1 INDArray (org.nd4j.linalg.api.ndarray.INDArray)1