Search in sources :

Example 1 with Grid

use of hex.grid.Grid in project h2o-3 by h2oai.

the class GBMGridTest method testDuplicatesCarsGrid.

//@Ignore("PUBDEV-1643")
@Test
public void testDuplicatesCarsGrid() {
    Grid grid = null;
    Frame fr = null;
    Vec old = null;
    try {
        fr = parse_test_file("smalldata/junit/cars_20mpg.csv");
        // Remove unique id
        fr.remove("name").remove();
        old = fr.remove("economy");
        // response to last column
        fr.add("economy", old);
        DKV.put(fr);
        // Setup random hyperparameter search space
        HashMap<String, Object[]> hyperParms = new HashMap<String, Object[]>() {

            {
                put("_distribution", new DistributionFamily[] { DistributionFamily.gaussian });
                put("_ntrees", new Integer[] { 5, 5 });
                put("_max_depth", new Integer[] { 2, 2 });
                put("_learn_rate", new Double[] { .1, .1 });
            }
        };
        // Fire off a grid search
        GBMModel.GBMParameters params = new GBMModel.GBMParameters();
        params._train = fr._key;
        params._response_column = "economy";
        Job<Grid> gs = GridSearch.startGridSearch(null, params, hyperParms);
        grid = gs.get();
        // Check that duplicate model have not been constructed
        Model[] models = grid.getModels();
        assertTrue("Number of returned models has to be > 0", models.length > 0);
        // But all off them should be same
        Key<Model> modelKey = models[0]._key;
        for (Model m : models) {
            assertTrue("Number of constructed models has to be equal to 1", modelKey == m._key);
        }
    } finally {
        if (old != null) {
            old.remove();
        }
        if (fr != null) {
            fr.remove();
        }
        if (grid != null) {
            grid.remove();
        }
    }
}
Also used : Frame(water.fvec.Frame) HashMap(java.util.HashMap) Grid(hex.grid.Grid) Vec(water.fvec.Vec) Model(hex.Model) Test(org.junit.Test)

Example 2 with Grid

use of hex.grid.Grid in project h2o-3 by h2oai.

the class KMeansGridTest method testUserPointsCarsGrid.

@Test
public void testUserPointsCarsGrid() {
    Grid grid = null;
    Frame fr = null;
    Frame init = ArrayUtils.frame(ard(ard(5.0, 3.4, 1.5, 0.2), ard(7.0, 3.2, 4.7, 1.4), ard(6.5, 3.0, 5.8, 2.2)));
    try {
        fr = parse_test_file("smalldata/iris/iris_wheader.csv");
        fr.remove("class").remove();
        DKV.put(fr);
        // Setup hyperparameter search space
        HashMap<String, Object[]> hyperParms = new HashMap<>();
        hyperParms.put("_k", new Integer[] { 3 });
        hyperParms.put("_init", new KMeans.Initialization[] { KMeans.Initialization.Random, KMeans.Initialization.PlusPlus, KMeans.Initialization.User, KMeans.Initialization.Furthest });
        hyperParms.put("_seed", new Long[] { 123456789L });
        // Fire off a grid search
        KMeansModel.KMeansParameters params = new KMeansModel.KMeansParameters();
        params._train = fr._key;
        params._user_points = init._key;
        // Get the Grid for this modeling class and frame
        Job<Grid> gs = GridSearch.startGridSearch(null, params, hyperParms);
        grid = gs.get();
        // Check that duplicate model have not been constructed
        Integer numModels = grid.getModels().length;
        System.out.println("Grid consists of " + numModels + " models");
        assertTrue(numModels == 4);
    } finally {
        if (fr != null) {
            fr.remove();
        }
        if (init != null) {
            init.remove();
        }
        if (grid != null) {
            grid.remove();
        }
    }
}
Also used : Frame(water.fvec.Frame) HashMap(java.util.HashMap) Grid(hex.grid.Grid) Test(org.junit.Test)

Example 3 with Grid

use of hex.grid.Grid in project h2o-3 by h2oai.

the class KMeansGridTest method testDuplicatesCarsGrid.

