Search in sources :

Example 6 with ConvergenceChecker

use of org.apache.commons.math3.optim.ConvergenceChecker in project GDSC-SMLM by aherbert.

the class PcPalmFitting method runBoundedOptimiser.

private PointValuePair runBoundedOptimiser(double[] initialSolution, double[] lowerB, double[] upperB, SumOfSquaresModelFunction function) {
    // Create the functions to optimise
    final ObjectiveFunction objective = new ObjectiveFunction(new SumOfSquaresMultivariateFunction(function));
    final ObjectiveFunctionGradient gradient = new ObjectiveFunctionGradient(new SumOfSquaresMultivariateVectorFunction(function));
    final boolean debug = false;
    // Try a gradient optimiser since this will produce a deterministic solution
    PointValuePair optimum = null;
    boundedEvaluations = 0;
    final MaxEval maxEvaluations = new MaxEval(2000);
    MultivariateOptimizer opt = null;
    for (int iteration = 0; iteration <= settings.fitRestarts; iteration++) {
        try {
            final double relativeThreshold = 1e-6;
            opt = new BoundedNonLinearConjugateGradientOptimizer(BoundedNonLinearConjugateGradientOptimizer.Formula.FLETCHER_REEVES, new SimpleValueChecker(relativeThreshold, -1));
            optimum = opt.optimize(maxEvaluations, gradient, objective, GoalType.MINIMIZE, new InitialGuess((optimum == null) ? initialSolution : optimum.getPointRef()), new SimpleBounds(lowerB, upperB));
            if (debug) {
                System.out.printf("Bounded Iter %d = %g (%d)\n", iteration, optimum.getValue(), opt.getEvaluations());
            }
        } catch (final RuntimeException ex) {
            // No need to restart
            break;
        } finally {
            if (opt != null) {
                boundedEvaluations += opt.getEvaluations();
            }
        }
    }
    // Try a CMAES optimiser which is non-deterministic. To overcome this we perform restarts.
    // CMAESOptimiser based on Matlab code:
    // https://www.lri.fr/~hansen/cmaes.m
    // Take the defaults from the Matlab documentation
    final double stopFitness = 0;
    final boolean isActiveCma = true;
    final int diagonalOnly = 0;
    final int checkFeasableCount = 1;
    final RandomGenerator random = new RandomGeneratorAdapter(UniformRandomProviders.create());
    final boolean generateStatistics = false;
    final ConvergenceChecker<PointValuePair> checker = new SimpleValueChecker(1e-6, 1e-10);
    // The sigma determines the search range for the variables. It should be 1/3 of the initial
    // search region.
    final double[] range = new double[lowerB.length];
    for (int i = 0; i < lowerB.length; i++) {
        range[i] = (upperB[i] - lowerB[i]) / 3;
    }
    final OptimizationData sigma = new CMAESOptimizer.Sigma(range);
    final OptimizationData popSize = new CMAESOptimizer.PopulationSize((int) (4 + Math.floor(3 * Math.log(initialSolution.length))));
    final SimpleBounds bounds = new SimpleBounds(lowerB, upperB);
    opt = new CMAESOptimizer(maxEvaluations.getMaxEval(), stopFitness, isActiveCma, diagonalOnly, checkFeasableCount, random, generateStatistics, checker);
    // Restart the optimiser several times and take the best answer.
    for (int iteration = 0; iteration <= settings.fitRestarts; iteration++) {
        try {
            // Start from the initial solution
            final PointValuePair constrainedSolution = opt.optimize(new InitialGuess(initialSolution), objective, GoalType.MINIMIZE, bounds, sigma, popSize, maxEvaluations);
            if (debug) {
                System.out.printf("CMAES Iter %d initial = %g (%d)\n", iteration, constrainedSolution.getValue(), opt.getEvaluations());
            }
            boundedEvaluations += opt.getEvaluations();
            if (optimum == null || constrainedSolution.getValue() < optimum.getValue()) {
                optimum = constrainedSolution;
            }
        } catch (final TooManyEvaluationsException | TooManyIterationsException ex) {
        // Ignore
        } finally {
            boundedEvaluations += maxEvaluations.getMaxEval();
        }
        if (optimum == null) {
            continue;
        }
        try {
            // Also restart from the current optimum
            final PointValuePair constrainedSolution = opt.optimize(new InitialGuess(optimum.getPointRef()), objective, GoalType.MINIMIZE, bounds, sigma, popSize, maxEvaluations);
            if (debug) {
                System.out.printf("CMAES Iter %d restart = %g (%d)\n", iteration, constrainedSolution.getValue(), opt.getEvaluations());
            }
            if (constrainedSolution.getValue() < optimum.getValue()) {
                optimum = constrainedSolution;
            }
        } catch (final TooManyEvaluationsException | TooManyIterationsException ex) {
        // Ignore
        } finally {
            boundedEvaluations += maxEvaluations.getMaxEval();
        }
    }
    return optimum;
}
Also used : MultivariateOptimizer(org.apache.commons.math3.optim.nonlinear.scalar.MultivariateOptimizer) MaxEval(org.apache.commons.math3.optim.MaxEval) InitialGuess(org.apache.commons.math3.optim.InitialGuess) SimpleBounds(org.apache.commons.math3.optim.SimpleBounds) ObjectiveFunction(org.apache.commons.math3.optim.nonlinear.scalar.ObjectiveFunction) SimpleValueChecker(org.apache.commons.math3.optim.SimpleValueChecker) RandomGenerator(org.apache.commons.math3.random.RandomGenerator) PointValuePair(org.apache.commons.math3.optim.PointValuePair) TooManyEvaluationsException(org.apache.commons.math3.exception.TooManyEvaluationsException) BoundedNonLinearConjugateGradientOptimizer(uk.ac.sussex.gdsc.smlm.math3.optim.nonlinear.scalar.gradient.BoundedNonLinearConjugateGradientOptimizer) TooManyIterationsException(org.apache.commons.math3.exception.TooManyIterationsException) RandomGeneratorAdapter(uk.ac.sussex.gdsc.core.utils.rng.RandomGeneratorAdapter) CMAESOptimizer(org.apache.commons.math3.optim.nonlinear.scalar.noderiv.CMAESOptimizer) ObjectiveFunctionGradient(org.apache.commons.math3.optim.nonlinear.scalar.ObjectiveFunctionGradient) OptimizationData(org.apache.commons.math3.optim.OptimizationData)

