Search in sources :

Example 21 with TreeNodeSignature

use of org.knime.base.node.mine.treeensemble2.model.TreeNodeSignature in project knime-core by knime.

the class RFSubsetColumnSampleStrategyTest method testGetColumnSampleForTreeNode.

/**
 * Tests the method {@link RFSubsetColumnSampleStrategy#getColumnSampleForTreeNode(org.knime.base.node.mine.treeensemble2.model.TreeNodeSignature)}
 *
 * @throws Exception
 */
@Test
public void testGetColumnSampleForTreeNode() throws Exception {
    final RFSubsetColumnSampleStrategy strategy = new RFSubsetColumnSampleStrategy(createTreeData(), RD, 5);
    final TreeNodeSignatureFactory sigFac = createSignatureFactory();
    TreeNodeSignature rootSig = sigFac.getRootSignature();
    ColumnSample sample = strategy.getColumnSampleForTreeNode(rootSig);
    assertEquals("Wrong number of columns in sample.", 5, sample.getNumCols());
    int[] colIndices0 = sample.getColumnIndices();
    sample = strategy.getColumnSampleForTreeNode(sigFac.getChildSignatureFor(rootSig, (byte) 0));
    assertEquals("Wrong number of columns in sample.", 5, sample.getNumCols());
    int[] colIndices1 = sample.getColumnIndices();
    sample = strategy.getColumnSampleForTreeNode(sigFac.getChildSignatureFor(rootSig, (byte) 1));
    assertEquals("Wrong number of columns in sample.", 5, sample.getNumCols());
    int[] colIndices2 = sample.getColumnIndices();
    assertEquals("sample sizes differ.", colIndices0.length, colIndices1.length);
    assertEquals("sample sizes differ.", colIndices0.length, colIndices2.length);
    assertEquals("sample sizes differ.", colIndices1.length, colIndices2.length);
    boolean match = true;
    for (int i = 0; i < colIndices0.length; i++) {
        match = match && colIndices0[i] == colIndices1[i] && colIndices0[i] == colIndices2[i];
        if (!match) {
            break;
        }
    }
    assertFalse("It is very unlikely that we get 3 times the same column sample.", match);
}
Also used : TreeNodeSignature(org.knime.base.node.mine.treeensemble2.model.TreeNodeSignature) TreeNodeSignatureFactory(org.knime.base.node.mine.treeensemble2.learner.TreeNodeSignatureFactory) Test(org.junit.Test)

Example 22 with TreeNodeSignature

use of org.knime.base.node.mine.treeensemble2.model.TreeNodeSignature in project knime-core by knime.

the class TreeLearnerClassification method buildTreeNode.

