Search in sources :

Example 1 with DistributionFamily

use of hex.genmodel.utils.DistributionFamily in project h2o-3 by h2oai.

the class StackedEnsembleModel method distributionFamily.

private DistributionFamily distributionFamily(Model aModel) {
    // TODO: hack alert: In DRF, _parms._distribution is always set to multinomial.  Yay.
    if (aModel instanceof DRFModel)
        if (aModel._output.isBinomialClassifier())
            return DistributionFamily.bernoulli;
        else if (aModel._output.isClassifier())
            throw new H2OIllegalArgumentException("Don't know how to set the distribution for a multinomial Random Forest classifier.");
        else
            return DistributionFamily.gaussian;
    try {
        Field familyField = ReflectionUtils.findNamedField(aModel._parms, "_family");
        Field distributionField = (familyField != null ? null : ReflectionUtils.findNamedField(aModel, "_dist"));
        if (null != familyField) {
            // GLM only, for now
            GLMModel.GLMParameters.Family thisFamily = (GLMModel.GLMParameters.Family) familyField.get(aModel._parms);
            if (thisFamily == GLMModel.GLMParameters.Family.binomial) {
                return DistributionFamily.bernoulli;
            }
            try {
                return Enum.valueOf(DistributionFamily.class, thisFamily.toString());
            } catch (IllegalArgumentException e) {
                throw new H2OIllegalArgumentException("Don't know how to find the right DistributionFamily for Family: " + thisFamily);
            }
        }
        if (null != distributionField) {
            Distribution distribution = ((Distribution) distributionField.get(aModel));
            DistributionFamily distributionFamily;
            if (null != distribution)
                distributionFamily = distribution.distribution;
            else
                distributionFamily = aModel._parms._distribution;
            // NOTE: If the algo does smart guessing of the distribution family we need to duplicate the logic here.
            if (distributionFamily == DistributionFamily.AUTO) {
                if (aModel._output.isBinomialClassifier())
                    distributionFamily = DistributionFamily.bernoulli;
                else if (aModel._output.isClassifier())
                    throw new H2OIllegalArgumentException("Don't know how to determine the distribution for a multinomial classifier.");
                else
                    distributionFamily = DistributionFamily.gaussian;
            }
            return distributionFamily;
        }
        throw new H2OIllegalArgumentException("Don't know how to stack models that have neither a distribution hyperparameter nor a family hyperparameter.");
    } catch (Exception e) {
        throw new H2OIllegalArgumentException(e.toString(), e.toString());
    }
}
Also used : Field(java.lang.reflect.Field) DRFModel(hex.tree.drf.DRFModel) GLMModel(hex.glm.GLMModel) DistributionFamily(hex.genmodel.utils.DistributionFamily) H2OIllegalArgumentException(water.exceptions.H2OIllegalArgumentException) DistributionFamily(hex.genmodel.utils.DistributionFamily) H2OIllegalArgumentException(water.exceptions.H2OIllegalArgumentException) H2OIllegalArgumentException(water.exceptions.H2OIllegalArgumentException)

Example 2 with DistributionFamily

use of hex.genmodel.utils.DistributionFamily in project h2o-3 by h2oai.

the class DeepLearningProstateTest method runFraction.

