use of com.simiacryptus.mindseye.eval.SampledArrayTrainable in project MindsEye by SimiaCryptus.
the class TrustSphereTest method train.
@Override
public void train(@Nonnull final NotebookOutput log, @Nonnull final Layer network, @Nonnull final Tensor[][] trainingData, final TrainingMonitor monitor) {
log.code(() -> {
@Nonnull final SimpleLossNetwork supervisedNetwork = new SimpleLossNetwork(network, new EntropyLossLayer());
@Nonnull final Trainable trainable = new SampledArrayTrainable(trainingData, supervisedNetwork, 10000);
@Nonnull final TrustRegionStrategy trustRegionStrategy = new TrustRegionStrategy() {
@Override
public TrustRegion getRegionPolicy(final Layer layer) {
return new AdaptiveTrustSphere();
}
};
return new IterativeTrainer(trainable).setIterationsPerSample(100).setMonitor(monitor).setOrientation(trustRegionStrategy).setTimeout(3, TimeUnit.MINUTES).setMaxIterations(500).runAndFree();
});
}
use of com.simiacryptus.mindseye.eval.SampledArrayTrainable in project MindsEye by SimiaCryptus.
the class QuadraticLineSearchTest method train.
@Override
public void train(@Nonnull final NotebookOutput log, @Nonnull final Layer network, @Nonnull final Tensor[][] trainingData, final TrainingMonitor monitor) {
log.code(() -> {
@Nonnull final SimpleLossNetwork supervisedNetwork = new SimpleLossNetwork(network, new EntropyLossLayer());
@Nonnull final Trainable trainable = new SampledArrayTrainable(trainingData, supervisedNetwork, 1000);
return new IterativeTrainer(trainable).setMonitor(monitor).setOrientation(new GradientDescent()).setLineSearchFactory((@Nonnull final CharSequence name) -> new QuadraticSearch()).setTimeout(3, TimeUnit.MINUTES).setMaxIterations(500).runAndFree();
});
}
use of com.simiacryptus.mindseye.eval.SampledArrayTrainable in project MindsEye by SimiaCryptus.
the class StaticRateTest method train.
@Override
public void train(@Nonnull final NotebookOutput log, @Nonnull final Layer network, @Nonnull final Tensor[][] trainingData, final TrainingMonitor monitor) {
log.code(() -> {
@Nonnull final SimpleLossNetwork supervisedNetwork = new SimpleLossNetwork(network, new EntropyLossLayer());
@Nonnull final Trainable trainable = new SampledArrayTrainable(trainingData, supervisedNetwork, 1000);
return new IterativeTrainer(trainable).setMonitor(monitor).setOrientation(new GradientDescent()).setLineSearchFactory((@Nonnull final CharSequence name) -> new StaticLearningRate(0.001)).setTimeout(3, TimeUnit.MINUTES).setMaxIterations(500).runAndFree();
});
}
use of com.simiacryptus.mindseye.eval.SampledArrayTrainable in project MindsEye by SimiaCryptus.
the class L1NormalizationTest method train.
@Override
public void train(@Nonnull final NotebookOutput log, @Nonnull final Layer network, @Nonnull final Tensor[][] trainingData, final TrainingMonitor monitor) {
log.code(() -> {
@Nonnull final SimpleLossNetwork supervisedNetwork = new SimpleLossNetwork(network, new EntropyLossLayer());
@Nonnull final Trainable trainable = new L12Normalizer(new SampledArrayTrainable(trainingData, supervisedNetwork, 1000)) {
@Override
public Layer getLayer() {
return inner.getLayer();
}
@Override
protected double getL1(final Layer layer) {
return 1.0;
}
@Override
protected double getL2(final Layer layer) {
return 0;
}
};
return new IterativeTrainer(trainable).setMonitor(monitor).setTimeout(3, TimeUnit.MINUTES).setMaxIterations(500).runAndFree();
});
}
Aggregations