Search in sources :

Example 1 with DRFModel

use of hex.tree.drf.DRFModel in project h2o-3 by h2oai.

the class StackedEnsembleModel method distributionFamily.

private DistributionFamily distributionFamily(Model aModel) {
    // TODO: hack alert: In DRF, _parms._distribution is always set to multinomial.  Yay.
    if (aModel instanceof DRFModel)
        if (aModel._output.isBinomialClassifier())
            return DistributionFamily.bernoulli;
        else if (aModel._output.isClassifier())
            throw new H2OIllegalArgumentException("Don't know how to set the distribution for a multinomial Random Forest classifier.");
        else
            return DistributionFamily.gaussian;
    try {
        Field familyField = ReflectionUtils.findNamedField(aModel._parms, "_family");
        Field distributionField = (familyField != null ? null : ReflectionUtils.findNamedField(aModel, "_dist"));
        if (null != familyField) {
            // GLM only, for now
            GLMModel.GLMParameters.Family thisFamily = (GLMModel.GLMParameters.Family) familyField.get(aModel._parms);
            if (thisFamily == GLMModel.GLMParameters.Family.binomial) {
                return DistributionFamily.bernoulli;
            }
            try {
                return Enum.valueOf(DistributionFamily.class, thisFamily.toString());
            } catch (IllegalArgumentException e) {
                throw new H2OIllegalArgumentException("Don't know how to find the right DistributionFamily for Family: " + thisFamily);
            }
        }
        if (null != distributionField) {
            Distribution distribution = ((Distribution) distributionField.get(aModel));
            DistributionFamily distributionFamily;
            if (null != distribution)
                distributionFamily = distribution.distribution;
            else
                distributionFamily = aModel._parms._distribution;
            // NOTE: If the algo does smart guessing of the distribution family we need to duplicate the logic here.
            if (distributionFamily == DistributionFamily.AUTO) {
                if (aModel._output.isBinomialClassifier())
                    distributionFamily = DistributionFamily.bernoulli;
                else if (aModel._output.isClassifier())
                    throw new H2OIllegalArgumentException("Don't know how to determine the distribution for a multinomial classifier.");
                else
                    distributionFamily = DistributionFamily.gaussian;
            }
            return distributionFamily;
        }
        throw new H2OIllegalArgumentException("Don't know how to stack models that have neither a distribution hyperparameter nor a family hyperparameter.");
    } catch (Exception e) {
        throw new H2OIllegalArgumentException(e.toString(), e.toString());
    }
}
Also used : Field(java.lang.reflect.Field) DRFModel(hex.tree.drf.DRFModel) GLMModel(hex.glm.GLMModel) DistributionFamily(hex.genmodel.utils.DistributionFamily) H2OIllegalArgumentException(water.exceptions.H2OIllegalArgumentException) DistributionFamily(hex.genmodel.utils.DistributionFamily) H2OIllegalArgumentException(water.exceptions.H2OIllegalArgumentException) H2OIllegalArgumentException(water.exceptions.H2OIllegalArgumentException)

Example 2 with DRFModel

use of hex.tree.drf.DRFModel in project h2o-3 by h2oai.

the class StackedEnsembleModel method checkAndInheritModelProperties.

