Search in sources :

Example 1 with MathIllegalStateException

use of org.apache.commons.math3.exception.MathIllegalStateException in project GDSC-SMLM by aherbert.

the class BoundedNonLinearConjugateGradientOptimizer method doOptimize.

/** {@inheritDoc} */
@Override
protected PointValuePair doOptimize() {
    final ConvergenceChecker<PointValuePair> checker = getConvergenceChecker();
    final double[] point = getStartPoint();
    final GoalType goal = getGoalType();
    final int n = point.length;
    sign = (goal == GoalType.MINIMIZE) ? -1 : 1;
    double[] unbounded = point.clone();
    applyBounds(point);
    double[] r = computeObjectiveGradient(point);
    checkGradients(r, unbounded);
    if (goal == GoalType.MINIMIZE) {
        for (int i = 0; i < n; i++) {
            r[i] = -r[i];
        }
    }
    // Initial search direction.
    double[] steepestDescent = preconditioner.precondition(point, r);
    double[] searchDirection = steepestDescent.clone();
    double delta = 0;
    for (int i = 0; i < n; ++i) {
        delta += r[i] * searchDirection[i];
    }
    // Used for non-gradient based line search
    LineSearch line = null;
    double rel = 1e-6;
    double abs = 1e-10;
    if (getConvergenceChecker() instanceof SimpleValueChecker) {
        rel = ((SimpleValueChecker) getConvergenceChecker()).getRelativeThreshold();
        abs = ((SimpleValueChecker) getConvergenceChecker()).getRelativeThreshold();
    }
    line = new LineSearch(Math.sqrt(rel), Math.sqrt(abs));
    PointValuePair current = null;
    int maxEval = getMaxEvaluations();
    while (true) {
        incrementIterationCount();
        final double objective = computeObjectiveValue(point);
        PointValuePair previous = current;
        current = new PointValuePair(point, objective);
        if (previous != null && checker.converged(getIterations(), previous, current)) {
            // We have found an optimum.
            return current;
        }
        double step;
        if (useGradientLineSearch) {
            // Classic code using the gradient function for the line search:
            // Find the optimal step in the search direction.
            final UnivariateFunction lsf = new LineSearchFunction(point, searchDirection);
            final double uB;
            try {
                uB = findUpperBound(lsf, 0, initialStep);
                // Check if the bracket found a minimum. Otherwise just move to the new point.
                if (noBracket)
                    step = uB;
                else {
                    // XXX Last parameters is set to a value close to zero in order to
                    // work around the divergence problem in the "testCircleFitting"
                    // unit test (see MATH-439).
                    //System.out.printf("Bracket %f - %f - %f\n", 0., 1e-15, uB);
                    step = solver.solve(maxEval, lsf, 0, uB, 1e-15);
                    // Subtract used up evaluations.
                    maxEval -= solver.getEvaluations();
                }
            } catch (MathIllegalStateException e) {
                //System.out.printf("Failed to bracket %s @ %s\n", Arrays.toString(point), Arrays.toString(searchDirection));
                // Line search without gradient (as per Powell optimiser)
                final UnivariatePointValuePair optimum = line.search(point, searchDirection);
                step = optimum.getPoint();
            //throw e;
            }
        } else {
            // Line search without gradient (as per Powell optimiser)
            final UnivariatePointValuePair optimum = line.search(point, searchDirection);
            step = optimum.getPoint();
        }
        //System.out.printf("Step = %f x %s\n", step, Arrays.toString(searchDirection));
        for (int i = 0; i < point.length; ++i) {
            point[i] += step * searchDirection[i];
        }
        unbounded = point.clone();
        applyBounds(point);
        r = computeObjectiveGradient(point);
        checkGradients(r, unbounded);
        if (goal == GoalType.MINIMIZE) {
            for (int i = 0; i < n; ++i) {
                r[i] = -r[i];
            }
        }
        // Compute beta.
        final double deltaOld = delta;
        final double[] newSteepestDescent = preconditioner.precondition(point, r);
        delta = 0;
        for (int i = 0; i < n; ++i) {
            delta += r[i] * newSteepestDescent[i];
        }
        if (delta == 0)
            return new PointValuePair(point, computeObjectiveValue(point));
        final double beta;
        switch(updateFormula) {
            case FLETCHER_REEVES:
                beta = delta / deltaOld;
                break;
            case POLAK_RIBIERE:
                double deltaMid = 0;
                for (int i = 0; i < r.length; ++i) {
                    deltaMid += r[i] * steepestDescent[i];
                }
                beta = (delta - deltaMid) / deltaOld;
                break;
            default:
                // Should never happen.
                throw new MathInternalError();
        }
        steepestDescent = newSteepestDescent;
        // Compute conjugate search direction.
        if (getIterations() % n == 0 || beta < 0) {
            // Break conjugation: reset search direction.
            searchDirection = steepestDescent.clone();
        } else {
            // Compute new conjugate search direction.
            for (int i = 0; i < n; ++i) {
                searchDirection[i] = steepestDescent[i] + beta * searchDirection[i];
            }
        }
        // The gradient has already been adjusted for the search direction
        checkGradients(searchDirection, unbounded, -sign);
    }
}
Also used : UnivariateFunction(org.apache.commons.math3.analysis.UnivariateFunction) GoalType(org.apache.commons.math3.optim.nonlinear.scalar.GoalType) UnivariatePointValuePair(org.apache.commons.math3.optim.univariate.UnivariatePointValuePair) SimpleValueChecker(org.apache.commons.math3.optim.SimpleValueChecker) MathIllegalStateException(org.apache.commons.math3.exception.MathIllegalStateException) PointValuePair(org.apache.commons.math3.optim.PointValuePair) UnivariatePointValuePair(org.apache.commons.math3.optim.univariate.UnivariatePointValuePair) MathInternalError(org.apache.commons.math3.exception.MathInternalError)

