use of hex.Trainer in project h2o-2 by h2oai.
the class NeuralNetMnist method execImpl.
@Override
protected void execImpl() {
Frame trainf = TestUtil.parseFromH2OFolder("smalldata/mnist/train.csv.gz");
Frame testf = TestUtil.parseFromH2OFolder("smalldata/mnist/test.csv.gz");
train = trainf.vecs();
test = testf.vecs();
// Labels are on last column for this dataset
final Vec trainLabels = train[train.length - 1];
train = Utils.remove(train, train.length - 1);
final Vec testLabels = test[test.length - 1];
test = Utils.remove(test, test.length - 1);
final Layer[] ls = build(train, trainLabels, null, null);
// Monitor training
final Timer timer = new Timer();
final long start = System.nanoTime();
final AtomicInteger evals = new AtomicInteger(1);
timer.schedule(new TimerTask() {
@Override
public void run() {
if (!Job.isRunning(self()))
timer.cancel();
else {
double time = (System.nanoTime() - start) / 1e9;
Trainer trainer = _trainer;
long processed = trainer == null ? 0 : trainer.processed();
int ps = (int) (processed / time);
String text = (int) time + "s, " + processed + " samples (" + (ps) + "/s) ";
// Build separate nets for scoring purposes, use same normalization stats as for training
Layer[] temp = build(train, trainLabels, (VecsInput) ls[0], (VecSoftmax) ls[ls.length - 1]);
Layer.shareWeights(ls, temp);
// Estimate training error on subset of dataset for speed
Errors e = NeuralNet.eval(temp, 1000, null);
text += "train: " + e;
text += ", rate: ";
text += String.format("%.5g", ls[0].rate(processed));
text += ", momentum: ";
text += String.format("%.5g", ls[0].momentum(processed));
System.out.println(text);
if ((evals.incrementAndGet() % 1) == 0) {
System.out.println("Computing test error");
temp = build(test, testLabels, (VecsInput) ls[0], (VecSoftmax) ls[ls.length - 1]);
Layer.shareWeights(ls, temp);
e = NeuralNet.eval(temp, 0, null);
System.out.println("Test error: " + e);
}
}
}
}, 0, 10);
startTraining(ls);
}
Aggregations