public void checkAndInheritModelProperties() {
    if (null == _parms._base_models || 0 == _parms._base_models.length)
        throw new H2OIllegalArgumentException("When creating a StackedEnsemble you must specify one or more models; found 0.");
    Model aModel = null;
    boolean beenHere = false;
    trainingFrameChecksum = _parms.train().checksum();
    for (Key<Model> k : _parms._base_models) {
        aModel = DKV.getGet(k);
        if (null == aModel) {
            Log.warn("Failed to find base model; skipping: " + k);
            continue;
        }
        if (beenHere) {
            // check that the base models are all consistent
            if (_output._isSupervised ^ aModel.isSupervised())
                throw new H2OIllegalArgumentException("Base models are inconsistent: there is a mix of supervised and unsupervised models: " + Arrays.toString(_parms._base_models));
            if (modelCategory != aModel._output.getModelCategory())
                throw new H2OIllegalArgumentException("Base models are inconsistent: there is a mix of different categories of models: " + Arrays.toString(_parms._base_models));
            Frame aTrainingFrame = aModel._parms.train();
            if (trainingFrameChecksum != aTrainingFrame.checksum())
                throw new H2OIllegalArgumentException("Base models are inconsistent: they use different training frames.  Found checksums: " + trainingFrameChecksum + " and: " + aTrainingFrame.checksum() + ".");
            NonBlockingHashSet<String> aNames = new NonBlockingHashSet<>();
            aNames.addAll(Arrays.asList(aModel._output._names));
            if (!aNames.equals(this.names))
                throw new H2OIllegalArgumentException("Base models are inconsistent: they use different column lists.  Found: " + this.names + " and: " + aNames + ".");
            NonBlockingHashSet<String> anIgnoredColumns = new NonBlockingHashSet<>();
            if (null != aModel._parms._ignored_columns)
                anIgnoredColumns.addAll(Arrays.asList(aModel._parms._ignored_columns));
            if (!anIgnoredColumns.equals(this.ignoredColumns))
                throw new H2OIllegalArgumentException("Base models are inconsistent: they use different ignored_column lists.  Found: " + this.ignoredColumns + " and: " + aModel._parms._ignored_columns + ".");
            if (!responseColumn.equals(aModel._parms._response_column))
                throw new H2OIllegalArgumentException("Base models are inconsistent: they use different response columns.  Found: " + responseColumn + " and: " + aModel._parms._response_column + ".");
            if (_output._domains.length != aModel._output._domains.length)
                throw new H2OIllegalArgumentException("Base models are inconsistent: there is a mix of different numbers of domains (categorical levels): " + Arrays.toString(_parms._base_models));
            if (nfolds != aModel._parms._nfolds)
                throw new H2OIllegalArgumentException("Base models are inconsistent: they use different values for nfolds.");
            // TODO: loosen this iff _parms._valid or if we add a separate holdout dataset for the ensemble
            if (aModel._parms._nfolds < 2)
                throw new H2OIllegalArgumentException("Base model does not use cross-validation: " + aModel._parms._nfolds);
            // TODO: loosen this iff it's consistent, like if we have a _fold_column
            if (aModel._parms._fold_assignment != Modulo)
                throw new H2OIllegalArgumentException("Base model does not use Modulo for cross-validation: " + aModel._parms._nfolds);
            if (!aModel._parms._keep_cross_validation_predictions)
                throw new H2OIllegalArgumentException("Base model does not keep cross-validation predictions: " + aModel._parms._nfolds);
            // Hack alert: DRF only does Bernoulli and Gaussian, so only compare _domains.length above.
            if (!(aModel instanceof DRFModel) && distributionFamily(aModel) != distributionFamily(this))
                Log.warn("Base models are inconsistent; they use different distributions: " + distributionFamily(this) + " and: " + distributionFamily(aModel) + ". Is this intentional?");
        // TODO: If we're set to DistributionFamily.AUTO then GLM might auto-conform the response column
        // giving us inconsistencies.
        } else {
            // !beenHere: this is the first base_model
            _output._isSupervised = aModel.isSupervised();
            this.modelCategory = aModel._output.getModelCategory();
            this._dist = new Distribution(distributionFamily(aModel));
            _output._domains = Arrays.copyOf(aModel._output._domains, aModel._output._domains.length);
            // TODO: set _parms._train to aModel._parms.train()
            _output._names = aModel._output._names;
            this.names = new NonBlockingHashSet<>();
            this.names.addAll(Arrays.asList(aModel._output._names));
            this.ignoredColumns = new NonBlockingHashSet<>();
            if (null != aModel._parms._ignored_columns)
                this.ignoredColumns.addAll(Arrays.asList(aModel._parms._ignored_columns));
            // consistent with the base_models:
            if (null != this._parms._ignored_columns) {
                NonBlockingHashSet<String> ensembleIgnoredColumns = new NonBlockingHashSet<>();
                ensembleIgnoredColumns.addAll(Arrays.asList(this._parms._ignored_columns));
                if (!ensembleIgnoredColumns.equals(this.ignoredColumns))
                    throw new H2OIllegalArgumentException("A StackedEnsemble takes its ignored_columns list from the base models.  An inconsistent list of ignored_columns was specified for the ensemble model.");
            }
            responseColumn = aModel._parms._response_column;
            if (!responseColumn.equals(_parms._response_column))
                throw new H2OIllegalArgumentException("StackedModel response_column must match the response_column of each base model.  Found: " + responseColumn + " and: " + _parms._response_column);
            nfolds = aModel._parms._nfolds;
            _parms._distribution = aModel._parms._distribution;
            beenHere = true;
        }
    }
    if (null == aModel)
        throw new H2OIllegalArgumentException("When creating a StackedEnsemble you must specify one or more models; " + _parms._base_models.length + " were specified but none of those were found: " + Arrays.toString(_parms._base_models));
}
Also used : Frame(water.fvec.Frame) DRFModel(hex.tree.drf.DRFModel) H2OIllegalArgumentException(water.exceptions.H2OIllegalArgumentException) GLMModel(hex.glm.GLMModel) DRFModel(hex.tree.drf.DRFModel) NonBlockingHashSet(water.nbhm.NonBlockingHashSet)

