Search in sources :

Example 26 with PointSample

use of com.simiacryptus.mindseye.lang.PointSample in project MindsEye by SimiaCryptus.

the class LayerRateDiagnosticTrainer method run.

/**
 * Run map.
 *
 * @return the map
 */
@Nonnull
public Map<Layer, LayerStats> run() {
    final long timeoutMs = System.currentTimeMillis() + timeout.toMillis();
    PointSample measure = measure();
    @Nonnull final ArrayList<Layer> layers = new ArrayList<>(measure.weights.getMap().keySet());
    while (timeoutMs > System.currentTimeMillis() && measure.sum > terminateThreshold) {
        if (currentIteration.get() > maxIterations) {
            break;
        }
        final PointSample initialPhasePoint = measure();
        measure = initialPhasePoint;
        for (int subiteration = 0; subiteration < iterationsPerSample; subiteration++) {
            if (currentIteration.incrementAndGet() > maxIterations) {
                break;
            }
            {
                @Nonnull final SimpleLineSearchCursor orient = (SimpleLineSearchCursor) getOrientation().orient(subject, measure, monitor);
                final double stepSize = 1e-12 * orient.origin.sum;
                @Nonnull final DeltaSet<Layer> pointB = orient.step(stepSize, monitor).point.delta.copy();
                @Nonnull final DeltaSet<Layer> pointA = orient.step(0.0, monitor).point.delta.copy();
                @Nonnull final DeltaSet<Layer> d1 = pointA;
                @Nonnull final DeltaSet<Layer> d2 = d1.add(pointB.scale(-1)).scale(1.0 / stepSize);
                @Nonnull final Map<Layer, Double> steps = new HashMap<>();
                final double overallStepEstimate = d1.getMagnitude() / d2.getMagnitude();
                for (final Layer layer : layers) {
                    final DoubleBuffer<Layer> a = d2.get(layer, (double[]) null);
                    final DoubleBuffer<Layer> b = d1.get(layer, (double[]) null);
                    final double bmag = Math.sqrt(b.deltaStatistics().sumSq());
                    final double amag = Math.sqrt(a.deltaStatistics().sumSq());
                    final double dot = a.dot(b) / (amag * bmag);
                    final double idealSize = bmag / (amag * dot);
                    steps.put(layer, idealSize);
                    monitor.log(String.format("Layers stats: %s (%s, %s, %s) => %s", layer, amag, bmag, dot, idealSize));
                }
                monitor.log(String.format("Estimated ideal rates for layers: %s (%s overall; probed at %s)", steps, overallStepEstimate, stepSize));
            }
            @Nullable SimpleLineSearchCursor bestOrient = null;
            @Nullable PointSample bestPoint = null;
            layerLoop: for (@Nonnull final Layer layer : layers) {
                @Nonnull SimpleLineSearchCursor orient = (SimpleLineSearchCursor) getOrientation().orient(subject, measure, monitor);
                @Nonnull final DeltaSet<Layer> direction = filterDirection(orient.direction, layer);
                if (direction.getMagnitude() == 0) {
                    monitor.log(String.format("Zero derivative for layer %s; skipping", layer));
                    continue layerLoop;
                }
                orient = new SimpleLineSearchCursor(orient.subject, orient.origin, direction);
                final PointSample previous = measure;
                measure = getLineSearchStrategy().step(orient, monitor);
                if (isStrict()) {
                    monitor.log(String.format("Iteration %s reverting. Error: %s", currentIteration.get(), measure.sum));
                    monitor.log(String.format("Optimal rate for layer %s: %s", layer.getName(), measure.getRate()));
                    if (null == bestPoint || bestPoint.sum < measure.sum) {
                        bestOrient = orient;
                        bestPoint = measure;
                    }
                    getLayerRates().put(layer, new LayerStats(measure.getRate(), initialPhasePoint.sum - measure.sum));
                    orient.step(0, monitor);
                    measure = previous;
                } else if (previous.sum == measure.sum) {
                    monitor.log(String.format("Iteration %s failed. Error: %s", currentIteration.get(), measure.sum));
                } else {
                    monitor.log(String.format("Iteration %s complete. Error: %s", currentIteration.get(), measure.sum));
                    monitor.log(String.format("Optimal rate for layer %s: %s", layer.getName(), measure.getRate()));
                    getLayerRates().put(layer, new LayerStats(measure.getRate(), initialPhasePoint.sum - measure.sum));
                }
            }
            monitor.log(String.format("Ideal rates: %s", getLayerRates()));
            if (null != bestPoint) {
                bestOrient.step(bestPoint.rate, monitor);
            }
            monitor.onStepComplete(new Step(measure, currentIteration.get()));
        }
    }
    return getLayerRates();
}
Also used : DoubleBuffer(com.simiacryptus.mindseye.lang.DoubleBuffer) Nonnull(javax.annotation.Nonnull) ArrayList(java.util.ArrayList) DeltaSet(com.simiacryptus.mindseye.lang.DeltaSet) Layer(com.simiacryptus.mindseye.lang.Layer) SimpleLineSearchCursor(com.simiacryptus.mindseye.opt.line.SimpleLineSearchCursor) PointSample(com.simiacryptus.mindseye.lang.PointSample) HashMap(java.util.HashMap) Map(java.util.Map) Nullable(javax.annotation.Nullable) Nonnull(javax.annotation.Nonnull)

