Search in sources :

Example 6 with GLMParameters

use of hex.glm.GLMModel.GLMParameters in project h2o-3 by h2oai.

the class GLMTest method testMultinomialGradient.

@Test
public void testMultinomialGradient() {
    Key parsed = Key.make("covtype");
    Frame fr = null;
    double[][] beta = new double[][] { { 5.886754459, -0.270479620, -0.075466082, -0.157524534, -0.225843747, -0.975387326, -0.018808013, -0.597839451, 0.931896624, 1.060006010, 1.513888539, 0.588802780, 0.157815155, -2.158268564, -0.504962385, -1.218970183, -0.840958642, -0.425931637, -0.355548831, -0.845035489, -0.065364107, 0.215897656, 0.213009374, 0.006831714, 1.212368946, 0.006106444, -0.350643486, -0.268207009, -0.252099054, -1.374010836, 0.257935860, 0.397459631, 0.411530391, 0.728368253, 0.292076224, 0.170774269, -0.059574793, 0.273670163, 0.180844505, -0.186483071, 0.369186813, 0.161909512, 0.249411716, -0.094481604, 0.413354360, -0.419043967, 0.044517794, -0.252596992, -0.371926422, 0.253835004, 0.588162090, 0.123330837, 2.856812217 }, { 1.89790254, -0.29776886, 0.15613197, 0.37602123, -0.36464436, -0.30240244, -0.57284370, 0.62408956, -0.22369305, 0.33644602, 0.79886400, 0.65351945, -0.53682819, -0.58319898, -1.07762513, -0.28527470, 0.46563482, -0.76956081, -0.72513805, 0.29857876, 0.03993456, 0.15835864, -0.24797599, -0.02483503, 0.93822490, -0.12406087, -0.75837978, -0.23516944, -0.48520212, 0.73571466, 0.19652011, 0.21602846, -0.32743154, 0.49421903, -0.02262943, 0.08093216, 0.11524497, 0.21657128, 0.18072853, 0.30872666, 0.17947687, 0.20156151, 0.16812179, -0.12286908, 0.29630502, 0.09992565, -0.00603293, 0.20700058, -0.49706211, -0.14534034, -0.18819217, 0.03642680, 7.31828340 }, { -6.098728943, 0.284144173, 0.114373474, 0.328977319, 0.417830082, 0.285696150, -0.652674822, 0.319136906, -0.942440279, -1.619235397, -1.272568201, -0.079855555, 1.191263550, 0.205102353, 0.991773314, 0.930363203, 1.014021007, 0.651243292, 0.646532457, 0.914336030, 0.012171754, -0.053042102, 0.777710362, 0.527369151, -0.019496049, 0.186290583, 0.554926655, 0.476911685, 0.529207520, -0.133243060, -0.198957274, -0.561552913, -0.069239959, -0.236600870, -0.969503908, -0.848089244, 0.001498592, -0.241007311, -0.129271912, -0.259961677, -0.895676033, -0.865827509, -0.972629899, 0.307756211, -1.809423763, -0.199557594, 0.024221965, -0.024834485, 0.047044475, 0.028951561, -0.157701002, 0.007940593, -2.073329675 }, { -8.36044440, 0.10541672, -0.01628680, -0.43787017, 0.42383466, 2.45802808, 0.59818831, 0.61971728, -0.62598983, 0.20261555, -0.21909545, 0.35125447, -3.29155913, 3.74668257, 0.18126128, -0.13948924, 0.20465077, -0.39930635, 0.15704570, -0.01036891, 0.02822546, -0.02349234, -0.93922249, -0.20025910, 0.25184125, 0.06415974, 0.35271290, 0.04609060, 0.03018497, -0.10641540, 0.00354805, -0.12194129, 0.05115876, 0.23981864, -0.10007012, 0.04773226, 0.01217421, 0.02367464, 0.05552397, 0.05343606, -0.05818705, -0.30055029, -0.03898723, 0.02322906, -0.04908215, 0.04274038, 0.25045428, 0.08561191, 0.15228160, 0.67005377, 0.59311621, 0.58814959, -4.83776046 }, { -0.39251919, 0.07053038, 0.09397355, 0.19394977, -0.02030732, -0.87489691, 0.21295049, 0.31800509, -0.05347208, -1.03491602, 2.20106706, -1.20895873, 1.06158893, -3.29214054, -0.69334082, 0.62309414, -1.64753442, 0.10189669, -0.44746013, -1.04084383, -0.01997483, -0.23356180, 0.34384724, 0.37566329, -1.79316510, 0.46183758, -0.58814389, 0.12072985, 0.48349078, 1.18956325, 0.41962148, 0.18767160, -0.25252495, -1.13671540, 0.71488183, 0.27405258, -0.03527945, 0.43124949, -0.28740586, 0.35165348, 1.17594079, 1.13893507, 0.49423372, 0.30525649, 0.70809680, 0.16660330, -0.37726163, -0.14687217, -0.17079711, -1.01897715, -1.17494223, -0.72698683, 1.64022531 }, { -5.892381502, 0.295534637, -0.112763568, 0.080283203, 0.197113227, 0.525435203, 0.727252262, -1.190672917, 1.137103389, -0.648526151, -2.581362158, -0.268338673, 2.010179009, 0.902074450, 0.816138328, 0.557071470, 0.389932578, 0.009422297, 0.542270816, 0.550653667, 0.005211720, -0.071954379, 0.320008238, 0.155814784, -0.264213966, 0.320538295, 0.569730803, 0.444518874, 0.247279544, -0.319484330, -0.372129988, 0.340944707, -0.158424299, -0.479426774, 0.026966661, 0.273389077, -0.004744599, -0.339321329, -0.119323949, -0.210123558, -1.218998166, -0.740525896, 0.134778587, 0.252701229, 0.527468284, 0.214164427, -0.080104361, -0.021448994, 0.004509104, -0.189729053, -0.335041198, -0.080698796, -1.192518082 }, { 12.9594170391, -0.1873774300, -0.1599625360, -0.3838368119, -0.4279825390, -1.1164727575, -0.2940645257, -0.0924364781, -0.2234047720, 1.7036099945, -0.4407937881, -0.0364237384, -0.5924593214, 1.1797487023, 0.2867554171, -0.4667946900, 0.4142538835, 0.8322365174, 0.1822980332, 0.1326797653, -0.0002045542, 0.0077943238, -0.4673767424, -0.8405848140, -0.3255599769, -0.9148717663, 0.2197967986, -0.5848745645, -0.5528616430, 0.0078757154, -0.3065382365, -0.4586101971, 0.3449315968, 0.3903371200, 0.0582787537, 0.0012089013, -0.0293189213, -0.3648369414, 0.1189047254, -0.0572478953, 0.4482567793, 0.4044976082, -0.0349286763, -0.6715923088, -0.0867185553, 0.0951677966, 0.1442048837, 0.1531401571, 0.8359504674, 0.4012062075, 0.6745982951, 0.0518378060, -3.7117127004 } };
    double[] exp_grad = new double[] { -8.955455e-05, 6.429112e-04, 4.384381e-04, 1.363695e-03, 4.714468e-04, -2.264769e-03, 4.412849e-04, 1.461760e-03, -2.957754e-05, -2.244325e-03, -2.744438e-03, 9.109376e-04, 1.920764e-03, 7.562221e-04, 1.840414e-04, 2.455081e-04, 3.077885e-04, 2.833261e-04, 1.248686e-04, 2.509248e-04, 9.681260e-06, -1.097335e-04, 1.005934e-03, 5.623159e-04, -2.568397e-03, 1.113900e-03, 1.263858e-04, 9.075801e-05, 8.056571e-05, 1.848318e-04, -1.291357e-04, -3.710570e-04, 5.693621e-05, 1.328082e-04, 3.244018e-04, 4.130594e-04, 9.681066e-06, 5.215260e-04, 4.054695e-04, 2.904901e-05, -3.074865e-03, -1.247025e-04, 1.044981e-03, 8.612937e-04, 1.376526e-03, 4.543256e-05, -4.596319e-06, 3.062111e-05, 5.649646e-05, 5.392599e-04, 9.681357e-04, 2.298219e-04, -1.369109e-03, -6.884926e-04, -9.921529e-04, -5.369346e-04, -1.732447e-03, 5.677645e-04, 1.655432e-03, -4.786890e-04, -8.688757e-04, 2.922016e-04, 3.601210e-03, 4.050781e-03, -6.409806e-04, -2.788663e-03, -1.426483e-03, -1.946904e-04, -8.279536e-04, -3.148338e-04, 2.263577e-06, -1.320917e-04, 3.635088e-04, -1.024655e-05, 1.079612e-04, -1.607591e-03, -1.801967e-04, 2.548311e-03, -1.007139e-03, -1.336990e-04, 2.538803e-04, -4.851292e-04, -9.168206e-04, 1.027708e-04, 1.061545e-03, -4.098038e-05, 1.070448e-04, 3.220238e-04, -7.011285e-04, -1.024153e-05, -7.967380e-04, -2.708138e-04, -2.698165e-04, 3.088978e-03, 4.260939e-04, -5.868815e-04, -1.562233e-03, -1.007565e-03, -2.034456e-04, -6.198011e-04, -3.277194e-05, -5.976557e-05, -1.143198e-03, -1.025416e-03, 3.671158e-04, 1.448332e-03, 1.940231e-03, -6.130695e-04, -2.086460e-03, -2.969848e-04, 1.455597e-04, 1.745515e-03, 2.123991e-03, 9.036201e-04, -5.270206e-04, 1.053891e-03, 1.358911e-03, 2.528711e-04, 1.326987e-04, -1.825879e-03, -6.085616e-04, -1.347628e-04, 3.499544e-04, 3.616313e-04, -7.008672e-04, -1.211077e-03, 1.117824e-05, 3.535679e-05, -2.668903e-03, -2.399884e-04, 3.979678e-04, 2.519517e-04, 1.113206e-04, 6.029871e-04, 3.512828e-04, 2.134159e-04, 7.590052e-05, 1.729959e-04, 4.472972e-05, 2.094373e-04, 3.136961e-04, 1.835530e-04, 1.117824e-05, 8.225263e-05, 4.330828e-05, 3.354142e-05, 7.452883e-04, 4.631413e-04, 2.054077e-04, -5.520636e-05, 2.818063e-04, 5.246077e-05, 1.131811e-04, 3.535664e-05, 6.523360e-05, 3.072416e-04, 2.913399e-04, 2.422760e-04, -1.580841e-03, -1.117356e-04, 2.573351e-04, 8.117137e-04, 1.168873e-04, -4.216143e-04, -5.847717e-05, 3.501109e-04, 2.344622e-04, -1.330097e-04, -5.948309e-04, -2.349808e-04, -4.495448e-05, -1.916493e-04, 5.017336e-04, -8.440468e-05, 4.767465e-04, 2.485018e-04, 2.060573e-04, -1.527142e-04, -9.268231e-06, -1.985972e-06, -6.285478e-06, -2.214673e-05, 5.822250e-04, -7.069316e-05, -4.387924e-05, -2.774128e-04, -5.455282e-04, 3.186328e-04, -3.793242e-05, -1.349306e-05, -3.070112e-05, -7.951882e-06, -3.723186e-05, -5.571437e-05, -3.260780e-05, -1.987225e-06, -1.462245e-05, -7.699184e-06, -5.962867e-06, -1.316053e-04, -8.108570e-05, -3.651228e-05, -5.312255e-05, -5.009791e-05, -9.325808e-06, -2.012086e-05, -6.285571e-06, -1.159698e-05, -5.462022e-05, -5.179310e-05, -4.307092e-05, 2.810360e-04, 3.869942e-04, -3.450936e-05, -7.805675e-05, 6.405561e-04, -2.284402e-04, -1.866295e-04, -4.858359e-04, 3.496890e-04, 7.352780e-04, 5.767877e-04, -8.477014e-04, -5.512698e-05, 1.091158e-03, -1.900036e-04, -4.632766e-05, 1.086153e-05, -7.743051e-05, -7.545391e-04, -3.143243e-05, -6.316374e-05, -2.435782e-06, -7.707894e-06, 4.451785e-04, 2.043479e-04, -8.673378e-05, -3.314975e-05, -3.181369e-05, -5.422704e-04, -9.020739e-05, 6.747588e-04, 5.997742e-06, -9.729086e-04, -9.751490e-06, -4.565744e-05, -4.181943e-04, 7.522183e-04, -2.436958e-06, 2.531532e-04, -9.441600e-06, 2.317743e-04, 4.254207e-04, -3.224488e-04, 3.979052e-04, 2.066697e-04, 2.486194e-05, 1.189306e-04, -2.465884e-05, -7.708071e-06, -1.422152e-05, -6.697064e-05, -6.351172e-05, -5.281060e-05, 3.446379e-04, -1.212986e-03, 9.206612e-04, 6.469824e-04, -6.605882e-04, -1.646537e-05, -6.854543e-04, -2.079925e-03, -1.031449e-03, 3.926585e-04, -1.556234e-03, -1.129748e-03, -2.113480e-04, -4.922559e-04, 1.938461e-03, 6.900824e-04, 1.497533e-04, -6.140808e-04, -3.365137e-04, 8.516225e-04, 5.874586e-04, -9.342693e-06, -2.955083e-05, 2.692614e-03, -9.928211e-04, -3.326157e-04, -3.572773e-04, 1.641113e-04, 7.442831e-05, -2.543959e-04, -1.783712e-04, -6.343638e-05, 9.077554e-05, -3.738480e-05, -1.750387e-04, -6.568480e-04, -2.035799e-04, -9.342694e-06, -6.874421e-05, -3.619677e-05, -2.803369e-05, -6.228932e-04, -3.870861e-04, -1.103792e-03, 9.585360e-04, -7.037269e-05, 2.736606e-04, -9.459508e-05, -2.955084e-05, -5.452180e-05, -2.567899e-04, -2.434930e-04, -2.024919e-04, 1.321256e-03, -2.244563e-04, -1.811758e-04, 8.043173e-04, 5.688820e-04, -5.182511e-04, -2.056167e-04, 1.290635e-04, -1.049207e-03, -7.305304e-04, -8.364983e-04, -4.528248e-04, -2.113987e-04, 3.279472e-04, 2.459491e-04, 5.986061e-05, 7.984705e-05, 1.001005e-04, 2.377746e-04, 4.061439e-05, 8.161668e-05, 3.151497e-06, 9.959707e-06, 1.549140e-04, 6.411739e-05, 1.121613e-04, 7.559378e-05, 4.110778e-05, 6.574476e-05, 7.925128e-05, 6.011770e-05, 2.139605e-05, 4.934971e-05, -5.597385e-06, -1.913622e-04, 1.706349e-04, -4.115145e-04, 3.149101e-06, 2.317293e-05, -1.246264e-04, 9.448371e-06, -4.303234e-04, 2.608783e-05, 7.889196e-05, -3.559375e-04, -5.551586e-04, -2.777131e-04, 6.505911e-04, 1.033867e-05, 1.837583e-05, 6.750772e-04, 1.247379e-04, -5.408403e-04, -4.453114e-04 };
    Vec origRes = null;
    try {
        fr = parse_test_file(parsed, "smalldata/covtype/covtype.20k.data");
        fr.remove("C21").remove();
        fr.remove("C29").remove();
        GLMParameters params = new GLMParameters(Family.multinomial);
        params._response_column = "C55";
        // params._response = fr.find(params._response_column);
        params._ignored_columns = new String[] {};
        params._train = parsed;
        params._lambda = new double[] { 0 };
        params._alpha = new double[] { 0 };
        origRes = fr.remove("C55");
        Vec res = fr.add("C55", origRes.toCategoricalVec());
        double[] means = new double[res.domain().length];
        long[] bins = res.bins();
        double sumInv = 1.0 / ArrayUtils.sum(bins);
        for (int i = 0; i < bins.length; ++i) means[i] = bins[i] * sumInv;
        DataInfo dinfo = new DataInfo(fr, null, 1, true, TransformType.STANDARDIZE, DataInfo.TransformType.NONE, true, false, false, false, false, false);
        GLMTask.GLMMultinomialGradientTask gmt = new GLMTask.GLMMultinomialGradientTask(null, dinfo, 0, beta, 1.0 / fr.numRows()).doAll(dinfo._adaptedFrame);
        assertEquals(0.6421113, gmt._likelihood / fr.numRows(), 1e-8);
        System.out.println("likelihood = " + gmt._likelihood / fr.numRows());
        double[] g = gmt.gradient();
        for (int i = 0; i < g.length; ++i) assertEquals("Mismatch at coefficient '" + "' (" + i + ")", exp_grad[i], g[i], 1e-8);
    } finally {
        if (origRes != null)
            origRes.remove();
        if (fr != null)
            fr.delete();
    }
}
Also used : GLMTask(hex.glm.GLMTask) GLMParameters(hex.glm.GLMModel.GLMParameters)

