Search in sources :

Example 16 with TrainingMonitor

use of com.simiacryptus.mindseye.opt.TrainingMonitor in project MindsEye by SimiaCryptus.

the class OwlQn method orient.

@Nonnull
@Override
public LineSearchCursor orient(final Trainable subject, @Nonnull final PointSample measurement, final TrainingMonitor monitor) {
    @Nonnull final SimpleLineSearchCursor gradient = (SimpleLineSearchCursor) inner.orient(subject, measurement, monitor);
    @Nonnull final DeltaSet<Layer> searchDirection = gradient.direction.copy();
    @Nonnull final DeltaSet<Layer> orthant = new DeltaSet<Layer>();
    for (@Nonnull final Layer layer : getLayers(gradient.direction.getMap().keySet())) {
        final double[] weights = gradient.direction.getMap().get(layer).target;
        @Nullable final double[] delta = gradient.direction.getMap().get(layer).getDelta();
        @Nullable final double[] searchDir = searchDirection.get(layer, weights).getDelta();
        @Nullable final double[] suborthant = orthant.get(layer, weights).getDelta();
        for (int i = 0; i < searchDir.length; i++) {
            final int positionSign = sign(weights[i]);
            final int directionSign = sign(delta[i]);
            suborthant[i] = 0 == positionSign ? directionSign : positionSign;
            searchDir[i] += factor_L1 * (weights[i] < 0 ? -1.0 : 1.0);
            if (sign(searchDir[i]) != directionSign) {
                searchDir[i] = delta[i];
            }
        }
        assert null != searchDir;
    }
    return new SimpleLineSearchCursor(subject, measurement, searchDirection) {

        @Nonnull
        @Override
        public LineSearchPoint step(final double alpha, final TrainingMonitor monitor) {
            origin.weights.stream().forEach(d -> d.restore());
            @Nonnull final DeltaSet<Layer> currentDirection = direction.copy();
            direction.getMap().forEach((layer, buffer) -> {
                if (null == buffer.getDelta())
                    return;
                @Nullable final double[] currentDelta = currentDirection.get(layer, buffer.target).getDelta();
                for (int i = 0; i < buffer.getDelta().length; i++) {
                    final double prevValue = buffer.target[i];
                    final double newValue = prevValue + buffer.getDelta()[i] * alpha;
                    if (sign(prevValue) != 0 && sign(prevValue) != sign(newValue)) {
                        currentDelta[i] = 0;
                        buffer.target[i] = 0;
                    } else {
                        buffer.target[i] = newValue;
                    }
                }
            });
            @Nonnull final PointSample measure = subject.measure(monitor).setRate(alpha);
            return new LineSearchPoint(measure, currentDirection.dot(measure.delta));
        }
    }.setDirectionType("OWL/QN");
}
Also used : LineSearchPoint(com.simiacryptus.mindseye.opt.line.LineSearchPoint) Collection(java.util.Collection) Collectors(java.util.stream.Collectors) Trainable(com.simiacryptus.mindseye.eval.Trainable) FullyConnectedLayer(com.simiacryptus.mindseye.layers.java.FullyConnectedLayer) TrainingMonitor(com.simiacryptus.mindseye.opt.TrainingMonitor) Layer(com.simiacryptus.mindseye.lang.Layer) SimpleLineSearchCursor(com.simiacryptus.mindseye.opt.line.SimpleLineSearchCursor) DeltaSet(com.simiacryptus.mindseye.lang.DeltaSet) LineSearchCursor(com.simiacryptus.mindseye.opt.line.LineSearchCursor) Nonnull(javax.annotation.Nonnull) PointSample(com.simiacryptus.mindseye.lang.PointSample) Nullable(javax.annotation.Nullable) Nonnull(javax.annotation.Nonnull) DeltaSet(com.simiacryptus.mindseye.lang.DeltaSet) FullyConnectedLayer(com.simiacryptus.mindseye.layers.java.FullyConnectedLayer) Layer(com.simiacryptus.mindseye.lang.Layer) LineSearchPoint(com.simiacryptus.mindseye.opt.line.LineSearchPoint) TrainingMonitor(com.simiacryptus.mindseye.opt.TrainingMonitor) SimpleLineSearchCursor(com.simiacryptus.mindseye.opt.line.SimpleLineSearchCursor) LineSearchPoint(com.simiacryptus.mindseye.opt.line.LineSearchPoint) PointSample(com.simiacryptus.mindseye.lang.PointSample) Nullable(javax.annotation.Nullable) Nonnull(javax.annotation.Nonnull)

