Search in sources :

Example 1 with ColumnSample

use of org.knime.base.node.mine.treeensemble2.sample.column.ColumnSample in project knime-core by knime.

the class SubsetColumnSampleStrategyTest method testGetColumnSampleForTreeNode.

/**
 * Tests the method {@link SubsetColumnSampleStrategy#getColumnSampleForTreeNode(org.knime.base.node.mine.treeensemble2.model.TreeNodeSignature)}
 * also tests {@link SubsetColumnSample} since the both always act in combination.
 *
 * @throws Exception
 */
@Test
public void testGetColumnSampleForTreeNode() throws Exception {
    final SubsetColumnSampleStrategy strategy = new SubsetColumnSampleStrategy(createTreeData(), RD, 5);
    TreeNodeSignatureFactory sigFac = createSignatureFactory();
    TreeNodeSignature rootSig = sigFac.getRootSignature();
    ColumnSample sample = strategy.getColumnSampleForTreeNode(rootSig);
    assertEquals("Wrong number of columns in sample.", 5, sample.getNumCols());
    int[] colIndices = sample.getColumnIndices();
    sample = strategy.getColumnSampleForTreeNode(sigFac.getChildSignatureFor(rootSig, (byte) 0));
    assertEquals("Wrong number of columns in sample.", 5, sample.getNumCols());
    assertArrayEquals(colIndices, sample.getColumnIndices());
    sample = strategy.getColumnSampleForTreeNode(sigFac.getChildSignatureFor(rootSig, (byte) 1));
    assertEquals("Wrong number of columns in sample.", 5, sample.getNumCols());
    assertArrayEquals(colIndices, sample.getColumnIndices());
}
Also used : TreeNodeSignature(org.knime.base.node.mine.treeensemble2.model.TreeNodeSignature) TreeNodeSignatureFactory(org.knime.base.node.mine.treeensemble2.learner.TreeNodeSignatureFactory) Test(org.junit.Test)

Example 2 with ColumnSample

use of org.knime.base.node.mine.treeensemble2.sample.column.ColumnSample in project knime-core by knime.

the class TreeLearnerRegression method learnSingleTree.

/**
 * {@inheritDoc}
 */
@Override
public TreeModelRegression learnSingleTree(final ExecutionMonitor exec, final RandomData rd) throws CanceledExecutionException {
    final TreeTargetNumericColumnData targetColumn = getTargetData();
    final TreeData data = getData();
    final RowSample rowSampling = getRowSampling();
    final TreeEnsembleLearnerConfiguration config = getConfig();
    final IDataIndexManager indexManager = getIndexManager();
    DataMemberships rootDataMemberships = new RootDataMemberships(rowSampling, data, indexManager);
    RegressionPriors targetPriors = targetColumn.getPriors(rootDataMemberships, config);
    BitSet forbiddenColumnSet = new BitSet(data.getNrAttributes());
    boolean isGradientBoosting = config instanceof GradientBoostingLearnerConfiguration;
    if (isGradientBoosting) {
        m_leafs = new ArrayList<TreeNodeRegression>();
    }
    final TreeNodeSignature rootSignature = TreeNodeSignature.ROOT_SIGNATURE;
    final ColumnSample rootColumnSample = getColSamplingStrategy().getColumnSampleForTreeNode(rootSignature);
    TreeNodeRegression rootNode = buildTreeNode(exec, 0, rootDataMemberships, rootColumnSample, getSignatureFactory().getRootSignature(), targetPriors, forbiddenColumnSet);
    assert forbiddenColumnSet.cardinality() == 0;
    rootNode.setTreeNodeCondition(TreeNodeTrueCondition.INSTANCE);
    if (isGradientBoosting) {
        return new TreeModelRegression(rootNode, m_leafs);
    }
    return new TreeModelRegression(rootNode);
}
Also used : TreeEnsembleLearnerConfiguration(org.knime.base.node.mine.treeensemble2.node.learner.TreeEnsembleLearnerConfiguration) RootDataMemberships(org.knime.base.node.mine.treeensemble2.data.memberships.RootDataMemberships) ColumnSample(org.knime.base.node.mine.treeensemble2.sample.column.ColumnSample) RegressionPriors(org.knime.base.node.mine.treeensemble2.data.RegressionPriors) BitSet(java.util.BitSet) TreeTargetNumericColumnData(org.knime.base.node.mine.treeensemble2.data.TreeTargetNumericColumnData) IDataIndexManager(org.knime.base.node.mine.treeensemble2.data.memberships.IDataIndexManager) TreeNodeSignature(org.knime.base.node.mine.treeensemble2.model.TreeNodeSignature) TreeNodeRegression(org.knime.base.node.mine.treeensemble2.model.TreeNodeRegression) DataMemberships(org.knime.base.node.mine.treeensemble2.data.memberships.DataMemberships) RootDataMemberships(org.knime.base.node.mine.treeensemble2.data.memberships.RootDataMemberships) TreeModelRegression(org.knime.base.node.mine.treeensemble2.model.TreeModelRegression) GradientBoostingLearnerConfiguration(org.knime.base.node.mine.treeensemble2.node.gradientboosting.learner.GradientBoostingLearnerConfiguration) TreeData(org.knime.base.node.mine.treeensemble2.data.TreeData) RowSample(org.knime.base.node.mine.treeensemble2.sample.row.RowSample)

