Search in sources :

Example 16 with Trainable

use of com.simiacryptus.mindseye.eval.Trainable in project MindsEye by SimiaCryptus.

the class TrainingTester method trainLBFGS.

/**
 * Train lbfgs list.
 *
 * @param log       the log
 * @param trainable the trainable
 * @return the list
 */
@Nonnull
public List<StepRecord> trainLBFGS(@Nonnull final NotebookOutput log, final Trainable trainable) {
    log.p("Next, we apply the same optimization using L-BFGS, which is nearly ideal for purely second-order or quadratic functions.");
    @Nonnull final List<StepRecord> history = new ArrayList<>();
    @Nonnull final TrainingMonitor monitor = TrainingTester.getMonitor(history);
    try {
        log.code(() -> {
            return new IterativeTrainer(trainable).setLineSearchFactory(label -> new ArmijoWolfeSearch()).setOrientation(new LBFGS()).setMonitor(monitor).setTimeout(30, TimeUnit.SECONDS).setIterationsPerSample(100).setMaxIterations(250).setTerminateThreshold(0).runAndFree();
        });
    } catch (Throwable e) {
        if (isThrowExceptions())
            throw new RuntimeException(e);
    }
    return history;
}
Also used : StepRecord(com.simiacryptus.mindseye.test.StepRecord) PipelineNetwork(com.simiacryptus.mindseye.network.PipelineNetwork) IntStream(java.util.stream.IntStream) Arrays(java.util.Arrays) BiFunction(java.util.function.BiFunction) LoggerFactory(org.slf4j.LoggerFactory) Tensor(com.simiacryptus.mindseye.lang.Tensor) HashMap(java.util.HashMap) Random(java.util.Random) Result(com.simiacryptus.mindseye.lang.Result) ArmijoWolfeSearch(com.simiacryptus.mindseye.opt.line.ArmijoWolfeSearch) ArrayList(java.util.ArrayList) Trainable(com.simiacryptus.mindseye.eval.Trainable) ConstantResult(com.simiacryptus.mindseye.lang.ConstantResult) TrainingMonitor(com.simiacryptus.mindseye.opt.TrainingMonitor) Map(java.util.Map) Layer(com.simiacryptus.mindseye.lang.Layer) QuadraticSearch(com.simiacryptus.mindseye.opt.line.QuadraticSearch) LBFGS(com.simiacryptus.mindseye.opt.orient.LBFGS) RecursiveSubspace(com.simiacryptus.mindseye.opt.orient.RecursiveSubspace) StepRecord(com.simiacryptus.mindseye.test.StepRecord) NotebookOutput(com.simiacryptus.util.io.NotebookOutput) ReferenceCounting(com.simiacryptus.mindseye.lang.ReferenceCounting) IterativeTrainer(com.simiacryptus.mindseye.opt.IterativeTrainer) Nonnull(javax.annotation.Nonnull) Nullable(javax.annotation.Nullable) MeanSqLossLayer(com.simiacryptus.mindseye.layers.java.MeanSqLossLayer) Logger(org.slf4j.Logger) PlotCanvas(smile.plot.PlotCanvas) QQN(com.simiacryptus.mindseye.opt.orient.QQN) GradientDescent(com.simiacryptus.mindseye.opt.orient.GradientDescent) BasicTrainable(com.simiacryptus.mindseye.eval.BasicTrainable) StaticLearningRate(com.simiacryptus.mindseye.opt.line.StaticLearningRate) TestUtil(com.simiacryptus.mindseye.test.TestUtil) DAGNode(com.simiacryptus.mindseye.network.DAGNode) DoubleStream(java.util.stream.DoubleStream) java.awt(java.awt) TimeUnit(java.util.concurrent.TimeUnit) List(java.util.List) Stream(java.util.stream.Stream) ArrayTrainable(com.simiacryptus.mindseye.eval.ArrayTrainable) TensorList(com.simiacryptus.mindseye.lang.TensorList) Step(com.simiacryptus.mindseye.opt.Step) ProblemRun(com.simiacryptus.mindseye.test.ProblemRun) javax.swing(javax.swing) LBFGS(com.simiacryptus.mindseye.opt.orient.LBFGS) TrainingMonitor(com.simiacryptus.mindseye.opt.TrainingMonitor) IterativeTrainer(com.simiacryptus.mindseye.opt.IterativeTrainer) ArmijoWolfeSearch(com.simiacryptus.mindseye.opt.line.ArmijoWolfeSearch) Nonnull(javax.annotation.Nonnull) ArrayList(java.util.ArrayList) Nonnull(javax.annotation.Nonnull)

