Search in sources :

Example 6 with GLMModel

use of hex.glm.GLMModel in project h2o-2 by h2oai.

the class GLMMakeModel method serve.

@Override
protected Response serve() {
    try {
        double[] b;
        String[] ns = names.split(",");
        if (beta.length == model.coefficients_names.length && Arrays.equals(ns, model.coefficients_names))
            b = beta;
        else {
            b = MemoryManager.malloc8d(model.coefficients_names.length);
            HashMap<String, Integer> map = new HashMap<String, Integer>();
            for (int i = 0; i < model.coefficients_names.length; ++i) map.put(model.coefficients_names[i], i);
            for (int i = 0; i < ns.length; ++i) {
                String s = ns[i];
                if (!map.containsKey(s))
                    throw new IllegalArgumentException("Unknown coefficient " + s);
                b[map.get(s)] = beta[i];
            }
        }
        //Key.make((byte) 1, /*Key.HIDDEN_USER_KEY*/Key.USER_KEY, H2O.SELF);
        destination_key = Key.make();
        GLMModel m = new GLMModel(model.get_params(), destination_key, model._dataKey, model.getParams(), model.coefficients_names, b, model.dinfo(), threshold);
        m.delete_and_lock(null).unlock(null);
        return Response.done(this);
    } catch (Throwable t) {
        return Response.error(t);
    }
}
Also used : GLMModel(hex.glm.GLMModel) HashMap(java.util.HashMap)

Example 7 with GLMModel

use of hex.glm.GLMModel in project h2o-3 by h2oai.

the class TestCase method execute.

