Search in sources :

Example 1 with TreeModelRegression

use of org.knime.base.node.mine.treeensemble2.model.TreeModelRegression in project knime-core by knime.

the class RegressionTreePMMLTranslatorNodeModel method execute.

/**
 * {@inheritDoc}
 */
@Override
protected PortObject[] execute(final PortObject[] inObjects, final ExecutionContext exec) throws Exception {
    final RegressionTreeModelPortObject treePO = (RegressionTreeModelPortObject) inObjects[0];
    final RegressionTreeModel model = treePO.getModel();
    final RegressionTreeModelPortObjectSpec treeSpec = treePO.getSpec();
    PMMLPortObjectSpec pmmlSpec = createPMMLSpec(treeSpec, model);
    PMMLPortObject portObject = new PMMLPortObject(pmmlSpec);
    final TreeModelRegression tree = model.getTreeModel();
    final RegressionTreeModelPMMLTranslator translator = new RegressionTreeModelPMMLTranslator(tree, model.getMetaData(), treeSpec.getLearnTableSpec());
    portObject.addModelTranslater(translator);
    if (translator.hasWarning()) {
        setWarningMessage(translator.getWarning());
    }
    return new PortObject[] { portObject };
}
Also used : RegressionTreeModelPortObject(org.knime.base.node.mine.treeensemble2.model.RegressionTreeModelPortObject) PMMLPortObjectSpec(org.knime.core.node.port.pmml.PMMLPortObjectSpec) PMMLPortObject(org.knime.core.node.port.pmml.PMMLPortObject) RegressionTreeModel(org.knime.base.node.mine.treeensemble2.model.RegressionTreeModel) RegressionTreeModelPortObjectSpec(org.knime.base.node.mine.treeensemble2.model.RegressionTreeModelPortObjectSpec) RegressionTreeModelPortObject(org.knime.base.node.mine.treeensemble2.model.RegressionTreeModelPortObject) PMMLPortObject(org.knime.core.node.port.pmml.PMMLPortObject) PortObject(org.knime.core.node.port.PortObject) TreeModelRegression(org.knime.base.node.mine.treeensemble2.model.TreeModelRegression) RegressionTreeModelPMMLTranslator(org.knime.base.node.mine.treeensemble2.model.pmml.RegressionTreeModelPMMLTranslator)

Example 2 with TreeModelRegression

use of org.knime.base.node.mine.treeensemble2.model.TreeModelRegression in project knime-core by knime.

the class RegressionTreeModel method load.

/**
 * Loads and returns new ensemble model, input is NOT closed afterwards.
 *
 * @param in ...
 * @param exec ...
 * @return ...
 * @throws IOException ...
 * @throws CanceledExecutionException ...
 */
public static RegressionTreeModel load(final InputStream in, final ExecutionMonitor exec, final TreeBuildingInterner treeBuildingInterner) throws IOException, CanceledExecutionException {
    // wrapping the argument (zip input) stream in a buffered stream
    // reduces read operation from, e.g. 42s to 2s
    TreeModelDataInputStream input = new TreeModelDataInputStream(new BufferedInputStream(new NonClosableInputStream(in)));
    int version = input.readInt();
    if (version > 20140201) {
        throw new IOException("Tree Ensemble version " + version + " not supported");
    }
    TreeType type = TreeType.load(input);
    TreeMetaData metaData = TreeMetaData.load(input);
    boolean isRegression = metaData.isRegression();
    TreeModelRegression model;
    try {
        model = TreeModelRegression.load(input, metaData, treeBuildingInterner);
        if (input.readByte() != 0) {
            throw new IOException("Model not terminated by 0 byte");
        }
    } catch (IOException e) {
        throw new IOException("Can't read tree model. " + e.getMessage(), e);
    }
    // does not close the method argument stream!!
    input.close();
    return new RegressionTreeModel(metaData, model, type);
}
Also used : BufferedInputStream(java.io.BufferedInputStream) NonClosableInputStream(org.knime.core.data.util.NonClosableInputStream) TreeMetaData(org.knime.base.node.mine.treeensemble2.data.TreeMetaData) IOException(java.io.IOException)