Example 17 with Trainable

use of com.simiacryptus.mindseye.eval.Trainable in project MindsEye by SimiaCryptus.

the class TrainingTester method train.

private List<StepRecord> train(@Nonnull NotebookOutput log, @Nonnull BiFunction<NotebookOutput, Trainable, List<StepRecord>> opt, @Nonnull Layer layer, @Nonnull Tensor[][] data, @Nonnull boolean... mask) {
    try {
        int inputs = data[0].length;
        @Nonnull final PipelineNetwork network = new PipelineNetwork(inputs);
        network.wrap(new MeanSqLossLayer(), network.add(layer, IntStream.range(0, inputs - 1).mapToObj(i -> network.getInput(i)).toArray(i -> new DAGNode[i])), network.getInput(inputs - 1));
        @Nonnull ArrayTrainable trainable = new ArrayTrainable(data, network);
        if (0 < mask.length)
            trainable.setMask(mask);
        List<StepRecord> history;
        try {
            history = opt.apply(log, trainable);
            if (history.stream().mapToDouble(x -> x.fitness).min().orElse(1) > 1e-5) {
                if (!network.isFrozen()) {
                    log.p("This training apply resulted in the following configuration:");
                    log.code(() -> {
                        return network.state().stream().map(Arrays::toString).reduce((a, b) -> a + "\n" + b).orElse("");
                    });
                }
                if (0 < mask.length) {
                    log.p("And regressed input:");
                    log.code(() -> {
                        return Arrays.stream(data).flatMap(x -> Arrays.stream(x)).limit(1).map(x -> x.prettyPrint()).reduce((a, b) -> a + "\n" + b).orElse("");
                    });
                }
                log.p("To produce the following output:");
                log.code(() -> {
                    Result[] array = ConstantResult.batchResultArray(pop(data));
                    @Nullable Result eval = layer.eval(array);
                    for (@Nonnull Result result : array) {
                        result.freeRef();
                        result.getData().freeRef();
                    }
                    TensorList tensorList = eval.getData();
                    eval.freeRef();
                    String str = tensorList.stream().limit(1).map(x -> {
                        String s = x.prettyPrint();
                        x.freeRef();
                        return s;
                    }).reduce((a, b) -> a + "\n" + b).orElse("");
                    tensorList.freeRef();
                    return str;
                });
            } else {
                log.p("Training Converged");
            }
        } finally {
            trainable.freeRef();
            network.freeRef();
        }
        return history;
    } finally {
        layer.freeRef();
        for (@Nonnull Tensor[] tensors : data) {
            for (@Nonnull Tensor tensor : tensors) {
                tensor.freeRef();
            }
        }
    }
}
Also used : PipelineNetwork(com.simiacryptus.mindseye.network.PipelineNetwork) IntStream(java.util.stream.IntStream) Arrays(java.util.Arrays) BiFunction(java.util.function.BiFunction) LoggerFactory(org.slf4j.LoggerFactory) Tensor(com.simiacryptus.mindseye.lang.Tensor) HashMap(java.util.HashMap) Random(java.util.Random) Result(com.simiacryptus.mindseye.lang.Result) ArmijoWolfeSearch(com.simiacryptus.mindseye.opt.line.ArmijoWolfeSearch) ArrayList(java.util.ArrayList) Trainable(com.simiacryptus.mindseye.eval.Trainable) ConstantResult(com.simiacryptus.mindseye.lang.ConstantResult) TrainingMonitor(com.simiacryptus.mindseye.opt.TrainingMonitor) Map(java.util.Map) Layer(com.simiacryptus.mindseye.lang.Layer) QuadraticSearch(com.simiacryptus.mindseye.opt.line.QuadraticSearch) LBFGS(com.simiacryptus.mindseye.opt.orient.LBFGS) RecursiveSubspace(com.simiacryptus.mindseye.opt.orient.RecursiveSubspace) StepRecord(com.simiacryptus.mindseye.test.StepRecord) NotebookOutput(com.simiacryptus.util.io.NotebookOutput) ReferenceCounting(com.simiacryptus.mindseye.lang.ReferenceCounting) IterativeTrainer(com.simiacryptus.mindseye.opt.IterativeTrainer) Nonnull(javax.annotation.Nonnull) Nullable(javax.annotation.Nullable) MeanSqLossLayer(com.simiacryptus.mindseye.layers.java.MeanSqLossLayer) Logger(org.slf4j.Logger) PlotCanvas(smile.plot.PlotCanvas) QQN(com.simiacryptus.mindseye.opt.orient.QQN) GradientDescent(com.simiacryptus.mindseye.opt.orient.GradientDescent) BasicTrainable(com.simiacryptus.mindseye.eval.BasicTrainable) StaticLearningRate(com.simiacryptus.mindseye.opt.line.StaticLearningRate) TestUtil(com.simiacryptus.mindseye.test.TestUtil) DAGNode(com.simiacryptus.mindseye.network.DAGNode) DoubleStream(java.util.stream.DoubleStream) java.awt(java.awt) TimeUnit(java.util.concurrent.TimeUnit) List(java.util.List) Stream(java.util.stream.Stream) ArrayTrainable(com.simiacryptus.mindseye.eval.ArrayTrainable) TensorList(com.simiacryptus.mindseye.lang.TensorList) Step(com.simiacryptus.mindseye.opt.Step) ProblemRun(com.simiacryptus.mindseye.test.ProblemRun) javax.swing(javax.swing) Tensor(com.simiacryptus.mindseye.lang.Tensor) Nonnull(javax.annotation.Nonnull) PipelineNetwork(com.simiacryptus.mindseye.network.PipelineNetwork) ArrayTrainable(com.simiacryptus.mindseye.eval.ArrayTrainable) TensorList(com.simiacryptus.mindseye.lang.TensorList) MeanSqLossLayer(com.simiacryptus.mindseye.layers.java.MeanSqLossLayer) Result(com.simiacryptus.mindseye.lang.Result) ConstantResult(com.simiacryptus.mindseye.lang.ConstantResult) StepRecord(com.simiacryptus.mindseye.test.StepRecord) Arrays(java.util.Arrays) Nullable(javax.annotation.Nullable)