Example 3 with ColumnSample

use of org.knime.base.node.mine.treeensemble2.sample.column.ColumnSample in project knime-core by knime.

the class TreeLearnerRegression method findBestSplitsRegression.

private SplitCandidate[] findBestSplitsRegression(final int currentDepth, final DataMemberships dataMemberships, final ColumnSample columnSample, final RegressionPriors targetPriors, final BitSet forbiddenColumnSet) {
    final TreeData data = getData();
    final RandomData rd = getRandomData();
    final TreeEnsembleLearnerConfiguration config = getConfig();
    final int maxLevels = config.getMaxLevels();
    if (maxLevels != TreeEnsembleLearnerConfiguration.MAX_LEVEL_INFINITE && currentDepth >= maxLevels) {
        return null;
    }
    final int minNodeSize = config.getMinNodeSize();
    if (minNodeSize != TreeEnsembleLearnerConfiguration.MIN_NODE_SIZE_UNDEFINED) {
        if (targetPriors.getNrRecords() < minNodeSize) {
            return null;
        }
    }
    final double priorSquaredDeviation = targetPriors.getSumSquaredDeviation();
    if (priorSquaredDeviation < TreeColumnData.EPSILON) {
        return null;
    }
    final TreeTargetNumericColumnData targetColumn = getTargetData();
    ArrayList<SplitCandidate> splitCandidates = null;
    if (currentDepth == 0 && config.getHardCodedRootColumn() != null) {
        final TreeAttributeColumnData rootColumn = data.getColumn(config.getHardCodedRootColumn());
        return new SplitCandidate[] { rootColumn.calcBestSplitRegression(dataMemberships, targetPriors, targetColumn, rd) };
    } else {
        splitCandidates = new ArrayList<SplitCandidate>(columnSample.getNumCols());
        for (TreeAttributeColumnData col : columnSample) {
            if (forbiddenColumnSet.get(col.getMetaData().getAttributeIndex())) {
                continue;
            }
            SplitCandidate currentColSplit = col.calcBestSplitRegression(dataMemberships, targetPriors, targetColumn, rd);
            if (currentColSplit != null) {
                splitCandidates.add(currentColSplit);
            }
        }
    }
    Comparator<SplitCandidate> comp = new Comparator<SplitCandidate>() {

        @Override
        public int compare(final SplitCandidate arg0, final SplitCandidate arg1) {
            int compareDouble = -Double.compare(arg0.getGainValue(), arg1.getGainValue());
            return compareDouble;
        }
    };
    if (splitCandidates.isEmpty()) {
        return null;
    }
    splitCandidates.sort(comp);
    return splitCandidates.toArray(new SplitCandidate[splitCandidates.size()]);
}
Also used : TreeEnsembleLearnerConfiguration(org.knime.base.node.mine.treeensemble2.node.learner.TreeEnsembleLearnerConfiguration) RandomData(org.apache.commons.math.random.RandomData) TreeAttributeColumnData(org.knime.base.node.mine.treeensemble2.data.TreeAttributeColumnData) TreeTargetNumericColumnData(org.knime.base.node.mine.treeensemble2.data.TreeTargetNumericColumnData) TreeData(org.knime.base.node.mine.treeensemble2.data.TreeData) Comparator(java.util.Comparator)

Example 4 with ColumnSample

use of org.knime.base.node.mine.treeensemble2.sample.column.ColumnSample in project knime-core by knime.

