Search in sources :

Example 11 with ColumnSample

use of org.knime.base.node.mine.treeensemble2.sample.column.ColumnSample in project knime-core by knime.

the class TreeEnsembleLearner method createColumnStatisticTable.

public BufferedDataTable createColumnStatisticTable(final ExecutionContext exec) throws CanceledExecutionException {
    BufferedDataContainer c = exec.createDataContainer(getColumnStatisticTableSpec());
    final int nrModels = m_ensembleModel.getNrModels();
    final TreeAttributeColumnData[] columns = m_data.getColumns();
    final int nrAttributes = columns.length;
    int[][] columnOnLevelCounts = new int[REPORT_LEVEL][nrAttributes];
    int[][] columnInLevelSampleCounts = new int[REPORT_LEVEL][nrAttributes];
    for (int i = 0; i < nrModels; i++) {
        final AbstractTreeModel<?> treeModel = m_ensembleModel.getTreeModel(i);
        for (int level = 0; level < REPORT_LEVEL; level++) {
            for (AbstractTreeNode treeNodeOnLevel : treeModel.getTreeNodes(level)) {
                TreeNodeSignature sig = treeNodeOnLevel.getSignature();
                ColumnSampleStrategy colStrat = m_columnSampleStrategies[i];
                ColumnSample cs = colStrat.getColumnSampleForTreeNode(sig);
                for (TreeAttributeColumnData col : cs) {
                    final int index = col.getMetaData().getAttributeIndex();
                    columnInLevelSampleCounts[level][index] += 1;
                }
                int splitAttIdx = treeNodeOnLevel.getSplitAttributeIndex();
                if (splitAttIdx >= 0) {
                    columnOnLevelCounts[level][splitAttIdx] += 1;
                }
            }
        }
    }
    for (int i = 0; i < nrAttributes; i++) {
        String name = columns[i].getMetaData().getAttributeName();
        int[] counts = new int[2 * REPORT_LEVEL];
        for (int level = 0; level < REPORT_LEVEL; level++) {
            counts[level] = columnOnLevelCounts[level][i];
            counts[REPORT_LEVEL + level] = columnInLevelSampleCounts[level][i];
        }
        DataRow row = new DefaultRow(name, counts);
        c.addRowToTable(row);
        exec.checkCanceled();
    }
    c.close();
    return c.getTable();
}
Also used : ColumnSampleStrategy(org.knime.base.node.mine.treeensemble2.sample.column.ColumnSampleStrategy) TreeAttributeColumnData(org.knime.base.node.mine.treeensemble2.data.TreeAttributeColumnData) BufferedDataContainer(org.knime.core.node.BufferedDataContainer) ColumnSample(org.knime.base.node.mine.treeensemble2.sample.column.ColumnSample) AbstractTreeNode(org.knime.base.node.mine.treeensemble2.model.AbstractTreeNode) TreeNodeSignature(org.knime.base.node.mine.treeensemble2.model.TreeNodeSignature) DataRow(org.knime.core.data.DataRow) DefaultRow(org.knime.core.data.def.DefaultRow)

Example 12 with ColumnSample

use of org.knime.base.node.mine.treeensemble2.sample.column.ColumnSample in project knime-core by knime.

the class TreeLearnerRegression method buildTreeNode.