Example 7 with GLMParameters

use of hex.glm.GLMModel.GLMParameters in project h2o-3 by h2oai.

the class GLMTest method test_COD_Airlines_SingleLambda.

// test categorical autoexpansions, run on airlines which has several categorical columns,
// once on explicitly expanded data, once on h2o autoexpanded and compare the results
@Test
public void test_COD_Airlines_SingleLambda() {
    GLMModel model1 = null;
    //  Distance + Origin + Dest + UniqueCarrier
    Frame fr = parse_test_file(Key.make("Airlines"), "smalldata/airlines/AirlinesTrain.csv.zip");
    String[] ignoredCols = new String[] { "IsDepDelayed_REC" };
    try {
        Scope.enter();
        GLMParameters params = new GLMParameters(Family.binomial);
        params._response_column = "IsDepDelayed";
        params._ignored_columns = ignoredCols;
        params._train = fr._key;
        params._valid = fr._key;
        //null; //new double[]{0.02934};//{0.02934494}; // null;
        params._lambda = new double[] { 0.01 };
        params._alpha = new double[] { 1 };
        params._standardize = false;
        params._solver = Solver.COORDINATE_DESCENT_NAIVE;
        params._lambda_search = true;
        params._nlambdas = 5;
        GLM glm = new GLM(params);
        model1 = glm.trainModel().get();
        double[] beta = model1.beta();
        double l1pen = ArrayUtils.l1norm(beta, true);
        double l2pen = ArrayUtils.l2norm2(beta, true);
    //System.out.println( " lambda min " + params._l2pen[params._l2pen.length-1] );
    //System.out.println( " lambda_max " + model1._lambda_max);
    //System.out.println(" intercept " + beta[beta.length-1]);
    //      double objective = model1._output._training_metrics./model1._nobs +
    //              params._l2pen[params._l2pen.length-1]*params._alpha[0]*l1pen + params._l2pen[params._l2pen.length-1]*(1-params._alpha[0])*l2pen/2  ;
    //      System.out.println( " objective value " + objective);
    //      assertEquals(0.670921, objective,1e-4);
    } finally {
        fr.delete();
        if (model1 != null)
            model1.delete();
    }
}
Also used : GLMParameters(hex.glm.GLMModel.GLMParameters) BufferedString(water.parser.BufferedString)

