use of hex.ModelMetricsMultinomial in project h2o-3 by h2oai.
the class DeepWaterAbstractIntegrationTest method MNISTSparse.
@Test
public void MNISTSparse() {
Frame tr = null;
Frame va = null;
DeepWaterModel m = null;
try {
DeepWaterParameters p = new DeepWaterParameters();
File file = FileUtils.locateFile("bigdata/laptop/mnist/train.csv.gz");
File valid = FileUtils.locateFile("bigdata/laptop/mnist/test.csv.gz");
if (file != null) {
p._response_column = "C785";
NFSFileVec trainfv = NFSFileVec.make(file);
tr = ParseDataset.parse(Key.make(), trainfv._key);
NFSFileVec validfv = NFSFileVec.make(valid);
va = ParseDataset.parse(Key.make(), validfv._key);
for (String col : new String[] { p._response_column }) {
Vec v = tr.remove(col);
tr.add(col, v.toCategoricalVec());
v.remove();
v = va.remove(col);
va.add(col, v.toCategoricalVec());
v.remove();
}
DKV.put(tr);
DKV.put(va);
p._backend = getBackend();
p._train = tr._key;
p._valid = va._key;
p._hidden = new int[] { 500, 500 };
p._sparse = true;
DeepWater j = new DeepWater(p);
m = j.trainModel().get();
Assert.assertTrue(((ModelMetricsMultinomial) (m._output._validation_metrics)).mean_per_class_error() < 0.05);
}
} finally {
if (tr != null)
tr.remove();
if (va != null)
va.remove();
if (m != null)
m.remove();
}
}
use of hex.ModelMetricsMultinomial in project h2o-3 by h2oai.
the class DeepWaterAbstractIntegrationTest method MOJOTestImage.
private void MOJOTestImage(DeepWaterParameters.Network network) {
Frame tr = null;
DeepWaterModel m = null;
Frame preds = null;
try {
DeepWaterParameters p = new DeepWaterParameters();
p._backend = getBackend();
p._train = (tr = parse_test_file("bigdata/laptop/deepwater/imagenet/cat_dog_mouse.csv"))._key;
p._response_column = "C2";
p._learning_rate = 1e-4;
p._network = network;
p._mini_batch_size = 4;
p._train_samples_per_iteration = 8;
p._epochs = 1e-3;
m = new DeepWater(p).trainModel().get();
// Score original training frame
preds = m.score(tr);
Assert.assertTrue(m.testJavaScoring(tr, preds, 1e-3));
preds.remove(0).remove();
double logloss = ModelMetricsMultinomial.make(preds, tr.vec(p._response_column)).logloss();
Assert.assertTrue(Math.abs(logloss - ((ModelMetricsMultinomial) m._output._training_metrics).logloss()) < 1e-3);
} finally {
if (tr != null)
tr.remove();
if (m != null)
m.remove();
if (preds != null)
preds.remove();
}
}
use of hex.ModelMetricsMultinomial in project h2o-3 by h2oai.
the class DeepWaterAbstractIntegrationTest method overWriteWithBestModel.
@Test
public void overWriteWithBestModel() {
DeepWaterModel m = null;
Frame tr = null;
try {
DeepWaterParameters p = new DeepWaterParameters();
p._backend = getBackend();
p._train = (tr = parse_test_file("bigdata/laptop/deepwater/imagenet/cat_dog_mouse.csv"))._key;
p._response_column = "C2";
p._epochs = 50;
p._learning_rate = 0.01;
p._momentum_start = 0.5;
p._momentum_stable = 0.5;
p._stopping_rounds = 0;
p._image_shape = new int[] { 28, 28 };
p._network = lenet;
p._problem_type = DeepWaterParameters.ProblemType.image;
// score a lot
p._train_samples_per_iteration = p._mini_batch_size;
p._score_duty_cycle = 1;
p._score_interval = 0;
p._overwrite_with_best_model = true;
m = new DeepWater(p).trainModel().get();
Log.info(m);
Assert.assertTrue(((ModelMetricsMultinomial) m._output._training_metrics).logloss() < 2);
} finally {
if (m != null)
m.remove();
if (tr != null)
tr.remove();
}
}
Aggregations