Example 27 with PointSample

use of com.simiacryptus.mindseye.lang.PointSample in project MindsEye by SimiaCryptus.

the class ValidatingTrainer method runPhase.

/**
 * Epoch runPhase result.
 *
 * @param epochParams the runPhase params
 * @param phase       the phase
 * @param i           the
 * @param seed        the seed
 * @return the runPhase result
 */
@Nonnull
protected EpochResult runPhase(@Nonnull final EpochParams epochParams, @Nonnull final TrainingPhase phase, final int i, final long seed) {
    monitor.log(String.format("Phase %d: %s", i, phase));
    phase.trainingSubject.setTrainingSize(epochParams.trainingSize);
    monitor.log(String.format("resetAndMeasure; trainingSize=%s", epochParams.trainingSize));
    PointSample currentPoint = reset(phase, seed).measure(phase);
    final double pointMean = currentPoint.getMean();
    assert 0 < currentPoint.delta.getMap().size() : "Nothing to optimize";
    int step = 1;
    for (; step <= epochParams.iterations || epochParams.iterations <= 0; step++) {
        if (shouldHalt(monitor, epochParams.timeoutMs)) {
            return new EpochResult(false, pointMean, currentPoint, step);
        }
        final long startTime = System.nanoTime();
        final long prevGcTime = ManagementFactory.getGarbageCollectorMXBeans().stream().mapToLong(x -> x.getCollectionTime()).sum();
        @Nonnull final StepResult epoch = runStep(currentPoint, phase);
        final long newGcTime = ManagementFactory.getGarbageCollectorMXBeans().stream().mapToLong(x -> x.getCollectionTime()).sum();
        final long endTime = System.nanoTime();
        final CharSequence performance = String.format("%s in %.3f seconds; %.3f in orientation, %.3f in gc, %.3f in line search; %.3f trainAll time", epochParams.trainingSize, (endTime - startTime) / 1e9, epoch.performance[0], (newGcTime - prevGcTime) / 1e3, epoch.performance[1], trainingMeasurementTime.getAndSet(0) / 1e9);
        currentPoint = epoch.currentPoint.setRate(0.0);
        if (epoch.previous.getMean() <= epoch.currentPoint.getMean()) {
            monitor.log(String.format("Iteration %s failed, aborting. Error: %s (%s)", currentIteration.get(), epoch.currentPoint.getMean(), performance));
            return new EpochResult(false, pointMean, currentPoint, step);
        } else {
            monitor.log(String.format("Iteration %s complete. Error: %s (%s)", currentIteration.get(), epoch.currentPoint.getMean(), performance));
        }
        monitor.onStepComplete(new Step(currentPoint, currentIteration.get()));
    }
    return new EpochResult(true, pointMean, currentPoint, step);
}
Also used : IntStream(java.util.stream.IntStream) Arrays(java.util.Arrays) DoubleStatistics(com.simiacryptus.util.data.DoubleStatistics) TemporalUnit(java.time.temporal.TemporalUnit) LineSearchStrategy(com.simiacryptus.mindseye.opt.line.LineSearchStrategy) SampledTrainable(com.simiacryptus.mindseye.eval.SampledTrainable) TrainableBase(com.simiacryptus.mindseye.eval.TrainableBase) DoubleBuffer(com.simiacryptus.mindseye.lang.DoubleBuffer) HashMap(java.util.HashMap) SampledCachedTrainable(com.simiacryptus.mindseye.eval.SampledCachedTrainable) ArmijoWolfeSearch(com.simiacryptus.mindseye.opt.line.ArmijoWolfeSearch) Function(java.util.function.Function) StateSet(com.simiacryptus.mindseye.lang.StateSet) ArrayList(java.util.ArrayList) TrainableWrapper(com.simiacryptus.mindseye.eval.TrainableWrapper) Trainable(com.simiacryptus.mindseye.eval.Trainable) AtomicInteger(java.util.concurrent.atomic.AtomicInteger) Duration(java.time.Duration) Map(java.util.Map) Layer(com.simiacryptus.mindseye.lang.Layer) ManagementFactory(java.lang.management.ManagementFactory) FailsafeLineSearchCursor(com.simiacryptus.mindseye.opt.line.FailsafeLineSearchCursor) LineSearchCursor(com.simiacryptus.mindseye.opt.line.LineSearchCursor) Nonnull(javax.annotation.Nonnull) IterativeStopException(com.simiacryptus.mindseye.lang.IterativeStopException) Util(com.simiacryptus.util.Util) StochasticComponent(com.simiacryptus.mindseye.layers.java.StochasticComponent) QQN(com.simiacryptus.mindseye.opt.orient.QQN) OrientationStrategy(com.simiacryptus.mindseye.opt.orient.OrientationStrategy) FastRandom(com.simiacryptus.util.FastRandom) Collectors(java.util.stream.Collectors) TimeUnit(java.util.concurrent.TimeUnit) AtomicLong(java.util.concurrent.atomic.AtomicLong) List(java.util.List) ChronoUnit(java.time.temporal.ChronoUnit) TimedResult(com.simiacryptus.util.lang.TimedResult) DAGNetwork(com.simiacryptus.mindseye.network.DAGNetwork) PointSample(com.simiacryptus.mindseye.lang.PointSample) Nonnull(javax.annotation.Nonnull) PointSample(com.simiacryptus.mindseye.lang.PointSample) Nonnull(javax.annotation.Nonnull)

