Search in sources :

Example 1 with GradientDescent

use of com.simiacryptus.mindseye.opt.orient.GradientDescent in project MindsEye by SimiaCryptus.

the class SimpleStochasticGradientDescentTest method train.

@Override
public void train(@Nonnull final NotebookOutput log, @Nonnull final Layer network, @Nonnull final Tensor[][] trainingData, final TrainingMonitor monitor) {
    log.p("Training a model involves a few different components. First, our model is combined mapCoords a loss function. " + "Then we take that model and combine it mapCoords our training data to define a trainable object. " + "Finally, we use a simple iterative scheme to refine the weights of our model. " + "The final output is the last output value of the loss function when evaluating the last batch.");
    log.code(() -> {
        @Nonnull final SimpleLossNetwork supervisedNetwork = new SimpleLossNetwork(network, new EntropyLossLayer());
        @Nonnull final Trainable trainable = new SampledArrayTrainable(trainingData, supervisedNetwork, 10000);
        return new IterativeTrainer(trainable).setMonitor(monitor).setOrientation(new GradientDescent()).setTimeout(5, TimeUnit.MINUTES).setMaxIterations(500).runAndFree();
    });
}
Also used : IterativeTrainer(com.simiacryptus.mindseye.opt.IterativeTrainer) Nonnull(javax.annotation.Nonnull) SampledArrayTrainable(com.simiacryptus.mindseye.eval.SampledArrayTrainable) GradientDescent(com.simiacryptus.mindseye.opt.orient.GradientDescent) EntropyLossLayer(com.simiacryptus.mindseye.layers.java.EntropyLossLayer) SimpleLossNetwork(com.simiacryptus.mindseye.network.SimpleLossNetwork) Trainable(com.simiacryptus.mindseye.eval.Trainable) SampledArrayTrainable(com.simiacryptus.mindseye.eval.SampledArrayTrainable)

Example 2 with GradientDescent

use of com.simiacryptus.mindseye.opt.orient.GradientDescent in project MindsEye by SimiaCryptus.

the class BisectionLineSearchTest method train.

@Override
public void train(@Nonnull final NotebookOutput log, @Nonnull final Layer network, @Nonnull final Tensor[][] trainingData, final TrainingMonitor monitor) {
    log.code(() -> {
        @Nonnull final SimpleLossNetwork supervisedNetwork = new SimpleLossNetwork(network, new EntropyLossLayer());
        @Nonnull final Trainable trainable = new SampledArrayTrainable(trainingData, supervisedNetwork, 1000);
        return new IterativeTrainer(trainable).setMonitor(monitor).setOrientation(new GradientDescent()).setLineSearchFactory((@Nonnull final CharSequence name) -> new BisectionSearch()).setTimeout(3, TimeUnit.MINUTES).setMaxIterations(500).runAndFree();
    });
}
Also used : IterativeTrainer(com.simiacryptus.mindseye.opt.IterativeTrainer) Nonnull(javax.annotation.Nonnull) SampledArrayTrainable(com.simiacryptus.mindseye.eval.SampledArrayTrainable) GradientDescent(com.simiacryptus.mindseye.opt.orient.GradientDescent) EntropyLossLayer(com.simiacryptus.mindseye.layers.java.EntropyLossLayer) SimpleLossNetwork(com.simiacryptus.mindseye.network.SimpleLossNetwork) Trainable(com.simiacryptus.mindseye.eval.Trainable) SampledArrayTrainable(com.simiacryptus.mindseye.eval.SampledArrayTrainable)

Example 3 with GradientDescent

use of com.simiacryptus.mindseye.opt.orient.GradientDescent in project MindsEye by SimiaCryptus.

the class TrainingTester method trainGD.

/**
 * Train gd list.
 *
 * @param log       the log
 * @param trainable the trainable
 * @return the list
 */
@Nonnull
public List<StepRecord> trainGD(@Nonnull final NotebookOutput log, final Trainable trainable) {
    log.p("First, we train using basic gradient descent method apply weak line search conditions.");
    @Nonnull final List<StepRecord> history = new ArrayList<>();
    @Nonnull final TrainingMonitor monitor = TrainingTester.getMonitor(history);
    try {
        log.code(() -> {
            return new IterativeTrainer(trainable).setLineSearchFactory(label -> new ArmijoWolfeSearch()).setOrientation(new GradientDescent()).setMonitor(monitor).setTimeout(30, TimeUnit.SECONDS).setMaxIterations(250).setTerminateThreshold(0).runAndFree();
        });
    } catch (Throwable e) {
        if (isThrowExceptions())
            throw new RuntimeException(e);
    }
    return history;
}
Also used : StepRecord(com.simiacryptus.mindseye.test.StepRecord) TrainingMonitor(com.simiacryptus.mindseye.opt.TrainingMonitor) IterativeTrainer(com.simiacryptus.mindseye.opt.IterativeTrainer) ArmijoWolfeSearch(com.simiacryptus.mindseye.opt.line.ArmijoWolfeSearch) Nonnull(javax.annotation.Nonnull) ArrayList(java.util.ArrayList) GradientDescent(com.simiacryptus.mindseye.opt.orient.GradientDescent) Nonnull(javax.annotation.Nonnull)

Example 4 with GradientDescent

use of com.simiacryptus.mindseye.opt.orient.GradientDescent in project MindsEye by SimiaCryptus.

