Search in sources :

Example 11 with TrainingMonitor

use of com.simiacryptus.mindseye.opt.TrainingMonitor in project MindsEye by SimiaCryptus.

the class TrustRegionStrategy method orient.

@Nonnull
@Override
public LineSearchCursor orient(@Nonnull final Trainable subject, final PointSample origin, final TrainingMonitor monitor) {
    history.add(0, origin);
    while (history.size() > maxHistory) {
        history.remove(history.size() - 1);
    }
    final SimpleLineSearchCursor cursor = inner.orient(subject, origin, monitor);
    return new LineSearchCursorBase() {

        @Nonnull
        @Override
        public CharSequence getDirectionType() {
            return cursor.getDirectionType() + "+Trust";
        }

        @Nonnull
        @Override
        public DeltaSet<Layer> position(final double alpha) {
            reset();
            @Nonnull final DeltaSet<Layer> adjustedPosVector = cursor.position(alpha);
            project(adjustedPosVector, new TrainingMonitor());
            return adjustedPosVector;
        }

        @Nonnull
        public DeltaSet<Layer> project(@Nonnull final DeltaSet<Layer> deltaIn, final TrainingMonitor monitor) {
            final DeltaSet<Layer> originalAlphaDerivative = cursor.direction;
            @Nonnull final DeltaSet<Layer> newAlphaDerivative = originalAlphaDerivative.copy();
            deltaIn.getMap().forEach((layer, buffer) -> {
                @Nullable final double[] delta = buffer.getDelta();
                if (null == delta)
                    return;
                final double[] currentPosition = buffer.target;
                @Nullable final double[] originalAlphaD = originalAlphaDerivative.get(layer, currentPosition).getDelta();
                @Nullable final double[] newAlphaD = newAlphaDerivative.get(layer, currentPosition).getDelta();
                @Nonnull final double[] proposedPosition = ArrayUtil.add(currentPosition, delta);
                final TrustRegion region = getRegionPolicy(layer);
                if (null != region) {
                    final Stream<double[]> zz = history.stream().map((@Nonnull final PointSample x) -> {
                        final DoubleBuffer<Layer> d = x.weights.getMap().get(layer);
                        @Nullable final double[] z = null == d ? null : d.getDelta();
                        return z;
                    });
                    final double[] projectedPosition = region.project(zz.filter(x -> null != x).toArray(i -> new double[i][]), proposedPosition);
                    if (projectedPosition != proposedPosition) {
                        for (int i = 0; i < projectedPosition.length; i++) {
                            delta[i] = projectedPosition[i] - currentPosition[i];
                        }
                        @Nonnull final double[] normal = ArrayUtil.subtract(projectedPosition, proposedPosition);
                        final double normalMagSq = ArrayUtil.dot(normal, normal);
                        // normalMagSq));
                        if (0 < normalMagSq) {
                            final double a = ArrayUtil.dot(originalAlphaD, normal);
                            if (a != -1) {
                                @Nonnull final double[] tangent = ArrayUtil.add(originalAlphaD, ArrayUtil.multiply(normal, -a / normalMagSq));
                                for (int i = 0; i < tangent.length; i++) {
                                    newAlphaD[i] = tangent[i];
                                }
                            // double newAlphaDerivSq = ArrayUtil.dot(tangent, tangent);
                            // double originalAlphaDerivSq = ArrayUtil.dot(originalAlphaD, originalAlphaD);
                            // assert(newAlphaDerivSq <= originalAlphaDerivSq);
                            // assert(Math.abs(ArrayUtil.dot(tangent, normal)) <= 1e-4);
                            // monitor.log(String.format("%s: normalMagSq = %s, newAlphaDerivSq = %s, originalAlphaDerivSq = %s", layer, normalMagSq, newAlphaDerivSq, originalAlphaDerivSq));
                            }
                        }
                    }
                }
            });
            return newAlphaDerivative;
        }

        @Override
        public void reset() {
            cursor.reset();
        }

        @Nonnull
        @Override
        public LineSearchPoint step(final double alpha, final TrainingMonitor monitor) {
            cursor.reset();
            @Nonnull final DeltaSet<Layer> adjustedPosVector = cursor.position(alpha);
            @Nonnull final DeltaSet<Layer> adjustedGradient = project(adjustedPosVector, monitor);
            adjustedPosVector.accumulate(1);
            @Nonnull final PointSample sample = subject.measure(monitor).setRate(alpha);
            return new LineSearchPoint(sample, adjustedGradient.dot(sample.delta));
        }

        @Override
        public void _free() {
            cursor.freeRef();
        }
    };
}
Also used : TrustRegion(com.simiacryptus.mindseye.opt.region.TrustRegion) IntStream(java.util.stream.IntStream) LineSearchPoint(com.simiacryptus.mindseye.opt.line.LineSearchPoint) TrustRegion(com.simiacryptus.mindseye.opt.region.TrustRegion) DoubleBuffer(com.simiacryptus.mindseye.lang.DoubleBuffer) ArrayUtil(com.simiacryptus.util.ArrayUtil) Trainable(com.simiacryptus.mindseye.eval.Trainable) List(java.util.List) Stream(java.util.stream.Stream) TrainingMonitor(com.simiacryptus.mindseye.opt.TrainingMonitor) Layer(com.simiacryptus.mindseye.lang.Layer) SimpleLineSearchCursor(com.simiacryptus.mindseye.opt.line.SimpleLineSearchCursor) DeltaSet(com.simiacryptus.mindseye.lang.DeltaSet) LineSearchCursor(com.simiacryptus.mindseye.opt.line.LineSearchCursor) LinkedList(java.util.LinkedList) Nonnull(javax.annotation.Nonnull) PointSample(com.simiacryptus.mindseye.lang.PointSample) LineSearchCursorBase(com.simiacryptus.mindseye.opt.line.LineSearchCursorBase) Nullable(javax.annotation.Nullable) Nonnull(javax.annotation.Nonnull) DeltaSet(com.simiacryptus.mindseye.lang.DeltaSet) Layer(com.simiacryptus.mindseye.lang.Layer) LineSearchPoint(com.simiacryptus.mindseye.opt.line.LineSearchPoint) TrainingMonitor(com.simiacryptus.mindseye.opt.TrainingMonitor) SimpleLineSearchCursor(com.simiacryptus.mindseye.opt.line.SimpleLineSearchCursor) LineSearchPoint(com.simiacryptus.mindseye.opt.line.LineSearchPoint) PointSample(com.simiacryptus.mindseye.lang.PointSample) LineSearchCursorBase(com.simiacryptus.mindseye.opt.line.LineSearchCursorBase) Nullable(javax.annotation.Nullable) Nonnull(javax.annotation.Nonnull)

