Search in sources :

Example 11 with DataInfo

use of hex.DataInfo in project h2o-3 by h2oai.

the class GLRMCategoricalTest method testExpandCatsIris.

@Test
public void testExpandCatsIris() throws InterruptedException, ExecutionException {
    double[][] iris = ard(ard(6.3, 2.5, 4.9, 1.5, 1), ard(5.7, 2.8, 4.5, 1.3, 1), ard(5.6, 2.8, 4.9, 2.0, 2), ard(5.0, 3.4, 1.6, 0.4, 0), ard(6.0, 2.2, 5.0, 1.5, 2));
    double[][] iris_expandR = ard(ard(0, 1, 0, 6.3, 2.5, 4.9, 1.5), ard(0, 1, 0, 5.7, 2.8, 4.5, 1.3), ard(0, 0, 1, 5.6, 2.8, 4.9, 2.0), ard(1, 0, 0, 5.0, 3.4, 1.6, 0.4), ard(0, 0, 1, 6.0, 2.2, 5.0, 1.5));
    String[] iris_cols = new String[] { "sepal_len", "sepal_wid", "petal_len", "petal_wid", "class" };
    String[][] iris_domains = new String[][] { null, null, null, null, new String[] { "setosa", "versicolor", "virginica" } };
    Frame fr = null;
    try {
        fr = parse_test_file(Key.make("iris.hex"), "smalldata/iris/iris_wheader.csv");
        DataInfo dinfo = new DataInfo(fr, null, 0, true, DataInfo.TransformType.NONE, DataInfo.TransformType.NONE, false, false, false, /* weights */
        false, /* offset */
        false, /* fold */
        false);
        Log.info("Original matrix:\n" + colFormat(iris_cols, "%8.7s") + ArrayUtils.pprint(iris));
        double[][] iris_perm = ArrayUtils.permuteCols(iris, dinfo._permutation);
        Log.info("Permuted matrix:\n" + colFormat(iris_cols, "%8.7s", dinfo._permutation) + ArrayUtils.pprint(iris_perm));
        double[][] iris_exp = GLRM.expandCats(iris_perm, dinfo);
        Log.info("Expanded matrix:\n" + colExpFormat(iris_cols, iris_domains, "%8.7s", dinfo._permutation) + ArrayUtils.pprint(iris_exp));
        Assert.assertArrayEquals(iris_expandR, iris_exp);
    } finally {
        if (fr != null)
            fr.delete();
    }
}
Also used : DataInfo(hex.DataInfo) Frame(water.fvec.Frame) Test(org.junit.Test)

Example 12 with DataInfo

use of hex.DataInfo in project h2o-3 by h2oai.

the class L_BFGS_Test method logistic.

@Test
public void logistic() {
    Key parsedKey = Key.make("prostate");
    DataInfo dinfo = null;
    try {
        GLMParameters glmp = new GLMParameters(Family.binomial, Family.binomial.defaultLink);
        glmp._alpha = new double[] { 0 };
        glmp._lambda = new double[] { 1e-5 };
        Frame source = parse_test_file(parsedKey, "smalldata/glm_test/prostate_cat_replaced.csv");
        source.add("CAPSULE", source.remove("CAPSULE"));
        source.remove("ID").remove();
        Frame valid = new Frame(source._names.clone(), source.vecs().clone());
        dinfo = new DataInfo(source, valid, 1, false, DataInfo.TransformType.STANDARDIZE, DataInfo.TransformType.NONE, true, false, false, /* weights */
        false, /* offset */
        false, /* fold */
        false);
        DKV.put(dinfo._key, dinfo);
        glmp._obj_reg = 1 / 380.0;
        GLMGradientSolver solver = new GLMGradientSolver(null, glmp, dinfo, 1e-5, null);
        L_BFGS lbfgs = new L_BFGS().setGradEps(1e-8);
        double[] beta = MemoryManager.malloc8d(dinfo.fullN() + 1);
        beta[beta.length - 1] = new GLMWeightsFun(glmp).link(source.vec("CAPSULE").mean());
        L_BFGS.Result r = lbfgs.solve(solver, beta, solver.getGradient(beta), new L_BFGS.ProgressMonitor() {

            int _i = 0;

            public boolean progress(double[] beta, GradientInfo ginfo) {
                System.out.println(++_i + ":" + ginfo._objVal + ", " + ArrayUtils.l2norm2(ginfo._gradient, false));
                return true;
            }
        });
        assertEquals(378.34, 2 * r.ginfo._objVal * source.numRows(), 1e-1);
    } finally {
        if (dinfo != null)
            DKV.remove(dinfo._key);
        Value v = DKV.get(parsedKey);
        if (v != null) {
            v.<Frame>get().delete();
        }
    }
}
Also used : DataInfo(hex.DataInfo) Frame(water.fvec.Frame) GLMGradientSolver(hex.glm.GLM.GLMGradientSolver) GradientInfo(hex.optimization.OptimizationUtils.GradientInfo) GLMWeightsFun(hex.glm.GLMModel.GLMWeightsFun) GLMParameters(hex.glm.GLMModel.GLMParameters) Test(org.junit.Test)

Example 13 with DataInfo

