Search in sources :

Example 1 with DeepLearningParameters

use of hex.deeplearning.DeepLearningModel.DeepLearningParameters in project h2o-3 by h2oai.

the class DeepLearningMNIST method run.

@Test
@Ignore
public void run() {
    Scope.enter();
    Frame frame = null;
    Frame vframe = null;
    try {
        File file = FileUtils.locateFile("bigdata/laptop/mnist/train.csv.gz");
        File valid = FileUtils.locateFile("bigdata/laptop/mnist/test.csv.gz");
        if (file != null) {
            NFSFileVec trainfv = NFSFileVec.make(file);
            frame = ParseDataset.parse(Key.make(), trainfv._key);
            NFSFileVec validfv = NFSFileVec.make(valid);
            vframe = ParseDataset.parse(Key.make(), validfv._key);
            DeepLearningParameters p = new DeepLearningParameters();
            // populate model parameters
            p._train = frame._key;
            p._valid = vframe._key;
            // last column is the response
            p._response_column = "C785";
            p._activation = DeepLearningParameters.Activation.RectifierWithDropout;
            //        p._activation = DeepLearningParameters.Activation.MaxoutWithDropout;
            p._hidden = new int[] { 128, 128, 128 };
            p._input_dropout_ratio = 0.0;
            p._score_training_samples = 0;
            p._adaptive_rate = false;
            p._rate = 0.005;
            p._rate_annealing = 0;
            p._momentum_start = 0;
            p._momentum_stable = 0;
            p._mini_batch_size = 1;
            p._train_samples_per_iteration = -1;
            //        p._score_duty_cycle = 0.1;
            p._shuffle_training_data = true;
            //        p._reproducible = true;
            //        p._l1= 1e-5;
            p._max_w2 = 1;
            //1000*10*5./6;
            p._epochs = 20;
            //faster as activations remain sparse
            p._sparse = true;
            // Convert response 'C785' to categorical (digits 1 to 10)
            int ci = frame.find("C785");
            Scope.track(frame.replace(ci, frame.vecs()[ci].toCategoricalVec()));
            Scope.track(vframe.replace(ci, vframe.vecs()[ci].toCategoricalVec()));
            DKV.put(frame);
            DKV.put(vframe);
            // speed up training
            //        p._adaptive_rate = true; //disable adaptive per-weight learning rate -> default settings for learning rate and momentum are probably not ideal (slow convergence)
            //avoid extra communication cost upfront, got enough data on each node for load balancing
            p._replicate_training_data = true;
            //no need to keep the best model around
            p._overwrite_with_best_model = true;
            p._classification_stop = -1;
            //        p._score_interval = 5; //score and print progress report (only) every 20 seconds
            //only score on a small sample of the training set -> don't want to spend too much time scoring (note: there will be at least 1 row per chunk)
            p._score_training_samples = 10000;
            DeepLearning dl = new DeepLearning(p, Key.<DeepLearningModel>make("dl_mnist_model"));
            DeepLearningModel model = dl.trainModel().get();
            if (model != null)
                model.delete();
        } else {
            Log.info("Please run ./gradlew syncBigDataLaptop in the top-level directory of h2o-3.");
        }
    } finally {
        Scope.exit();
        if (vframe != null)
            vframe.remove();
        if (frame != null)
            frame.remove();
    }
}
Also used : Frame(water.fvec.Frame) NFSFileVec(water.fvec.NFSFileVec) DeepLearningParameters(hex.deeplearning.DeepLearningModel.DeepLearningParameters) File(java.io.File) Ignore(org.junit.Ignore) Test(org.junit.Test)

Example 2 with DeepLearningParameters

use of hex.deeplearning.DeepLearningModel.DeepLearningParameters in project h2o-3 by h2oai.

the class DeepLearningMissingTest method run.

