use of org.knime.base.node.mine.treeensemble2.sample.column.ColumnSample in project knime-core by knime.
the class SubsetColumnSampleStrategyTest method testGetColumnSampleForTreeNode.
/**
* Tests the method {@link SubsetColumnSampleStrategy#getColumnSampleForTreeNode(org.knime.base.node.mine.treeensemble2.model.TreeNodeSignature)}
* also tests {@link SubsetColumnSample} since the both always act in combination.
*
* @throws Exception
*/
@Test
public void testGetColumnSampleForTreeNode() throws Exception {
final SubsetColumnSampleStrategy strategy = new SubsetColumnSampleStrategy(createTreeData(), RD, 5);
TreeNodeSignatureFactory sigFac = createSignatureFactory();
TreeNodeSignature rootSig = sigFac.getRootSignature();
ColumnSample sample = strategy.getColumnSampleForTreeNode(rootSig);
assertEquals("Wrong number of columns in sample.", 5, sample.getNumCols());
int[] colIndices = sample.getColumnIndices();
sample = strategy.getColumnSampleForTreeNode(sigFac.getChildSignatureFor(rootSig, (byte) 0));
assertEquals("Wrong number of columns in sample.", 5, sample.getNumCols());
assertArrayEquals(colIndices, sample.getColumnIndices());
sample = strategy.getColumnSampleForTreeNode(sigFac.getChildSignatureFor(rootSig, (byte) 1));
assertEquals("Wrong number of columns in sample.", 5, sample.getNumCols());
assertArrayEquals(colIndices, sample.getColumnIndices());
}
use of org.knime.base.node.mine.treeensemble2.sample.column.ColumnSample in project knime-core by knime.
the class TreeLearnerRegression method learnSingleTree.
/**
* {@inheritDoc}
*/
@Override
public TreeModelRegression learnSingleTree(final ExecutionMonitor exec, final RandomData rd) throws CanceledExecutionException {
final TreeTargetNumericColumnData targetColumn = getTargetData();
final TreeData data = getData();
final RowSample rowSampling = getRowSampling();
final TreeEnsembleLearnerConfiguration config = getConfig();
final IDataIndexManager indexManager = getIndexManager();
DataMemberships rootDataMemberships = new RootDataMemberships(rowSampling, data, indexManager);
RegressionPriors targetPriors = targetColumn.getPriors(rootDataMemberships, config);
BitSet forbiddenColumnSet = new BitSet(data.getNrAttributes());
boolean isGradientBoosting = config instanceof GradientBoostingLearnerConfiguration;
if (isGradientBoosting) {
m_leafs = new ArrayList<TreeNodeRegression>();
}
final TreeNodeSignature rootSignature = TreeNodeSignature.ROOT_SIGNATURE;
final ColumnSample rootColumnSample = getColSamplingStrategy().getColumnSampleForTreeNode(rootSignature);
TreeNodeRegression rootNode = buildTreeNode(exec, 0, rootDataMemberships, rootColumnSample, getSignatureFactory().getRootSignature(), targetPriors, forbiddenColumnSet);
assert forbiddenColumnSet.cardinality() == 0;
rootNode.setTreeNodeCondition(TreeNodeTrueCondition.INSTANCE);
if (isGradientBoosting) {
return new TreeModelRegression(rootNode, m_leafs);
}
return new TreeModelRegression(rootNode);
}
use of org.knime.base.node.mine.treeensemble2.sample.column.ColumnSample in project knime-core by knime.
the class TreeLearnerRegression method findBestSplitsRegression.
private SplitCandidate[] findBestSplitsRegression(final int currentDepth, final DataMemberships dataMemberships, final ColumnSample columnSample, final RegressionPriors targetPriors, final BitSet forbiddenColumnSet) {
final TreeData data = getData();
final RandomData rd = getRandomData();
final TreeEnsembleLearnerConfiguration config = getConfig();
final int maxLevels = config.getMaxLevels();
if (maxLevels != TreeEnsembleLearnerConfiguration.MAX_LEVEL_INFINITE && currentDepth >= maxLevels) {
return null;
}
final int minNodeSize = config.getMinNodeSize();
if (minNodeSize != TreeEnsembleLearnerConfiguration.MIN_NODE_SIZE_UNDEFINED) {
if (targetPriors.getNrRecords() < minNodeSize) {
return null;
}
}
final double priorSquaredDeviation = targetPriors.getSumSquaredDeviation();
if (priorSquaredDeviation < TreeColumnData.EPSILON) {
return null;
}
final TreeTargetNumericColumnData targetColumn = getTargetData();
ArrayList<SplitCandidate> splitCandidates = null;
if (currentDepth == 0 && config.getHardCodedRootColumn() != null) {
final TreeAttributeColumnData rootColumn = data.getColumn(config.getHardCodedRootColumn());
return new SplitCandidate[] { rootColumn.calcBestSplitRegression(dataMemberships, targetPriors, targetColumn, rd) };
} else {
splitCandidates = new ArrayList<SplitCandidate>(columnSample.getNumCols());
for (TreeAttributeColumnData col : columnSample) {
if (forbiddenColumnSet.get(col.getMetaData().getAttributeIndex())) {
continue;
}
SplitCandidate currentColSplit = col.calcBestSplitRegression(dataMemberships, targetPriors, targetColumn, rd);
if (currentColSplit != null) {
splitCandidates.add(currentColSplit);
}
}
}
Comparator<SplitCandidate> comp = new Comparator<SplitCandidate>() {
@Override
public int compare(final SplitCandidate arg0, final SplitCandidate arg1) {
int compareDouble = -Double.compare(arg0.getGainValue(), arg1.getGainValue());
return compareDouble;
}
};
if (splitCandidates.isEmpty()) {
return null;
}
splitCandidates.sort(comp);
return splitCandidates.toArray(new SplitCandidate[splitCandidates.size()]);
}
use of org.knime.base.node.mine.treeensemble2.sample.column.ColumnSample in project knime-core by knime.
the class TreeLearnerClassification method findBestSplitClassification.
private SplitCandidate findBestSplitClassification(final int currentDepth, final DataMemberships dataMemberships, final ColumnSample columnSample, final TreeNodeSignature treeNodeSignature, final ClassificationPriors targetPriors, final BitSet forbiddenColumnSet) {
final TreeData data = getData();
final RandomData rd = getRandomData();
// final ColumnSampleStrategy colSamplingStrategy = getColSamplingStrategy();
final TreeEnsembleLearnerConfiguration config = getConfig();
final int maxLevels = config.getMaxLevels();
if (maxLevels != TreeEnsembleLearnerConfiguration.MAX_LEVEL_INFINITE && currentDepth >= maxLevels) {
return null;
}
final int minNodeSize = config.getMinNodeSize();
if (minNodeSize != TreeEnsembleLearnerConfiguration.MIN_NODE_SIZE_UNDEFINED) {
if (targetPriors.getNrRecords() < minNodeSize) {
return null;
}
}
final double priorImpurity = targetPriors.getPriorImpurity();
if (priorImpurity < TreeColumnData.EPSILON) {
return null;
}
final TreeTargetNominalColumnData targetColumn = (TreeTargetNominalColumnData) data.getTargetColumn();
SplitCandidate splitCandidate = null;
if (currentDepth == 0 && config.getHardCodedRootColumn() != null) {
final TreeAttributeColumnData rootColumn = data.getColumn(config.getHardCodedRootColumn());
// TODO discuss whether this option makes sense with surrogates
return rootColumn.calcBestSplitClassification(dataMemberships, targetPriors, targetColumn, rd);
}
double bestGainValue = 0.0;
for (TreeAttributeColumnData col : columnSample) {
if (forbiddenColumnSet.get(col.getMetaData().getAttributeIndex())) {
continue;
}
final SplitCandidate currentColSplit = col.calcBestSplitClassification(dataMemberships, targetPriors, targetColumn, rd);
if (currentColSplit != null) {
final double currentGain = currentColSplit.getGainValue();
final boolean tiebreaker = currentGain == bestGainValue ? (rd.nextInt(0, 1) == 0) : false;
if (currentColSplit.getGainValue() > bestGainValue || tiebreaker) {
splitCandidate = currentColSplit;
bestGainValue = currentGain;
}
}
}
return splitCandidate;
}
use of org.knime.base.node.mine.treeensemble2.sample.column.ColumnSample in project knime-core by knime.
the class TreeLearnerClassification method findBestSplitsClassification.
/**
* Returns a list of SplitCandidates sorted (descending) by their gain
*
* @param currentDepth
* @param rowSampleWeights
* @param treeNodeSignature
* @param targetPriors
* @param forbiddenColumnSet
* @param membershipController
* @return
*/
private SplitCandidate[] findBestSplitsClassification(final int currentDepth, final DataMemberships dataMemberships, final ColumnSample columnSample, final TreeNodeSignature treeNodeSignature, final ClassificationPriors targetPriors, final BitSet forbiddenColumnSet) {
final TreeData data = getData();
final RandomData rd = getRandomData();
// final ColumnSampleStrategy colSamplingStrategy = getColSamplingStrategy();
final TreeEnsembleLearnerConfiguration config = getConfig();
final int maxLevels = config.getMaxLevels();
if (maxLevels != TreeEnsembleLearnerConfiguration.MAX_LEVEL_INFINITE && currentDepth >= maxLevels) {
return null;
}
final int minNodeSize = config.getMinNodeSize();
if (minNodeSize != TreeEnsembleLearnerConfiguration.MIN_NODE_SIZE_UNDEFINED) {
if (targetPriors.getNrRecords() < minNodeSize) {
return null;
}
}
final double priorImpurity = targetPriors.getPriorImpurity();
if (priorImpurity < TreeColumnData.EPSILON) {
return null;
}
final TreeTargetNominalColumnData targetColumn = (TreeTargetNominalColumnData) data.getTargetColumn();
SplitCandidate splitCandidate = null;
if (currentDepth == 0 && config.getHardCodedRootColumn() != null) {
final TreeAttributeColumnData rootColumn = data.getColumn(config.getHardCodedRootColumn());
// TODO discuss whether this option makes sense with surrogates
return new SplitCandidate[] { rootColumn.calcBestSplitClassification(dataMemberships, targetPriors, targetColumn, rd) };
}
double bestGainValue = 0.0;
final Comparator<SplitCandidate> comp = new Comparator<SplitCandidate>() {
@Override
public int compare(final SplitCandidate o1, final SplitCandidate o2) {
int compareDouble = -Double.compare(o1.getGainValue(), o2.getGainValue());
return compareDouble;
}
};
ArrayList<SplitCandidate> candidates = new ArrayList<SplitCandidate>(columnSample.getNumCols());
for (TreeAttributeColumnData col : columnSample) {
if (forbiddenColumnSet.get(col.getMetaData().getAttributeIndex())) {
continue;
}
SplitCandidate currentColSplit = col.calcBestSplitClassification(dataMemberships, targetPriors, targetColumn, rd);
if (currentColSplit != null) {
candidates.add(currentColSplit);
}
}
if (candidates.isEmpty()) {
return null;
}
candidates.sort(comp);
return candidates.toArray(new SplitCandidate[candidates.size()]);
}
Aggregations