Search in sources :

Example 11 with DeepLearningModel

use of hex.deeplearning.DeepLearningModel in project h2o-2 by h2oai.

the class DeepLearningAutoEncoderCategoricalTest method run.

@Test
public void run() {
    long seed = 0xDECAF;
    Key file_train = NFSFileVec.make(find_test_file(PATH));
    Frame train = ParseDataset2.parse(Key.make(), new Key[] { file_train });
    DeepLearning p = new DeepLearning();
    p.source = train;
    p.autoencoder = true;
    p.response = train.lastVec();
    p.seed = seed;
    p.hidden = new int[] { 100, 50, 20 };
    //    p.ignored_cols = new int[]{0,1,2,3,6,7,8,10}; //Optional: ignore all categoricals
    //    p.ignored_cols = new int[]{4,5,9}; //Optional: ignore all numericals
    p.adaptive_rate = true;
    p.l1 = 1e-4;
    p.activation = DeepLearning.Activation.Tanh;
    p.train_samples_per_iteration = -1;
    p.loss = DeepLearning.Loss.MeanSquare;
    p.epochs = 2;
    //    p.shuffle_training_data = true;
    p.force_load_balance = true;
    p.score_training_samples = 0;
    p.score_validation_samples = 0;
    //    p.reproducible = true;
    p.invoke();
    // Verification of results
    StringBuilder sb = new StringBuilder();
    sb.append("Verifying results.\n");
    DeepLearningModel mymodel = UKV.get(p.dest());
    sb.append("Reported mean reconstruction error: " + mymodel.mse() + "\n");
    // Training data
    // Reconstruct data using the same helper functions and verify that self-reported MSE agrees
    final Frame l2 = mymodel.scoreAutoEncoder(train);
    final Vec l2vec = l2.anyVec();
    sb.append("Actual   mean reconstruction error: " + l2vec.mean() + "\n");
    // print stats and potential outliers
    double quantile = 1 - 5. / train.numRows();
    sb.append("The following training points are reconstructed with an error above the " + quantile * 100 + "-th percentile - potential \"outliers\" in testing data.\n");
    double thresh = mymodel.calcOutlierThreshold(l2vec, quantile);
    for (long i = 0; i < l2vec.length(); i++) {
        if (l2vec.at(i) > thresh) {
            sb.append(String.format("row %d : l2vec error = %5f\n", i, l2vec.at(i)));
        }
    }
    Log.info(sb.toString());
    Assert.assertEquals(mymodel.mse(), l2vec.mean(), 1e-8);
    // Create reconstruction
    Log.info("Creating full reconstruction.");
    final Frame recon_train = mymodel.score(train);
    // cleanup
    recon_train.delete();
    train.delete();
    p.delete();
    mymodel.delete();
    l2.delete();
}
Also used : Frame(water.fvec.Frame) NFSFileVec(water.fvec.NFSFileVec) Vec(water.fvec.Vec) DeepLearning(hex.deeplearning.DeepLearning) Key(water.Key) DeepLearningModel(hex.deeplearning.DeepLearningModel) Test(org.junit.Test)

Example 12 with DeepLearningModel

use of hex.deeplearning.DeepLearningModel in project h2o-3 by h2oai.

the class XValPredictionsCheck method testXValPredictions.

@Test
public void testXValPredictions() {
    final int nfolds = 3;
    Frame tfr = null;
    try {
        // Load data, hack frames
        tfr = parse_test_file("smalldata/iris/iris_wheader.csv");
        Frame foldId = new Frame(new String[] { "foldId" }, new Vec[] { AstKFold.kfoldColumn(tfr.vec("class").makeZero(), nfolds, 543216789) });
        tfr.add(foldId);
        DKV.put(tfr);
        // GBM
        GBMModel.GBMParameters parms = new GBMModel.GBMParameters();
        parms._train = tfr._key;
        parms._response_column = "class";
        parms._ntrees = 1;
        parms._max_depth = 1;
        parms._fold_column = "foldId";
        parms._distribution = DistributionFamily.multinomial;
        parms._keep_cross_validation_predictions = true;
        GBM job = new GBM(parms);
        GBMModel gbm = job.trainModel().get();
        checkModel(gbm, foldId.anyVec(), 3);
        // DRF
        DRFModel.DRFParameters parmsDRF = new DRFModel.DRFParameters();
        parmsDRF._train = tfr._key;
        parmsDRF._response_column = "class";
        parmsDRF._ntrees = 1;
        parmsDRF._max_depth = 1;
        parmsDRF._fold_column = "foldId";
        parmsDRF._distribution = DistributionFamily.multinomial;
        parmsDRF._keep_cross_validation_predictions = true;
        DRF drfJob = new DRF(parmsDRF);
        DRFModel drf = drfJob.trainModel().get();
        checkModel(drf, foldId.anyVec(), 3);
        // GLM
        GLMModel.GLMParameters parmsGLM = new GLMModel.GLMParameters();
        parmsGLM._train = tfr._key;
        parmsGLM._response_column = "sepal_len";
        parmsGLM._fold_column = "foldId";
        parmsGLM._keep_cross_validation_predictions = true;
        GLM glmJob = new GLM(parmsGLM);
        GLMModel glm = glmJob.trainModel().get();
        checkModel(glm, foldId.anyVec(), 1);
        // DL
        DeepLearningModel.DeepLearningParameters parmsDL = new DeepLearningModel.DeepLearningParameters();
        parmsDL._train = tfr._key;
        parmsDL._response_column = "class";
        parmsDL._hidden = new int[] { 1 };
        parmsDL._epochs = 1;
        parmsDL._fold_column = "foldId";
        parmsDL._keep_cross_validation_predictions = true;
        DeepLearning dlJob = new DeepLearning(parmsDL);
        DeepLearningModel dl = dlJob.trainModel().get();
        checkModel(dl, foldId.anyVec(), 3);
    } finally {
        if (tfr != null)
            tfr.remove();
    }
}
Also used : Frame(water.fvec.Frame) DRFModel(hex.tree.drf.DRFModel) GLMModel(hex.glm.GLMModel) GLM(hex.glm.GLM) DeepLearning(hex.deeplearning.DeepLearning) GBMModel(hex.tree.gbm.GBMModel) GBM(hex.tree.gbm.GBM) DRF(hex.tree.drf.DRF) DeepLearningModel(hex.deeplearning.DeepLearningModel) Test(org.junit.Test)

Aggregations

DeepLearningModel (hex.deeplearning.DeepLearningModel)12 Frame (water.fvec.Frame)10 DeepLearning (hex.deeplearning.DeepLearning)9 Test (org.junit.Test)6 Random (java.util.Random)4 Key (water.Key)4 Vec (water.fvec.Vec)3 Neurons (hex.deeplearning.Neurons)2 GLM (hex.glm.GLM)2 GLMModel (hex.glm.GLMModel)2 DRF (hex.tree.drf.DRF)2 DRFModel (hex.tree.drf.DRFModel)2 GBM (hex.tree.gbm.GBM)2 GBMModel (hex.tree.gbm.GBMModel)2 NFSFileVec (water.fvec.NFSFileVec)2 NeuralNet (hex.NeuralNet)1 Grid (hex.grid.Grid)1 GridSearch (hex.grid.GridSearch)1 DRFParametersV3 (hex.schemas.DRFV3.DRFParametersV3)1 DeepLearningParametersV3 (hex.schemas.DeepLearningV3.DeepLearningParametersV3)1