use of hex.genmodel.algos.glrm.GlrmRegularizer in project h2o-3 by h2oai.
the class GLRMCategoricalTest method testLosses.
@Test
public void testLosses() throws InterruptedException, ExecutionException {
long seed = 0xDECAF;
Random rng = new Random(seed);
Frame train = null;
// Categoricals: CAPSULE, RACE, DPROS, DCAPS
final int[] cats = new int[] { 1, 3, 4, 5 };
final GlrmRegularizer[] regs = new GlrmRegularizer[] { GlrmRegularizer.Quadratic, GlrmRegularizer.L1, GlrmRegularizer.NonNegative, GlrmRegularizer.OneSparse, GlrmRegularizer.UnitOneSparse, GlrmRegularizer.Simplex };
Scope.enter();
try {
train = parse_test_file(Key.make("prostate.hex"), "smalldata/logreg/prostate.csv");
for (int i = 0; i < cats.length; i++) Scope.track(train.replace(cats[i], train.vec(cats[i]).toCategoricalVec()));
train.remove("ID").remove();
DKV.put(train._key, train);
for (GlrmLoss loss : new GlrmLoss[] { GlrmLoss.Quadratic, GlrmLoss.Absolute, GlrmLoss.Huber, GlrmLoss.Poisson }) {
for (GlrmLoss multiloss : new GlrmLoss[] { GlrmLoss.Categorical, GlrmLoss.Ordinal }) {
GLRMModel model = null;
try {
Scope.enter();
long myseed = rng.nextLong();
Log.info("GLRM using seed = " + myseed);
GLRMParameters parms = new GLRMParameters();
parms._train = train._key;
parms._transform = DataInfo.TransformType.NONE;
parms._k = 5;
parms._loss = loss;
parms._multi_loss = multiloss;
parms._init = GlrmInitialization.SVD;
parms._regularization_x = regs[rng.nextInt(regs.length)];
parms._regularization_y = regs[rng.nextInt(regs.length)];
parms._gamma_x = Math.abs(rng.nextDouble());
parms._gamma_y = Math.abs(rng.nextDouble());
parms._recover_svd = false;
parms._seed = myseed;
parms._verbose = false;
parms._max_iterations = 500;
model = new GLRM(parms).trainModel().get();
Log.info("Iteration " + model._output._iterations + ": Objective value = " + model._output._objective);
model.score(train).delete();
ModelMetricsGLRM mm = (ModelMetricsGLRM) ModelMetrics.getFromDKV(model, train);
Log.info("Numeric Sum of Squared Error = " + mm._numerr + "\tCategorical Misclassification Error = " + mm._caterr);
} finally {
if (model != null)
model.delete();
Scope.exit();
}
}
}
} finally {
if (train != null)
train.delete();
Scope.exit();
}
}
Aggregations