use of hex.deeplearning.DeepLearningModel in project h2o-2 by h2oai.
the class DeepLearningAutoEncoderCategoricalTest method run.
@Test
public void run() {
long seed = 0xDECAF;
Key file_train = NFSFileVec.make(find_test_file(PATH));
Frame train = ParseDataset2.parse(Key.make(), new Key[] { file_train });
DeepLearning p = new DeepLearning();
p.source = train;
p.autoencoder = true;
p.response = train.lastVec();
p.seed = seed;
p.hidden = new int[] { 100, 50, 20 };
// p.ignored_cols = new int[]{0,1,2,3,6,7,8,10}; //Optional: ignore all categoricals
// p.ignored_cols = new int[]{4,5,9}; //Optional: ignore all numericals
p.adaptive_rate = true;
p.l1 = 1e-4;
p.activation = DeepLearning.Activation.Tanh;
p.train_samples_per_iteration = -1;
p.loss = DeepLearning.Loss.MeanSquare;
p.epochs = 2;
// p.shuffle_training_data = true;
p.force_load_balance = true;
p.score_training_samples = 0;
p.score_validation_samples = 0;
// p.reproducible = true;
p.invoke();
// Verification of results
StringBuilder sb = new StringBuilder();
sb.append("Verifying results.\n");
DeepLearningModel mymodel = UKV.get(p.dest());
sb.append("Reported mean reconstruction error: " + mymodel.mse() + "\n");
// Training data
// Reconstruct data using the same helper functions and verify that self-reported MSE agrees
final Frame l2 = mymodel.scoreAutoEncoder(train);
final Vec l2vec = l2.anyVec();
sb.append("Actual mean reconstruction error: " + l2vec.mean() + "\n");
// print stats and potential outliers
double quantile = 1 - 5. / train.numRows();
sb.append("The following training points are reconstructed with an error above the " + quantile * 100 + "-th percentile - potential \"outliers\" in testing data.\n");
double thresh = mymodel.calcOutlierThreshold(l2vec, quantile);
for (long i = 0; i < l2vec.length(); i++) {
if (l2vec.at(i) > thresh) {
sb.append(String.format("row %d : l2vec error = %5f\n", i, l2vec.at(i)));
}
}
Log.info(sb.toString());
Assert.assertEquals(mymodel.mse(), l2vec.mean(), 1e-8);
// Create reconstruction
Log.info("Creating full reconstruction.");
final Frame recon_train = mymodel.score(train);
// cleanup
recon_train.delete();
train.delete();
p.delete();
mymodel.delete();
l2.delete();
}
use of hex.deeplearning.DeepLearningModel in project h2o-3 by h2oai.
the class XValPredictionsCheck method testXValPredictions.
@Test
public void testXValPredictions() {
final int nfolds = 3;
Frame tfr = null;
try {
// Load data, hack frames
tfr = parse_test_file("smalldata/iris/iris_wheader.csv");
Frame foldId = new Frame(new String[] { "foldId" }, new Vec[] { AstKFold.kfoldColumn(tfr.vec("class").makeZero(), nfolds, 543216789) });
tfr.add(foldId);
DKV.put(tfr);
// GBM
GBMModel.GBMParameters parms = new GBMModel.GBMParameters();
parms._train = tfr._key;
parms._response_column = "class";
parms._ntrees = 1;
parms._max_depth = 1;
parms._fold_column = "foldId";
parms._distribution = DistributionFamily.multinomial;
parms._keep_cross_validation_predictions = true;
GBM job = new GBM(parms);
GBMModel gbm = job.trainModel().get();
checkModel(gbm, foldId.anyVec(), 3);
// DRF
DRFModel.DRFParameters parmsDRF = new DRFModel.DRFParameters();
parmsDRF._train = tfr._key;
parmsDRF._response_column = "class";
parmsDRF._ntrees = 1;
parmsDRF._max_depth = 1;
parmsDRF._fold_column = "foldId";
parmsDRF._distribution = DistributionFamily.multinomial;
parmsDRF._keep_cross_validation_predictions = true;
DRF drfJob = new DRF(parmsDRF);
DRFModel drf = drfJob.trainModel().get();
checkModel(drf, foldId.anyVec(), 3);
// GLM
GLMModel.GLMParameters parmsGLM = new GLMModel.GLMParameters();
parmsGLM._train = tfr._key;
parmsGLM._response_column = "sepal_len";
parmsGLM._fold_column = "foldId";
parmsGLM._keep_cross_validation_predictions = true;
GLM glmJob = new GLM(parmsGLM);
GLMModel glm = glmJob.trainModel().get();
checkModel(glm, foldId.anyVec(), 1);
// DL
DeepLearningModel.DeepLearningParameters parmsDL = new DeepLearningModel.DeepLearningParameters();
parmsDL._train = tfr._key;
parmsDL._response_column = "class";
parmsDL._hidden = new int[] { 1 };
parmsDL._epochs = 1;
parmsDL._fold_column = "foldId";
parmsDL._keep_cross_validation_predictions = true;
DeepLearning dlJob = new DeepLearning(parmsDL);
DeepLearningModel dl = dlJob.trainModel().get();
checkModel(dl, foldId.anyVec(), 3);
} finally {
if (tfr != null)
tfr.remove();
}
}
Aggregations