Example 18 with Trainable

use of com.simiacryptus.mindseye.eval.Trainable in project MindsEye by SimiaCryptus.

the class TrustSphereTest method train.

@Override
public void train(@Nonnull final NotebookOutput log, @Nonnull final Layer network, @Nonnull final Tensor[][] trainingData, final TrainingMonitor monitor) {
    log.code(() -> {
        @Nonnull final SimpleLossNetwork supervisedNetwork = new SimpleLossNetwork(network, new EntropyLossLayer());
        @Nonnull final Trainable trainable = new SampledArrayTrainable(trainingData, supervisedNetwork, 10000);
        @Nonnull final TrustRegionStrategy trustRegionStrategy = new TrustRegionStrategy() {

            @Override
            public TrustRegion getRegionPolicy(final Layer layer) {
                return new AdaptiveTrustSphere();
            }
        };
        return new IterativeTrainer(trainable).setIterationsPerSample(100).setMonitor(monitor).setOrientation(trustRegionStrategy).setTimeout(3, TimeUnit.MINUTES).setMaxIterations(500).runAndFree();
    });
}
Also used : IterativeTrainer(com.simiacryptus.mindseye.opt.IterativeTrainer) Nonnull(javax.annotation.Nonnull) SampledArrayTrainable(com.simiacryptus.mindseye.eval.SampledArrayTrainable) EntropyLossLayer(com.simiacryptus.mindseye.layers.java.EntropyLossLayer) SimpleLossNetwork(com.simiacryptus.mindseye.network.SimpleLossNetwork) Trainable(com.simiacryptus.mindseye.eval.Trainable) SampledArrayTrainable(com.simiacryptus.mindseye.eval.SampledArrayTrainable) EntropyLossLayer(com.simiacryptus.mindseye.layers.java.EntropyLossLayer) Layer(com.simiacryptus.mindseye.lang.Layer) TrustRegionStrategy(com.simiacryptus.mindseye.opt.orient.TrustRegionStrategy)