Example 7 with ConvergenceChecker

use of org.apache.commons.math3.optim.ConvergenceChecker in project GDSC-SMLM by aherbert.

the class BinomialFitter method fitBinomial.

/**
 * Fit the binomial distribution (n,p) to the cumulative histogram. Performs fitting assuming a
 * fixed n value and attempts to optimise p.
 *
 * @param histogram The input histogram
 * @param mean The histogram mean (used to estimate p). Calculated if NaN.
 * @param n The n to evaluate
 * @param zeroTruncated True if the model should ignore n=0 (zero-truncated binomial)
 * @return The best fit (n, p)
 * @throws IllegalArgumentException If any of the input data values are negative
 * @throws IllegalArgumentException If any fitting a zero truncated binomial and there are no
 *         values above zero
 */
public PointValuePair fitBinomial(double[] histogram, double mean, int n, boolean zeroTruncated) {
    if (Double.isNaN(mean)) {
        mean = getMean(histogram);
    }
    if (zeroTruncated && histogram[0] > 0) {
        log("Fitting zero-truncated histogram but there are zero values - " + "Renormalising to ignore zero");
        double cumul = 0;
        for (int i = 1; i < histogram.length; i++) {
            cumul += histogram[i];
        }
        if (cumul == 0) {
            throw new IllegalArgumentException("Fitting zero-truncated histogram but there are no non-zero values");
        }
        histogram[0] = 0;
        for (int i = 1; i < histogram.length; i++) {
            histogram[i] /= cumul;
        }
    }
    final int nFittedPoints = Math.min(histogram.length, n + 1) - ((zeroTruncated) ? 1 : 0);
    if (nFittedPoints < 1) {
        log("No points to fit (%d): Histogram.length = %d, n = %d, zero-truncated = %b", nFittedPoints, histogram.length, n, zeroTruncated);
        return null;
    }
    // The model is only fitting the probability p
    // For a binomial n*p = mean => p = mean/n
    final double[] initialSolution = new double[] { Math.min(mean / n, 1) };
    // Create the function
    final BinomialModelFunction function = new BinomialModelFunction(histogram, n, zeroTruncated);
    final double[] lB = new double[1];
    final double[] uB = new double[] { 1 };
    final SimpleBounds bounds = new SimpleBounds(lB, uB);
    // Fit
    // CMAESOptimizer or BOBYQAOptimizer support bounds
    // CMAESOptimiser based on Matlab code:
    // https://www.lri.fr/~hansen/cmaes.m
    // Take the defaults from the Matlab documentation
    final int maxIterations = 2000;
    final double stopFitness = 0;
    final boolean isActiveCma = true;
    final int diagonalOnly = 0;
    final int checkFeasableCount = 1;
    final RandomGenerator random = new RandomGeneratorAdapter(UniformRandomProviders.create());
    final boolean generateStatistics = false;
    final ConvergenceChecker<PointValuePair> checker = new SimpleValueChecker(1e-6, 1e-10);
    // The sigma determines the search range for the variables. It should be 1/3 of the initial
    // search region.
    final OptimizationData sigma = new CMAESOptimizer.Sigma(new double[] { (uB[0] - lB[0]) / 3 });
    final OptimizationData popSize = new CMAESOptimizer.PopulationSize((int) (4 + Math.floor(3 * Math.log(2))));
    try {
        PointValuePair solution = null;
        boolean noRefit = maximumLikelihood;
        if (n == 1 && zeroTruncated) {
            // No need to fit
            solution = new PointValuePair(new double[] { 1 }, 0);
            noRefit = true;
        } else {
            final GoalType goalType = (maximumLikelihood) ? GoalType.MAXIMIZE : GoalType.MINIMIZE;
            // Iteratively fit
            final CMAESOptimizer opt = new CMAESOptimizer(maxIterations, stopFitness, isActiveCma, diagonalOnly, checkFeasableCount, random, generateStatistics, checker);
            for (int iteration = 0; iteration <= fitRestarts; iteration++) {
                try {
                    // Start from the initial solution
                    final PointValuePair result = opt.optimize(new InitialGuess(initialSolution), new ObjectiveFunction(function), goalType, bounds, sigma, popSize, new MaxIter(maxIterations), new MaxEval(maxIterations * 2));
                    // opt.getEvaluations());
                    if (solution == null || result.getValue() < solution.getValue()) {
                        solution = result;
                    }
                } catch (final TooManyEvaluationsException | TooManyIterationsException ex) {
                // No solution
                }
                if (solution == null) {
                    continue;
                }
                try {
                    // Also restart from the current optimum
                    final PointValuePair result = opt.optimize(new InitialGuess(solution.getPointRef()), new ObjectiveFunction(function), goalType, bounds, sigma, popSize, new MaxIter(maxIterations), new MaxEval(maxIterations * 2));
                    // opt.getEvaluations());
                    if (result.getValue() < solution.getValue()) {
                        solution = result;
                    }
                } catch (final TooManyEvaluationsException | TooManyIterationsException ex) {
                // No solution
                }
            }
            if (solution == null) {
                return null;
            }
        }
        if (noRefit) {
            // Although we fit the log-likelihood, return the sum-of-squares to allow
            // comparison across different n
            final double p = solution.getPointRef()[0];
            double ss = 0;
            final double[] obs = function.pvalues;
            final double[] exp = function.getP(p);
            for (int i = 0; i < obs.length; i++) {
                ss += (obs[i] - exp[i]) * (obs[i] - exp[i]);
            }
            return new PointValuePair(solution.getPointRef(), ss);
        // We can do a LVM refit if the number of fitted points is more than 1.
        } else if (nFittedPoints > 1) {
            // Improve SS fit with a gradient based LVM optimizer
            final LevenbergMarquardtOptimizer optimizer = new LevenbergMarquardtOptimizer();
            try {
                final BinomialModelFunctionGradient gradientFunction = new BinomialModelFunctionGradient(histogram, n, zeroTruncated);
                // @formatter:off
                final LeastSquaresProblem problem = new LeastSquaresBuilder().maxEvaluations(Integer.MAX_VALUE).maxIterations(3000).start(solution.getPointRef()).target(gradientFunction.pvalues).weight(new DiagonalMatrix(gradientFunction.getWeights())).model(gradientFunction, gradientFunction::jacobian).build();
                // @formatter:on
                final Optimum lvmSolution = optimizer.optimize(problem);
                // Check the pValue is valid since the LVM is not bounded.
                final double p = lvmSolution.getPoint().getEntry(0);
                if (p <= 1 && p >= 0) {
                    // True if the weights are 1
                    final double ss = lvmSolution.getResiduals().dotProduct(lvmSolution.getResiduals());
                    // ss += (obs[i] - exp[i]) * (obs[i] - exp[i]);
                    if (ss < solution.getValue()) {
                        // MathUtils.rounded(100 * (solution.getValue() - ss) / solution.getValue(), 4));
                        return new PointValuePair(lvmSolution.getPoint().toArray(), ss);
                    }
                }
            } catch (final TooManyIterationsException ex) {
                log("Failed to re-fit: Too many iterations: %s", ex.getMessage());
            } catch (final ConvergenceException ex) {
                log("Failed to re-fit: %s", ex.getMessage());
            } catch (final Exception ex) {
            // Ignore this ...
            }
        }
        return solution;
    } catch (final RuntimeException ex) {
        log("Failed to fit Binomial distribution with N=%d : %s", n, ex.getMessage());
    }
    return null;
}
Also used : InitialGuess(org.apache.commons.math3.optim.InitialGuess) MaxEval(org.apache.commons.math3.optim.MaxEval) SimpleBounds(org.apache.commons.math3.optim.SimpleBounds) ObjectiveFunction(org.apache.commons.math3.optim.nonlinear.scalar.ObjectiveFunction) SimpleValueChecker(org.apache.commons.math3.optim.SimpleValueChecker) RandomGenerator(org.apache.commons.math3.random.RandomGenerator) PointValuePair(org.apache.commons.math3.optim.PointValuePair) LeastSquaresBuilder(org.apache.commons.math3.fitting.leastsquares.LeastSquaresBuilder) TooManyEvaluationsException(org.apache.commons.math3.exception.TooManyEvaluationsException) DiagonalMatrix(org.apache.commons.math3.linear.DiagonalMatrix) ConvergenceException(org.apache.commons.math3.exception.ConvergenceException) TooManyIterationsException(org.apache.commons.math3.exception.TooManyIterationsException) LeastSquaresProblem(org.apache.commons.math3.fitting.leastsquares.LeastSquaresProblem) RandomGeneratorAdapter(uk.ac.sussex.gdsc.core.utils.rng.RandomGeneratorAdapter) CMAESOptimizer(org.apache.commons.math3.optim.nonlinear.scalar.noderiv.CMAESOptimizer) GoalType(org.apache.commons.math3.optim.nonlinear.scalar.GoalType) ConvergenceException(org.apache.commons.math3.exception.ConvergenceException) TooManyIterationsException(org.apache.commons.math3.exception.TooManyIterationsException) TooManyEvaluationsException(org.apache.commons.math3.exception.TooManyEvaluationsException) Optimum(org.apache.commons.math3.fitting.leastsquares.LeastSquaresOptimizer.Optimum) LevenbergMarquardtOptimizer(org.apache.commons.math3.fitting.leastsquares.LevenbergMarquardtOptimizer) OptimizationData(org.apache.commons.math3.optim.OptimizationData) MaxIter(org.apache.commons.math3.optim.MaxIter)