Example 8 with GLMParameters

use of hex.glm.GLMModel.GLMParameters in project h2o-3 by h2oai.

the class GLMTest method testXval.

@Test
public void testXval() {
    GLMModel model = null;
    Frame fr = parse_test_file("smalldata/glm_test/prostate_cat_replaced.csv");
    try {
        GLMParameters params = new GLMParameters(Family.binomial);
        params._response_column = "CAPSULE";
        params._ignored_columns = new String[] { "ID" };
        params._train = fr._key;
        params._lambda_search = true;
        params._nfolds = 3;
        params._standardize = false;
        GLM glm = new GLM(params);
        model = glm.trainModel().get();
    } finally {
        fr.delete();
        if (model != null) {
            for (Key k : model._output._cross_validation_models) Keyed.remove(k);
            model.delete();
        }
    }
}
Also used : GLMParameters(hex.glm.GLMModel.GLMParameters)

Example 9 with GLMParameters

use of hex.glm.GLMModel.GLMParameters in project h2o-3 by h2oai.

the class GLMTest method testGammaRegression.

/**
   * Test Gamma regression on simple and small synthetic dataset.
   * Equation is: y = 1/(x+1);
   *
   * @throws ExecutionException
   * @throws InterruptedException
   */
@Test
public void testGammaRegression() throws InterruptedException, ExecutionException {
    GLMModel model = null;
    Frame fr = null, res = null;
    try {
        // make data so that the expected coefficients is icept = col[0] = 1.0
        Key raw = Key.make("gamma_test_data_raw");
        Key parsed = Key.make("gamma_test_data_parsed");
        FVecTest.makeByteVec(raw, "x,y\n0,1\n1,0.5\n2,0.3333333\n3,0.25\n4,0.2\n5,0.1666667\n6,0.1428571\n7,0.125");
        fr = ParseDataset.parse(parsed, raw);
        //      /public GLM2(String desc, Key dest, Frame src, Family family, Link link, double alpha, double lambda) {
        //      double [] vals = new double[] {1.0,1.0};
        //public GLM2(String desc, Key dest, Frame src, Family family, Link link, double alpha, double lambda) {
        GLMParameters params = new GLMParameters(Family.gamma);
        // params._response = 1;
        params._response_column = fr._names[1];
        params._train = parsed;
        params._lambda = new double[] { 0 };
        model = new GLM(params).trainModel().get();
        for (double c : model.beta()) assertEquals(1.0, c, 1e-4);
        // test scoring
        testScoring(model, fr);
    } finally {
        if (fr != null)
            fr.delete();
        if (res != null)
            res.delete();
        if (model != null)
            model.delete();
    }
}
Also used : GLMParameters(hex.glm.GLMModel.GLMParameters)

