Search in sources :

Example 1 with TreeNodeSignature

use of org.knime.base.node.mine.treeensemble2.model.TreeNodeSignature in project knime-core by knime.

the class RegressionGBTModelImporter method importFromPMMLInternal.

/**
 * {@inheritDoc}
 */
@Override
public GradientBoostedTreesModel importFromPMMLInternal(final MiningModel miningModel) {
    Segmentation segmentation = miningModel.getSegmentation();
    CheckUtils.checkArgument(segmentation.getMultipleModelMethod() == MULTIPLEMODELMETHOD.SUM, "The provided segmentation has not the required sum as multiple model method but '%s' instead.", segmentation.getMultipleModelMethod());
    Pair<List<TreeModelRegression>, List<Map<TreeNodeSignature, Double>>> treesCoeffientMapsPair = readSumSegmentation(segmentation);
    List<TreeModelRegression> trees = treesCoeffientMapsPair.getFirst();
    // TODO user should be warned if there is no initial value or anything else is fishy
    double initialValue = miningModel.getTargets().getTargetList().get(0).getRescaleConstant();
    // currently only models learned on "ordinary" columns can be read back in
    return new GradientBoostedTreesModel(getMetaDataMapper().getTreeMetaData(), trees.toArray(new TreeModelRegression[trees.size()]), TreeType.Ordinary, initialValue, treesCoeffientMapsPair.getSecond());
}
Also used : Segmentation(org.dmg.pmml.SegmentationDocument.Segmentation) List(java.util.List) GradientBoostedTreesModel(org.knime.base.node.mine.treeensemble2.model.GradientBoostedTreesModel) TreeNodeSignature(org.knime.base.node.mine.treeensemble2.model.TreeNodeSignature) TreeModelRegression(org.knime.base.node.mine.treeensemble2.model.TreeModelRegression)

Example 2 with TreeNodeSignature

use of org.knime.base.node.mine.treeensemble2.model.TreeNodeSignature in project knime-core by knime.

the class SubsetColumnSampleStrategyTest method testGetColumnSampleForTreeNode.

/**
 * Tests the method {@link SubsetColumnSampleStrategy#getColumnSampleForTreeNode(org.knime.base.node.mine.treeensemble2.model.TreeNodeSignature)}
 * also tests {@link SubsetColumnSample} since the both always act in combination.
 *
 * @throws Exception
 */
@Test
public void testGetColumnSampleForTreeNode() throws Exception {
    final SubsetColumnSampleStrategy strategy = new SubsetColumnSampleStrategy(createTreeData(), RD, 5);
    TreeNodeSignatureFactory sigFac = createSignatureFactory();
    TreeNodeSignature rootSig = sigFac.getRootSignature();
    ColumnSample sample = strategy.getColumnSampleForTreeNode(rootSig);
    assertEquals("Wrong number of columns in sample.", 5, sample.getNumCols());
    int[] colIndices = sample.getColumnIndices();
    sample = strategy.getColumnSampleForTreeNode(sigFac.getChildSignatureFor(rootSig, (byte) 0));
    assertEquals("Wrong number of columns in sample.", 5, sample.getNumCols());
    assertArrayEquals(colIndices, sample.getColumnIndices());
    sample = strategy.getColumnSampleForTreeNode(sigFac.getChildSignatureFor(rootSig, (byte) 1));
    assertEquals("Wrong number of columns in sample.", 5, sample.getNumCols());
    assertArrayEquals(colIndices, sample.getColumnIndices());
}
Also used : TreeNodeSignature(org.knime.base.node.mine.treeensemble2.model.TreeNodeSignature) TreeNodeSignatureFactory(org.knime.base.node.mine.treeensemble2.learner.TreeNodeSignatureFactory) Test(org.junit.Test)

Example 3 with TreeNodeSignature

use of org.knime.base.node.mine.treeensemble2.model.TreeNodeSignature in project knime-core by knime.

the class TreeLearnerRegression method learnSingleTree.

/**
 * {@inheritDoc}
 */
