Search in sources :

Example 46 with DeepLearningParameters

use of hex.deeplearning.DeepLearningModel.DeepLearningParameters in project h2o-3 by h2oai.

the class DeepLearningTest method testHuberDeltaLarge.

@Test
public void testHuberDeltaLarge() {
    Frame tfr = null;
    DeepLearningModel dl = null;
    try {
        tfr = parse_test_file("./smalldata/gbm_test/BostonHousing.csv");
        DeepLearningParameters parms = new DeepLearningParameters();
        parms._train = tfr._key;
        parms._response_column = tfr.lastVecName();
        parms._reproducible = true;
        parms._hidden = new int[] { 20, 20 };
        parms._seed = 0xdecaf;
        parms._distribution = huber;
        //just like gaussian
        parms._huber_alpha = 1;
        dl = new DeepLearning(parms).trainModel().get();
        Assert.assertEquals(12.93808, /*MSE*/
        ((ModelMetricsRegression) dl._output._training_metrics)._mean_residual_deviance, 0.7);
        Assert.assertEquals(12.93808, /*MSE*/
        ((ModelMetricsRegression) dl._output._training_metrics)._MSE, 0.7);
    } finally {
        if (tfr != null)
            tfr.delete();
        if (dl != null)
            dl.deleteCrossValidationModels();
        if (dl != null)
            dl.delete();
    }
}
Also used : Frame(water.fvec.Frame) DeepLearningParameters(hex.deeplearning.DeepLearningModel.DeepLearningParameters) Test(org.junit.Test)

Example 47 with DeepLearningParameters

use of hex.deeplearning.DeepLearningModel.DeepLearningParameters in project h2o-3 by h2oai.

the class DeepLearningTest method testCheckpointSameEpochs.

@Test
public void testCheckpointSameEpochs() {
    Frame tfr = null;
    DeepLearningModel dl = null;
    DeepLearningModel dl2 = null;
    try {
        tfr = parse_test_file("./smalldata/iris/iris.csv");
        DeepLearningParameters parms = new DeepLearningParameters();
        parms._train = tfr._key;
        parms._epochs = 10;
        parms._response_column = "C5";
        parms._reproducible = true;
        parms._hidden = new int[] { 2, 2 };
        parms._seed = 0xdecaf;
        parms._variable_importances = true;
        dl = new DeepLearning(parms).trainModel().get();
        DeepLearningParameters parms2 = (DeepLearningParameters) parms.clone();
        parms2._epochs = 10;
        parms2._checkpoint = dl._key;
        try {
            dl2 = new DeepLearning(parms2).trainModel().get();
            Assert.fail("Should toss exception instead of reaching here");
        } catch (H2OIllegalArgumentException ex) {
        }
    } finally {
        if (tfr != null)
            tfr.delete();
        if (dl != null)
            dl.delete();
        if (dl2 != null)
            dl2.delete();
    }
}
Also used : Frame(water.fvec.Frame) H2OIllegalArgumentException(water.exceptions.H2OIllegalArgumentException) DeepLearningParameters(hex.deeplearning.DeepLearningModel.DeepLearningParameters) Test(org.junit.Test)

Example 48 with DeepLearningParameters

use of hex.deeplearning.DeepLearningModel.DeepLearningParameters in project h2o-3 by h2oai.

the class DeepLearningTest method basicDL.

