Search in sources :

Example 1 with Errors

use of hex.NeuralNet.Errors in project h2o-2 by h2oai.

the class NeuralNetMnist method execImpl.

@Override
protected void execImpl() {
    Frame trainf = TestUtil.parseFromH2OFolder("smalldata/mnist/train.csv.gz");
    Frame testf = TestUtil.parseFromH2OFolder("smalldata/mnist/test.csv.gz");
    train = trainf.vecs();
    test = testf.vecs();
    // Labels are on last column for this dataset
    final Vec trainLabels = train[train.length - 1];
    train = Utils.remove(train, train.length - 1);
    final Vec testLabels = test[test.length - 1];
    test = Utils.remove(test, test.length - 1);
    final Layer[] ls = build(train, trainLabels, null, null);
    // Monitor training
    final Timer timer = new Timer();
    final long start = System.nanoTime();
    final AtomicInteger evals = new AtomicInteger(1);
    timer.schedule(new TimerTask() {

        @Override
        public void run() {
            if (!Job.isRunning(self()))
                timer.cancel();
            else {
                double time = (System.nanoTime() - start) / 1e9;
                Trainer trainer = _trainer;
                long processed = trainer == null ? 0 : trainer.processed();
                int ps = (int) (processed / time);
                String text = (int) time + "s, " + processed + " samples (" + (ps) + "/s) ";
                // Build separate nets for scoring purposes, use same normalization stats as for training
                Layer[] temp = build(train, trainLabels, (VecsInput) ls[0], (VecSoftmax) ls[ls.length - 1]);
                Layer.shareWeights(ls, temp);
                // Estimate training error on subset of dataset for speed
                Errors e = NeuralNet.eval(temp, 1000, null);
                text += "train: " + e;
                text += ", rate: ";
                text += String.format("%.5g", ls[0].rate(processed));
                text += ", momentum: ";
                text += String.format("%.5g", ls[0].momentum(processed));
                System.out.println(text);
                if ((evals.incrementAndGet() % 1) == 0) {
                    System.out.println("Computing test error");
                    temp = build(test, testLabels, (VecsInput) ls[0], (VecSoftmax) ls[ls.length - 1]);
                    Layer.shareWeights(ls, temp);
                    e = NeuralNet.eval(temp, 0, null);
                    System.out.println("Test error: " + e);
                }
            }
        }
    }, 0, 10);
    startTraining(ls);
}
Also used : Frame(water.fvec.Frame) VecSoftmax(hex.Layer.VecSoftmax) Trainer(hex.Trainer) Layer(hex.Layer) Errors(hex.NeuralNet.Errors) Timer(java.util.Timer) TimerTask(java.util.TimerTask) Vec(water.fvec.Vec) AppendableVec(water.fvec.AppendableVec) AtomicInteger(java.util.concurrent.atomic.AtomicInteger) VecsInput(hex.Layer.VecsInput)

Aggregations

Layer (hex.Layer)1 VecSoftmax (hex.Layer.VecSoftmax)1 VecsInput (hex.Layer.VecsInput)1 Errors (hex.NeuralNet.Errors)1 Trainer (hex.Trainer)1 Timer (java.util.Timer)1 TimerTask (java.util.TimerTask)1 AtomicInteger (java.util.concurrent.atomic.AtomicInteger)1 AppendableVec (water.fvec.AppendableVec)1 Frame (water.fvec.Frame)1 Vec (water.fvec.Vec)1