private TreeNodeRegression buildTreeNode(final ExecutionMonitor exec, final int currentDepth, final DataMemberships dataMemberships, final ColumnSample columnSample, final TreeNodeSignature treeNodeSignature, final RegressionPriors targetPriors, final BitSet forbiddenColumnSet) throws CanceledExecutionException {
    final TreeData data = getData();
    final RandomData rd = getRandomData();
    final TreeEnsembleLearnerConfiguration config = getConfig();
    exec.checkCanceled();
    final SplitCandidate candidate = findBestSplitRegression(currentDepth, dataMemberships, columnSample, targetPriors, forbiddenColumnSet);
    if (candidate == null) {
        if (config instanceof GradientBoostingLearnerConfiguration) {
            TreeNodeRegression leaf = new TreeNodeRegression(treeNodeSignature, targetPriors, dataMemberships.getOriginalIndices());
            addToLeafList(leaf);
            return leaf;
        }
        return new TreeNodeRegression(treeNodeSignature, targetPriors);
    }
    final TreeTargetNumericColumnData targetColumn = (TreeTargetNumericColumnData) data.getTargetColumn();
    boolean useSurrogates = config.getMissingValueHandling() == MissingValueHandling.Surrogate;
    TreeNodeCondition[] childConditions;
    TreeNodeRegression[] childNodes;
    if (useSurrogates) {
        SurrogateSplit surrogateSplit = Surrogates.learnSurrogates(dataMemberships, candidate, data, columnSample, config, rd);
        childConditions = surrogateSplit.getChildConditions();
        BitSet[] childMarkers = surrogateSplit.getChildMarkers();
        assert childMarkers[0].cardinality() + childMarkers[1].cardinality() == dataMemberships.getRowCount() : "Sum of rows in children does not add up to number of rows in parent.";
        childNodes = new TreeNodeRegression[2];
        for (int i = 0; i < 2; i++) {
            DataMemberships childMemberships = dataMemberships.createChildMemberships(childMarkers[i]);
            TreeNodeSignature childSignature = getSignatureFactory().getChildSignatureFor(treeNodeSignature, (byte) i);
            ColumnSample childColumnSample = getColSamplingStrategy().getColumnSampleForTreeNode(childSignature);
            RegressionPriors childTargetPriors = targetColumn.getPriors(childMemberships, config);
            childNodes[i] = buildTreeNode(exec, currentDepth + 1, childMemberships, childColumnSample, childSignature, childTargetPriors, forbiddenColumnSet);
            childNodes[i].setTreeNodeCondition(childConditions[i]);
        }
    } else {
        SplitCandidate bestSplit = candidate;
        TreeAttributeColumnData splitColumn = bestSplit.getColumnData();
        final int attributeIndex = splitColumn.getMetaData().getAttributeIndex();
        boolean markAttributeAsForbidden = !bestSplit.canColumnBeSplitFurther();
        forbiddenColumnSet.set(attributeIndex, markAttributeAsForbidden);
        childConditions = bestSplit.getChildConditions();
        if (childConditions.length > Short.MAX_VALUE) {
            throw new RuntimeException("Too many children when splitting " + "attribute " + bestSplit.getColumnData() + " (maximum supported: " + Short.MAX_VALUE + "): " + childConditions.length);
        }
        childNodes = new TreeNodeRegression[childConditions.length];
        for (int i = 0; i < childConditions.length; i++) {
            TreeNodeCondition cond = childConditions[i];
            DataMemberships childMemberships = dataMemberships.createChildMemberships(splitColumn.updateChildMemberships(cond, dataMemberships));
            RegressionPriors childTargetPriors = targetColumn.getPriors(childMemberships, config);
            TreeNodeSignature childSignature = treeNodeSignature.createChildSignature((byte) i);
            ColumnSample childColumnSample = getColSamplingStrategy().getColumnSampleForTreeNode(childSignature);
            childNodes[i] = buildTreeNode(exec, currentDepth + 1, childMemberships, childColumnSample, childSignature, childTargetPriors, forbiddenColumnSet);
            childNodes[i].setTreeNodeCondition(cond);
        }
        if (markAttributeAsForbidden) {
            forbiddenColumnSet.set(attributeIndex, false);
        }
    }
    return new TreeNodeRegression(treeNodeSignature, targetPriors, childNodes);
}
Also used : TreeEnsembleLearnerConfiguration(org.knime.base.node.mine.treeensemble2.node.learner.TreeEnsembleLearnerConfiguration) RandomData(org.apache.commons.math.random.RandomData) TreeAttributeColumnData(org.knime.base.node.mine.treeensemble2.data.TreeAttributeColumnData) ColumnSample(org.knime.base.node.mine.treeensemble2.sample.column.ColumnSample) RegressionPriors(org.knime.base.node.mine.treeensemble2.data.RegressionPriors) BitSet(java.util.BitSet) TreeTargetNumericColumnData(org.knime.base.node.mine.treeensemble2.data.TreeTargetNumericColumnData) TreeNodeSignature(org.knime.base.node.mine.treeensemble2.model.TreeNodeSignature) TreeNodeRegression(org.knime.base.node.mine.treeensemble2.model.TreeNodeRegression) DataMemberships(org.knime.base.node.mine.treeensemble2.data.memberships.DataMemberships) RootDataMemberships(org.knime.base.node.mine.treeensemble2.data.memberships.RootDataMemberships) GradientBoostingLearnerConfiguration(org.knime.base.node.mine.treeensemble2.node.gradientboosting.learner.GradientBoostingLearnerConfiguration) TreeData(org.knime.base.node.mine.treeensemble2.data.TreeData) TreeNodeCondition(org.knime.base.node.mine.treeensemble2.model.TreeNodeCondition)

Example 13 with ColumnSample

use of org.knime.base.node.mine.treeensemble2.sample.column.ColumnSample in project knime-core by knime.