public TestCaseResult execute() throws Exception, AssertionError {
    loadTestCaseDataSets();
    makeModelParameters();
    double startTime = 0, stopTime = 0;
    if (!grid) {
        Model.Output modelOutput = null;
        DRF drfJob;
        DRFModel drfModel = null;
        GLM glmJob;
        GLMModel glmModel = null;
        GBM gbmJob;
        GBMModel gbmModel = null;
        DeepLearning dlJob;
        DeepLearningModel dlModel = null;
        String bestModelJson = null;
        try {
            switch(algo) {
                case "drf":
                    drfJob = new DRF((DRFModel.DRFParameters) params);
                    AccuracyTestingSuite.summaryLog.println("Training DRF model.");
                    startTime = System.currentTimeMillis();
                    drfModel = drfJob.trainModel().get();
                    stopTime = System.currentTimeMillis();
                    modelOutput = drfModel._output;
                    bestModelJson = drfModel._parms.toJsonString();
                    break;
                case "glm":
                    glmJob = new GLM((GLMModel.GLMParameters) params, Key.<GLMModel>make("GLMModel"));
                    AccuracyTestingSuite.summaryLog.println("Training GLM model.");
                    startTime = System.currentTimeMillis();
                    glmModel = glmJob.trainModel().get();
                    stopTime = System.currentTimeMillis();
                    modelOutput = glmModel._output;
                    bestModelJson = glmModel._parms.toJsonString();
                    break;
                case "gbm":
                    gbmJob = new GBM((GBMModel.GBMParameters) params);
                    AccuracyTestingSuite.summaryLog.println("Training GBM model.");
                    startTime = System.currentTimeMillis();
                    gbmModel = gbmJob.trainModel().get();
                    stopTime = System.currentTimeMillis();
                    modelOutput = gbmModel._output;
                    bestModelJson = gbmModel._parms.toJsonString();
                    break;
                case "dl":
                    dlJob = new DeepLearning((DeepLearningModel.DeepLearningParameters) params);
                    AccuracyTestingSuite.summaryLog.println("Training DL model.");
                    startTime = System.currentTimeMillis();
                    dlModel = dlJob.trainModel().get();
                    stopTime = System.currentTimeMillis();
                    modelOutput = dlModel._output;
                    bestModelJson = dlModel._parms.toJsonString();
                    break;
            }
        } catch (Exception e) {
            throw new Exception(e);
        } finally {
            if (drfModel != null) {
                drfModel.delete();
            }
            if (glmModel != null) {
                glmModel.delete();
            }
            if (gbmModel != null) {
                gbmModel.delete();
            }
            if (dlModel != null) {
                dlModel.delete();
            }
        }
        removeTestCaseDataSetFrames();
        //Add check if cv is used
        if (params._nfolds > 0) {
            return new TestCaseResult(testCaseId, getMetrics(modelOutput._training_metrics), getMetrics(modelOutput._cross_validation_metrics), stopTime - startTime, bestModelJson, this, trainingDataSet, testingDataSet);
        } else {
            return new TestCaseResult(testCaseId, getMetrics(modelOutput._training_metrics), getMetrics(modelOutput._validation_metrics), stopTime - startTime, bestModelJson, this, trainingDataSet, testingDataSet);
        }
    } else {
        assert !modelSelectionCriteria.equals("");
        makeGridParameters();
        makeSearchCriteria();
        Grid grid = null;
        Model bestModel = null;
        String bestModelJson = null;
        try {
            SchemaServer.registerAllSchemasIfNecessary();
            switch(// TODO: Hack for PUBDEV-2812
            algo) {
                case "drf":
                    if (!drfRegistered) {
                        new DRF(true);
                        new DRFParametersV3();
                        drfRegistered = true;
                    }
                    break;
                case "glm":
                    if (!glmRegistered) {
                        new GLM(true);
                        new GLMParametersV3();
                        glmRegistered = true;
                    }
                    break;
                case "gbm":
                    if (!gbmRegistered) {
                        new GBM(true);
                        new GBMParametersV3();
                        gbmRegistered = true;
                    }
                    break;
                case "dl":
                    if (!dlRegistered) {
                        new DeepLearning(true);
                        new DeepLearningParametersV3();
                        dlRegistered = true;
                    }
                    break;
            }
            startTime = System.currentTimeMillis();
            // TODO: ModelParametersBuilderFactory parameter must be instantiated properly
            Job<Grid> gs = GridSearch.startGridSearch(null, params, hyperParms, new GridSearch.SimpleParametersBuilderFactory<>(), searchCriteria);
            grid = gs.get();
            stopTime = System.currentTimeMillis();
            boolean higherIsBetter = higherIsBetter(modelSelectionCriteria);
            double bestScore = higherIsBetter ? -Double.MAX_VALUE : Double.MAX_VALUE;
            for (Model m : grid.getModels()) {
                double validationMetricScore = getMetrics(m._output._validation_metrics).get(modelSelectionCriteria);
                AccuracyTestingSuite.summaryLog.println(modelSelectionCriteria + " for model " + m._key.toString() + " is " + validationMetricScore);
                if (higherIsBetter ? validationMetricScore > bestScore : validationMetricScore < bestScore) {
                    bestScore = validationMetricScore;
                    bestModel = m;
                    bestModelJson = bestModel._parms.toJsonString();
                }
            }
            AccuracyTestingSuite.summaryLog.println("Best model: " + bestModel._key.toString());
            AccuracyTestingSuite.summaryLog.println("Best model parameters: " + bestModelJson);
        } catch (Exception e) {
            throw new Exception(e);
        } finally {
            if (grid != null) {
                grid.delete();
            }
        }
        removeTestCaseDataSetFrames();
        //Add check if cv is used
        if (params._nfolds > 0) {
            return new TestCaseResult(testCaseId, getMetrics(bestModel._output._training_metrics), getMetrics(bestModel._output._cross_validation_metrics), stopTime - startTime, bestModelJson, this, trainingDataSet, testingDataSet);
        } else {
            return new TestCaseResult(testCaseId, getMetrics(bestModel._output._training_metrics), getMetrics(bestModel._output._validation_metrics), stopTime - startTime, bestModelJson, this, trainingDataSet, testingDataSet);
        }
    }
}
Also used : Grid(hex.grid.Grid) GLM(hex.glm.GLM) DeepLearning(hex.deeplearning.DeepLearning) GBMParametersV3(hex.schemas.GBMV3.GBMParametersV3) GBM(hex.tree.gbm.GBM) GBMModel(hex.tree.gbm.GBMModel) DRFModel(hex.tree.drf.DRFModel) GLMModel(hex.glm.GLMModel) IOException(java.io.IOException) GridSearch(hex.grid.GridSearch) DeepLearningParametersV3(hex.schemas.DeepLearningV3.DeepLearningParametersV3) GLMModel(hex.glm.GLMModel) DeepLearningModel(hex.deeplearning.DeepLearningModel) SharedTreeModel(hex.tree.SharedTreeModel) GBMModel(hex.tree.gbm.GBMModel) DRFModel(hex.tree.drf.DRFModel) DRF(hex.tree.drf.DRF) GLMParametersV3(hex.schemas.GLMV3.GLMParametersV3) DeepLearningModel(hex.deeplearning.DeepLearningModel) DRFParametersV3(hex.schemas.DRFV3.DRFParametersV3)

