Search in sources :

Example 1 with Triple

use of water.util.Triple in project h2o-3 by h2oai.

the class DRFTest method testColSamplingPerTree.

@Test
public void testColSamplingPerTree() {
    Frame tfr = null;
    Key[] ksplits = new Key[0];
    try {
        tfr = parse_test_file("./smalldata/gbm_test/ecology_model.csv");
        SplitFrame sf = new SplitFrame(tfr, new double[] { 0.5, 0.5 }, new Key[] { Key.make("train.hex"), Key.make("test.hex") });
        // Invoke the job
        sf.exec().get();
        ksplits = sf._destination_frames;
        DRFModel drf = null;
        float[] sample_rates = new float[] { 0.2f, 0.4f, 0.6f, 0.8f, 1.0f };
        float[] col_sample_rates = new float[] { 0.4f, 0.6f, 0.8f, 1.0f };
        float[] col_sample_rates_per_tree = new float[] { 0.4f, 0.6f, 0.8f, 1.0f };
        Map<Double, Triple<Float>> hm = new TreeMap<>();
        for (float sample_rate : sample_rates) {
            for (float col_sample_rate : col_sample_rates) {
                for (float col_sample_rate_per_tree : col_sample_rates_per_tree) {
                    Scope.enter();
                    try {
                        DRFModel.DRFParameters parms = new DRFModel.DRFParameters();
                        parms._train = ksplits[0];
                        parms._valid = ksplits[1];
                        //regression
                        parms._response_column = "Angaus";
                        parms._seed = 12345;
                        parms._min_rows = 1;
                        parms._max_depth = 15;
                        parms._ntrees = 2;
                        parms._mtries = Math.max(1, (int) (col_sample_rate * (tfr.numCols() - 1)));
                        parms._col_sample_rate_per_tree = col_sample_rate_per_tree;
                        parms._sample_rate = sample_rate;
                        // Build a first model; all remaining models should be equal
                        DRF job = new DRF(parms);
                        drf = job.trainModel().get();
                        // too slow, but passes (now)
                        //            // Build a POJO, validate same results
                        //            Frame pred = drf.score(tfr);
                        //            Assert.assertTrue(drf.testJavaScoring(tfr,pred,1e-15));
                        //            pred.remove();
                        ModelMetricsRegression mm = (ModelMetricsRegression) drf._output._validation_metrics;
                        hm.put(mm.mse(), new Triple<>(sample_rate, col_sample_rate, col_sample_rate_per_tree));
                    } finally {
                        if (drf != null)
                            drf.delete();
                        Scope.exit();
                    }
                }
            }
        }
        Iterator<Map.Entry<Double, Triple<Float>>> it;
        Triple<Float> last = null;
        // iterator over results (min to max MSE) - best to worst
        for (it = hm.entrySet().iterator(); it.hasNext(); ) {
            Map.Entry<Double, Triple<Float>> n = it.next();
            Log.info("MSE: " + n.getKey() + ", row sample: " + n.getValue().v1 + ", col sample: " + n.getValue().v2 + ", col sample per tree: " + n.getValue().v3);
            last = n.getValue();
        }
    // worst validation MSE should belong to the most overfit case (1.0, 1.0, 1.0)
    //      Assert.assertTrue(last.v1==sample_rates[sample_rates.length-1]);
    //      Assert.assertTrue(last.v2==col_sample_rates[col_sample_rates.length-1]);
    //      Assert.assertTrue(last.v3==col_sample_rates_per_tree[col_sample_rates_per_tree.length-1]);
    } finally {
        if (tfr != null)
            tfr.remove();
        for (Key k : ksplits) if (k != null)
            k.remove();
    }
}
Also used : Frame(water.fvec.Frame) SplitFrame(hex.SplitFrame) ModelMetricsRegression(hex.ModelMetricsRegression) Triple(water.util.Triple) SplitFrame(hex.SplitFrame) Test(org.junit.Test)

Aggregations

ModelMetricsRegression (hex.ModelMetricsRegression)1 SplitFrame (hex.SplitFrame)1 Test (org.junit.Test)1 Frame (water.fvec.Frame)1 Triple (water.util.Triple)1