Example 3 with TreeModelRegression

use of org.knime.base.node.mine.treeensemble2.model.TreeModelRegression in project knime-core by knime.

the class RegressionGBTModelImporter method importFromPMMLInternal.

/**
 * {@inheritDoc}
 */
@Override
public GradientBoostedTreesModel importFromPMMLInternal(final MiningModel miningModel) {
    Segmentation segmentation = miningModel.getSegmentation();
    CheckUtils.checkArgument(segmentation.getMultipleModelMethod() == MULTIPLEMODELMETHOD.SUM, "The provided segmentation has not the required sum as multiple model method but '%s' instead.", segmentation.getMultipleModelMethod());
    Pair<List<TreeModelRegression>, List<Map<TreeNodeSignature, Double>>> treesCoeffientMapsPair = readSumSegmentation(segmentation);
    List<TreeModelRegression> trees = treesCoeffientMapsPair.getFirst();
    // TODO user should be warned if there is no initial value or anything else is fishy
    double initialValue = miningModel.getTargets().getTargetList().get(0).getRescaleConstant();
    // currently only models learned on "ordinary" columns can be read back in
    return new GradientBoostedTreesModel(getMetaDataMapper().getTreeMetaData(), trees.toArray(new TreeModelRegression[trees.size()]), TreeType.Ordinary, initialValue, treesCoeffientMapsPair.getSecond());
}
Also used : Segmentation(org.dmg.pmml.SegmentationDocument.Segmentation) List(java.util.List) GradientBoostedTreesModel(org.knime.base.node.mine.treeensemble2.model.GradientBoostedTreesModel) TreeNodeSignature(org.knime.base.node.mine.treeensemble2.model.TreeNodeSignature) TreeModelRegression(org.knime.base.node.mine.treeensemble2.model.TreeModelRegression)

Example 4 with TreeModelRegression

use of org.knime.base.node.mine.treeensemble2.model.TreeModelRegression in project knime-core by knime.

the class TreeLearnerRegression method learnSingleTree.

/**
 * {@inheritDoc}
 */
@Override
public TreeModelRegression learnSingleTree(final ExecutionMonitor exec, final RandomData rd) throws CanceledExecutionException {
    final TreeTargetNumericColumnData targetColumn = getTargetData();
    final TreeData data = getData();
    final RowSample rowSampling = getRowSampling();
    final TreeEnsembleLearnerConfiguration config = getConfig();
    final IDataIndexManager indexManager = getIndexManager();
    DataMemberships rootDataMemberships = new RootDataMemberships(rowSampling, data, indexManager);
    RegressionPriors targetPriors = targetColumn.getPriors(rootDataMemberships, config);
    BitSet forbiddenColumnSet = new BitSet(data.getNrAttributes());
    boolean isGradientBoosting = config instanceof GradientBoostingLearnerConfiguration;
    if (isGradientBoosting) {
        m_leafs = new ArrayList<TreeNodeRegression>();
    }
    final TreeNodeSignature rootSignature = TreeNodeSignature.ROOT_SIGNATURE;
    final ColumnSample rootColumnSample = getColSamplingStrategy().getColumnSampleForTreeNode(rootSignature);
    TreeNodeRegression rootNode = buildTreeNode(exec, 0, rootDataMemberships, rootColumnSample, getSignatureFactory().getRootSignature(), targetPriors, forbiddenColumnSet);
    assert forbiddenColumnSet.cardinality() == 0;
    rootNode.setTreeNodeCondition(TreeNodeTrueCondition.INSTANCE);
    if (isGradientBoosting) {
        return new TreeModelRegression(rootNode, m_leafs);
    }
    return new TreeModelRegression(rootNode);
}
Also used : TreeEnsembleLearnerConfiguration(org.knime.base.node.mine.treeensemble2.node.learner.TreeEnsembleLearnerConfiguration) RootDataMemberships(org.knime.base.node.mine.treeensemble2.data.memberships.RootDataMemberships) ColumnSample(org.knime.base.node.mine.treeensemble2.sample.column.ColumnSample) RegressionPriors(org.knime.base.node.mine.treeensemble2.data.RegressionPriors) BitSet(java.util.BitSet) TreeTargetNumericColumnData(org.knime.base.node.mine.treeensemble2.data.TreeTargetNumericColumnData) IDataIndexManager(org.knime.base.node.mine.treeensemble2.data.memberships.IDataIndexManager) TreeNodeSignature(org.knime.base.node.mine.treeensemble2.model.TreeNodeSignature) TreeNodeRegression(org.knime.base.node.mine.treeensemble2.model.TreeNodeRegression) DataMemberships(org.knime.base.node.mine.treeensemble2.data.memberships.DataMemberships) RootDataMemberships(org.knime.base.node.mine.treeensemble2.data.memberships.RootDataMemberships) TreeModelRegression(org.knime.base.node.mine.treeensemble2.model.TreeModelRegression) GradientBoostingLearnerConfiguration(org.knime.base.node.mine.treeensemble2.node.gradientboosting.learner.GradientBoostingLearnerConfiguration) TreeData(org.knime.base.node.mine.treeensemble2.data.TreeData) RowSample(org.knime.base.node.mine.treeensemble2.sample.row.RowSample)