//@Ignore("PUBDEV-1643")
@Test
public void testDuplicatesCarsGrid() {
    Grid grid = null;
    Frame fr = null;
    try {
        fr = parse_test_file("smalldata/iris/iris_wheader.csv");
        fr.remove("class").remove();
        DKV.put(fr);
        // Setup hyperparameter search space
        HashMap<String, Object[]> hyperParms = new HashMap<>();
        hyperParms.put("_k", new Integer[] { 3, 3, 3 });
        hyperParms.put("_init", new KMeans.Initialization[] { KMeans.Initialization.Random, KMeans.Initialization.Random, KMeans.Initialization.Random });
        hyperParms.put("_seed", new Long[] { 123456789L, 123456789L, 123456789L });
        // Fire off a grid search
        KMeansModel.KMeansParameters params = new KMeansModel.KMeansParameters();
        params._train = fr._key;
        // Get the Grid for this modeling class and frame
        Job<Grid> gs = GridSearch.startGridSearch(null, params, hyperParms);
        grid = gs.get();
        // Check that duplicate model have not been constructed
        Model[] models = grid.getModels();
        assertTrue("Number of returned models has to be > 0", models.length > 0);
        // But all off them should be same
        Key<Model> modelKey = models[0]._key;
        for (Model m : models) {
            assertTrue("Number of constructed models has to be equal to 1", modelKey == m._key);
        }
    } finally {
        if (fr != null) {
            fr.remove();
        }
        if (grid != null) {
            grid.remove();
        }
    }
}
Also used : Frame(water.fvec.Frame) HashMap(java.util.HashMap) Grid(hex.grid.Grid) Model(hex.Model) Test(org.junit.Test)

Example 4 with Grid

use of hex.grid.Grid in project h2o-3 by h2oai.

the class GLRMGridTest method testMultipleGridInvocation.

@Test
public void testMultipleGridInvocation() {
    Grid<GLRMModel.GLRMParameters> grid = null;
    Frame fr = null;
    try {
        fr = parse_test_file("smalldata/iris/iris_wheader.csv");
        // Hyper-space
        HashMap<String, Object[]> hyperParms = new HashMap<String, Object[]>() {

            {
                put("_k", new Integer[] { 2, 4 });
                // Search over this range of the init enum
                put("_transform", new DataInfo.TransformType[] { DataInfo.TransformType.NONE, DataInfo.TransformType.DEMEAN });
            }
        };
        // Name of used hyper parameters
        String[] hyperParamNames = hyperParms.keySet().toArray(new String[hyperParms.size()]);
        Arrays.sort(hyperParamNames);
        int hyperSpaceSize = ArrayUtils.crossProductSize(hyperParms);
        // Create default parameters
        GLRMModel.GLRMParameters params = new GLRMModel.GLRMParameters();
        params._train = fr._key;
        params._seed = 4224L;
        params._loss = GlrmLoss.Absolute;
        params._init = GlrmInitialization.SVD;
        //
        // Fire off a grid search multiple times with same key and make sure
        // that results are same
        //
        final int ITER_CNT = 2;
        Key<Model>[][] modelKeys = new Key[ITER_CNT][];
        Key<Grid> gridKey = Key.make("GLRM_grid_iris" + Key.rand());
        for (int i = 0; i < ITER_CNT; i++) {
            Job<Grid> gs = GridSearch.startGridSearch(gridKey, params, hyperParms);
            grid = (Grid<GLRMModel.GLRMParameters>) gs.get();
            modelKeys[i] = grid.getModelKeys();
            // Make sure number of produced models match size of specified hyper space
            Assert.assertEquals("Size of grid should match to size of hyper space", hyperSpaceSize, grid.getModelCount() + grid.getFailureCount());
            //
            // Make sure that names of used parameters match
            //
            String[] gridHyperNames = grid.getHyperNames();
            Arrays.sort(gridHyperNames);
            Assert.assertArrayEquals("Hyper parameters names should match!", hyperParamNames, gridHyperNames);
        }
        Assert.assertArrayEquals("The model keys should be same between two iterations!", modelKeys[0], modelKeys[1]);
    } finally {
        if (fr != null) {
            fr.remove();
        }
        if (grid != null) {
            grid.remove();
        }
    }
}
Also used : DataInfo(hex.DataInfo) Frame(water.fvec.Frame) HashMap(java.util.HashMap) Grid(hex.grid.Grid) Model(hex.Model) Key(water.Key) Test(org.junit.Test)

Example 5 with Grid

use of hex.grid.Grid in project h2o-3 by h2oai.

the class GLRMGridTest method testGridAppend.

