Search in sources :

Example 6 with GradientBoostedTreesModel

use of org.knime.base.node.mine.treeensemble2.model.GradientBoostedTreesModel in project knime-core by knime.

the class RegressionGBTModelImporter method importFromPMMLInternal.

/**
 * {@inheritDoc}
 */
@Override
public GradientBoostedTreesModel importFromPMMLInternal(final MiningModel miningModel) {
    Segmentation segmentation = miningModel.getSegmentation();
    CheckUtils.checkArgument(segmentation.getMultipleModelMethod() == MULTIPLEMODELMETHOD.SUM, "The provided segmentation has not the required sum as multiple model method but '%s' instead.", segmentation.getMultipleModelMethod());
    Pair<List<TreeModelRegression>, List<Map<TreeNodeSignature, Double>>> treesCoeffientMapsPair = readSumSegmentation(segmentation);
    List<TreeModelRegression> trees = treesCoeffientMapsPair.getFirst();
    // TODO user should be warned if there is no initial value or anything else is fishy
    double initialValue = miningModel.getTargets().getTargetList().get(0).getRescaleConstant();
    // currently only models learned on "ordinary" columns can be read back in
    return new GradientBoostedTreesModel(getMetaDataMapper().getTreeMetaData(), trees.toArray(new TreeModelRegression[trees.size()]), TreeType.Ordinary, initialValue, treesCoeffientMapsPair.getSecond());
}
Also used : Segmentation(org.dmg.pmml.SegmentationDocument.Segmentation) List(java.util.List) GradientBoostedTreesModel(org.knime.base.node.mine.treeensemble2.model.GradientBoostedTreesModel) TreeNodeSignature(org.knime.base.node.mine.treeensemble2.model.TreeNodeSignature) TreeModelRegression(org.knime.base.node.mine.treeensemble2.model.TreeModelRegression)

Example 7 with GradientBoostedTreesModel

use of org.knime.base.node.mine.treeensemble2.model.GradientBoostedTreesModel in project knime-core by knime.

the class MGradientBoostedTreesLearner method learn.

/**
 * {@inheritDoc}
 */
@Override
public AbstractGradientBoostingModel learn(final ExecutionMonitor exec) throws CanceledExecutionException {
    final TreeData actualData = getData();
    final GradientBoostingLearnerConfiguration config = getConfig();
    final int nrModels = config.getNrModels();
    final TreeTargetNumericColumnData actualTarget = getTarget();
    final double initialValue = actualTarget.getMedian();
    final ArrayList<TreeModelRegression> models = new ArrayList<TreeModelRegression>(nrModels);
    final ArrayList<Map<TreeNodeSignature, Double>> coefficientMaps = new ArrayList<Map<TreeNodeSignature, Double>>(nrModels);
    final double[] previousPrediction = new double[actualTarget.getNrRows()];
    Arrays.fill(previousPrediction, initialValue);
    final RandomData rd = config.createRandomData();
    final double alpha = config.getAlpha();
    TreeNodeSignatureFactory signatureFactory = null;
    final int maxLevels = config.getMaxLevels();
    // this should be the default
    if (maxLevels < TreeEnsembleLearnerConfiguration.MAX_LEVEL_INFINITE) {
        final int capacity = IntMath.pow(2, maxLevels - 1);
        signatureFactory = new TreeNodeSignatureFactory(capacity);
    } else {
        signatureFactory = new TreeNodeSignatureFactory();
    }
    exec.setMessage("Learning model");
    TreeData residualData;
    for (int i = 0; i < nrModels; i++) {
        final double[] residuals = new double[actualTarget.getNrRows()];
        for (int j = 0; j < actualTarget.getNrRows(); j++) {
            residuals[j] = actualTarget.getValueFor(j) - previousPrediction[j];
        }
        final double quantile = calculateAlphaQuantile(residuals, alpha);
        final double[] gradients = new double[residuals.length];
        for (int j = 0; j < gradients.length; j++) {
            gradients[j] = Math.abs(residuals[j]) <= quantile ? residuals[j] : quantile * Math.signum(residuals[j]);
        }
        residualData = createResidualDataFromArray(gradients, actualData);
        final RandomData rdSingle = TreeEnsembleLearnerConfiguration.createRandomData(rd.nextLong(Long.MIN_VALUE, Long.MAX_VALUE));
        final RowSample rowSample = getRowSampler().createRowSample(rdSingle);
        final TreeLearnerRegression treeLearner = new TreeLearnerRegression(getConfig(), residualData, getIndexManager(), signatureFactory, rdSingle, rowSample);
        final TreeModelRegression tree = treeLearner.learnSingleTree(exec, rdSingle);
        final Map<TreeNodeSignature, Double> coefficientMap = calcCoefficientMap(residuals, quantile, tree);
        adaptPreviousPrediction(previousPrediction, tree, coefficientMap);
        models.add(tree);
        coefficientMaps.add(coefficientMap);
        exec.setProgress(((double) i) / nrModels, "Finished level " + i + "/" + nrModels);
    }
    return new GradientBoostedTreesModel(getConfig(), actualData.getMetaData(), models.toArray(new TreeModelRegression[models.size()]), actualData.getTreeType(), initialValue, coefficientMaps);
}
Also used : RandomData(org.apache.commons.math.random.RandomData) ArrayList(java.util.ArrayList) TreeTargetNumericColumnData(org.knime.base.node.mine.treeensemble2.data.TreeTargetNumericColumnData) GradientBoostedTreesModel(org.knime.base.node.mine.treeensemble2.model.GradientBoostedTreesModel) TreeNodeSignature(org.knime.base.node.mine.treeensemble2.model.TreeNodeSignature) TreeModelRegression(org.knime.base.node.mine.treeensemble2.model.TreeModelRegression) GradientBoostingLearnerConfiguration(org.knime.base.node.mine.treeensemble2.node.gradientboosting.learner.GradientBoostingLearnerConfiguration) TreeData(org.knime.base.node.mine.treeensemble2.data.TreeData) TreeLearnerRegression(org.knime.base.node.mine.treeensemble2.learner.TreeLearnerRegression) RowSample(org.knime.base.node.mine.treeensemble2.sample.row.RowSample) HashMap(java.util.HashMap) Map(java.util.Map) TreeNodeSignatureFactory(org.knime.base.node.mine.treeensemble2.learner.TreeNodeSignatureFactory)