Example 5 with TreeModelRegression

use of org.knime.base.node.mine.treeensemble2.model.TreeModelRegression in project knime-core by knime.

the class AbstractGradientBoostedTreesLearner method adaptPreviousPrediction.

/**
 * Adapts the previous prediction by adding the predictions of the <b>tree</b> regulated by the respective
 * coefficients in <b>coefficientMap</b>.
 *
 * @param previousPrediction Prediction of the previous steps
 * @param tree the tree of the current iteration
 * @param coefficientMap contains the coefficients for the leafs of the tree
 */
protected void adaptPreviousPrediction(final double[] previousPrediction, final TreeModelRegression tree, final Map<TreeNodeSignature, Double> coefficientMap) {
    TreeData data = getData();
    IDataIndexManager indexManager = getIndexManager();
    for (int i = 0; i < data.getNrRows(); i++) {
        PredictorRecord record = createPredictorRecord(data, indexManager, i);
        previousPrediction[i] += coefficientMap.get(tree.findMatchingNode(record).getSignature());
    }
}
Also used : PredictorRecord(org.knime.base.node.mine.treeensemble2.data.PredictorRecord) TreeData(org.knime.base.node.mine.treeensemble2.data.TreeData) IDataIndexManager(org.knime.base.node.mine.treeensemble2.data.memberships.IDataIndexManager)

Aggregations

TreeModelRegression (org.knime.base.node.mine.treeensemble2.model.TreeModelRegression)13 TreeNodeSignature (org.knime.base.node.mine.treeensemble2.model.TreeNodeSignature)10 TreeNodeRegression (org.knime.base.node.mine.treeensemble2.model.TreeNodeRegression)6 Map (java.util.Map)5 TreeData (org.knime.base.node.mine.treeensemble2.data.TreeData)5 ArrayList (java.util.ArrayList)4 HashMap (java.util.HashMap)4 Segment (org.dmg.pmml.SegmentDocument.Segment)4 Segmentation (org.dmg.pmml.SegmentationDocument.Segmentation)4 PredictorRecord (org.knime.base.node.mine.treeensemble2.data.PredictorRecord)4 RandomData (org.apache.commons.math.random.RandomData)3 TreeTargetNumericColumnData (org.knime.base.node.mine.treeensemble2.data.TreeTargetNumericColumnData)3 IDataIndexManager (org.knime.base.node.mine.treeensemble2.data.memberships.IDataIndexManager)3 RegressionTreeModel (org.knime.base.node.mine.treeensemble2.model.RegressionTreeModel)3 DataRow (org.knime.core.data.DataRow)3 DataTableSpec (org.knime.core.data.DataTableSpec)3 IOException (java.io.IOException)2 List (java.util.List)2 Target (org.dmg.pmml.TargetDocument.Target)2 Targets (org.dmg.pmml.TargetsDocument.Targets)2