Example 10 with GLMParameters

use of hex.glm.GLMModel.GLMParameters in project h2o-3 by h2oai.

the class GLMTest method testBounds.

// Leask xval keys
//  @Test public void testXval() {
//    GLMModel model = null;
//    Frame fr = parse_test_file("smalldata/glm_test/prostate_cat_replaced.csv");
//    Frame score = null;
//    try{
//      Scope.enter();
//      // R results
////      Coefficients:
////        (Intercept)           ID          AGE       RACER2       RACER3        DPROS        DCAPS          PSA          VOL      GLEASON
////          -8.894088     0.001588    -0.009589     0.231777    -0.459937     0.556231     0.556395     0.027854    -0.011355     1.010179
//      String [] cfs1 = new String [] {"Intercept","AGE", "RACE.R2","RACE.R3", "DPROS", "DCAPS", "PSA", "VOL", "GLEASON"};
//      double [] vals = new double [] {-8.14867, -0.01368, 0.32337, -0.38028, 0.55964, 0.49548, 0.02794, -0.01104, 0.97704};
//      GLMParameters params = new GLMParameters(Family.binomial);
//      params._n_folds = 10;
//      params._response_column = "CAPSULE";
//      params._ignored_columns = new String[]{"ID"};
//      params._train = fr._key;
//      params._lambda = new double[]{0};
//      model = new GLM(params,Key.make("prostate_model")).trainModel().get();
//      HashMap<String, Double> coefs = model.coefficients();
//      for(int i = 0; i < cfs1.length; ++i)
//        assertEquals(vals[i], coefs.get(cfs1[i]),1e-4);
//      GLMValidation val = model.trainVal();
////      assertEquals(512.3, val.nullDeviance(),1e-1);
////      assertEquals(378.3, val.residualDeviance(),1e-1);
////      assertEquals(396.3, val.AIC(),1e-1);
////      score = model.score(fr);
////
////      hex.ModelMetrics mm = hex.ModelMetrics.getFromDKV(model,fr);
////
////      AUCData adata = mm._aucdata;
////      assertEquals(val.auc(),adata.AUC(),1e-2);
////      GLMValidation val2 = new GLMValidationTsk(params,model._ymu,rank(model.beta())).doAll(new Vec[]{fr.vec("CAPSULE"),score.vec("1")})._val;
////      assertEquals(val.residualDeviance(),val2.residualDeviance(),1e-6);
////      assertEquals(val.nullDeviance(),val2.nullDeviance(),1e-6);
//    } finally {
//      fr.delete();
//      if(model != null)model.delete();
//      if(score != null)score.delete();
//      Scope.exit();
//    }
//  }
/**
   * Test bounds on prostate dataset, 2 cases :
   * 1) test against known result in glmnet (with elastic net regularization) with elastic net penalty
   * 2) test with no regularization, check the ginfo in the end.
   */