@Override
public TreeModelRegression learnSingleTree(final ExecutionMonitor exec, final RandomData rd) throws CanceledExecutionException {
    final TreeTargetNumericColumnData targetColumn = getTargetData();
    final TreeData data = getData();
    final RowSample rowSampling = getRowSampling();
    final TreeEnsembleLearnerConfiguration config = getConfig();
    final IDataIndexManager indexManager = getIndexManager();
    DataMemberships rootDataMemberships = new RootDataMemberships(rowSampling, data, indexManager);
    RegressionPriors targetPriors = targetColumn.getPriors(rootDataMemberships, config);
    BitSet forbiddenColumnSet = new BitSet(data.getNrAttributes());
    boolean isGradientBoosting = config instanceof GradientBoostingLearnerConfiguration;
    if (isGradientBoosting) {
        m_leafs = new ArrayList<TreeNodeRegression>();
    }
    final TreeNodeSignature rootSignature = TreeNodeSignature.ROOT_SIGNATURE;
    final ColumnSample rootColumnSample = getColSamplingStrategy().getColumnSampleForTreeNode(rootSignature);
    TreeNodeRegression rootNode = buildTreeNode(exec, 0, rootDataMemberships, rootColumnSample, getSignatureFactory().getRootSignature(), targetPriors, forbiddenColumnSet);
    assert forbiddenColumnSet.cardinality() == 0;
    rootNode.setTreeNodeCondition(TreeNodeTrueCondition.INSTANCE);
    if (isGradientBoosting) {
        return new TreeModelRegression(rootNode, m_leafs);
    }
    return new TreeModelRegression(rootNode);
}
Also used : TreeEnsembleLearnerConfiguration(org.knime.base.node.mine.treeensemble2.node.learner.TreeEnsembleLearnerConfiguration) RootDataMemberships(org.knime.base.node.mine.treeensemble2.data.memberships.RootDataMemberships) ColumnSample(org.knime.base.node.mine.treeensemble2.sample.column.ColumnSample) RegressionPriors(org.knime.base.node.mine.treeensemble2.data.RegressionPriors) BitSet(java.util.BitSet) TreeTargetNumericColumnData(org.knime.base.node.mine.treeensemble2.data.TreeTargetNumericColumnData) IDataIndexManager(org.knime.base.node.mine.treeensemble2.data.memberships.IDataIndexManager) TreeNodeSignature(org.knime.base.node.mine.treeensemble2.model.TreeNodeSignature) TreeNodeRegression(org.knime.base.node.mine.treeensemble2.model.TreeNodeRegression) DataMemberships(org.knime.base.node.mine.treeensemble2.data.memberships.DataMemberships) RootDataMemberships(org.knime.base.node.mine.treeensemble2.data.memberships.RootDataMemberships) TreeModelRegression(org.knime.base.node.mine.treeensemble2.model.TreeModelRegression) GradientBoostingLearnerConfiguration(org.knime.base.node.mine.treeensemble2.node.gradientboosting.learner.GradientBoostingLearnerConfiguration) TreeData(org.knime.base.node.mine.treeensemble2.data.TreeData) RowSample(org.knime.base.node.mine.treeensemble2.sample.row.RowSample)

Example 4 with TreeNodeSignature

use of org.knime.base.node.mine.treeensemble2.model.TreeNodeSignature in project knime-core by knime.

the class AbstractGradientBoostedTreesLearner method adaptPreviousPrediction.

/**
 * Adapts the previous prediction by adding the predictions of the <b>tree</b> regulated by the respective
 * coefficients in <b>coefficientMap</b>.
 *
 * @param previousPrediction Prediction of the previous steps
 * @param tree the tree of the current iteration
 * @param coefficientMap contains the coefficients for the leafs of the tree
 */