Example 8 with GLMModel

use of hex.glm.GLMModel in project h2o-2 by h2oai.

the class GLMPredict method serve.

@Override
protected Response serve() {
    try {
        if (model == null)
            throw new IllegalArgumentException("Model is required to perform validation!");
        final Key predictionKey = (prediction == null) ? Key.make("__Prediction_" + Key.make()) : prediction;
        GLMModel m = new GLMModel.GetScoringModelTask(null, model, lambda).invokeTask()._res;
        // Create a new random key
        if (prediction == null)
            prediction = Key.make("__Prediction_" + Key.make());
        Frame fr = new Frame(prediction, new String[0], new Vec[0]).delete_and_lock(null);
        fr = m.score(data);
        // Jam in the frame key
        fr = new Frame(prediction, fr._names, fr.vecs());
        fr.unlock(null);
        return Inspect2.redirect(this, prediction.toString());
    } catch (Throwable t) {
        return Response.error(t);
    }
}
Also used : GLMModel(hex.glm.GLMModel) Frame(water.fvec.Frame) Vec(water.fvec.Vec) RString(water.util.RString)

Example 9 with GLMModel

use of hex.glm.GLMModel in project h2o-3 by h2oai.

the class XValPredictionsCheck method testXValPredictions.

@Test
public void testXValPredictions() {
    final int nfolds = 3;
    Frame tfr = null;
    try {
        // Load data, hack frames
        tfr = parse_test_file("smalldata/iris/iris_wheader.csv");
        Frame foldId = new Frame(new String[] { "foldId" }, new Vec[] { AstKFold.kfoldColumn(tfr.vec("class").makeZero(), nfolds, 543216789) });
        tfr.add(foldId);
        DKV.put(tfr);
        // GBM
        GBMModel.GBMParameters parms = new GBMModel.GBMParameters();
        parms._train = tfr._key;
        parms._response_column = "class";
        parms._ntrees = 1;
        parms._max_depth = 1;
        parms._fold_column = "foldId";
        parms._distribution = DistributionFamily.multinomial;
        parms._keep_cross_validation_predictions = true;
        GBM job = new GBM(parms);
        GBMModel gbm = job.trainModel().get();
        checkModel(gbm, foldId.anyVec(), 3);
        // DRF
        DRFModel.DRFParameters parmsDRF = new DRFModel.DRFParameters();
        parmsDRF._train = tfr._key;
        parmsDRF._response_column = "class";
        parmsDRF._ntrees = 1;
        parmsDRF._max_depth = 1;
        parmsDRF._fold_column = "foldId";
        parmsDRF._distribution = DistributionFamily.multinomial;
        parmsDRF._keep_cross_validation_predictions = true;
        DRF drfJob = new DRF(parmsDRF);
        DRFModel drf = drfJob.trainModel().get();
        checkModel(drf, foldId.anyVec(), 3);
        // GLM
        GLMModel.GLMParameters parmsGLM = new GLMModel.GLMParameters();
        parmsGLM._train = tfr._key;
        parmsGLM._response_column = "sepal_len";
        parmsGLM._fold_column = "foldId";
        parmsGLM._keep_cross_validation_predictions = true;
        GLM glmJob = new GLM(parmsGLM);
        GLMModel glm = glmJob.trainModel().get();
        checkModel(glm, foldId.anyVec(), 1);
        // DL
        DeepLearningModel.DeepLearningParameters parmsDL = new DeepLearningModel.DeepLearningParameters();
        parmsDL._train = tfr._key;
        parmsDL._response_column = "class";
        parmsDL._hidden = new int[] { 1 };
        parmsDL._epochs = 1;
        parmsDL._fold_column = "foldId";
        parmsDL._keep_cross_validation_predictions = true;
        DeepLearning dlJob = new DeepLearning(parmsDL);
        DeepLearningModel dl = dlJob.trainModel().get();
        checkModel(dl, foldId.anyVec(), 3);
    } finally {
        if (tfr != null)
            tfr.remove();
    }
}
Also used : Frame(water.fvec.Frame) DRFModel(hex.tree.drf.DRFModel) GLMModel(hex.glm.GLMModel) GLM(hex.glm.GLM) DeepLearning(hex.deeplearning.DeepLearning) GBMModel(hex.tree.gbm.GBMModel) GBM(hex.tree.gbm.GBM) DRF(hex.tree.drf.DRF) DeepLearningModel(hex.deeplearning.DeepLearningModel) Test(org.junit.Test)

