Search in sources :

Example 11 with PointSample

use of com.simiacryptus.mindseye.lang.PointSample in project MindsEye by SimiaCryptus.

the class TensorListTrainable method eval.

/**
 * Eval point sample.
 *
 * @param list    the list
 * @param monitor the monitor
 * @return the point sample
 */
@Nonnull
protected PointSample eval(@Nonnull final TensorList[] list, @Nullable final TrainingMonitor monitor) {
    int inputs = data.length;
    assert 0 < inputs;
    int items = data[0].length();
    assert 0 < items;
    @Nonnull final TimedResult<PointSample> timedResult = TimedResult.time(() -> {
        final Result[] nnContext = TensorListTrainable.getNNContext(list, mask);
        final Result result = network.eval(nnContext);
        for (@Nonnull Result nnResult : nnContext) {
            nnResult.getData().freeRef();
            nnResult.freeRef();
        }
        final TensorList resultData = result.getData();
        final DoubleSummaryStatistics statistics = resultData.stream().flatMapToDouble(x -> {
            double[] array = Arrays.stream(x.getData()).toArray();
            x.freeRef();
            return Arrays.stream(array);
        }).summaryStatistics();
        final double sum = statistics.getSum();
        @Nonnull final DeltaSet<Layer> deltaSet = new DeltaSet<Layer>();
        @Nonnull PointSample pointSample;
        try {
            result.accumulate(deltaSet, 1.0);
            // log.info(String.format("Evaluated to %s delta buffers, %s mag", DeltaSet<LayerBase>.getMap().size(), DeltaSet<LayerBase>.getMagnitude()));
            @Nonnull StateSet<Layer> stateSet = new StateSet<>(deltaSet);
            pointSample = new PointSample(deltaSet, stateSet, sum, 0.0, items);
            stateSet.freeRef();
        } finally {
            resultData.freeRef();
            result.freeRef();
            deltaSet.freeRef();
        }
        return pointSample;
    });
    if (null != monitor && verbosity() > 0) {
        monitor.log(String.format("Device completed %s items in %.3f sec", items, timedResult.timeNanos / 1e9));
    }
    @Nonnull PointSample normalize = timedResult.result.normalize();
    timedResult.result.freeRef();
    return normalize;
}
Also used : IntStream(java.util.stream.IntStream) Arrays(java.util.Arrays) Tensor(com.simiacryptus.mindseye.lang.Tensor) ReferenceCountingBase(com.simiacryptus.mindseye.lang.ReferenceCountingBase) DoubleSummaryStatistics(java.util.DoubleSummaryStatistics) Result(com.simiacryptus.mindseye.lang.Result) StateSet(com.simiacryptus.mindseye.lang.StateSet) ConstantResult(com.simiacryptus.mindseye.lang.ConstantResult) TrainingMonitor(com.simiacryptus.mindseye.opt.TrainingMonitor) TensorList(com.simiacryptus.mindseye.lang.TensorList) PlaceholderLayer(com.simiacryptus.mindseye.layers.java.PlaceholderLayer) TimedResult(com.simiacryptus.util.lang.TimedResult) Layer(com.simiacryptus.mindseye.lang.Layer) TensorArray(com.simiacryptus.mindseye.lang.TensorArray) DeltaSet(com.simiacryptus.mindseye.lang.DeltaSet) Nonnull(javax.annotation.Nonnull) PointSample(com.simiacryptus.mindseye.lang.PointSample) Nullable(javax.annotation.Nullable) Nonnull(javax.annotation.Nonnull) DeltaSet(com.simiacryptus.mindseye.lang.DeltaSet) TensorList(com.simiacryptus.mindseye.lang.TensorList) DoubleSummaryStatistics(java.util.DoubleSummaryStatistics) PlaceholderLayer(com.simiacryptus.mindseye.layers.java.PlaceholderLayer) Layer(com.simiacryptus.mindseye.lang.Layer) Result(com.simiacryptus.mindseye.lang.Result) ConstantResult(com.simiacryptus.mindseye.lang.ConstantResult) TimedResult(com.simiacryptus.util.lang.TimedResult) PointSample(com.simiacryptus.mindseye.lang.PointSample) StateSet(com.simiacryptus.mindseye.lang.StateSet) Nonnull(javax.annotation.Nonnull)

Example 12 with PointSample

use of com.simiacryptus.mindseye.lang.PointSample in project MindsEye by SimiaCryptus.

the class LayerRateDiagnosticTrainer method measure.

/**
 * Measure point sample.
 *
 * @return the point sample
 */