public void basicDL(String fnametrain, String hexnametrain, String fnametest, PrepData prep, int epochs, double[][] expCM, String[] expRespDom, double expMSE, int[] hidden, double l1, boolean classification, DeepLearningParameters.Activation activation) throws Throwable {
    Scope.enter();
    DeepLearningParameters dl = new DeepLearningParameters();
    Frame frTest = null, pred = null;
    Frame frTrain = null;
    Frame test = null, res = null;
    DeepLearningModel model = null;
    try {
        frTrain = parse_test_file(fnametrain);
        Vec removeme = unifyFrame(dl, frTrain, prep, classification);
        if (removeme != null)
            Scope.track(removeme);
        DKV.put(frTrain._key, frTrain);
        // Configure DL
        dl._train = frTrain._key;
        dl._response_column = ((Frame) DKV.getGet(dl._train)).lastVecName();
        dl._seed = (1L << 32) | 2;
        dl._reproducible = true;
        dl._epochs = epochs;
        dl._stopping_rounds = 0;
        dl._activation = activation;
        dl._export_weights_and_biases = RandomUtils.getRNG(fnametrain.hashCode()).nextBoolean();
        dl._hidden = hidden;
        dl._l1 = l1;
        dl._elastic_averaging = false;
        // Invoke DL and block till the end
        DeepLearning job = new DeepLearning(dl, Key.<DeepLearningModel>make("DL_model_" + hexnametrain));
        // Get the model
        model = job.trainModel().get();
        Log.info(model._output);
        //HEX-1817
        assertTrue(job.isStopped());
        hex.ModelMetrics mm;
        if (fnametest != null) {
            frTest = parse_test_file(fnametest);
            pred = model.score(frTest);
            mm = hex.ModelMetrics.getFromDKV(model, frTest);
        // Check test set CM
        } else {
            pred = model.score(frTrain);
            mm = hex.ModelMetrics.getFromDKV(model, frTrain);
        }
        test = parse_test_file(fnametrain);
        res = model.score(test);
        if (classification) {
            assertTrue("Expected: " + Arrays.deepToString(expCM) + ", Got: " + Arrays.deepToString(mm.cm()._cm), Arrays.deepEquals(mm.cm()._cm, expCM));
            String[] cmDom = model._output._domains[model._output._domains.length - 1];
            Assert.assertArrayEquals("CM domain differs!", expRespDom, cmDom);
            Log.info("\nTraining CM:\n" + mm.cm().toASCII());
            Log.info("\nTraining CM:\n" + hex.ModelMetrics.getFromDKV(model, test).cm().toASCII());
        } else {
            assertTrue("Expected: " + expMSE + ", Got: " + mm.mse(), MathUtils.compare(expMSE, mm.mse(), 1e-8, 1e-8));
            Log.info("\nOOB Training MSE: " + mm.mse());
            Log.info("\nTraining MSE: " + hex.ModelMetrics.getFromDKV(model, test).mse());
        }
        hex.ModelMetrics.getFromDKV(model, test);
        // Build a POJO, validate same results
        assertTrue(model.testJavaScoring(test, res, 1e-5));
    } finally {
        if (frTrain != null)
            frTrain.remove();
        if (frTest != null)
            frTest.remove();
        // Remove the model
        if (model != null)
            model.delete();
        if (pred != null)
            pred.delete();
        if (test != null)
            test.delete();
        if (res != null)
            res.delete();
        Scope.exit();
    }
}
Also used : Frame(water.fvec.Frame) NFSFileVec(water.fvec.NFSFileVec) Vec(water.fvec.Vec) DeepLearningParameters(hex.deeplearning.DeepLearningModel.DeepLearningParameters) hex(hex)

Example 49 with DeepLearningParameters

use of hex.deeplearning.DeepLearningModel.DeepLearningParameters in project h2o-3 by h2oai.

the class DeepLearningTest method testCategoricalEncodingRegressionHuber.