Example 28 with PointSample

use of com.simiacryptus.mindseye.lang.PointSample in project MindsEye by SimiaCryptus.

the class QuadraticSearch method filter.

private PointSample filter(@Nonnull final LineSearchCursor cursor, @Nonnull final PointSample point, final TrainingMonitor monitor) {
    if (stepSize == 1.0) {
        point.addRef();
        return point;
    } else {
        LineSearchPoint step = cursor.step(point.rate * stepSize, monitor);
        PointSample point1 = step.point;
        point1.addRef();
        step.freeRef();
        return point1;
    }
}
Also used : PointSample(com.simiacryptus.mindseye.lang.PointSample)

Example 29 with PointSample

use of com.simiacryptus.mindseye.lang.PointSample in project MindsEye by SimiaCryptus.

the class QuadraticSearch method step.

@Override
public PointSample step(@Nonnull final LineSearchCursor cursor, @Nonnull final TrainingMonitor monitor) {
    if (currentRate < getMinRate()) {
        currentRate = getMinRate();
    }
    final PointSample pointSample = _step(cursor, monitor);
    setCurrentRate(pointSample.rate);
    return pointSample;
}
Also used : PointSample(com.simiacryptus.mindseye.lang.PointSample)

Example 30 with PointSample

use of com.simiacryptus.mindseye.lang.PointSample in project MindsEye by SimiaCryptus.

the class StaticLearningRate method step.

@Override
public PointSample step(@Nonnull final LineSearchCursor cursor, @Nonnull final TrainingMonitor monitor) {
    double thisRate = rate;
    final LineSearchPoint startPoint = cursor.step(0, monitor);
    // theta(0)
    final double startValue = startPoint.point.sum;
    @Nullable LineSearchPoint lastStep = null;
    while (true) {
        if (null != lastStep)
            lastStep.freeRef();
        lastStep = cursor.step(thisRate, monitor);
        double lastValue = lastStep.point.sum;
        if (!Double.isFinite(lastValue)) {
            lastValue = Double.POSITIVE_INFINITY;
        }
        if (lastValue + startValue * 1e-15 > startValue) {
            monitor.log(String.format("Non-decreasing runStep. %s > %s at " + thisRate, lastValue, startValue));
            thisRate /= 2;
            if (thisRate < getMinimumRate()) {
                if (null != lastStep)
                    lastStep.freeRef();
                PointSample point = startPoint.point;
                point.addRef();
                startPoint.freeRef();
                return point;
            }
        } else {
            PointSample point = lastStep.point;
            point.addRef();
            startPoint.freeRef();
            lastStep.freeRef();
            return point;
        }
    }
}
Also used : PointSample(com.simiacryptus.mindseye.lang.PointSample) Nullable(javax.annotation.Nullable)

Aggregations

PointSample (com.simiacryptus.mindseye.lang.PointSample)33 Nonnull (javax.annotation.Nonnull)24 Layer (com.simiacryptus.mindseye.lang.Layer)16 Nullable (javax.annotation.Nullable)14 DeltaSet (com.simiacryptus.mindseye.lang.DeltaSet)10 TrainingMonitor (com.simiacryptus.mindseye.opt.TrainingMonitor)9 SimpleLineSearchCursor (com.simiacryptus.mindseye.opt.line.SimpleLineSearchCursor)9 StateSet (com.simiacryptus.mindseye.lang.StateSet)8 LineSearchCursor (com.simiacryptus.mindseye.opt.line.LineSearchCursor)8 List (java.util.List)8 Trainable (com.simiacryptus.mindseye.eval.Trainable)7 Arrays (java.util.Arrays)7 Collectors (java.util.stream.Collectors)7 IterativeStopException (com.simiacryptus.mindseye.lang.IterativeStopException)6 Map (java.util.Map)6 IntStream (java.util.stream.IntStream)6 DoubleBuffer (com.simiacryptus.mindseye.lang.DoubleBuffer)5 PlaceholderLayer (com.simiacryptus.mindseye.layers.java.PlaceholderLayer)5 FailsafeLineSearchCursor (com.simiacryptus.mindseye.opt.line.FailsafeLineSearchCursor)5 LineSearchStrategy (com.simiacryptus.mindseye.opt.line.LineSearchStrategy)5