the class TreeLearnerRegression method findBestSplitRegression.

private SplitCandidate findBestSplitRegression(final int currentDepth, final DataMemberships dataMemberships, final ColumnSample columnSample, final RegressionPriors targetPriors, final BitSet forbiddenColumnSet) {
    final TreeData data = getData();
    final RandomData rd = getRandomData();
    final TreeEnsembleLearnerConfiguration config = getConfig();
    final int maxLevels = config.getMaxLevels();
    if (maxLevels != TreeEnsembleLearnerConfiguration.MAX_LEVEL_INFINITE && currentDepth >= maxLevels) {
        return null;
    }
    final int minNodeSize = config.getMinNodeSize();
    if (minNodeSize != TreeEnsembleLearnerConfiguration.MIN_NODE_SIZE_UNDEFINED) {
        if (targetPriors.getNrRecords() < minNodeSize) {
            return null;
        }
    }
    final double priorSquaredDeviation = targetPriors.getSumSquaredDeviation();
    if (priorSquaredDeviation < TreeColumnData.EPSILON) {
        return null;
    }
    final TreeTargetNumericColumnData targetColumn = getTargetData();
    SplitCandidate splitCandidate = null;
    if (currentDepth == 0 && config.getHardCodedRootColumn() != null) {
        final TreeAttributeColumnData rootColumn = data.getColumn(config.getHardCodedRootColumn());
        return rootColumn.calcBestSplitRegression(dataMemberships, targetPriors, targetColumn, rd);
    } else {
        double bestGainValue = 0.0;
        for (TreeAttributeColumnData col : columnSample) {
            if (forbiddenColumnSet.get(col.getMetaData().getAttributeIndex())) {
                continue;
            }
            SplitCandidate currentColSplit = col.calcBestSplitRegression(dataMemberships, targetPriors, targetColumn, rd);
            if (currentColSplit != null) {
                double gainValue = currentColSplit.getGainValue();
                if (gainValue > bestGainValue) {
                    bestGainValue = gainValue;
                    splitCandidate = currentColSplit;
                }
            }
        }
        return splitCandidate;
    }
}
Also used : TreeEnsembleLearnerConfiguration(org.knime.base.node.mine.treeensemble2.node.learner.TreeEnsembleLearnerConfiguration) RandomData(org.apache.commons.math.random.RandomData) TreeAttributeColumnData(org.knime.base.node.mine.treeensemble2.data.TreeAttributeColumnData) TreeTargetNumericColumnData(org.knime.base.node.mine.treeensemble2.data.TreeTargetNumericColumnData) TreeData(org.knime.base.node.mine.treeensemble2.data.TreeData)

Aggregations

TreeAttributeColumnData (org.knime.base.node.mine.treeensemble2.data.TreeAttributeColumnData)8 TreeData (org.knime.base.node.mine.treeensemble2.data.TreeData)8 TreeNodeSignature (org.knime.base.node.mine.treeensemble2.model.TreeNodeSignature)8 TreeEnsembleLearnerConfiguration (org.knime.base.node.mine.treeensemble2.node.learner.TreeEnsembleLearnerConfiguration)8 BitSet (java.util.BitSet)5 RandomData (org.apache.commons.math.random.RandomData)5 TreeTargetNominalColumnData (org.knime.base.node.mine.treeensemble2.data.TreeTargetNominalColumnData)5 ColumnSample (org.knime.base.node.mine.treeensemble2.sample.column.ColumnSample)5 TreeTargetNumericColumnData (org.knime.base.node.mine.treeensemble2.data.TreeTargetNumericColumnData)4 DataMemberships (org.knime.base.node.mine.treeensemble2.data.memberships.DataMemberships)4 RootDataMemberships (org.knime.base.node.mine.treeensemble2.data.memberships.RootDataMemberships)4 Test (org.junit.Test)3 ClassificationPriors (org.knime.base.node.mine.treeensemble2.data.ClassificationPriors)3 TreeNodeSignatureFactory (org.knime.base.node.mine.treeensemble2.learner.TreeNodeSignatureFactory)3 TreeNodeCondition (org.knime.base.node.mine.treeensemble2.model.TreeNodeCondition)3 ArrayList (java.util.ArrayList)2 Comparator (java.util.Comparator)2 RegressionPriors (org.knime.base.node.mine.treeensemble2.data.RegressionPriors)2 TreeNodeClassification (org.knime.base.node.mine.treeensemble2.model.TreeNodeClassification)2 TreeNodeRegression (org.knime.base.node.mine.treeensemble2.model.TreeNodeRegression)2