Search in sources :

Example 1 with MnistCanvas

use of hex.MnistCanvas in project h2o-2 by h2oai.

the class NeuralNetMnistPretrain method preTrain.

private final void preTrain(Layer[] ls, int index, int epochs) {
    // Build a network with same layers below 'index', and an auto-encoder at the top
    Layer[] pre = new Layer[index + 2];
    VecsInput input = (VecsInput) ls[0];
    pre[0] = new VecsInput(input.vecs, input);
    //clone the parameters
    pre[0].init(pre, 0, ls[0].params);
    for (int i = 1; i < index; i++) {
        //pre[i] = new Layer.Rectifier(ls[i].units);
        pre[i] = new Layer.Tanh(ls[i].units);
        Layer.shareWeights(ls[i], pre[i]);
        //share the parameters
        pre[i].init(pre, i, ls[i].params);
        //turn off training for these layers
        pre[i].params.rate = 0;
    }
    // Auto-encoder is a layer and a reverse layer on top
    //pre[index] = new Layer.Rectifier(ls[index].units);
    //pre[index + 1] = new Layer.RectifierPrime(ls[index - 1].units);
    pre[index] = new Layer.Tanh(ls[index].units);
    pre[index].init(pre, index, ls[index].params);
    pre[index].params.rate = 1e-5;
    pre[index + 1] = new Layer.TanhPrime(ls[index - 1].units);
    pre[index + 1].init(pre, index + 1, pre[index].params);
    pre[index + 1].params.rate = 1e-5;
    Layer.shareWeights(ls[index], pre[index]);
    Layer.shareWeights(ls[index], pre[index + 1]);
    _trainer = new Trainer.Direct(pre, epochs, self());
    // Basic visualization of images and weights
    JFrame frame = new JFrame("H2O Pre-Training");
    frame.setDefaultCloseOperation(WindowConstants.EXIT_ON_CLOSE);
    MnistCanvas canvas = new MnistCanvas(_trainer);
    frame.setContentPane(canvas.init());
    frame.pack();
    frame.setLocationRelativeTo(null);
    frame.setVisible(true);
    _trainer.start();
    _trainer.join();
}
Also used : MnistCanvas(hex.MnistCanvas) VecsInput(hex.Layer.VecsInput) Trainer(hex.Trainer) Layer(hex.Layer)

Example 2 with MnistCanvas

use of hex.MnistCanvas in project h2o-2 by h2oai.

the class NeuralNetMnistPretrain method startTraining.

@Override
protected void startTraining(Layer[] ls) {
    int pretrain_epochs = 2;
    preTrain(ls, pretrain_epochs);
    // actual run
    int epochs = 0;
    if (epochs > 0) {
        //    _trainer = new Trainer.Direct(ls, epochs, self());
        _trainer = new Trainer.Threaded(ls, epochs, self(), -1);
        // Basic visualization of images and weights
        JFrame frame = new JFrame("H2O Training");
        frame.setDefaultCloseOperation(WindowConstants.EXIT_ON_CLOSE);
        MnistCanvas canvas = new MnistCanvas(_trainer);
        frame.setContentPane(canvas.init());
        frame.pack();
        frame.setLocationRelativeTo(null);
        //_trainer = new Trainer.MapReduce(ls, epochs, self());
        frame.setVisible(true);
        _trainer.start();
        _trainer.join();
    }
}
Also used : MnistCanvas(hex.MnistCanvas) Trainer(hex.Trainer)

Aggregations

MnistCanvas (hex.MnistCanvas)2 Trainer (hex.Trainer)2 Layer (hex.Layer)1 VecsInput (hex.Layer.VecsInput)1