Search in sources :

Example 1 with Minimize

use of ffx.algorithms.Minimize in project ffx by mjschnie.

the class Looptimizer method energyAndGradient.

@Override
public double energyAndGradient(double[] x, double[] gradient) {
    double e = potential.energyAndGradient(x, gradient);
    /**
     * OSRW is propagated with the slowly varying terms.
     */
    if (state == STATE.FAST) {
        return e;
    }
    if (osrwOptimization && lambda > osrwOptimizationLambdaCutoff) {
        if (energyCount % osrwOptimizationFrequency == 0) {
            logger.info(String.format(" OSRW Minimization (Step %d)", energyCount));
            // Set Lambda value to 1.0.
            lambdaInterface.setLambda(1.0);
            potential.setEnergyTermState(STATE.BOTH);
            RefinementMinimize refinementMinimize = null;
            Minimize minimize = null;
            double[] xStart = null;
            double[] xFinal = null;
            // Optimize the system.
            if (useXRayMinimizer) {
                refinementMinimize = new RefinementMinimize(diffractionData);
                int n = refinementMinimize.refinementEnergy.getNumberOfVariables();
                xStart = new double[n];
                xStart = refinementMinimize.refinementEnergy.getCoordinates(xStart);
                refinementMinimize.minimize(osrwOptimizationEps);
                xFinal = new double[n];
                xFinal = refinementMinimize.refinementEnergy.getCoordinates(xFinal);
            } else {
                minimize = new Minimize(null, potential, null);
                int n = potential.getNumberOfVariables();
                xStart = new double[n];
                xStart = potential.getCoordinates(xStart);
                minimize.minimize(osrwOptimizationEps);
                xFinal = new double[n];
                xFinal = potential.getCoordinates(xFinal);
            }
            double minValue;
            if (useXRayMinimizer) {
                // Collect the minimum R value.
                minValue = diffractionData.getRCrystalStat();
            } else {
                // Collect the minimum energy.
                minValue = potential.getTotalEnergy();
            }
            // If a new minimum has been found, save its coordinates.
            if (minValue < osrwOptimum) {
                osrwOptimum = minValue;
                if (useXRayMinimizer) {
                    logger.info(String.format(" New minimum R found: %16.8f (Step %d).", osrwOptimum, energyCount));
                } else {
                    logger.info(String.format(" New minimum energy found: %16.8f (Step %d).", osrwOptimum, energyCount));
                }
                osrwOptimumCoords = xFinal;
                if (pdbFilter.writeFile(pdbFile, false)) {
                    logger.info(String.format(" Wrote PDB file to " + pdbFile.getName()));
                }
            }
            /**
             * Reset coordinates for X-ray minimization (parameters may
             * include B-Factors).
             */
            if (useXRayMinimizer) {
                refinementMinimize.refinementEnergy.energy(xStart);
            }
            /**
             * Revert to the coordinates, gradient lambda, and RESPA State
             * prior to optimization.
             */
            potential.setScaling(null);
            lambdaInterface.setLambda(lambda);
            potential.setEnergyTermState(state);
            double eCheck = potential.energyAndGradient(x, gradient);
            if (abs(eCheck - e) > osrwOptimizationTolerance) {
                logger.warning(String.format(" OSRW optimization could not revert coordinates %16.8f vs. %16.8f.", e, eCheck));
            }
        }
    }
    double biasEnergy = 0.0;
    dEdLambda = lambdaInterface.getdEdL();
    d2EdLambda2 = lambdaInterface.getd2EdL2();
    int lambdaBin = binForLambda(lambda);
    int FLambdaBin = binForFLambda(dEdLambda);
    double dEdU = dEdLambda;
    if (propagateLambda) {
        energyCount++;
    }
    /**
     * Calculate recursion kernel G(L, F_L) and its derivatives with respect
     * to L and F_L.
     */
    double dGdLambda = 0.0;
    double dGdFLambda = 0.0;
    double ls2 = (2.0 * dL) * (2.0 * dL);
    double FLs2 = (2.0 * dFL) * (2.0 * dFL);
    for (int iL = -biasCutoff; iL <= biasCutoff; iL++) {
        int lcenter = lambdaBin + iL;
        double deltaL = lambda - (lcenter * dL);
        double deltaL2 = deltaL * deltaL;
        // Mirror conditions for recursion kernel counts.
        int lcount = lcenter;
        double mirrorFactor = 1.0;
        if (lcount == 0 || lcount == lambdaBins - 1) {
            mirrorFactor = 2.0;
        } else if (lcount < 0) {
            lcount = -lcount;
        } else if (lcount > lambdaBins - 1) {
            // Number of bins past the last bin
            lcount -= (lambdaBins - 1);
            // Mirror bin
            lcount = lambdaBins - 1 - lcount;
        }
        for (int iFL = -biasCutoff; iFL <= biasCutoff; iFL++) {
            int FLcenter = FLambdaBin + iFL;
            /**
             * If either of the following FL edge conditions are true, then
             * there are no counts and we continue.
             */
            if (FLcenter < 0 || FLcenter >= FLambdaBins) {
                continue;
            }
            double deltaFL = dEdLambda - (minFLambda + FLcenter * dFL + dFL_2);
            double deltaFL2 = deltaFL * deltaFL;
            double weight = mirrorFactor * recursionKernel[lcount][FLcenter];
            double bias = weight * biasMag * exp(-deltaL2 / (2.0 * ls2)) * exp(-deltaFL2 / (2.0 * FLs2));
            biasEnergy += bias;
            dGdLambda -= deltaL / ls2 * bias;
            dGdFLambda -= deltaFL / FLs2 * bias;
        }
    }
    /**
     * Lambda gradient due to recursion kernel G(L, F_L).
     */
    dEdLambda += dGdLambda + dGdFLambda * d2EdLambda2;
    /**
     * Cartesian coordinate gradient due to recursion kernel G(L, F_L).
     */
    fill(dUdXdL, 0.0);
    lambdaInterface.getdEdXdL(dUdXdL);
    for (int i = 0; i < nVariables; i++) {
        gradient[i] += dGdFLambda * dUdXdL[i];
    }
    if (propagateLambda && energyCount > 0) {
        /**
         * Update free energy F(L) every ~10 steps.
         */
        if (energyCount % 10 == 0) {
            fLambdaUpdates++;
            boolean printFLambda = fLambdaUpdates % fLambdaPrintInterval == 0;
            totalFreeEnergy = updateFLambda(printFLambda);
            /**
             * Calculating Moving Average & Standard Deviation
             */
            totalAverage += totalFreeEnergy;
            totalSquare += Math.pow(totalFreeEnergy, 2);
            periodCount++;
            if (periodCount == window - 1) {
                double average = totalAverage / window;
                double stdev = Math.sqrt((totalSquare - Math.pow(totalAverage, 2) / window) / window);
                logger.info(String.format(" The running average is %12.4f kcal/mol and the stdev is %8.4f kcal/mol.", average, stdev));
                totalAverage = 0;
                totalSquare = 0;
                periodCount = 0;
            }
        }
        if (energyCount % saveFrequency == 0) {
            if (algorithmListener != null) {
                algorithmListener.algorithmUpdate(lambdaOneAssembly);
            }
            /**
             * Only the rank 0 process writes the histogram restart file.
             */
            if (rank == 0) {
                try {
                    OSRWHistogramWriter osrwHistogramRestart = new OSRWHistogramWriter(new BufferedWriter(new FileWriter(histogramFile)));
                    osrwHistogramRestart.writeHistogramFile();
                    osrwHistogramRestart.flush();
                    osrwHistogramRestart.close();
                    logger.info(String.format(" Wrote OSRW histogram restart file to %s.", histogramFile.getName()));
                } catch (IOException ex) {
                    String message = " Exception writing OSRW histogram restart file.";
                    logger.log(Level.INFO, message, ex);
                }
            }
            /**
             * All ranks write a lambda restart file.
             */
            try {
                OSRWLambdaWriter osrwLambdaRestart = new OSRWLambdaWriter(new BufferedWriter(new FileWriter(lambdaFile)));
                osrwLambdaRestart.writeLambdaFile();
                osrwLambdaRestart.flush();
                osrwLambdaRestart.close();
                logger.info(String.format(" Wrote OSRW lambda restart file to %s.", lambdaFile.getName()));
            } catch (IOException ex) {
                String message = " Exception writing OSRW lambda restart file.";
                logger.log(Level.INFO, message, ex);
            }
        }
        /**
         * Write out snapshot upon each full lambda traversal.
         */
        if (writeTraversalSnapshots) {
            double heldTraversalLambda = 0.5;
            if (!traversalInHand.isEmpty()) {
                heldTraversalLambda = Double.parseDouble(traversalInHand.get(0).split(",")[0]);
                if ((lambda > 0.2 && traversalSnapshotTarget == 0) || (lambda < 0.8 && traversalSnapshotTarget == 1)) {
                    int snapshotCounts = Integer.parseInt(traversalInHand.get(0).split(",")[1]);
                    traversalInHand.remove(0);
                    File fileToWrite;
                    int numStructures;
                    if (traversalSnapshotTarget == 0) {
                        fileToWrite = lambdaZeroFile;
                        numStructures = ++lambdaZeroStructures;
                    } else {
                        fileToWrite = lambdaOneFile;
                        numStructures = ++lambdaOneStructures;
                    }
                    try {
                        FileWriter fw = new FileWriter(fileToWrite, true);
                        BufferedWriter bw = new BufferedWriter(fw);
                        bw.write(String.format("MODEL        %d          L=%.4f  counts=%d", numStructures, heldTraversalLambda, snapshotCounts));
                        for (int i = 0; i < 50; i++) {
                            bw.write(" ");
                        }
                        bw.newLine();
                        for (int i = 0; i < traversalInHand.size(); i++) {
                            bw.write(traversalInHand.get(i));
                            bw.newLine();
                        }
                        bw.write(String.format("ENDMDL"));
                        for (int i = 0; i < 75; i++) {
                            bw.write(" ");
                        }
                        bw.newLine();
                        bw.close();
                        logger.info(String.format(" Wrote traversal structure L=%.4f", heldTraversalLambda));
                    } catch (Exception exception) {
                        logger.warning(String.format("Exception writing to file: %s", fileToWrite.getName()));
                    }
                    heldTraversalLambda = 0.5;
                    traversalInHand.clear();
                    traversalSnapshotTarget = 1 - traversalSnapshotTarget;
                }
            }
            if (((lambda < 0.1 && traversalInHand.isEmpty()) || (lambda < heldTraversalLambda - 0.025 && !traversalInHand.isEmpty())) && (traversalSnapshotTarget == 0 || traversalSnapshotTarget == -1)) {
                if (lambdaZeroFilter == null) {
                    lambdaZeroFilter = new PDBFilter(lambdaZeroFile, lambdaZeroAssembly, null, null);
                    lambdaZeroFilter.setListMode(true);
                }
                lambdaZeroFilter.clearListOutput();
                lambdaZeroFilter.writeFileWithHeader(lambdaFile, String.format("%.4f,%d,", lambda, totalCounts));
                traversalInHand = lambdaZeroFilter.getListOutput();
                traversalSnapshotTarget = 0;
            } else if (((lambda > 0.9 && traversalInHand.isEmpty()) || (lambda > heldTraversalLambda + 0.025 && !traversalInHand.isEmpty())) && (traversalSnapshotTarget == 1 || traversalSnapshotTarget == -1)) {
                if (lambdaOneFilter == null) {
                    lambdaOneFilter = new PDBFilter(lambdaOneFile, lambdaOneAssembly, null, null);
                    lambdaOneFilter.setListMode(true);
                }
                lambdaOneFilter.clearListOutput();
                lambdaOneFilter.writeFileWithHeader(lambdaFile, String.format("%.4f,%d,", lambda, totalCounts));
                traversalInHand = lambdaOneFilter.getListOutput();
                traversalSnapshotTarget = 1;
            }
        }
    }
    /**
     * Compute the energy and gradient for the recursion slave at F(L) using
     * interpolation.
     */
    double freeEnergy = currentFreeEnergy();
    biasEnergy += freeEnergy;
    if (print) {
        logger.info(String.format(" %s %16.8f", "Bias Energy       ", biasEnergy));
        logger.info(String.format(" %s %16.8f  %s", "OSRW Potential    ", e + biasEnergy, "(Kcal/mole)"));
    }
    if (propagateLambda && energyCount > 0) {
        /**
         * Log the current Lambda state.
         */
        if (energyCount % printFrequency == 0) {
            if (lambdaBins < 1000) {
                logger.info(String.format(" L=%6.4f (%3d) F_LU=%10.4f F_LB=%10.4f F_L=%10.4f", lambda, lambdaBin, dEdU, dEdLambda - dEdU, dEdLambda));
            } else {
                logger.info(String.format(" L=%6.4f (%4d) F_LU=%10.4f F_LB=%10.4f F_L=%10.4f", lambda, lambdaBin, dEdU, dEdLambda - dEdU, dEdLambda));
            }
        }
        /**
         * Metadynamics grid counts (every 'countInterval' steps).
         */
        if (energyCount % countInterval == 0) {
            if (jobBackend != null) {
                if (world.size() > 1) {
                    jobBackend.setComment(String.format("Overall dG=%10.4f at %7.3e psec, Current: [L=%6.4f, F_L=%10.4f, dG=%10.4f] at %7.3e psec", totalFreeEnergy, totalCounts * dt * countInterval, lambda, dEdU, -freeEnergy, energyCount * dt));
                } else {
                    jobBackend.setComment(String.format("Overall dG=%10.4f at %7.3e psec, Current: [L=%6.4f, F_L=%10.4f, dG=%10.4f]", totalFreeEnergy, totalCounts * dt * countInterval, lambda, dEdU, -freeEnergy));
                }
            }
            if (asynchronous) {
                asynchronousSend(lambda, dEdU);
            } else {
                synchronousSend(lambda, dEdU);
            }
        }
    }
    /**
     * Propagate the Lambda particle.
     */
    if (propagateLambda) {
        langevin();
    } else {
        equilibrationCounts++;
        if (jobBackend != null) {
            jobBackend.setComment(String.format("Equilibration [L=%6.4f, F_L=%10.4f]", lambda, dEdU));
        }
        if (equilibrationCounts % 10 == 0) {
            logger.info(String.format(" L=%6.4f, F_L=%10.4f", lambda, dEdU));
        }
    }
    totalEnergy = e + biasEnergy;
    return totalEnergy;
}
Also used : Minimize(ffx.algorithms.Minimize) FileWriter(java.io.FileWriter) IOException(java.io.IOException) IOException(java.io.IOException) FileNotFoundException(java.io.FileNotFoundException) BufferedWriter(java.io.BufferedWriter) File(java.io.File) PDBFilter(ffx.potential.parsers.PDBFilter)

