use of water.util.Triple in project h2o-3 by h2oai.
the class DRFTest method testColSamplingPerTree.
@Test
public void testColSamplingPerTree() {
Frame tfr = null;
Key[] ksplits = new Key[0];
try {
tfr = parse_test_file("./smalldata/gbm_test/ecology_model.csv");
SplitFrame sf = new SplitFrame(tfr, new double[] { 0.5, 0.5 }, new Key[] { Key.make("train.hex"), Key.make("test.hex") });
// Invoke the job
sf.exec().get();
ksplits = sf._destination_frames;
DRFModel drf = null;
float[] sample_rates = new float[] { 0.2f, 0.4f, 0.6f, 0.8f, 1.0f };
float[] col_sample_rates = new float[] { 0.4f, 0.6f, 0.8f, 1.0f };
float[] col_sample_rates_per_tree = new float[] { 0.4f, 0.6f, 0.8f, 1.0f };
Map<Double, Triple<Float>> hm = new TreeMap<>();
for (float sample_rate : sample_rates) {
for (float col_sample_rate : col_sample_rates) {
for (float col_sample_rate_per_tree : col_sample_rates_per_tree) {
Scope.enter();
try {
DRFModel.DRFParameters parms = new DRFModel.DRFParameters();
parms._train = ksplits[0];
parms._valid = ksplits[1];
//regression
parms._response_column = "Angaus";
parms._seed = 12345;
parms._min_rows = 1;
parms._max_depth = 15;
parms._ntrees = 2;
parms._mtries = Math.max(1, (int) (col_sample_rate * (tfr.numCols() - 1)));
parms._col_sample_rate_per_tree = col_sample_rate_per_tree;
parms._sample_rate = sample_rate;
// Build a first model; all remaining models should be equal
DRF job = new DRF(parms);
drf = job.trainModel().get();
// too slow, but passes (now)
// // Build a POJO, validate same results
// Frame pred = drf.score(tfr);
// Assert.assertTrue(drf.testJavaScoring(tfr,pred,1e-15));
// pred.remove();
ModelMetricsRegression mm = (ModelMetricsRegression) drf._output._validation_metrics;
hm.put(mm.mse(), new Triple<>(sample_rate, col_sample_rate, col_sample_rate_per_tree));
} finally {
if (drf != null)
drf.delete();
Scope.exit();
}
}
}
}
Iterator<Map.Entry<Double, Triple<Float>>> it;
Triple<Float> last = null;
// iterator over results (min to max MSE) - best to worst
for (it = hm.entrySet().iterator(); it.hasNext(); ) {
Map.Entry<Double, Triple<Float>> n = it.next();
Log.info("MSE: " + n.getKey() + ", row sample: " + n.getValue().v1 + ", col sample: " + n.getValue().v2 + ", col sample per tree: " + n.getValue().v3);
last = n.getValue();
}
// worst validation MSE should belong to the most overfit case (1.0, 1.0, 1.0)
// Assert.assertTrue(last.v1==sample_rates[sample_rates.length-1]);
// Assert.assertTrue(last.v2==col_sample_rates[col_sample_rates.length-1]);
// Assert.assertTrue(last.v3==col_sample_rates_per_tree[col_sample_rates_per_tree.length-1]);
} finally {
if (tfr != null)
tfr.remove();
for (Key k : ksplits) if (k != null)
k.remove();
}
}
Aggregations