public void runFraction(float fraction) {
    long seed = 0xDECAFFF;
    Random rng = new Random(seed);
    String[] datasets = new String[2];
    int[][] responses = new int[datasets.length][];
    //CAPSULE (binomial), AGE (regression), GLEASON (multi-class)
    datasets[0] = "smalldata/logreg/prostate.csv";
    //CAPSULE (binomial), AGE (regression), GLEASON (multi-class)
    responses[0] = new int[] { 1, 2, 8 };
    //Iris-type (multi-class)
    datasets[1] = "smalldata/iris/iris.csv";
    //Iris-type (multi-class)
    responses[1] = new int[] { 4 };
    HashSet<Long> checkSums = new LinkedHashSet<>();
    int testcount = 0;
    int count = 0;
    for (int i = 0; i < datasets.length; ++i) {
        final String dataset = datasets[i];
        for (final int resp : responses[i]) {
            Frame frame = null, vframe = null;
            try {
                NFSFileVec nfs = TestUtil.makeNfsFileVec(dataset);
                frame = ParseDataset.parse(Key.make(), nfs._key);
                NFSFileVec vnfs = TestUtil.makeNfsFileVec(dataset);
                vframe = ParseDataset.parse(Key.make(), vnfs._key);
                boolean classification = !(i == 0 && resp == 2);
                String respname = frame.name(resp);
                if (classification && !frame.vec(resp).isCategorical()) {
                    Vec r = frame.vec(resp).toCategoricalVec();
                    frame.remove(resp).remove();
                    frame.add(respname, r);
                    DKV.put(frame);
                    Vec vr = vframe.vec(respname).toCategoricalVec();
                    vframe.remove(respname).remove();
                    vframe.add(respname, vr);
                    DKV.put(vframe);
                }
                if (classification) {
                    assert (frame.vec(respname).isCategorical());
                    assert (vframe.vec(respname).isCategorical());
                }
                for (DeepLearningParameters.Loss loss : new DeepLearningParameters.Loss[] { DeepLearningParameters.Loss.Automatic, DeepLearningParameters.Loss.CrossEntropy, DeepLearningParameters.Loss.Huber, //                DeepLearningParameters.Loss.ModifiedHuber,
                DeepLearningParameters.Loss.Absolute, DeepLearningParameters.Loss.Quadratic }) {
                    if (!classification && (loss == DeepLearningParameters.Loss.CrossEntropy || loss == DeepLearningParameters.Loss.ModifiedHuber))
                        continue;
                    for (DistributionFamily dist : new DistributionFamily[] { DistributionFamily.AUTO, DistributionFamily.laplace, DistributionFamily.huber, //                  DistributionFamily.modified_huber,
                    DistributionFamily.bernoulli, DistributionFamily.gaussian, DistributionFamily.poisson, DistributionFamily.tweedie, DistributionFamily.gamma }) {
                        if (classification && dist != DistributionFamily.multinomial && dist != DistributionFamily.bernoulli && dist != DistributionFamily.modified_huber)
                            continue;
                        if (!classification) {
                            if (dist == DistributionFamily.multinomial || dist == DistributionFamily.bernoulli || dist == DistributionFamily.modified_huber)
                                continue;
                        }
                        boolean cont = false;
                        switch(dist) {
                            case tweedie:
                            case gamma:
                            case poisson:
                                if (loss != DeepLearningParameters.Loss.Automatic)
                                    cont = true;
                                break;
                            case huber:
                                if (loss != DeepLearningParameters.Loss.Huber && loss != DeepLearningParameters.Loss.Automatic)
                                    cont = true;
                                break;
                            case laplace:
                                if (loss != DeepLearningParameters.Loss.Absolute && loss != DeepLearningParameters.Loss.Automatic)
                                    cont = true;
                                break;
                            case modified_huber:
                                if (loss != DeepLearningParameters.Loss.ModifiedHuber && loss != DeepLearningParameters.Loss.Automatic)
                                    cont = true;
                                break;
                            case bernoulli:
                                if (loss != DeepLearningParameters.Loss.CrossEntropy && loss != DeepLearningParameters.Loss.Automatic)
                                    cont = true;
                                break;
                        }
                        if (cont)
                            continue;
                        for (boolean elastic_averaging : new boolean[] { true, false }) {
                            for (boolean replicate : new boolean[] { true, false }) {
                                for (DeepLearningParameters.Activation activation : new DeepLearningParameters.Activation[] { DeepLearningParameters.Activation.Tanh, DeepLearningParameters.Activation.TanhWithDropout, DeepLearningParameters.Activation.Rectifier, DeepLearningParameters.Activation.RectifierWithDropout, DeepLearningParameters.Activation.Maxout, DeepLearningParameters.Activation.MaxoutWithDropout }) {
                                    boolean reproducible = false;
                                    switch(dist) {
                                        case tweedie:
                                        case gamma:
                                        case poisson:
                                            //don't remember why - probably to force stability
                                            reproducible = true;
                                        default:
                                    }
                                    for (boolean load_balance : new boolean[] { true, false }) {
                                        for (boolean shuffle : new boolean[] { true, false }) {
                                            for (boolean balance_classes : new boolean[] { true, false }) {
                                                for (ClassSamplingMethod csm : new ClassSamplingMethod[] { ClassSamplingMethod.Stratified, ClassSamplingMethod.Uniform }) {
                                                    for (int scoretraining : new int[] { 200, 20, 0 }) {
                                                        for (int scorevalidation : new int[] { 200, 20, 0 }) {
                                                            for (int vf : new int[] { //no validation
                                                            0, //same as source
                                                            1, //different validation frame
                                                            -1 }) {
                                                                for (int n_folds : new int[] { 0, 2 }) {
                                                                    //FIXME: Add back
                                                                    if (n_folds > 0 && balance_classes)
                                                                        continue;
                                                                    for (boolean overwrite_with_best_model : new boolean[] { false, true }) {
                                                                        for (int train_samples_per_iteration : new int[] { //auto-tune
                                                                        -2, //N epochs per iteration
                                                                        -1, //1 epoch per iteration
                                                                        0, // <1 epoch per iteration
                                                                        rng.nextInt(200), //>1 epoch per iteration
                                                                        500 }) {
                                                                            DeepLearningModel model1 = null, model2 = null;
                                                                            count++;
                                                                            if (fraction < rng.nextFloat())
                                                                                continue;
                                                                            try {
                                                                                Log.info("**************************)");
                                                                                Log.info("Starting test #" + count);
                                                                                Log.info("**************************)");
                                                                                final double epochs = 7 + rng.nextDouble() + rng.nextInt(4);
                                                                                final int[] hidden = new int[] { 3 + rng.nextInt(4), 3 + rng.nextInt(6) };
                                                                                final double[] hidden_dropout_ratios = activation.name().contains("Hidden") ? new double[] { rng.nextFloat(), rng.nextFloat() } : null;
                                                                                //no validation
                                                                                Frame valid = null;
                                                                                if (//use the same frame for validation
                                                                                vf == 1)
                                                                                    //use the same frame for validation
                                                                                    valid = frame;
                                                                                else if (vf == -1)
                                                                                    //different validation frame (here: from the same file)
                                                                                    valid = vframe;
                                                                                long myseed = rng.nextLong();
                                                                                boolean replicate2 = rng.nextBoolean();
                                                                                boolean elastic_averaging2 = rng.nextBoolean();
                                                                                // build the model, with all kinds of shuffling/rebalancing/sampling
                                                                                DeepLearningParameters p = new DeepLearningParameters();
                                                                                {
                                                                                    Log.info("Using seed: " + myseed);
                                                                                    p._train = frame._key;
                                                                                    p._response_column = respname;
                                                                                    p._valid = valid == null ? null : valid._key;
                                                                                    p._hidden = hidden;
                                                                                    p._input_dropout_ratio = 0.1;
                                                                                    p._hidden_dropout_ratios = hidden_dropout_ratios;
                                                                                    p._activation = activation;
                                                                                    //                                      p.best_model_key = best_model_key;
                                                                                    p._overwrite_with_best_model = overwrite_with_best_model;
                                                                                    p._epochs = epochs;
                                                                                    p._loss = loss;
                                                                                    p._distribution = dist;
                                                                                    p._nfolds = n_folds;
                                                                                    p._seed = myseed;
                                                                                    p._train_samples_per_iteration = train_samples_per_iteration;
                                                                                    p._force_load_balance = load_balance;
                                                                                    p._replicate_training_data = replicate;
                                                                                    p._reproducible = reproducible;
                                                                                    p._shuffle_training_data = shuffle;
                                                                                    p._score_training_samples = scoretraining;
                                                                                    p._score_validation_samples = scorevalidation;
                                                                                    p._classification_stop = -1;
                                                                                    p._regression_stop = -1;
                                                                                    p._stopping_rounds = 0;
                                                                                    p._balance_classes = classification && balance_classes;
                                                                                    p._quiet_mode = true;
                                                                                    p._score_validation_sampling = csm;
                                                                                    p._elastic_averaging = elastic_averaging;
                                                                                    //                                      Log.info(new String(p.writeJSON(new AutoBuffer()).buf()).replace(",","\n"));
                                                                                    DeepLearning dl = new DeepLearning(p, Key.<DeepLearningModel>make(Key.make().toString() + "first"));
                                                                                    try {
                                                                                        model1 = dl.trainModel().get();
                                                                                        checkSums.add(model1.checksum());
                                                                                        testcount++;
                                                                                    } catch (Throwable t) {
                                                                                        model1 = DKV.getGet(dl.dest());
                                                                                        if (model1 != null)
                                                                                            Assert.assertTrue(model1._output._job.isCrashed());
                                                                                        throw t;
                                                                                    }
                                                                                    Log.info("Trained for " + model1.epoch_counter + " epochs.");
                                                                                    assert (((p._train_samples_per_iteration <= 0 || p._train_samples_per_iteration >= frame.numRows()) && model1.epoch_counter > epochs) || Math.abs(model1.epoch_counter - epochs) / epochs < 0.20);
                                                                                    // check that iteration is of the expected length - check via when first scoring happens
                                                                                    if (p._train_samples_per_iteration == 0) {
                                                                                        // no sampling - every node does its share of the full data
                                                                                        if (!replicate)
                                                                                            assert ((double) model1._output._scoring_history.get(1, 3) == 1);
                                                                                        else
                                                                                            assert ((double) model1._output._scoring_history.get(1, 3) > 0.7 && (double) model1._output._scoring_history.get(1, 3) < 1.3) : ("First scoring at " + model1._output._scoring_history.get(1, 3) + " epochs, should be closer to 1!" + "\n" + model1.toString());
                                                                                    } else if (p._train_samples_per_iteration == -1) {
                                                                                        // no sampling - every node does its share of the full data
                                                                                        if (!replicate)
                                                                                            assert ((double) model1._output._scoring_history.get(1, 3) == 1);
                                                                                        else // every node passes over the full dataset
                                                                                        {
                                                                                            if (!reproducible)
                                                                                                assert ((double) model1._output._scoring_history.get(1, 3) == H2O.CLOUD.size());
                                                                                        }
                                                                                    }
                                                                                    if (n_folds != 0) {
                                                                                        assert (model1._output._cross_validation_metrics != null);
                                                                                    } else {
                                                                                        assert (model1._output._cross_validation_metrics == null);
                                                                                    }
                                                                                }
                                                                                assert (model1.model_info().get_params()._l1 == 0);
                                                                                assert (model1.model_info().get_params()._l2 == 0);
                                                                                Assert.assertFalse(model1._output._job.isCrashed());
                                                                                if (n_folds != 0)
                                                                                    continue;
                                                                                // Do some more training via checkpoint restart
                                                                                // For n_folds, continue without n_folds (not yet implemented) - from now on, model2 will have n_folds=0...
                                                                                DeepLearningParameters p2 = new DeepLearningParameters();
                                                                                Assert.assertTrue(model1.model_info().get_processed_total() >= frame.numRows() * epochs);
                                                                                {
                                                                                    p2._checkpoint = model1._key;
                                                                                    p2._distribution = dist;
                                                                                    p2._loss = loss;
                                                                                    p2._nfolds = n_folds;
                                                                                    p2._train = frame._key;
                                                                                    p2._activation = activation;
                                                                                    p2._hidden = hidden;
                                                                                    p2._valid = valid == null ? null : valid._key;
                                                                                    p2._l1 = 1e-3;
                                                                                    p2._l2 = 1e-3;
                                                                                    p2._reproducible = reproducible;
                                                                                    p2._response_column = respname;
                                                                                    p2._overwrite_with_best_model = overwrite_with_best_model;
                                                                                    p2._quiet_mode = true;
                                                                                    //final amount of training epochs
                                                                                    p2._epochs = 2 * epochs;
                                                                                    p2._replicate_training_data = replicate2;
                                                                                    p2._stopping_rounds = 0;
                                                                                    p2._seed = myseed;
                                                                                    //                                              p2._loss = loss; //fall back to default
                                                                                    //                                              p2._distribution = dist; //fall back to default
                                                                                    p2._train_samples_per_iteration = train_samples_per_iteration;
                                                                                    p2._balance_classes = classification && balance_classes;
                                                                                    p2._elastic_averaging = elastic_averaging2;
                                                                                    DeepLearning dl = new DeepLearning(p2);
                                                                                    try {
                                                                                        model2 = dl.trainModel().get();
                                                                                    } catch (Throwable t) {
                                                                                        model2 = DKV.getGet(dl.dest());
                                                                                        if (model2 != null)
                                                                                            Assert.assertTrue(model2._output._job.isCrashed());
                                                                                        throw t;
                                                                                    }
                                                                                }
                                                                                Assert.assertTrue(model1._output._job.isDone());
                                                                                Assert.assertTrue(model2._output._job.isDone());
                                                                                assert (model1._parms != p2);
                                                                                assert (model1.model_info().get_params() != model2.model_info().get_params());
                                                                                assert (model1.model_info().get_params()._l1 == 0);
                                                                                assert (model1.model_info().get_params()._l2 == 0);
                                                                                if (!overwrite_with_best_model)
                                                                                    Assert.assertTrue(model2.model_info().get_processed_total() >= frame.numRows() * 2 * epochs);
                                                                                assert (p != p2);
                                                                                assert (p != model1.model_info().get_params());
                                                                                assert (p2 != model2.model_info().get_params());
                                                                                if (p._loss == DeepLearningParameters.Loss.Automatic) {
                                                                                    assert (p2._loss == DeepLearningParameters.Loss.Automatic);
                                                                                //                                              assert(model1.model_info().get_params()._loss != DeepLearningParameters.Loss.Automatic);
                                                                                //                                              assert(model2.model_info().get_params()._loss != DeepLearningParameters.Loss.Automatic);
                                                                                }
                                                                                assert (p._hidden_dropout_ratios == null);
                                                                                assert (p2._hidden_dropout_ratios == null);
                                                                                if (p._activation.toString().contains("WithDropout")) {
                                                                                    assert (model1.model_info().get_params()._hidden_dropout_ratios != null);
                                                                                    assert (model2.model_info().get_params()._hidden_dropout_ratios != null);
                                                                                    assert (Arrays.equals(model1.model_info().get_params()._hidden_dropout_ratios, model2.model_info().get_params()._hidden_dropout_ratios));
                                                                                }
                                                                                assert (p._l1 == 0);
                                                                                assert (p._l2 == 0);
                                                                                assert (p2._l1 == 1e-3);
                                                                                assert (p2._l2 == 1e-3);
                                                                                assert (model1.model_info().get_params()._l1 == 0);
                                                                                assert (model1.model_info().get_params()._l2 == 0);
                                                                                assert (model2.model_info().get_params()._l1 == 1e-3);
                                                                                assert (model2.model_info().get_params()._l2 == 1e-3);
                                                                                if (valid == null)
                                                                                    valid = frame;
                                                                                double threshold;
                                                                                if (model2._output.isClassifier()) {
                                                                                    Frame pred = null;
                                                                                    Vec labels, predlabels, pred2labels;
                                                                                    try {
                                                                                        pred = model2.score(valid);
                                                                                        DKV.put(Key.make("pred"), pred);
                                                                                        // Build a POJO, validate same results
                                                                                        if (!model2.testJavaScoring(valid, pred, 1e-6)) {
                                                                                            model2.testJavaScoring(valid, pred, 1e-6);
                                                                                        }
                                                                                        Assert.assertTrue(model2.testJavaScoring(valid, pred, 1e-6));
                                                                                        hex.ModelMetrics mm = hex.ModelMetrics.getFromDKV(model2, valid);
                                                                                        double error;
                                                                                        // binary
                                                                                        if (model2._output.nclasses() == 2) {
                                                                                            assert (resp == 1);
                                                                                            threshold = mm.auc_obj().defaultThreshold();
                                                                                            error = mm.auc_obj().defaultErr();
                                                                                            // check that auc.cm() is the right CM
                                                                                            Assert.assertEquals(new ConfusionMatrix(mm.auc_obj().defaultCM(), valid.vec(respname).domain()).err(), error, 1e-15);
                                                                                            // check that calcError() is consistent as well (for CM=null, AUC!=null)
                                                                                            Assert.assertEquals(mm.cm().err(), error, 1e-15);
                                                                                            // check that the labels made with the default threshold are consistent with the CM that's reported by the AUC object
                                                                                            labels = valid.vec(respname);
                                                                                            predlabels = pred.vecs()[0];
                                                                                            ConfusionMatrix cm = buildCM(labels, predlabels);
                                                                                            Log.info("CM from pre-made labels:");
                                                                                            Log.info(cm.toASCII());
                                                                                            if (Math.abs(cm.err() - error) > 2e-2) {
                                                                                                ConfusionMatrix cm2 = buildCM(labels, predlabels);
                                                                                                Log.info(cm2.toASCII());
                                                                                            }
                                                                                            Assert.assertEquals(cm.err(), error, 2e-2);
                                                                                            // confirm that orig CM was made with the right threshold
                                                                                            // manually make labels with AUC-given default threshold
                                                                                            String ast = "(as.factor (> (cols pred [2]) " + threshold + "))";
                                                                                            Frame tmp = Rapids.exec(ast).getFrame();
                                                                                            pred2labels = tmp.vecs()[0];
                                                                                            cm = buildCM(labels, pred2labels);
                                                                                            Log.info("CM from self-made labels:");
                                                                                            Log.info(cm.toASCII());
                                                                                            //AUC-given F1-optimal threshold might not reproduce AUC-given CM-error identically, but should match up to 2%
                                                                                            Assert.assertEquals(cm.err(), error, 2e-2);
                                                                                            tmp.delete();
                                                                                        }
                                                                                        DKV.remove(Key.make("pred"));
                                                                                    } finally {
                                                                                        if (pred != null)
                                                                                            pred.delete();
                                                                                    }
                                                                                } else //classifier
                                                                                {
                                                                                    Frame pred = null;
                                                                                    try {
                                                                                        pred = model2.score(valid);
                                                                                        // Build a POJO, validate same results
                                                                                        Assert.assertTrue(model2.testJavaScoring(frame, pred, 1e-6));
                                                                                    } finally {
                                                                                        if (pred != null)
                                                                                            pred.delete();
                                                                                    }
                                                                                }
                                                                                Log.info("Parameters combination " + count + ": PASS");
                                                                            } catch (H2OModelBuilderIllegalArgumentException | IllegalArgumentException ex) {
                                                                                System.err.println(ex);
                                                                                throw H2O.fail("should not get here");
                                                                            } catch (RuntimeException t) {
                                                                                String msg = // this way we evade null messages
                                                                                "" + t.getMessage() + (t.getCause() == null ? "" : t.getCause().getMessage());
                                                                                Assert.assertTrue("Unexpected exception " + t + ": " + msg, msg.contains("unstable"));
                                                                            } catch (AssertionError ae) {
                                                                                // test assertions should be preserved
                                                                                throw ae;
                                                                            } catch (Throwable t) {
                                                                                t.printStackTrace();
                                                                                throw new RuntimeException(t);
                                                                            } finally {
                                                                                if (model1 != null) {
                                                                                    model1.deleteCrossValidationModels();
                                                                                    model1.delete();
                                                                                }
                                                                                if (model2 != null) {
                                                                                    model2.deleteCrossValidationModels();
                                                                                    model2.delete();
                                                                                }
                                                                            }
                                                                        }
                                                                    }
                                                                }
                                                            }
                                                        }
                                                    }
                                                }
                                            }
                                        }
                                    }
                                }
                            }
                        }
                    }
                }
            } finally {
                if (frame != null)
                    frame.delete();
                if (vframe != null)
                    vframe.delete();
            }
        }
    }
    Log.info("\n\n=============================================");
    Log.info("Tested " + testcount + " out of " + count + " parameter combinations.");
    Log.info("=============================================");
    if (checkSums.size() != testcount) {
        Log.info("Only found " + checkSums.size() + " unique checksums.");
    }
    Assert.assertTrue(checkSums.size() == testcount);
}
Also used : LinkedHashSet(java.util.LinkedHashSet) Frame(water.fvec.Frame) ConfusionMatrix(hex.ConfusionMatrix) NFSFileVec(water.fvec.NFSFileVec) DeepLearningParameters(hex.deeplearning.DeepLearningModel.DeepLearningParameters) ClassSamplingMethod(hex.deeplearning.DeepLearningModel.DeepLearningParameters.ClassSamplingMethod) Random(java.util.Random) H2OModelBuilderIllegalArgumentException(water.exceptions.H2OModelBuilderIllegalArgumentException) DistributionFamily(hex.genmodel.utils.DistributionFamily) H2OModelBuilderIllegalArgumentException(water.exceptions.H2OModelBuilderIllegalArgumentException) NFSFileVec(water.fvec.NFSFileVec) Vec(water.fvec.Vec)

