Search in sources :

Example 1 with GLMGradientSolver

use of hex.glm.GLM.GLMGradientSolver in project h2o-3 by h2oai.

the class ComputationState method setLambda.

public void setLambda(double lambda) {
    adjustToNewLambda(0, _lambda);
    // strong rules are to be applied on the gradient with no l2 penalty
    // NOTE: we start with lambdaOld being 0, not lambda_max
    // non-recursive strong rules should use lambdaMax instead of _lambda
    // However, it seems tobe working nicely to use 0 instead and be more aggressive on the predictor pruning
    // (shoudl be safe as we check the KKTs anyways)
    applyStrongRules(lambda, _lambda);
    adjustToNewLambda(lambda, 0);
    _lambda = lambda;
    _gslvr = new GLMGradientSolver(_job, _parms, _activeData, l2pen(), _activeBC);
}
Also used : GLMGradientSolver(hex.glm.GLM.GLMGradientSolver)

Example 2 with GLMGradientSolver

use of hex.glm.GLM.GLMGradientSolver in project h2o-3 by h2oai.

the class ComputationState method removeCols.

public int[] removeCols(int[] cols) {
    int[] activeCols = ArrayUtils.removeIds(_activeData.activeCols(), cols);
    if (_beta != null)
        _beta = ArrayUtils.removeIds(_beta, cols);
    if (_u != null)
        _u = ArrayUtils.removeIds(_u, cols);
    if (_ginfo != null && _ginfo._gradient != null)
        _ginfo._gradient = ArrayUtils.removeIds(_ginfo._gradient, cols);
    _activeData = _dinfo.filterExpandedColumns(activeCols);
    _activeBC = _bc.filterExpandedColumns(activeCols);
    _gslvr = new GLMGradientSolver(_job, _parms, _activeData, (1 - _alpha) * _lambda, _activeBC);
    return activeCols;
}
Also used : GLMGradientSolver(hex.glm.GLM.GLMGradientSolver)

Example 3 with GLMGradientSolver

use of hex.glm.GLM.GLMGradientSolver in project h2o-3 by h2oai.

the class L_BFGS_Test method logistic.

@Test
public void logistic() {
    Key parsedKey = Key.make("prostate");
    DataInfo dinfo = null;
    try {
        GLMParameters glmp = new GLMParameters(Family.binomial, Family.binomial.defaultLink);
        glmp._alpha = new double[] { 0 };
        glmp._lambda = new double[] { 1e-5 };
        Frame source = parse_test_file(parsedKey, "smalldata/glm_test/prostate_cat_replaced.csv");
        source.add("CAPSULE", source.remove("CAPSULE"));
        source.remove("ID").remove();
        Frame valid = new Frame(source._names.clone(), source.vecs().clone());
        dinfo = new DataInfo(source, valid, 1, false, DataInfo.TransformType.STANDARDIZE, DataInfo.TransformType.NONE, true, false, false, /* weights */
        false, /* offset */
        false, /* fold */
        false);
        DKV.put(dinfo._key, dinfo);
        glmp._obj_reg = 1 / 380.0;
        GLMGradientSolver solver = new GLMGradientSolver(null, glmp, dinfo, 1e-5, null);
        L_BFGS lbfgs = new L_BFGS().setGradEps(1e-8);
        double[] beta = MemoryManager.malloc8d(dinfo.fullN() + 1);
        beta[beta.length - 1] = new GLMWeightsFun(glmp).link(source.vec("CAPSULE").mean());
        L_BFGS.Result r = lbfgs.solve(solver, beta, solver.getGradient(beta), new L_BFGS.ProgressMonitor() {

            int _i = 0;

            public boolean progress(double[] beta, GradientInfo ginfo) {
                System.out.println(++_i + ":" + ginfo._objVal + ", " + ArrayUtils.l2norm2(ginfo._gradient, false));
                return true;
            }
        });
        assertEquals(378.34, 2 * r.ginfo._objVal * source.numRows(), 1e-1);
    } finally {
        if (dinfo != null)
            DKV.remove(dinfo._key);
        Value v = DKV.get(parsedKey);
        if (v != null) {
            v.<Frame>get().delete();
        }
    }
}
Also used : DataInfo(hex.DataInfo) Frame(water.fvec.Frame) GLMGradientSolver(hex.glm.GLM.GLMGradientSolver) GradientInfo(hex.optimization.OptimizationUtils.GradientInfo) GLMWeightsFun(hex.glm.GLMModel.GLMWeightsFun) GLMParameters(hex.glm.GLMModel.GLMParameters) Test(org.junit.Test)