@Test
public void run() {
    long seed = 1234;
    DeepLearningModel mymodel = null;
    Frame train = null;
    Frame test = null;
    Frame data = null;
    DeepLearningParameters p;
    Log.info("");
    Log.info("STARTING.");
    Log.info("Using seed " + seed);
    Map<DeepLearningParameters.MissingValuesHandling, Double> sumErr = new TreeMap<>();
    StringBuilder sb = new StringBuilder();
    for (DeepLearningParameters.MissingValuesHandling mvh : new DeepLearningParameters.MissingValuesHandling[] { DeepLearningParameters.MissingValuesHandling.MeanImputation, DeepLearningParameters.MissingValuesHandling.Skip }) {
        double sumloss = 0;
        Map<Double, Double> map = new TreeMap<>();
        for (double missing_fraction : new double[] { 0, 0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.99 }) {
            double loss = 0;
            try {
                Scope.enter();
                NFSFileVec nfs = NFSFileVec.make("smalldata/junit/weather.csv");
                data = ParseDataset.parse(Key.make("data.hex"), nfs._key);
                Log.info("FrameSplitting");
                // Create holdout test data on clean data (before adding missing values)
                FrameSplitter fs = new FrameSplitter(data, new double[] { 0.75f }, generateNumKeys(data._key, 2), null);
                //.join();
                H2O.submitTask(fs);
                Frame[] train_test = fs.getResult();
                train = train_test[0];
                test = train_test[1];
                Log.info("Done...");
                // add missing values to the training data (excluding the response)
                if (missing_fraction > 0) {
                    Frame frtmp = new Frame(Key.<Frame>make(), train.names(), train.vecs());
                    //exclude the response
                    frtmp.remove(frtmp.numCols() - 1);
                    //need to put the frame (to be modified) into DKV for MissingInserter to pick up
                    DKV.put(frtmp._key, frtmp);
                    FrameUtils.MissingInserter j = new FrameUtils.MissingInserter(frtmp._key, seed, missing_fraction);
                    //MissingInserter is non-blocking, must block here explicitly
                    j.execImpl().get();
                    //Delete the frame header (not the data)
                    DKV.remove(frtmp._key);
                }
                // Build a regularized DL model with polluted training data, score on clean validation set
                p = new DeepLearningParameters();
                p._train = train._key;
                p._valid = test._key;
                p._response_column = train._names[train.numCols() - 1];
                //only for weather data
                p._ignored_columns = new String[] { train._names[1], train._names[22] };
                p._missing_values_handling = mvh;
                // DeepLearningParameters.Loss.ModifiedHuber;
                p._loss = DeepLearningParameters.Loss.CrossEntropy;
                p._activation = DeepLearningParameters.Activation.Rectifier;
                p._hidden = new int[] { 50, 50 };
                p._l1 = 1e-5;
                p._input_dropout_ratio = 0.2;
                p._epochs = 3;
                p._reproducible = true;
                p._seed = seed;
                p._elastic_averaging = false;
                // Convert response to categorical
                int ri = train.numCols() - 1;
                int ci = test.find(p._response_column);
                Scope.track(train.replace(ri, train.vecs()[ri].toCategoricalVec()));
                Scope.track(test.replace(ci, test.vecs()[ci].toCategoricalVec()));
                DKV.put(train);
                DKV.put(test);
                DeepLearning dl = new DeepLearning(p);
                Log.info("Starting with " + missing_fraction * 100 + "% missing values added.");
                mymodel = dl.trainModel().get();
                // Extract the scoring on validation set from the model
                loss = mymodel.loss();
                Log.info("Missing " + missing_fraction * 100 + "% -> logloss: " + loss);
            } catch (Throwable t) {
                t.printStackTrace();
                loss = 100;
            } finally {
                Scope.exit();
                // cleanup
                if (mymodel != null) {
                    mymodel.delete();
                }
                if (train != null)
                    train.delete();
                if (test != null)
                    test.delete();
                if (data != null)
                    data.delete();
            }
            map.put(missing_fraction, loss);
            sumloss += loss;
        }
        sb.append("\nMethod: ").append(mvh.toString()).append("\n");
        sb.append("missing fraction --> loss\n");
        for (String s : Arrays.toString(map.entrySet().toArray()).split(",")) sb.append(s.replace("=", " --> ")).append("\n");
        sb.append('\n');
        sb.append("sum loss: ").append(sumloss).append("\n");
        sumErr.put(mvh, sumloss);
    }
    Log.info(sb.toString());
    Assert.assertEquals(405.5017, sumErr.get(DeepLearningParameters.MissingValuesHandling.Skip), 1e-2);
    Assert.assertEquals(3.914915, sumErr.get(DeepLearningParameters.MissingValuesHandling.MeanImputation), 1e-3);
}
Also used : FrameUtils(water.util.FrameUtils) Frame(water.fvec.Frame) NFSFileVec(water.fvec.NFSFileVec) DeepLearningParameters(hex.deeplearning.DeepLearningModel.DeepLearningParameters) FrameSplitter(hex.FrameSplitter) Test(org.junit.Test)

Example 3 with DeepLearningParameters

use of hex.deeplearning.DeepLearningModel.DeepLearningParameters in project h2o-3 by h2oai.

the class DeepLearningProstateTest method runFraction.