@Test
public void testBounds() {
    //    glmnet's result:
    //    res2 <- glmnet(x=M,y=D$CAPSULE,lower.limits=-.5,upper.limits=.5,family='binomial')
    //    res2$beta[,58]
    //    AGE        RACE          DPROS       PSA         VOL         GLEASON
    //    -0.00616326 -0.50000000  0.50000000  0.03628192 -0.01249324  0.50000000 //    res2$a0[100]
    //    res2$a0[58]
    //    s57
    //    -4.155864
    //    lambda = 0.001108, null dev =  512.2888, res dev = 379.7597
    GLMModel model = null;
    Key parsed = Key.make("prostate_parsed");
    Key modelKey = Key.make("prostate_model");
    Frame fr = parse_test_file(parsed, "smalldata/logreg/prostate.csv");
    Key betaConsKey = Key.make("beta_constraints");
    String[] cfs1 = new String[] { "AGE", "RACE", "DPROS", "DCAPS", "PSA", "VOL", "GLEASON", "Intercept" };
    double[] vals = new double[] { -0.006502588, -0.500000000, 0.500000000, 0.400000000, 0.034826559, -0.011661747, 0.500000000, -4.564024 };
    //    [AGE, RACE, DPROS, DCAPS, PSA, VOL, GLEASON, Intercept]
    FVecTest.makeByteVec(betaConsKey, "names, lower_bounds, upper_bounds\n AGE, -.5, .5\n RACE, -.5, .5\n DCAPS, -.4, .4\n DPROS, -.5, .5 \nPSA, -.5, .5\n VOL, -.5, .5\nGLEASON, -.5, .5");
    Frame betaConstraints = ParseDataset.parse(Key.make("beta_constraints.hex"), betaConsKey);
    try {
        // H2O differs on intercept and race, same residual deviance though
        GLMParameters params = new GLMParameters();
        params._standardize = true;
        params._family = Family.binomial;
        params._beta_constraints = betaConstraints._key;
        params._response_column = "CAPSULE";
        params._ignored_columns = new String[] { "ID" };
        params._train = fr._key;
        params._objective_epsilon = 0;
        params._alpha = new double[] { 1 };
        params._lambda = new double[] { 0.001607 };
        params._obj_reg = 1.0 / 380;
        GLM glm = new GLM(params, modelKey);
        model = glm.trainModel().get();
        assertTrue(glm.isStopped());
        //      Map<String, Double> coefs =  model.coefficients();
        //      for (int i = 0; i < cfs1.length; ++i)
        //        assertEquals(vals[i], coefs.get(cfs1[i]), 1e-1);
        ModelMetricsBinomialGLM val = (ModelMetricsBinomialGLM) model._output._training_metrics;
        assertEquals(512.2888, val._nullDev, 1e-1);
        // 388.4952716196743
        assertTrue(val._resDev <= 388.5);
        model.delete();
        params._lambda = new double[] { 0 };
        params._alpha = new double[] { 0 };
        FVecTest.makeByteVec(betaConsKey, "names, lower_bounds, upper_bounds\n RACE, -.5, .5\n DCAPS, -.4, .4\n DPROS, -.5, .5 \nPSA, -.5, .5\n VOL, -.5, .5");
        betaConstraints = ParseDataset.parse(Key.make("beta_constraints.hex"), betaConsKey);
        glm = new GLM(params, modelKey);
        model = glm.trainModel().get();
        assertTrue(glm.isStopped());
        double[] beta = model.beta();
        System.out.println("beta = " + Arrays.toString(beta));
        fr.add("CAPSULE", fr.remove("CAPSULE"));
        fr.remove("ID").remove();
        DKV.put(fr._key, fr);
        // now check the ginfo
        DataInfo dinfo = new DataInfo(fr, null, 1, true, TransformType.NONE, DataInfo.TransformType.NONE, true, false, false, false, false, false);
        GLMGradientTask lt = new GLMBinomialGradientTask(null, dinfo, params, 0, beta).doAll(dinfo._adaptedFrame);
        double[] grad = lt._gradient;
        String[] names = model.dinfo().coefNames();
        BufferedString tmpStr = new BufferedString();
        outer: for (int i = 0; i < names.length; ++i) {
            for (int j = 0; j < betaConstraints.numRows(); ++j) {
                if (betaConstraints.vec("names").atStr(tmpStr, j).toString().equals(names[i])) {
                    if (Math.abs(beta[i] - betaConstraints.vec("lower_bounds").at(j)) < 1e-4 || Math.abs(beta[i] - betaConstraints.vec("upper_bounds").at(j)) < 1e-4) {
                        continue outer;
                    }
                }
            }
            assertEquals(0, grad[i], 1e-2);
        }
    } finally {
        fr.delete();
        betaConstraints.delete();
        if (model != null)
            model.delete();
    }
}
Also used : BufferedString(water.parser.BufferedString) GLMParameters(hex.glm.GLMModel.GLMParameters) BufferedString(water.parser.BufferedString)