Example 8 with ConvergenceChecker

use of org.apache.commons.math3.optim.ConvergenceChecker in project GDSC-SMLM by aherbert.

the class JumpDistanceAnalysis method createCmaesOptimizer.

private static CMAESOptimizer createCmaesOptimizer() {
    final double rel = 1e-8;
    final double abs = 1e-10;
    final int maxIterations = 2000;
    final double stopFitness = 0;
    final boolean isActiveCma = true;
    final int diagonalOnly = 20;
    final int checkFeasableCount = 1;
    final RandomGenerator random = new RandomGeneratorAdapter(UniformRandomProviders.create());
    final boolean generateStatistics = false;
    final ConvergenceChecker<PointValuePair> checker = new SimpleValueChecker(rel, abs);
    // Iterate this for stability in the initial guess
    return new CMAESOptimizer(maxIterations, stopFitness, isActiveCma, diagonalOnly, checkFeasableCount, random, generateStatistics, checker);
}
Also used : RandomGeneratorAdapter(uk.ac.sussex.gdsc.core.utils.rng.RandomGeneratorAdapter) CMAESOptimizer(org.apache.commons.math3.optim.nonlinear.scalar.noderiv.CMAESOptimizer) SimpleValueChecker(org.apache.commons.math3.optim.SimpleValueChecker) RandomGenerator(org.apache.commons.math3.random.RandomGenerator) PointValuePair(org.apache.commons.math3.optim.PointValuePair)