public void runFraction(float fraction) {
    long seed = 0xDECAFFF;
    Random rng = new Random(seed);
    String[] datasets = new String[2];
    int[][] responses = new int[datasets.length][];
    //CAPSULE (binomial), AGE (regression), GLEASON (multi-class)
    datasets[0] = "smalldata/logreg/prostate.csv";
    //CAPSULE (binomial), AGE (regression), GLEASON (multi-class)
    responses[0] = new int[] { 1, 2, 8 };
    //Iris-type (multi-class)
    datasets[1] = "smalldata/iris/iris.csv";
    //Iris-type (multi-class)
    responses[1] = new int[] { 4 };
    HashSet<Long> checkSums = new LinkedHashSet<>();
    int testcount = 0;
    int count = 0;
    for (int i = 0; i < datasets.length; ++i) {
        final String dataset = datasets[i];
        for (final int resp : responses[i]) {
            Frame frame = null, vframe = null;
            try {
                NFSFileVec nfs = TestUtil.makeNfsFileVec(dataset);
                frame = ParseDataset.parse(Key.make(), nfs._key);
                NFSFileVec vnfs = TestUtil.makeNfsFileVec(dataset);
                vframe = ParseDataset.parse(Key.make(), vnfs._key);
                boolean classification = !(i == 0 && resp == 2);
                String respname = frame.name(resp);
                if (classification && !frame.vec(resp).isCategorical()) {
                    Vec r = frame.vec(resp).toCategoricalVec();
                    frame.remove(resp).remove();
                    frame.add(respname, r);
                    DKV.put(frame);
                    Vec vr = vframe.vec(respname).toCategoricalVec();
                    vframe.remove(respname).remove();
                    vframe.add(respname, vr);
                    DKV.put(vframe);
                }
                if (classification) {
                    assert (frame.vec(respname).isCategorical());
                    assert (vframe.vec(respname).isCategorical());
                }
                for (DeepLearningParameters.Loss loss : new DeepLearningParameters.Loss[] { DeepLearningParameters.Loss.Automatic, DeepLearningParameters.Loss.CrossEntropy, DeepLearningParameters.Loss.Huber, //                DeepLearningParameters.Loss.ModifiedHuber,
                DeepLearningParameters.Loss.Absolute, DeepLearningParameters.Loss.Quadratic }) {
                    if (!classification && (loss == DeepLearningParameters.Loss.CrossEntropy || loss == DeepLearningParameters.Loss.ModifiedHuber))
                        continue;
                    for (DistributionFamily dist : new DistributionFamily[] { DistributionFamily.AUTO, DistributionFamily.laplace, DistributionFamily.huber, //                  DistributionFamily.modified_huber,
                    DistributionFamily.bernoulli, DistributionFamily.gaussian, DistributionFamily.poisson, DistributionFamily.tweedie, DistributionFamily.gamma }) {
                        if (classification && dist != DistributionFamily.multinomial && dist != DistributionFamily.bernoulli && dist != DistributionFamily.modified_huber)
                            continue;
                        if (!classification) {
                            if (dist == DistributionFamily.multinomial || dist == DistributionFamily.bernoulli || dist == DistributionFamily.modified_huber)
                                continue;
                        }
                        boolean cont = false;
                        switch(dist) {
                            case tweedie:
                            case gamma:
                            case poisson:
                                if (loss != DeepLearningParameters.Loss.Automatic)
                                    cont = true;
                                break;
                            case huber:
                                if (loss != DeepLearningParameters.Loss.Huber && loss != DeepLearningParameters.Loss.Automatic)
                                    cont = true;
                                break;
                            case laplace:
                                if (loss != DeepLearningParameters.Loss.Absolute && loss != DeepLearningParameters.Loss.Automatic)
                                    cont = true;
                                break;
                            case modified_huber:
                                if (loss != DeepLearningParameters.Loss.ModifiedHuber && loss != DeepLearningParameters.Loss.Automatic)
                                    cont = true;
                                break;
                            case bernoulli:
                                if (loss != DeepLearningParameters.Loss.CrossEntropy && loss != DeepLearningParameters.Loss.Automatic)
                                    cont = true;
                                break;
                        }
                        if (cont)
                            continue;
                        for (boolean elastic_averaging : new boolean[] { true, false }) {
                            for (boolean replicate : new boolean[] { true, false }) {
                                for (DeepLearningParameters.Activation activation : new DeepLearningParameters.Activation[] { DeepLearningParameters.Activation.Tanh, DeepLearningParameters.Activation.TanhWithDropout, DeepLearningParameters.Activation.Rectifier, DeepLearningParameters.Activation.RectifierWithDropout, DeepLearningParameters.Activation.Maxout, DeepLearningParameters.Activation.MaxoutWithDropout }) {
                                    boolean reproducible = false;
                                    switch(dist) {
                                        case tweedie:
                                        case gamma:
                                        case poisson:
                                            //don't remember why - probably to force stability
                                            reproducible = true;
                                        default:
                                    }
                                    for (boolean load_balance : new boolean[] { true, false }) {
                                        for (boolean shuffle : new boolean[] { true, false }) {
                                            for (boolean balance_classes : new boolean[] { true, false }) {
                                                for (ClassSamplingMethod csm : new ClassSamplingMethod[] { ClassSamplingMethod.Stratified, ClassSamplingMethod.Uniform }) {
                                                    for (int scoretraining : new int[] { 200, 20, 0 }) {
                                                        for (int scorevalidation : new int[] { 200, 20, 0 }) {
                                                            for (int vf : new int[] { //no validation
                                                            0, //same as source
                                                            1, //different validation frame
                                                            -1 }) {
                                                                for (int n_folds : new int[] { 0, 2 }) {
                                                                    //FIXME: Add back
                                                                    if (n_folds > 0 && balance_classes)
                                                                        continue;
                                                                    for (boolean overwrite_with_best_model : new boolean[] { false, true }) {
                                                                        for (int train_samples_per_iteration : new int[] { //auto-tune
                                                                        -2, //N epochs per iteration
                                                                        -1, //1 epoch per iteration
                                                                        0, // <1 epoch per iteration
                                                                        rng.nextInt(200), //>1 epoch per iteration
                                                                        500 }) {
                                                                            DeepLearningModel model1 = null, model2 = null;
                                                                            count++;
                                                                            if (fraction < rng.nextFloat())
                                                                                continue;
                                                                            try {
                                                                                Log.info("**************************)");
                                                                                Log.info("Starting test #" + count);
                                                                                Log.info("**************************)");
                                                                                final double epochs = 7 + rng.nextDouble() + rng.nextInt(4);
                                                                                final int[] hidden = new int[] { 3 + rng.nextInt(4), 3 + rng.nextInt(6) };
                                                                                final double[] hidden_dropout_ratios = activation.name().contains("Hidden") ? new double[] { rng.nextFloat(), rng.nextFloat() } : null;
                                                                                //no validation
                                                                                Frame valid = null;
                                                                                if (//use the same frame for validation
                                                                                vf == 1)
                                                                                    //use the same frame for validation
                                                                                    valid = frame;
                                                                                else if (vf == -1)
                                                                                    //different validation frame (here: from the same file)
                                                                                    valid = vframe;
                                                                                long myseed = rng.nextLong();
                                                                                boolean replicate2 = rng.nextBoolean();
                                                                                boolean elastic_averaging2 = rng.nextBoolean();
                                                                                // build the model, with all kinds of shuffling/rebalancing/sampling
                                                                                DeepLearningParameters p = new DeepLearningParameters();
                                                                                {
                                                                                    Log.info("Using seed: " + myseed);
                                                                                    p._train = frame._key;
                                                                                    p._response_column = respname;
                                                                                    p._valid = valid == null ? null : valid._key;
                                                                                    p._hidden = hidden;
                                                                                    p._input_dropout_ratio = 0.1;
                                                                                    p._hidden_dropout_ratios = hidden_dropout_ratios;
                                                                                    p._activation = activation;
                                                                                    //                                      p.best_model_key = best_model_key;
                                                                                    p._overwrite_with_best_model = overwrite_with_best_model;
                                                                                    p._epochs = epochs;
                                                                                    p._loss = loss;
                                                                                    p._distribution = dist;
                                                                                    p._nfolds = n_folds;
                                                                                    p._seed = myseed;
                                                                                    p._train_samples_per_iteration = train_samples_per_iteration;
                                                                                    p._force_load_balance = load_balance;
                                                                                    p._replicate_training_data = replicate;
                                                                                    p._reproducible = reproducible;
                                                                                    p._shuffle_training_data = shuffle;
                                                                                    p._score_training_samples = scoretraining;
                                                                                    p._score_validation_samples = scorevalidation;
                                                                                    p._classification_stop = -1;
                                                                                    p._regression_stop = -1;
                                                                                    p._stopping_rounds = 0;
                                                                                    p._balance_classes = classification && balance_classes;
                                                                                    p._quiet_mode = true;
                                                                                    p._score_validation_sampling = csm;
                                                                                    p._elastic_averaging = elastic_averaging;
                                                                                    //                                      Log.info(new String(p.writeJSON(new AutoBuffer()).buf()).replace(",","\n"));
                                                                                    DeepLearning dl = new DeepLearning(p, Key.<DeepLearningModel>make(Key.make().toString() + "first"));
                                                                                    try {
                                                                                        model1 = dl.trainModel().get();
                                                                                        checkSums.add(model1.checksum());
                                                                                        testcount++;
                                                                                    } catch (Throwable t) {
                                                                                        model1 = DKV.getGet(dl.dest());
                                                                                        if (model1 != null)
                                                                                            Assert.assertTrue(model1._output._job.isCrashed());
                                                                                        throw t;
                                                                                    }
                                                                                    Log.info("Trained for " + model1.epoch_counter + " epochs.");
                                                                                    assert (((p._train_samples_per_iteration <= 0 || p._train_samples_per_iteration >= frame.numRows()) && model1.epoch_counter > epochs) || Math.abs(model1.epoch_counter - epochs) / epochs < 0.20);
                                                                                    // check that iteration is of the expected length - check via when first scoring happens
                                                                                    if (p._train_samples_per_iteration == 0) {
                                                                                        // no sampling - every node does its share of the full data
                                                                                        if (!replicate)
                                                                                            assert ((double) model1._output._scoring_history.get(1, 3) == 1);
                                                                                        else
                                                                                            assert ((double) model1._output._scoring_history.get(1, 3) > 0.7 && (double) model1._output._scoring_history.get(1, 3) < 1.3) : ("First scoring at " + model1._output._scoring_history.get(1, 3) + " epochs, should be closer to 1!" + "\n" + model1.toString());
                                                                                    } else if (p._train_samples_per_iteration == -1) {
                                                                                        // no sampling - every node does its share of the full data
                                                                                        if (!replicate)
                                                                                            assert ((double) model1._output._scoring_history.get(1, 3) == 1);
                                                                                        else // every node passes over the full dataset
                                                                                        {
                                                                                            if (!reproducible)
                                                                                                assert ((double) model1._output._scoring_history.get(1, 3) == H2O.CLOUD.size());
                                                                                        }
                                                                                    }
                                                                                    if (n_folds != 0) {
                                                                                        assert (model1._output._cross_validation_metrics != null);
                                                                                    } else {
                                                                                        assert (model1._output._cross_validation_metrics == null);
                                                                                    }
                                                                                }
                                                                                assert (model1.model_info().get_params()._l1 == 0);
                                                                                assert (model1.model_info().get_params()._l2 == 0);
                                                                                Assert.assertFalse(model1._output._job.isCrashed());
                                                                                if (n_folds != 0)
                                                                                    continue;
                                                                                // Do some more training via checkpoint restart
                                                                                // For n_folds, continue without n_folds (not yet implemented) - from now on, model2 will have n_folds=0...
                                                                                DeepLearningParameters p2 = new DeepLearningParameters();
                                                                                Assert.assertTrue(model1.model_info().get_processed_total() >= frame.numRows() * epochs);
                                                                                {
                                                                                    p2._checkpoint = model1._key;
                                                                                    p2._distribution = dist;
                                                                                    p2._loss = loss;
                                                                                    p2._nfolds = n_folds;
                                                                                    p2._train = frame._key;
                                                                                    p2._activation = activation;
                                                                                    p2._hidden = hidden;
                                                                                    p2._valid = valid == null ? null : valid._key;
                                                                                    p2._l1 = 1e-3;
                                                                                    p2._l2 = 1e-3;
                                                                                    p2._reproducible = reproducible;
                                                                                    p2._response_column = respname;
                                                                                    p2._overwrite_with_best_model = overwrite_with_best_model;
                                                                                    p2._quiet_mode = true;
                                                                                    //final amount of training epochs
                                                                                    p2._epochs = 2 * epochs;
                                                                                    p2._replicate_training_data = replicate2;
                                                                                    p2._stopping_rounds = 0;
                                                                                    p2._seed = myseed;
                                                                                    //                                              p2._loss = loss; //fall back to default
                                                                                    //                                              p2._distribution = dist; //fall back to default
                                                                                    p2._train_samples_per_iteration = train_samples_per_iteration;
                                                                                    p2._balance_classes = classification && balance_classes;
                                                                                    p2._elastic_averaging = elastic_averaging2;
                                                                                    DeepLearning dl = new DeepLearning(p2);
                                                                                    try {
                                                                                        model2 = dl.trainModel().get();
                                                                                    } catch (Throwable t) {
                                                                                        model2 = DKV.getGet(dl.dest());
                                                                                        if (model2 != null)
                                                                                            Assert.assertTrue(model2._output._job.isCrashed());
                                                                                        throw t;
                                                                                    }
                                                                                }
                                                                                Assert.assertTrue(model1._output._job.isDone());
                                                                                Assert.assertTrue(model2._output._job.isDone());
                                                                                assert (model1._parms != p2);
                                                                                assert (model1.model_info().get_params() != model2.model_info().get_params());
                                                                                assert (model1.model_info().get_params()._l1 == 0);
                                                                                assert (model1.model_info().get_params()._l2 == 0);
                                                                                if (!overwrite_with_best_model)
                                                                                    Assert.assertTrue(model2.model_info().get_processed_total() >= frame.numRows() * 2 * epochs);
                                                                                assert (p != p2);
                                                                                assert (p != model1.model_info().get_params());
                                                                                assert (p2 != model2.model_info().get_params());
                                                                                if (p._loss == DeepLearningParameters.Loss.Automatic) {
                                                                                    assert (p2._loss == DeepLearningParameters.Loss.Automatic);
                                                                                //                                              assert(model1.model_info().get_params()._loss != DeepLearningParameters.Loss.Automatic);
                                                                                //                                              assert(model2.model_info().get_params()._loss != DeepLearningParameters.Loss.Automatic);
                                                                                }
                                                                                assert (p._hidden_dropout_ratios == null);
                                                                                assert (p2._hidden_dropout_ratios == null);
                                                                                if (p._activation.toString().contains("WithDropout")) {
                                                                                    assert (model1.model_info().get_params()._hidden_dropout_ratios != null);
                                                                                    assert (model2.model_info().get_params()._hidden_dropout_ratios != null);
                                                                                    assert (Arrays.equals(model1.model_info().get_params()._hidden_dropout_ratios, model2.model_info().get_params()._hidden_dropout_ratios));
                                                                                }
                                                                                assert (p._l1 == 0);
                                                                                assert (p._l2 == 0);
                                                                                assert (p2._l1 == 1e-3);
                                                                                assert (p2._l2 == 1e-3);
                                                                                assert (model1.model_info().get_params()._l1 == 0);
                                                                                assert (model1.model_info().get_params()._l2 == 0);
                                                                                assert (model2.model_info().get_params()._l1 == 1e-3);
                                                                                assert (model2.model_info().get_params()._l2 == 1e-3);
                                                                                if (valid == null)
                                                                                    valid = frame;
                                                                                double threshold;
                                                                                if (model2._output.isClassifier()) {
                                                                                    Frame pred = null;
                                                                                    Vec labels, predlabels, pred2labels;
                                                                                    try {
                                                                                        pred = model2.score(valid);
                                                                                        DKV.put(Key.make("pred"), pred);
                                                                                        // Build a POJO, validate same results
                                                                                        if (!model2.testJavaScoring(valid, pred, 1e-6)) {
                                                                                            model2.testJavaScoring(valid, pred, 1e-6);
                                                                                        }
                                                                                        Assert.assertTrue(model2.testJavaScoring(valid, pred, 1e-6));
                                                                                        hex.ModelMetrics mm = hex.ModelMetrics.getFromDKV(model2, valid);
                                                                                        double error;
                                                                                        // binary
                                                                                        if (model2._output.nclasses() == 2) {
                                                                                            assert (resp == 1);
                                                                                            threshold = mm.auc_obj().defaultThreshold();
                                                                                            error = mm.auc_obj().defaultErr();
                                                                                            // check that auc.cm() is the right CM
                                                                                            Assert.assertEquals(new ConfusionMatrix(mm.auc_obj().defaultCM(), valid.vec(respname).domain()).err(), error, 1e-15);
                                                                                            // check that calcError() is consistent as well (for CM=null, AUC!=null)
                                                                                            Assert.assertEquals(mm.cm().err(), error, 1e-15);
                                                                                            // check that the labels made with the default threshold are consistent with the CM that's reported by the AUC object
                                                                                            labels = valid.vec(respname);
                                                                                            predlabels = pred.vecs()[0];
                                                                                            ConfusionMatrix cm = buildCM(labels, predlabels);
                                                                                            Log.info("CM from pre-made labels:");
                                                                                            Log.info(cm.toASCII());
                                                                                            if (Math.abs(cm.err() - error) > 2e-2) {
                                                                                                ConfusionMatrix cm2 = buildCM(labels, predlabels);
                                                                                                Log.info(cm2.toASCII());
                                                                                            }
                                                                                            Assert.assertEquals(cm.err(), error, 2e-2);
                                                                                            // confirm that orig CM was made with the right threshold
                                                                                            // manually make labels with AUC-given default threshold
                                                                                            String ast = "(as.factor (> (cols pred [2]) " + threshold + "))";
                                                                                            Frame tmp = Rapids.exec(ast).getFrame();
                                                                                            pred2labels = tmp.vecs()[0];
                                                                                            cm = buildCM(labels, pred2labels);
                                                                                            Log.info("CM from self-made labels:");
                                                                                            Log.info(cm.toASCII());
                                                                                            //AUC-given F1-optimal threshold might not reproduce AUC-given CM-error identically, but should match up to 2%
                                                                                            Assert.assertEquals(cm.err(), error, 2e-2);
                                                                                            tmp.delete();
                                                                                        }
                                                                                        DKV.remove(Key.make("pred"));
                                                                                    } finally {
                                                                                        if (pred != null)
                                                                                            pred.delete();
                                                                                    }
                                                                                } else //classifier
                                                                                {
                                                                                    Frame pred = null;
                                                                                    try {
                                                                                        pred = model2.score(valid);
                                                                                        // Build a POJO, validate same results
                                                                                        Assert.assertTrue(model2.testJavaScoring(frame, pred, 1e-6));
                                                                                    } finally {
                                                                                        if (pred != null)
                                                                                            pred.delete();
                                                                                    }
                                                                                }
                                                                                Log.info("Parameters combination " + count + ": PASS");
                                                                            } catch (H2OModelBuilderIllegalArgumentException | IllegalArgumentException ex) {
                                                                                System.err.println(ex);
                                                                                throw H2O.fail("should not get here");
                                                                            } catch (RuntimeException t) {
                                                                                String msg = // this way we evade null messages
                                                                                "" + t.getMessage() + (t.getCause() == null ? "" : t.getCause().getMessage());
                                                                                Assert.assertTrue("Unexpected exception " + t + ": " + msg, msg.contains("unstable"));
                                                                            } catch (AssertionError ae) {
                                                                                // test assertions should be preserved
                                                                                throw ae;
                                                                            } catch (Throwable t) {
                                                                                t.printStackTrace();
                                                                                throw new RuntimeException(t);
                                                                            } finally {
                                                                                if (model1 != null) {
                                                                                    model1.deleteCrossValidationModels();
                                                                                    model1.delete();
                                                                                }
                                                                                if (model2 != null) {
                                                                                    model2.deleteCrossValidationModels();
                                                                                    model2.delete();
                                                                                }
                                                                            }
                                                                        }
                                                                    }
                                                                }
                                                            }
                                                        }
                                                    }
                                                }
                                            }
                                        }
                                    }
                                }
                            }
                        }
                    }
                }
            } finally {
                if (frame != null)
                    frame.delete();
                if (vframe != null)
                    vframe.delete();
            }
        }
    }
    Log.info("\n\n=============================================");
    Log.info("Tested " + testcount + " out of " + count + " parameter combinations.");
    Log.info("=============================================");
    if (checkSums.size() != testcount) {
        Log.info("Only found " + checkSums.size() + " unique checksums.");
    }
    Assert.assertTrue(checkSums.size() == testcount);
}
Also used : LinkedHashSet(java.util.LinkedHashSet) Frame(water.fvec.Frame) ConfusionMatrix(hex.ConfusionMatrix) NFSFileVec(water.fvec.NFSFileVec) DeepLearningParameters(hex.deeplearning.DeepLearningModel.DeepLearningParameters) ClassSamplingMethod(hex.deeplearning.DeepLearningModel.DeepLearningParameters.ClassSamplingMethod) Random(java.util.Random) H2OModelBuilderIllegalArgumentException(water.exceptions.H2OModelBuilderIllegalArgumentException) DistributionFamily(hex.genmodel.utils.DistributionFamily) H2OModelBuilderIllegalArgumentException(water.exceptions.H2OModelBuilderIllegalArgumentException) NFSFileVec(water.fvec.NFSFileVec) Vec(water.fvec.Vec)

