use of com.simiacryptus.mindseye.lang.PointSample in project MindsEye by SimiaCryptus.
the class LayerRateDiagnosticTrainer method run.
/**
* Run map.
*
* @return the map
*/
@Nonnull
public Map<Layer, LayerStats> run() {
final long timeoutMs = System.currentTimeMillis() + timeout.toMillis();
PointSample measure = measure();
@Nonnull final ArrayList<Layer> layers = new ArrayList<>(measure.weights.getMap().keySet());
while (timeoutMs > System.currentTimeMillis() && measure.sum > terminateThreshold) {
if (currentIteration.get() > maxIterations) {
break;
}
final PointSample initialPhasePoint = measure();
measure = initialPhasePoint;
for (int subiteration = 0; subiteration < iterationsPerSample; subiteration++) {
if (currentIteration.incrementAndGet() > maxIterations) {
break;
}
{
@Nonnull final SimpleLineSearchCursor orient = (SimpleLineSearchCursor) getOrientation().orient(subject, measure, monitor);
final double stepSize = 1e-12 * orient.origin.sum;
@Nonnull final DeltaSet<Layer> pointB = orient.step(stepSize, monitor).point.delta.copy();
@Nonnull final DeltaSet<Layer> pointA = orient.step(0.0, monitor).point.delta.copy();
@Nonnull final DeltaSet<Layer> d1 = pointA;
@Nonnull final DeltaSet<Layer> d2 = d1.add(pointB.scale(-1)).scale(1.0 / stepSize);
@Nonnull final Map<Layer, Double> steps = new HashMap<>();
final double overallStepEstimate = d1.getMagnitude() / d2.getMagnitude();
for (final Layer layer : layers) {
final DoubleBuffer<Layer> a = d2.get(layer, (double[]) null);
final DoubleBuffer<Layer> b = d1.get(layer, (double[]) null);
final double bmag = Math.sqrt(b.deltaStatistics().sumSq());
final double amag = Math.sqrt(a.deltaStatistics().sumSq());
final double dot = a.dot(b) / (amag * bmag);
final double idealSize = bmag / (amag * dot);
steps.put(layer, idealSize);
monitor.log(String.format("Layers stats: %s (%s, %s, %s) => %s", layer, amag, bmag, dot, idealSize));
}
monitor.log(String.format("Estimated ideal rates for layers: %s (%s overall; probed at %s)", steps, overallStepEstimate, stepSize));
}
@Nullable SimpleLineSearchCursor bestOrient = null;
@Nullable PointSample bestPoint = null;
layerLoop: for (@Nonnull final Layer layer : layers) {
@Nonnull SimpleLineSearchCursor orient = (SimpleLineSearchCursor) getOrientation().orient(subject, measure, monitor);
@Nonnull final DeltaSet<Layer> direction = filterDirection(orient.direction, layer);
if (direction.getMagnitude() == 0) {
monitor.log(String.format("Zero derivative for layer %s; skipping", layer));
continue layerLoop;
}
orient = new SimpleLineSearchCursor(orient.subject, orient.origin, direction);
final PointSample previous = measure;
measure = getLineSearchStrategy().step(orient, monitor);
if (isStrict()) {
monitor.log(String.format("Iteration %s reverting. Error: %s", currentIteration.get(), measure.sum));
monitor.log(String.format("Optimal rate for layer %s: %s", layer.getName(), measure.getRate()));
if (null == bestPoint || bestPoint.sum < measure.sum) {
bestOrient = orient;
bestPoint = measure;
}
getLayerRates().put(layer, new LayerStats(measure.getRate(), initialPhasePoint.sum - measure.sum));
orient.step(0, monitor);
measure = previous;
} else if (previous.sum == measure.sum) {
monitor.log(String.format("Iteration %s failed. Error: %s", currentIteration.get(), measure.sum));
} else {
monitor.log(String.format("Iteration %s complete. Error: %s", currentIteration.get(), measure.sum));
monitor.log(String.format("Optimal rate for layer %s: %s", layer.getName(), measure.getRate()));
getLayerRates().put(layer, new LayerStats(measure.getRate(), initialPhasePoint.sum - measure.sum));
}
}
monitor.log(String.format("Ideal rates: %s", getLayerRates()));
if (null != bestPoint) {
bestOrient.step(bestPoint.rate, monitor);
}
monitor.onStepComplete(new Step(measure, currentIteration.get()));
}
}
return getLayerRates();
}
use of com.simiacryptus.mindseye.lang.PointSample in project MindsEye by SimiaCryptus.
the class ValidatingTrainer method runPhase.
/**
* Epoch runPhase result.
*
* @param epochParams the runPhase params
* @param phase the phase
* @param i the
* @param seed the seed
* @return the runPhase result
*/
@Nonnull
protected EpochResult runPhase(@Nonnull final EpochParams epochParams, @Nonnull final TrainingPhase phase, final int i, final long seed) {
monitor.log(String.format("Phase %d: %s", i, phase));
phase.trainingSubject.setTrainingSize(epochParams.trainingSize);
monitor.log(String.format("resetAndMeasure; trainingSize=%s", epochParams.trainingSize));
PointSample currentPoint = reset(phase, seed).measure(phase);
final double pointMean = currentPoint.getMean();
assert 0 < currentPoint.delta.getMap().size() : "Nothing to optimize";
int step = 1;
for (; step <= epochParams.iterations || epochParams.iterations <= 0; step++) {
if (shouldHalt(monitor, epochParams.timeoutMs)) {
return new EpochResult(false, pointMean, currentPoint, step);
}
final long startTime = System.nanoTime();
final long prevGcTime = ManagementFactory.getGarbageCollectorMXBeans().stream().mapToLong(x -> x.getCollectionTime()).sum();
@Nonnull final StepResult epoch = runStep(currentPoint, phase);
final long newGcTime = ManagementFactory.getGarbageCollectorMXBeans().stream().mapToLong(x -> x.getCollectionTime()).sum();
final long endTime = System.nanoTime();
final CharSequence performance = String.format("%s in %.3f seconds; %.3f in orientation, %.3f in gc, %.3f in line search; %.3f trainAll time", epochParams.trainingSize, (endTime - startTime) / 1e9, epoch.performance[0], (newGcTime - prevGcTime) / 1e3, epoch.performance[1], trainingMeasurementTime.getAndSet(0) / 1e9);
currentPoint = epoch.currentPoint.setRate(0.0);
if (epoch.previous.getMean() <= epoch.currentPoint.getMean()) {
monitor.log(String.format("Iteration %s failed, aborting. Error: %s (%s)", currentIteration.get(), epoch.currentPoint.getMean(), performance));
return new EpochResult(false, pointMean, currentPoint, step);
} else {
monitor.log(String.format("Iteration %s complete. Error: %s (%s)", currentIteration.get(), epoch.currentPoint.getMean(), performance));
}
monitor.onStepComplete(new Step(currentPoint, currentIteration.get()));
}
return new EpochResult(true, pointMean, currentPoint, step);
}
use of com.simiacryptus.mindseye.lang.PointSample in project MindsEye by SimiaCryptus.
the class QuadraticSearch method filter.
private PointSample filter(@Nonnull final LineSearchCursor cursor, @Nonnull final PointSample point, final TrainingMonitor monitor) {
if (stepSize == 1.0) {
point.addRef();
return point;
} else {
LineSearchPoint step = cursor.step(point.rate * stepSize, monitor);
PointSample point1 = step.point;
point1.addRef();
step.freeRef();
return point1;
}
}
use of com.simiacryptus.mindseye.lang.PointSample in project MindsEye by SimiaCryptus.
the class QuadraticSearch method step.
@Override
public PointSample step(@Nonnull final LineSearchCursor cursor, @Nonnull final TrainingMonitor monitor) {
if (currentRate < getMinRate()) {
currentRate = getMinRate();
}
final PointSample pointSample = _step(cursor, monitor);
setCurrentRate(pointSample.rate);
return pointSample;
}
use of com.simiacryptus.mindseye.lang.PointSample in project MindsEye by SimiaCryptus.
the class StaticLearningRate method step.
@Override
public PointSample step(@Nonnull final LineSearchCursor cursor, @Nonnull final TrainingMonitor monitor) {
double thisRate = rate;
final LineSearchPoint startPoint = cursor.step(0, monitor);
// theta(0)
final double startValue = startPoint.point.sum;
@Nullable LineSearchPoint lastStep = null;
while (true) {
if (null != lastStep)
lastStep.freeRef();
lastStep = cursor.step(thisRate, monitor);
double lastValue = lastStep.point.sum;
if (!Double.isFinite(lastValue)) {
lastValue = Double.POSITIVE_INFINITY;
}
if (lastValue + startValue * 1e-15 > startValue) {
monitor.log(String.format("Non-decreasing runStep. %s > %s at " + thisRate, lastValue, startValue));
thisRate /= 2;
if (thisRate < getMinimumRate()) {
if (null != lastStep)
lastStep.freeRef();
PointSample point = startPoint.point;
point.addRef();
startPoint.freeRef();
return point;
}
} else {
PointSample point = lastStep.point;
point.addRef();
startPoint.freeRef();
lastStep.freeRef();
return point;
}
}
}
Aggregations