use of com.simiacryptus.mindseye.opt.orient.GradientDescent in project MindsEye by SimiaCryptus.
the class SimpleStochasticGradientDescentTest method train.
@Override
public void train(@Nonnull final NotebookOutput log, @Nonnull final Layer network, @Nonnull final Tensor[][] trainingData, final TrainingMonitor monitor) {
log.p("Training a model involves a few different components. First, our model is combined mapCoords a loss function. " + "Then we take that model and combine it mapCoords our training data to define a trainable object. " + "Finally, we use a simple iterative scheme to refine the weights of our model. " + "The final output is the last output value of the loss function when evaluating the last batch.");
log.code(() -> {
@Nonnull final SimpleLossNetwork supervisedNetwork = new SimpleLossNetwork(network, new EntropyLossLayer());
@Nonnull final Trainable trainable = new SampledArrayTrainable(trainingData, supervisedNetwork, 10000);
return new IterativeTrainer(trainable).setMonitor(monitor).setOrientation(new GradientDescent()).setTimeout(5, TimeUnit.MINUTES).setMaxIterations(500).runAndFree();
});
}
use of com.simiacryptus.mindseye.opt.orient.GradientDescent in project MindsEye by SimiaCryptus.
the class BisectionLineSearchTest method train.
@Override
public void train(@Nonnull final NotebookOutput log, @Nonnull final Layer network, @Nonnull final Tensor[][] trainingData, final TrainingMonitor monitor) {
log.code(() -> {
@Nonnull final SimpleLossNetwork supervisedNetwork = new SimpleLossNetwork(network, new EntropyLossLayer());
@Nonnull final Trainable trainable = new SampledArrayTrainable(trainingData, supervisedNetwork, 1000);
return new IterativeTrainer(trainable).setMonitor(monitor).setOrientation(new GradientDescent()).setLineSearchFactory((@Nonnull final CharSequence name) -> new BisectionSearch()).setTimeout(3, TimeUnit.MINUTES).setMaxIterations(500).runAndFree();
});
}
use of com.simiacryptus.mindseye.opt.orient.GradientDescent in project MindsEye by SimiaCryptus.
the class TrainingTester method trainGD.
/**
* Train gd list.
*
* @param log the log
* @param trainable the trainable
* @return the list
*/
@Nonnull
public List<StepRecord> trainGD(@Nonnull final NotebookOutput log, final Trainable trainable) {
log.p("First, we train using basic gradient descent method apply weak line search conditions.");
@Nonnull final List<StepRecord> history = new ArrayList<>();
@Nonnull final TrainingMonitor monitor = TrainingTester.getMonitor(history);
try {
log.code(() -> {
return new IterativeTrainer(trainable).setLineSearchFactory(label -> new ArmijoWolfeSearch()).setOrientation(new GradientDescent()).setMonitor(monitor).setTimeout(30, TimeUnit.SECONDS).setMaxIterations(250).setTerminateThreshold(0).runAndFree();
});
} catch (Throwable e) {
if (isThrowExceptions())
throw new RuntimeException(e);
}
return history;
}
use of com.simiacryptus.mindseye.opt.orient.GradientDescent in project MindsEye by SimiaCryptus.
the class ImageDecompositionLab method train.
/**
* Train.
*
* @param log the log
* @param monitor the monitor
* @param network the network
* @param data the data
* @param timeoutMinutes the timeout minutes
* @param mask the mask
*/
protected void train(@Nonnull final NotebookOutput log, final TrainingMonitor monitor, final Layer network, @Nonnull final Tensor[][] data, final int timeoutMinutes, final boolean... mask) {
log.out("Training for %s minutes, mask=%s", timeoutMinutes, Arrays.toString(mask));
log.code(() -> {
@Nonnull SampledTrainable trainingSubject = new SampledArrayTrainable(data, network, data.length);
trainingSubject = (SampledTrainable) ((TrainableDataMask) trainingSubject).setMask(mask);
@Nonnull final ValidatingTrainer validatingTrainer = new ValidatingTrainer(trainingSubject, new ArrayTrainable(data, network)).setMaxTrainingSize(data.length).setMinTrainingSize(5).setMonitor(monitor).setTimeout(timeoutMinutes, TimeUnit.MINUTES).setMaxIterations(1000);
validatingTrainer.getRegimen().get(0).setOrientation(new GradientDescent()).setLineSearchFactory(name -> name.equals(QQN.CURSOR_NAME) ? new QuadraticSearch().setCurrentRate(1.0) : new QuadraticSearch().setCurrentRate(1.0));
validatingTrainer.run();
});
}
use of com.simiacryptus.mindseye.opt.orient.GradientDescent in project MindsEye by SimiaCryptus.
the class TrainingTester method trainCjGD.
/**
* Train cj gd list.
*
* @param log the log
* @param trainable the trainable
* @return the list
*/
@Nonnull
public List<StepRecord> trainCjGD(@Nonnull final NotebookOutput log, final Trainable trainable) {
log.p("First, we use a conjugate gradient descent method, which converges the fastest for purely linear functions.");
@Nonnull final List<StepRecord> history = new ArrayList<>();
@Nonnull final TrainingMonitor monitor = TrainingTester.getMonitor(history);
try {
log.code(() -> {
return new IterativeTrainer(trainable).setLineSearchFactory(label -> new QuadraticSearch()).setOrientation(new GradientDescent()).setMonitor(monitor).setTimeout(30, TimeUnit.SECONDS).setMaxIterations(250).setTerminateThreshold(0).runAndFree();
});
} catch (Throwable e) {
if (isThrowExceptions())
throw new RuntimeException(e);
}
return history;
}
Aggregations