Example 4 with DeepLearningParameters

use of hex.deeplearning.DeepLearningModel.DeepLearningParameters in project h2o-3 by h2oai.

the class DeepLearningReproducibilityTest method run.

@Test
public void run() {
    NFSFileVec ff = TestUtil.makeNfsFileVec("smalldata/junit/weather.csv");
    Frame golden = ParseDataset.parse(Key.make("golden.hex"), ff._key);
    DeepLearningModel mymodel = null;
    Frame train = null;
    Frame test = null;
    Frame data = null;
    Map<Integer, Float> repeatErrs = new TreeMap<>();
    int N = 3;
    StringBuilder sb = new StringBuilder();
    float repro_error = 0;
    for (boolean repro : new boolean[] { true, false }) {
        Scope.enter();
        Frame[] preds = new Frame[N];
        long[] checksums = new long[N];
        double[] numbers = new double[N];
        for (int repeat = 0; repeat < N; ++repeat) {
            try {
                NFSFileVec file = TestUtil.makeNfsFileVec("smalldata/junit/weather.csv");
                data = ParseDataset.parse(Key.make("data.hex"), file._key);
                //test parser consistency
                Assert.assertTrue(TestUtil.isBitIdentical(data, golden));
                // Create holdout test data on clean data (before adding missing values)
                train = data;
                test = data;
                // Build a regularized DL model with polluted training data, score on clean validation set
                DeepLearningParameters p = new DeepLearningParameters();
                p._train = train._key;
                p._valid = test._key;
                p._response_column = train.names()[train.names().length - 1];
                int ci = train.names().length - 1;
                Scope.track(train.replace(ci, train.vecs()[ci].toCategoricalVec()));
                DKV.put(train);
                //for weather data
                p._ignored_columns = new String[] { "EvapMM", "RISK_MM" };
                p._activation = DeepLearningParameters.Activation.RectifierWithDropout;
                p._hidden = new int[] { 32, 58 };
                p._l1 = 1e-5;
                p._l2 = 3e-5;
                p._seed = 0xbebe;
                p._loss = DeepLearningParameters.Loss.CrossEntropy;
                p._input_dropout_ratio = 0.2;
                p._train_samples_per_iteration = 3;
                p._hidden_dropout_ratios = new double[] { 0.4, 0.1 };
                p._epochs = 1.32;
                //          p._nfolds = 2;
                p._quiet_mode = true;
                p._reproducible = repro;
                DeepLearning dl = new DeepLearning(p);
                mymodel = dl.trainModel().get();
                // Extract the scoring on validation set from the model
                preds[repeat] = mymodel.score(test);
                for (int i = 0; i < 5; ++i) {
                    Frame tmp = mymodel.score(test);
                    Assert.assertTrue("Prediction #" + i + " for repeat #" + repeat + " differs!", TestUtil.isBitIdentical(preds[repeat], tmp));
                    tmp.delete();
                }
                Log.info("Prediction:\n" + FrameUtils.chunkSummary(preds[repeat]).toString());
                numbers[repeat] = mymodel.model_info().get_weights(0).get(23, 4);
                //check that the model state is consistent
                checksums[repeat] = mymodel.model_info().checksum_impl();
                repeatErrs.put(repeat, mymodel.loss());
            } finally {
                // cleanup
                if (mymodel != null) {
                    mymodel.delete();
                }
                if (train != null)
                    train.delete();
                if (test != null)
                    test.delete();
                if (data != null)
                    data.delete();
            }
        }
        sb.append("Reproducibility: ").append(repro ? "on" : "off").append("\n");
        sb.append("Repeat # --> Validation Loss\n");
        for (String s : Arrays.toString(repeatErrs.entrySet().toArray()).split(",")) sb.append(s.replace("=", " --> ")).append("\n");
        sb.append('\n');
        Log.info(sb.toString());
        try {
            if (repro) {
                // check reproducibility
                for (double error : numbers) assertTrue(Arrays.toString(numbers), error == numbers[0]);
                for (Float error : repeatErrs.values()) assertTrue(error.equals(repeatErrs.get(0)));
                for (long cs : checksums) assertTrue(cs == checksums[0]);
                for (Frame f : preds) {
                    //            assertTrue(TestUtil.isBitIdentical(f, preds[0])); // PUBDEV-892: This should have passed all the time
                    for (int i = 0; i < f.vecs().length; ++i) {
                        //PUBDEV-892: This tolerance should be 1e-15
                        TestUtil.assertVecEquals(f.vecs()[i], preds[0].vecs()[i], 1e-5);
                    }
                }
                repro_error = repeatErrs.get(0);
            } else {
                // check standard deviation of non-reproducible mode
                double mean = 0;
                for (Float error : repeatErrs.values()) {
                    mean += error;
                }
                mean /= N;
                // check non-reproducibility (Hogwild! will never reproduce)
                for (int i = 1; i < N; ++i) assertTrue(repeatErrs.get(i) != repeatErrs.get(0));
                Log.info("mean error: " + mean);
                double stddev = 0;
                for (Float error : repeatErrs.values()) {
                    stddev += (error - mean) * (error - mean);
                }
                stddev /= N;
                stddev = Math.sqrt(stddev);
                Log.info("standard deviation: " + stddev);
                //          assertTrue(stddev < 0.3 / Math.sqrt(N));
                Log.info("difference to reproducible mode: " + Math.abs(mean - repro_error) / stddev + " standard deviations");
            }
        } finally {
            for (Frame f : preds) if (f != null)
                f.delete();
        }
        Scope.exit();
    }
    golden.delete();
}
Also used : Frame(water.fvec.Frame) NFSFileVec(water.fvec.NFSFileVec) DeepLearningParameters(hex.deeplearning.DeepLearningModel.DeepLearningParameters) Test(org.junit.Test)