Example 17 with TrainingMonitor

use of com.simiacryptus.mindseye.opt.TrainingMonitor in project MindsEye by SimiaCryptus.

the class QQN method orient.

@Override
public LineSearchCursor orient(@Nonnull final Trainable subject, @Nonnull final PointSample origin, @Nonnull final TrainingMonitor monitor) {
    inner.addToHistory(origin, monitor);
    final SimpleLineSearchCursor lbfgsCursor = inner.orient(subject, origin, monitor);
    final DeltaSet<Layer> lbfgs = lbfgsCursor.direction;
    @Nonnull final DeltaSet<Layer> gd = origin.delta.scale(-1.0);
    final double lbfgsMag = lbfgs.getMagnitude();
    final double gdMag = gd.getMagnitude();
    if (Math.abs(lbfgsMag - gdMag) / (lbfgsMag + gdMag) > 1e-2) {
        @Nonnull final DeltaSet<Layer> scaledGradient = gd.scale(lbfgsMag / gdMag);
        monitor.log(String.format("Returning Quadratic Cursor %s GD, %s QN", gdMag, lbfgsMag));
        gd.freeRef();
        return new LineSearchCursorBase() {

            @Nonnull
            @Override
            public CharSequence getDirectionType() {
                return CURSOR_NAME;
            }

            @Override
            public DeltaSet<Layer> position(final double t) {
                if (!Double.isFinite(t))
                    throw new IllegalArgumentException();
                return scaledGradient.scale(t - t * t).add(lbfgs.scale(t * t));
            }

            @Override
            public void reset() {
                lbfgsCursor.reset();
            }

            @Nonnull
            @Override
            public LineSearchPoint step(final double t, @Nonnull final TrainingMonitor monitor) {
                if (!Double.isFinite(t))
                    throw new IllegalArgumentException();
                reset();
                position(t).accumulate(1);
                @Nonnull final PointSample sample = subject.measure(monitor).setRate(t);
                // monitor.log(String.format("delta buffers %d %d %d %d %d", sample.delta.apply.size(), origin.delta.apply.size(), lbfgs.apply.size(), gd.apply.size(), scaledGradient.apply.size()));
                inner.addToHistory(sample, monitor);
                @Nonnull final DeltaSet<Layer> tangent = scaledGradient.scale(1 - 2 * t).add(lbfgs.scale(2 * t));
                return new LineSearchPoint(sample, tangent.dot(sample.delta));
            }

            @Override
            public void _free() {
                scaledGradient.freeRef();
                lbfgsCursor.freeRef();
            }
        };
    } else {
        gd.freeRef();
        return lbfgsCursor;
    }
}
Also used : TrainingMonitor(com.simiacryptus.mindseye.opt.TrainingMonitor) Nonnull(javax.annotation.Nonnull) SimpleLineSearchCursor(com.simiacryptus.mindseye.opt.line.SimpleLineSearchCursor) LineSearchPoint(com.simiacryptus.mindseye.opt.line.LineSearchPoint) PointSample(com.simiacryptus.mindseye.lang.PointSample) Layer(com.simiacryptus.mindseye.lang.Layer) LineSearchCursorBase(com.simiacryptus.mindseye.opt.line.LineSearchCursorBase)

