Search in sources :

Example 1 with L_BFGS

use of hex.optimization.L_BFGS in project h2o-3 by h2oai.

the class GLM method init.

@Override
public void init(boolean expensive) {
    super.init(expensive);
    hide("_balance_classes", "Not applicable since class balancing is not required for GLM.");
    hide("_max_after_balance_size", "Not applicable since class balancing is not required for GLM.");
    hide("_class_sampling_factors", "Not applicable since class balancing is not required for GLM.");
    _parms.validate(this);
    if (_response != null) {
        if (!isClassifier() && _response.isCategorical())
            error("_response", H2O.technote(2, "Regression requires numeric response, got categorical."));
        switch(_parms._family) {
            case binomial:
                if (!_response.isBinary() && _nclass != 2)
                    error("_family", H2O.technote(2, "Binomial requires the response to be a 2-class categorical or a binary column (0/1)"));
                break;
            case multinomial:
                if (_nclass <= 2)
                    error("_family", H2O.technote(2, "Multinomial requires a categorical response with at least 3 levels (for 2 class problem use family=binomial."));
                break;
            case poisson:
                if (_nclass != 1)
                    error("_family", "Poisson requires the response to be numeric.");
                if (_response.min() < 0)
                    error("_family", "Poisson requires response >= 0");
                if (!_response.isInt())
                    warn("_family", "Poisson expects non-negative integer response, got floats.");
                break;
            case gamma:
                if (_nclass != 1)
                    error("_distribution", H2O.technote(2, "Gamma requires the response to be numeric."));
                if (_response.min() <= 0)
                    error("_family", "Gamma requires positive respone");
                break;
            case tweedie:
                if (_nclass != 1)
                    error("_family", H2O.technote(2, "Tweedie requires the response to be numeric."));
                break;
            case quasibinomial:
                if (_nclass != 1)
                    error("_family", H2O.technote(2, "Quasi_binomial requires the response to be numeric."));
                break;
            case gaussian:
                //          if (_nclass != 1) error("_family", H2O.technote(2, "Gaussian requires the response to be numeric."));
                break;
            default:
                error("_family", "Invalid distribution: " + _parms._distribution);
        }
    }
    if (expensive) {
        if (error_count() > 0)
            return;
        if (_parms._alpha == null)
            _parms._alpha = new double[] { _parms._solver == Solver.L_BFGS ? 0 : .5 };
        if (_parms._lambda_search && _parms._nlambdas == -1)
            // fewer lambdas needed for ridge
            _parms._nlambdas = _parms._alpha[0] == 0 ? 30 : 100;
        _lsc = new LambdaSearchScoringHistory(_parms._valid != null, _parms._nfolds > 1);
        _sc = new ScoringHistory();
        // make sure we have all the rollups computed in parallel
        _train.bulkRollups();
        _sc = new ScoringHistory();
        _t0 = System.currentTimeMillis();
        if (_parms._lambda_search || !_parms._intercept || _parms._lambda == null || _parms._lambda[0] > 0)
            _parms._use_all_factor_levels = true;
        if (_parms._link == Link.family_default)
            _parms._link = _parms._family.defaultLink;
        _dinfo = new DataInfo(_train.clone(), _valid, 1, _parms._use_all_factor_levels || _parms._lambda_search, _parms._standardize ? DataInfo.TransformType.STANDARDIZE : DataInfo.TransformType.NONE, DataInfo.TransformType.NONE, _parms._missing_values_handling == MissingValuesHandling.Skip, _parms._missing_values_handling == MissingValuesHandling.MeanImputation, false, hasWeightCol(), hasOffsetCol(), hasFoldCol(), _parms._interactions);
        if (_parms._max_iterations == -1) {
            // fill in default max iterations
            int numclasses = _parms._family == Family.multinomial ? nclasses() : 1;
            if (_parms._solver == Solver.L_BFGS) {
                _parms._max_iterations = _parms._lambda_search ? _parms._nlambdas * 100 * numclasses : numclasses * Math.max(20, _dinfo.fullN() >> 2);
                if (_parms._alpha[0] > 0)
                    _parms._max_iterations *= 10;
            } else
                _parms._max_iterations = _parms._lambda_search ? 10 * _parms._nlambdas : 50;
        }
        if (_valid != null)
            _validDinfo = _dinfo.validDinfo(_valid);
        _state = new ComputationState(_job, _parms, _dinfo, null, nclasses());
        // skipping extra rows? (outside of weights == 0)GLMT
        boolean skippingRows = (_parms._missing_values_handling == MissingValuesHandling.Skip && _train.hasNAs());
        if (hasWeightCol() || skippingRows) {
            // need to re-compute means and sd
            // && _parms._lambda_search && _parms._alpha[0] > 0;
            boolean setWeights = skippingRows;
            if (setWeights) {
                Vec wc = _weights == null ? _dinfo._adaptedFrame.anyVec().makeCon(1) : _weights.makeCopy();
                _dinfo.setWeights(_generatedWeights = "__glm_gen_weights", wc);
            }
            YMUTask ymt = new YMUTask(_dinfo, _parms._family == Family.multinomial ? nclasses() : 1, setWeights, skippingRows, true).doAll(_dinfo._adaptedFrame);
            if (ymt.wsum() == 0)
                throw new IllegalArgumentException("No rows left in the dataset after filtering out rows with missing values. Ignore columns with many NAs or impute your missing values prior to calling glm.");
            Log.info(LogMsg("using " + ymt.nobs() + " nobs out of " + _dinfo._adaptedFrame.numRows() + " total"));
            // if sparse data, need second pass to compute variance
            _nobs = ymt.nobs();
            if (_parms._obj_reg == -1)
                _parms._obj_reg = 1.0 / ymt.wsum();
            if (!_parms._stdOverride)
                _dinfo.updateWeightedSigmaAndMean(ymt.predictorSDs(), ymt.predictorMeans());
            if (_parms._family == Family.multinomial) {
                _state._ymu = MemoryManager.malloc8d(_nclass);
                for (int i = 0; i < _state._ymu.length; ++i) _state._ymu[i] = _priorClassDist[i];
            } else
                _state._ymu = _parms._intercept ? ymt._yMu : new double[] { _parms.linkInv(0) };
        } else {
            _nobs = _train.numRows();
            if (_parms._obj_reg == -1)
                _parms._obj_reg = 1.0 / _nobs;
            if (_parms._family == Family.multinomial) {
                _state._ymu = MemoryManager.malloc8d(_nclass);
                for (int i = 0; i < _state._ymu.length; ++i) _state._ymu[i] = _priorClassDist[i];
            } else
                _state._ymu = new double[] { _parms._intercept ? _train.lastVec().mean() : _parms.linkInv(0) };
        }
        BetaConstraint bc = (_parms._beta_constraints != null) ? new BetaConstraint(_parms._beta_constraints.get()) : new BetaConstraint();
        if ((bc.hasBounds() || bc.hasProximalPenalty()) && _parms._compute_p_values)
            error("_compute_p_values", "P-values can not be computed for constrained problems");
        _state.setBC(bc);
        if (hasOffsetCol() && _parms._intercept) {
            // fit intercept
            GLMGradientSolver gslvr = new GLMGradientSolver(_job, _parms, _dinfo.filterExpandedColumns(new int[0]), 0, _state.activeBC());
            double[] x = new L_BFGS().solve(gslvr, new double[] { -_offset.mean() }).coefs;
            Log.info(LogMsg("fitted intercept = " + x[0]));
            x[0] = _parms.linkInv(x[0]);
            _state._ymu = x;
        }
        if (_parms._prior > 0)
            _iceptAdjust = -Math.log(_state._ymu[0] * (1 - _parms._prior) / (_parms._prior * (1 - _state._ymu[0])));
        ArrayList<Vec> vecs = new ArrayList<>();
        if (_weights != null)
            vecs.add(_weights);
        if (_offset != null)
            vecs.add(_offset);
        vecs.add(_response);
        double[] beta = getNullBeta();
        GLMGradientInfo ginfo = new GLMGradientSolver(_job, _parms, _dinfo, 0, _state.activeBC()).getGradient(beta);
        _lmax = lmax(ginfo._gradient);
        _state.setLambdaMax(_lmax);
        _model = new GLMModel(_result, _parms, GLM.this, _state._ymu, _dinfo._adaptedFrame.lastVec().sigma(), _lmax, _nobs);
        String[] warns = _model.adaptTestForTrain(_valid, true, true);
        for (String s : warns) _job.warn(s);
        if (_parms._lambda_min_ratio == -1) {
            _parms._lambda_min_ratio = (_nobs >> 4) > _dinfo.fullN() ? 1e-4 : 1e-2;
            if (_parms._alpha[0] == 0)
                // smalelr lambda min for ridge as we are starting quite high
                _parms._lambda_min_ratio *= 1e-2;
        }
        _state.updateState(beta, ginfo);
        if (_parms._lambda == null) {
            // no lambda given, we will base lambda as a fraction of lambda max
            if (_parms._lambda_search) {
                _parms._lambda = new double[_parms._nlambdas];
                double dec = Math.pow(_parms._lambda_min_ratio, 1.0 / (_parms._nlambdas - 1));
                _parms._lambda[0] = _lmax;
                double l = _lmax;
                for (int i = 1; i < _parms._nlambdas; ++i) _parms._lambda[i] = (l *= dec);
            // todo set the null submodel
            } else
                _parms._lambda = new double[] { 10 * _parms._lambda_min_ratio * _lmax };
        }
        if (!Double.isNaN(_lambdaCVEstimate)) {
            for (int i = 0; i < _parms._lambda.length; ++i) if (_parms._lambda[i] < _lambdaCVEstimate) {
                _parms._lambda = Arrays.copyOf(_parms._lambda, i + 1);
                break;
            }
            _parms._lambda[_parms._lambda.length - 1] = _lambdaCVEstimate;
        }
        if (_parms._objective_epsilon == -1) {
            if (_parms._lambda_search)
                _parms._objective_epsilon = 1e-4;
            else
                // lower default objective epsilon for non-standardized problems (mostly to match classical tools)
                _parms._objective_epsilon = _parms._lambda[0] == 0 ? 1e-6 : 1e-4;
        }
        if (_parms._gradient_epsilon == -1) {
            _parms._gradient_epsilon = _parms._lambda[0] == 0 ? 1e-6 : 1e-4;
            if (_parms._lambda_search)
                _parms._gradient_epsilon *= 1e-2;
        }
        // clone2 so that I don't change instance which is in the DKV directly
        // (clone2 also shallow clones _output)
        _model.clone2().delete_and_lock(_job._key);
    }
}
Also used : GLMModel(hex.glm.GLMModel) ArrayList(java.util.ArrayList) L_BFGS(hex.optimization.L_BFGS) BufferedString(water.parser.BufferedString) H2OModelBuilderIllegalArgumentException(water.exceptions.H2OModelBuilderIllegalArgumentException)

Aggregations

GLMModel (hex.glm.GLMModel)1 L_BFGS (hex.optimization.L_BFGS)1 ArrayList (java.util.ArrayList)1 H2OModelBuilderIllegalArgumentException (water.exceptions.H2OModelBuilderIllegalArgumentException)1 BufferedString (water.parser.BufferedString)1