use of com.simiacryptus.mindseye.network.SimpleLossNetwork in project MindsEye by SimiaCryptus.
the class GDTest method train.
@Override
public void train(@Nonnull final NotebookOutput log, @Nonnull final Layer network, @Nonnull final Tensor[][] trainingData, final TrainingMonitor monitor) {
log.code(() -> {
@Nonnull final SimpleLossNetwork supervisedNetwork = new SimpleLossNetwork(network, new EntropyLossLayer());
@Nonnull final Trainable trainable = new SampledArrayTrainable(trainingData, supervisedNetwork, 1000);
return new IterativeTrainer(trainable).setMonitor(monitor).setOrientation(new GradientDescent()).setTimeout(5, TimeUnit.MINUTES).setMaxIterations(500).runAndFree();
});
}
use of com.simiacryptus.mindseye.network.SimpleLossNetwork in project MindsEye by SimiaCryptus.
the class MomentumTest method train.
@Override
public void train(@Nonnull final NotebookOutput log, @Nonnull final Layer network, @Nonnull final Tensor[][] trainingData, final TrainingMonitor monitor) {
log.code(() -> {
@Nonnull final SimpleLossNetwork supervisedNetwork = new SimpleLossNetwork(network, new EntropyLossLayer());
@Nonnull final Trainable trainable = new SampledArrayTrainable(trainingData, supervisedNetwork, 1000);
return new IterativeTrainer(trainable).setMonitor(monitor).setOrientation(new ValidatingOrientationWrapper(new MomentumStrategy(new GradientDescent()).setCarryOver(0.8))).setTimeout(5, TimeUnit.MINUTES).setMaxIterations(500).runAndFree();
});
}
use of com.simiacryptus.mindseye.network.SimpleLossNetwork in project MindsEye by SimiaCryptus.
the class QQNTest method train.
@Override
public void train(@Nonnull final NotebookOutput log, @Nonnull final Layer network, @Nonnull final Tensor[][] trainingData, final TrainingMonitor monitor) {
log.code(() -> {
@Nonnull final SimpleLossNetwork supervisedNetwork = new SimpleLossNetwork(network, new EntropyLossLayer());
// return new IterativeTrainer(new SampledArrayTrainable(trainingData, supervisedNetwork, 10000))
@Nonnull ValidatingTrainer trainer = new ValidatingTrainer(new SampledArrayTrainable(trainingData, supervisedNetwork, 1000, 10000), new ArrayTrainable(trainingData, supervisedNetwork)).setMonitor(monitor);
trainer.getRegimen().get(0).setOrientation(new QQN());
return trainer.setTimeout(5, TimeUnit.MINUTES).setMaxIterations(500).run();
});
}
use of com.simiacryptus.mindseye.network.SimpleLossNetwork in project MindsEye by SimiaCryptus.
the class LinearSumConstraintTest method train.
@Override
public void train(@Nonnull final NotebookOutput log, @Nonnull final Layer network, @Nonnull final Tensor[][] trainingData, final TrainingMonitor monitor) {
log.code(() -> {
@Nonnull final SimpleLossNetwork supervisedNetwork = new SimpleLossNetwork(network, new EntropyLossLayer());
@Nonnull final Trainable trainable = new SampledArrayTrainable(trainingData, supervisedNetwork, 10000);
@Nonnull final TrustRegionStrategy trustRegionStrategy = new TrustRegionStrategy() {
@Override
public TrustRegion getRegionPolicy(final Layer layer) {
return new LinearSumConstraint();
}
};
return new IterativeTrainer(trainable).setIterationsPerSample(100).setMonitor(monitor).setOrientation(trustRegionStrategy).setTimeout(3, TimeUnit.MINUTES).setMaxIterations(500).runAndFree();
});
}
use of com.simiacryptus.mindseye.network.SimpleLossNetwork in project MindsEye by SimiaCryptus.
the class L2NormalizationTest method train.
@Override
public void train(@Nonnull final NotebookOutput log, @Nonnull final Layer network, @Nonnull final Tensor[][] trainingData, final TrainingMonitor monitor) {
log.p("Training a model involves a few different components. First, our model is combined mapCoords a loss function. " + "Then we take that model and combine it mapCoords our training data to define a trainable object. " + "Finally, we use a simple iterative scheme to refine the weights of our model. " + "The final output is the last output value of the loss function when evaluating the last batch.");
log.code(() -> {
@Nonnull final SimpleLossNetwork supervisedNetwork = new SimpleLossNetwork(network, new EntropyLossLayer());
@Nonnull final Trainable trainable = new L12Normalizer(new SampledArrayTrainable(trainingData, supervisedNetwork, 1000)) {
@Override
public Layer getLayer() {
return inner.getLayer();
}
@Override
protected double getL1(final Layer layer) {
return 0.0;
}
@Override
protected double getL2(final Layer layer) {
return 1e4;
}
};
return new IterativeTrainer(trainable).setMonitor(monitor).setTimeout(3, TimeUnit.MINUTES).setMaxIterations(500).runAndFree();
});
}
Aggregations