Example 19 with Trainable

use of com.simiacryptus.mindseye.eval.Trainable in project MindsEye by SimiaCryptus.

the class QuadraticLineSearchTest method train.

@Override
public void train(@Nonnull final NotebookOutput log, @Nonnull final Layer network, @Nonnull final Tensor[][] trainingData, final TrainingMonitor monitor) {
    log.code(() -> {
        @Nonnull final SimpleLossNetwork supervisedNetwork = new SimpleLossNetwork(network, new EntropyLossLayer());
        @Nonnull final Trainable trainable = new SampledArrayTrainable(trainingData, supervisedNetwork, 1000);
        return new IterativeTrainer(trainable).setMonitor(monitor).setOrientation(new GradientDescent()).setLineSearchFactory((@Nonnull final CharSequence name) -> new QuadraticSearch()).setTimeout(3, TimeUnit.MINUTES).setMaxIterations(500).runAndFree();
    });
}
Also used : IterativeTrainer(com.simiacryptus.mindseye.opt.IterativeTrainer) Nonnull(javax.annotation.Nonnull) SampledArrayTrainable(com.simiacryptus.mindseye.eval.SampledArrayTrainable) GradientDescent(com.simiacryptus.mindseye.opt.orient.GradientDescent) EntropyLossLayer(com.simiacryptus.mindseye.layers.java.EntropyLossLayer) SimpleLossNetwork(com.simiacryptus.mindseye.network.SimpleLossNetwork) Trainable(com.simiacryptus.mindseye.eval.Trainable) SampledArrayTrainable(com.simiacryptus.mindseye.eval.SampledArrayTrainable)

Example 20 with Trainable

use of com.simiacryptus.mindseye.eval.Trainable in project MindsEye by SimiaCryptus.

the class OwlQn method orient.

