Search in sources :

Example 21 with Model

use of hex.Model in project h2o-3 by h2oai.

the class FramesHandler method doFetch.

private FramesV3 doFetch(int version, FramesV3 s) {
    s.createAndFillImpl();
    // safe
    Frame frame = getFromDKV("key", s.frame_id.key());
    s.frames = new FrameV3[1];
    // TODO: Refactor with FrameBaseV3
    s.frames[0] = new FrameV3(frame, s.row_offset, s.row_count).fillFromImpl(frame, s.row_offset, s.row_count, s.column_offset, s.column_count);
    if (s.find_compatible_models) {
        Model[] compatible = Frames.findCompatibleModels(frame, Models.fetchAll());
        s.compatible_models = new ModelSchemaV3[compatible.length];
        ((FrameV3) s.frames[0]).compatible_models = new String[compatible.length];
        int i = 0;
        for (Model m : compatible) {
            s.compatible_models[i] = (ModelSchemaV3) SchemaServer.schema(version, m).fillFromImpl(m);
            ((FrameV3) s.frames[0]).compatible_models[i] = m._key.toString();
            i++;
        }
    }
    return s;
}
Also used : Frame(water.fvec.Frame) Model(hex.Model)

Example 22 with Model

use of hex.Model in project h2o-3 by h2oai.

the class GBMGridTest method testCarsGrid.

@Test
public void testCarsGrid() {
    Grid<GBMModel.GBMParameters> grid = null;
    Frame fr = null;
    Vec old = null;
    try {
        fr = parse_test_file("smalldata/junit/cars.csv");
        // Remove unique id
        fr.remove("name").remove();
        old = fr.remove("cylinders");
        // response to last column
        fr.add("cylinders", old.toCategoricalVec());
        DKV.put(fr);
        // Setup hyperparameter search space
        final Double[] legalLearnRateOpts = new Double[] { 0.01, 0.1, 0.3 };
        final Double[] illegalLearnRateOpts = new Double[] { -1.0 };
        HashMap<String, Object[]> hyperParms = new HashMap<String, Object[]>() {

            {
                put("_ntrees", new Integer[] { 1, 2 });
                put("_distribution", new DistributionFamily[] { DistributionFamily.multinomial });
                put("_max_depth", new Integer[] { 1, 2, 5 });
                put("_learn_rate", ArrayUtils.join(legalLearnRateOpts, illegalLearnRateOpts));
            }
        };
        // Name of used hyper parameters
        String[] hyperParamNames = hyperParms.keySet().toArray(new String[hyperParms.size()]);
        Arrays.sort(hyperParamNames);
        int hyperSpaceSize = ArrayUtils.crossProductSize(hyperParms);
        // Fire off a grid search
        GBMModel.GBMParameters params = new GBMModel.GBMParameters();
        params._train = fr._key;
        params._response_column = "cylinders";
        // Get the Grid for this modeling class and frame
        Job<Grid> gs = GridSearch.startGridSearch(null, params, hyperParms);
        grid = (Grid<GBMModel.GBMParameters>) gs.get();
        // Make sure number of produced models match size of specified hyper space
        Assert.assertEquals("Size of grid (models+failures) should match to size of hyper space", hyperSpaceSize, grid.getModelCount() + grid.getFailureCount());
        //
        // Make sure that names of used parameters match
        //
        String[] gridHyperNames = grid.getHyperNames();
        Arrays.sort(gridHyperNames);
        Assert.assertArrayEquals("Hyper parameters names should match!", hyperParamNames, gridHyperNames);
        //
        // Make sure that values of used parameters match as well to the specified values
        //
        Key<Model>[] mKeys = grid.getModelKeys();
        Map<String, Set<Object>> usedHyperParams = GridTestUtils.initMap(hyperParamNames);
        for (Key<Model> mKey : mKeys) {
            GBMModel gbm = (GBMModel) mKey.get();
            System.out.println(gbm._output._scored_train[gbm._output._ntrees]._mse + " " + Arrays.deepToString(ArrayUtils.zip(grid.getHyperNames(), grid.getHyperValues(gbm._parms))));
            GridTestUtils.extractParams(usedHyperParams, gbm._parms, hyperParamNames);
        }
        // Remove illegal options
        hyperParms.put("_learn_rate", legalLearnRateOpts);
        GridTestUtils.assertParamsEqual("Grid models parameters have to cover specified hyper space", hyperParms, usedHyperParams);
        // Verify model failure
        Map<String, Set<Object>> failedHyperParams = GridTestUtils.initMap(hyperParamNames);
        ;
        for (Model.Parameters failedParams : grid.getFailedParameters()) {
            GridTestUtils.extractParams(failedHyperParams, failedParams, hyperParamNames);
        }
        hyperParms.put("_learn_rate", illegalLearnRateOpts);
        GridTestUtils.assertParamsEqual("Failed model parameters have to correspond to specified hyper space", hyperParms, failedHyperParams);
    } finally {
        if (old != null) {
            old.remove();
        }
        if (fr != null) {
            fr.remove();
        }
        if (grid != null) {
            grid.remove();
        }
    }
}
Also used : Frame(water.fvec.Frame) Set(java.util.Set) HashMap(java.util.HashMap) Grid(hex.grid.Grid) Vec(water.fvec.Vec) Model(hex.Model) Key(water.Key) Test(org.junit.Test)

