use of hex.FrameTask.DataInfo.TransformType in project h2o-2 by h2oai.
the class GLM2 method init.
@Override
public void init() {
try {
super.init();
if (family == Family.gamma)
setHighAccuracy();
if (link == Link.family_default)
link = family.defaultLink;
_intercept = intercept ? 1 : 0;
// TODO
tweedie_link_power = 1 - tweedie_variance_power;
if (tweedie_link_power == 0)
link = Link.log;
_glm = new GLMParams(family, tweedie_variance_power, link, tweedie_link_power);
source2 = new Frame(source);
assert sorted(ignored_cols);
source2.remove(ignored_cols);
if (offset != null)
// remove offset and add it later explicitly (so that it does not interfere with DataInfo.prepareFrame)
source2.remove(source2.find(offset));
if (nlambdas == -1)
nlambdas = 100;
if (lambda_search && lambda.length > 1)
throw new IllegalArgumentException("Can not supply both lambda_search and multiple lambdas. If lambda_search is on, GLM expects only one value of lambda_value, representing the lambda_value min (smallest lambda_value in the lambda_value search).");
// check the response
if (response.isEnum() && family != Family.binomial)
throw new IllegalArgumentException("Invalid response variable, trying to run regression with categorical response!");
switch(family) {
case poisson:
case tweedie:
if (response.min() < 0)
throw new IllegalArgumentException("Illegal response column for family='" + family + "', response must be >= 0.");
break;
case gamma:
if (response.min() <= 0)
throw new IllegalArgumentException("Invalid response for family='Gamma', response must be > 0!");
break;
case binomial:
if (response.min() < 0 || response.max() > 1)
throw new IllegalArgumentException("Illegal response column for family='Binomial', response must in <0,1> range!");
break;
default:
}
toEnum = family == Family.binomial && (!response.isEnum() && (response.min() < 0 || response.max() > 1));
if (source2.numCols() <= 1 && !intercept)
throw new IllegalArgumentException("There are no predictors left after ignoring constant columns in the dataset and no intercept => No parameters to estimate.");
Frame fr = DataInfo.prepareFrame(source2, response, new int[0], toEnum, true, true);
if (offset != null) {
// now put the offset just in front of response
int id = source.find(offset);
String name = source.names()[id];
String responseName = fr.names()[fr.numCols() - 1];
Vec responseVec = fr.remove(fr.numCols() - 1);
fr.add(name, offset);
fr.add(responseName, responseVec);
_noffsets = 1;
}
TransformType dt = TransformType.NONE;
if (standardize)
dt = intercept ? TransformType.STANDARDIZE : TransformType.DESCALE;
_srcDinfo = new DataInfo(fr, 1, intercept, use_all_factor_levels || lambda_search, dt, DataInfo.TransformType.NONE);
if (offset != null && dt != TransformType.NONE) {
// do not standardize offset
if (_srcDinfo._normMul != null)
_srcDinfo._normMul[_srcDinfo._normMul.length - 1] = 1;
if (_srcDinfo._normSub != null)
_srcDinfo._normSub[_srcDinfo._normSub.length - 1] = 0;
}
if (!intercept && _srcDinfo._cats > 0)
throw new IllegalArgumentException("Models with no intercept are only supported with all-numeric predictors.");
_activeData = _srcDinfo;
if (higher_accuracy)
setHighAccuracy();
if (beta_constraints != null) {
Vec v = beta_constraints.vec("names");
if (v == null)
throw new IllegalArgumentException("Invalid beta constraints file, missing column with predictor names");
// for now only enums allowed here
String[] dom = v.domain();
String[] names = Utils.append(_srcDinfo.coefNames(), "Intercept");
int[] map = Utils.asInts(v);
HashSet<Integer> s = new HashSet<Integer>();
for (int i : map) if (!s.add(i))
throw new IllegalArgumentException("Invalid beta constraints file, got duplicate constraints for '" + dom[i] + "'");
if (!Arrays.deepEquals(dom, names)) {
// need mapping
HashMap<String, Integer> m = new HashMap<String, Integer>();
for (int i = 0; i < names.length; ++i) {
m.put(names[i], i);
}
int[] newMap = MemoryManager.malloc4(map.length);
for (int i = 0; i < map.length; ++i) {
Integer I = m.get(dom[map[i]]);
if (I == null)
throw new IllegalArgumentException("unknown predictor name '" + dom[map[i]] + "'");
newMap[i] = I == null ? -1 : I;
}
map = newMap;
}
final int numoff = _srcDinfo.numStart();
if ((v = beta_constraints.vec("lower_bounds")) != null) {
_lbs = map == null ? Utils.asDoubles(v) : mapVec(Utils.asDoubles(v), makeAry(names.length, Double.NEGATIVE_INFINITY), map);
// for(int i = 0; i < _lbs.length; ++i)
// if(_lbs[i] > 0) throw new IllegalArgumentException("lower bounds must be non-positive");
System.out.println("lower bounds = " + Arrays.toString(_lbs));
if (_srcDinfo._normMul != null) {
for (int i = numoff; i < _srcDinfo.fullN(); ++i) {
if (Double.isInfinite(_lbs[i]))
continue;
_lbs[i] /= _srcDinfo._normMul[i - numoff];
}
}
}
System.out.println("lbs = " + Arrays.toString(_lbs));
if ((v = beta_constraints.vec("upper_bounds")) != null) {
_ubs = map == null ? Utils.asDoubles(v) : mapVec(Utils.asDoubles(v), makeAry(names.length, Double.POSITIVE_INFINITY), map);
System.out.println("upper bounds = " + Arrays.toString(_ubs));
// if (_ubs[i] < 0) throw new IllegalArgumentException("lower bounds must be non-positive");
if (_srcDinfo._normMul != null)
for (int i = numoff; i < _srcDinfo.fullN(); ++i) {
if (Double.isInfinite(_ubs[i]))
continue;
_ubs[i] /= _srcDinfo._normMul[i - numoff];
}
}
System.out.println("ubs = " + Arrays.toString(_ubs));
if (_lbs != null && _ubs != null) {
for (int i = 0; i < _lbs.length; ++i) if (_lbs[i] > _ubs[i])
throw new IllegalArgumentException("Invalid upper/lower bounds: lower bounds must be <= upper bounds for all variables.");
}
if ((v = beta_constraints.vec("beta_given")) != null) {
_bgs = map == null ? Utils.asDoubles(v) : mapVec(Utils.asDoubles(v), makeAry(names.length, 0), map);
if (_srcDinfo._normMul != null) {
double norm = 0;
for (int i = numoff; i < _srcDinfo.fullN(); ++i) {
norm += _bgs[i] * _srcDinfo._normSub[i - numoff];
_bgs[i] /= _srcDinfo._normMul[i - numoff];
}
if (_intercept == 1)
_bgs[_bgs.length - 1] -= norm;
}
}
if ((v = beta_constraints.vec("rho")) != null)
_rho = map == null ? Utils.asDoubles(v) : mapVec(Utils.asDoubles(v), makeAry(names.length, 0), map);
else if (_bgs != null)
throw new IllegalArgumentException("Missing vector of penalties (rho) in beta_constraints file.");
String[] cols = new String[] { "names", "rho", "beta_given", "lower_bounds", "upper_bounds" };
Arrays.sort(cols);
for (String str : beta_constraints.names()) if (Arrays.binarySearch(cols, str) < 0)
Log.warn("unknown column in beta_constraints file: '" + str + "'");
}
if (non_negative) {
// make srue lb is >= 0
if (_lbs == null)
_lbs = new double[_srcDinfo.fullN() + 1];
// no bounds for intercept
_lbs[_srcDinfo.fullN()] = Double.NEGATIVE_INFINITY;
for (int i = 0; i < _lbs.length; ++i) if (_lbs[i] < 0)
_lbs[i] = 0;
}
} catch (RuntimeException e) {
e.printStackTrace();
cleanup();
throw e;
}
}
Aggregations