Example 12 with TrainingMonitor

use of com.simiacryptus.mindseye.opt.TrainingMonitor in project MindsEye by SimiaCryptus.

the class LocalSparkTrainable method measure.

@Nonnull
@Override
public PointSample measure(final TrainingMonitor monitor) {
    final long time1 = System.nanoTime();
    final JavaRDD<Tensor[]> javaRDD = sampledRDD.toJavaRDD();
    assert !javaRDD.isEmpty();
    final List<ReducableResult> mapPartitions = javaRDD.partitions().stream().map(partition -> {
        try {
            final List<Tensor[]>[] array = javaRDD.collectPartitions(new int[] { partition.index() });
            assert 0 < array.length;
            if (0 == Arrays.stream(array).mapToInt((@Nonnull final List<Tensor[]> x) -> x.size()).sum()) {
                return null;
            }
            assert 0 < Arrays.stream(array).mapToInt(x -> x.stream().mapToInt(y -> y.length).sum()).sum();
            final Stream<Tensor[]> stream = Arrays.stream(array).flatMap(i -> i.stream());
            @Nonnull final Iterator<Tensor[]> iterator = stream.iterator();
            return new PartitionTask(network).call(iterator).next();
        } catch (@Nonnull final RuntimeException e) {
            throw e;
        } catch (@Nonnull final Exception e) {
            throw new RuntimeException(e);
        }
    }).filter(x -> null != x).collect(Collectors.toList());
    final long time2 = System.nanoTime();
    @Nonnull final SparkTrainable.ReducableResult result = mapPartitions.stream().reduce(SparkTrainable.ReducableResult::add).get();
    if (isVerbose()) {
        log.info(String.format("Measure timing: %.3f / %.3f for %s items", (time2 - time1) * 1e-9, (System.nanoTime() - time2) * 1e-9, sampledRDD.count()));
    }
    @Nonnull final DeltaSet<Layer> xxx = getDelta(result);
    return new PointSample(xxx, new StateSet<Layer>(xxx), result.sum, 0.0, result.count).normalize();
}
Also used : Arrays(java.util.Arrays) Logger(org.slf4j.Logger) Iterator(java.util.Iterator) LoggerFactory(org.slf4j.LoggerFactory) Tensor(com.simiacryptus.mindseye.lang.Tensor) Collectors(java.util.stream.Collectors) StateSet(com.simiacryptus.mindseye.lang.StateSet) List(java.util.List) Stream(java.util.stream.Stream) TrainingMonitor(com.simiacryptus.mindseye.opt.TrainingMonitor) Layer(com.simiacryptus.mindseye.lang.Layer) DeltaSet(com.simiacryptus.mindseye.lang.DeltaSet) RDD(org.apache.spark.rdd.RDD) Nonnull(javax.annotation.Nonnull) PointSample(com.simiacryptus.mindseye.lang.PointSample) JavaRDD(org.apache.spark.api.java.JavaRDD) Tensor(com.simiacryptus.mindseye.lang.Tensor) Nonnull(javax.annotation.Nonnull) Layer(com.simiacryptus.mindseye.lang.Layer) List(java.util.List) PointSample(com.simiacryptus.mindseye.lang.PointSample) StateSet(com.simiacryptus.mindseye.lang.StateSet) Nonnull(javax.annotation.Nonnull)