Example 8 with GradientBoostedTreesModel

use of org.knime.base.node.mine.treeensemble2.model.GradientBoostedTreesModel in project knime-core by knime.

the class TreeEnsembleModel method load.

public static TreeEnsembleModel load(final InputStream in) throws IOException {
    // wrapping the argument (zip input) stream in a buffered stream
    // reduces read operation from, e.g. 42s to 2s
    TreeModelDataInputStream input = new TreeModelDataInputStream(new BufferedInputStream(new NonClosableInputStream(in)));
    int version = input.readInt();
    if (version > 20160114) {
        throw new IOException("Tree Ensemble version " + version + " not supported");
    }
    byte ensembleType;
    if (version == 20160114) {
        ensembleType = input.readByte();
    } else {
        ensembleType = 'r';
    }
    TreeType type = TreeType.load(input);
    TreeMetaData metaData = TreeMetaData.load(input);
    int nrModels = input.readInt();
    boolean containsClassDistribution;
    if (version == 20121019) {
        containsClassDistribution = true;
    } else {
        containsClassDistribution = input.readBoolean();
    }
    input.setContainsClassDistribution(containsClassDistribution);
    AbstractTreeModel[] models = new AbstractTreeModel[nrModels];
    boolean isRegression = metaData.isRegression();
    if (ensembleType != 'r') {
        isRegression = true;
    }
    final TreeBuildingInterner treeBuildingInterner = new TreeBuildingInterner();
    for (int i = 0; i < nrModels; i++) {
        AbstractTreeModel singleModel;
        try {
            singleModel = isRegression ? TreeModelRegression.load(input, metaData, treeBuildingInterner) : TreeModelClassification.load(input, metaData, treeBuildingInterner);
            if (input.readByte() != 0) {
                throw new IOException("Model not terminated by 0 byte");
            }
        } catch (IOException e) {
            throw new IOException("Can't read tree model " + (i + 1) + "/" + nrModels + ": " + e.getMessage(), e);
        }
        models[i] = singleModel;
    }
    TreeEnsembleModel result;
    switch(ensembleType) {
        case 'r':
            result = new TreeEnsembleModel(metaData, models, type, containsClassDistribution);
            break;
        case 'g':
            result = new GradientBoostingModel(metaData, models, type, containsClassDistribution);
            break;
        case 't':
            result = new GradientBoostedTreesModel(metaData, models, type, containsClassDistribution);
            break;
        case 'm':
            result = new MultiClassGradientBoostedTreesModel(metaData, models, type, containsClassDistribution);
            break;
        default:
            throw new IllegalStateException("Unknown ensemble type: '" + (char) ensembleType + "'");
    }
    result.loadData(input);
    // does not close the method argument stream!!
    input.close();
    return result;
}
Also used : IOException(java.io.IOException) BufferedInputStream(java.io.BufferedInputStream) NonClosableInputStream(org.knime.core.data.util.NonClosableInputStream) TreeMetaData(org.knime.base.node.mine.treeensemble2.data.TreeMetaData)

