use of hex.deeplearning.DeepLearning in project h2o-3 by h2oai.
the class TestCase method execute.
public TestCaseResult execute() throws Exception, AssertionError {
loadTestCaseDataSets();
makeModelParameters();
double startTime = 0, stopTime = 0;
if (!grid) {
Model.Output modelOutput = null;
DRF drfJob;
DRFModel drfModel = null;
GLM glmJob;
GLMModel glmModel = null;
GBM gbmJob;
GBMModel gbmModel = null;
DeepLearning dlJob;
DeepLearningModel dlModel = null;
String bestModelJson = null;
try {
switch(algo) {
case "drf":
drfJob = new DRF((DRFModel.DRFParameters) params);
AccuracyTestingSuite.summaryLog.println("Training DRF model.");
startTime = System.currentTimeMillis();
drfModel = drfJob.trainModel().get();
stopTime = System.currentTimeMillis();
modelOutput = drfModel._output;
bestModelJson = drfModel._parms.toJsonString();
break;
case "glm":
glmJob = new GLM((GLMModel.GLMParameters) params, Key.<GLMModel>make("GLMModel"));
AccuracyTestingSuite.summaryLog.println("Training GLM model.");
startTime = System.currentTimeMillis();
glmModel = glmJob.trainModel().get();
stopTime = System.currentTimeMillis();
modelOutput = glmModel._output;
bestModelJson = glmModel._parms.toJsonString();
break;
case "gbm":
gbmJob = new GBM((GBMModel.GBMParameters) params);
AccuracyTestingSuite.summaryLog.println("Training GBM model.");
startTime = System.currentTimeMillis();
gbmModel = gbmJob.trainModel().get();
stopTime = System.currentTimeMillis();
modelOutput = gbmModel._output;
bestModelJson = gbmModel._parms.toJsonString();
break;
case "dl":
dlJob = new DeepLearning((DeepLearningModel.DeepLearningParameters) params);
AccuracyTestingSuite.summaryLog.println("Training DL model.");
startTime = System.currentTimeMillis();
dlModel = dlJob.trainModel().get();
stopTime = System.currentTimeMillis();
modelOutput = dlModel._output;
bestModelJson = dlModel._parms.toJsonString();
break;
}
} catch (Exception e) {
throw new Exception(e);
} finally {
if (drfModel != null) {
drfModel.delete();
}
if (glmModel != null) {
glmModel.delete();
}
if (gbmModel != null) {
gbmModel.delete();
}
if (dlModel != null) {
dlModel.delete();
}
}
removeTestCaseDataSetFrames();
//Add check if cv is used
if (params._nfolds > 0) {
return new TestCaseResult(testCaseId, getMetrics(modelOutput._training_metrics), getMetrics(modelOutput._cross_validation_metrics), stopTime - startTime, bestModelJson, this, trainingDataSet, testingDataSet);
} else {
return new TestCaseResult(testCaseId, getMetrics(modelOutput._training_metrics), getMetrics(modelOutput._validation_metrics), stopTime - startTime, bestModelJson, this, trainingDataSet, testingDataSet);
}
} else {
assert !modelSelectionCriteria.equals("");
makeGridParameters();
makeSearchCriteria();
Grid grid = null;
Model bestModel = null;
String bestModelJson = null;
try {
SchemaServer.registerAllSchemasIfNecessary();
switch(// TODO: Hack for PUBDEV-2812
algo) {
case "drf":
if (!drfRegistered) {
new DRF(true);
new DRFParametersV3();
drfRegistered = true;
}
break;
case "glm":
if (!glmRegistered) {
new GLM(true);
new GLMParametersV3();
glmRegistered = true;
}
break;
case "gbm":
if (!gbmRegistered) {
new GBM(true);
new GBMParametersV3();
gbmRegistered = true;
}
break;
case "dl":
if (!dlRegistered) {
new DeepLearning(true);
new DeepLearningParametersV3();
dlRegistered = true;
}
break;
}
startTime = System.currentTimeMillis();
// TODO: ModelParametersBuilderFactory parameter must be instantiated properly
Job<Grid> gs = GridSearch.startGridSearch(null, params, hyperParms, new GridSearch.SimpleParametersBuilderFactory<>(), searchCriteria);
grid = gs.get();
stopTime = System.currentTimeMillis();
boolean higherIsBetter = higherIsBetter(modelSelectionCriteria);
double bestScore = higherIsBetter ? -Double.MAX_VALUE : Double.MAX_VALUE;
for (Model m : grid.getModels()) {
double validationMetricScore = getMetrics(m._output._validation_metrics).get(modelSelectionCriteria);
AccuracyTestingSuite.summaryLog.println(modelSelectionCriteria + " for model " + m._key.toString() + " is " + validationMetricScore);
if (higherIsBetter ? validationMetricScore > bestScore : validationMetricScore < bestScore) {
bestScore = validationMetricScore;
bestModel = m;
bestModelJson = bestModel._parms.toJsonString();
}
}
AccuracyTestingSuite.summaryLog.println("Best model: " + bestModel._key.toString());
AccuracyTestingSuite.summaryLog.println("Best model parameters: " + bestModelJson);
} catch (Exception e) {
throw new Exception(e);
} finally {
if (grid != null) {
grid.delete();
}
}
removeTestCaseDataSetFrames();
//Add check if cv is used
if (params._nfolds > 0) {
return new TestCaseResult(testCaseId, getMetrics(bestModel._output._training_metrics), getMetrics(bestModel._output._cross_validation_metrics), stopTime - startTime, bestModelJson, this, trainingDataSet, testingDataSet);
} else {
return new TestCaseResult(testCaseId, getMetrics(bestModel._output._training_metrics), getMetrics(bestModel._output._validation_metrics), stopTime - startTime, bestModelJson, this, trainingDataSet, testingDataSet);
}
}
}
use of hex.deeplearning.DeepLearning in project h2o-2 by h2oai.
the class DeepLearningMissingTest method run.
@Test
public void run() {
long seed = new Random().nextLong();
DeepLearningModel mymodel = null;
Frame train = null;
Frame test = null;
Frame data = null;
DeepLearning p;
Log.info("");
Log.info("STARTING.");
Log.info("Using seed " + seed);
Map<DeepLearning.MissingValuesHandling, Double> sumErr = new TreeMap<DeepLearning.MissingValuesHandling, Double>();
StringBuilder sb = new StringBuilder();
for (DeepLearning.MissingValuesHandling mvh : new DeepLearning.MissingValuesHandling[] { DeepLearning.MissingValuesHandling.Skip, DeepLearning.MissingValuesHandling.MeanImputation }) {
double sumerr = 0;
Map<Double, Double> map = new TreeMap<Double, Double>();
for (double missing_fraction : new double[] { 0, 0.1, 0.25, 0.5, 0.75, 1 }) {
try {
Key file = NFSFileVec.make(find_test_file("smalldata/weather.csv"));
// Key file = NFSFileVec.make(find_test_file("smalldata/mnist/test.csv.gz"));
data = ParseDataset2.parse(Key.make("data.hex"), new Key[] { file });
// Create holdout test data on clean data (before adding missing values)
FrameSplitter fs = new FrameSplitter(data, new float[] { 0.75f });
H2O.submitTask(fs).join();
Frame[] train_test = fs.getResult();
train = train_test[0];
test = train_test[1];
// add missing values to the training data (excluding the response)
if (missing_fraction > 0) {
Frame frtmp = new Frame(Key.make(), train.names(), train.vecs());
//exclude the response
frtmp.remove(frtmp.numCols() - 1);
DKV.put(frtmp._key, frtmp);
InsertMissingValues imv = new InsertMissingValues();
imv.missing_fraction = missing_fraction;
//use the same seed for Skip and MeanImputation!
imv.seed = seed;
imv.key = frtmp._key;
imv.serve();
//just remove the Frame header (not the chunks)
DKV.remove(frtmp._key);
}
// Build a regularized DL model with polluted training data, score on clean validation set
p = new DeepLearning();
p.source = train;
p.validation = test;
p.response = train.lastVec();
//only for weather data
p.ignored_cols = new int[] { 1, 22 };
p.missing_values_handling = mvh;
p.activation = DeepLearning.Activation.RectifierWithDropout;
p.hidden = new int[] { 200, 200 };
p.l1 = 1e-5;
p.input_dropout_ratio = 0.2;
p.epochs = 10;
p.quiet_mode = true;
try {
Log.info("Starting with " + missing_fraction * 100 + "% missing values added.");
p.invoke();
} catch (Throwable t) {
t.printStackTrace();
throw new RuntimeException(t);
} finally {
p.delete();
}
// Extract the scoring on validation set from the model
mymodel = UKV.get(p.dest());
DeepLearningModel.Errors[] errs = mymodel.scoring_history();
DeepLearningModel.Errors lasterr = errs[errs.length - 1];
double err = lasterr.valid_err;
Log.info("Missing " + missing_fraction * 100 + "% -> Err: " + err);
map.put(missing_fraction, err);
sumerr += err;
} catch (Throwable t) {
t.printStackTrace();
throw new RuntimeException(t);
} finally {
// cleanup
if (mymodel != null) {
mymodel.delete_xval_models();
mymodel.delete_best_model();
mymodel.delete();
}
if (train != null)
train.delete();
if (test != null)
test.delete();
if (data != null)
data.delete();
}
}
sb.append("\nMethod: " + mvh.toString() + "\n");
sb.append("missing fraction --> Error\n");
for (String s : Arrays.toString(map.entrySet().toArray()).split(",")) sb.append(s.replace("=", " --> ")).append("\n");
sb.append('\n');
sb.append("Sum Err: " + sumerr + "\n");
sumErr.put(mvh, sumerr);
}
Log.info(sb.toString());
Assert.assertTrue(sumErr.get(DeepLearning.MissingValuesHandling.Skip) > sumErr.get(DeepLearning.MissingValuesHandling.MeanImputation));
//this holds true for both datasets
Assert.assertTrue(sumErr.get(DeepLearning.MissingValuesHandling.MeanImputation) < 2);
}
use of hex.deeplearning.DeepLearning in project h2o-2 by h2oai.
the class DeepLearningAutoEncoderCategoricalTest method run.
@Test
public void run() {
long seed = 0xDECAF;
Key file_train = NFSFileVec.make(find_test_file(PATH));
Frame train = ParseDataset2.parse(Key.make(), new Key[] { file_train });
DeepLearning p = new DeepLearning();
p.source = train;
p.autoencoder = true;
p.response = train.lastVec();
p.seed = seed;
p.hidden = new int[] { 100, 50, 20 };
// p.ignored_cols = new int[]{0,1,2,3,6,7,8,10}; //Optional: ignore all categoricals
// p.ignored_cols = new int[]{4,5,9}; //Optional: ignore all numericals
p.adaptive_rate = true;
p.l1 = 1e-4;
p.activation = DeepLearning.Activation.Tanh;
p.train_samples_per_iteration = -1;
p.loss = DeepLearning.Loss.MeanSquare;
p.epochs = 2;
// p.shuffle_training_data = true;
p.force_load_balance = true;
p.score_training_samples = 0;
p.score_validation_samples = 0;
// p.reproducible = true;
p.invoke();
// Verification of results
StringBuilder sb = new StringBuilder();
sb.append("Verifying results.\n");
DeepLearningModel mymodel = UKV.get(p.dest());
sb.append("Reported mean reconstruction error: " + mymodel.mse() + "\n");
// Training data
// Reconstruct data using the same helper functions and verify that self-reported MSE agrees
final Frame l2 = mymodel.scoreAutoEncoder(train);
final Vec l2vec = l2.anyVec();
sb.append("Actual mean reconstruction error: " + l2vec.mean() + "\n");
// print stats and potential outliers
double quantile = 1 - 5. / train.numRows();
sb.append("The following training points are reconstructed with an error above the " + quantile * 100 + "-th percentile - potential \"outliers\" in testing data.\n");
double thresh = mymodel.calcOutlierThreshold(l2vec, quantile);
for (long i = 0; i < l2vec.length(); i++) {
if (l2vec.at(i) > thresh) {
sb.append(String.format("row %d : l2vec error = %5f\n", i, l2vec.at(i)));
}
}
Log.info(sb.toString());
Assert.assertEquals(mymodel.mse(), l2vec.mean(), 1e-8);
// Create reconstruction
Log.info("Creating full reconstruction.");
final Frame recon_train = mymodel.score(train);
// cleanup
recon_train.delete();
train.delete();
p.delete();
mymodel.delete();
l2.delete();
}
use of hex.deeplearning.DeepLearning in project h2o-2 by h2oai.
the class DeepLearningMnist method execImpl.
@Override
protected void execImpl() {
Log.info("Parsing data.");
//long seed = 0xC0FFEE;
long seed = new Random().nextLong();
double fraction = 1.0;
// Frame trainf = sampleFrame(TestUtil.parseFromH2OFolder("smalldata/mnist/train10x.csv"), (long)(600000*fraction), seed);
Frame trainf = sampleFrame(TestUtil.parseFromH2OFolder("smalldata/mnist/train.csv.gz"), (long) (60000 * fraction), seed);
Frame testf = sampleFrame(TestUtil.parseFromH2OFolder("smalldata/mnist/test.csv.gz"), (long) (10000 * fraction), seed + 1);
Log.info("Done.");
DeepLearning p = new DeepLearning();
// Hinton parameters -> should lead to ~1 % test error after a few dozen million samples
p.seed = seed;
// p.hidden = new int[]{1024,1024,2048};
p.hidden = new int[] { 128, 128, 256 };
p.activation = DeepLearning.Activation.RectifierWithDropout;
p.loss = DeepLearning.Loss.CrossEntropy;
p.input_dropout_ratio = 0.2;
p.epochs = 10;
p.l1 = 1e-5;
p.l2 = 0;
if (true) {
// automatic learning rate
p.adaptive_rate = true;
p.rho = 0.99;
p.epsilon = 1e-8;
// p.max_w2 = 15;
p.max_w2 = Float.POSITIVE_INFINITY;
} else {
// manual learning rate
p.adaptive_rate = false;
p.rate = 0.01;
p.rate_annealing = 1e-6;
p.momentum_start = 0.5;
p.momentum_ramp = 1800000;
p.momentum_stable = 0.99;
// p.max_w2 = 15;
p.max_w2 = Float.POSITIVE_INFINITY;
}
p.initial_weight_distribution = DeepLearning.InitialWeightDistribution.UniformAdaptive;
// p.initial_weight_scale = 0.01
p.classification = true;
p.diagnostics = true;
p.expert_mode = true;
p.score_training_samples = 1000;
p.score_validation_samples = 10000;
p.validation = testf;
p.source = trainf;
p.response = trainf.lastVec();
p.ignored_cols = null;
p.classification_stop = -1;
p.train_samples_per_iteration = -1;
p.score_interval = 30;
p.variable_importances = false;
//to match old NeuralNet behavior
p.fast_mode = true;
// p.ignore_const_cols = true;
//to match old NeuralNet behavior and to have images look straight
p.ignore_const_cols = false;
p.shuffle_training_data = false;
p.force_load_balance = true;
p.replicate_training_data = true;
p.quiet_mode = false;
p.invoke();
// visualize((DeepLearningModel) UKV.get(p.dest()));
}
use of hex.deeplearning.DeepLearning in project h2o-3 by h2oai.
the class XValPredictionsCheck method testXValPredictions.
@Test
public void testXValPredictions() {
final int nfolds = 3;
Frame tfr = null;
try {
// Load data, hack frames
tfr = parse_test_file("smalldata/iris/iris_wheader.csv");
Frame foldId = new Frame(new String[] { "foldId" }, new Vec[] { AstKFold.kfoldColumn(tfr.vec("class").makeZero(), nfolds, 543216789) });
tfr.add(foldId);
DKV.put(tfr);
// GBM
GBMModel.GBMParameters parms = new GBMModel.GBMParameters();
parms._train = tfr._key;
parms._response_column = "class";
parms._ntrees = 1;
parms._max_depth = 1;
parms._fold_column = "foldId";
parms._distribution = DistributionFamily.multinomial;
parms._keep_cross_validation_predictions = true;
GBM job = new GBM(parms);
GBMModel gbm = job.trainModel().get();
checkModel(gbm, foldId.anyVec(), 3);
// DRF
DRFModel.DRFParameters parmsDRF = new DRFModel.DRFParameters();
parmsDRF._train = tfr._key;
parmsDRF._response_column = "class";
parmsDRF._ntrees = 1;
parmsDRF._max_depth = 1;
parmsDRF._fold_column = "foldId";
parmsDRF._distribution = DistributionFamily.multinomial;
parmsDRF._keep_cross_validation_predictions = true;
DRF drfJob = new DRF(parmsDRF);
DRFModel drf = drfJob.trainModel().get();
checkModel(drf, foldId.anyVec(), 3);
// GLM
GLMModel.GLMParameters parmsGLM = new GLMModel.GLMParameters();
parmsGLM._train = tfr._key;
parmsGLM._response_column = "sepal_len";
parmsGLM._fold_column = "foldId";
parmsGLM._keep_cross_validation_predictions = true;
GLM glmJob = new GLM(parmsGLM);
GLMModel glm = glmJob.trainModel().get();
checkModel(glm, foldId.anyVec(), 1);
// DL
DeepLearningModel.DeepLearningParameters parmsDL = new DeepLearningModel.DeepLearningParameters();
parmsDL._train = tfr._key;
parmsDL._response_column = "class";
parmsDL._hidden = new int[] { 1 };
parmsDL._epochs = 1;
parmsDL._fold_column = "foldId";
parmsDL._keep_cross_validation_predictions = true;
DeepLearning dlJob = new DeepLearning(parmsDL);
DeepLearningModel dl = dlJob.trainModel().get();
checkModel(dl, foldId.anyVec(), 3);
} finally {
if (tfr != null)
tfr.remove();
}
}
Aggregations