@Nonnull
@Override
public LineSearchCursor orient(final Trainable subject, @Nonnull final PointSample measurement, final TrainingMonitor monitor) {
    @Nonnull final SimpleLineSearchCursor gradient = (SimpleLineSearchCursor) inner.orient(subject, measurement, monitor);
    @Nonnull final DeltaSet<Layer> searchDirection = gradient.direction.copy();
    @Nonnull final DeltaSet<Layer> orthant = new DeltaSet<Layer>();
    for (@Nonnull final Layer layer : getLayers(gradient.direction.getMap().keySet())) {
        final double[] weights = gradient.direction.getMap().get(layer).target;
        @Nullable final double[] delta = gradient.direction.getMap().get(layer).getDelta();
        @Nullable final double[] searchDir = searchDirection.get(layer, weights).getDelta();
        @Nullable final double[] suborthant = orthant.get(layer, weights).getDelta();
        for (int i = 0; i < searchDir.length; i++) {
            final int positionSign = sign(weights[i]);
            final int directionSign = sign(delta[i]);
            suborthant[i] = 0 == positionSign ? directionSign : positionSign;
            searchDir[i] += factor_L1 * (weights[i] < 0 ? -1.0 : 1.0);
            if (sign(searchDir[i]) != directionSign) {
                searchDir[i] = delta[i];
            }
        }
        assert null != searchDir;
    }
    return new SimpleLineSearchCursor(subject, measurement, searchDirection) {

        @Nonnull
        @Override
        public LineSearchPoint step(final double alpha, final TrainingMonitor monitor) {
            origin.weights.stream().forEach(d -> d.restore());
            @Nonnull final DeltaSet<Layer> currentDirection = direction.copy();
            direction.getMap().forEach((layer, buffer) -> {
                if (null == buffer.getDelta())
                    return;
                @Nullable final double[] currentDelta = currentDirection.get(layer, buffer.target).getDelta();
                for (int i = 0; i < buffer.getDelta().length; i++) {
                    final double prevValue = buffer.target[i];
                    final double newValue = prevValue + buffer.getDelta()[i] * alpha;
                    if (sign(prevValue) != 0 && sign(prevValue) != sign(newValue)) {
                        currentDelta[i] = 0;
                        buffer.target[i] = 0;
                    } else {
                        buffer.target[i] = newValue;
                    }
                }
            });
            @Nonnull final PointSample measure = subject.measure(monitor).setRate(alpha);
            return new LineSearchPoint(measure, currentDirection.dot(measure.delta));
        }
    }.setDirectionType("OWL/QN");
}
Also used : LineSearchPoint(com.simiacryptus.mindseye.opt.line.LineSearchPoint) Collection(java.util.Collection) Collectors(java.util.stream.Collectors) Trainable(com.simiacryptus.mindseye.eval.Trainable) FullyConnectedLayer(com.simiacryptus.mindseye.layers.java.FullyConnectedLayer) TrainingMonitor(com.simiacryptus.mindseye.opt.TrainingMonitor) Layer(com.simiacryptus.mindseye.lang.Layer) SimpleLineSearchCursor(com.simiacryptus.mindseye.opt.line.SimpleLineSearchCursor) DeltaSet(com.simiacryptus.mindseye.lang.DeltaSet) LineSearchCursor(com.simiacryptus.mindseye.opt.line.LineSearchCursor) Nonnull(javax.annotation.Nonnull) PointSample(com.simiacryptus.mindseye.lang.PointSample) Nullable(javax.annotation.Nullable) Nonnull(javax.annotation.Nonnull) DeltaSet(com.simiacryptus.mindseye.lang.DeltaSet) FullyConnectedLayer(com.simiacryptus.mindseye.layers.java.FullyConnectedLayer) Layer(com.simiacryptus.mindseye.lang.Layer) LineSearchPoint(com.simiacryptus.mindseye.opt.line.LineSearchPoint) TrainingMonitor(com.simiacryptus.mindseye.opt.TrainingMonitor) SimpleLineSearchCursor(com.simiacryptus.mindseye.opt.line.SimpleLineSearchCursor) LineSearchPoint(com.simiacryptus.mindseye.opt.line.LineSearchPoint) PointSample(com.simiacryptus.mindseye.lang.PointSample) Nullable(javax.annotation.Nullable) Nonnull(javax.annotation.Nonnull)

Aggregations

Trainable (com.simiacryptus.mindseye.eval.Trainable)25 Nonnull (javax.annotation.Nonnull)25 IterativeTrainer (com.simiacryptus.mindseye.opt.IterativeTrainer)22 Layer (com.simiacryptus.mindseye.lang.Layer)15 EntropyLossLayer (com.simiacryptus.mindseye.layers.java.EntropyLossLayer)13 SimpleLossNetwork (com.simiacryptus.mindseye.network.SimpleLossNetwork)13 SampledArrayTrainable (com.simiacryptus.mindseye.eval.SampledArrayTrainable)12 ArrayTrainable (com.simiacryptus.mindseye.eval.ArrayTrainable)10 Tensor (com.simiacryptus.mindseye.lang.Tensor)9 List (java.util.List)9 TrainingMonitor (com.simiacryptus.mindseye.opt.TrainingMonitor)8 ArrayList (java.util.ArrayList)8 Map (java.util.Map)8 PipelineNetwork (com.simiacryptus.mindseye.network.PipelineNetwork)7 ArmijoWolfeSearch (com.simiacryptus.mindseye.opt.line.ArmijoWolfeSearch)7 GradientDescent (com.simiacryptus.mindseye.opt.orient.GradientDescent)7 StepRecord (com.simiacryptus.mindseye.test.StepRecord)7 Arrays (java.util.Arrays)7 IntStream (java.util.stream.IntStream)7 DeltaSet (com.simiacryptus.mindseye.lang.DeltaSet)5