Example 9 with ConvergenceChecker

use of org.apache.commons.math3.optim.ConvergenceChecker in project GDSC-SMLM by aherbert.

the class BoundedNonLinearConjugateGradientOptimizer method doOptimize.

@Override
protected PointValuePair doOptimize() {
    final ConvergenceChecker<PointValuePair> checker = getConvergenceChecker();
    final double[] point = getStartPoint();
    final GoalType goal = getGoalType();
    final int n = point.length;
    sign = (goal == GoalType.MINIMIZE) ? -1 : 1;
    double[] unbounded = point.clone();
    applyBounds(point);
    double[] gradient = computeObjectiveGradient(point);
    checkGradients(gradient, unbounded);
    if (goal == GoalType.MINIMIZE) {
        for (int i = 0; i < n; i++) {
            gradient[i] = -gradient[i];
        }
    }
    // Initial search direction.
    double[] steepestDescent = preconditioner.precondition(point, gradient);
    double[] searchDirection = steepestDescent.clone();
    double delta = 0;
    for (int i = 0; i < n; ++i) {
        delta += gradient[i] * searchDirection[i];
    }
    // Used for non-gradient based line search
    LineSearch line = null;
    double rel = 1e-6;
    double abs = 1e-10;
    if (getConvergenceChecker() instanceof SimpleValueChecker) {
        rel = ((SimpleValueChecker) getConvergenceChecker()).getRelativeThreshold();
        abs = ((SimpleValueChecker) getConvergenceChecker()).getRelativeThreshold();
    }
    line = new LineSearch(Math.sqrt(rel), Math.sqrt(abs));
    PointValuePair current = null;
    int maxEval = getMaxEvaluations();
    for (; ; ) {
        incrementIterationCount();
        final double objective = computeObjectiveValue(point);
        final PointValuePair previous = current;
        current = new PointValuePair(point, objective);
        if (previous != null && checker.converged(getIterations(), previous, current)) {
            // We have found an optimum.
            return current;
        }
        double step;
        if (useGradientLineSearch) {
            // Classic code using the gradient function for the line search:
            // Find the optimal step in the search direction.
            final UnivariateFunction lsf = new LineSearchFunction(point, searchDirection);
            final double uB;
            try {
                uB = findUpperBound(lsf, 0, initialStep);
                // Check if the bracket found a minimum. Otherwise just move to the new point.
                if (noBracket) {
                    step = uB;
                } else {
                    // XXX Last parameters is set to a value close to zero in order to
                    // work around the divergence problem in the "testCircleFitting"
                    // unit test (see MATH-439).
                    // System.out.printf("Bracket %f - %f - %f\n", 0., 1e-15, uB);
                    step = solver.solve(maxEval, lsf, 0, uB, 1e-15);
                    // Subtract used up evaluations.
                    maxEval -= solver.getEvaluations();
                }
            } catch (final MathIllegalStateException ex) {
                // System.out.printf("Failed to bracket %s @ %s\n", Arrays.toString(point),
                // Arrays.toString(searchDirection));
                // Line search without gradient (as per Powell optimiser)
                final UnivariatePointValuePair optimum = line.search(point, searchDirection);
                step = optimum.getPoint();
            // throw ex;
            }
        } else {
            // Line search without gradient (as per Powell optimiser)
            final UnivariatePointValuePair optimum = line.search(point, searchDirection);
            step = optimum.getPoint();
        }
        // System.out.printf("Step = %f x %s\n", step, Arrays.toString(searchDirection));
        for (int i = 0; i < point.length; ++i) {
            point[i] += step * searchDirection[i];
        }
        unbounded = point.clone();
        applyBounds(point);
        gradient = computeObjectiveGradient(point);
        checkGradients(gradient, unbounded);
        if (goal == GoalType.MINIMIZE) {
            for (int i = 0; i < n; ++i) {
                gradient[i] = -gradient[i];
            }
        }
        // Compute beta.
        final double deltaOld = delta;
        final double[] newSteepestDescent = preconditioner.precondition(point, gradient);
        delta = 0;
        for (int i = 0; i < n; ++i) {
            delta += gradient[i] * newSteepestDescent[i];
        }
        if (delta == 0) {
            return new PointValuePair(point, computeObjectiveValue(point));
        }
        final double beta;
        switch(updateFormula) {
            case FLETCHER_REEVES:
                beta = delta / deltaOld;
                break;
            case POLAK_RIBIERE:
                double deltaMid = 0;
                for (int i = 0; i < gradient.length; ++i) {
                    deltaMid += gradient[i] * steepestDescent[i];
                }
                beta = (delta - deltaMid) / deltaOld;
                break;
            default:
                // Should never happen.
                throw new MathInternalError();
        }
        steepestDescent = newSteepestDescent;
        // Compute conjugate search direction.
        if (getIterations() % n == 0 || beta < 0) {
            // Break conjugation: reset search direction.
            searchDirection = steepestDescent.clone();
        } else {
            // Compute new conjugate search direction.
            for (int i = 0; i < n; ++i) {
                searchDirection[i] = steepestDescent[i] + beta * searchDirection[i];
            }
        }
        // The gradient has already been adjusted for the search direction
        checkGradients(searchDirection, unbounded, -sign);
    }
}
Also used : UnivariateFunction(org.apache.commons.math3.analysis.UnivariateFunction) GoalType(org.apache.commons.math3.optim.nonlinear.scalar.GoalType) UnivariatePointValuePair(org.apache.commons.math3.optim.univariate.UnivariatePointValuePair) SimpleValueChecker(org.apache.commons.math3.optim.SimpleValueChecker) MathIllegalStateException(org.apache.commons.math3.exception.MathIllegalStateException) PointValuePair(org.apache.commons.math3.optim.PointValuePair) UnivariatePointValuePair(org.apache.commons.math3.optim.univariate.UnivariatePointValuePair) MathInternalError(org.apache.commons.math3.exception.MathInternalError)