protected void adaptPreviousPrediction(final double[] previousPrediction, final TreeModelRegression tree, final Map<TreeNodeSignature, Double> coefficientMap) {
    TreeData data = getData();
    IDataIndexManager indexManager = getIndexManager();
    for (int i = 0; i < data.getNrRows(); i++) {
        PredictorRecord record = createPredictorRecord(data, indexManager, i);
        previousPrediction[i] += coefficientMap.get(tree.findMatchingNode(record).getSignature());
    }
}
Also used : PredictorRecord(org.knime.base.node.mine.treeensemble2.data.PredictorRecord) TreeData(org.knime.base.node.mine.treeensemble2.data.TreeData) IDataIndexManager(org.knime.base.node.mine.treeensemble2.data.memberships.IDataIndexManager)

Example 5 with TreeNodeSignature

use of org.knime.base.node.mine.treeensemble2.model.TreeNodeSignature in project knime-core by knime.

the class LKGradientBoostedTreesLearner method calculateCoefficientMap.

private Map<TreeNodeSignature, Double> calculateCoefficientMap(final TreeModelRegression tree, final TreeData pseudoResiduals, final double numClasses) {
    final List<TreeNodeRegression> leafs = tree.getLeafs();
    final Map<TreeNodeSignature, Double> coefficientMap = new HashMap<TreeNodeSignature, Double>();
    final TreeTargetNumericColumnData pseudoTarget = (TreeTargetNumericColumnData) pseudoResiduals.getTargetColumn();
    double learningRate = getConfig().getLearningRate();
    for (TreeNodeRegression leaf : leafs) {
        final int[] indices = leaf.getRowIndicesInTreeData();
        double sumTop = 0;
        double sumBottom = 0;
        for (int index : indices) {
            double val = pseudoTarget.getValueFor(index);
            sumTop += val;
            double absVal = Math.abs(val);
            sumBottom += Math.abs(absVal) * (1 - Math.abs(absVal));
        }
        final double coefficient = (numClasses - 1) / numClasses * (sumTop / sumBottom);
        coefficientMap.put(leaf.getSignature(), learningRate * coefficient);
    }
    return coefficientMap;
}
Also used : HashMap(java.util.HashMap) TreeTargetNumericColumnData(org.knime.base.node.mine.treeensemble2.data.TreeTargetNumericColumnData) TreeNodeSignature(org.knime.base.node.mine.treeensemble2.model.TreeNodeSignature) TreeNodeRegression(org.knime.base.node.mine.treeensemble2.model.TreeNodeRegression)

Aggregations

TreeNodeSignature (org.knime.base.node.mine.treeensemble2.model.TreeNodeSignature)20 TreeData (org.knime.base.node.mine.treeensemble2.data.TreeData)10 TreeModelRegression (org.knime.base.node.mine.treeensemble2.model.TreeModelRegression)8 ArrayList (java.util.ArrayList)6 TreeNodeRegression (org.knime.base.node.mine.treeensemble2.model.TreeNodeRegression)6 TreeEnsembleLearnerConfiguration (org.knime.base.node.mine.treeensemble2.node.learner.TreeEnsembleLearnerConfiguration)6 Map (java.util.Map)5 RandomData (org.apache.commons.math.random.RandomData)5 TreeAttributeColumnData (org.knime.base.node.mine.treeensemble2.data.TreeAttributeColumnData)5 TreeNodeSignatureFactory (org.knime.base.node.mine.treeensemble2.learner.TreeNodeSignatureFactory)5 ColumnSample (org.knime.base.node.mine.treeensemble2.sample.column.ColumnSample)5 BitSet (java.util.BitSet)4 HashMap (java.util.HashMap)4 Segment (org.dmg.pmml.SegmentDocument.Segment)4 TreeTargetNominalColumnData (org.knime.base.node.mine.treeensemble2.data.TreeTargetNominalColumnData)4 RootDataMemberships (org.knime.base.node.mine.treeensemble2.data.memberships.RootDataMemberships)4 List (java.util.List)3 Segmentation (org.dmg.pmml.SegmentationDocument.Segmentation)3 Test (org.junit.Test)3 TreeTargetNumericColumnData (org.knime.base.node.mine.treeensemble2.data.TreeTargetNumericColumnData)3