Example 18 with TrainingMonitor

use of com.simiacryptus.mindseye.opt.TrainingMonitor in project MindsEye by SimiaCryptus.

the class RecursiveSubspace method buildSubspace.

/**
 * Build subspace nn layer.
 *
 * @param subject     the subject
 * @param measurement the measurement
 * @param monitor     the monitor
 * @return the nn layer
 */
@Nullable
public Layer buildSubspace(@Nonnull Trainable subject, @Nonnull PointSample measurement, @Nonnull TrainingMonitor monitor) {
    @Nonnull PointSample origin = measurement.copyFull().backup();
    @Nonnull final DeltaSet<Layer> direction = measurement.delta.scale(-1);
    final double magnitude = direction.getMagnitude();
    if (Math.abs(magnitude) < 1e-10) {
        monitor.log(String.format("Zero gradient: %s", magnitude));
    } else if (Math.abs(magnitude) < 1e-5) {
        monitor.log(String.format("Low gradient: %s", magnitude));
    }
    boolean hasPlaceholders = direction.getMap().entrySet().stream().filter(x -> x.getKey() instanceof PlaceholderLayer).findAny().isPresent();
    List<Layer> deltaLayers = direction.getMap().entrySet().stream().map(x -> x.getKey()).filter(x -> !(x instanceof PlaceholderLayer)).collect(Collectors.toList());
    int size = deltaLayers.size() + (hasPlaceholders ? 1 : 0);
    if (null == weights || weights.length != size)
        weights = new double[size];
    return new LayerBase() {

        @Nonnull
        Layer self = this;

        @Nonnull
        @Override
        public Result eval(Result... array) {
            assertAlive();
            origin.restore();
            IntStream.range(0, deltaLayers.size()).forEach(i -> {
                direction.getMap().get(deltaLayers.get(i)).accumulate(weights[hasPlaceholders ? (i + 1) : i]);
            });
            if (hasPlaceholders) {
                direction.getMap().entrySet().stream().filter(x -> x.getKey() instanceof PlaceholderLayer).distinct().forEach(entry -> entry.getValue().accumulate(weights[0]));
            }
            PointSample measure = subject.measure(monitor);
            double mean = measure.getMean();
            monitor.log(String.format("RecursiveSubspace: %s <- %s", mean, Arrays.toString(weights)));
            direction.addRef();
            return new Result(TensorArray.wrap(new Tensor(mean)), (DeltaSet<Layer> buffer, TensorList data) -> {
                DoubleStream deltaStream = deltaLayers.stream().mapToDouble(layer -> {
                    Delta<Layer> a = direction.getMap().get(layer);
                    Delta<Layer> b = measure.delta.getMap().get(layer);
                    return b.dot(a) / Math.max(Math.sqrt(a.dot(a)), 1e-8);
                });
                if (hasPlaceholders) {
                    deltaStream = DoubleStream.concat(DoubleStream.of(direction.getMap().keySet().stream().filter(x -> x instanceof PlaceholderLayer).distinct().mapToDouble(layer -> {
                        Delta<Layer> a = direction.getMap().get(layer);
                        Delta<Layer> b = measure.delta.getMap().get(layer);
                        return b.dot(a) / Math.max(Math.sqrt(a.dot(a)), 1e-8);
                    }).sum()), deltaStream);
                }
                buffer.get(self, weights).addInPlace(deltaStream.toArray()).freeRef();
            }) {

                @Override
                protected void _free() {
                    measure.freeRef();
                    direction.freeRef();
                }

                @Override
                public boolean isAlive() {
                    return true;
                }
            };
        }

        @Override
        protected void _free() {
            direction.freeRef();
            origin.freeRef();
            super._free();
        }

        @Nonnull
        @Override
        public JsonObject getJson(Map<CharSequence, byte[]> resources, DataSerializer dataSerializer) {
            throw new IllegalStateException();
        }

        @Nullable
        @Override
        public List<double[]> state() {
            return null;
        }
    };
}
Also used : IntStream(java.util.stream.IntStream) JsonObject(com.google.gson.JsonObject) Arrays(java.util.Arrays) Tensor(com.simiacryptus.mindseye.lang.Tensor) Result(com.simiacryptus.mindseye.lang.Result) ArmijoWolfeSearch(com.simiacryptus.mindseye.opt.line.ArmijoWolfeSearch) DataSerializer(com.simiacryptus.mindseye.lang.DataSerializer) StateSet(com.simiacryptus.mindseye.lang.StateSet) Trainable(com.simiacryptus.mindseye.eval.Trainable) Delta(com.simiacryptus.mindseye.lang.Delta) TrainingMonitor(com.simiacryptus.mindseye.opt.TrainingMonitor) Map(java.util.Map) PlaceholderLayer(com.simiacryptus.mindseye.layers.java.PlaceholderLayer) Layer(com.simiacryptus.mindseye.lang.Layer) SimpleLineSearchCursor(com.simiacryptus.mindseye.opt.line.SimpleLineSearchCursor) IterativeTrainer(com.simiacryptus.mindseye.opt.IterativeTrainer) Nonnull(javax.annotation.Nonnull) Nullable(javax.annotation.Nullable) BasicTrainable(com.simiacryptus.mindseye.eval.BasicTrainable) Collectors(java.util.stream.Collectors) DoubleStream(java.util.stream.DoubleStream) List(java.util.List) LayerBase(com.simiacryptus.mindseye.lang.LayerBase) ArrayTrainable(com.simiacryptus.mindseye.eval.ArrayTrainable) TensorList(com.simiacryptus.mindseye.lang.TensorList) TensorArray(com.simiacryptus.mindseye.lang.TensorArray) DeltaSet(com.simiacryptus.mindseye.lang.DeltaSet) PointSample(com.simiacryptus.mindseye.lang.PointSample) Tensor(com.simiacryptus.mindseye.lang.Tensor) Nonnull(javax.annotation.Nonnull) DeltaSet(com.simiacryptus.mindseye.lang.DeltaSet) TensorList(com.simiacryptus.mindseye.lang.TensorList) PlaceholderLayer(com.simiacryptus.mindseye.layers.java.PlaceholderLayer) Layer(com.simiacryptus.mindseye.lang.Layer) Result(com.simiacryptus.mindseye.lang.Result) LayerBase(com.simiacryptus.mindseye.lang.LayerBase) DoubleStream(java.util.stream.DoubleStream) PointSample(com.simiacryptus.mindseye.lang.PointSample) Map(java.util.Map) PlaceholderLayer(com.simiacryptus.mindseye.layers.java.PlaceholderLayer) DataSerializer(com.simiacryptus.mindseye.lang.DataSerializer) Nullable(javax.annotation.Nullable)