Example 4 with GLMGradientSolver

use of hex.glm.GLM.GLMGradientSolver in project h2o-3 by h2oai.

the class L_BFGS_Test method testArcene.

// Test LSM on arcene - wide dataset with ~10k columns
// test warm start and max #iteratoions
@Test
public void testArcene() {
    Key parsedKey = Key.make("arcene_parsed");
    DataInfo dinfo = null;
    try {
        Frame source = parse_test_file(parsedKey, "smalldata/glm_test/arcene.csv");
        Frame valid = new Frame(source._names.clone(), source.vecs().clone());
        GLMParameters glmp = new GLMParameters(Family.gaussian);
        glmp._lambda = new double[] { 1e-5 };
        glmp._alpha = new double[] { 0 };
        glmp._obj_reg = 0.01;
        dinfo = new DataInfo(source, valid, 1, false, DataInfo.TransformType.STANDARDIZE, DataInfo.TransformType.NONE, true, false, false, /* weights */
        false, /* offset */
        false, /* fold */
        false);
        DKV.put(dinfo._key, dinfo);
        GradientSolver solver = new GLMGradientSolver(null, glmp, dinfo, 1e-5, null);
        L_BFGS lbfgs = new L_BFGS().setMaxIter(20);
        double[] beta = MemoryManager.malloc8d(dinfo.fullN() + 1);
        beta[beta.length - 1] = new GLMWeightsFun(glmp).link(source.lastVec().mean());
        L_BFGS.Result r1 = lbfgs.solve(solver, beta.clone(), solver.getGradient(beta), new L_BFGS.ProgressMonitor() {

            int _i = 0;

            public boolean progress(double[] beta, GradientInfo ginfo) {
                System.out.println(++_i + ":" + ginfo._objVal);
                return true;
            }
        });
        lbfgs.setMaxIter(50);
        final int iter = r1.iter;
        L_BFGS.Result r2 = lbfgs.solve(solver, r1.coefs, r1.ginfo, new L_BFGS.ProgressMonitor() {

            int _i = 0;

            public boolean progress(double[] beta, GradientInfo ginfo) {
                System.out.println(iter + " + " + ++_i + ":" + ginfo._objVal);
                return true;
            }
        });
        System.out.println();
        lbfgs = new L_BFGS().setMaxIter(100);
        L_BFGS.Result r3 = lbfgs.solve(solver, beta.clone(), solver.getGradient(beta), new L_BFGS.ProgressMonitor() {

            int _i = 0;

            public boolean progress(double[] beta, GradientInfo ginfo) {
                System.out.println(++_i + ":" + ginfo._objVal + ", " + ArrayUtils.l2norm2(ginfo._gradient, false));
                return true;
            }
        });
        assertEquals(r1.iter, 20);
        //      assertEquals (r1.iter + r2.iter,r3.iter); // should be equal? got mismatch by 2
        assertEquals(r2.ginfo._objVal, r3.ginfo._objVal, 1e-8);
        assertEquals(.5 * glmp._lambda[0] * ArrayUtils.l2norm(r3.coefs, true) + r3.ginfo._objVal, 1e-4, 5e-4);
        assertTrue("iter# expected < 100, got " + r3.iter, r3.iter < 100);
    } finally {
        if (dinfo != null)
            DKV.remove(dinfo._key);
        Value v = DKV.get(parsedKey);
        if (v != null) {
            v.<Frame>get().delete();
        }
    }
}
Also used : DataInfo(hex.DataInfo) Frame(water.fvec.Frame) GLMGradientSolver(hex.glm.GLM.GLMGradientSolver) GradientSolver(hex.optimization.OptimizationUtils.GradientSolver) GLMGradientSolver(hex.glm.GLM.GLMGradientSolver) GradientInfo(hex.optimization.OptimizationUtils.GradientInfo) GLMWeightsFun(hex.glm.GLMModel.GLMWeightsFun) GLMParameters(hex.glm.GLMModel.GLMParameters) Test(org.junit.Test)

Example 5 with GLMGradientSolver

use of hex.glm.GLM.GLMGradientSolver in project h2o-3 by h2oai.