the class TreeLearnerClassification method findBestSplitClassification.

private SplitCandidate findBestSplitClassification(final int currentDepth, final DataMemberships dataMemberships, final ColumnSample columnSample, final TreeNodeSignature treeNodeSignature, final ClassificationPriors targetPriors, final BitSet forbiddenColumnSet) {
    final TreeData data = getData();
    final RandomData rd = getRandomData();
    // final ColumnSampleStrategy colSamplingStrategy = getColSamplingStrategy();
    final TreeEnsembleLearnerConfiguration config = getConfig();
    final int maxLevels = config.getMaxLevels();
    if (maxLevels != TreeEnsembleLearnerConfiguration.MAX_LEVEL_INFINITE && currentDepth >= maxLevels) {
        return null;
    }
    final int minNodeSize = config.getMinNodeSize();
    if (minNodeSize != TreeEnsembleLearnerConfiguration.MIN_NODE_SIZE_UNDEFINED) {
        if (targetPriors.getNrRecords() < minNodeSize) {
            return null;
        }
    }
    final double priorImpurity = targetPriors.getPriorImpurity();
    if (priorImpurity < TreeColumnData.EPSILON) {
        return null;
    }
    final TreeTargetNominalColumnData targetColumn = (TreeTargetNominalColumnData) data.getTargetColumn();
    SplitCandidate splitCandidate = null;
    if (currentDepth == 0 && config.getHardCodedRootColumn() != null) {
        final TreeAttributeColumnData rootColumn = data.getColumn(config.getHardCodedRootColumn());
        // TODO discuss whether this option makes sense with surrogates
        return rootColumn.calcBestSplitClassification(dataMemberships, targetPriors, targetColumn, rd);
    }
    double bestGainValue = 0.0;
    for (TreeAttributeColumnData col : columnSample) {
        if (forbiddenColumnSet.get(col.getMetaData().getAttributeIndex())) {
            continue;
        }
        final SplitCandidate currentColSplit = col.calcBestSplitClassification(dataMemberships, targetPriors, targetColumn, rd);
        if (currentColSplit != null) {
            final double currentGain = currentColSplit.getGainValue();
            final boolean tiebreaker = currentGain == bestGainValue ? (rd.nextInt(0, 1) == 0) : false;
            if (currentColSplit.getGainValue() > bestGainValue || tiebreaker) {
                splitCandidate = currentColSplit;
                bestGainValue = currentGain;
            }
        }
    }
    return splitCandidate;
}
Also used : TreeEnsembleLearnerConfiguration(org.knime.base.node.mine.treeensemble2.node.learner.TreeEnsembleLearnerConfiguration) RandomData(org.apache.commons.math.random.RandomData) TreeAttributeColumnData(org.knime.base.node.mine.treeensemble2.data.TreeAttributeColumnData) TreeData(org.knime.base.node.mine.treeensemble2.data.TreeData) TreeTargetNominalColumnData(org.knime.base.node.mine.treeensemble2.data.TreeTargetNominalColumnData)

Example 5 with ColumnSample

use of org.knime.base.node.mine.treeensemble2.sample.column.ColumnSample in project knime-core by knime.

the class TreeLearnerClassification method findBestSplitsClassification.

/**
 * Returns a list of SplitCandidates sorted (descending) by their gain
 *
 * @param currentDepth
 * @param rowSampleWeights
 * @param treeNodeSignature
 * @param targetPriors
 * @param forbiddenColumnSet
 * @param membershipController
 * @return
 */