the class ImageDecompositionLab method train.

/**
 * Train.
 *
 * @param log            the log
 * @param monitor        the monitor
 * @param network        the network
 * @param data           the data
 * @param timeoutMinutes the timeout minutes
 * @param mask           the mask
 */
protected void train(@Nonnull final NotebookOutput log, final TrainingMonitor monitor, final Layer network, @Nonnull final Tensor[][] data, final int timeoutMinutes, final boolean... mask) {
    log.out("Training for %s minutes, mask=%s", timeoutMinutes, Arrays.toString(mask));
    log.code(() -> {
        @Nonnull SampledTrainable trainingSubject = new SampledArrayTrainable(data, network, data.length);
        trainingSubject = (SampledTrainable) ((TrainableDataMask) trainingSubject).setMask(mask);
        @Nonnull final ValidatingTrainer validatingTrainer = new ValidatingTrainer(trainingSubject, new ArrayTrainable(data, network)).setMaxTrainingSize(data.length).setMinTrainingSize(5).setMonitor(monitor).setTimeout(timeoutMinutes, TimeUnit.MINUTES).setMaxIterations(1000);
        validatingTrainer.getRegimen().get(0).setOrientation(new GradientDescent()).setLineSearchFactory(name -> name.equals(QQN.CURSOR_NAME) ? new QuadraticSearch().setCurrentRate(1.0) : new QuadraticSearch().setCurrentRate(1.0));
        validatingTrainer.run();
    });
}
Also used : TrainableDataMask(com.simiacryptus.mindseye.eval.TrainableDataMask) SampledTrainable(com.simiacryptus.mindseye.eval.SampledTrainable) Nonnull(javax.annotation.Nonnull) SampledArrayTrainable(com.simiacryptus.mindseye.eval.SampledArrayTrainable) QuadraticSearch(com.simiacryptus.mindseye.opt.line.QuadraticSearch) GradientDescent(com.simiacryptus.mindseye.opt.orient.GradientDescent) ValidatingTrainer(com.simiacryptus.mindseye.opt.ValidatingTrainer) SampledArrayTrainable(com.simiacryptus.mindseye.eval.SampledArrayTrainable) ArrayTrainable(com.simiacryptus.mindseye.eval.ArrayTrainable)

Example 5 with GradientDescent

use of com.simiacryptus.mindseye.opt.orient.GradientDescent in project MindsEye by SimiaCryptus.

the class TrainingTester method trainCjGD.

/**
 * Train cj gd list.
 *
 * @param log       the log
 * @param trainable the trainable
 * @return the list
 */
@Nonnull
public List<StepRecord> trainCjGD(@Nonnull final NotebookOutput log, final Trainable trainable) {
    log.p("First, we use a conjugate gradient descent method, which converges the fastest for purely linear functions.");
    @Nonnull final List<StepRecord> history = new ArrayList<>();
    @Nonnull final TrainingMonitor monitor = TrainingTester.getMonitor(history);
    try {
        log.code(() -> {
            return new IterativeTrainer(trainable).setLineSearchFactory(label -> new QuadraticSearch()).setOrientation(new GradientDescent()).setMonitor(monitor).setTimeout(30, TimeUnit.SECONDS).setMaxIterations(250).setTerminateThreshold(0).runAndFree();
        });
    } catch (Throwable e) {
        if (isThrowExceptions())
            throw new RuntimeException(e);
    }
    return history;
}
Also used : StepRecord(com.simiacryptus.mindseye.test.StepRecord) TrainingMonitor(com.simiacryptus.mindseye.opt.TrainingMonitor) IterativeTrainer(com.simiacryptus.mindseye.opt.IterativeTrainer) Nonnull(javax.annotation.Nonnull) QuadraticSearch(com.simiacryptus.mindseye.opt.line.QuadraticSearch) ArrayList(java.util.ArrayList) GradientDescent(com.simiacryptus.mindseye.opt.orient.GradientDescent) Nonnull(javax.annotation.Nonnull)

Aggregations

GradientDescent (com.simiacryptus.mindseye.opt.orient.GradientDescent)7 Nonnull (javax.annotation.Nonnull)7 IterativeTrainer (com.simiacryptus.mindseye.opt.IterativeTrainer)6 SampledArrayTrainable (com.simiacryptus.mindseye.eval.SampledArrayTrainable)5 Trainable (com.simiacryptus.mindseye.eval.Trainable)4 EntropyLossLayer (com.simiacryptus.mindseye.layers.java.EntropyLossLayer)4 SimpleLossNetwork (com.simiacryptus.mindseye.network.SimpleLossNetwork)4 TrainingMonitor (com.simiacryptus.mindseye.opt.TrainingMonitor)2 QuadraticSearch (com.simiacryptus.mindseye.opt.line.QuadraticSearch)2 StepRecord (com.simiacryptus.mindseye.test.StepRecord)2 ArrayList (java.util.ArrayList)2 ArrayTrainable (com.simiacryptus.mindseye.eval.ArrayTrainable)1 SampledTrainable (com.simiacryptus.mindseye.eval.SampledTrainable)1 TrainableDataMask (com.simiacryptus.mindseye.eval.TrainableDataMask)1 ValidatingTrainer (com.simiacryptus.mindseye.opt.ValidatingTrainer)1 ArmijoWolfeSearch (com.simiacryptus.mindseye.opt.line.ArmijoWolfeSearch)1