Example 23 with Model

use of hex.Model in project h2o-3 by h2oai.

the class GBMGridTest method testRandomCarsGrid.

//@Ignore("PUBDEV-1648")
@Test
public void testRandomCarsGrid() {
    Grid grid = null;
    GBMModel gbmRebuilt = null;
    Frame fr = null;
    Vec old = null;
    try {
        fr = parse_test_file("smalldata/junit/cars.csv");
        fr.remove("name").remove();
        old = fr.remove("economy (mpg)");
        // response to last column
        fr.add("economy (mpg)", old);
        DKV.put(fr);
        // Setup random hyperparameter search space
        HashMap<String, Object[]> hyperParms = new HashMap<>();
        hyperParms.put("_distribution", new DistributionFamily[] { DistributionFamily.gaussian });
        // Construct random grid search space
        Random rng = new Random();
        Integer ntreesDim = rng.nextInt(4) + 1;
        Integer maxDepthDim = rng.nextInt(4) + 1;
        Integer learnRateDim = rng.nextInt(4) + 1;
        Integer[] ntreesArr = interval(1, 25);
        ArrayList<Integer> ntreesList = new ArrayList<>(Arrays.asList(ntreesArr));
        Collections.shuffle(ntreesList);
        Integer[] ntreesSpace = new Integer[ntreesDim];
        for (int i = 0; i < ntreesDim; i++) {
            ntreesSpace[i] = ntreesList.get(i);
        }
        Integer[] maxDepthArr = interval(1, 10);
        ArrayList<Integer> maxDepthList = new ArrayList<>(Arrays.asList(maxDepthArr));
        Collections.shuffle(maxDepthList);
        Integer[] maxDepthSpace = new Integer[maxDepthDim];
        for (int i = 0; i < maxDepthDim; i++) {
            maxDepthSpace[i] = maxDepthList.get(i);
        }
        Double[] learnRateArr = interval(0.01, 1.0, 0.01);
        ArrayList<Double> learnRateList = new ArrayList<>(Arrays.asList(learnRateArr));
        Collections.shuffle(learnRateList);
        Double[] learnRateSpace = new Double[learnRateDim];
        for (int i = 0; i < learnRateDim; i++) {
            learnRateSpace[i] = learnRateList.get(i);
        }
        hyperParms.put("_ntrees", ntreesSpace);
        hyperParms.put("_max_depth", maxDepthSpace);
        hyperParms.put("_learn_rate", learnRateSpace);
        // Fire off a grid search
        GBMModel.GBMParameters params = new GBMModel.GBMParameters();
        params._train = fr._key;
        params._response_column = "economy (mpg)";
        // Get the Grid for this modeling class and frame
        Job<Grid> gs = GridSearch.startGridSearch(null, params, hyperParms);
        grid = gs.get();
        System.out.println("ntrees search space: " + Arrays.toString(ntreesSpace));
        System.out.println("max_depth search space: " + Arrays.toString(maxDepthSpace));
        System.out.println("learn_rate search space: " + Arrays.toString(learnRateSpace));
        // Check that cardinality of grid
        Model[] ms = grid.getModels();
        Integer numModels = ms.length;
        System.out.println("Grid consists of " + numModels + " models");
        assertTrue(numModels == ntreesDim * maxDepthDim * learnRateDim);
        // Pick a random model from the grid
        HashMap<String, Object[]> randomHyperParms = new HashMap<>();
        randomHyperParms.put("_distribution", new DistributionFamily[] { DistributionFamily.gaussian });
        Integer ntreeVal = ntreesSpace[rng.nextInt(ntreesSpace.length)];
        randomHyperParms.put("_ntrees", new Integer[] { ntreeVal });
        Integer maxDepthVal = maxDepthSpace[rng.nextInt(maxDepthSpace.length)];
        randomHyperParms.put("_max_depth", maxDepthSpace);
        Double learnRateVal = learnRateSpace[rng.nextInt(learnRateSpace.length)];
        randomHyperParms.put("_learn_rate", learnRateSpace);
        //TODO: GBMModel gbmFromGrid = (GBMModel) g2.model(randomHyperParms).get();
        // Rebuild it with it's parameters
        params._distribution = DistributionFamily.gaussian;
        params._ntrees = ntreeVal;
        params._max_depth = maxDepthVal;
        params._learn_rate = learnRateVal;
        GBM gbm = new GBM(params);
        gbmRebuilt = gbm.trainModel().get();
        assertTrue(gbm.isStopped());
        // Make sure the MSE metrics match
        //double fromGridMSE = gbmFromGrid._output._scored_train[gbmFromGrid._output._ntrees]._mse;
        double rebuiltMSE = gbmRebuilt._output._scored_train[gbmRebuilt._output._ntrees]._mse;
        //System.out.println("The random grid model's MSE: " + fromGridMSE);
        System.out.println("The rebuilt model's MSE: " + rebuiltMSE);
    //assertEquals(fromGridMSE, rebuiltMSE);
    } finally {
        if (old != null)
            old.remove();
        if (fr != null)
            fr.remove();
        if (grid != null)
            grid.remove();
        if (gbmRebuilt != null)
            gbmRebuilt.remove();
    }
}
Also used : Frame(water.fvec.Frame) HashMap(java.util.HashMap) Grid(hex.grid.Grid) ArrayList(java.util.ArrayList) Random(java.util.Random) Vec(water.fvec.Vec) Model(hex.Model) Test(org.junit.Test)

