use of hex.SplitFrame in project h2o-3 by h2oai.
the class DRFTest method testColSamplingPerTree.
@Test
public void testColSamplingPerTree() {
Frame tfr = null;
Key[] ksplits = new Key[0];
try {
tfr = parse_test_file("./smalldata/gbm_test/ecology_model.csv");
SplitFrame sf = new SplitFrame(tfr, new double[] { 0.5, 0.5 }, new Key[] { Key.make("train.hex"), Key.make("test.hex") });
// Invoke the job
sf.exec().get();
ksplits = sf._destination_frames;
DRFModel drf = null;
float[] sample_rates = new float[] { 0.2f, 0.4f, 0.6f, 0.8f, 1.0f };
float[] col_sample_rates = new float[] { 0.4f, 0.6f, 0.8f, 1.0f };
float[] col_sample_rates_per_tree = new float[] { 0.4f, 0.6f, 0.8f, 1.0f };
Map<Double, Triple<Float>> hm = new TreeMap<>();
for (float sample_rate : sample_rates) {
for (float col_sample_rate : col_sample_rates) {
for (float col_sample_rate_per_tree : col_sample_rates_per_tree) {
Scope.enter();
try {
DRFModel.DRFParameters parms = new DRFModel.DRFParameters();
parms._train = ksplits[0];
parms._valid = ksplits[1];
//regression
parms._response_column = "Angaus";
parms._seed = 12345;
parms._min_rows = 1;
parms._max_depth = 15;
parms._ntrees = 2;
parms._mtries = Math.max(1, (int) (col_sample_rate * (tfr.numCols() - 1)));
parms._col_sample_rate_per_tree = col_sample_rate_per_tree;
parms._sample_rate = sample_rate;
// Build a first model; all remaining models should be equal
DRF job = new DRF(parms);
drf = job.trainModel().get();
// too slow, but passes (now)
// // Build a POJO, validate same results
// Frame pred = drf.score(tfr);
// Assert.assertTrue(drf.testJavaScoring(tfr,pred,1e-15));
// pred.remove();
ModelMetricsRegression mm = (ModelMetricsRegression) drf._output._validation_metrics;
hm.put(mm.mse(), new Triple<>(sample_rate, col_sample_rate, col_sample_rate_per_tree));
} finally {
if (drf != null)
drf.delete();
Scope.exit();
}
}
}
}
Iterator<Map.Entry<Double, Triple<Float>>> it;
Triple<Float> last = null;
// iterator over results (min to max MSE) - best to worst
for (it = hm.entrySet().iterator(); it.hasNext(); ) {
Map.Entry<Double, Triple<Float>> n = it.next();
Log.info("MSE: " + n.getKey() + ", row sample: " + n.getValue().v1 + ", col sample: " + n.getValue().v2 + ", col sample per tree: " + n.getValue().v3);
last = n.getValue();
}
// worst validation MSE should belong to the most overfit case (1.0, 1.0, 1.0)
// Assert.assertTrue(last.v1==sample_rates[sample_rates.length-1]);
// Assert.assertTrue(last.v2==col_sample_rates[col_sample_rates.length-1]);
// Assert.assertTrue(last.v3==col_sample_rates_per_tree[col_sample_rates_per_tree.length-1]);
} finally {
if (tfr != null)
tfr.remove();
for (Key k : ksplits) if (k != null)
k.remove();
}
}
use of hex.SplitFrame in project h2o-3 by h2oai.
the class PCATest method testIrisSplitScoring.
@Test
public void testIrisSplitScoring() throws InterruptedException, ExecutionException {
PCAModel model = null;
Frame fr = null, fr2 = null;
Frame tr = null, te = null;
try {
fr = parse_test_file("smalldata/iris/iris_wheader.csv");
SplitFrame sf = new SplitFrame(fr, new double[] { 0.5, 0.5 }, new Key[] { Key.make("train.hex"), Key.make("test.hex") });
// Invoke the job
sf.exec().get();
Key[] ksplits = sf._destination_frames;
tr = DKV.get(ksplits[0]).get();
te = DKV.get(ksplits[1]).get();
PCAModel.PCAParameters parms = new PCAModel.PCAParameters();
parms._train = ksplits[0];
parms._valid = ksplits[1];
parms._k = 4;
parms._max_iterations = 1000;
parms._pca_method = PCAParameters.Method.GramSVD;
model = new PCA(parms).trainModel().get();
// Done building model; produce a score column with cluster choices
fr2 = model.score(te);
Assert.assertTrue(model.testJavaScoring(te, fr2, 1e-5));
} finally {
if (fr != null)
fr.delete();
if (fr2 != null)
fr2.delete();
if (tr != null)
tr.delete();
if (te != null)
te.delete();
if (model != null)
model.delete();
}
}
use of hex.SplitFrame in project h2o-3 by h2oai.
the class KMeansTest method testValidation.
@Test
public void testValidation() {
KMeansModel kmm = null;
for (boolean standardize : new boolean[] { true, false }) {
Frame fr = null, fr2 = null;
Frame tr = null, te = null;
try {
fr = parse_test_file("smalldata/iris/iris_wheader.csv");
SplitFrame sf = new SplitFrame(fr, new double[] { 0.5, 0.5 }, new Key[] { Key.make("train.hex"), Key.make("test.hex") });
// Invoke the job
sf.exec().get();
Key<Frame>[] ksplits = sf._destination_frames;
tr = DKV.get(ksplits[0]).get();
te = DKV.get(ksplits[1]).get();
KMeansModel.KMeansParameters parms = new KMeansModel.KMeansParameters();
parms._train = ksplits[0];
parms._valid = ksplits[1];
parms._k = 3;
parms._standardize = standardize;
parms._max_iterations = 10;
parms._init = KMeans.Initialization.Random;
kmm = doSeed(parms, 0);
// Iris last column is categorical; make sure centers are ordered in the
// same order as the iris columns.
double[][] /*features*/
centers = kmm._output._centers_raw;
for (int k = 0; k < parms._k; k++) {
double flower = centers[k][4];
Assert.assertTrue("categorical column expected", flower == (int) flower);
}
// Done building model; produce a score column with cluster choices
fr2 = kmm.score(te);
Assert.assertTrue(kmm.testJavaScoring(te, fr2, 1e-15));
} finally {
if (tr != null)
tr.delete();
if (te != null)
te.delete();
if (fr2 != null)
fr2.delete();
if (fr != null)
fr.delete();
if (kmm != null)
kmm.delete();
}
}
}
use of hex.SplitFrame in project h2o-3 by h2oai.
the class SVDTest method testIrisSplitScoring.
@Test
public void testIrisSplitScoring() throws InterruptedException, ExecutionException {
SVDModel model = null;
Frame fr = null, fr2 = null;
Frame tr = null, te = null;
try {
fr = parse_test_file("smalldata/iris/iris_wheader.csv");
SplitFrame sf = new SplitFrame(fr, new double[] { 0.5, 0.5 }, new Key[] { Key.make("train.hex"), Key.make("test.hex") });
// Invoke the job
sf.exec().get();
Key[] ksplits = sf._destination_frames;
tr = DKV.get(ksplits[0]).get();
te = DKV.get(ksplits[1]).get();
SVDModel.SVDParameters parms = new SVDModel.SVDParameters();
parms._train = ksplits[0];
parms._valid = ksplits[1];
parms._nv = 4;
parms._max_iterations = 1000;
parms._svd_method = SVDParameters.Method.Power;
parms._save_v_frame = false;
model = new SVD(parms).trainModel().get();
// Done building model; produce a score column with cluster choices
fr2 = model.score(te);
Assert.assertTrue(model.testJavaScoring(te, fr2, 1e-5));
} finally {
if (fr != null)
fr.delete();
if (fr2 != null)
fr2.delete();
if (tr != null)
tr.delete();
if (te != null)
te.delete();
if (model != null)
model.delete();
}
}
Aggregations