private SplitCandidate[] findBestSplitsClassification(final int currentDepth, final DataMemberships dataMemberships, final ColumnSample columnSample, final TreeNodeSignature treeNodeSignature, final ClassificationPriors targetPriors, final BitSet forbiddenColumnSet) {
    final TreeData data = getData();
    final RandomData rd = getRandomData();
    // final ColumnSampleStrategy colSamplingStrategy = getColSamplingStrategy();
    final TreeEnsembleLearnerConfiguration config = getConfig();
    final int maxLevels = config.getMaxLevels();
    if (maxLevels != TreeEnsembleLearnerConfiguration.MAX_LEVEL_INFINITE && currentDepth >= maxLevels) {
        return null;
    }
    final int minNodeSize = config.getMinNodeSize();
    if (minNodeSize != TreeEnsembleLearnerConfiguration.MIN_NODE_SIZE_UNDEFINED) {
        if (targetPriors.getNrRecords() < minNodeSize) {
            return null;
        }
    }
    final double priorImpurity = targetPriors.getPriorImpurity();
    if (priorImpurity < TreeColumnData.EPSILON) {
        return null;
    }
    final TreeTargetNominalColumnData targetColumn = (TreeTargetNominalColumnData) data.getTargetColumn();
    SplitCandidate splitCandidate = null;
    if (currentDepth == 0 && config.getHardCodedRootColumn() != null) {
        final TreeAttributeColumnData rootColumn = data.getColumn(config.getHardCodedRootColumn());
        // TODO discuss whether this option makes sense with surrogates
        return new SplitCandidate[] { rootColumn.calcBestSplitClassification(dataMemberships, targetPriors, targetColumn, rd) };
    }
    double bestGainValue = 0.0;
    final Comparator<SplitCandidate> comp = new Comparator<SplitCandidate>() {

        @Override
        public int compare(final SplitCandidate o1, final SplitCandidate o2) {
            int compareDouble = -Double.compare(o1.getGainValue(), o2.getGainValue());
            return compareDouble;
        }
    };
    ArrayList<SplitCandidate> candidates = new ArrayList<SplitCandidate>(columnSample.getNumCols());
    for (TreeAttributeColumnData col : columnSample) {
        if (forbiddenColumnSet.get(col.getMetaData().getAttributeIndex())) {
            continue;
        }
        SplitCandidate currentColSplit = col.calcBestSplitClassification(dataMemberships, targetPriors, targetColumn, rd);
        if (currentColSplit != null) {
            candidates.add(currentColSplit);
        }
    }
    if (candidates.isEmpty()) {
        return null;
    }
    candidates.sort(comp);
    return candidates.toArray(new SplitCandidate[candidates.size()]);
}
Also used : TreeEnsembleLearnerConfiguration(org.knime.base.node.mine.treeensemble2.node.learner.TreeEnsembleLearnerConfiguration) RandomData(org.apache.commons.math.random.RandomData) TreeAttributeColumnData(org.knime.base.node.mine.treeensemble2.data.TreeAttributeColumnData) ArrayList(java.util.ArrayList) TreeData(org.knime.base.node.mine.treeensemble2.data.TreeData) TreeTargetNominalColumnData(org.knime.base.node.mine.treeensemble2.data.TreeTargetNominalColumnData) Comparator(java.util.Comparator)

Aggregations

TreeAttributeColumnData (org.knime.base.node.mine.treeensemble2.data.TreeAttributeColumnData)8 TreeData (org.knime.base.node.mine.treeensemble2.data.TreeData)8 TreeNodeSignature (org.knime.base.node.mine.treeensemble2.model.TreeNodeSignature)8 TreeEnsembleLearnerConfiguration (org.knime.base.node.mine.treeensemble2.node.learner.TreeEnsembleLearnerConfiguration)8 BitSet (java.util.BitSet)5 RandomData (org.apache.commons.math.random.RandomData)5 TreeTargetNominalColumnData (org.knime.base.node.mine.treeensemble2.data.TreeTargetNominalColumnData)5 ColumnSample (org.knime.base.node.mine.treeensemble2.sample.column.ColumnSample)5 TreeTargetNumericColumnData (org.knime.base.node.mine.treeensemble2.data.TreeTargetNumericColumnData)4 DataMemberships (org.knime.base.node.mine.treeensemble2.data.memberships.DataMemberships)4 RootDataMemberships (org.knime.base.node.mine.treeensemble2.data.memberships.RootDataMemberships)4 Test (org.junit.Test)3 ClassificationPriors (org.knime.base.node.mine.treeensemble2.data.ClassificationPriors)3 TreeNodeSignatureFactory (org.knime.base.node.mine.treeensemble2.learner.TreeNodeSignatureFactory)3 TreeNodeCondition (org.knime.base.node.mine.treeensemble2.model.TreeNodeCondition)3 ArrayList (java.util.ArrayList)2 Comparator (java.util.Comparator)2 RegressionPriors (org.knime.base.node.mine.treeensemble2.data.RegressionPriors)2 TreeNodeClassification (org.knime.base.node.mine.treeensemble2.model.TreeNodeClassification)2 TreeNodeRegression (org.knime.base.node.mine.treeensemble2.model.TreeNodeRegression)2