Search in sources :

Example 1 with GlrmRegularizer

use of hex.genmodel.algos.glrm.GlrmRegularizer in project h2o-3 by h2oai.

the class GLRMCategoricalTest method testLosses.

@Test
public void testLosses() throws InterruptedException, ExecutionException {
    long seed = 0xDECAF;
    Random rng = new Random(seed);
    Frame train = null;
    // Categoricals: CAPSULE, RACE, DPROS, DCAPS
    final int[] cats = new int[] { 1, 3, 4, 5 };
    final GlrmRegularizer[] regs = new GlrmRegularizer[] { GlrmRegularizer.Quadratic, GlrmRegularizer.L1, GlrmRegularizer.NonNegative, GlrmRegularizer.OneSparse, GlrmRegularizer.UnitOneSparse, GlrmRegularizer.Simplex };
    Scope.enter();
    try {
        train = parse_test_file(Key.make("prostate.hex"), "smalldata/logreg/prostate.csv");
        for (int i = 0; i < cats.length; i++) Scope.track(train.replace(cats[i], train.vec(cats[i]).toCategoricalVec()));
        train.remove("ID").remove();
        DKV.put(train._key, train);
        for (GlrmLoss loss : new GlrmLoss[] { GlrmLoss.Quadratic, GlrmLoss.Absolute, GlrmLoss.Huber, GlrmLoss.Poisson }) {
            for (GlrmLoss multiloss : new GlrmLoss[] { GlrmLoss.Categorical, GlrmLoss.Ordinal }) {
                GLRMModel model = null;
                try {
                    Scope.enter();
                    long myseed = rng.nextLong();
                    Log.info("GLRM using seed = " + myseed);
                    GLRMParameters parms = new GLRMParameters();
                    parms._train = train._key;
                    parms._transform = DataInfo.TransformType.NONE;
                    parms._k = 5;
                    parms._loss = loss;
                    parms._multi_loss = multiloss;
                    parms._init = GlrmInitialization.SVD;
                    parms._regularization_x = regs[rng.nextInt(regs.length)];
                    parms._regularization_y = regs[rng.nextInt(regs.length)];
                    parms._gamma_x = Math.abs(rng.nextDouble());
                    parms._gamma_y = Math.abs(rng.nextDouble());
                    parms._recover_svd = false;
                    parms._seed = myseed;
                    parms._verbose = false;
                    parms._max_iterations = 500;
                    model = new GLRM(parms).trainModel().get();
                    Log.info("Iteration " + model._output._iterations + ": Objective value = " + model._output._objective);
                    model.score(train).delete();
                    ModelMetricsGLRM mm = (ModelMetricsGLRM) ModelMetrics.getFromDKV(model, train);
                    Log.info("Numeric Sum of Squared Error = " + mm._numerr + "\tCategorical Misclassification Error = " + mm._caterr);
                } finally {
                    if (model != null)
                        model.delete();
                    Scope.exit();
                }
            }
        }
    } finally {
        if (train != null)
            train.delete();
        Scope.exit();
    }
}
Also used : Frame(water.fvec.Frame) GlrmLoss(hex.genmodel.algos.glrm.GlrmLoss) Random(java.util.Random) GLRMParameters(hex.glrm.GLRMModel.GLRMParameters) GlrmRegularizer(hex.genmodel.algos.glrm.GlrmRegularizer) Test(org.junit.Test)

Aggregations

GlrmLoss (hex.genmodel.algos.glrm.GlrmLoss)1 GlrmRegularizer (hex.genmodel.algos.glrm.GlrmRegularizer)1 GLRMParameters (hex.glrm.GLRMModel.GLRMParameters)1 Random (java.util.Random)1 Test (org.junit.Test)1 Frame (water.fvec.Frame)1