Search in sources :

Example 1 with TransformType

use of hex.FrameTask.DataInfo.TransformType in project h2o-2 by h2oai.

the class GLM2 method init.

@Override
public void init() {
    try {
        super.init();
        if (family == Family.gamma)
            setHighAccuracy();
        if (link == Link.family_default)
            link = family.defaultLink;
        _intercept = intercept ? 1 : 0;
        // TODO
        tweedie_link_power = 1 - tweedie_variance_power;
        if (tweedie_link_power == 0)
            link = Link.log;
        _glm = new GLMParams(family, tweedie_variance_power, link, tweedie_link_power);
        source2 = new Frame(source);
        assert sorted(ignored_cols);
        source2.remove(ignored_cols);
        if (offset != null)
            // remove offset and add it later explicitly (so that it does not interfere with DataInfo.prepareFrame)
            source2.remove(source2.find(offset));
        if (nlambdas == -1)
            nlambdas = 100;
        if (lambda_search && lambda.length > 1)
            throw new IllegalArgumentException("Can not supply both lambda_search and multiple lambdas. If lambda_search is on, GLM expects only one value of lambda_value, representing the lambda_value min (smallest lambda_value in the lambda_value search).");
        // check the response
        if (response.isEnum() && family != Family.binomial)
            throw new IllegalArgumentException("Invalid response variable, trying to run regression with categorical response!");
        switch(family) {
            case poisson:
            case tweedie:
                if (response.min() < 0)
                    throw new IllegalArgumentException("Illegal response column for family='" + family + "', response must be >= 0.");
                break;
            case gamma:
                if (response.min() <= 0)
                    throw new IllegalArgumentException("Invalid response for family='Gamma', response must be > 0!");
                break;
            case binomial:
                if (response.min() < 0 || response.max() > 1)
                    throw new IllegalArgumentException("Illegal response column for family='Binomial', response must in <0,1> range!");
                break;
            default:
        }
        toEnum = family == Family.binomial && (!response.isEnum() && (response.min() < 0 || response.max() > 1));
        if (source2.numCols() <= 1 && !intercept)
            throw new IllegalArgumentException("There are no predictors left after ignoring constant columns in the dataset and no intercept => No parameters to estimate.");
        Frame fr = DataInfo.prepareFrame(source2, response, new int[0], toEnum, true, true);
        if (offset != null) {
            // now put the offset just in front of response
            int id = source.find(offset);
            String name = source.names()[id];
            String responseName = fr.names()[fr.numCols() - 1];
            Vec responseVec = fr.remove(fr.numCols() - 1);
            fr.add(name, offset);
            fr.add(responseName, responseVec);
            _noffsets = 1;
        }
        TransformType dt = TransformType.NONE;
        if (standardize)
            dt = intercept ? TransformType.STANDARDIZE : TransformType.DESCALE;
        _srcDinfo = new DataInfo(fr, 1, intercept, use_all_factor_levels || lambda_search, dt, DataInfo.TransformType.NONE);
        if (offset != null && dt != TransformType.NONE) {
            // do not standardize offset
            if (_srcDinfo._normMul != null)
                _srcDinfo._normMul[_srcDinfo._normMul.length - 1] = 1;
            if (_srcDinfo._normSub != null)
                _srcDinfo._normSub[_srcDinfo._normSub.length - 1] = 0;
        }
        if (!intercept && _srcDinfo._cats > 0)
            throw new IllegalArgumentException("Models with no intercept are only supported with all-numeric predictors.");
        _activeData = _srcDinfo;
        if (higher_accuracy)
            setHighAccuracy();
        if (beta_constraints != null) {
            Vec v = beta_constraints.vec("names");
            if (v == null)
                throw new IllegalArgumentException("Invalid beta constraints file, missing column with predictor names");
            // for now only enums allowed here
            String[] dom = v.domain();
            String[] names = Utils.append(_srcDinfo.coefNames(), "Intercept");
            int[] map = Utils.asInts(v);
            HashSet<Integer> s = new HashSet<Integer>();
            for (int i : map) if (!s.add(i))
                throw new IllegalArgumentException("Invalid beta constraints file, got duplicate constraints for '" + dom[i] + "'");
            if (!Arrays.deepEquals(dom, names)) {
                // need mapping
                HashMap<String, Integer> m = new HashMap<String, Integer>();
                for (int i = 0; i < names.length; ++i) {
                    m.put(names[i], i);
                }
                int[] newMap = MemoryManager.malloc4(map.length);
                for (int i = 0; i < map.length; ++i) {
                    Integer I = m.get(dom[map[i]]);
                    if (I == null)
                        throw new IllegalArgumentException("unknown predictor name '" + dom[map[i]] + "'");
                    newMap[i] = I == null ? -1 : I;
                }
                map = newMap;
            }
            final int numoff = _srcDinfo.numStart();
            if ((v = beta_constraints.vec("lower_bounds")) != null) {
                _lbs = map == null ? Utils.asDoubles(v) : mapVec(Utils.asDoubles(v), makeAry(names.length, Double.NEGATIVE_INFINITY), map);
                //            for(int i = 0; i < _lbs.length; ++i)
                //            if(_lbs[i] > 0) throw new IllegalArgumentException("lower bounds must be non-positive");
                System.out.println("lower bounds = " + Arrays.toString(_lbs));
                if (_srcDinfo._normMul != null) {
                    for (int i = numoff; i < _srcDinfo.fullN(); ++i) {
                        if (Double.isInfinite(_lbs[i]))
                            continue;
                        _lbs[i] /= _srcDinfo._normMul[i - numoff];
                    }
                }
            }
            System.out.println("lbs = " + Arrays.toString(_lbs));
            if ((v = beta_constraints.vec("upper_bounds")) != null) {
                _ubs = map == null ? Utils.asDoubles(v) : mapVec(Utils.asDoubles(v), makeAry(names.length, Double.POSITIVE_INFINITY), map);
                System.out.println("upper bounds = " + Arrays.toString(_ubs));
                //            if (_ubs[i] < 0) throw new IllegalArgumentException("lower bounds must be non-positive");
                if (_srcDinfo._normMul != null)
                    for (int i = numoff; i < _srcDinfo.fullN(); ++i) {
                        if (Double.isInfinite(_ubs[i]))
                            continue;
                        _ubs[i] /= _srcDinfo._normMul[i - numoff];
                    }
            }
            System.out.println("ubs = " + Arrays.toString(_ubs));
            if (_lbs != null && _ubs != null) {
                for (int i = 0; i < _lbs.length; ++i) if (_lbs[i] > _ubs[i])
                    throw new IllegalArgumentException("Invalid upper/lower bounds: lower bounds must be <= upper bounds for all variables.");
            }
            if ((v = beta_constraints.vec("beta_given")) != null) {
                _bgs = map == null ? Utils.asDoubles(v) : mapVec(Utils.asDoubles(v), makeAry(names.length, 0), map);
                if (_srcDinfo._normMul != null) {
                    double norm = 0;
                    for (int i = numoff; i < _srcDinfo.fullN(); ++i) {
                        norm += _bgs[i] * _srcDinfo._normSub[i - numoff];
                        _bgs[i] /= _srcDinfo._normMul[i - numoff];
                    }
                    if (_intercept == 1)
                        _bgs[_bgs.length - 1] -= norm;
                }
            }
            if ((v = beta_constraints.vec("rho")) != null)
                _rho = map == null ? Utils.asDoubles(v) : mapVec(Utils.asDoubles(v), makeAry(names.length, 0), map);
            else if (_bgs != null)
                throw new IllegalArgumentException("Missing vector of penalties (rho) in beta_constraints file.");
            String[] cols = new String[] { "names", "rho", "beta_given", "lower_bounds", "upper_bounds" };
            Arrays.sort(cols);
            for (String str : beta_constraints.names()) if (Arrays.binarySearch(cols, str) < 0)
                Log.warn("unknown column in beta_constraints file: '" + str + "'");
        }
        if (non_negative) {
            // make srue lb is >= 0
            if (_lbs == null)
                _lbs = new double[_srcDinfo.fullN() + 1];
            // no bounds for intercept
            _lbs[_srcDinfo.fullN()] = Double.NEGATIVE_INFINITY;
            for (int i = 0; i < _lbs.length; ++i) if (_lbs[i] < 0)
                _lbs[i] = 0;
        }
    } catch (RuntimeException e) {
        e.printStackTrace();
        cleanup();
        throw e;
    }
}
Also used : DataInfo(hex.FrameTask.DataInfo) Frame(water.fvec.Frame) HashMap(java.util.HashMap) RString(water.util.RString) TransformType(hex.FrameTask.DataInfo.TransformType) AtomicInteger(java.util.concurrent.atomic.AtomicInteger) Vec(water.fvec.Vec) HashSet(java.util.HashSet)

Aggregations

DataInfo (hex.FrameTask.DataInfo)1 TransformType (hex.FrameTask.DataInfo.TransformType)1 HashMap (java.util.HashMap)1 HashSet (java.util.HashSet)1 AtomicInteger (java.util.concurrent.atomic.AtomicInteger)1 Frame (water.fvec.Frame)1 Vec (water.fvec.Vec)1 RString (water.util.RString)1