Search in sources :

Example 1 with AUCData

use of water.api.AUCData in project h2o-2 by h2oai.

the class DeepLearningProstateTest method runFraction.

public void runFraction(float fraction) {
    long seed = 0xDECAF;
    Random rng = new Random(seed);
    String[] datasets = new String[2];
    int[][] responses = new int[datasets.length][];
    //CAPSULE (binomial), AGE (regression), GLEASON (multi-class)
    datasets[0] = "smalldata/./logreg/prostate.csv";
    //CAPSULE (binomial), AGE (regression), GLEASON (multi-class)
    responses[0] = new int[] { 1, 2, 8 };
    //Iris-type (multi-class)
    datasets[1] = "smalldata/iris/iris.csv";
    //Iris-type (multi-class)
    responses[1] = new int[] { 4 };
    int testcount = 0;
    int count = 0;
    for (int i = 0; i < datasets.length; ++i) {
        String dataset = datasets[i];
        Key file = NFSFileVec.make(find_test_file(dataset));
        Frame frame = ParseDataset2.parse(Key.make(), new Key[] { file });
        Key vfile = NFSFileVec.make(find_test_file(dataset));
        Frame vframe = ParseDataset2.parse(Key.make(), new Key[] { vfile });
        try {
            for (boolean replicate : new boolean[] { true, false }) {
                for (boolean load_balance : new boolean[] { true, false }) {
                    for (boolean shuffle : new boolean[] { true, false }) {
                        for (boolean balance_classes : new boolean[] { true, false }) {
                            for (int resp : responses[i]) {
                                for (DeepLearning.ClassSamplingMethod csm : new DeepLearning.ClassSamplingMethod[] { DeepLearning.ClassSamplingMethod.Stratified, DeepLearning.ClassSamplingMethod.Uniform }) {
                                    for (int scoretraining : new int[] { 200, 20, 0 }) {
                                        for (int scorevalidation : new int[] { 200, 20, 0 }) {
                                            for (int vf : new int[] { //no validation
                                            0, //same as source
                                            1, //different validation frame
                                            -1 }) {
                                                for (int n_folds : new int[] { 0, 2 }) {
                                                    if (n_folds != 0 && vf != 0)
                                                        continue;
                                                    for (boolean keep_cv_splits : new boolean[] { false }) {
                                                        //otherwise it leaks
                                                        for (boolean override_with_best_model : new boolean[] { false, true }) {
                                                            for (int train_samples_per_iteration : new int[] { //auto-tune
                                                            -2, //N epochs per iteration
                                                            -1, //1 epoch per iteration
                                                            0, // <1 epoch per iteration
                                                            rng.nextInt(200), //>1 epoch per iteration
                                                            500 }) {
                                                                DeepLearningModel model1 = null, model2 = null;
                                                                Key dest = null, dest_tmp = null;
                                                                count++;
                                                                if (fraction < rng.nextFloat())
                                                                    continue;
                                                                try {
                                                                    Log.info("**************************)");
                                                                    Log.info("Starting test #" + count);
                                                                    Log.info("**************************)");
                                                                    final double epochs = 7 + rng.nextDouble() + rng.nextInt(4);
                                                                    final int[] hidden = new int[] { 1 + rng.nextInt(4), 1 + rng.nextInt(6) };
                                                                    //no validation
                                                                    Frame valid = null;
                                                                    if (//use the same frame for validation
                                                                    vf == 1)
                                                                        //use the same frame for validation
                                                                        valid = frame;
                                                                    else //different validation frame (here: from the same file)
                                                                    if (vf == -1)
                                                                        valid = vframe;
                                                                    // build the model, with all kinds of shuffling/rebalancing/sampling
                                                                    dest_tmp = Key.make("first");
                                                                    {
                                                                        Log.info("Using seed: " + seed);
                                                                        DeepLearning p = new DeepLearning();
                                                                        p.checkpoint = null;
                                                                        p.destination_key = dest_tmp;
                                                                        p.source = frame;
                                                                        p.response = frame.vecs()[resp];
                                                                        p.validation = valid;
                                                                        p.hidden = hidden;
                                                                        if (i == 0 && resp == 2)
                                                                            p.classification = false;
                                                                        //                                      p.best_model_key = best_model_key;
                                                                        p.override_with_best_model = override_with_best_model;
                                                                        p.epochs = epochs;
                                                                        p.n_folds = n_folds;
                                                                        p.keep_cross_validation_splits = keep_cv_splits;
                                                                        p.seed = seed;
                                                                        p.train_samples_per_iteration = train_samples_per_iteration;
                                                                        p.force_load_balance = load_balance;
                                                                        p.replicate_training_data = replicate;
                                                                        p.shuffle_training_data = shuffle;
                                                                        p.score_training_samples = scoretraining;
                                                                        p.score_validation_samples = scorevalidation;
                                                                        p.classification_stop = -1;
                                                                        p.regression_stop = -1;
                                                                        p.balance_classes = balance_classes;
                                                                        p.quiet_mode = true;
                                                                        p.score_validation_sampling = csm;
                                                                        try {
                                                                            p.invoke();
                                                                        } catch (Throwable t) {
                                                                            t.printStackTrace();
                                                                            throw new RuntimeException(t);
                                                                        } finally {
                                                                            p.delete();
                                                                        }
                                                                        model1 = UKV.get(dest_tmp);
                                                                        assert (((p.train_samples_per_iteration <= 0 || p.train_samples_per_iteration >= frame.numRows()) && model1.epoch_counter > epochs) || Math.abs(model1.epoch_counter - epochs) / epochs < 0.20);
                                                                        if (n_folds != 0) // test HTML of cv models
                                                                        {
                                                                            for (Key k : model1.get_params().xval_models) {
                                                                                DeepLearningModel cv_model = UKV.get(k);
                                                                                StringBuilder sb = new StringBuilder();
                                                                                cv_model.generateHTML("cv", sb);
                                                                                cv_model.delete_best_model();
                                                                                cv_model.delete();
                                                                            }
                                                                        }
                                                                    }
                                                                    // Do some more training via checkpoint restart
                                                                    // For n_folds, continue without n_folds (not yet implemented) - from now on, model2 will have n_folds=0...
                                                                    dest = Key.make("restart");
                                                                    DeepLearning p = new DeepLearning();
                                                                    //this actually *requires* frame to also still be in UKV (because of DataInfo...)
                                                                    final DeepLearningModel tmp_model = UKV.get(dest_tmp);
                                                                    //HEX-1817
                                                                    Assert.assertTrue(tmp_model.get_params().state == Job.JobState.DONE);
                                                                    Assert.assertTrue(tmp_model.model_info().get_processed_total() >= frame.numRows() * epochs);
                                                                    assert (tmp_model != null);
                                                                    p.checkpoint = dest_tmp;
                                                                    p.destination_key = dest;
                                                                    p.n_folds = 0;
                                                                    p.source = frame;
                                                                    p.validation = valid;
                                                                    p.response = frame.vecs()[resp];
                                                                    if (i == 0 && resp == 2)
                                                                        p.classification = false;
                                                                    p.override_with_best_model = override_with_best_model;
                                                                    p.epochs = epochs;
                                                                    p.seed = seed;
                                                                    p.train_samples_per_iteration = train_samples_per_iteration;
                                                                    try {
                                                                        p.invoke();
                                                                    } catch (Throwable t) {
                                                                        t.printStackTrace();
                                                                        throw new RuntimeException(t);
                                                                    } finally {
                                                                        p.delete();
                                                                    }
                                                                    // score and check result (on full data)
                                                                    //this actually *requires* frame to also still be in UKV (because of DataInfo...)
                                                                    model2 = UKV.get(dest);
                                                                    //HEX-1817
                                                                    Assert.assertTrue(model2.get_params().state == Job.JobState.DONE);
                                                                    // test HTML
                                                                    {
                                                                        StringBuilder sb = new StringBuilder();
                                                                        model2.generateHTML("test", sb);
                                                                    }
                                                                    // score and check result of the best_model
                                                                    if (model2.actual_best_model_key != null) {
                                                                        final DeepLearningModel best_model = UKV.get(model2.actual_best_model_key);
                                                                        //HEX-1817
                                                                        Assert.assertTrue(best_model.get_params().state == Job.JobState.DONE);
                                                                        // test HTML
                                                                        {
                                                                            StringBuilder sb = new StringBuilder();
                                                                            best_model.generateHTML("test", sb);
                                                                        }
                                                                        if (override_with_best_model) {
                                                                            Assert.assertEquals(best_model.error(), model2.error(), 0);
                                                                        }
                                                                    }
                                                                    if (valid == null)
                                                                        valid = frame;
                                                                    double threshold = 0;
                                                                    if (model2.isClassifier()) {
                                                                        Frame pred = null, pred2 = null;
                                                                        try {
                                                                            pred = model2.score(valid);
                                                                            StringBuilder sb = new StringBuilder();
                                                                            AUC auc = new AUC();
                                                                            double error = 0;
                                                                            // binary
                                                                            if (model2.nclasses() == 2) {
                                                                                auc.actual = valid;
                                                                                assert (resp == 1);
                                                                                auc.vactual = valid.vecs()[resp];
                                                                                auc.predict = pred;
                                                                                auc.vpredict = pred.vecs()[2];
                                                                                auc.invoke();
                                                                                auc.toASCII(sb);
                                                                                AUCData aucd = auc.data();
                                                                                threshold = aucd.threshold();
                                                                                error = aucd.err();
                                                                                Log.info(sb);
                                                                                // check that auc.cm() is the right CM
                                                                                Assert.assertEquals(new ConfusionMatrix(aucd.cm()).err(), error, 1e-15);
                                                                                // check that calcError() is consistent as well (for CM=null, AUC!=null)
                                                                                Assert.assertEquals(model2.calcError(valid, auc.vactual, pred, pred, "training", false, 0, null, auc, null), error, 1e-15);
                                                                            }
                                                                            // Compute CM
                                                                            double CMerrorOrig;
                                                                            {
                                                                                sb = new StringBuilder();
                                                                                water.api.ConfusionMatrix CM = new water.api.ConfusionMatrix();
                                                                                CM.actual = valid;
                                                                                CM.vactual = valid.vecs()[resp];
                                                                                CM.predict = pred;
                                                                                CM.vpredict = pred.vecs()[0];
                                                                                CM.invoke();
                                                                                sb.append("\n");
                                                                                sb.append("Threshold: " + "default\n");
                                                                                CM.toASCII(sb);
                                                                                Log.info(sb);
                                                                                CMerrorOrig = new ConfusionMatrix(CM.cm).err();
                                                                            }
                                                                            // confirm that orig CM was made with threshold 0.5
                                                                            // put pred2 into UKV, and allow access
                                                                            pred2 = new Frame(Key.make("pred2"), pred.names(), pred.vecs());
                                                                            pred2.delete_and_lock(null);
                                                                            pred2.unlock(null);
                                                                            if (model2.nclasses() == 2) {
                                                                                // make labels with 0.5 threshold for binary classifier
                                                                                Env ev = Exec2.exec("pred2[,1]=pred2[,3]>=" + 0.5);
                                                                                try {
                                                                                    pred2 = ev.popAry();
                                                                                    String skey = ev.key();
                                                                                    ev.subRef(pred2, skey);
                                                                                } finally {
                                                                                    if (ev != null)
                                                                                        ev.remove_and_unlock();
                                                                                }
                                                                                water.api.ConfusionMatrix CM = new water.api.ConfusionMatrix();
                                                                                CM.actual = valid;
                                                                                CM.vactual = valid.vecs()[1];
                                                                                CM.predict = pred2;
                                                                                CM.vpredict = pred2.vecs()[0];
                                                                                CM.invoke();
                                                                                sb = new StringBuilder();
                                                                                sb.append("\n");
                                                                                sb.append("Threshold: " + 0.5 + "\n");
                                                                                CM.toASCII(sb);
                                                                                Log.info(sb);
                                                                                double threshErr = new ConfusionMatrix(CM.cm).err();
                                                                                Assert.assertEquals(threshErr, CMerrorOrig, 1e-15);
                                                                                // make labels with AUC-given threshold for best F1
                                                                                ev = Exec2.exec("pred2[,1]=pred2[,3]>=" + threshold);
                                                                                try {
                                                                                    pred2 = ev.popAry();
                                                                                    String skey = ev.key();
                                                                                    ev.subRef(pred2, skey);
                                                                                } finally {
                                                                                    if (ev != null)
                                                                                        ev.remove_and_unlock();
                                                                                }
                                                                                CM = new water.api.ConfusionMatrix();
                                                                                CM.actual = valid;
                                                                                CM.vactual = valid.vecs()[1];
                                                                                CM.predict = pred2;
                                                                                CM.vpredict = pred2.vecs()[0];
                                                                                CM.invoke();
                                                                                sb = new StringBuilder();
                                                                                sb.append("\n");
                                                                                sb.append("Threshold: ").append(threshold).append("\n");
                                                                                CM.toASCII(sb);
                                                                                Log.info(sb);
                                                                                double threshErr2 = new ConfusionMatrix(CM.cm).err();
                                                                                Assert.assertEquals(threshErr2, error, 1e-15);
                                                                            }
                                                                        } finally {
                                                                            if (pred != null)
                                                                                pred.delete();
                                                                            if (pred2 != null)
                                                                                pred2.delete();
                                                                        }
                                                                    }
                                                                    //classifier
                                                                    Log.info("Parameters combination " + count + ": PASS");
                                                                    testcount++;
                                                                } catch (Throwable t) {
                                                                    t.printStackTrace();
                                                                    throw new RuntimeException(t);
                                                                } finally {
                                                                    if (model1 != null) {
                                                                        model1.delete_xval_models();
                                                                        model1.delete_best_model();
                                                                        model1.delete();
                                                                    }
                                                                    if (model2 != null) {
                                                                        model2.delete_xval_models();
                                                                        model2.delete_best_model();
                                                                        model2.delete();
                                                                    }
                                                                }
                                                            }
                                                        }
                                                    }
                                                }
                                            }
                                        }
                                    }
                                }
                            }
                        }
                    }
                }
            }
        } finally {
            frame.delete();
            vframe.delete();
        }
    }
    Log.info("\n\n=============================================");
    Log.info("Tested " + testcount + " out of " + count + " parameter combinations.");
    Log.info("=============================================");
}
Also used : Frame(water.fvec.Frame) AUCData(water.api.AUCData) DeepLearning(hex.deeplearning.DeepLearning) Env(water.exec.Env) AUC(water.api.AUC) Random(java.util.Random) water(water) DeepLearningModel(hex.deeplearning.DeepLearningModel)

Aggregations

DeepLearning (hex.deeplearning.DeepLearning)1 DeepLearningModel (hex.deeplearning.DeepLearningModel)1 Random (java.util.Random)1 water (water)1 AUC (water.api.AUC)1 AUCData (water.api.AUCData)1 Env (water.exec.Env)1 Frame (water.fvec.Frame)1