Search in sources :

Example 6 with ModelMetricsMultinomial

use of hex.ModelMetricsMultinomial in project h2o-3 by h2oai.

the class DeepWaterAbstractIntegrationTest method MNISTSparse.

@Test
public void MNISTSparse() {
    Frame tr = null;
    Frame va = null;
    DeepWaterModel m = null;
    try {
        DeepWaterParameters p = new DeepWaterParameters();
        File file = FileUtils.locateFile("bigdata/laptop/mnist/train.csv.gz");
        File valid = FileUtils.locateFile("bigdata/laptop/mnist/test.csv.gz");
        if (file != null) {
            p._response_column = "C785";
            NFSFileVec trainfv = NFSFileVec.make(file);
            tr = ParseDataset.parse(Key.make(), trainfv._key);
            NFSFileVec validfv = NFSFileVec.make(valid);
            va = ParseDataset.parse(Key.make(), validfv._key);
            for (String col : new String[] { p._response_column }) {
                Vec v = tr.remove(col);
                tr.add(col, v.toCategoricalVec());
                v.remove();
                v = va.remove(col);
                va.add(col, v.toCategoricalVec());
                v.remove();
            }
            DKV.put(tr);
            DKV.put(va);
            p._backend = getBackend();
            p._train = tr._key;
            p._valid = va._key;
            p._hidden = new int[] { 500, 500 };
            p._sparse = true;
            DeepWater j = new DeepWater(p);
            m = j.trainModel().get();
            Assert.assertTrue(((ModelMetricsMultinomial) (m._output._validation_metrics)).mean_per_class_error() < 0.05);
        }
    } finally {
        if (tr != null)
            tr.remove();
        if (va != null)
            va.remove();
        if (m != null)
            m.remove();
    }
}
Also used : Frame(water.fvec.Frame) ShuffleSplitFrame(hex.splitframe.ShuffleSplitFrame) NFSFileVec(water.fvec.NFSFileVec) Vec(water.fvec.Vec) NFSFileVec(water.fvec.NFSFileVec) ModelMetricsMultinomial(hex.ModelMetricsMultinomial)

Example 7 with ModelMetricsMultinomial

use of hex.ModelMetricsMultinomial in project h2o-3 by h2oai.

the class DeepWaterAbstractIntegrationTest method MOJOTestImage.

private void MOJOTestImage(DeepWaterParameters.Network network) {
    Frame tr = null;
    DeepWaterModel m = null;
    Frame preds = null;
    try {
        DeepWaterParameters p = new DeepWaterParameters();
        p._backend = getBackend();
        p._train = (tr = parse_test_file("bigdata/laptop/deepwater/imagenet/cat_dog_mouse.csv"))._key;
        p._response_column = "C2";
        p._learning_rate = 1e-4;
        p._network = network;
        p._mini_batch_size = 4;
        p._train_samples_per_iteration = 8;
        p._epochs = 1e-3;
        m = new DeepWater(p).trainModel().get();
        // Score original training frame
        preds = m.score(tr);
        Assert.assertTrue(m.testJavaScoring(tr, preds, 1e-3));
        preds.remove(0).remove();
        double logloss = ModelMetricsMultinomial.make(preds, tr.vec(p._response_column)).logloss();
        Assert.assertTrue(Math.abs(logloss - ((ModelMetricsMultinomial) m._output._training_metrics).logloss()) < 1e-3);
    } finally {
        if (tr != null)
            tr.remove();
        if (m != null)
            m.remove();
        if (preds != null)
            preds.remove();
    }
}
Also used : Frame(water.fvec.Frame) ShuffleSplitFrame(hex.splitframe.ShuffleSplitFrame) ModelMetricsMultinomial(hex.ModelMetricsMultinomial)

Example 8 with ModelMetricsMultinomial

use of hex.ModelMetricsMultinomial in project h2o-3 by h2oai.

the class DeepWaterAbstractIntegrationTest method overWriteWithBestModel.

@Test
public void overWriteWithBestModel() {
    DeepWaterModel m = null;
    Frame tr = null;
    try {
        DeepWaterParameters p = new DeepWaterParameters();
        p._backend = getBackend();
        p._train = (tr = parse_test_file("bigdata/laptop/deepwater/imagenet/cat_dog_mouse.csv"))._key;
        p._response_column = "C2";
        p._epochs = 50;
        p._learning_rate = 0.01;
        p._momentum_start = 0.5;
        p._momentum_stable = 0.5;
        p._stopping_rounds = 0;
        p._image_shape = new int[] { 28, 28 };
        p._network = lenet;
        p._problem_type = DeepWaterParameters.ProblemType.image;
        // score a lot
        p._train_samples_per_iteration = p._mini_batch_size;
        p._score_duty_cycle = 1;
        p._score_interval = 0;
        p._overwrite_with_best_model = true;
        m = new DeepWater(p).trainModel().get();
        Log.info(m);
        Assert.assertTrue(((ModelMetricsMultinomial) m._output._training_metrics).logloss() < 2);
    } finally {
        if (m != null)
            m.remove();
        if (tr != null)
            tr.remove();
    }
}
Also used : Frame(water.fvec.Frame) ShuffleSplitFrame(hex.splitframe.ShuffleSplitFrame) ModelMetricsMultinomial(hex.ModelMetricsMultinomial)

Aggregations

ModelMetricsMultinomial (hex.ModelMetricsMultinomial)8 ShuffleSplitFrame (hex.splitframe.ShuffleSplitFrame)8 Frame (water.fvec.Frame)8 NFSFileVec (water.fvec.NFSFileVec)3 Vec (water.fvec.Vec)3 FrameSplitter (hex.FrameSplitter)2