use of hex.ModelMetricsMultinomial in project h2o-3 by h2oai.
the class DeepWaterAbstractIntegrationTest method MNISTHinton.
@Test
public void MNISTHinton() {
Frame tr = null;
Frame va = null;
DeepWaterModel m = null;
try {
DeepWaterParameters p = new DeepWaterParameters();
File file = FileUtils.locateFile("bigdata/laptop/mnist/train.csv.gz");
File valid = FileUtils.locateFile("bigdata/laptop/mnist/test.csv.gz");
if (file != null) {
p._response_column = "C785";
NFSFileVec trainfv = NFSFileVec.make(file);
tr = ParseDataset.parse(Key.make(), trainfv._key);
NFSFileVec validfv = NFSFileVec.make(valid);
va = ParseDataset.parse(Key.make(), validfv._key);
for (String col : new String[] { p._response_column }) {
Vec v = tr.remove(col);
tr.add(col, v.toCategoricalVec());
v.remove();
v = va.remove(col);
va.add(col, v.toCategoricalVec());
v.remove();
}
DKV.put(tr);
DKV.put(va);
p._backend = getBackend();
p._hidden = new int[] { 1024, 1024, 2048 };
p._input_dropout_ratio = 0.1;
p._hidden_dropout_ratios = new double[] { 0.5, 0.5, 0.5 };
p._stopping_rounds = 0;
p._learning_rate = 1e-3;
p._mini_batch_size = 32;
p._epochs = 20;
p._train = tr._key;
p._valid = va._key;
DeepWater j = new DeepWater(p);
m = j.trainModel().get();
Assert.assertTrue(((ModelMetricsMultinomial) (m._output._validation_metrics)).mean_per_class_error() < 0.05);
}
} finally {
if (tr != null)
tr.remove();
if (va != null)
va.remove();
if (m != null)
m.remove();
}
}
use of hex.ModelMetricsMultinomial in project h2o-3 by h2oai.
the class DeepWaterAbstractIntegrationTest method MNISTLenet.
@Test
public void MNISTLenet() {
Frame tr = null;
Frame va = null;
DeepWaterModel m = null;
try {
DeepWaterParameters p = new DeepWaterParameters();
File file = FileUtils.locateFile("bigdata/laptop/mnist/train.csv.gz");
File valid = FileUtils.locateFile("bigdata/laptop/mnist/test.csv.gz");
if (file != null) {
p._response_column = "C785";
NFSFileVec trainfv = NFSFileVec.make(file);
tr = ParseDataset.parse(Key.make(), trainfv._key);
NFSFileVec validfv = NFSFileVec.make(valid);
va = ParseDataset.parse(Key.make(), validfv._key);
for (String col : new String[] { p._response_column }) {
Vec v = tr.remove(col);
tr.add(col, v.toCategoricalVec());
v.remove();
v = va.remove(col);
va.add(col, v.toCategoricalVec());
v.remove();
}
DKV.put(tr);
DKV.put(va);
p._backend = getBackend();
p._train = tr._key;
p._valid = va._key;
p._image_shape = new int[] { 28, 28 };
//to keep it 28x28
p._ignore_const_cols = false;
p._channels = 1;
p._network = lenet;
DeepWater j = new DeepWater(p);
m = j.trainModel().get();
Assert.assertTrue(((ModelMetricsMultinomial) (m._output._validation_metrics)).mean_per_class_error() < 0.05);
}
} finally {
if (tr != null)
tr.remove();
if (va != null)
va.remove();
if (m != null)
m.remove();
}
}
use of hex.ModelMetricsMultinomial in project h2o-3 by h2oai.
the class DeepWaterAbstractIntegrationTest method testCheckpointOverwriteWithBestModel2.
// Check that the restarted model honors the previous model as a best model so far
@Test
public void testCheckpointOverwriteWithBestModel2() {
Frame tfr = null;
DeepWaterModel dl = null;
DeepWaterModel dl2 = null;
Frame train = null, valid = null;
try {
tfr = parse_test_file("./smalldata/iris/iris.csv");
FrameSplitter fs = new FrameSplitter(tfr, new double[] { 0.8 }, new Key[] { Key.make("train"), Key.make("valid") }, null);
fs.compute2();
train = fs.getResult()[0];
valid = fs.getResult()[1];
DeepWaterParameters parms = new DeepWaterParameters();
parms._backend = getBackend();
parms._train = train._key;
parms._valid = valid._key;
parms._epochs = 10;
parms._response_column = "C5";
parms._hidden = new int[] { 50, 50 };
parms._seed = 0xdecaf;
parms._train_samples_per_iteration = 0;
parms._score_duty_cycle = 1;
parms._score_interval = 0;
parms._stopping_rounds = 0;
parms._overwrite_with_best_model = true;
dl = new DeepWater(parms).trainModel().get();
double ll1 = ((ModelMetricsMultinomial) dl._output._validation_metrics).logloss();
DeepWaterParameters parms2 = (DeepWaterParameters) parms.clone();
parms2._epochs = 20;
parms2._checkpoint = dl._key;
dl2 = new DeepWater(parms2).trainModel().get();
double ll2 = ((ModelMetricsMultinomial) dl2._output._validation_metrics).logloss();
Assert.assertTrue(ll2 <= ll1);
} finally {
if (tfr != null)
tfr.delete();
if (dl != null)
dl.delete();
if (dl2 != null)
dl2.delete();
if (train != null)
train.delete();
if (valid != null)
valid.delete();
}
}
use of hex.ModelMetricsMultinomial in project h2o-3 by h2oai.
the class DeepWaterAbstractIntegrationTest method testCheckpointOverwriteWithBestModel.
/*
public ArrayList<int[]> texts2arrayOnehot(ArrayList<String> texts) {
int maxlen = 0;
int index = 0;
Map<String, Integer> dict = new HashMap<>();
dict.put(PADDING_SYMBOL, index);
index += 1;
for (String text : texts) {
String[] tokens = tokenize(text);
for (String token : tokens) {
if (!dict.containsKey(token)) {
dict.put(token, index);
index += 1;
}
}
int len = tokens.length;
if (len > maxlen) maxlen = len;
}
System.out.println(dict);
System.out.println("maxlen " + maxlen);
System.out.println("dict size " + dict.size());
Assert.assertEquals(38, maxlen);
Assert.assertEquals(88, index);
Assert.assertEquals(88, dict.size());
ArrayList<int[]> array = new ArrayList<>();
for (String text: texts) {
ArrayList<int[]> data = tokensToArray(tokenize(text), maxlen, dict);
System.out.println(text);
System.out.println(" rows " + data.size() + " cols " + data.get(0).length);
//for (int[] x : data) {
// System.out.println(Arrays.toString(x));
//}
array.addAll(data);
}
return array;
}
*/
@Test
public void testCheckpointOverwriteWithBestModel() {
Frame tfr = null;
DeepWaterModel dl = null;
DeepWaterModel dl2 = null;
Frame train = null, valid = null;
try {
tfr = parse_test_file("./smalldata/iris/iris.csv");
FrameSplitter fs = new FrameSplitter(tfr, new double[] { 0.8 }, new Key[] { Key.make("train"), Key.make("valid") }, null);
fs.compute2();
train = fs.getResult()[0];
valid = fs.getResult()[1];
DeepWaterParameters parms = new DeepWaterParameters();
parms._backend = getBackend();
parms._train = train._key;
parms._valid = valid._key;
parms._epochs = 1;
parms._response_column = "C5";
parms._hidden = new int[] { 50, 50 };
parms._seed = 0xdecaf;
parms._train_samples_per_iteration = 0;
parms._score_duty_cycle = 1;
parms._score_interval = 0;
parms._stopping_rounds = 0;
parms._overwrite_with_best_model = true;
dl = new DeepWater(parms).trainModel().get();
double ll1 = ((ModelMetricsMultinomial) dl._output._validation_metrics).logloss();
DeepWaterParameters parms2 = (DeepWaterParameters) parms.clone();
parms2._epochs = 10;
parms2._checkpoint = dl._key;
dl2 = new DeepWater(parms2).trainModel().get();
double ll2 = ((ModelMetricsMultinomial) dl2._output._validation_metrics).logloss();
Assert.assertTrue(ll2 <= ll1);
} finally {
if (tfr != null)
tfr.delete();
if (dl != null)
dl.delete();
if (dl2 != null)
dl2.delete();
if (train != null)
train.delete();
if (valid != null)
valid.delete();
}
}
use of hex.ModelMetricsMultinomial in project h2o-3 by h2oai.
the class DeepWaterAbstractIntegrationTest method restoreState.
public void restoreState(DeepWaterParameters.Network network) {
DeepWaterModel m1 = null;
DeepWaterModel m2 = null;
Frame tr = null;
Frame pred = null;
try {
DeepWaterParameters p = new DeepWaterParameters();
p._backend = getBackend();
p._train = (tr = parse_test_file("bigdata/laptop/deepwater/imagenet/cat_dog_mouse.csv"))._key;
p._network = network;
p._response_column = "C2";
p._mini_batch_size = 2;
p._train_samples_per_iteration = p._mini_batch_size;
p._learning_rate = 0e-3;
p._seed = 12345;
p._epochs = 0.01;
p._quiet_mode = true;
p._problem_type = DeepWaterParameters.ProblemType.image;
m1 = new DeepWater(p).trainModel().get();
Log.info("Scoring the original model.");
pred = m1.score(tr);
pred.remove(0).remove();
ModelMetricsMultinomial mm1 = ModelMetricsMultinomial.make(pred, tr.vec(p._response_column));
Log.info("Original LL: " + ((ModelMetricsMultinomial) m1._output._training_metrics).logloss());
Log.info("Scored LL: " + mm1.logloss());
pred.remove();
Log.info("Keeping the raw byte[] of the model.");
byte[] raw = new AutoBuffer().put(m1).buf();
Log.info("Removing the model from the DKV.");
m1.remove();
Log.info("Restoring the model from the raw byte[].");
m2 = new AutoBuffer(raw).get();
Log.info("Scoring the restored model.");
pred = m2.score(tr);
pred.remove(0).remove();
ModelMetricsMultinomial mm2 = ModelMetricsMultinomial.make(pred, tr.vec(p._response_column));
Log.info("Restored LL: " + mm2.logloss());
//make sure scoring is self-consistent
Assert.assertEquals(((ModelMetricsMultinomial) m1._output._training_metrics).logloss(), mm1.logloss(), 1e-5 * mm1.logloss());
Assert.assertEquals(mm1.logloss(), mm2.logloss(), 1e-5 * mm1.logloss());
} finally {
if (m1 != null)
m1.delete();
if (m2 != null)
m2.delete();
if (tr != null)
tr.remove();
if (pred != null)
pred.remove();
}
}
Aggregations