private TreeNodeClassification buildTreeNode(final ExecutionMonitor exec, final int currentDepth, final DataMemberships dataMemberships, final ColumnSample columnSample, final TreeNodeSignature treeNodeSignature, final ClassificationPriors targetPriors, final BitSet forbiddenColumnSet) throws CanceledExecutionException {
    final TreeData data = getData();
    final TreeEnsembleLearnerConfiguration config = getConfig();
    exec.checkCanceled();
    final boolean useSurrogates = getConfig().getMissingValueHandling() == MissingValueHandling.Surrogate;
    TreeNodeCondition[] childConditions;
    boolean markAttributeAsForbidden = false;
    final TreeTargetNominalColumnData targetColumn = (TreeTargetNominalColumnData) data.getTargetColumn();
    TreeNodeClassification[] childNodes;
    int attributeIndex = -1;
    if (useSurrogates) {
        SplitCandidate[] candidates = findBestSplitsClassification(currentDepth, dataMemberships, columnSample, treeNodeSignature, targetPriors, forbiddenColumnSet);
        if (candidates == null) {
            return new TreeNodeClassification(treeNodeSignature, targetPriors, config);
        }
        SurrogateSplit surrogateSplit = Surrogates.learnSurrogates(dataMemberships, candidates[0], data, columnSample, config, getRandomData());
        childConditions = surrogateSplit.getChildConditions();
        BitSet[] childMarkers = surrogateSplit.getChildMarkers();
        childNodes = new TreeNodeClassification[2];
        for (int i = 0; i < 2; i++) {
            DataMemberships childMemberships = dataMemberships.createChildMemberships(childMarkers[i]);
            ClassificationPriors childTargetPriors = targetColumn.getDistribution(childMemberships, config);
            TreeNodeSignature childSignature = getSignatureFactory().getChildSignatureFor(treeNodeSignature, (byte) i);
            ColumnSample childColumnSample = getColSamplingStrategy().getColumnSampleForTreeNode(childSignature);
            childNodes[i] = buildTreeNode(exec, currentDepth + 1, childMemberships, childColumnSample, childSignature, childTargetPriors, forbiddenColumnSet);
            childNodes[i].setTreeNodeCondition(childConditions[i]);
        }
    } else {
        // handle non surrogate case
        SplitCandidate bestSplit = findBestSplitClassification(currentDepth, dataMemberships, columnSample, treeNodeSignature, targetPriors, forbiddenColumnSet);
        if (bestSplit == null) {
            return new TreeNodeClassification(treeNodeSignature, targetPriors, config);
        }
        TreeAttributeColumnData splitColumn = bestSplit.getColumnData();
        attributeIndex = splitColumn.getMetaData().getAttributeIndex();
        markAttributeAsForbidden = !bestSplit.canColumnBeSplitFurther();
        forbiddenColumnSet.set(attributeIndex, markAttributeAsForbidden);
        childConditions = bestSplit.getChildConditions();
        childNodes = new TreeNodeClassification[childConditions.length];
        if (childConditions.length > Short.MAX_VALUE) {
            throw new RuntimeException("Too many children when splitting " + "attribute " + bestSplit.getColumnData() + " (maximum supported: " + Short.MAX_VALUE + "): " + childConditions.length);
        }
        // Build child nodes
        for (int i = 0; i < childConditions.length; i++) {
            DataMemberships childMemberships = null;
            TreeNodeCondition cond = childConditions[i];
            childMemberships = dataMemberships.createChildMemberships(splitColumn.updateChildMemberships(cond, dataMemberships));
            ClassificationPriors childTargetPriors = targetColumn.getDistribution(childMemberships, config);
            TreeNodeSignature childSignature = treeNodeSignature.createChildSignature((byte) i);
            ColumnSample childColumnSample = getColSamplingStrategy().getColumnSampleForTreeNode(childSignature);
            childNodes[i] = buildTreeNode(exec, currentDepth + 1, childMemberships, childColumnSample, childSignature, childTargetPriors, forbiddenColumnSet);
            childNodes[i].setTreeNodeCondition(cond);
        }
    }
    if (markAttributeAsForbidden) {
        forbiddenColumnSet.set(attributeIndex, false);
    }
    return new TreeNodeClassification(treeNodeSignature, targetPriors, childNodes, getConfig());
}
Also used : TreeEnsembleLearnerConfiguration(org.knime.base.node.mine.treeensemble2.node.learner.TreeEnsembleLearnerConfiguration) TreeNodeClassification(org.knime.base.node.mine.treeensemble2.model.TreeNodeClassification) TreeAttributeColumnData(org.knime.base.node.mine.treeensemble2.data.TreeAttributeColumnData) ColumnSample(org.knime.base.node.mine.treeensemble2.sample.column.ColumnSample) BitSet(java.util.BitSet) TreeNodeSignature(org.knime.base.node.mine.treeensemble2.model.TreeNodeSignature) RootDataMemberships(org.knime.base.node.mine.treeensemble2.data.memberships.RootDataMemberships) DataMemberships(org.knime.base.node.mine.treeensemble2.data.memberships.DataMemberships) TreeData(org.knime.base.node.mine.treeensemble2.data.TreeData) TreeNodeCondition(org.knime.base.node.mine.treeensemble2.model.TreeNodeCondition) TreeTargetNominalColumnData(org.knime.base.node.mine.treeensemble2.data.TreeTargetNominalColumnData) ClassificationPriors(org.knime.base.node.mine.treeensemble2.data.ClassificationPriors)

