use of hex.naivebayes.NaiveBayesModel.NaiveBayesParameters in project h2o-3 by h2oai.
the class NaiveBayesTest method testIris.
@Test
public void testIris() throws InterruptedException, ExecutionException {
NaiveBayesModel model = null;
Frame train = null, score = null;
try {
train = parse_test_file(Key.make("iris_wheader.hex"), "smalldata/iris/iris_wheader.csv");
NaiveBayesParameters parms = new NaiveBayesParameters();
parms._train = train._key;
parms._laplace = 0;
parms._response_column = train._names[4];
parms._compute_metrics = false;
model = new NaiveBayes(parms).trainModel().get();
// Done building model; produce a score column with class assignments
score = model.score(train);
Assert.assertTrue(model.testJavaScoring(train, score, 1e-6));
} finally {
if (train != null)
train.delete();
if (score != null)
score.delete();
if (model != null)
model.delete();
}
}
use of hex.naivebayes.NaiveBayesModel.NaiveBayesParameters in project h2o-3 by h2oai.
the class NaiveBayesTest method testProstate.
@Test
public void testProstate() throws InterruptedException, ExecutionException {
NaiveBayesModel model = null;
Frame train = null, score = null;
// Categoricals: CAPSULE, RACE, DPROS, DCAPS
final int[] cats = new int[] { 1, 3, 4, 5 };
try {
Scope.enter();
train = parse_test_file(Key.make("prostate.hex"), "smalldata/logreg/prostate.csv");
for (int i = 0; i < cats.length; i++) Scope.track(train.replace(cats[i], train.vec(cats[i]).toCategoricalVec()));
train.remove("ID").remove();
DKV.put(train._key, train);
NaiveBayesParameters parms = new NaiveBayesParameters();
parms._train = train._key;
parms._laplace = 0;
parms._response_column = train._names[0];
parms._compute_metrics = true;
model = new NaiveBayes(parms).trainModel().get();
// Done building model; produce a score column with class assignments
score = model.score(train);
Assert.assertTrue(model.testJavaScoring(train, score, 1e-6));
} finally {
if (train != null)
train.delete();
if (score != null)
score.delete();
if (model != null)
model.delete();
Scope.exit();
}
}
use of hex.naivebayes.NaiveBayesModel.NaiveBayesParameters in project h2o-3 by h2oai.
the class NaiveBayesTest method testIrisValidation.
@Test
public void testIrisValidation() throws InterruptedException, ExecutionException {
NaiveBayesModel model = null;
Frame fr = null, fr2 = null;
Frame tr = null, te = null;
try {
fr = parse_test_file("smalldata/iris/iris_wheader.csv");
SplitFrame sf = new SplitFrame(fr, new double[] { 0.5, 0.5 }, new Key[] { Key.make("train.hex"), Key.make("test.hex") });
// Invoke the job
sf.exec().get();
Key[] ksplits = sf._destination_frames;
tr = DKV.get(ksplits[0]).get();
te = DKV.get(ksplits[1]).get();
NaiveBayesParameters parms = new NaiveBayesParameters();
parms._train = ksplits[0];
parms._valid = ksplits[1];
// Need Laplace smoothing
parms._laplace = 0.01;
parms._response_column = fr._names[4];
parms._compute_metrics = true;
model = new NaiveBayes(parms).trainModel().get();
// Done building model; produce a score column with class assignments
fr2 = model.score(te);
Assert.assertTrue(model.testJavaScoring(te, fr2, 1e-6));
} finally {
if (fr != null)
fr.delete();
if (fr2 != null)
fr2.delete();
if (tr != null)
tr.delete();
if (te != null)
te.delete();
if (model != null)
model.delete();
}
}
use of hex.naivebayes.NaiveBayesModel.NaiveBayesParameters in project h2o-3 by h2oai.
the class NaiveBayesTest method testCovtype.
@Test
public void testCovtype() throws InterruptedException, ExecutionException {
NaiveBayesModel model = null;
Frame train = null, score = null;
try {
Scope.enter();
train = parse_test_file(Key.make("covtype.hex"), "smalldata/covtype/covtype.20k.data");
// Change response to categorical
Scope.track(train.replace(54, train.vecs()[54].toCategoricalVec()));
DKV.put(train);
NaiveBayesParameters parms = new NaiveBayesParameters();
parms._train = train._key;
parms._laplace = 0;
parms._response_column = train._names[54];
parms._compute_metrics = false;
model = new NaiveBayes(parms).trainModel().get();
// Done building model; produce a score column with class assignments
score = model.score(train);
Assert.assertTrue(model.testJavaScoring(train, score, 1e-6));
} finally {
if (train != null)
train.delete();
if (score != null)
score.delete();
if (model != null)
model.delete();
Scope.exit();
}
}
Aggregations