use of hex.FrameSplitter in project h2o-3 by h2oai.
the class DeepWaterAbstractIntegrationTest method testCheckpointOverwriteWithBestModel.
/*
public ArrayList<int[]> texts2arrayOnehot(ArrayList<String> texts) {
int maxlen = 0;
int index = 0;
Map<String, Integer> dict = new HashMap<>();
dict.put(PADDING_SYMBOL, index);
index += 1;
for (String text : texts) {
String[] tokens = tokenize(text);
for (String token : tokens) {
if (!dict.containsKey(token)) {
dict.put(token, index);
index += 1;
}
}
int len = tokens.length;
if (len > maxlen) maxlen = len;
}
System.out.println(dict);
System.out.println("maxlen " + maxlen);
System.out.println("dict size " + dict.size());
Assert.assertEquals(38, maxlen);
Assert.assertEquals(88, index);
Assert.assertEquals(88, dict.size());
ArrayList<int[]> array = new ArrayList<>();
for (String text: texts) {
ArrayList<int[]> data = tokensToArray(tokenize(text), maxlen, dict);
System.out.println(text);
System.out.println(" rows " + data.size() + " cols " + data.get(0).length);
//for (int[] x : data) {
// System.out.println(Arrays.toString(x));
//}
array.addAll(data);
}
return array;
}
*/
@Test
public void testCheckpointOverwriteWithBestModel() {
Frame tfr = null;
DeepWaterModel dl = null;
DeepWaterModel dl2 = null;
Frame train = null, valid = null;
try {
tfr = parse_test_file("./smalldata/iris/iris.csv");
FrameSplitter fs = new FrameSplitter(tfr, new double[] { 0.8 }, new Key[] { Key.make("train"), Key.make("valid") }, null);
fs.compute2();
train = fs.getResult()[0];
valid = fs.getResult()[1];
DeepWaterParameters parms = new DeepWaterParameters();
parms._backend = getBackend();
parms._train = train._key;
parms._valid = valid._key;
parms._epochs = 1;
parms._response_column = "C5";
parms._hidden = new int[] { 50, 50 };
parms._seed = 0xdecaf;
parms._train_samples_per_iteration = 0;
parms._score_duty_cycle = 1;
parms._score_interval = 0;
parms._stopping_rounds = 0;
parms._overwrite_with_best_model = true;
dl = new DeepWater(parms).trainModel().get();
double ll1 = ((ModelMetricsMultinomial) dl._output._validation_metrics).logloss();
DeepWaterParameters parms2 = (DeepWaterParameters) parms.clone();
parms2._epochs = 10;
parms2._checkpoint = dl._key;
dl2 = new DeepWater(parms2).trainModel().get();
double ll2 = ((ModelMetricsMultinomial) dl2._output._validation_metrics).logloss();
Assert.assertTrue(ll2 <= ll1);
} finally {
if (tfr != null)
tfr.delete();
if (dl != null)
dl.delete();
if (dl2 != null)
dl2.delete();
if (train != null)
train.delete();
if (valid != null)
valid.delete();
}
}
Aggregations