Example 10 with ConvergenceChecker

use of org.apache.commons.math3.optim.ConvergenceChecker in project GDSC-SMLM by aherbert.

the class CustomPowellOptimizer method doOptimize.

// CHECKSTYLE.OFF: LocalVariableName
// CHECKSTYLE.OFF: ParameterName
@Override
protected PointValuePair doOptimize() {
    final GoalType goal = getGoalType();
    final double[] guess = getStartPoint();
    final int n = guess.length;
    // Mark when we have modified the basis vectors
    boolean nonBasis = false;
    double[][] direc = createBasisVectors(n);
    final ConvergenceChecker<PointValuePair> checker = getConvergenceChecker();
    double[] x = guess;
    // Ensure the point is within bounds
    applyBounds(x);
    double functionValue = computeObjectiveValue(x);
    double[] x1 = x.clone();
    for (; ; ) {
        incrementIterationCount();
        final double fX = functionValue;
        double fX2 = 0;
        double delta = 0;
        int bigInd = 0;
        for (int i = 0; i < n; i++) {
            fX2 = functionValue;
            final UnivariatePointValuePair optimum = line.search(x, direc[i]);
            functionValue = optimum.getValue();
            x = newPoint(x, direc[i], optimum.getPoint());
            if ((fX2 - functionValue) > delta) {
                delta = fX2 - functionValue;
                bigInd = i;
            }
        }
        final PointValuePair previous = new PointValuePair(x1, fX, false);
        final PointValuePair current = new PointValuePair(x, functionValue, false);
        boolean stop = false;
        if (positionChecker != null) {
            // Check for convergence on the position
            stop = positionChecker.converged(getIterations(), previous, current);
        }
        if (!stop) {
            // Check if we have improved from an impossible position
            if (Double.isInfinite(fX) || Double.isNaN(fX)) {
                if (Double.isInfinite(functionValue) || Double.isNaN(functionValue)) {
                    // Nowhere to go
                    stop = true;
                }
            } else {
                stop = DoubleEquality.almostEqualRelativeOrAbsolute(fX, functionValue, relativeThreshold, absoluteThreshold);
            }
        }
        if (!stop && checker != null) {
            stop = checker.converged(getIterations(), previous, current);
        }
        boolean reset = false;
        if (stop) {
            // Only allow convergence using the basis vectors, i.e. we cannot move along any dimension
            if (basisConvergence && nonBasis) {
                // Reset to the basis vectors and continue
                reset = true;
            } else {
                final PointValuePair answer;
                if (goal == GoalType.MINIMIZE) {
                    answer = (functionValue < fX) ? current : previous;
                } else {
                    answer = (functionValue > fX) ? current : previous;
                }
                return answer;
            }
        }
        if (reset) {
            direc = createBasisVectors(n);
            nonBasis = false;
        }
        final double[] d = new double[n];
        final double[] x2 = new double[n];
        for (int i = 0; i < n; i++) {
            d[i] = x[i] - x1[i];
            x2[i] = x[i] + d[i];
        }
        applyBounds(x2);
        x1 = x.clone();
        fX2 = computeObjectiveValue(x2);
        // See if we can continue along the overall search direction to find a better value
        if (fX > fX2) {
            // Check if:
            // 1. The decrease along the average direction was not due to any single direction's
            // decrease
            // 2. There is a substantial second derivative along the average direction and we are close
            // to it minimum
            double t = 2 * (fX + fX2 - 2 * functionValue);
            double temp = fX - functionValue - delta;
            t *= temp * temp;
            temp = fX - fX2;
            t -= delta * temp * temp;
            if (t < 0.0) {
                final UnivariatePointValuePair optimum = line.search(x, d);
                functionValue = optimum.getValue();
                if (reset) {
                    x = newPoint(x, d, optimum.getPoint());
                    continue;
                }
                final double[][] result = newPointAndDirection(x, d, optimum.getPoint());
                x = result[0];
                final int lastInd = n - 1;
                direc[bigInd] = direc[lastInd];
                direc[lastInd] = result[1];
                nonBasis = true;
            }
        }
    }
}
Also used : GoalType(org.apache.commons.math3.optim.nonlinear.scalar.GoalType) UnivariatePointValuePair(org.apache.commons.math3.optim.univariate.UnivariatePointValuePair) PointValuePair(org.apache.commons.math3.optim.PointValuePair) UnivariatePointValuePair(org.apache.commons.math3.optim.univariate.UnivariatePointValuePair)