public PointSample measure() {
    PointSample currentPoint;
    int retries = 0;
    do {
        if (!subject.reseed(System.nanoTime()) && retries > 0)
            throw new IterativeStopException();
        if (10 < retries++)
            throw new IterativeStopException();
        currentPoint = subject.measure(monitor);
    } while (!Double.isFinite(currentPoint.sum));
    assert Double.isFinite(currentPoint.sum);
    return currentPoint;
}
Also used : IterativeStopException(com.simiacryptus.mindseye.lang.IterativeStopException) PointSample(com.simiacryptus.mindseye.lang.PointSample)

Example 13 with PointSample

use of com.simiacryptus.mindseye.lang.PointSample in project MindsEye by SimiaCryptus.

the class ValidatingTrainer method runStep.

/**
 * Step runStep result.
 *
 * @param previousPoint the previous point
 * @param phase         the phase
 * @return the runStep result
 */
@Nonnull
protected StepResult runStep(@Nonnull final PointSample previousPoint, @Nonnull final TrainingPhase phase) {
    currentIteration.incrementAndGet();
    @Nonnull final TimedResult<LineSearchCursor> timedOrientation = TimedResult.time(() -> phase.orientation.orient(phase.trainingSubject, previousPoint, monitor));
    final LineSearchCursor direction = timedOrientation.result;
    final CharSequence directionType = direction.getDirectionType();
    LineSearchStrategy lineSearchStrategy;
    if (phase.lineSearchStrategyMap.containsKey(directionType)) {
        lineSearchStrategy = phase.lineSearchStrategyMap.get(directionType);
    } else {
        monitor.log(String.format("Constructing line search parameters: %s", directionType));
        lineSearchStrategy = phase.lineSearchFactory.apply(direction.getDirectionType());
        phase.lineSearchStrategyMap.put(directionType, lineSearchStrategy);
    }
    @Nonnull final TimedResult<PointSample> timedLineSearch = TimedResult.time(() -> {
        @Nonnull final FailsafeLineSearchCursor cursor = new FailsafeLineSearchCursor(direction, previousPoint, monitor);
        lineSearchStrategy.step(cursor, monitor);
        @Nonnull final PointSample restore = cursor.getBest(monitor).restore();
        // cursor.step(restore.rate, monitor);
        return restore;
    });
    final PointSample bestPoint = timedLineSearch.result;
    if (bestPoint.getMean() > previousPoint.getMean()) {
        throw new IllegalStateException(bestPoint.getMean() + " > " + previousPoint.getMean());
    }
    monitor.log(compare(previousPoint, bestPoint));
    return new StepResult(previousPoint, bestPoint, new double[] { timedOrientation.timeNanos / 1e9, timedLineSearch.timeNanos / 1e9 });
}
Also used : FailsafeLineSearchCursor(com.simiacryptus.mindseye.opt.line.FailsafeLineSearchCursor) LineSearchCursor(com.simiacryptus.mindseye.opt.line.LineSearchCursor) Nonnull(javax.annotation.Nonnull) LineSearchStrategy(com.simiacryptus.mindseye.opt.line.LineSearchStrategy) PointSample(com.simiacryptus.mindseye.lang.PointSample) FailsafeLineSearchCursor(com.simiacryptus.mindseye.opt.line.FailsafeLineSearchCursor) Nonnull(javax.annotation.Nonnull)

Example 14 with PointSample

use of com.simiacryptus.mindseye.lang.PointSample in project MindsEye by SimiaCryptus.

the class ValidatingTrainer method measure.

private PointSample measure(@Nonnull final TrainingPhase phase) {
    int retries = 0;
    do {
        if (10 < retries++)
            throw new IterativeStopException();
        final PointSample currentPoint = phase.trainingSubject.measure(monitor);
        if (Double.isFinite(currentPoint.getMean()))
            return currentPoint;
        phase.orientation.reset();
    } while (true);
}
Also used : IterativeStopException(com.simiacryptus.mindseye.lang.IterativeStopException) PointSample(com.simiacryptus.mindseye.lang.PointSample)

Example 15 with PointSample

use of com.simiacryptus.mindseye.lang.PointSample in project MindsEye by SimiaCryptus.

the class ValidatingTrainer method run.

/**
 * Run double.
 *
 * @return the double
 */