Example 3 with DRFModel

use of hex.tree.drf.DRFModel in project h2o-3 by h2oai.

the class ModelSerializationTest method prepareDRFModel.

private DRFModel prepareDRFModel(String dataset, String[] ignoredColumns, String response, boolean classification, int ntrees) {
    Frame f = parse_test_file(dataset);
    try {
        if (classification && !f.vec(response).isCategorical()) {
            f.replace(f.find(response), f.vec(response).toCategoricalVec()).remove();
            DKV.put(f._key, f);
        }
        DRFModel.DRFParameters drfParams = new DRFModel.DRFParameters();
        drfParams._train = f._key;
        drfParams._ignored_columns = ignoredColumns;
        drfParams._response_column = response;
        drfParams._ntrees = ntrees;
        drfParams._score_each_iteration = true;
        return new DRF(drfParams).trainModel().get();
    } finally {
        if (f != null)
            f.delete();
    }
}
Also used : Frame(water.fvec.Frame) DRFModel(hex.tree.drf.DRFModel) DRF(hex.tree.drf.DRF)

Example 4 with DRFModel

use of hex.tree.drf.DRFModel in project h2o-3 by h2oai.

the class ModelSerializationTest method testDRFModelBinomial.

@Test
public void testDRFModelBinomial() throws IOException {
    DRFModel model = null, loadedModel = null;
    try {
        model = prepareDRFModel("smalldata/logreg/prostate.csv", ar("ID"), "CAPSULE", true, 5);
        CompressedTree[][] trees = getTrees(model);
        loadedModel = saveAndLoad(model);
        // And compare
        assertModelBinaryEquals(model, loadedModel);
        CompressedTree[][] loadedTrees = getTrees(loadedModel);
        assertTreeEquals("Trees have to be binary same", trees, loadedTrees);
    } finally {
        if (model != null)
            model.delete();
        if (loadedModel != null)
            loadedModel.delete();
    }
}
Also used : DRFModel(hex.tree.drf.DRFModel) Test(org.junit.Test)

Example 5 with DRFModel

use of hex.tree.drf.DRFModel in project h2o-3 by h2oai.

the class ModelSerializationTest method testDRFModelMultinomial.

@Test
public void testDRFModelMultinomial() throws IOException {
    DRFModel model, loadedModel = null;
    try {
        model = prepareDRFModel("smalldata/iris/iris.csv", ESA, "C5", true, 5);
        CompressedTree[][] trees = getTrees(model);
        loadedModel = saveAndLoad(model);
        // And compare
        assertModelBinaryEquals(model, loadedModel);
        CompressedTree[][] loadedTrees = getTrees(loadedModel);
        assertTreeEquals("Trees have to be binary same", trees, loadedTrees);
    } finally {
        if (loadedModel != null)
            loadedModel.delete();
    }
}
Also used : DRFModel(hex.tree.drf.DRFModel) Test(org.junit.Test)

Aggregations

DRFModel (hex.tree.drf.DRFModel)8 GLMModel (hex.glm.GLMModel)4 Frame (water.fvec.Frame)4 DRF (hex.tree.drf.DRF)3 Test (org.junit.Test)3 DeepLearning (hex.deeplearning.DeepLearning)2 DeepLearningModel (hex.deeplearning.DeepLearningModel)2 GLM (hex.glm.GLM)2 GBM (hex.tree.gbm.GBM)2 GBMModel (hex.tree.gbm.GBMModel)2 H2OIllegalArgumentException (water.exceptions.H2OIllegalArgumentException)2 DistributionFamily (hex.genmodel.utils.DistributionFamily)1 Grid (hex.grid.Grid)1 GridSearch (hex.grid.GridSearch)1 DRFParametersV3 (hex.schemas.DRFV3.DRFParametersV3)1 DeepLearningParametersV3 (hex.schemas.DeepLearningV3.DeepLearningParametersV3)1 GBMParametersV3 (hex.schemas.GBMV3.GBMParametersV3)1 GLMParametersV3 (hex.schemas.GLMV3.GLMParametersV3)1 SharedTreeModel (hex.tree.SharedTreeModel)1 IOException (java.io.IOException)1