use of com.simiacryptus.mindseye.lang.PointSample in project MindsEye by SimiaCryptus.
the class LBFGS method lbfgs.
private boolean lbfgs(@Nonnull PointSample measurement, @Nonnull TrainingMonitor monitor, @Nonnull List<PointSample> history, @Nonnull DeltaSet<Layer> direction) {
try {
@Nonnull DeltaSet<Layer> p = measurement.delta.copy();
if (!p.stream().parallel().allMatch(y -> Arrays.stream(y.getDelta()).allMatch(d -> Double.isFinite(d)))) {
throw new IllegalStateException("Non-finite value");
}
@Nonnull final double[] alphas = new double[history.size()];
for (int i = history.size() - 2; i >= 0; i--) {
@Nonnull final DeltaSet<Layer> sd = history.get(i + 1).weights.subtract(history.get(i).weights);
@Nonnull final DeltaSet<Layer> yd = history.get(i + 1).delta.subtract(history.get(i).delta);
final double denominator = sd.dot(yd);
if (0 == denominator) {
throw new IllegalStateException("Orientation vanished.");
}
alphas[i] = p.dot(sd) / denominator;
p = p.subtract(yd.scale(alphas[i]));
if ((!p.stream().parallel().allMatch(y -> Arrays.stream(y.getDelta()).allMatch(d -> Double.isFinite(d))))) {
throw new IllegalStateException("Non-finite value");
}
}
@Nonnull final DeltaSet<Layer> sk = history.get(history.size() - 1).weights.subtract(history.get(history.size() - 2).weights);
@Nonnull final DeltaSet<Layer> yk = history.get(history.size() - 1).delta.subtract(history.get(history.size() - 2).delta);
p = p.scale(sk.dot(yk) / yk.dot(yk));
if (!p.stream().parallel().allMatch(y -> Arrays.stream(y.getDelta()).allMatch(d -> Double.isFinite(d)))) {
throw new IllegalStateException("Non-finite value");
}
for (int i = 0; i < history.size() - 1; i++) {
@Nonnull final DeltaSet<Layer> sd = history.get(i + 1).weights.subtract(history.get(i).weights);
@Nonnull final DeltaSet<Layer> yd = history.get(i + 1).delta.subtract(history.get(i).delta);
final double beta = p.dot(yd) / sd.dot(yd);
p = p.add(sd.scale(alphas[i] - beta));
if (!p.stream().parallel().allMatch(y -> Arrays.stream(y.getDelta()).allMatch(d -> Double.isFinite(d)))) {
throw new IllegalStateException("Non-finite value");
}
}
boolean accept = measurement.delta.dot(p) < 0;
if (accept) {
monitor.log("Accepted: " + new Stats(direction, p));
copy(p, direction);
} else {
monitor.log("Rejected: " + new Stats(direction, p));
}
return accept;
} catch (Throwable e) {
monitor.log(String.format("LBFGS Orientation Error: %s", e.getMessage()));
return false;
}
}
use of com.simiacryptus.mindseye.lang.PointSample in project MindsEye by SimiaCryptus.
the class IterativeTrainer method measure.
/**
* Measure point sample.
*
* @param reset the reset
* @return the point sample
*/
@Nullable
public PointSample measure(boolean reset) {
@Nullable PointSample currentPoint = null;
int retries = 0;
do {
if (reset) {
orientation.reset();
if (subject.getLayer() instanceof DAGNetwork) {
((DAGNetwork) subject.getLayer()).visitLayers(layer -> {
if (layer instanceof StochasticComponent)
((StochasticComponent) layer).shuffle(StochasticComponent.random.get().nextLong());
});
}
if (!subject.reseed(System.nanoTime())) {
if (retries > 0)
throw new IterativeStopException("Failed to reset training subject");
} else {
monitor.log(String.format("Reset training subject"));
}
}
if (null != currentPoint) {
currentPoint.freeRef();
}
currentPoint = subject.measure(monitor);
} while (!Double.isFinite(currentPoint.getMean()) && 10 < retries++);
if (!Double.isFinite(currentPoint.getMean())) {
currentPoint.freeRef();
throw new IterativeStopException();
}
return currentPoint;
}
use of com.simiacryptus.mindseye.lang.PointSample in project MindsEye by SimiaCryptus.
the class IterativeTrainer method run.
/**
* Run double.
*
* @return the double
*/
public double run() {
final long timeoutMs = System.currentTimeMillis() + timeout.toMillis();
long lastIterationTime = System.nanoTime();
@Nullable PointSample currentPoint = measure(true);
mainLoop: while (timeoutMs > System.currentTimeMillis() && currentPoint.getMean() > terminateThreshold) {
if (currentIteration.get() > maxIterations) {
break;
}
currentPoint.freeRef();
currentPoint = measure(true);
assert 0 < currentPoint.delta.getMap().size() : "Nothing to optimize";
subiterationLoop: for (int subiteration = 0; subiteration < iterationsPerSample || iterationsPerSample <= 0; subiteration++) {
if (timeoutMs < System.currentTimeMillis()) {
break mainLoop;
}
if (currentIteration.incrementAndGet() > maxIterations) {
break mainLoop;
}
currentPoint.freeRef();
currentPoint = measure(true);
@Nullable final PointSample _currentPoint = currentPoint;
@Nonnull final TimedResult<LineSearchCursor> timedOrientation = TimedResult.time(() -> orientation.orient(subject, _currentPoint, monitor));
final LineSearchCursor direction = timedOrientation.result;
final CharSequence directionType = direction.getDirectionType();
@Nullable final PointSample previous = currentPoint;
previous.addRef();
try {
@Nonnull final TimedResult<PointSample> timedLineSearch = TimedResult.time(() -> step(direction, directionType, previous));
currentPoint.freeRef();
currentPoint = timedLineSearch.result;
final long now = System.nanoTime();
final CharSequence perfString = String.format("Total: %.4f; Orientation: %.4f; Line Search: %.4f", (now - lastIterationTime) / 1e9, timedOrientation.timeNanos / 1e9, timedLineSearch.timeNanos / 1e9);
lastIterationTime = now;
monitor.log(String.format("Fitness changed from %s to %s", previous.getMean(), currentPoint.getMean()));
if (previous.getMean() <= currentPoint.getMean()) {
if (previous.getMean() < currentPoint.getMean()) {
monitor.log(String.format("Resetting Iteration %s", perfString));
currentPoint.freeRef();
currentPoint = direction.step(0, monitor).point;
} else {
monitor.log(String.format("Static Iteration %s", perfString));
}
if (subject.reseed(System.nanoTime())) {
monitor.log(String.format("Iteration %s failed, retrying. Error: %s", currentIteration.get(), currentPoint.getMean()));
monitor.log(String.format("Previous Error: %s -> %s", previous.getRate(), previous.getMean()));
break subiterationLoop;
} else {
monitor.log(String.format("Iteration %s failed, aborting. Error: %s", currentIteration.get(), currentPoint.getMean()));
monitor.log(String.format("Previous Error: %s -> %s", previous.getRate(), previous.getMean()));
break mainLoop;
}
} else {
monitor.log(String.format("Iteration %s complete. Error: %s " + perfString, currentIteration.get(), currentPoint.getMean()));
}
monitor.onStepComplete(new Step(currentPoint, currentIteration.get()));
} finally {
previous.freeRef();
direction.freeRef();
}
}
}
if (subject.getLayer() instanceof DAGNetwork) {
((DAGNetwork) subject.getLayer()).visitLayers(layer -> {
if (layer instanceof StochasticComponent)
((StochasticComponent) layer).clearNoise();
});
}
double mean = null == currentPoint ? Double.NaN : currentPoint.getMean();
currentPoint.freeRef();
return mean;
}
use of com.simiacryptus.mindseye.lang.PointSample in project MindsEye by SimiaCryptus.
the class RoundRobinTrainer method run.
/**
* Run double.
*
* @return the double
*/
public double run() {
final long timeoutMs = System.currentTimeMillis() + timeout.toMillis();
PointSample currentPoint = measure();
mainLoop: while (timeoutMs > System.currentTimeMillis() && currentPoint.sum > terminateThreshold) {
if (currentIteration.get() > maxIterations) {
break;
}
currentPoint = measure();
subiterationLoop: for (int subiteration = 0; subiteration < iterationsPerSample; subiteration++) {
final PointSample previousOrientations = currentPoint;
for (@Nonnull final OrientationStrategy<?> orientation : orientations) {
if (currentIteration.incrementAndGet() > maxIterations) {
break;
}
final LineSearchCursor direction = orientation.orient(subject, currentPoint, monitor);
@Nonnull final CharSequence directionType = direction.getDirectionType() + "+" + Long.toHexString(System.identityHashCode(orientation));
LineSearchStrategy lineSearchStrategy;
if (lineSearchStrategyMap.containsKey(directionType)) {
lineSearchStrategy = lineSearchStrategyMap.get(directionType);
} else {
log.info(String.format("Constructing line search parameters: %s", directionType));
lineSearchStrategy = lineSearchFactory.apply(directionType);
lineSearchStrategyMap.put(directionType, lineSearchStrategy);
}
final PointSample previous = currentPoint;
currentPoint = lineSearchStrategy.step(direction, monitor);
monitor.onStepComplete(new Step(currentPoint, currentIteration.get()));
if (previous.sum == currentPoint.sum) {
monitor.log(String.format("Iteration %s failed, ignoring. Error: %s", currentIteration.get(), currentPoint.sum));
} else {
monitor.log(String.format("Iteration %s complete. Error: %s", currentIteration.get(), currentPoint.sum));
}
}
if (previousOrientations.sum <= currentPoint.sum) {
if (subject.reseed(System.nanoTime())) {
monitor.log(String.format("MacroIteration %s failed, retrying. Error: %s", currentIteration.get(), currentPoint.sum));
break subiterationLoop;
} else {
monitor.log(String.format("MacroIteration %s failed, aborting. Error: %s", currentIteration.get(), currentPoint.sum));
break mainLoop;
}
}
}
}
return null == currentPoint ? Double.NaN : currentPoint.sum;
}
use of com.simiacryptus.mindseye.lang.PointSample in project MindsEye by SimiaCryptus.
the class RoundRobinTrainer method measure.
/**
* Measure point sample.
*
* @return the point sample
*/
public PointSample measure() {
PointSample currentPoint;
int retries = 0;
do {
if (!subject.reseed(System.nanoTime()) && retries > 0)
throw new IterativeStopException();
if (10 < retries++)
throw new IterativeStopException();
currentPoint = subject.measure(monitor);
} while (!Double.isFinite(currentPoint.sum));
assert Double.isFinite(currentPoint.sum);
return currentPoint;
}
Aggregations