public double run() {
    try {
        final long timeoutAt = System.currentTimeMillis() + timeout.toMillis();
        if (validationSubject.getLayer() instanceof DAGNetwork) {
            ((DAGNetwork) validationSubject.getLayer()).visitLayers(layer -> {
                if (layer instanceof StochasticComponent)
                    ((StochasticComponent) layer).clearNoise();
            });
        }
        @Nonnull final EpochParams epochParams = new EpochParams(timeoutAt, epochIterations, getTrainingSize(), validationSubject.measure(monitor));
        int epochNumber = 0;
        int iterationNumber = 0;
        int lastImprovement = 0;
        double lowestValidation = Double.POSITIVE_INFINITY;
        while (true) {
            if (shouldHalt(monitor, timeoutAt)) {
                monitor.log("Training halted");
                break;
            }
            monitor.log(String.format("Epoch parameters: %s, %s", epochParams.trainingSize, epochParams.iterations));
            @Nonnull final List<TrainingPhase> regimen = getRegimen();
            final long seed = System.nanoTime();
            final List<EpochResult> epochResults = IntStream.range(0, regimen.size()).mapToObj(i -> {
                final TrainingPhase phase = getRegimen().get(i);
                return runPhase(epochParams, phase, i, seed);
            }).collect(Collectors.toList());
            final EpochResult primaryPhase = epochResults.get(0);
            iterationNumber += primaryPhase.iterations;
            final double trainingDelta = primaryPhase.currentPoint.getMean() / primaryPhase.priorMean;
            if (validationSubject.getLayer() instanceof DAGNetwork) {
                ((DAGNetwork) validationSubject.getLayer()).visitLayers(layer -> {
                    if (layer instanceof StochasticComponent)
                        ((StochasticComponent) layer).clearNoise();
                });
            }
            final PointSample currentValidation = validationSubject.measure(monitor);
            final double overtraining = Math.log(trainingDelta) / Math.log(currentValidation.getMean() / epochParams.validation.getMean());
            final double validationDelta = currentValidation.getMean() / epochParams.validation.getMean();
            final double adj1 = Math.pow(Math.log(getTrainingTarget()) / Math.log(validationDelta), adjustmentFactor);
            final double adj2 = Math.pow(overtraining / getOvertrainingTarget(), adjustmentFactor);
            final double validationMean = currentValidation.getMean();
            if (validationMean < lowestValidation) {
                lowestValidation = validationMean;
                lastImprovement = iterationNumber;
            }
            monitor.log(String.format("Epoch %d result apply %s iterations, %s/%s samples: {validation *= 2^%.5f; training *= 2^%.3f; Overtraining = %.2f}, {itr*=%.2f, len*=%.2f} %s since improvement; %.4f validation time", ++epochNumber, primaryPhase.iterations, epochParams.trainingSize, getMaxTrainingSize(), Math.log(validationDelta) / Math.log(2), Math.log(trainingDelta) / Math.log(2), overtraining, adj1, adj2, iterationNumber - lastImprovement, validatingMeasurementTime.getAndSet(0) / 1e9));
            if (!primaryPhase.continueTraining) {
                monitor.log(String.format("Training %d runPhase halted", epochNumber));
                break;
            }
            if (epochParams.trainingSize >= getMaxTrainingSize()) {
                final double roll = FastRandom.INSTANCE.random();
                if (roll > Math.pow(2 - validationDelta, pessimism)) {
                    monitor.log(String.format("Training randomly converged: %3f", roll));
                    break;
                } else {
                    if (iterationNumber - lastImprovement > improvmentStaleThreshold) {
                        if (disappointments.incrementAndGet() > getDisappointmentThreshold()) {
                            monitor.log(String.format("Training converged after %s iterations", iterationNumber - lastImprovement));
                            break;
                        } else {
                            monitor.log(String.format("Training failed to converged on %s attempt after %s iterations", disappointments.get(), iterationNumber - lastImprovement));
                        }
                    } else {
                        disappointments.set(0);
                    }
                }
            }
            if (validationDelta < 1.0 && trainingDelta < 1.0) {
                if (adj1 < 1 - adjustmentTolerance || adj1 > 1 + adjustmentTolerance) {
                    epochParams.iterations = Math.max(getMinEpochIterations(), Math.min(getMaxEpochIterations(), (int) (primaryPhase.iterations * adj1)));
                }
                if (adj2 < 1 + adjustmentTolerance || adj2 > 1 - adjustmentTolerance) {
                    epochParams.trainingSize = Math.max(0, Math.min(Math.max(getMinTrainingSize(), Math.min(getMaxTrainingSize(), (int) (epochParams.trainingSize * adj2))), epochParams.trainingSize));
                }
            } else {
                epochParams.trainingSize = Math.max(0, Math.min(Math.max(getMinTrainingSize(), Math.min(getMaxTrainingSize(), epochParams.trainingSize * 5)), epochParams.trainingSize));
                epochParams.iterations = 1;
            }
            epochParams.validation = currentValidation;
        }
        if (validationSubject.getLayer() instanceof DAGNetwork) {
            ((DAGNetwork) validationSubject.getLayer()).visitLayers(layer -> {
                if (layer instanceof StochasticComponent)
                    ((StochasticComponent) layer).clearNoise();
            });
        }
        return epochParams.validation.getMean();
    } catch (@Nonnull final Throwable e) {
        throw new RuntimeException(e);
    }
}
Also used : IntStream(java.util.stream.IntStream) Arrays(java.util.Arrays) DoubleStatistics(com.simiacryptus.util.data.DoubleStatistics) TemporalUnit(java.time.temporal.TemporalUnit) LineSearchStrategy(com.simiacryptus.mindseye.opt.line.LineSearchStrategy) SampledTrainable(com.simiacryptus.mindseye.eval.SampledTrainable) TrainableBase(com.simiacryptus.mindseye.eval.TrainableBase) DoubleBuffer(com.simiacryptus.mindseye.lang.DoubleBuffer) HashMap(java.util.HashMap) SampledCachedTrainable(com.simiacryptus.mindseye.eval.SampledCachedTrainable) ArmijoWolfeSearch(com.simiacryptus.mindseye.opt.line.ArmijoWolfeSearch) Function(java.util.function.Function) StateSet(com.simiacryptus.mindseye.lang.StateSet) ArrayList(java.util.ArrayList) TrainableWrapper(com.simiacryptus.mindseye.eval.TrainableWrapper) Trainable(com.simiacryptus.mindseye.eval.Trainable) AtomicInteger(java.util.concurrent.atomic.AtomicInteger) Duration(java.time.Duration) Map(java.util.Map) Layer(com.simiacryptus.mindseye.lang.Layer) ManagementFactory(java.lang.management.ManagementFactory) FailsafeLineSearchCursor(com.simiacryptus.mindseye.opt.line.FailsafeLineSearchCursor) LineSearchCursor(com.simiacryptus.mindseye.opt.line.LineSearchCursor) Nonnull(javax.annotation.Nonnull) IterativeStopException(com.simiacryptus.mindseye.lang.IterativeStopException) Util(com.simiacryptus.util.Util) StochasticComponent(com.simiacryptus.mindseye.layers.java.StochasticComponent) QQN(com.simiacryptus.mindseye.opt.orient.QQN) OrientationStrategy(com.simiacryptus.mindseye.opt.orient.OrientationStrategy) FastRandom(com.simiacryptus.util.FastRandom) Collectors(java.util.stream.Collectors) TimeUnit(java.util.concurrent.TimeUnit) AtomicLong(java.util.concurrent.atomic.AtomicLong) List(java.util.List) ChronoUnit(java.time.temporal.ChronoUnit) TimedResult(com.simiacryptus.util.lang.TimedResult) DAGNetwork(com.simiacryptus.mindseye.network.DAGNetwork) PointSample(com.simiacryptus.mindseye.lang.PointSample) Nonnull(javax.annotation.Nonnull) DAGNetwork(com.simiacryptus.mindseye.network.DAGNetwork) StochasticComponent(com.simiacryptus.mindseye.layers.java.StochasticComponent) PointSample(com.simiacryptus.mindseye.lang.PointSample)