Example 5 with DeepLearningParameters

use of hex.deeplearning.DeepLearningModel.DeepLearningParameters in project h2o-3 by h2oai.

the class DeepLearningScoreTest method testPubDev928.

/** Load simple dataset, rebalance to a number of chunks > number of rows, and run deep learning */
@Test
public void testPubDev928() {
    // Create rebalanced dataset
    Key rebalancedKey = Key.make("rebalanced");
    NFSFileVec nfs = TestUtil.makeNfsFileVec("smalldata/logreg/prostate.csv");
    Frame fr = ParseDataset.parse(Key.make(), nfs._key);
    RebalanceDataSet rb = new RebalanceDataSet(fr, rebalancedKey, (int) (fr.numRows() + 1));
    H2O.submitTask(rb);
    rb.join();
    Frame rebalanced = DKV.get(rebalancedKey).get();
    // Assert that there is at least one 0-len chunk
    assertZeroLengthChunk("Rebalanced dataset should contain at least one 0-len chunk!", rebalanced.anyVec());
    DeepLearningModel dlModel = null;
    try {
        // Launch Deep Learning
        DeepLearningParameters dlParams = new DeepLearningParameters();
        dlParams._train = rebalancedKey;
        dlParams._epochs = 5;
        dlParams._response_column = "CAPSULE";
        dlModel = new DeepLearning(dlParams).trainModel().get();
    } finally {
        fr.delete();
        rebalanced.delete();
        if (dlModel != null)
            dlModel.delete();
    }
}
Also used : Frame(water.fvec.Frame) RebalanceDataSet(water.fvec.RebalanceDataSet) NFSFileVec(water.fvec.NFSFileVec) DeepLearningParameters(hex.deeplearning.DeepLearningModel.DeepLearningParameters) Key(water.Key) Test(org.junit.Test)

Aggregations

DeepLearningParameters (hex.deeplearning.DeepLearningModel.DeepLearningParameters)58 Frame (water.fvec.Frame)54 Test (org.junit.Test)52 NFSFileVec (water.fvec.NFSFileVec)26 Vec (water.fvec.Vec)22 hex (hex)5 Ignore (org.junit.Ignore)5 DistributionFamily (hex.genmodel.utils.DistributionFamily)4 Random (java.util.Random)4 H2OIllegalArgumentException (water.exceptions.H2OIllegalArgumentException)3 DataInfo (hex.DataInfo)2 File (java.io.File)2 Key (water.Key)2 H2OModelBuilderIllegalArgumentException (water.exceptions.H2OModelBuilderIllegalArgumentException)2 PrettyPrint (water.util.PrettyPrint)2 ConfusionMatrix (hex.ConfusionMatrix)1 Distribution (hex.Distribution)1 FrameSplitter (hex.FrameSplitter)1 FrameTask (hex.FrameTask)1 ModelMetricsBinomial (hex.ModelMetricsBinomial)1