Example 3 with DistributionFamily

use of hex.genmodel.utils.DistributionFamily in project h2o-3 by h2oai.

the class DeepLearningGradientCheck method gradientCheck.

@Test
public void gradientCheck() {
    Frame tfr = null;
    DeepLearningModel dl = null;
    try {
        tfr = parse_test_file("smalldata/glm_test/cancar_logIn.csv");
        for (String s : new String[] { "Merit", "Class" }) {
            Vec f = tfr.vec(s).toCategoricalVec();
            tfr.remove(s).remove();
            tfr.add(s, f);
        }
        DKV.put(tfr);
        tfr.add("Binary", tfr.anyVec().makeZero());
        new MRTask() {

            public void map(Chunk[] c) {
                for (int i = 0; i < c[0]._len; ++i) if (c[0].at8(i) == 1)
                    c[1].set(i, 1);
            }
        }.doAll(tfr.vecs(new String[] { "Class", "Binary" }));
        Vec cv = tfr.vec("Binary").toCategoricalVec();
        tfr.remove("Binary").remove();
        tfr.add("Binary", cv);
        DKV.put(tfr);
        Random rng = new Random(0xDECAF);
        int count = 0;
        int failedcount = 0;
        double maxRelErr = 0;
        double meanRelErr = 0;
        for (DistributionFamily dist : new DistributionFamily[] { DistributionFamily.gaussian, DistributionFamily.laplace, DistributionFamily.quantile, DistributionFamily.huber, // DistributionFamily.modified_huber,
        DistributionFamily.gamma, DistributionFamily.poisson, DistributionFamily.AUTO, DistributionFamily.tweedie, DistributionFamily.multinomial, DistributionFamily.bernoulli }) {
            for (DeepLearningParameters.Activation act : new DeepLearningParameters.Activation[] { //            DeepLearningParameters.Activation.ExpRectifier,
            DeepLearningParameters.Activation.Tanh, DeepLearningParameters.Activation.Rectifier }) {
                for (String response : new String[] { //binary classification
                "Binary", //multi-class
                "Class", //regression
                "Cost" }) {
                    for (boolean adaptive : new boolean[] { true, false }) {
                        for (int miniBatchSize : new int[] { 1 }) {
                            if (response.equals("Class")) {
                                if (dist != DistributionFamily.multinomial && dist != DistributionFamily.AUTO)
                                    continue;
                            } else if (response.equals("Binary")) {
                                if (dist != DistributionFamily.modified_huber && dist != DistributionFamily.bernoulli && dist != DistributionFamily.AUTO)
                                    continue;
                            } else {
                                if (dist == DistributionFamily.multinomial || dist == DistributionFamily.modified_huber || dist == DistributionFamily.bernoulli)
                                    continue;
                            }
                            DeepLearningParameters parms = new DeepLearningParameters();
                            parms._huber_alpha = rng.nextDouble() + 0.1;
                            parms._tweedie_power = 1.01 + rng.nextDouble() * 0.9;
                            parms._quantile_alpha = 0.05 + rng.nextDouble() * 0.9;
                            parms._train = tfr._key;
                            //converge to a reasonable model to avoid too large gradients
                            parms._epochs = 100;
                            parms._l1 = 1e-3;
                            parms._l2 = 1e-3;
                            parms._force_load_balance = false;
                            parms._hidden = new int[] { 10, 10, 10 };
                            //otherwise we introduce small bprop errors
                            parms._fast_mode = false;
                            parms._response_column = response;
                            parms._distribution = dist;
                            parms._max_w2 = 10;
                            parms._seed = 0xaaabbb;
                            parms._activation = act;
                            parms._adaptive_rate = adaptive;
                            parms._rate = 1e-4;
                            parms._momentum_start = 0.9;
                            parms._momentum_stable = 0.99;
                            parms._mini_batch_size = miniBatchSize;
                            //                DeepLearningModelInfo.gradientCheck = null;
                            //tell it what gradient to collect
                            DeepLearningModelInfo.gradientCheck = new DeepLearningModelInfo.GradientCheck(0, 0, 0);
                            // Build a first model; all remaining models should be equal
                            DeepLearning job = new DeepLearning(parms);
                            try {
                                dl = job.trainModel().get();
                                boolean classification = response.equals("Class") || response.equals("Binary");
                                if (!classification) {
                                    Frame p = dl.score(tfr);
                                    hex.ModelMetrics mm = hex.ModelMetrics.getFromDKV(dl, tfr);
                                    double resdev = ((ModelMetricsRegression) mm)._mean_residual_deviance;
                                    Log.info("Mean residual deviance: " + resdev);
                                    p.delete();
                                }
                                //golden version
                                DeepLearningModelInfo modelInfo = IcedUtils.deepCopy(dl.model_info());
                                //                Log.info(modelInfo.toStringAll());
                                long before = dl.model_info().checksum_impl();
                                float meanLoss = 0;
                                // loop over every row in the dataset and check that the predictions
                                for (int rId = 0; rId < tfr.numRows(); rId += 1) /*miniBatchSize*/
                                {
                                    // start from scratch - with a clean model
                                    dl.set_model_info(IcedUtils.deepCopy(modelInfo));
                                    final DataInfo di = dl.model_info().data_info();
                                    // populate miniBatch (consecutive rows)
                                    final DataInfo.Row[] rowsMiniBatch = new DataInfo.Row[miniBatchSize];
                                    for (int i = 0; i < rowsMiniBatch.length; ++i) {
                                        if (0 <= rId + i && rId + i < tfr.numRows()) {
                                            rowsMiniBatch[i] = new FrameTask.ExtractDenseRow(di, rId + i).doAll(di._adaptedFrame)._row;
                                        }
                                    }
                                    // loss at weight
                                    long cs = dl.model_info().checksum_impl();
                                    double loss = dl.meanLoss(rowsMiniBatch);
                                    assert (cs == before);
                                    assert (before == dl.model_info().checksum_impl());
                                    meanLoss += loss;
                                    for (int layer = 0; layer <= parms._hidden.length; ++layer) {
                                        int rows = dl.model_info().get_weights(layer).rows();
                                        assert (dl.model_info().get_biases(layer).size() == rows);
                                        for (int row = 0; row < rows; ++row) {
                                            //check bias
                                            if (true) {
                                                // start from scratch - with a clean model
                                                dl.set_model_info(IcedUtils.deepCopy(modelInfo));
                                                // do one forward propagation pass (and fill the mini-batch gradients -> set training=true)
                                                Neurons[] neurons = DeepLearningTask.makeNeuronsForTraining(dl.model_info());
                                                double[] responses = new double[miniBatchSize];
                                                double[] offsets = new double[miniBatchSize];
                                                int n = 0;
                                                for (DataInfo.Row myRow : rowsMiniBatch) {
                                                    if (myRow == null)
                                                        continue;
                                                    ((Neurons.Input) neurons[0]).setInput(-1, myRow.numIds, myRow.numVals, myRow.nBins, myRow.binIds, n);
                                                    responses[n] = myRow.response(0);
                                                    offsets[n] = myRow.offset;
                                                    n++;
                                                }
                                                DeepLearningTask.fpropMiniBatch(-1, /*seed doesn't matter*/
                                                neurons, dl.model_info(), null, true, /*training*/
                                                responses, offsets, n);
                                                // check that we didn't change the model's weights/biases
                                                long after = dl.model_info().checksum_impl();
                                                assert (after == before);
                                                // record the gradient since gradientChecking is enabled
                                                //tell it what gradient to collect
                                                DeepLearningModelInfo.gradientCheck = new DeepLearningModelInfo.GradientCheck(layer, row, -1);
                                                //update the weights and biases
                                                DeepLearningTask.bpropMiniBatch(neurons, n);
                                                assert (before != dl.model_info().checksum_impl());
                                                // reset the model back to the trained model
                                                dl.set_model_info(IcedUtils.deepCopy(modelInfo));
                                                assert (before == dl.model_info().checksum_impl());
                                                double bpropGradient = DeepLearningModelInfo.gradientCheck.gradient;
                                                // FIXME: re-enable this once the loss is computed from the de-standardized prediction/response
                                                //                    double actualResponse=myRow.response[0];
                                                //                    double predResponseLinkSpace = neurons[neurons.length-1]._a.get(0);
                                                //                    if (di._normRespMul != null) {
                                                //                      bpropGradient /= di._normRespMul[0]; //no shift for gradient
                                                //                      actualResponse = (actualResponse / di._normRespMul[0] + di._normRespSub[0]);
                                                //                      predResponseLinkSpace = (predResponseLinkSpace / di._normRespMul[0] + di._normRespSub[0]);
                                                //                    }
                                                //                    bpropGradient *= new Distribution(parms._distribution).gradient(actualResponse, predResponseLinkSpace);
                                                final double bias = dl.model_info().get_biases(layer).get(row);
                                                //don't make the weight deltas too small, or the float weights "won't notice"
                                                double eps = 1e-4 * Math.abs(bias);
                                                if (eps == 0)
                                                    eps = 1e-6;
                                                // loss at bias + eps
                                                dl.model_info().get_biases(layer).set(row, bias + eps);
                                                double up = dl.meanLoss(rowsMiniBatch);
                                                // loss at bias - eps
                                                dl.model_info().get_biases(layer).set(row, bias - eps);
                                                double down = dl.meanLoss(rowsMiniBatch);
                                                if (Math.abs(up - down) / Math.abs(up + down) < 1e-8) {
                                                    //relative change in loss function is too small -> skip
                                                    continue;
                                                }
                                                double gradient = ((up - down) / (2. * eps));
                                                double relError = 2 * Math.abs(bpropGradient - gradient) / (Math.abs(gradient) + Math.abs(bpropGradient));
                                                count++;
                                                // if either gradient is tiny, check if both are tiny
                                                if (Math.abs(gradient) < 1e-7 || Math.abs(bpropGradient) < 1e-7) {
                                                    //all good
                                                    if (Math.abs(bpropGradient - gradient) < 1e-7)
                                                        continue;
                                                }
                                                meanRelErr += relError;
                                                if (relError > MAX_TOLERANCE) {
                                                    Log.info("\nDistribution: " + dl._parms._distribution);
                                                    Log.info("\nRow: " + rId);
                                                    Log.info("bias (layer " + layer + ", row " + row + "): " + bias + " +/- " + eps);
                                                    Log.info("loss: " + loss);
                                                    Log.info("losses up/down: " + up + " / " + down);
                                                    Log.info("=> Finite differences gradient: " + gradient);
                                                    Log.info("=> Back-propagation gradient  : " + bpropGradient);
                                                    Log.info("=> Relative error             : " + PrettyPrint.formatPct(relError));
                                                    failedcount++;
                                                }
                                            }
                                            int cols = dl.model_info().get_weights(layer).cols();
                                            for (int col = 0; col < cols; ++col) {
                                                if (rng.nextFloat() >= SAMPLE_RATE)
                                                    continue;
                                                // start from scratch - with a clean model
                                                dl.set_model_info(IcedUtils.deepCopy(modelInfo));
                                                // do one forward propagation pass (and fill the mini-batch gradients -> set training=true)
                                                Neurons[] neurons = DeepLearningTask.makeNeuronsForTraining(dl.model_info());
                                                double[] responses = new double[miniBatchSize];
                                                double[] offsets = new double[miniBatchSize];
                                                int n = 0;
                                                for (DataInfo.Row myRow : rowsMiniBatch) {
                                                    if (myRow == null)
                                                        continue;
                                                    ((Neurons.Input) neurons[0]).setInput(-1, myRow.numIds, myRow.numVals, myRow.nBins, myRow.binIds, n);
                                                    responses[n] = myRow.response(0);
                                                    offsets[n] = myRow.offset;
                                                    n++;
                                                }
                                                DeepLearningTask.fpropMiniBatch(-1, /*seed doesn't matter*/
                                                neurons, dl.model_info(), null, true, /*training*/
                                                responses, offsets, n);
                                                // check that we didn't change the model's weights/biases
                                                long after = dl.model_info().checksum_impl();
                                                assert (after == before);
                                                // record the gradient since gradientChecking is enabled
                                                //tell it what gradient to collect
                                                DeepLearningModelInfo.gradientCheck = new DeepLearningModelInfo.GradientCheck(layer, row, col);
                                                //update the weights
                                                DeepLearningTask.bpropMiniBatch(neurons, n);
                                                assert (before != dl.model_info().checksum_impl());
                                                // reset the model back to the trained model
                                                dl.set_model_info(IcedUtils.deepCopy(modelInfo));
                                                assert (before == dl.model_info().checksum_impl());
                                                double bpropGradient = DeepLearningModelInfo.gradientCheck.gradient;
                                                // FIXME: re-enable this once the loss is computed from the de-standardized prediction/response
                                                //                    double actualResponse=myRow.response[0];
                                                //                    double predResponseLinkSpace = neurons[neurons.length-1]._a.get(0);
                                                //                    if (di._normRespMul != null) {
                                                //                      bpropGradient /= di._normRespMul[0]; //no shift for gradient
                                                //                      actualResponse = (actualResponse / di._normRespMul[0] + di._normRespSub[0]);
                                                //                      predResponseLinkSpace = (predResponseLinkSpace / di._normRespMul[0] + di._normRespSub[0]);
                                                //                    }
                                                //                    bpropGradient *= new Distribution(parms._distribution).gradient(actualResponse, predResponseLinkSpace);
                                                final float weight = dl.model_info().get_weights(layer).get(row, col);
                                                //don't make the weight deltas too small, or the float weights "won't notice"
                                                double eps = 1e-4 * Math.abs(weight);
                                                if (eps == 0)
                                                    eps = 1e-6;
                                                // loss at weight + eps
                                                dl.model_info().get_weights(layer).set(row, col, (float) (weight + eps));
                                                double up = dl.meanLoss(rowsMiniBatch);
                                                // loss at weight - eps
                                                dl.model_info().get_weights(layer).set(row, col, (float) (weight - eps));
                                                double down = dl.meanLoss(rowsMiniBatch);
                                                if (Math.abs(up - down) / Math.abs(up + down) < 1e-8) {
                                                    //relative change in loss function is too small -> skip
                                                    continue;
                                                }
                                                double gradient = ((up - down) / (2. * eps));
                                                double relError = 2 * Math.abs(bpropGradient - gradient) / (Math.abs(gradient) + Math.abs(bpropGradient));
                                                count++;
                                                // if either gradient is tiny, check if both are tiny
                                                if (Math.abs(gradient) < 1e-7 || Math.abs(bpropGradient) < 1e-7) {
                                                    //all good
                                                    if (Math.abs(bpropGradient - gradient) < 1e-7)
                                                        continue;
                                                }
                                                meanRelErr += relError;
                                                if (relError > MAX_TOLERANCE) {
                                                    Log.info("\nDistribution: " + dl._parms._distribution);
                                                    Log.info("\nRow: " + rId);
                                                    Log.info("weight (layer " + layer + ", row " + row + ", col " + col + "): " + weight + " +/- " + eps);
                                                    Log.info("loss: " + loss);
                                                    Log.info("losses up/down: " + up + " / " + down);
                                                    Log.info("=> Finite differences gradient: " + gradient);
                                                    Log.info("=> Back-propagation gradient  : " + bpropGradient);
                                                    Log.info("=> Relative error             : " + PrettyPrint.formatPct(relError));
                                                    failedcount++;
                                                }
                                                //                          Assert.assertTrue(failedcount==0);
                                                maxRelErr = Math.max(maxRelErr, relError);
                                                assert (!Double.isNaN(maxRelErr));
                                            }
                                        }
                                    }
                                }
                                meanLoss /= tfr.numRows();
                                Log.info("Mean loss: " + meanLoss);
                            //                  // FIXME: re-enable this
                            //                  if (parms._l1 == 0 && parms._l2 == 0) {
                            //                    assert(Math.abs(meanLoss-resdev)/Math.abs(resdev) < 1e-5);
                            //                  }
                            } catch (RuntimeException ex) {
                                dl = DKV.getGet(job.dest());
                                if (dl != null)
                                    Assert.assertTrue(dl.model_info().isUnstable());
                                else
                                    Assert.assertTrue(job.isStopped());
                            } finally {
                                if (dl != null)
                                    dl.delete();
                            }
                        }
                    }
                }
            }
        }
        Log.info("Number of tests: " + count);
        Log.info("Number of failed tests: " + failedcount);
        Log.info("Mean. relative error: " + meanRelErr / count);
        Log.info("Max. relative error: " + PrettyPrint.formatPct(maxRelErr));
        Assert.assertTrue("Error too large: " + maxRelErr + " >= " + MAX_TOLERANCE, maxRelErr < MAX_TOLERANCE);
        Assert.assertTrue("Failed count too large: " + failedcount + " > " + MAX_FAILED_COUNT, failedcount <= MAX_FAILED_COUNT);
    } finally {
        if (tfr != null)
            tfr.remove();
    }
}
Also used : Frame(water.fvec.Frame) DeepLearningParameters(hex.deeplearning.DeepLearningModel.DeepLearningParameters) ModelMetricsRegression(hex.ModelMetricsRegression) Random(java.util.Random) FrameTask(hex.FrameTask) DataInfo(hex.DataInfo) DistributionFamily(hex.genmodel.utils.DistributionFamily) Chunk(water.fvec.Chunk) PrettyPrint(water.util.PrettyPrint) Vec(water.fvec.Vec) Test(org.junit.Test)

