Search in sources :

Example 1 with NonBlockingHashSet

use of water.nbhm.NonBlockingHashSet in project h2o-3 by h2oai.

the class StackedEnsembleModel method checkAndInheritModelProperties.

public void checkAndInheritModelProperties() {
    if (null == _parms._base_models || 0 == _parms._base_models.length)
        throw new H2OIllegalArgumentException("When creating a StackedEnsemble you must specify one or more models; found 0.");
    Model aModel = null;
    boolean beenHere = false;
    trainingFrameChecksum = _parms.train().checksum();
    for (Key<Model> k : _parms._base_models) {
        aModel = DKV.getGet(k);
        if (null == aModel) {
            Log.warn("Failed to find base model; skipping: " + k);
            continue;
        }
        if (beenHere) {
            // check that the base models are all consistent
            if (_output._isSupervised ^ aModel.isSupervised())
                throw new H2OIllegalArgumentException("Base models are inconsistent: there is a mix of supervised and unsupervised models: " + Arrays.toString(_parms._base_models));
            if (modelCategory != aModel._output.getModelCategory())
                throw new H2OIllegalArgumentException("Base models are inconsistent: there is a mix of different categories of models: " + Arrays.toString(_parms._base_models));
            Frame aTrainingFrame = aModel._parms.train();
            if (trainingFrameChecksum != aTrainingFrame.checksum())
                throw new H2OIllegalArgumentException("Base models are inconsistent: they use different training frames.  Found checksums: " + trainingFrameChecksum + " and: " + aTrainingFrame.checksum() + ".");
            NonBlockingHashSet<String> aNames = new NonBlockingHashSet<>();
            aNames.addAll(Arrays.asList(aModel._output._names));
            if (!aNames.equals(this.names))
                throw new H2OIllegalArgumentException("Base models are inconsistent: they use different column lists.  Found: " + this.names + " and: " + aNames + ".");
            NonBlockingHashSet<String> anIgnoredColumns = new NonBlockingHashSet<>();
            if (null != aModel._parms._ignored_columns)
                anIgnoredColumns.addAll(Arrays.asList(aModel._parms._ignored_columns));
            if (!anIgnoredColumns.equals(this.ignoredColumns))
                throw new H2OIllegalArgumentException("Base models are inconsistent: they use different ignored_column lists.  Found: " + this.ignoredColumns + " and: " + aModel._parms._ignored_columns + ".");
            if (!responseColumn.equals(aModel._parms._response_column))
                throw new H2OIllegalArgumentException("Base models are inconsistent: they use different response columns.  Found: " + responseColumn + " and: " + aModel._parms._response_column + ".");
            if (_output._domains.length != aModel._output._domains.length)
                throw new H2OIllegalArgumentException("Base models are inconsistent: there is a mix of different numbers of domains (categorical levels): " + Arrays.toString(_parms._base_models));
            if (nfolds != aModel._parms._nfolds)
                throw new H2OIllegalArgumentException("Base models are inconsistent: they use different values for nfolds.");
            // TODO: loosen this iff _parms._valid or if we add a separate holdout dataset for the ensemble
            if (aModel._parms._nfolds < 2)
                throw new H2OIllegalArgumentException("Base model does not use cross-validation: " + aModel._parms._nfolds);
            // TODO: loosen this iff it's consistent, like if we have a _fold_column
            if (aModel._parms._fold_assignment != Modulo)
                throw new H2OIllegalArgumentException("Base model does not use Modulo for cross-validation: " + aModel._parms._nfolds);
            if (!aModel._parms._keep_cross_validation_predictions)
                throw new H2OIllegalArgumentException("Base model does not keep cross-validation predictions: " + aModel._parms._nfolds);
            // Hack alert: DRF only does Bernoulli and Gaussian, so only compare _domains.length above.
            if (!(aModel instanceof DRFModel) && distributionFamily(aModel) != distributionFamily(this))
                Log.warn("Base models are inconsistent; they use different distributions: " + distributionFamily(this) + " and: " + distributionFamily(aModel) + ". Is this intentional?");
        // TODO: If we're set to DistributionFamily.AUTO then GLM might auto-conform the response column
        // giving us inconsistencies.
        } else {
            // !beenHere: this is the first base_model
            _output._isSupervised = aModel.isSupervised();
            this.modelCategory = aModel._output.getModelCategory();
            this._dist = new Distribution(distributionFamily(aModel));
            _output._domains = Arrays.copyOf(aModel._output._domains, aModel._output._domains.length);
            // TODO: set _parms._train to aModel._parms.train()
            _output._names = aModel._output._names;
            this.names = new NonBlockingHashSet<>();
            this.names.addAll(Arrays.asList(aModel._output._names));
            this.ignoredColumns = new NonBlockingHashSet<>();
            if (null != aModel._parms._ignored_columns)
                this.ignoredColumns.addAll(Arrays.asList(aModel._parms._ignored_columns));
            // consistent with the base_models:
            if (null != this._parms._ignored_columns) {
                NonBlockingHashSet<String> ensembleIgnoredColumns = new NonBlockingHashSet<>();
                ensembleIgnoredColumns.addAll(Arrays.asList(this._parms._ignored_columns));
                if (!ensembleIgnoredColumns.equals(this.ignoredColumns))
                    throw new H2OIllegalArgumentException("A StackedEnsemble takes its ignored_columns list from the base models.  An inconsistent list of ignored_columns was specified for the ensemble model.");
            }
            responseColumn = aModel._parms._response_column;
            if (!responseColumn.equals(_parms._response_column))
                throw new H2OIllegalArgumentException("StackedModel response_column must match the response_column of each base model.  Found: " + responseColumn + " and: " + _parms._response_column);
            nfolds = aModel._parms._nfolds;
            _parms._distribution = aModel._parms._distribution;
            beenHere = true;
        }
    }
    if (null == aModel)
        throw new H2OIllegalArgumentException("When creating a StackedEnsemble you must specify one or more models; " + _parms._base_models.length + " were specified but none of those were found: " + Arrays.toString(_parms._base_models));
}
Also used : Frame(water.fvec.Frame) DRFModel(hex.tree.drf.DRFModel) H2OIllegalArgumentException(water.exceptions.H2OIllegalArgumentException) GLMModel(hex.glm.GLMModel) DRFModel(hex.tree.drf.DRFModel) NonBlockingHashSet(water.nbhm.NonBlockingHashSet)

Aggregations

GLMModel (hex.glm.GLMModel)1 DRFModel (hex.tree.drf.DRFModel)1 H2OIllegalArgumentException (water.exceptions.H2OIllegalArgumentException)1 Frame (water.fvec.Frame)1 NonBlockingHashSet (water.nbhm.NonBlockingHashSet)1