Aggregations

PointValuePair (org.apache.commons.math3.optim.PointValuePair)13 SimpleValueChecker (org.apache.commons.math3.optim.SimpleValueChecker)8 GoalType (org.apache.commons.math3.optim.nonlinear.scalar.GoalType)6 CMAESOptimizer (org.apache.commons.math3.optim.nonlinear.scalar.noderiv.CMAESOptimizer)6 RandomGenerator (org.apache.commons.math3.random.RandomGenerator)6 TooManyEvaluationsException (org.apache.commons.math3.exception.TooManyEvaluationsException)4 TooManyIterationsException (org.apache.commons.math3.exception.TooManyIterationsException)4 InitialGuess (org.apache.commons.math3.optim.InitialGuess)4 MaxEval (org.apache.commons.math3.optim.MaxEval)4 OptimizationData (org.apache.commons.math3.optim.OptimizationData)4 SimpleBounds (org.apache.commons.math3.optim.SimpleBounds)4 ObjectiveFunction (org.apache.commons.math3.optim.nonlinear.scalar.ObjectiveFunction)4 UnivariatePointValuePair (org.apache.commons.math3.optim.univariate.UnivariatePointValuePair)4 RandomGeneratorAdapter (uk.ac.sussex.gdsc.core.utils.rng.RandomGeneratorAdapter)3 UnivariateFunction (org.apache.commons.math3.analysis.UnivariateFunction)2 ConvergenceException (org.apache.commons.math3.exception.ConvergenceException)2 MathIllegalStateException (org.apache.commons.math3.exception.MathIllegalStateException)2 MathInternalError (org.apache.commons.math3.exception.MathInternalError)2 LeastSquaresBuilder (org.apache.commons.math3.fitting.leastsquares.LeastSquaresBuilder)2 Optimum (org.apache.commons.math3.fitting.leastsquares.LeastSquaresOptimizer.Optimum)2