Example 10 with GLMModel

use of hex.glm.GLMModel in project h2o-3 by h2oai.

the class MakeGLMModelHandler method make_model.

public GLMModelV3 make_model(int version, MakeGLMModelV3 args) {
    GLMModel model = DKV.getGet(args.model.key());
    if (model == null)
        throw new IllegalArgumentException("missing source model " + args.model);
    String[] names = model._output.coefficientNames();
    Map<String, Double> coefs = model.coefficients();
    for (int i = 0; i < args.names.length; ++i) coefs.put(args.names[i], args.beta[i]);
    double[] beta = model.beta().clone();
    for (int i = 0; i < beta.length; ++i) beta[i] = coefs.get(names[i]);
    GLMModel m = new GLMModel(args.dest != null ? args.dest.key() : Key.make(), model._parms, null, model._ymu, Double.NaN, Double.NaN, -1);
    DataInfo dinfo = model.dinfo();
    dinfo.setPredictorTransform(TransformType.NONE);
    // GLMOutput(DataInfo dinfo, String[] column_names, String[][] domains, String[] coefficient_names, boolean binomial) {
    m._output = new GLMOutput(model.dinfo(), model._output._names, model._output._domains, model._output.coefficientNames(), model._output._binomial, beta);
    DKV.put(m._key, m);
    GLMModelV3 res = new GLMModelV3();
    res.fillFromImpl(m);
    return res;
}
Also used : DataInfo(hex.DataInfo) GLMOutput(hex.glm.GLMModel.GLMOutput) GLMModel(hex.glm.GLMModel)

Aggregations

GLMModel (hex.glm.GLMModel)11 Frame (water.fvec.Frame)5 GLM (hex.glm.GLM)4 Test (org.junit.Test)4 GBM (hex.tree.gbm.GBM)3 GBMModel (hex.tree.gbm.GBMModel)3 DeepLearning (hex.deeplearning.DeepLearning)2 DeepLearningModel (hex.deeplearning.DeepLearningModel)2 DRF (hex.tree.drf.DRF)2 DRFModel (hex.tree.drf.DRFModel)2 NFSFileVec (water.fvec.NFSFileVec)2 DataInfo (hex.DataInfo)1 GLMOutput (hex.glm.GLMModel.GLMOutput)1 Grid (hex.grid.Grid)1 GridSearch (hex.grid.GridSearch)1 L_BFGS (hex.optimization.L_BFGS)1 Quantile (hex.quantile.Quantile)1 QuantileModel (hex.quantile.QuantileModel)1 DRFParametersV3 (hex.schemas.DRFV3.DRFParametersV3)1 DeepLearningParametersV3 (hex.schemas.DeepLearningV3.DeepLearningParametersV3)1