Search in sources :

Example 1 with hex

use of hex in project h2o-3 by h2oai.

the class DeepLearningTest method basicDL.

public void basicDL(String fnametrain, String hexnametrain, String fnametest, PrepData prep, int epochs, double[][] expCM, String[] expRespDom, double expMSE, int[] hidden, double l1, boolean classification, DeepLearningParameters.Activation activation) throws Throwable {
    Scope.enter();
    DeepLearningParameters dl = new DeepLearningParameters();
    Frame frTest = null, pred = null;
    Frame frTrain = null;
    Frame test = null, res = null;
    DeepLearningModel model = null;
    try {
        frTrain = parse_test_file(fnametrain);
        Vec removeme = unifyFrame(dl, frTrain, prep, classification);
        if (removeme != null)
            Scope.track(removeme);
        DKV.put(frTrain._key, frTrain);
        // Configure DL
        dl._train = frTrain._key;
        dl._response_column = ((Frame) DKV.getGet(dl._train)).lastVecName();
        dl._seed = (1L << 32) | 2;
        dl._reproducible = true;
        dl._epochs = epochs;
        dl._stopping_rounds = 0;
        dl._activation = activation;
        dl._export_weights_and_biases = RandomUtils.getRNG(fnametrain.hashCode()).nextBoolean();
        dl._hidden = hidden;
        dl._l1 = l1;
        dl._elastic_averaging = false;
        // Invoke DL and block till the end
        DeepLearning job = new DeepLearning(dl, Key.<DeepLearningModel>make("DL_model_" + hexnametrain));
        // Get the model
        model = job.trainModel().get();
        Log.info(model._output);
        //HEX-1817
        assertTrue(job.isStopped());
        hex.ModelMetrics mm;
        if (fnametest != null) {
            frTest = parse_test_file(fnametest);
            pred = model.score(frTest);
            mm = hex.ModelMetrics.getFromDKV(model, frTest);
        // Check test set CM
        } else {
            pred = model.score(frTrain);
            mm = hex.ModelMetrics.getFromDKV(model, frTrain);
        }
        test = parse_test_file(fnametrain);
        res = model.score(test);
        if (classification) {
            assertTrue("Expected: " + Arrays.deepToString(expCM) + ", Got: " + Arrays.deepToString(mm.cm()._cm), Arrays.deepEquals(mm.cm()._cm, expCM));
            String[] cmDom = model._output._domains[model._output._domains.length - 1];
            Assert.assertArrayEquals("CM domain differs!", expRespDom, cmDom);
            Log.info("\nTraining CM:\n" + mm.cm().toASCII());
            Log.info("\nTraining CM:\n" + hex.ModelMetrics.getFromDKV(model, test).cm().toASCII());
        } else {
            assertTrue("Expected: " + expMSE + ", Got: " + mm.mse(), MathUtils.compare(expMSE, mm.mse(), 1e-8, 1e-8));
            Log.info("\nOOB Training MSE: " + mm.mse());
            Log.info("\nTraining MSE: " + hex.ModelMetrics.getFromDKV(model, test).mse());
        }
        hex.ModelMetrics.getFromDKV(model, test);
        // Build a POJO, validate same results
        assertTrue(model.testJavaScoring(test, res, 1e-5));
    } finally {
        if (frTrain != null)
            frTrain.remove();
        if (frTest != null)
            frTest.remove();
        // Remove the model
        if (model != null)
            model.delete();
        if (pred != null)
            pred.delete();
        if (test != null)
            test.delete();
        if (res != null)
            res.delete();
        Scope.exit();
    }
}
Also used : Frame(water.fvec.Frame) NFSFileVec(water.fvec.NFSFileVec) Vec(water.fvec.Vec) DeepLearningParameters(hex.deeplearning.DeepLearningModel.DeepLearningParameters) hex(hex)

Aggregations

hex (hex)1 DeepLearningParameters (hex.deeplearning.DeepLearningModel.DeepLearningParameters)1 Frame (water.fvec.Frame)1 NFSFileVec (water.fvec.NFSFileVec)1 Vec (water.fvec.Vec)1