Example 13 with TrainingMonitor

use of com.simiacryptus.mindseye.opt.TrainingMonitor in project MindsEye by SimiaCryptus.

the class TrainingTester method trainCjGD.

/**
 * Train cj gd list.
 *
 * @param log       the log
 * @param trainable the trainable
 * @return the list
 */
@Nonnull
public List<StepRecord> trainCjGD(@Nonnull final NotebookOutput log, final Trainable trainable) {
    log.p("First, we use a conjugate gradient descent method, which converges the fastest for purely linear functions.");
    @Nonnull final List<StepRecord> history = new ArrayList<>();
    @Nonnull final TrainingMonitor monitor = TrainingTester.getMonitor(history);
    try {
        log.code(() -> {
            return new IterativeTrainer(trainable).setLineSearchFactory(label -> new QuadraticSearch()).setOrientation(new GradientDescent()).setMonitor(monitor).setTimeout(30, TimeUnit.SECONDS).setMaxIterations(250).setTerminateThreshold(0).runAndFree();
        });
    } catch (Throwable e) {
        if (isThrowExceptions())
            throw new RuntimeException(e);
    }
    return history;
}
Also used : StepRecord(com.simiacryptus.mindseye.test.StepRecord) TrainingMonitor(com.simiacryptus.mindseye.opt.TrainingMonitor) IterativeTrainer(com.simiacryptus.mindseye.opt.IterativeTrainer) Nonnull(javax.annotation.Nonnull) QuadraticSearch(com.simiacryptus.mindseye.opt.line.QuadraticSearch) ArrayList(java.util.ArrayList) GradientDescent(com.simiacryptus.mindseye.opt.orient.GradientDescent) Nonnull(javax.annotation.Nonnull)

Example 14 with TrainingMonitor

use of com.simiacryptus.mindseye.opt.TrainingMonitor in project MindsEye by SimiaCryptus.

the class TrainingTester method trainLBFGS.

/**
 * Train lbfgs list.
 *
 * @param log       the log
 * @param trainable the trainable
 * @return the list
 */
