use of hex.tree.drf.DRFModel in project h2o-3 by h2oai.
the class StackedEnsembleModel method distributionFamily.
private DistributionFamily distributionFamily(Model aModel) {
// TODO: hack alert: In DRF, _parms._distribution is always set to multinomial. Yay.
if (aModel instanceof DRFModel)
if (aModel._output.isBinomialClassifier())
return DistributionFamily.bernoulli;
else if (aModel._output.isClassifier())
throw new H2OIllegalArgumentException("Don't know how to set the distribution for a multinomial Random Forest classifier.");
else
return DistributionFamily.gaussian;
try {
Field familyField = ReflectionUtils.findNamedField(aModel._parms, "_family");
Field distributionField = (familyField != null ? null : ReflectionUtils.findNamedField(aModel, "_dist"));
if (null != familyField) {
// GLM only, for now
GLMModel.GLMParameters.Family thisFamily = (GLMModel.GLMParameters.Family) familyField.get(aModel._parms);
if (thisFamily == GLMModel.GLMParameters.Family.binomial) {
return DistributionFamily.bernoulli;
}
try {
return Enum.valueOf(DistributionFamily.class, thisFamily.toString());
} catch (IllegalArgumentException e) {
throw new H2OIllegalArgumentException("Don't know how to find the right DistributionFamily for Family: " + thisFamily);
}
}
if (null != distributionField) {
Distribution distribution = ((Distribution) distributionField.get(aModel));
DistributionFamily distributionFamily;
if (null != distribution)
distributionFamily = distribution.distribution;
else
distributionFamily = aModel._parms._distribution;
// NOTE: If the algo does smart guessing of the distribution family we need to duplicate the logic here.
if (distributionFamily == DistributionFamily.AUTO) {
if (aModel._output.isBinomialClassifier())
distributionFamily = DistributionFamily.bernoulli;
else if (aModel._output.isClassifier())
throw new H2OIllegalArgumentException("Don't know how to determine the distribution for a multinomial classifier.");
else
distributionFamily = DistributionFamily.gaussian;
}
return distributionFamily;
}
throw new H2OIllegalArgumentException("Don't know how to stack models that have neither a distribution hyperparameter nor a family hyperparameter.");
} catch (Exception e) {
throw new H2OIllegalArgumentException(e.toString(), e.toString());
}
}
use of hex.tree.drf.DRFModel in project h2o-3 by h2oai.
the class StackedEnsembleModel method checkAndInheritModelProperties.
public void checkAndInheritModelProperties() {
if (null == _parms._base_models || 0 == _parms._base_models.length)
throw new H2OIllegalArgumentException("When creating a StackedEnsemble you must specify one or more models; found 0.");
Model aModel = null;
boolean beenHere = false;
trainingFrameChecksum = _parms.train().checksum();
for (Key<Model> k : _parms._base_models) {
aModel = DKV.getGet(k);
if (null == aModel) {
Log.warn("Failed to find base model; skipping: " + k);
continue;
}
if (beenHere) {
// check that the base models are all consistent
if (_output._isSupervised ^ aModel.isSupervised())
throw new H2OIllegalArgumentException("Base models are inconsistent: there is a mix of supervised and unsupervised models: " + Arrays.toString(_parms._base_models));
if (modelCategory != aModel._output.getModelCategory())
throw new H2OIllegalArgumentException("Base models are inconsistent: there is a mix of different categories of models: " + Arrays.toString(_parms._base_models));
Frame aTrainingFrame = aModel._parms.train();
if (trainingFrameChecksum != aTrainingFrame.checksum())
throw new H2OIllegalArgumentException("Base models are inconsistent: they use different training frames. Found checksums: " + trainingFrameChecksum + " and: " + aTrainingFrame.checksum() + ".");
NonBlockingHashSet<String> aNames = new NonBlockingHashSet<>();
aNames.addAll(Arrays.asList(aModel._output._names));
if (!aNames.equals(this.names))
throw new H2OIllegalArgumentException("Base models are inconsistent: they use different column lists. Found: " + this.names + " and: " + aNames + ".");
NonBlockingHashSet<String> anIgnoredColumns = new NonBlockingHashSet<>();
if (null != aModel._parms._ignored_columns)
anIgnoredColumns.addAll(Arrays.asList(aModel._parms._ignored_columns));
if (!anIgnoredColumns.equals(this.ignoredColumns))
throw new H2OIllegalArgumentException("Base models are inconsistent: they use different ignored_column lists. Found: " + this.ignoredColumns + " and: " + aModel._parms._ignored_columns + ".");
if (!responseColumn.equals(aModel._parms._response_column))
throw new H2OIllegalArgumentException("Base models are inconsistent: they use different response columns. Found: " + responseColumn + " and: " + aModel._parms._response_column + ".");
if (_output._domains.length != aModel._output._domains.length)
throw new H2OIllegalArgumentException("Base models are inconsistent: there is a mix of different numbers of domains (categorical levels): " + Arrays.toString(_parms._base_models));
if (nfolds != aModel._parms._nfolds)
throw new H2OIllegalArgumentException("Base models are inconsistent: they use different values for nfolds.");
// TODO: loosen this iff _parms._valid or if we add a separate holdout dataset for the ensemble
if (aModel._parms._nfolds < 2)
throw new H2OIllegalArgumentException("Base model does not use cross-validation: " + aModel._parms._nfolds);
// TODO: loosen this iff it's consistent, like if we have a _fold_column
if (aModel._parms._fold_assignment != Modulo)
throw new H2OIllegalArgumentException("Base model does not use Modulo for cross-validation: " + aModel._parms._nfolds);
if (!aModel._parms._keep_cross_validation_predictions)
throw new H2OIllegalArgumentException("Base model does not keep cross-validation predictions: " + aModel._parms._nfolds);
// Hack alert: DRF only does Bernoulli and Gaussian, so only compare _domains.length above.
if (!(aModel instanceof DRFModel) && distributionFamily(aModel) != distributionFamily(this))
Log.warn("Base models are inconsistent; they use different distributions: " + distributionFamily(this) + " and: " + distributionFamily(aModel) + ". Is this intentional?");
// TODO: If we're set to DistributionFamily.AUTO then GLM might auto-conform the response column
// giving us inconsistencies.
} else {
// !beenHere: this is the first base_model
_output._isSupervised = aModel.isSupervised();
this.modelCategory = aModel._output.getModelCategory();
this._dist = new Distribution(distributionFamily(aModel));
_output._domains = Arrays.copyOf(aModel._output._domains, aModel._output._domains.length);
// TODO: set _parms._train to aModel._parms.train()
_output._names = aModel._output._names;
this.names = new NonBlockingHashSet<>();
this.names.addAll(Arrays.asList(aModel._output._names));
this.ignoredColumns = new NonBlockingHashSet<>();
if (null != aModel._parms._ignored_columns)
this.ignoredColumns.addAll(Arrays.asList(aModel._parms._ignored_columns));
// consistent with the base_models:
if (null != this._parms._ignored_columns) {
NonBlockingHashSet<String> ensembleIgnoredColumns = new NonBlockingHashSet<>();
ensembleIgnoredColumns.addAll(Arrays.asList(this._parms._ignored_columns));
if (!ensembleIgnoredColumns.equals(this.ignoredColumns))
throw new H2OIllegalArgumentException("A StackedEnsemble takes its ignored_columns list from the base models. An inconsistent list of ignored_columns was specified for the ensemble model.");
}
responseColumn = aModel._parms._response_column;
if (!responseColumn.equals(_parms._response_column))
throw new H2OIllegalArgumentException("StackedModel response_column must match the response_column of each base model. Found: " + responseColumn + " and: " + _parms._response_column);
nfolds = aModel._parms._nfolds;
_parms._distribution = aModel._parms._distribution;
beenHere = true;
}
}
if (null == aModel)
throw new H2OIllegalArgumentException("When creating a StackedEnsemble you must specify one or more models; " + _parms._base_models.length + " were specified but none of those were found: " + Arrays.toString(_parms._base_models));
}
use of hex.tree.drf.DRFModel in project h2o-3 by h2oai.
the class ModelSerializationTest method prepareDRFModel.
private DRFModel prepareDRFModel(String dataset, String[] ignoredColumns, String response, boolean classification, int ntrees) {
Frame f = parse_test_file(dataset);
try {
if (classification && !f.vec(response).isCategorical()) {
f.replace(f.find(response), f.vec(response).toCategoricalVec()).remove();
DKV.put(f._key, f);
}
DRFModel.DRFParameters drfParams = new DRFModel.DRFParameters();
drfParams._train = f._key;
drfParams._ignored_columns = ignoredColumns;
drfParams._response_column = response;
drfParams._ntrees = ntrees;
drfParams._score_each_iteration = true;
return new DRF(drfParams).trainModel().get();
} finally {
if (f != null)
f.delete();
}
}
use of hex.tree.drf.DRFModel in project h2o-3 by h2oai.
the class ModelSerializationTest method testDRFModelBinomial.
@Test
public void testDRFModelBinomial() throws IOException {
DRFModel model = null, loadedModel = null;
try {
model = prepareDRFModel("smalldata/logreg/prostate.csv", ar("ID"), "CAPSULE", true, 5);
CompressedTree[][] trees = getTrees(model);
loadedModel = saveAndLoad(model);
// And compare
assertModelBinaryEquals(model, loadedModel);
CompressedTree[][] loadedTrees = getTrees(loadedModel);
assertTreeEquals("Trees have to be binary same", trees, loadedTrees);
} finally {
if (model != null)
model.delete();
if (loadedModel != null)
loadedModel.delete();
}
}
use of hex.tree.drf.DRFModel in project h2o-3 by h2oai.
the class ModelSerializationTest method testDRFModelMultinomial.
@Test
public void testDRFModelMultinomial() throws IOException {
DRFModel model, loadedModel = null;
try {
model = prepareDRFModel("smalldata/iris/iris.csv", ESA, "C5", true, 5);
CompressedTree[][] trees = getTrees(model);
loadedModel = saveAndLoad(model);
// And compare
assertModelBinaryEquals(model, loadedModel);
CompressedTree[][] loadedTrees = getTrees(loadedModel);
assertTreeEquals("Trees have to be binary same", trees, loadedTrees);
} finally {
if (loadedModel != null)
loadedModel.delete();
}
}
Aggregations