Aggregations

GLMParameters (hex.glm.GLMModel.GLMParameters)50 Test (org.junit.Test)23 Solver (hex.glm.GLMModel.GLMParameters.Solver)16 ModelMetricsBinomialGLM (hex.ModelMetricsBinomialGLM)13 BufferedString (water.parser.BufferedString)10 ModelMetricsRegressionGLM (hex.ModelMetricsRegressionGLM)8 Frame (water.fvec.Frame)7 H2OModelBuilderIllegalArgumentException (water.exceptions.H2OModelBuilderIllegalArgumentException)6 ModelMetricsMultinomialGLM (hex.ModelMetricsBinomialGLM.ModelMetricsMultinomialGLM)4 GLMWeightsFun (hex.glm.GLMModel.GLMWeightsFun)4 HashMap (java.util.HashMap)4 NFSFileVec (water.fvec.NFSFileVec)3 hex (hex)2 DataInfo (hex.DataInfo)2 GLMGradientSolver (hex.glm.GLM.GLMGradientSolver)2 GradientInfo (hex.optimization.OptimizationUtils.GradientInfo)2 GLMTask (hex.glm.GLMTask)1 GradientSolver (hex.optimization.OptimizationUtils.GradientSolver)1 H2OCountedCompleter (water.H2O.H2OCountedCompleter)1