Example 2 with MathIllegalStateException

use of org.apache.commons.math3.exception.MathIllegalStateException in project GDSC-SMLM by aherbert.

the class BoundedNonLinearConjugateGradientOptimizer method doOptimize.

@Override
protected PointValuePair doOptimize() {
    final ConvergenceChecker<PointValuePair> checker = getConvergenceChecker();
    final double[] point = getStartPoint();
    final GoalType goal = getGoalType();
    final int n = point.length;
    sign = (goal == GoalType.MINIMIZE) ? -1 : 1;
    double[] unbounded = point.clone();
    applyBounds(point);
    double[] gradient = computeObjectiveGradient(point);
    checkGradients(gradient, unbounded);
    if (goal == GoalType.MINIMIZE) {
        for (int i = 0; i < n; i++) {
            gradient[i] = -gradient[i];
        }
    }
    // Initial search direction.
    double[] steepestDescent = preconditioner.precondition(point, gradient);
    double[] searchDirection = steepestDescent.clone();
    double delta = 0;
    for (int i = 0; i < n; ++i) {
        delta += gradient[i] * searchDirection[i];
    }
    // Used for non-gradient based line search
    LineSearch line = null;
    double rel = 1e-6;
    double abs = 1e-10;
    if (getConvergenceChecker() instanceof SimpleValueChecker) {
        rel = ((SimpleValueChecker) getConvergenceChecker()).getRelativeThreshold();
        abs = ((SimpleValueChecker) getConvergenceChecker()).getRelativeThreshold();
    }
    line = new LineSearch(Math.sqrt(rel), Math.sqrt(abs));
    PointValuePair current = null;
    int maxEval = getMaxEvaluations();
    for (; ; ) {
        incrementIterationCount();
        final double objective = computeObjectiveValue(point);
        final PointValuePair previous = current;
        current = new PointValuePair(point, objective);
        if (previous != null && checker.converged(getIterations(), previous, current)) {
            // We have found an optimum.
            return current;
        }
        double step;
        if (useGradientLineSearch) {
            // Classic code using the gradient function for the line search:
            // Find the optimal step in the search direction.
            final UnivariateFunction lsf = new LineSearchFunction(point, searchDirection);
            final double uB;
            try {
                uB = findUpperBound(lsf, 0, initialStep);
                // Check if the bracket found a minimum. Otherwise just move to the new point.
                if (noBracket) {
                    step = uB;
                } else {
                    // XXX Last parameters is set to a value close to zero in order to
                    // work around the divergence problem in the "testCircleFitting"
                    // unit test (see MATH-439).
                    // System.out.printf("Bracket %f - %f - %f\n", 0., 1e-15, uB);
                    step = solver.solve(maxEval, lsf, 0, uB, 1e-15);
                    // Subtract used up evaluations.
                    maxEval -= solver.getEvaluations();
                }
            } catch (final MathIllegalStateException ex) {
                // System.out.printf("Failed to bracket %s @ %s\n", Arrays.toString(point),
                // Arrays.toString(searchDirection));
                // Line search without gradient (as per Powell optimiser)
                final UnivariatePointValuePair optimum = line.search(point, searchDirection);
                step = optimum.getPoint();
            // throw ex;
            }
        } else {
            // Line search without gradient (as per Powell optimiser)
            final UnivariatePointValuePair optimum = line.search(point, searchDirection);
            step = optimum.getPoint();
        }
        // System.out.printf("Step = %f x %s\n", step, Arrays.toString(searchDirection));
        for (int i = 0; i < point.length; ++i) {
            point[i] += step * searchDirection[i];
        }
        unbounded = point.clone();
        applyBounds(point);
        gradient = computeObjectiveGradient(point);
        checkGradients(gradient, unbounded);
        if (goal == GoalType.MINIMIZE) {
            for (int i = 0; i < n; ++i) {
                gradient[i] = -gradient[i];
            }
        }
        // Compute beta.
        final double deltaOld = delta;
        final double[] newSteepestDescent = preconditioner.precondition(point, gradient);
        delta = 0;
        for (int i = 0; i < n; ++i) {
            delta += gradient[i] * newSteepestDescent[i];
        }
        if (delta == 0) {
            return new PointValuePair(point, computeObjectiveValue(point));
        }
        final double beta;
        switch(updateFormula) {
            case FLETCHER_REEVES:
                beta = delta / deltaOld;
                break;
            case POLAK_RIBIERE:
                double deltaMid = 0;
                for (int i = 0; i < gradient.length; ++i) {
                    deltaMid += gradient[i] * steepestDescent[i];
                }
                beta = (delta - deltaMid) / deltaOld;
                break;
            default:
                // Should never happen.
                throw new MathInternalError();
        }
        steepestDescent = newSteepestDescent;
        // Compute conjugate search direction.
        if (getIterations() % n == 0 || beta < 0) {
            // Break conjugation: reset search direction.
            searchDirection = steepestDescent.clone();
        } else {
            // Compute new conjugate search direction.
            for (int i = 0; i < n; ++i) {
                searchDirection[i] = steepestDescent[i] + beta * searchDirection[i];
            }
        }
        // The gradient has already been adjusted for the search direction
        checkGradients(searchDirection, unbounded, -sign);
    }
}
Also used : UnivariateFunction(org.apache.commons.math3.analysis.UnivariateFunction) GoalType(org.apache.commons.math3.optim.nonlinear.scalar.GoalType) UnivariatePointValuePair(org.apache.commons.math3.optim.univariate.UnivariatePointValuePair) SimpleValueChecker(org.apache.commons.math3.optim.SimpleValueChecker) MathIllegalStateException(org.apache.commons.math3.exception.MathIllegalStateException) PointValuePair(org.apache.commons.math3.optim.PointValuePair) UnivariatePointValuePair(org.apache.commons.math3.optim.univariate.UnivariatePointValuePair) MathInternalError(org.apache.commons.math3.exception.MathInternalError)