the class ComputationState method applyStrongRules.

/**
   * Apply strong rules to filter out expected inactive (with zero coefficient) predictors.
   *
   * @return indices of expected active predictors.
   */
protected void applyStrongRules(double lambdaNew, double lambdaOld) {
    lambdaNew = Math.min(_lambdaMax, lambdaNew);
    lambdaOld = Math.min(_lambdaMax, lambdaOld);
    if (_parms._family == Family.multinomial) /* && _parms._solver != GLMParameters.Solver.L_BFGS */
    {
        applyStrongRulesMultinomial(lambdaNew, lambdaOld);
        return;
    }
    int P = _dinfo.fullN();
    _activeBC = _bc;
    _activeData = _activeData != null ? _activeData : _dinfo;
    _allIn = _allIn || _parms._alpha[0] * lambdaNew == 0 || _activeBC.hasBounds();
    if (!_allIn) {
        int newlySelected = 0;
        final double rhs = Math.max(0, _alpha * (2 * lambdaNew - lambdaOld));
        int[] newCols = MemoryManager.malloc4(P);
        int j = 0;
        int[] oldActiveCols = _activeData._activeCols == null ? new int[] { P } : _activeData.activeCols();
        for (int i = 0; i < P; ++i) {
            if (j < oldActiveCols.length && oldActiveCols[j] == i)
                j++;
            else if (_ginfo._gradient[i] > rhs || -_ginfo._gradient[i] > rhs)
                newCols[newlySelected++] = i;
        }
        if (_parms._max_active_predictors != -1 && (oldActiveCols.length + newlySelected - 1) > _parms._max_active_predictors) {
            Integer[] bigInts = ArrayUtils.toIntegers(newCols, 0, newlySelected);
            Arrays.sort(bigInts, new Comparator<Integer>() {

                @Override
                public int compare(Integer o1, Integer o2) {
                    return (int) Math.signum(_ginfo._gradient[o2.intValue()] * _ginfo._gradient[o2.intValue()] - _ginfo._gradient[o1.intValue()] * _ginfo._gradient[o1.intValue()]);
                }
            });
            newCols = ArrayUtils.toInt(bigInts, 0, _parms._max_active_predictors - oldActiveCols.length + 1);
            Arrays.sort(newCols);
        } else
            newCols = Arrays.copyOf(newCols, newlySelected);
        newCols = ArrayUtils.sortedMerge(oldActiveCols, newCols);
        // merge already active columns in
        int active = newCols.length;
        _allIn = active == P;
        if (!_allIn) {
            int[] cols = newCols;
            // intercept is always selected, even if it is false (it's gonna be dropped later, it is needed for other stuff too)
            assert cols[active - 1] == P;
            _beta = ArrayUtils.select(_beta, cols);
            if (_u != null)
                _u = ArrayUtils.select(_u, cols);
            _activeData = _dinfo.filterExpandedColumns(cols);
            assert _activeData.activeCols().length == _beta.length;
            assert _u == null || _activeData.activeCols().length == _u.length;
            _ginfo = new GLMGradientInfo(_ginfo._likelihood, _ginfo._objVal, ArrayUtils.select(_ginfo._gradient, cols));
            _activeBC = _bc.filterExpandedColumns(_activeData.activeCols());
            _gslvr = new GLMGradientSolver(_job, _parms, _activeData, (1 - _alpha) * _lambda, _bc);
            assert _beta.length == cols.length;
            return;
        }
    }
    _activeData = _dinfo;
}
Also used : GLMGradientInfo(hex.glm.GLM.GLMGradientInfo) GLMGradientSolver(hex.glm.GLM.GLMGradientSolver) BetaConstraint(hex.glm.GLM.BetaConstraint)

Aggregations

GLMGradientSolver (hex.glm.GLM.GLMGradientSolver)6 DataInfo (hex.DataInfo)2 BetaConstraint (hex.glm.GLM.BetaConstraint)2 GLMGradientInfo (hex.glm.GLM.GLMGradientInfo)2 GLMParameters (hex.glm.GLMModel.GLMParameters)2 GLMWeightsFun (hex.glm.GLMModel.GLMWeightsFun)2 GradientInfo (hex.optimization.OptimizationUtils.GradientInfo)2 Test (org.junit.Test)2 Frame (water.fvec.Frame)2 GradientSolver (hex.optimization.OptimizationUtils.GradientSolver)1