@Nonnull
public List<StepRecord> trainLBFGS(@Nonnull final NotebookOutput log, final Trainable trainable) {
    log.p("Next, we apply the same optimization using L-BFGS, which is nearly ideal for purely second-order or quadratic functions.");
    @Nonnull final List<StepRecord> history = new ArrayList<>();
    @Nonnull final TrainingMonitor monitor = TrainingTester.getMonitor(history);
    try {
        log.code(() -> {
            return new IterativeTrainer(trainable).setLineSearchFactory(label -> new ArmijoWolfeSearch()).setOrientation(new LBFGS()).setMonitor(monitor).setTimeout(30, TimeUnit.SECONDS).setIterationsPerSample(100).setMaxIterations(250).setTerminateThreshold(0).runAndFree();
        });
    } catch (Throwable e) {
        if (isThrowExceptions())
            throw new RuntimeException(e);
    }
    return history;
}
Also used : StepRecord(com.simiacryptus.mindseye.test.StepRecord) PipelineNetwork(com.simiacryptus.mindseye.network.PipelineNetwork) IntStream(java.util.stream.IntStream) Arrays(java.util.Arrays) BiFunction(java.util.function.BiFunction) LoggerFactory(org.slf4j.LoggerFactory) Tensor(com.simiacryptus.mindseye.lang.Tensor) HashMap(java.util.HashMap) Random(java.util.Random) Result(com.simiacryptus.mindseye.lang.Result) ArmijoWolfeSearch(com.simiacryptus.mindseye.opt.line.ArmijoWolfeSearch) ArrayList(java.util.ArrayList) Trainable(com.simiacryptus.mindseye.eval.Trainable) ConstantResult(com.simiacryptus.mindseye.lang.ConstantResult) TrainingMonitor(com.simiacryptus.mindseye.opt.TrainingMonitor) Map(java.util.Map) Layer(com.simiacryptus.mindseye.lang.Layer) QuadraticSearch(com.simiacryptus.mindseye.opt.line.QuadraticSearch) LBFGS(com.simiacryptus.mindseye.opt.orient.LBFGS) RecursiveSubspace(com.simiacryptus.mindseye.opt.orient.RecursiveSubspace) StepRecord(com.simiacryptus.mindseye.test.StepRecord) NotebookOutput(com.simiacryptus.util.io.NotebookOutput) ReferenceCounting(com.simiacryptus.mindseye.lang.ReferenceCounting) IterativeTrainer(com.simiacryptus.mindseye.opt.IterativeTrainer) Nonnull(javax.annotation.Nonnull) Nullable(javax.annotation.Nullable) MeanSqLossLayer(com.simiacryptus.mindseye.layers.java.MeanSqLossLayer) Logger(org.slf4j.Logger) PlotCanvas(smile.plot.PlotCanvas) QQN(com.simiacryptus.mindseye.opt.orient.QQN) GradientDescent(com.simiacryptus.mindseye.opt.orient.GradientDescent) BasicTrainable(com.simiacryptus.mindseye.eval.BasicTrainable) StaticLearningRate(com.simiacryptus.mindseye.opt.line.StaticLearningRate) TestUtil(com.simiacryptus.mindseye.test.TestUtil) DAGNode(com.simiacryptus.mindseye.network.DAGNode) DoubleStream(java.util.stream.DoubleStream) java.awt(java.awt) TimeUnit(java.util.concurrent.TimeUnit) List(java.util.List) Stream(java.util.stream.Stream) ArrayTrainable(com.simiacryptus.mindseye.eval.ArrayTrainable) TensorList(com.simiacryptus.mindseye.lang.TensorList) Step(com.simiacryptus.mindseye.opt.Step) ProblemRun(com.simiacryptus.mindseye.test.ProblemRun) javax.swing(javax.swing) LBFGS(com.simiacryptus.mindseye.opt.orient.LBFGS) TrainingMonitor(com.simiacryptus.mindseye.opt.TrainingMonitor) IterativeTrainer(com.simiacryptus.mindseye.opt.IterativeTrainer) ArmijoWolfeSearch(com.simiacryptus.mindseye.opt.line.ArmijoWolfeSearch) Nonnull(javax.annotation.Nonnull) ArrayList(java.util.ArrayList) Nonnull(javax.annotation.Nonnull)

Example 15 with TrainingMonitor

use of com.simiacryptus.mindseye.opt.TrainingMonitor in project MindsEye by SimiaCryptus.

the class LBFGS method lbfgs.