Example 23 with TreeNodeSignature

use of org.knime.base.node.mine.treeensemble2.model.TreeNodeSignature in project knime-core by knime.

the class MGradientBoostedTreesLearner method calcCoefficientMap.

private Map<TreeNodeSignature, Double> calcCoefficientMap(final double[] residuals, final double quantile, final TreeModelRegression tree) {
    final List<TreeNodeRegression> leafs = tree.getLeafs();
    final Map<TreeNodeSignature, Double> coefficientMap = new HashMap<TreeNodeSignature, Double>((int) (leafs.size() / 0.75 + 1));
    final double learningRate = getConfig().getLearningRate();
    for (TreeNodeRegression leaf : leafs) {
        final int[] indices = leaf.getRowIndicesInTreeData();
        final double[] values = new double[indices.length];
        for (int i = 0; i < indices.length; i++) {
            values[i] = residuals[indices[i]];
        }
        final double median = calcMedian(values);
        double sum = 0;
        for (int i = 0; i < values.length; i++) {
            sum += Math.signum(values[i] - median) * Math.min(quantile, Math.abs(values[i] - median));
        }
        final double coefficient = median + (1.0 / values.length) * sum;
        coefficientMap.put(leaf.getSignature(), coefficient * learningRate);
    }
    return coefficientMap;
}
Also used : HashMap(java.util.HashMap) TreeNodeSignature(org.knime.base.node.mine.treeensemble2.model.TreeNodeSignature) TreeNodeRegression(org.knime.base.node.mine.treeensemble2.model.TreeNodeRegression)

Example 24 with TreeNodeSignature

use of org.knime.base.node.mine.treeensemble2.model.TreeNodeSignature in project knime-core by knime.

the class TreeEnsembleLearner method createColumnStatisticTable.

public BufferedDataTable createColumnStatisticTable(final ExecutionContext exec) throws CanceledExecutionException {
    BufferedDataContainer c = exec.createDataContainer(getColumnStatisticTableSpec());
    final int nrModels = m_ensembleModel.getNrModels();
    final TreeAttributeColumnData[] columns = m_data.getColumns();
    final int nrAttributes = columns.length;
    int[][] columnOnLevelCounts = new int[REPORT_LEVEL][nrAttributes];
    int[][] columnInLevelSampleCounts = new int[REPORT_LEVEL][nrAttributes];
    for (int i = 0; i < nrModels; i++) {
        final AbstractTreeModel<?> treeModel = m_ensembleModel.getTreeModel(i);
        for (int level = 0; level < REPORT_LEVEL; level++) {
            for (AbstractTreeNode treeNodeOnLevel : treeModel.getTreeNodes(level)) {
                TreeNodeSignature sig = treeNodeOnLevel.getSignature();
                ColumnSampleStrategy colStrat = m_columnSampleStrategies[i];
                ColumnSample cs = colStrat.getColumnSampleForTreeNode(sig);
                for (TreeAttributeColumnData col : cs) {
                    final int index = col.getMetaData().getAttributeIndex();
                    columnInLevelSampleCounts[level][index] += 1;
                }
                int splitAttIdx = treeNodeOnLevel.getSplitAttributeIndex();
                if (splitAttIdx >= 0) {
                    columnOnLevelCounts[level][splitAttIdx] += 1;
                }
            }
        }
    }
    for (int i = 0; i < nrAttributes; i++) {
        String name = columns[i].getMetaData().getAttributeName();
        int[] counts = new int[2 * REPORT_LEVEL];
        for (int level = 0; level < REPORT_LEVEL; level++) {
            counts[level] = columnOnLevelCounts[level][i];
            counts[REPORT_LEVEL + level] = columnInLevelSampleCounts[level][i];
        }
        DataRow row = new DefaultRow(name, counts);
        c.addRowToTable(row);
        exec.checkCanceled();
    }
    c.close();
    return c.getTable();
}
Also used : ColumnSampleStrategy(org.knime.base.node.mine.treeensemble2.sample.column.ColumnSampleStrategy) TreeAttributeColumnData(org.knime.base.node.mine.treeensemble2.data.TreeAttributeColumnData) BufferedDataContainer(org.knime.core.node.BufferedDataContainer) ColumnSample(org.knime.base.node.mine.treeensemble2.sample.column.ColumnSample) AbstractTreeNode(org.knime.base.node.mine.treeensemble2.model.AbstractTreeNode) TreeNodeSignature(org.knime.base.node.mine.treeensemble2.model.TreeNodeSignature) DataRow(org.knime.core.data.DataRow) DefaultRow(org.knime.core.data.def.DefaultRow)