@Test
public void testCategoricalEncodingRegressionHuber() {
    Frame tfr = null;
    DeepLearningModel dl = null;
    try {
        String response = "age";
        tfr = parse_test_file("./smalldata/junit/titanic_alt.csv");
        if (tfr.vec(response).isBinary()) {
            Vec v = tfr.remove(response);
            tfr.add(response, v.toCategoricalVec());
            v.remove();
        }
        DKV.put(tfr);
        DeepLearningParameters parms = new DeepLearningParameters();
        parms._train = tfr._key;
        parms._valid = tfr._key;
        parms._response_column = response;
        parms._reproducible = true;
        parms._hidden = new int[] { 20, 20 };
        parms._seed = 0xdecaf;
        parms._nfolds = 3;
        parms._distribution = huber;
        parms._categorical_encoding = Model.Parameters.CategoricalEncodingScheme.Binary;
        dl = new DeepLearning(parms).trainModel().get();
        Assert.assertEquals(87.26206135855, ((ModelMetricsRegression) dl._output._training_metrics)._mean_residual_deviance, 1e-4);
        Assert.assertEquals(87.26206135855, ((ModelMetricsRegression) dl._output._validation_metrics)._mean_residual_deviance, 1e-4);
        Assert.assertEquals(117.8014, ((ModelMetricsRegression) dl._output._cross_validation_metrics)._mean_residual_deviance, 1e-4);
        Assert.assertEquals(117.8014, Double.parseDouble((String) (dl._output._cross_validation_metrics_summary).get(3, 0)), 1);
    } finally {
        if (tfr != null)
            tfr.remove();
        if (dl != null)
            dl.deleteCrossValidationModels();
        if (dl != null)
            dl.delete();
    }
}
Also used : Frame(water.fvec.Frame) NFSFileVec(water.fvec.NFSFileVec) Vec(water.fvec.Vec) DeepLearningParameters(hex.deeplearning.DeepLearningModel.DeepLearningParameters) Test(org.junit.Test)

Example 50 with DeepLearningParameters

use of hex.deeplearning.DeepLearningModel.DeepLearningParameters in project h2o-3 by h2oai.

the class DeepLearningTest method testCrossValidation.

@Test
public void testCrossValidation() {
    Frame tfr = null;
    DeepLearningModel dl = null;
    try {
        tfr = parse_test_file("./smalldata/gbm_test/BostonHousing.csv");
        DeepLearningParameters parms = new DeepLearningParameters();
        parms._train = tfr._key;
        parms._response_column = tfr.lastVecName();
        parms._reproducible = true;
        parms._hidden = new int[] { 20, 20 };
        parms._seed = 0xdecaf;
        parms._nfolds = 4;
        dl = new DeepLearning(parms).trainModel().get();
        Assert.assertEquals(12.959355363801334, dl._output._training_metrics._MSE, 1e-6);
        Assert.assertEquals(17.296871012606317, dl._output._cross_validation_metrics._MSE, 1e-6);
    } finally {
        if (tfr != null)
            tfr.delete();
        if (dl != null)
            dl.deleteCrossValidationModels();
        if (dl != null)
            dl.delete();
    }
}
Also used : Frame(water.fvec.Frame) DeepLearningParameters(hex.deeplearning.DeepLearningModel.DeepLearningParameters) Test(org.junit.Test)

Aggregations

DeepLearningParameters (hex.deeplearning.DeepLearningModel.DeepLearningParameters)58 Frame (water.fvec.Frame)54 Test (org.junit.Test)52 NFSFileVec (water.fvec.NFSFileVec)26 Vec (water.fvec.Vec)22 hex (hex)5 Ignore (org.junit.Ignore)5 DistributionFamily (hex.genmodel.utils.DistributionFamily)4 Random (java.util.Random)4 H2OIllegalArgumentException (water.exceptions.H2OIllegalArgumentException)3 DataInfo (hex.DataInfo)2 File (java.io.File)2 Key (water.Key)2 H2OModelBuilderIllegalArgumentException (water.exceptions.H2OModelBuilderIllegalArgumentException)2 PrettyPrint (water.util.PrettyPrint)2 ConfusionMatrix (hex.ConfusionMatrix)1 Distribution (hex.Distribution)1 FrameSplitter (hex.FrameSplitter)1 FrameTask (hex.FrameTask)1 ModelMetricsBinomial (hex.ModelMetricsBinomial)1