Example 3 with MathIllegalStateException

use of org.apache.commons.math3.exception.MathIllegalStateException in project GDSC-SMLM by aherbert.

the class BoundedNonLinearConjugateGradientOptimizer method findUpperBoundWithChecks.

/**
	 * Finds the upper bound b ensuring bracketing of a root between a and b.
	 *
	 * @param f
	 *            function whose root must be bracketed.
	 * @param a
	 *            lower bound of the interval.
	 * @param h
	 *            initial step to try.
	 * @return b such that f(a) and f(b) have opposite signs.
	 * @throws MathIllegalStateException
	 *             if no bracket can be found.
	 */
@SuppressWarnings("unused")
private double findUpperBoundWithChecks(final UnivariateFunction f, final double a, final double h) {
    noBracket = false;
    final double yA = f.value(a);
    // Check we have a gradient. This should be true unless something slipped by.
    if (Double.isNaN(yA))
        throw new MathIllegalStateException(LocalizedFormats.UNABLE_TO_BRACKET_OPTIMUM_IN_LINE_SEARCH);
    double yB = yA;
    double lastB = Double.NaN;
    for (double step = h; step < Double.MAX_VALUE; step *= FastMath.max(2, yA / yB)) {
        double b = a + step;
        yB = f.value(b);
        if (yA * yB <= 0) {
            return b;
        }
        if (Double.isNaN(yB)) {
            // We have moved along the vector to a point where we have no gradient.
            // Bracketing is impossible.
            noBracket = true;
            // Check we made at least one step to a place with a new gradient
            if (lastB != Double.NaN)
                // Return the point we reached as the minimum
                return lastB;
            // with a valid gradient
            for (step *= 0.1; step > Double.MIN_VALUE; step *= 0.1) {
                b = a + step;
                yB = f.value(b);
                if (yA * yB <= 0) {
                    return b;
                }
                if (!Double.isNaN(yB)) {
                    lastB = b;
                }
            }
            if (lastB != Double.NaN)
                // Return the point we reached as the minimum
                return lastB;
            throw new MathIllegalStateException(LocalizedFormats.UNABLE_TO_BRACKET_OPTIMUM_IN_LINE_SEARCH);
        }
        lastB = b;
    }
    throw new MathIllegalStateException(LocalizedFormats.UNABLE_TO_BRACKET_OPTIMUM_IN_LINE_SEARCH);
}
Also used : MathIllegalStateException(org.apache.commons.math3.exception.MathIllegalStateException)

Aggregations

MathIllegalStateException (org.apache.commons.math3.exception.MathIllegalStateException)3 UnivariateFunction (org.apache.commons.math3.analysis.UnivariateFunction)2 MathInternalError (org.apache.commons.math3.exception.MathInternalError)2 PointValuePair (org.apache.commons.math3.optim.PointValuePair)2 SimpleValueChecker (org.apache.commons.math3.optim.SimpleValueChecker)2 GoalType (org.apache.commons.math3.optim.nonlinear.scalar.GoalType)2 UnivariatePointValuePair (org.apache.commons.math3.optim.univariate.UnivariatePointValuePair)2