Example 2 with Minimize

use of ffx.algorithms.Minimize in project ffx by mjschnie.

the class ModelingShell method minimize.

/**
 * <p>
 * minimize</p>
 *
 * @param eps a double.
 * @return a {@link ffx.numerics.Potential} object.
 */
public Potential minimize(double eps) {
    if (interrupted) {
        logger.info(" Algorithm interrupted - skipping minimization.");
        return null;
    }
    if (terminatableAlgorithm != null) {
        logger.info(" Algorithm already running - skipping minimization.");
        return null;
    }
    MolecularAssembly active = mainPanel.getHierarchy().getActive();
    if (active != null) {
        Minimize minimize = new Minimize(active, this);
        terminatableAlgorithm = minimize;
        Potential potential = minimize.minimize(eps);
        terminatableAlgorithm = null;
        return potential;
    } else {
        logger.info(" No active system to minimize.");
    }
    return null;
}
Also used : MolecularAssembly(ffx.potential.MolecularAssembly) Minimize(ffx.algorithms.Minimize) Potential(ffx.numerics.Potential)

Aggregations

Minimize (ffx.algorithms.Minimize)2 Potential (ffx.numerics.Potential)1 MolecularAssembly (ffx.potential.MolecularAssembly)1 PDBFilter (ffx.potential.parsers.PDBFilter)1 BufferedWriter (java.io.BufferedWriter)1 File (java.io.File)1 FileNotFoundException (java.io.FileNotFoundException)1 FileWriter (java.io.FileWriter)1 IOException (java.io.IOException)1