use of hex in project h2o-3 by h2oai.
the class DeepLearningTest method basicDL.
public void basicDL(String fnametrain, String hexnametrain, String fnametest, PrepData prep, int epochs, double[][] expCM, String[] expRespDom, double expMSE, int[] hidden, double l1, boolean classification, DeepLearningParameters.Activation activation) throws Throwable {
Scope.enter();
DeepLearningParameters dl = new DeepLearningParameters();
Frame frTest = null, pred = null;
Frame frTrain = null;
Frame test = null, res = null;
DeepLearningModel model = null;
try {
frTrain = parse_test_file(fnametrain);
Vec removeme = unifyFrame(dl, frTrain, prep, classification);
if (removeme != null)
Scope.track(removeme);
DKV.put(frTrain._key, frTrain);
// Configure DL
dl._train = frTrain._key;
dl._response_column = ((Frame) DKV.getGet(dl._train)).lastVecName();
dl._seed = (1L << 32) | 2;
dl._reproducible = true;
dl._epochs = epochs;
dl._stopping_rounds = 0;
dl._activation = activation;
dl._export_weights_and_biases = RandomUtils.getRNG(fnametrain.hashCode()).nextBoolean();
dl._hidden = hidden;
dl._l1 = l1;
dl._elastic_averaging = false;
// Invoke DL and block till the end
DeepLearning job = new DeepLearning(dl, Key.<DeepLearningModel>make("DL_model_" + hexnametrain));
// Get the model
model = job.trainModel().get();
Log.info(model._output);
//HEX-1817
assertTrue(job.isStopped());
hex.ModelMetrics mm;
if (fnametest != null) {
frTest = parse_test_file(fnametest);
pred = model.score(frTest);
mm = hex.ModelMetrics.getFromDKV(model, frTest);
// Check test set CM
} else {
pred = model.score(frTrain);
mm = hex.ModelMetrics.getFromDKV(model, frTrain);
}
test = parse_test_file(fnametrain);
res = model.score(test);
if (classification) {
assertTrue("Expected: " + Arrays.deepToString(expCM) + ", Got: " + Arrays.deepToString(mm.cm()._cm), Arrays.deepEquals(mm.cm()._cm, expCM));
String[] cmDom = model._output._domains[model._output._domains.length - 1];
Assert.assertArrayEquals("CM domain differs!", expRespDom, cmDom);
Log.info("\nTraining CM:\n" + mm.cm().toASCII());
Log.info("\nTraining CM:\n" + hex.ModelMetrics.getFromDKV(model, test).cm().toASCII());
} else {
assertTrue("Expected: " + expMSE + ", Got: " + mm.mse(), MathUtils.compare(expMSE, mm.mse(), 1e-8, 1e-8));
Log.info("\nOOB Training MSE: " + mm.mse());
Log.info("\nTraining MSE: " + hex.ModelMetrics.getFromDKV(model, test).mse());
}
hex.ModelMetrics.getFromDKV(model, test);
// Build a POJO, validate same results
assertTrue(model.testJavaScoring(test, res, 1e-5));
} finally {
if (frTrain != null)
frTrain.remove();
if (frTest != null)
frTest.remove();
// Remove the model
if (model != null)
model.delete();
if (pred != null)
pred.delete();
if (test != null)
test.delete();
if (res != null)
res.delete();
Scope.exit();
}
}
Aggregations