Example 25 with TreeNodeSignature

use of org.knime.base.node.mine.treeensemble2.model.TreeNodeSignature in project knime-core by knime.

the class TreeLearnerRegression method buildTreeNode.

private TreeNodeRegression buildTreeNode(final ExecutionMonitor exec, final int currentDepth, final DataMemberships dataMemberships, final ColumnSample columnSample, final TreeNodeSignature treeNodeSignature, final RegressionPriors targetPriors, final BitSet forbiddenColumnSet) throws CanceledExecutionException {
    final TreeData data = getData();
    final RandomData rd = getRandomData();
    final TreeEnsembleLearnerConfiguration config = getConfig();
    exec.checkCanceled();
    final SplitCandidate candidate = findBestSplitRegression(currentDepth, dataMemberships, columnSample, targetPriors, forbiddenColumnSet);
    if (candidate == null) {
        if (config instanceof GradientBoostingLearnerConfiguration) {
            TreeNodeRegression leaf = new TreeNodeRegression(treeNodeSignature, targetPriors, dataMemberships.getOriginalIndices());
            addToLeafList(leaf);
            return leaf;
        }
        return new TreeNodeRegression(treeNodeSignature, targetPriors);
    }
    final TreeTargetNumericColumnData targetColumn = (TreeTargetNumericColumnData) data.getTargetColumn();
    boolean useSurrogates = config.getMissingValueHandling() == MissingValueHandling.Surrogate;
    TreeNodeCondition[] childConditions;
    TreeNodeRegression[] childNodes;
    if (useSurrogates) {
        SurrogateSplit surrogateSplit = Surrogates.learnSurrogates(dataMemberships, candidate, data, columnSample, config, rd);
        childConditions = surrogateSplit.getChildConditions();
        BitSet[] childMarkers = surrogateSplit.getChildMarkers();
        assert childMarkers[0].cardinality() + childMarkers[1].cardinality() == dataMemberships.getRowCount() : "Sum of rows in children does not add up to number of rows in parent.";
        childNodes = new TreeNodeRegression[2];
        for (int i = 0; i < 2; i++) {
            DataMemberships childMemberships = dataMemberships.createChildMemberships(childMarkers[i]);
            TreeNodeSignature childSignature = getSignatureFactory().getChildSignatureFor(treeNodeSignature, (byte) i);
            ColumnSample childColumnSample = getColSamplingStrategy().getColumnSampleForTreeNode(childSignature);
            RegressionPriors childTargetPriors = targetColumn.getPriors(childMemberships, config);
            childNodes[i] = buildTreeNode(exec, currentDepth + 1, childMemberships, childColumnSample, childSignature, childTargetPriors, forbiddenColumnSet);
            childNodes[i].setTreeNodeCondition(childConditions[i]);
        }
    } else {
        SplitCandidate bestSplit = candidate;
        TreeAttributeColumnData splitColumn = bestSplit.getColumnData();
        final int attributeIndex = splitColumn.getMetaData().getAttributeIndex();
        boolean markAttributeAsForbidden = !bestSplit.canColumnBeSplitFurther();
        forbiddenColumnSet.set(attributeIndex, markAttributeAsForbidden);
        childConditions = bestSplit.getChildConditions();
        if (childConditions.length > Short.MAX_VALUE) {
            throw new RuntimeException("Too many children when splitting " + "attribute " + bestSplit.getColumnData() + " (maximum supported: " + Short.MAX_VALUE + "): " + childConditions.length);
        }
        childNodes = new TreeNodeRegression[childConditions.length];
        for (int i = 0; i < childConditions.length; i++) {
            TreeNodeCondition cond = childConditions[i];
            DataMemberships childMemberships = dataMemberships.createChildMemberships(splitColumn.updateChildMemberships(cond, dataMemberships));
            RegressionPriors childTargetPriors = targetColumn.getPriors(childMemberships, config);
            TreeNodeSignature childSignature = treeNodeSignature.createChildSignature((byte) i);
            ColumnSample childColumnSample = getColSamplingStrategy().getColumnSampleForTreeNode(childSignature);
            childNodes[i] = buildTreeNode(exec, currentDepth + 1, childMemberships, childColumnSample, childSignature, childTargetPriors, forbiddenColumnSet);
            childNodes[i].setTreeNodeCondition(cond);
        }
        if (markAttributeAsForbidden) {
            forbiddenColumnSet.set(attributeIndex, false);
        }
    }
    return new TreeNodeRegression(treeNodeSignature, targetPriors, childNodes);
}
Also used : TreeEnsembleLearnerConfiguration(org.knime.base.node.mine.treeensemble2.node.learner.TreeEnsembleLearnerConfiguration) RandomData(org.apache.commons.math.random.RandomData) TreeAttributeColumnData(org.knime.base.node.mine.treeensemble2.data.TreeAttributeColumnData) ColumnSample(org.knime.base.node.mine.treeensemble2.sample.column.ColumnSample) RegressionPriors(org.knime.base.node.mine.treeensemble2.data.RegressionPriors) BitSet(java.util.BitSet) TreeTargetNumericColumnData(org.knime.base.node.mine.treeensemble2.data.TreeTargetNumericColumnData) TreeNodeSignature(org.knime.base.node.mine.treeensemble2.model.TreeNodeSignature) TreeNodeRegression(org.knime.base.node.mine.treeensemble2.model.TreeNodeRegression) DataMemberships(org.knime.base.node.mine.treeensemble2.data.memberships.DataMemberships) RootDataMemberships(org.knime.base.node.mine.treeensemble2.data.memberships.RootDataMemberships) GradientBoostingLearnerConfiguration(org.knime.base.node.mine.treeensemble2.node.gradientboosting.learner.GradientBoostingLearnerConfiguration) TreeData(org.knime.base.node.mine.treeensemble2.data.TreeData) TreeNodeCondition(org.knime.base.node.mine.treeensemble2.model.TreeNodeCondition)

