Search in sources :

Example 6 with DistributionFamily

use of hex.genmodel.utils.DistributionFamily in project h2o-3 by h2oai.

the class GBMTest method testDistributions.

// just a simple sanity check - not a golden test
@Test
public void testDistributions() {
    Frame tfr = null, vfr = null, res = null;
    GBMModel gbm = null;
    for (DistributionFamily dist : new DistributionFamily[] { DistributionFamily.AUTO, gaussian, DistributionFamily.poisson, DistributionFamily.gamma, DistributionFamily.tweedie }) {
        Scope.enter();
        try {
            tfr = parse_test_file("smalldata/glm_test/cancar_logIn.csv");
            vfr = parse_test_file("smalldata/glm_test/cancar_logIn.csv");
            for (String s : new String[] { "Merit", "Class" }) {
                Scope.track(tfr.replace(tfr.find(s), tfr.vec(s).toCategoricalVec()));
                Scope.track(vfr.replace(vfr.find(s), vfr.vec(s).toCategoricalVec()));
            }
            DKV.put(tfr);
            DKV.put(vfr);
            GBMModel.GBMParameters parms = new GBMModel.GBMParameters();
            parms._train = tfr._key;
            parms._response_column = "Cost";
            parms._seed = 0xdecaf;
            parms._distribution = dist;
            parms._min_rows = 1;
            parms._ntrees = 30;
            //        parms._offset_column = "logInsured"; //POJO scoring not supported for offsets
            parms._learn_rate = 1e-3f;
            // Build a first model; all remaining models should be equal
            gbm = new GBM(parms).trainModel().get();
            res = gbm.score(vfr);
            Assert.assertTrue(gbm.testJavaScoring(vfr, res, 1e-15));
            res.remove();
            ModelMetricsRegression mm = (ModelMetricsRegression) gbm._output._training_metrics;
        } finally {
            if (tfr != null)
                tfr.remove();
            if (vfr != null)
                vfr.remove();
            if (res != null)
                res.remove();
            if (gbm != null)
                gbm.delete();
            Scope.exit();
        }
    }
}
Also used : Frame(water.fvec.Frame) DistributionFamily(hex.genmodel.utils.DistributionFamily) Test(org.junit.Test)

Example 7 with DistributionFamily

use of hex.genmodel.utils.DistributionFamily in project h2o-3 by h2oai.

the class DeepLearningGradientCheck method checkDistributionGradients.

@Test
public void checkDistributionGradients() {
    Random rng = new Random(0xDECAF);
    for (DistributionFamily dist : new DistributionFamily[] { DistributionFamily.AUTO, DistributionFamily.gaussian, DistributionFamily.laplace, DistributionFamily.quantile, DistributionFamily.huber, DistributionFamily.gamma, DistributionFamily.poisson, DistributionFamily.tweedie, DistributionFamily.bernoulli }) {
        DeepLearningParameters p = new DeepLearningParameters();
        p._distribution = dist;
        int N = 1000;
        double eps = 1. / (10. * N);
        for (double y : new double[] { 0, 1 }) {
            // scan the range -2..2 in function approximation space (link space)
            for (int i = -5 * N; i < 5 * N; ++i) {
                p._huber_alpha = rng.nextDouble() + 0.1;
                p._tweedie_power = 1.01 + rng.nextDouble() * 0.9;
                p._quantile_alpha = 0.05 + rng.nextDouble() * 0.9;
                Distribution d = new Distribution(p);
                // avoid issues at 0
                double f = (i + 0.5) / N;
                //f in link space (model space)
                double grad = -2 * d.negHalfGradient(y, f);
                double w = rng.nextDouble() * 10;
                //deviance in real space
                double approxgrad = (d.deviance(w, y, d.linkInv(f + eps)) - d.deviance(w, y, d.linkInv(f - eps))) / (2 * eps * w);
                assert (Math.abs(grad - approxgrad) <= 1e-4);
            }
        }
    }
}
Also used : Random(java.util.Random) DistributionFamily(hex.genmodel.utils.DistributionFamily) Distribution(hex.Distribution) DeepLearningParameters(hex.deeplearning.DeepLearningModel.DeepLearningParameters) PrettyPrint(water.util.PrettyPrint) Test(org.junit.Test)

Aggregations

DistributionFamily (hex.genmodel.utils.DistributionFamily)7 Test (org.junit.Test)5 Frame (water.fvec.Frame)5 DeepLearningParameters (hex.deeplearning.DeepLearningModel.DeepLearningParameters)4 Random (java.util.Random)3 Vec (water.fvec.Vec)3 PrettyPrint (water.util.PrettyPrint)2 ConfusionMatrix (hex.ConfusionMatrix)1 DataInfo (hex.DataInfo)1 Distribution (hex.Distribution)1 FrameTask (hex.FrameTask)1 ModelMetricsRegression (hex.ModelMetricsRegression)1 ClassSamplingMethod (hex.deeplearning.DeepLearningModel.DeepLearningParameters.ClassSamplingMethod)1 GLMModel (hex.glm.GLMModel)1 DRFModel (hex.tree.drf.DRFModel)1 Field (java.lang.reflect.Field)1 LinkedHashSet (java.util.LinkedHashSet)1 H2OIllegalArgumentException (water.exceptions.H2OIllegalArgumentException)1 H2OModelBuilderIllegalArgumentException (water.exceptions.H2OModelBuilderIllegalArgumentException)1 Chunk (water.fvec.Chunk)1