@Test
public void testGridAppend() {
    Grid<GLRMModel.GLRMParameters> grid = null;
    Frame fr = null;
    try {
        fr = parse_test_file("smalldata/iris/iris_wheader.csv");
        // Hyper-space
        HashMap<String, Object[]> hyperParms = new HashMap<String, Object[]>() {

            {
                put("_k", new Integer[] { 2, 4 });
                // Search over this range of the init enum
                put("_transform", new DataInfo.TransformType[] { DataInfo.TransformType.NONE, DataInfo.TransformType.DEMEAN });
            }
        };
        // Name of used hyper parameters
        final String[] hyperParamNames1 = hyperParms.keySet().toArray(new String[hyperParms.size()]);
        Arrays.sort(hyperParamNames1);
        final int hyperSpaceSize1 = ArrayUtils.crossProductSize(hyperParms);
        // Create default parameters
        GLRMModel.GLRMParameters params = new GLRMModel.GLRMParameters();
        params._train = fr._key;
        params._seed = 4224L;
        params._loss = GlrmLoss.Absolute;
        params._init = GlrmInitialization.SVD;
        //
        // Fire off a grid two  times with same key and make sure
        // that final grid contains all models from both runs.
        //
        Key<Grid> gridKey = Key.make("GLRM_grid_iris" + Key.rand());
        // 1st iteration
        final Job<Grid> gs1 = GridSearch.startGridSearch(gridKey, params, hyperParms);
        grid = (Grid<GLRMModel.GLRMParameters>) gs1.get();
        // Make sure number of produced models match size of specified hyper space
        Assert.assertEquals("Size of grid should match to size of hyper space", hyperSpaceSize1, grid.getModelCount() + grid.getFailureCount());
        // Make sure that names of used parameters match
        String[] gridHyperNames1 = grid.getHyperNames();
        Arrays.sort(gridHyperNames1);
        Assert.assertArrayEquals("Hyper parameters names should match!", hyperParamNames1, gridHyperNames1);
        // 2nd iteration
        hyperParms.put("_k", new Integer[] { 3 });
        final String[] hyperParamNames2 = hyperParms.keySet().toArray(new String[hyperParms.size()]);
        Arrays.sort(hyperParamNames2);
        final int hyperSpaceSize2 = ArrayUtils.crossProductSize(hyperParms);
        Assert.assertArrayEquals("Names of hyperspaces should be same!", hyperParamNames1, hyperParamNames2);
        final Job<Grid> gs2 = GridSearch.startGridSearch(gridKey, params, hyperParms);
        grid = (Grid<GLRMModel.GLRMParameters>) gs2.get();
        // Make sure number of produced models match size of specified hyper space
        Assert.assertEquals("Size of grid should match to size of hyper space", hyperSpaceSize1 + hyperSpaceSize2, grid.getModelCount() + grid.getFailureCount());
        // Make sure that names of used parameters match
        String[] gridHyperNames2 = grid.getHyperNames();
        Arrays.sort(gridHyperNames2);
        Assert.assertArrayEquals("Hyper parameters names should match!", hyperParamNames2, gridHyperNames2);
        // Verify PUBDEV-2633 - unique names of models
        Set<String> modelNames = new HashSet<>(grid.getModelCount());
        for (Key<Model> modelKey : grid.getModelKeys()) {
            modelNames.add(modelKey.toString());
        }
        Assert.assertEquals("Model names should be unique!", grid.getModelCount(), modelNames.size());
    } finally {
        if (fr != null) {
            fr.remove();
        }
        if (grid != null) {
            grid.remove();
        }
    }
}
Also used : DataInfo(hex.DataInfo) Frame(water.fvec.Frame) HashMap(java.util.HashMap) Grid(hex.grid.Grid) Model(hex.Model) HashSet(java.util.HashSet) Test(org.junit.Test)

Aggregations

Grid (hex.grid.Grid)15 HashMap (java.util.HashMap)12 Frame (water.fvec.Frame)12 Model (hex.Model)11 Test (org.junit.Test)11 Vec (water.fvec.Vec)6 ArrayList (java.util.ArrayList)3 Random (java.util.Random)3 Set (java.util.Set)3 DataInfo (hex.DataInfo)2 Key (water.Key)2 ModelBuilder (hex.ModelBuilder)1 DeepLearning (hex.deeplearning.DeepLearning)1 DeepLearningModel (hex.deeplearning.DeepLearningModel)1 GLM (hex.glm.GLM)1 GLMModel (hex.glm.GLMModel)1 GridSearch (hex.grid.GridSearch)1 DRFParametersV3 (hex.schemas.DRFV3.DRFParametersV3)1 DeepLearningParametersV3 (hex.schemas.DeepLearningV3.DeepLearningParametersV3)1 GBMParametersV3 (hex.schemas.GBMV3.GBMParametersV3)1