Example 4 with DistributionFamily

use of hex.genmodel.utils.DistributionFamily in project h2o-3 by h2oai.

the class DeepLearningTest method testDistributions.

@Test
public void testDistributions() {
    Frame tfr = null, vfr = null, fr2 = null;
    DeepLearningModel dl = null;
    for (DistributionFamily dist : new DistributionFamily[] { AUTO, gaussian, poisson, gamma, tweedie }) {
        Scope.enter();
        try {
            tfr = parse_test_file("smalldata/glm_test/cancar_logIn.csv");
            for (String s : new String[] { "Merit", "Class" }) {
                Scope.track(tfr.replace(tfr.find(s), tfr.vec(s).toCategoricalVec()));
            }
            DKV.put(tfr);
            DeepLearningParameters parms = new DeepLearningParameters();
            parms._train = tfr._key;
            parms._epochs = 1;
            parms._reproducible = true;
            parms._hidden = new int[] { 50, 50 };
            parms._response_column = "Cost";
            parms._seed = 0xdecaf;
            parms._distribution = dist;
            // Build a first model; all remaining models should be equal
            DeepLearning job = new DeepLearning(parms);
            dl = job.trainModel().get();
            ModelMetricsRegression mm = (ModelMetricsRegression) dl._output._training_metrics;
            if (dist == gaussian || dist == AUTO)
                Assert.assertEquals(mm._mean_residual_deviance, mm._MSE, 1e-6);
            else
                assertTrue(mm._mean_residual_deviance != mm._MSE);
            assertTrue(dl.testJavaScoring(tfr, fr2 = dl.score(tfr), 1e-5));
        } finally {
            if (tfr != null)
                tfr.remove();
            if (dl != null)
                dl.delete();
            if (fr2 != null)
                fr2.delete();
            Scope.exit();
        }
    }
}
Also used : Frame(water.fvec.Frame) DistributionFamily(hex.genmodel.utils.DistributionFamily) DeepLearningParameters(hex.deeplearning.DeepLearningModel.DeepLearningParameters) Test(org.junit.Test)