Aggregations

PointSample (com.simiacryptus.mindseye.lang.PointSample)33 Nonnull (javax.annotation.Nonnull)24 Layer (com.simiacryptus.mindseye.lang.Layer)16 Nullable (javax.annotation.Nullable)14 DeltaSet (com.simiacryptus.mindseye.lang.DeltaSet)10 TrainingMonitor (com.simiacryptus.mindseye.opt.TrainingMonitor)9 SimpleLineSearchCursor (com.simiacryptus.mindseye.opt.line.SimpleLineSearchCursor)9 StateSet (com.simiacryptus.mindseye.lang.StateSet)8 LineSearchCursor (com.simiacryptus.mindseye.opt.line.LineSearchCursor)8 List (java.util.List)8 Trainable (com.simiacryptus.mindseye.eval.Trainable)7 Arrays (java.util.Arrays)7 Collectors (java.util.stream.Collectors)7 IterativeStopException (com.simiacryptus.mindseye.lang.IterativeStopException)6 Map (java.util.Map)6 IntStream (java.util.stream.IntStream)6 DoubleBuffer (com.simiacryptus.mindseye.lang.DoubleBuffer)5 PlaceholderLayer (com.simiacryptus.mindseye.layers.java.PlaceholderLayer)5 FailsafeLineSearchCursor (com.simiacryptus.mindseye.opt.line.FailsafeLineSearchCursor)5 LineSearchStrategy (com.simiacryptus.mindseye.opt.line.LineSearchStrategy)5