Example 19 with TrainingMonitor

use of com.simiacryptus.mindseye.opt.TrainingMonitor in project MindsEye by SimiaCryptus.

the class RecursiveSubspace method train.

/**
 * Train.
 *
 * @param monitor    the monitor
 * @param macroLayer the macro layer
 */
public void train(@Nonnull TrainingMonitor monitor, Layer macroLayer) {
    @Nonnull BasicTrainable inner = new BasicTrainable(macroLayer);
    // @javax.annotation.Nonnull Tensor tensor = new Tensor();
    @Nonnull ArrayTrainable trainable = new ArrayTrainable(inner, new Tensor[][] { {} });
    inner.freeRef();
    // tensor.freeRef();
    new IterativeTrainer(trainable).setOrientation(new LBFGS()).setLineSearchFactory(n -> new ArmijoWolfeSearch()).setMonitor(new TrainingMonitor() {

        @Override
        public void log(String msg) {
            monitor.log("\t" + msg);
        }
    }).setMaxIterations(getIterations()).setIterationsPerSample(getIterations()).setTerminateThreshold(terminateThreshold).runAndFree();
    trainable.freeRef();
}
Also used : BasicTrainable(com.simiacryptus.mindseye.eval.BasicTrainable) IntStream(java.util.stream.IntStream) JsonObject(com.google.gson.JsonObject) Arrays(java.util.Arrays) Tensor(com.simiacryptus.mindseye.lang.Tensor) Result(com.simiacryptus.mindseye.lang.Result) ArmijoWolfeSearch(com.simiacryptus.mindseye.opt.line.ArmijoWolfeSearch) DataSerializer(com.simiacryptus.mindseye.lang.DataSerializer) StateSet(com.simiacryptus.mindseye.lang.StateSet) Trainable(com.simiacryptus.mindseye.eval.Trainable) Delta(com.simiacryptus.mindseye.lang.Delta) TrainingMonitor(com.simiacryptus.mindseye.opt.TrainingMonitor) Map(java.util.Map) PlaceholderLayer(com.simiacryptus.mindseye.layers.java.PlaceholderLayer) Layer(com.simiacryptus.mindseye.lang.Layer) SimpleLineSearchCursor(com.simiacryptus.mindseye.opt.line.SimpleLineSearchCursor) IterativeTrainer(com.simiacryptus.mindseye.opt.IterativeTrainer) Nonnull(javax.annotation.Nonnull) Nullable(javax.annotation.Nullable) BasicTrainable(com.simiacryptus.mindseye.eval.BasicTrainable) Collectors(java.util.stream.Collectors) DoubleStream(java.util.stream.DoubleStream) List(java.util.List) LayerBase(com.simiacryptus.mindseye.lang.LayerBase) ArrayTrainable(com.simiacryptus.mindseye.eval.ArrayTrainable) TensorList(com.simiacryptus.mindseye.lang.TensorList) TensorArray(com.simiacryptus.mindseye.lang.TensorArray) DeltaSet(com.simiacryptus.mindseye.lang.DeltaSet) PointSample(com.simiacryptus.mindseye.lang.PointSample) IterativeTrainer(com.simiacryptus.mindseye.opt.IterativeTrainer) TrainingMonitor(com.simiacryptus.mindseye.opt.TrainingMonitor) ArmijoWolfeSearch(com.simiacryptus.mindseye.opt.line.ArmijoWolfeSearch) Nonnull(javax.annotation.Nonnull) ArrayTrainable(com.simiacryptus.mindseye.eval.ArrayTrainable)