Example 9 with GradientBoostedTreesModel

use of org.knime.base.node.mine.treeensemble2.model.GradientBoostedTreesModel in project knime-core by knime.

the class RegressionGBTModelExporter method doWrite.

/**
 * {@inheritDoc}
 */
@Override
protected void doWrite(final MiningModel model) {
    // write the initial value
    Targets targets = model.addNewTargets();
    Target target = targets.addNewTarget();
    GradientBoostedTreesModel gbtModel = getGBTModel();
    target.setField(gbtModel.getMetaData().getTargetMetaData().getAttributeName());
    target.setRescaleConstant(gbtModel.getInitialValue());
    // write the model
    Segmentation segmentation = model.addNewSegmentation();
    List<TreeModelRegression> trees = IntStream.range(0, gbtModel.getNrModels()).mapToObj(gbtModel::getTreeModelRegression).collect(Collectors.toList());
    writeSumSegmentation(segmentation, trees, gbtModel.getCoeffientMaps());
}
Also used : Target(org.dmg.pmml.TargetDocument.Target) Segmentation(org.dmg.pmml.SegmentationDocument.Segmentation) Targets(org.dmg.pmml.TargetsDocument.Targets) GradientBoostedTreesModel(org.knime.base.node.mine.treeensemble2.model.GradientBoostedTreesModel) TreeModelRegression(org.knime.base.node.mine.treeensemble2.model.TreeModelRegression)

Example 10 with GradientBoostedTreesModel

use of org.knime.base.node.mine.treeensemble2.model.GradientBoostedTreesModel in project knime-core by knime.

the class GradientBoostingPredictorCellFactory method createFactory.

public static GradientBoostingPredictorCellFactory createFactory(final GradientBoostingPredictor<GradientBoostedTreesModel> predictor) throws InvalidSettingsException {
    TreeEnsembleModelPortObjectSpec modelSpec = predictor.getModelSpec();
    DataTableSpec learnSpec = modelSpec.getLearnTableSpec();
    DataTableSpec testSpec = predictor.getDataSpec();
    UniqueNameGenerator nameGen = new UniqueNameGenerator(testSpec);
    DataColumnSpec newColSpec = nameGen.newColumn(predictor.getConfiguration().getPredictionColumnName(), DoubleCell.TYPE);
    return new GradientBoostingPredictorCellFactory(newColSpec, predictor.getModel(), learnSpec, modelSpec.calculateFilterIndices(testSpec));
}
Also used : DataTableSpec(org.knime.core.data.DataTableSpec) DataColumnSpec(org.knime.core.data.DataColumnSpec) TreeEnsembleModelPortObjectSpec(org.knime.base.node.mine.treeensemble2.model.TreeEnsembleModelPortObjectSpec) UniqueNameGenerator(org.knime.core.util.UniqueNameGenerator)

Aggregations

GradientBoostedTreesModel (org.knime.base.node.mine.treeensemble2.model.GradientBoostedTreesModel)8 TreeEnsembleModelPortObjectSpec (org.knime.base.node.mine.treeensemble2.model.TreeEnsembleModelPortObjectSpec)5 DataTableSpec (org.knime.core.data.DataTableSpec)5 GradientBoostingModelPortObject (org.knime.base.node.mine.treeensemble2.model.GradientBoostingModelPortObject)3 MultiClassGradientBoostedTreesModel (org.knime.base.node.mine.treeensemble2.model.MultiClassGradientBoostedTreesModel)3 TreeModelRegression (org.knime.base.node.mine.treeensemble2.model.TreeModelRegression)3 GradientBoostingPredictor (org.knime.base.node.mine.treeensemble2.node.gradientboosting.predictor.GradientBoostingPredictor)3 ColumnRearranger (org.knime.core.data.container.ColumnRearranger)3 Segmentation (org.dmg.pmml.SegmentationDocument.Segmentation)2 TreeNodeSignature (org.knime.base.node.mine.treeensemble2.model.TreeNodeSignature)2 PMMLPortObjectSpec (org.knime.core.node.port.pmml.PMMLPortObjectSpec)2 BufferedInputStream (java.io.BufferedInputStream)1 IOException (java.io.IOException)1 ArrayList (java.util.ArrayList)1 HashMap (java.util.HashMap)1 List (java.util.List)1 Map (java.util.Map)1 RandomData (org.apache.commons.math.random.RandomData)1 Target (org.dmg.pmml.TargetDocument.Target)1 Targets (org.dmg.pmml.TargetsDocument.Targets)1