use of ffx.algorithms.Minimize in project ffx by mjschnie.
the class Looptimizer method energyAndGradient.
@Override
public double energyAndGradient(double[] x, double[] gradient) {
double e = potential.energyAndGradient(x, gradient);
/**
* OSRW is propagated with the slowly varying terms.
*/
if (state == STATE.FAST) {
return e;
}
if (osrwOptimization && lambda > osrwOptimizationLambdaCutoff) {
if (energyCount % osrwOptimizationFrequency == 0) {
logger.info(String.format(" OSRW Minimization (Step %d)", energyCount));
// Set Lambda value to 1.0.
lambdaInterface.setLambda(1.0);
potential.setEnergyTermState(STATE.BOTH);
RefinementMinimize refinementMinimize = null;
Minimize minimize = null;
double[] xStart = null;
double[] xFinal = null;
// Optimize the system.
if (useXRayMinimizer) {
refinementMinimize = new RefinementMinimize(diffractionData);
int n = refinementMinimize.refinementEnergy.getNumberOfVariables();
xStart = new double[n];
xStart = refinementMinimize.refinementEnergy.getCoordinates(xStart);
refinementMinimize.minimize(osrwOptimizationEps);
xFinal = new double[n];
xFinal = refinementMinimize.refinementEnergy.getCoordinates(xFinal);
} else {
minimize = new Minimize(null, potential, null);
int n = potential.getNumberOfVariables();
xStart = new double[n];
xStart = potential.getCoordinates(xStart);
minimize.minimize(osrwOptimizationEps);
xFinal = new double[n];
xFinal = potential.getCoordinates(xFinal);
}
double minValue;
if (useXRayMinimizer) {
// Collect the minimum R value.
minValue = diffractionData.getRCrystalStat();
} else {
// Collect the minimum energy.
minValue = potential.getTotalEnergy();
}
// If a new minimum has been found, save its coordinates.
if (minValue < osrwOptimum) {
osrwOptimum = minValue;
if (useXRayMinimizer) {
logger.info(String.format(" New minimum R found: %16.8f (Step %d).", osrwOptimum, energyCount));
} else {
logger.info(String.format(" New minimum energy found: %16.8f (Step %d).", osrwOptimum, energyCount));
}
osrwOptimumCoords = xFinal;
if (pdbFilter.writeFile(pdbFile, false)) {
logger.info(String.format(" Wrote PDB file to " + pdbFile.getName()));
}
}
/**
* Reset coordinates for X-ray minimization (parameters may
* include B-Factors).
*/
if (useXRayMinimizer) {
refinementMinimize.refinementEnergy.energy(xStart);
}
/**
* Revert to the coordinates, gradient lambda, and RESPA State
* prior to optimization.
*/
potential.setScaling(null);
lambdaInterface.setLambda(lambda);
potential.setEnergyTermState(state);
double eCheck = potential.energyAndGradient(x, gradient);
if (abs(eCheck - e) > osrwOptimizationTolerance) {
logger.warning(String.format(" OSRW optimization could not revert coordinates %16.8f vs. %16.8f.", e, eCheck));
}
}
}
double biasEnergy = 0.0;
dEdLambda = lambdaInterface.getdEdL();
d2EdLambda2 = lambdaInterface.getd2EdL2();
int lambdaBin = binForLambda(lambda);
int FLambdaBin = binForFLambda(dEdLambda);
double dEdU = dEdLambda;
if (propagateLambda) {
energyCount++;
}
/**
* Calculate recursion kernel G(L, F_L) and its derivatives with respect
* to L and F_L.
*/
double dGdLambda = 0.0;
double dGdFLambda = 0.0;
double ls2 = (2.0 * dL) * (2.0 * dL);
double FLs2 = (2.0 * dFL) * (2.0 * dFL);
for (int iL = -biasCutoff; iL <= biasCutoff; iL++) {
int lcenter = lambdaBin + iL;
double deltaL = lambda - (lcenter * dL);
double deltaL2 = deltaL * deltaL;
// Mirror conditions for recursion kernel counts.
int lcount = lcenter;
double mirrorFactor = 1.0;
if (lcount == 0 || lcount == lambdaBins - 1) {
mirrorFactor = 2.0;
} else if (lcount < 0) {
lcount = -lcount;
} else if (lcount > lambdaBins - 1) {
// Number of bins past the last bin
lcount -= (lambdaBins - 1);
// Mirror bin
lcount = lambdaBins - 1 - lcount;
}
for (int iFL = -biasCutoff; iFL <= biasCutoff; iFL++) {
int FLcenter = FLambdaBin + iFL;
/**
* If either of the following FL edge conditions are true, then
* there are no counts and we continue.
*/
if (FLcenter < 0 || FLcenter >= FLambdaBins) {
continue;
}
double deltaFL = dEdLambda - (minFLambda + FLcenter * dFL + dFL_2);
double deltaFL2 = deltaFL * deltaFL;
double weight = mirrorFactor * recursionKernel[lcount][FLcenter];
double bias = weight * biasMag * exp(-deltaL2 / (2.0 * ls2)) * exp(-deltaFL2 / (2.0 * FLs2));
biasEnergy += bias;
dGdLambda -= deltaL / ls2 * bias;
dGdFLambda -= deltaFL / FLs2 * bias;
}
}
/**
* Lambda gradient due to recursion kernel G(L, F_L).
*/
dEdLambda += dGdLambda + dGdFLambda * d2EdLambda2;
/**
* Cartesian coordinate gradient due to recursion kernel G(L, F_L).
*/
fill(dUdXdL, 0.0);
lambdaInterface.getdEdXdL(dUdXdL);
for (int i = 0; i < nVariables; i++) {
gradient[i] += dGdFLambda * dUdXdL[i];
}
if (propagateLambda && energyCount > 0) {
/**
* Update free energy F(L) every ~10 steps.
*/
if (energyCount % 10 == 0) {
fLambdaUpdates++;
boolean printFLambda = fLambdaUpdates % fLambdaPrintInterval == 0;
totalFreeEnergy = updateFLambda(printFLambda);
/**
* Calculating Moving Average & Standard Deviation
*/
totalAverage += totalFreeEnergy;
totalSquare += Math.pow(totalFreeEnergy, 2);
periodCount++;
if (periodCount == window - 1) {
double average = totalAverage / window;
double stdev = Math.sqrt((totalSquare - Math.pow(totalAverage, 2) / window) / window);
logger.info(String.format(" The running average is %12.4f kcal/mol and the stdev is %8.4f kcal/mol.", average, stdev));
totalAverage = 0;
totalSquare = 0;
periodCount = 0;
}
}
if (energyCount % saveFrequency == 0) {
if (algorithmListener != null) {
algorithmListener.algorithmUpdate(lambdaOneAssembly);
}
/**
* Only the rank 0 process writes the histogram restart file.
*/
if (rank == 0) {
try {
OSRWHistogramWriter osrwHistogramRestart = new OSRWHistogramWriter(new BufferedWriter(new FileWriter(histogramFile)));
osrwHistogramRestart.writeHistogramFile();
osrwHistogramRestart.flush();
osrwHistogramRestart.close();
logger.info(String.format(" Wrote OSRW histogram restart file to %s.", histogramFile.getName()));
} catch (IOException ex) {
String message = " Exception writing OSRW histogram restart file.";
logger.log(Level.INFO, message, ex);
}
}
/**
* All ranks write a lambda restart file.
*/
try {
OSRWLambdaWriter osrwLambdaRestart = new OSRWLambdaWriter(new BufferedWriter(new FileWriter(lambdaFile)));
osrwLambdaRestart.writeLambdaFile();
osrwLambdaRestart.flush();
osrwLambdaRestart.close();
logger.info(String.format(" Wrote OSRW lambda restart file to %s.", lambdaFile.getName()));
} catch (IOException ex) {
String message = " Exception writing OSRW lambda restart file.";
logger.log(Level.INFO, message, ex);
}
}
/**
* Write out snapshot upon each full lambda traversal.
*/
if (writeTraversalSnapshots) {
double heldTraversalLambda = 0.5;
if (!traversalInHand.isEmpty()) {
heldTraversalLambda = Double.parseDouble(traversalInHand.get(0).split(",")[0]);
if ((lambda > 0.2 && traversalSnapshotTarget == 0) || (lambda < 0.8 && traversalSnapshotTarget == 1)) {
int snapshotCounts = Integer.parseInt(traversalInHand.get(0).split(",")[1]);
traversalInHand.remove(0);
File fileToWrite;
int numStructures;
if (traversalSnapshotTarget == 0) {
fileToWrite = lambdaZeroFile;
numStructures = ++lambdaZeroStructures;
} else {
fileToWrite = lambdaOneFile;
numStructures = ++lambdaOneStructures;
}
try {
FileWriter fw = new FileWriter(fileToWrite, true);
BufferedWriter bw = new BufferedWriter(fw);
bw.write(String.format("MODEL %d L=%.4f counts=%d", numStructures, heldTraversalLambda, snapshotCounts));
for (int i = 0; i < 50; i++) {
bw.write(" ");
}
bw.newLine();
for (int i = 0; i < traversalInHand.size(); i++) {
bw.write(traversalInHand.get(i));
bw.newLine();
}
bw.write(String.format("ENDMDL"));
for (int i = 0; i < 75; i++) {
bw.write(" ");
}
bw.newLine();
bw.close();
logger.info(String.format(" Wrote traversal structure L=%.4f", heldTraversalLambda));
} catch (Exception exception) {
logger.warning(String.format("Exception writing to file: %s", fileToWrite.getName()));
}
heldTraversalLambda = 0.5;
traversalInHand.clear();
traversalSnapshotTarget = 1 - traversalSnapshotTarget;
}
}
if (((lambda < 0.1 && traversalInHand.isEmpty()) || (lambda < heldTraversalLambda - 0.025 && !traversalInHand.isEmpty())) && (traversalSnapshotTarget == 0 || traversalSnapshotTarget == -1)) {
if (lambdaZeroFilter == null) {
lambdaZeroFilter = new PDBFilter(lambdaZeroFile, lambdaZeroAssembly, null, null);
lambdaZeroFilter.setListMode(true);
}
lambdaZeroFilter.clearListOutput();
lambdaZeroFilter.writeFileWithHeader(lambdaFile, String.format("%.4f,%d,", lambda, totalCounts));
traversalInHand = lambdaZeroFilter.getListOutput();
traversalSnapshotTarget = 0;
} else if (((lambda > 0.9 && traversalInHand.isEmpty()) || (lambda > heldTraversalLambda + 0.025 && !traversalInHand.isEmpty())) && (traversalSnapshotTarget == 1 || traversalSnapshotTarget == -1)) {
if (lambdaOneFilter == null) {
lambdaOneFilter = new PDBFilter(lambdaOneFile, lambdaOneAssembly, null, null);
lambdaOneFilter.setListMode(true);
}
lambdaOneFilter.clearListOutput();
lambdaOneFilter.writeFileWithHeader(lambdaFile, String.format("%.4f,%d,", lambda, totalCounts));
traversalInHand = lambdaOneFilter.getListOutput();
traversalSnapshotTarget = 1;
}
}
}
/**
* Compute the energy and gradient for the recursion slave at F(L) using
* interpolation.
*/
double freeEnergy = currentFreeEnergy();
biasEnergy += freeEnergy;
if (print) {
logger.info(String.format(" %s %16.8f", "Bias Energy ", biasEnergy));
logger.info(String.format(" %s %16.8f %s", "OSRW Potential ", e + biasEnergy, "(Kcal/mole)"));
}
if (propagateLambda && energyCount > 0) {
/**
* Log the current Lambda state.
*/
if (energyCount % printFrequency == 0) {
if (lambdaBins < 1000) {
logger.info(String.format(" L=%6.4f (%3d) F_LU=%10.4f F_LB=%10.4f F_L=%10.4f", lambda, lambdaBin, dEdU, dEdLambda - dEdU, dEdLambda));
} else {
logger.info(String.format(" L=%6.4f (%4d) F_LU=%10.4f F_LB=%10.4f F_L=%10.4f", lambda, lambdaBin, dEdU, dEdLambda - dEdU, dEdLambda));
}
}
/**
* Metadynamics grid counts (every 'countInterval' steps).
*/
if (energyCount % countInterval == 0) {
if (jobBackend != null) {
if (world.size() > 1) {
jobBackend.setComment(String.format("Overall dG=%10.4f at %7.3e psec, Current: [L=%6.4f, F_L=%10.4f, dG=%10.4f] at %7.3e psec", totalFreeEnergy, totalCounts * dt * countInterval, lambda, dEdU, -freeEnergy, energyCount * dt));
} else {
jobBackend.setComment(String.format("Overall dG=%10.4f at %7.3e psec, Current: [L=%6.4f, F_L=%10.4f, dG=%10.4f]", totalFreeEnergy, totalCounts * dt * countInterval, lambda, dEdU, -freeEnergy));
}
}
if (asynchronous) {
asynchronousSend(lambda, dEdU);
} else {
synchronousSend(lambda, dEdU);
}
}
}
/**
* Propagate the Lambda particle.
*/
if (propagateLambda) {
langevin();
} else {
equilibrationCounts++;
if (jobBackend != null) {
jobBackend.setComment(String.format("Equilibration [L=%6.4f, F_L=%10.4f]", lambda, dEdU));
}
if (equilibrationCounts % 10 == 0) {
logger.info(String.format(" L=%6.4f, F_L=%10.4f", lambda, dEdU));
}
}
totalEnergy = e + biasEnergy;
return totalEnergy;
}
use of ffx.algorithms.Minimize in project ffx by mjschnie.
the class ModelingShell method minimize.
/**
* <p>
* minimize</p>
*
* @param eps a double.
* @return a {@link ffx.numerics.Potential} object.
*/
public Potential minimize(double eps) {
if (interrupted) {
logger.info(" Algorithm interrupted - skipping minimization.");
return null;
}
if (terminatableAlgorithm != null) {
logger.info(" Algorithm already running - skipping minimization.");
return null;
}
MolecularAssembly active = mainPanel.getHierarchy().getActive();
if (active != null) {
Minimize minimize = new Minimize(active, this);
terminatableAlgorithm = minimize;
Potential potential = minimize.minimize(eps);
terminatableAlgorithm = null;
return potential;
} else {
logger.info(" No active system to minimize.");
}
return null;
}
Aggregations