Aggregations

TrainingMonitor (com.simiacryptus.mindseye.opt.TrainingMonitor)19 Nonnull (javax.annotation.Nonnull)19 Layer (com.simiacryptus.mindseye.lang.Layer)13 List (java.util.List)12 Nullable (javax.annotation.Nullable)12 PointSample (com.simiacryptus.mindseye.lang.PointSample)10 Tensor (com.simiacryptus.mindseye.lang.Tensor)10 Arrays (java.util.Arrays)10 DeltaSet (com.simiacryptus.mindseye.lang.DeltaSet)9 Trainable (com.simiacryptus.mindseye.eval.Trainable)8 StepRecord (com.simiacryptus.mindseye.test.StepRecord)8 IntStream (java.util.stream.IntStream)8 ArrayTrainable (com.simiacryptus.mindseye.eval.ArrayTrainable)7 TensorList (com.simiacryptus.mindseye.lang.TensorList)7 ArrayList (java.util.ArrayList)7 Result (com.simiacryptus.mindseye.lang.Result)6 StateSet (com.simiacryptus.mindseye.lang.StateSet)6 IterativeTrainer (com.simiacryptus.mindseye.opt.IterativeTrainer)6 ConstantResult (com.simiacryptus.mindseye.lang.ConstantResult)5 TensorArray (com.simiacryptus.mindseye.lang.TensorArray)5