use of com.simiacryptus.mindseye.network.SimpleLossNetwork in project MindsEye by SimiaCryptus.
the class RecursiveSubspaceTest method train.
@Override
public void train(@Nonnull final NotebookOutput log, @Nonnull final Layer network, @Nonnull final Tensor[][] trainingData, final TrainingMonitor monitor) {
log.code(() -> {
@Nonnull final SimpleLossNetwork supervisedNetwork = new SimpleLossNetwork(network, new EntropyLossLayer());
@Nonnull ValidatingTrainer trainer = new ValidatingTrainer(new SampledArrayTrainable(trainingData, supervisedNetwork, 1000, 1000), new ArrayTrainable(trainingData, supervisedNetwork, 1000).cached()).setMonitor(monitor);
trainer.getRegimen().get(0).setOrientation(getOrientation()).setLineSearchFactory(name -> name.toString().contains("LBFGS") ? new StaticLearningRate(1.0) : new QuadraticSearch());
return trainer.setTimeout(15, TimeUnit.MINUTES).setMaxIterations(500).run();
});
}
use of com.simiacryptus.mindseye.network.SimpleLossNetwork in project MindsEye by SimiaCryptus.
the class SingleOrthantTrustRegionTest method train.
@Override
public void train(@Nonnull final NotebookOutput log, @Nonnull final Layer network, @Nonnull final Tensor[][] trainingData, final TrainingMonitor monitor) {
log.code(() -> {
@Nonnull final SimpleLossNetwork supervisedNetwork = new SimpleLossNetwork(network, new EntropyLossLayer());
@Nonnull final Trainable trainable = new SampledArrayTrainable(trainingData, supervisedNetwork, 10000);
@Nonnull final TrustRegionStrategy trustRegionStrategy = new TrustRegionStrategy() {
@Override
public TrustRegion getRegionPolicy(final Layer layer) {
return new SingleOrthant();
}
};
return new IterativeTrainer(trainable).setIterationsPerSample(100).setMonitor(monitor).setOrientation(trustRegionStrategy).setTimeout(3, TimeUnit.MINUTES).setMaxIterations(500).runAndFree();
});
}
use of com.simiacryptus.mindseye.network.SimpleLossNetwork in project MindsEye by SimiaCryptus.
the class AutoencoderNetwork method train.
/**
* Train autoencoder network . training parameters.
*
* @return the autoencoder network . training parameters
*/
@Nonnull
public AutoencoderNetwork.TrainingParameters train() {
return new AutoencoderNetwork.TrainingParameters() {
@Nonnull
@Override
public SimpleLossNetwork getTrainingNetwork() {
@Nonnull final PipelineNetwork student = new PipelineNetwork();
student.add(encoder);
student.add(decoder);
return new SimpleLossNetwork(student, new MeanSqLossLayer());
}
@Nonnull
@Override
protected TrainingMonitor wrap(@Nonnull final TrainingMonitor monitor) {
return new TrainingMonitor() {
@Override
public void log(final String msg) {
monitor.log(msg);
}
@Override
public void onStepComplete(final Step currentPoint) {
inputNoise.shuffle();
encodedNoise.shuffle(StochasticComponent.random.get().nextLong());
monitor.onStepComplete(currentPoint);
}
};
}
};
}
use of com.simiacryptus.mindseye.network.SimpleLossNetwork in project MindsEye by SimiaCryptus.
the class TrustSphereTest method train.
@Override
public void train(@Nonnull final NotebookOutput log, @Nonnull final Layer network, @Nonnull final Tensor[][] trainingData, final TrainingMonitor monitor) {
log.code(() -> {
@Nonnull final SimpleLossNetwork supervisedNetwork = new SimpleLossNetwork(network, new EntropyLossLayer());
@Nonnull final Trainable trainable = new SampledArrayTrainable(trainingData, supervisedNetwork, 10000);
@Nonnull final TrustRegionStrategy trustRegionStrategy = new TrustRegionStrategy() {
@Override
public TrustRegion getRegionPolicy(final Layer layer) {
return new AdaptiveTrustSphere();
}
};
return new IterativeTrainer(trainable).setIterationsPerSample(100).setMonitor(monitor).setOrientation(trustRegionStrategy).setTimeout(3, TimeUnit.MINUTES).setMaxIterations(500).runAndFree();
});
}
use of com.simiacryptus.mindseye.network.SimpleLossNetwork in project MindsEye by SimiaCryptus.
the class QuadraticLineSearchTest method train.
@Override
public void train(@Nonnull final NotebookOutput log, @Nonnull final Layer network, @Nonnull final Tensor[][] trainingData, final TrainingMonitor monitor) {
log.code(() -> {
@Nonnull final SimpleLossNetwork supervisedNetwork = new SimpleLossNetwork(network, new EntropyLossLayer());
@Nonnull final Trainable trainable = new SampledArrayTrainable(trainingData, supervisedNetwork, 1000);
return new IterativeTrainer(trainable).setMonitor(monitor).setOrientation(new GradientDescent()).setLineSearchFactory((@Nonnull final CharSequence name) -> new QuadraticSearch()).setTimeout(3, TimeUnit.MINUTES).setMaxIterations(500).runAndFree();
});
}
Aggregations