Example 24 with Model

use of hex.Model in project h2o-3 by h2oai.

the class ModelSerializationTest method testSimpleModel.

@Test
public void testSimpleModel() throws IOException {
    // Create a model
    BlahModel.BlahParameters params = new BlahModel.BlahParameters();
    BlahModel.BlahOutput output = new BlahModel.BlahOutput(false, false, false);
    Model model = new BlahModel(Key.make("BLAHModel"), params, output);
    DKV.put(model._key, model);
    // Create a serializer, save a model and reload it
    Model loadedModel = null;
    try {
        loadedModel = saveAndLoad(model);
        // And compare
        assertModelBinaryEquals(model, loadedModel);
    } finally {
        if (loadedModel != null)
            loadedModel.delete();
    }
}
Also used : GLMModel(hex.glm.GLMModel) GBMModel(hex.tree.gbm.GBMModel) Model(hex.Model) SharedTreeModel(hex.tree.SharedTreeModel) DRFModel(hex.tree.drf.DRFModel) Test(org.junit.Test)

Example 25 with Model

use of hex.Model in project h2o-3 by h2oai.

the class DRFGridTest method testRandomCarsGrid.

//@Ignore("PUBDEV-1648")
@Test
public void testRandomCarsGrid() {
    Grid grid = null;
    DRFModel drfRebuilt = null;
    Frame fr = null;
    try {
        fr = parse_test_file("smalldata/junit/cars.csv");
        fr.remove("name").remove();
        Vec old = fr.remove("economy (mpg)");
        // response to last column
        fr.add("economy (mpg)", old);
        DKV.put(fr);
        // Setup random hyperparameter search space
        HashMap<String, Object[]> hyperParms = new HashMap<>();
        // Construct random grid search space
        long seed = System.nanoTime();
        Random rng = new Random(seed);
        // Limit to 1-3 randomly, 4 times.  Average total number of models is
        // 2^4, or 16.  Max is 81 models.
        Integer ntreesDim = rng.nextInt(3) + 1;
        Integer maxDepthDim = rng.nextInt(3) + 1;
        Integer mtriesDim = rng.nextInt(3) + 1;
        Integer sampleRateDim = rng.nextInt(3) + 1;
        Integer[] ntreesArr = interval(1, 15);
        ArrayList<Integer> ntreesList = new ArrayList<>(Arrays.asList(ntreesArr));
        Collections.shuffle(ntreesList);
        Integer[] ntreesSpace = new Integer[ntreesDim];
        for (int i = 0; i < ntreesDim; i++) {
            ntreesSpace[i] = ntreesList.get(i);
        }
        Integer[] maxDepthArr = interval(1, 10);
        ArrayList<Integer> maxDepthList = new ArrayList<>(Arrays.asList(maxDepthArr));
        Collections.shuffle(maxDepthList);
        Integer[] maxDepthSpace = new Integer[maxDepthDim];
        for (int i = 0; i < maxDepthDim; i++) {
            maxDepthSpace[i] = maxDepthList.get(i);
        }
        Integer[] mtriesArr = interval(1, 5);
        ArrayList<Integer> mtriesList = new ArrayList<>(Arrays.asList(mtriesArr));
        Collections.shuffle(mtriesList);
        Integer[] mtriesSpace = new Integer[mtriesDim];
        for (int i = 0; i < mtriesDim; i++) {
            mtriesSpace[i] = mtriesList.get(i);
        }
        Double[] sampleRateArr = interval(0.01, 0.99, 0.01);
        ArrayList<Double> sampleRateList = new ArrayList<>(Arrays.asList(sampleRateArr));
        Collections.shuffle(sampleRateList);
        Double[] sampleRateSpace = new Double[sampleRateDim];
        for (int i = 0; i < sampleRateDim; i++) {
            sampleRateSpace[i] = sampleRateList.get(i);
        }
        hyperParms.put("_ntrees", ntreesSpace);
        hyperParms.put("_max_depth", maxDepthSpace);
        hyperParms.put("_mtries", mtriesSpace);
        hyperParms.put("_sample_rate", sampleRateSpace);
        // Fire off a grid search
        DRFModel.DRFParameters params = new DRFModel.DRFParameters();
        params._train = fr._key;
        params._response_column = "economy (mpg)";
        // Get the Grid for this modeling class and frame
        Job<Grid> gs = GridSearch.startGridSearch(null, params, hyperParms);
        grid = gs.get();
        System.out.println("Test seed: " + seed);
        System.out.println("ntrees search space: " + Arrays.toString(ntreesSpace));
        System.out.println("max_depth search space: " + Arrays.toString(maxDepthSpace));
        System.out.println("mtries search space: " + Arrays.toString(mtriesSpace));
        System.out.println("sample_rate search space: " + Arrays.toString(sampleRateSpace));
        // Check that cardinality of grid
        Model[] ms = grid.getModels();
        int numModels = ms.length;
        System.out.println("Grid consists of " + numModels + " models");
        assertEquals("Number of models should match hyper space size", numModels, ntreesDim * maxDepthDim * sampleRateDim * mtriesDim + grid.getFailureCount());
        // Pick a random model from the grid
        HashMap<String, Object[]> randomHyperParms = new HashMap<>();
        Integer ntreeVal = ntreesSpace[rng.nextInt(ntreesSpace.length)];
        randomHyperParms.put("_ntrees", new Integer[] { ntreeVal });
        Integer maxDepthVal = maxDepthSpace[rng.nextInt(maxDepthSpace.length)];
        randomHyperParms.put("_max_depth", maxDepthSpace);
        Integer mtriesVal = mtriesSpace[rng.nextInt(mtriesSpace.length)];
        randomHyperParms.put("_max_depth", mtriesSpace);
        Double sampleRateVal = sampleRateSpace[rng.nextInt(sampleRateSpace.length)];
        randomHyperParms.put("_sample_rate", sampleRateSpace);
        //TODO: DRFModel drfFromGrid = (DRFModel) g2.model(randomHyperParms).get();
        // Rebuild it with it's parameters
        params._ntrees = ntreeVal;
        params._max_depth = maxDepthVal;
        params._mtries = mtriesVal;
        drfRebuilt = new DRF(params).trainModel().get();
        // Make sure the MSE metrics match
        //double fromGridMSE = drfFromGrid._output._scored_train[drfFromGrid._output._ntrees]._mse;
        double rebuiltMSE = drfRebuilt._output._scored_train[drfRebuilt._output._ntrees]._mse;
        //System.out.println("The random grid model's MSE: " + fromGridMSE);
        System.out.println("The rebuilt model's MSE: " + rebuiltMSE);
    //assertEquals(fromGridMSE, rebuiltMSE);
    } finally {
        if (fr != null) {
            fr.remove();
        }
        if (grid != null) {
            grid.remove();
        }
        if (drfRebuilt != null) {
            drfRebuilt.remove();
        }
    }
}
Also used : Frame(water.fvec.Frame) HashMap(java.util.HashMap) Grid(hex.grid.Grid) ArrayList(java.util.ArrayList) Random(java.util.Random) Vec(water.fvec.Vec) Model(hex.Model) Test(org.junit.Test)

Aggregations

Model (hex.Model)28 Frame (water.fvec.Frame)15 Grid (hex.grid.Grid)11 HashMap (java.util.HashMap)11 Test (org.junit.Test)11 MojoModel (hex.genmodel.MojoModel)8 Vec (water.fvec.Vec)6 URI (java.net.URI)4 ArrayList (java.util.ArrayList)4 Key (water.Key)4 Persist (water.persist.Persist)4 Random (java.util.Random)3 Set (java.util.Set)3 DataInfo (hex.DataInfo)2 ModelMojoWriter (hex.ModelMojoWriter)1 GLMModel (hex.glm.GLMModel)1 SharedTreeModel (hex.tree.SharedTreeModel)1 DRFModel (hex.tree.drf.DRFModel)1 GBMModel (hex.tree.gbm.GBMModel)1 HashSet (java.util.HashSet)1