use of hex.MnistCanvas in project h2o-2 by h2oai.
the class NeuralNetMnistPretrain method preTrain.
private final void preTrain(Layer[] ls, int index, int epochs) {
// Build a network with same layers below 'index', and an auto-encoder at the top
Layer[] pre = new Layer[index + 2];
VecsInput input = (VecsInput) ls[0];
pre[0] = new VecsInput(input.vecs, input);
//clone the parameters
pre[0].init(pre, 0, ls[0].params);
for (int i = 1; i < index; i++) {
//pre[i] = new Layer.Rectifier(ls[i].units);
pre[i] = new Layer.Tanh(ls[i].units);
Layer.shareWeights(ls[i], pre[i]);
//share the parameters
pre[i].init(pre, i, ls[i].params);
//turn off training for these layers
pre[i].params.rate = 0;
}
// Auto-encoder is a layer and a reverse layer on top
//pre[index] = new Layer.Rectifier(ls[index].units);
//pre[index + 1] = new Layer.RectifierPrime(ls[index - 1].units);
pre[index] = new Layer.Tanh(ls[index].units);
pre[index].init(pre, index, ls[index].params);
pre[index].params.rate = 1e-5;
pre[index + 1] = new Layer.TanhPrime(ls[index - 1].units);
pre[index + 1].init(pre, index + 1, pre[index].params);
pre[index + 1].params.rate = 1e-5;
Layer.shareWeights(ls[index], pre[index]);
Layer.shareWeights(ls[index], pre[index + 1]);
_trainer = new Trainer.Direct(pre, epochs, self());
// Basic visualization of images and weights
JFrame frame = new JFrame("H2O Pre-Training");
frame.setDefaultCloseOperation(WindowConstants.EXIT_ON_CLOSE);
MnistCanvas canvas = new MnistCanvas(_trainer);
frame.setContentPane(canvas.init());
frame.pack();
frame.setLocationRelativeTo(null);
frame.setVisible(true);
_trainer.start();
_trainer.join();
}
use of hex.MnistCanvas in project h2o-2 by h2oai.
the class NeuralNetMnistPretrain method startTraining.
@Override
protected void startTraining(Layer[] ls) {
int pretrain_epochs = 2;
preTrain(ls, pretrain_epochs);
// actual run
int epochs = 0;
if (epochs > 0) {
// _trainer = new Trainer.Direct(ls, epochs, self());
_trainer = new Trainer.Threaded(ls, epochs, self(), -1);
// Basic visualization of images and weights
JFrame frame = new JFrame("H2O Training");
frame.setDefaultCloseOperation(WindowConstants.EXIT_ON_CLOSE);
MnistCanvas canvas = new MnistCanvas(_trainer);
frame.setContentPane(canvas.init());
frame.pack();
frame.setLocationRelativeTo(null);
//_trainer = new Trainer.MapReduce(ls, epochs, self());
frame.setVisible(true);
_trainer.start();
_trainer.join();
}
}
Aggregations