Example 5 with DistributionFamily

use of hex.genmodel.utils.DistributionFamily in project h2o-3 by h2oai.

the class GBMTest method testDeviances.

@Test
public void testDeviances() {
    for (DistributionFamily dist : DistributionFamily.values()) {
        if (dist == modified_huber)
            continue;
        Frame tfr = null;
        Frame res = null;
        Frame preds = null;
        GBMModel gbm = null;
        try {
            tfr = parse_test_file("./smalldata/gbm_test/BostonHousing.csv");
            GBMModel.GBMParameters parms = new GBMModel.GBMParameters();
            parms._train = tfr._key;
            String resp = tfr.lastVecName();
            if (dist == modified_huber || dist == bernoulli || dist == multinomial) {
                resp = dist == multinomial ? "rad" : "chas";
                Vec v = tfr.remove(resp);
                tfr.add(resp, v.toCategoricalVec());
                v.remove();
                DKV.put(tfr);
            }
            parms._response_column = resp;
            parms._distribution = dist;
            gbm = new GBM(parms).trainModel().get();
            preds = gbm.score(tfr);
            res = gbm.computeDeviances(tfr, preds, "myDeviances");
            double meanDeviance = res.anyVec().mean();
            if (gbm._output.nclasses() == 2)
                Assert.assertEquals(meanDeviance, ((ModelMetricsBinomial) gbm._output._training_metrics)._logloss, 1e-6 * Math.abs(meanDeviance));
            else if (gbm._output.nclasses() > 2)
                Assert.assertEquals(meanDeviance, ((ModelMetricsMultinomial) gbm._output._training_metrics)._logloss, 1e-6 * Math.abs(meanDeviance));
            else
                Assert.assertEquals(meanDeviance, ((ModelMetricsRegression) gbm._output._training_metrics)._mean_residual_deviance, 1e-6 * Math.abs(meanDeviance));
        } finally {
            if (tfr != null)
                tfr.delete();
            if (res != null)
                res.delete();
            if (preds != null)
                preds.delete();
            if (gbm != null)
                gbm.delete();
        }
    }
}
Also used : Frame(water.fvec.Frame) DistributionFamily(hex.genmodel.utils.DistributionFamily) Vec(water.fvec.Vec) FVecTest.makeByteVec(water.fvec.FVecTest.makeByteVec) Test(org.junit.Test)

Aggregations

DistributionFamily (hex.genmodel.utils.DistributionFamily)7 Test (org.junit.Test)5 Frame (water.fvec.Frame)5 DeepLearningParameters (hex.deeplearning.DeepLearningModel.DeepLearningParameters)4 Random (java.util.Random)3 Vec (water.fvec.Vec)3 PrettyPrint (water.util.PrettyPrint)2 ConfusionMatrix (hex.ConfusionMatrix)1 DataInfo (hex.DataInfo)1 Distribution (hex.Distribution)1 FrameTask (hex.FrameTask)1 ModelMetricsRegression (hex.ModelMetricsRegression)1 ClassSamplingMethod (hex.deeplearning.DeepLearningModel.DeepLearningParameters.ClassSamplingMethod)1 GLMModel (hex.glm.GLMModel)1 DRFModel (hex.tree.drf.DRFModel)1 Field (java.lang.reflect.Field)1 LinkedHashSet (java.util.LinkedHashSet)1 H2OIllegalArgumentException (water.exceptions.H2OIllegalArgumentException)1 H2OModelBuilderIllegalArgumentException (water.exceptions.H2OModelBuilderIllegalArgumentException)1 Chunk (water.fvec.Chunk)1