private boolean lbfgs(@Nonnull PointSample measurement, @Nonnull TrainingMonitor monitor, @Nonnull List<PointSample> history, @Nonnull DeltaSet<Layer> direction) {
    try {
        @Nonnull DeltaSet<Layer> p = measurement.delta.copy();
        if (!p.stream().parallel().allMatch(y -> Arrays.stream(y.getDelta()).allMatch(d -> Double.isFinite(d)))) {
            throw new IllegalStateException("Non-finite value");
        }
        @Nonnull final double[] alphas = new double[history.size()];
        for (int i = history.size() - 2; i >= 0; i--) {
            @Nonnull final DeltaSet<Layer> sd = history.get(i + 1).weights.subtract(history.get(i).weights);
            @Nonnull final DeltaSet<Layer> yd = history.get(i + 1).delta.subtract(history.get(i).delta);
            final double denominator = sd.dot(yd);
            if (0 == denominator) {
                throw new IllegalStateException("Orientation vanished.");
            }
            alphas[i] = p.dot(sd) / denominator;
            p = p.subtract(yd.scale(alphas[i]));
            if ((!p.stream().parallel().allMatch(y -> Arrays.stream(y.getDelta()).allMatch(d -> Double.isFinite(d))))) {
                throw new IllegalStateException("Non-finite value");
            }
        }
        @Nonnull final DeltaSet<Layer> sk = history.get(history.size() - 1).weights.subtract(history.get(history.size() - 2).weights);
        @Nonnull final DeltaSet<Layer> yk = history.get(history.size() - 1).delta.subtract(history.get(history.size() - 2).delta);
        p = p.scale(sk.dot(yk) / yk.dot(yk));
        if (!p.stream().parallel().allMatch(y -> Arrays.stream(y.getDelta()).allMatch(d -> Double.isFinite(d)))) {
            throw new IllegalStateException("Non-finite value");
        }
        for (int i = 0; i < history.size() - 1; i++) {
            @Nonnull final DeltaSet<Layer> sd = history.get(i + 1).weights.subtract(history.get(i).weights);
            @Nonnull final DeltaSet<Layer> yd = history.get(i + 1).delta.subtract(history.get(i).delta);
            final double beta = p.dot(yd) / sd.dot(yd);
            p = p.add(sd.scale(alphas[i] - beta));
            if (!p.stream().parallel().allMatch(y -> Arrays.stream(y.getDelta()).allMatch(d -> Double.isFinite(d)))) {
                throw new IllegalStateException("Non-finite value");
            }
        }
        boolean accept = measurement.delta.dot(p) < 0;
        if (accept) {
            monitor.log("Accepted: " + new Stats(direction, p));
            copy(p, direction);
        } else {
            monitor.log("Rejected: " + new Stats(direction, p));
        }
        return accept;
    } catch (Throwable e) {
        monitor.log(String.format("LBFGS Orientation Error: %s", e.getMessage()));
        return false;
    }
}
Also used : Arrays(java.util.Arrays) LineSearchPoint(com.simiacryptus.mindseye.opt.line.LineSearchPoint) Collectors(java.util.stream.Collectors) ArrayUtil(com.simiacryptus.util.ArrayUtil) TreeSet(java.util.TreeSet) Trainable(com.simiacryptus.mindseye.eval.Trainable) Delta(com.simiacryptus.mindseye.lang.Delta) List(java.util.List) TrainingMonitor(com.simiacryptus.mindseye.opt.TrainingMonitor) Map(java.util.Map) PlaceholderLayer(com.simiacryptus.mindseye.layers.java.PlaceholderLayer) Layer(com.simiacryptus.mindseye.lang.Layer) SimpleLineSearchCursor(com.simiacryptus.mindseye.opt.line.SimpleLineSearchCursor) DeltaSet(com.simiacryptus.mindseye.lang.DeltaSet) DoubleBufferSet(com.simiacryptus.mindseye.lang.DoubleBufferSet) Comparator(java.util.Comparator) Nonnull(javax.annotation.Nonnull) PointSample(com.simiacryptus.mindseye.lang.PointSample) Nullable(javax.annotation.Nullable) Nonnull(javax.annotation.Nonnull) PlaceholderLayer(com.simiacryptus.mindseye.layers.java.PlaceholderLayer) Layer(com.simiacryptus.mindseye.lang.Layer) LineSearchPoint(com.simiacryptus.mindseye.opt.line.LineSearchPoint)

Aggregations

TrainingMonitor (com.simiacryptus.mindseye.opt.TrainingMonitor)19 Nonnull (javax.annotation.Nonnull)19 Layer (com.simiacryptus.mindseye.lang.Layer)13 List (java.util.List)12 Nullable (javax.annotation.Nullable)12 PointSample (com.simiacryptus.mindseye.lang.PointSample)10 Tensor (com.simiacryptus.mindseye.lang.Tensor)10 Arrays (java.util.Arrays)10 DeltaSet (com.simiacryptus.mindseye.lang.DeltaSet)9 Trainable (com.simiacryptus.mindseye.eval.Trainable)8 StepRecord (com.simiacryptus.mindseye.test.StepRecord)8 IntStream (java.util.stream.IntStream)8 ArrayTrainable (com.simiacryptus.mindseye.eval.ArrayTrainable)7 TensorList (com.simiacryptus.mindseye.lang.TensorList)7 ArrayList (java.util.ArrayList)7 Result (com.simiacryptus.mindseye.lang.Result)6 StateSet (com.simiacryptus.mindseye.lang.StateSet)6 IterativeTrainer (com.simiacryptus.mindseye.opt.IterativeTrainer)6 ConstantResult (com.simiacryptus.mindseye.lang.ConstantResult)5 TensorArray (com.simiacryptus.mindseye.lang.TensorArray)5