Aggregations

TreeNodeSignature (org.knime.base.node.mine.treeensemble2.model.TreeNodeSignature)20 TreeData (org.knime.base.node.mine.treeensemble2.data.TreeData)10 TreeModelRegression (org.knime.base.node.mine.treeensemble2.model.TreeModelRegression)8 ArrayList (java.util.ArrayList)6 TreeNodeRegression (org.knime.base.node.mine.treeensemble2.model.TreeNodeRegression)6 TreeEnsembleLearnerConfiguration (org.knime.base.node.mine.treeensemble2.node.learner.TreeEnsembleLearnerConfiguration)6 Map (java.util.Map)5 RandomData (org.apache.commons.math.random.RandomData)5 TreeAttributeColumnData (org.knime.base.node.mine.treeensemble2.data.TreeAttributeColumnData)5 TreeNodeSignatureFactory (org.knime.base.node.mine.treeensemble2.learner.TreeNodeSignatureFactory)5 ColumnSample (org.knime.base.node.mine.treeensemble2.sample.column.ColumnSample)5 BitSet (java.util.BitSet)4 HashMap (java.util.HashMap)4 Segment (org.dmg.pmml.SegmentDocument.Segment)4 TreeTargetNominalColumnData (org.knime.base.node.mine.treeensemble2.data.TreeTargetNominalColumnData)4 RootDataMemberships (org.knime.base.node.mine.treeensemble2.data.memberships.RootDataMemberships)4 List (java.util.List)3 Segmentation (org.dmg.pmml.SegmentationDocument.Segmentation)3 Test (org.junit.Test)3 TreeTargetNumericColumnData (org.knime.base.node.mine.treeensemble2.data.TreeTargetNumericColumnData)3