use of hex.DataInfo in project h2o-3 by h2oai.

the class L_BFGS_Test method testArcene.

// Test LSM on arcene - wide dataset with ~10k columns
// test warm start and max #iteratoions
@Test
public void testArcene() {
    Key parsedKey = Key.make("arcene_parsed");
    DataInfo dinfo = null;
    try {
        Frame source = parse_test_file(parsedKey, "smalldata/glm_test/arcene.csv");
        Frame valid = new Frame(source._names.clone(), source.vecs().clone());
        GLMParameters glmp = new GLMParameters(Family.gaussian);
        glmp._lambda = new double[] { 1e-5 };
        glmp._alpha = new double[] { 0 };
        glmp._obj_reg = 0.01;
        dinfo = new DataInfo(source, valid, 1, false, DataInfo.TransformType.STANDARDIZE, DataInfo.TransformType.NONE, true, false, false, /* weights */
        false, /* offset */
        false, /* fold */
        false);
        DKV.put(dinfo._key, dinfo);
        GradientSolver solver = new GLMGradientSolver(null, glmp, dinfo, 1e-5, null);
        L_BFGS lbfgs = new L_BFGS().setMaxIter(20);
        double[] beta = MemoryManager.malloc8d(dinfo.fullN() + 1);
        beta[beta.length - 1] = new GLMWeightsFun(glmp).link(source.lastVec().mean());
        L_BFGS.Result r1 = lbfgs.solve(solver, beta.clone(), solver.getGradient(beta), new L_BFGS.ProgressMonitor() {

            int _i = 0;

            public boolean progress(double[] beta, GradientInfo ginfo) {
                System.out.println(++_i + ":" + ginfo._objVal);
                return true;
            }
        });
        lbfgs.setMaxIter(50);
        final int iter = r1.iter;
        L_BFGS.Result r2 = lbfgs.solve(solver, r1.coefs, r1.ginfo, new L_BFGS.ProgressMonitor() {

            int _i = 0;

            public boolean progress(double[] beta, GradientInfo ginfo) {
                System.out.println(iter + " + " + ++_i + ":" + ginfo._objVal);
                return true;
            }
        });
        System.out.println();
        lbfgs = new L_BFGS().setMaxIter(100);
        L_BFGS.Result r3 = lbfgs.solve(solver, beta.clone(), solver.getGradient(beta), new L_BFGS.ProgressMonitor() {

            int _i = 0;

            public boolean progress(double[] beta, GradientInfo ginfo) {
                System.out.println(++_i + ":" + ginfo._objVal + ", " + ArrayUtils.l2norm2(ginfo._gradient, false));
                return true;
            }
        });
        assertEquals(r1.iter, 20);
        //      assertEquals (r1.iter + r2.iter,r3.iter); // should be equal? got mismatch by 2
        assertEquals(r2.ginfo._objVal, r3.ginfo._objVal, 1e-8);
        assertEquals(.5 * glmp._lambda[0] * ArrayUtils.l2norm(r3.coefs, true) + r3.ginfo._objVal, 1e-4, 5e-4);
        assertTrue("iter# expected < 100, got " + r3.iter, r3.iter < 100);
    } finally {
        if (dinfo != null)
            DKV.remove(dinfo._key);
        Value v = DKV.get(parsedKey);
        if (v != null) {
            v.<Frame>get().delete();
        }
    }
}
Also used : DataInfo(hex.DataInfo) Frame(water.fvec.Frame) GLMGradientSolver(hex.glm.GLM.GLMGradientSolver) GradientSolver(hex.optimization.OptimizationUtils.GradientSolver) GLMGradientSolver(hex.glm.GLM.GLMGradientSolver) GradientInfo(hex.optimization.OptimizationUtils.GradientInfo) GLMWeightsFun(hex.glm.GLMModel.GLMWeightsFun) GLMParameters(hex.glm.GLMModel.GLMParameters) Test(org.junit.Test)

Aggregations

DataInfo (hex.DataInfo)13 Frame (water.fvec.Frame)6 Test (org.junit.Test)5 BetaConstraint (hex.glm.GLM.BetaConstraint)3 DeepLearningParameters (hex.deeplearning.DeepLearningModel.DeepLearningParameters)2 GLMGradientSolver (hex.glm.GLM.GLMGradientSolver)2 GLMParameters (hex.glm.GLMModel.GLMParameters)2 GLMWeightsFun (hex.glm.GLMModel.GLMWeightsFun)2 Gram (hex.gram.Gram)2 GradientInfo (hex.optimization.OptimizationUtils.GradientInfo)2 Vec (water.fvec.Vec)2 ValFrame (water.rapids.vals.ValFrame)2 FrameTask (hex.FrameTask)1 ModelMetricsRegression (hex.ModelMetricsRegression)1 ToEigenVec (hex.ToEigenVec)1 DistributionFamily (hex.genmodel.utils.DistributionFamily)1 GLMGradientInfo (hex.glm.GLM.GLMGradientInfo)1 GLMModel (hex.glm.GLMModel)1 GLMOutput (hex.glm.GLMModel.GLMOutput)1 GradientSolver (hex.optimization.OptimizationUtils.GradientSolver)1