use of org.knime.base.node.mine.treeensemble2.model.TreeNodeSignature in project knime-core by knime.
the class RFSubsetColumnSampleStrategyTest method testGetColumnSampleForTreeNode.
/**
* Tests the method {@link RFSubsetColumnSampleStrategy#getColumnSampleForTreeNode(org.knime.base.node.mine.treeensemble2.model.TreeNodeSignature)}
*
* @throws Exception
*/
@Test
public void testGetColumnSampleForTreeNode() throws Exception {
final RFSubsetColumnSampleStrategy strategy = new RFSubsetColumnSampleStrategy(createTreeData(), RD, 5);
final TreeNodeSignatureFactory sigFac = createSignatureFactory();
TreeNodeSignature rootSig = sigFac.getRootSignature();
ColumnSample sample = strategy.getColumnSampleForTreeNode(rootSig);
assertEquals("Wrong number of columns in sample.", 5, sample.getNumCols());
int[] colIndices0 = sample.getColumnIndices();
sample = strategy.getColumnSampleForTreeNode(sigFac.getChildSignatureFor(rootSig, (byte) 0));
assertEquals("Wrong number of columns in sample.", 5, sample.getNumCols());
int[] colIndices1 = sample.getColumnIndices();
sample = strategy.getColumnSampleForTreeNode(sigFac.getChildSignatureFor(rootSig, (byte) 1));
assertEquals("Wrong number of columns in sample.", 5, sample.getNumCols());
int[] colIndices2 = sample.getColumnIndices();
assertEquals("sample sizes differ.", colIndices0.length, colIndices1.length);
assertEquals("sample sizes differ.", colIndices0.length, colIndices2.length);
assertEquals("sample sizes differ.", colIndices1.length, colIndices2.length);
boolean match = true;
for (int i = 0; i < colIndices0.length; i++) {
match = match && colIndices0[i] == colIndices1[i] && colIndices0[i] == colIndices2[i];
if (!match) {
break;
}
}
assertFalse("It is very unlikely that we get 3 times the same column sample.", match);
}
use of org.knime.base.node.mine.treeensemble2.model.TreeNodeSignature in project knime-core by knime.
the class TreeLearnerClassification method buildTreeNode.
private TreeNodeClassification buildTreeNode(final ExecutionMonitor exec, final int currentDepth, final DataMemberships dataMemberships, final ColumnSample columnSample, final TreeNodeSignature treeNodeSignature, final ClassificationPriors targetPriors, final BitSet forbiddenColumnSet) throws CanceledExecutionException {
final TreeData data = getData();
final TreeEnsembleLearnerConfiguration config = getConfig();
exec.checkCanceled();
final boolean useSurrogates = getConfig().getMissingValueHandling() == MissingValueHandling.Surrogate;
TreeNodeCondition[] childConditions;
boolean markAttributeAsForbidden = false;
final TreeTargetNominalColumnData targetColumn = (TreeTargetNominalColumnData) data.getTargetColumn();
TreeNodeClassification[] childNodes;
int attributeIndex = -1;
if (useSurrogates) {
SplitCandidate[] candidates = findBestSplitsClassification(currentDepth, dataMemberships, columnSample, treeNodeSignature, targetPriors, forbiddenColumnSet);
if (candidates == null) {
return new TreeNodeClassification(treeNodeSignature, targetPriors, config);
}
SurrogateSplit surrogateSplit = Surrogates.learnSurrogates(dataMemberships, candidates[0], data, columnSample, config, getRandomData());
childConditions = surrogateSplit.getChildConditions();
BitSet[] childMarkers = surrogateSplit.getChildMarkers();
childNodes = new TreeNodeClassification[2];
for (int i = 0; i < 2; i++) {
DataMemberships childMemberships = dataMemberships.createChildMemberships(childMarkers[i]);
ClassificationPriors childTargetPriors = targetColumn.getDistribution(childMemberships, config);
TreeNodeSignature childSignature = getSignatureFactory().getChildSignatureFor(treeNodeSignature, (byte) i);
ColumnSample childColumnSample = getColSamplingStrategy().getColumnSampleForTreeNode(childSignature);
childNodes[i] = buildTreeNode(exec, currentDepth + 1, childMemberships, childColumnSample, childSignature, childTargetPriors, forbiddenColumnSet);
childNodes[i].setTreeNodeCondition(childConditions[i]);
}
} else {
// handle non surrogate case
SplitCandidate bestSplit = findBestSplitClassification(currentDepth, dataMemberships, columnSample, treeNodeSignature, targetPriors, forbiddenColumnSet);
if (bestSplit == null) {
return new TreeNodeClassification(treeNodeSignature, targetPriors, config);
}
TreeAttributeColumnData splitColumn = bestSplit.getColumnData();
attributeIndex = splitColumn.getMetaData().getAttributeIndex();
markAttributeAsForbidden = !bestSplit.canColumnBeSplitFurther();
forbiddenColumnSet.set(attributeIndex, markAttributeAsForbidden);
childConditions = bestSplit.getChildConditions();
childNodes = new TreeNodeClassification[childConditions.length];
if (childConditions.length > Short.MAX_VALUE) {
throw new RuntimeException("Too many children when splitting " + "attribute " + bestSplit.getColumnData() + " (maximum supported: " + Short.MAX_VALUE + "): " + childConditions.length);
}
// Build child nodes
for (int i = 0; i < childConditions.length; i++) {
DataMemberships childMemberships = null;
TreeNodeCondition cond = childConditions[i];
childMemberships = dataMemberships.createChildMemberships(splitColumn.updateChildMemberships(cond, dataMemberships));
ClassificationPriors childTargetPriors = targetColumn.getDistribution(childMemberships, config);
TreeNodeSignature childSignature = treeNodeSignature.createChildSignature((byte) i);
ColumnSample childColumnSample = getColSamplingStrategy().getColumnSampleForTreeNode(childSignature);
childNodes[i] = buildTreeNode(exec, currentDepth + 1, childMemberships, childColumnSample, childSignature, childTargetPriors, forbiddenColumnSet);
childNodes[i].setTreeNodeCondition(cond);
}
}
if (markAttributeAsForbidden) {
forbiddenColumnSet.set(attributeIndex, false);
}
return new TreeNodeClassification(treeNodeSignature, targetPriors, childNodes, getConfig());
}
use of org.knime.base.node.mine.treeensemble2.model.TreeNodeSignature in project knime-core by knime.
the class MGradientBoostedTreesLearner method calcCoefficientMap.
private Map<TreeNodeSignature, Double> calcCoefficientMap(final double[] residuals, final double quantile, final TreeModelRegression tree) {
final List<TreeNodeRegression> leafs = tree.getLeafs();
final Map<TreeNodeSignature, Double> coefficientMap = new HashMap<TreeNodeSignature, Double>((int) (leafs.size() / 0.75 + 1));
final double learningRate = getConfig().getLearningRate();
for (TreeNodeRegression leaf : leafs) {
final int[] indices = leaf.getRowIndicesInTreeData();
final double[] values = new double[indices.length];
for (int i = 0; i < indices.length; i++) {
values[i] = residuals[indices[i]];
}
final double median = calcMedian(values);
double sum = 0;
for (int i = 0; i < values.length; i++) {
sum += Math.signum(values[i] - median) * Math.min(quantile, Math.abs(values[i] - median));
}
final double coefficient = median + (1.0 / values.length) * sum;
coefficientMap.put(leaf.getSignature(), coefficient * learningRate);
}
return coefficientMap;
}
use of org.knime.base.node.mine.treeensemble2.model.TreeNodeSignature in project knime-core by knime.
the class TreeEnsembleLearner method createColumnStatisticTable.
public BufferedDataTable createColumnStatisticTable(final ExecutionContext exec) throws CanceledExecutionException {
BufferedDataContainer c = exec.createDataContainer(getColumnStatisticTableSpec());
final int nrModels = m_ensembleModel.getNrModels();
final TreeAttributeColumnData[] columns = m_data.getColumns();
final int nrAttributes = columns.length;
int[][] columnOnLevelCounts = new int[REPORT_LEVEL][nrAttributes];
int[][] columnInLevelSampleCounts = new int[REPORT_LEVEL][nrAttributes];
for (int i = 0; i < nrModels; i++) {
final AbstractTreeModel<?> treeModel = m_ensembleModel.getTreeModel(i);
for (int level = 0; level < REPORT_LEVEL; level++) {
for (AbstractTreeNode treeNodeOnLevel : treeModel.getTreeNodes(level)) {
TreeNodeSignature sig = treeNodeOnLevel.getSignature();
ColumnSampleStrategy colStrat = m_columnSampleStrategies[i];
ColumnSample cs = colStrat.getColumnSampleForTreeNode(sig);
for (TreeAttributeColumnData col : cs) {
final int index = col.getMetaData().getAttributeIndex();
columnInLevelSampleCounts[level][index] += 1;
}
int splitAttIdx = treeNodeOnLevel.getSplitAttributeIndex();
if (splitAttIdx >= 0) {
columnOnLevelCounts[level][splitAttIdx] += 1;
}
}
}
}
for (int i = 0; i < nrAttributes; i++) {
String name = columns[i].getMetaData().getAttributeName();
int[] counts = new int[2 * REPORT_LEVEL];
for (int level = 0; level < REPORT_LEVEL; level++) {
counts[level] = columnOnLevelCounts[level][i];
counts[REPORT_LEVEL + level] = columnInLevelSampleCounts[level][i];
}
DataRow row = new DefaultRow(name, counts);
c.addRowToTable(row);
exec.checkCanceled();
}
c.close();
return c.getTable();
}
use of org.knime.base.node.mine.treeensemble2.model.TreeNodeSignature in project knime-core by knime.
the class TreeLearnerRegression method buildTreeNode.
private TreeNodeRegression buildTreeNode(final ExecutionMonitor exec, final int currentDepth, final DataMemberships dataMemberships, final ColumnSample columnSample, final TreeNodeSignature treeNodeSignature, final RegressionPriors targetPriors, final BitSet forbiddenColumnSet) throws CanceledExecutionException {
final TreeData data = getData();
final RandomData rd = getRandomData();
final TreeEnsembleLearnerConfiguration config = getConfig();
exec.checkCanceled();
final SplitCandidate candidate = findBestSplitRegression(currentDepth, dataMemberships, columnSample, targetPriors, forbiddenColumnSet);
if (candidate == null) {
if (config instanceof GradientBoostingLearnerConfiguration) {
TreeNodeRegression leaf = new TreeNodeRegression(treeNodeSignature, targetPriors, dataMemberships.getOriginalIndices());
addToLeafList(leaf);
return leaf;
}
return new TreeNodeRegression(treeNodeSignature, targetPriors);
}
final TreeTargetNumericColumnData targetColumn = (TreeTargetNumericColumnData) data.getTargetColumn();
boolean useSurrogates = config.getMissingValueHandling() == MissingValueHandling.Surrogate;
TreeNodeCondition[] childConditions;
TreeNodeRegression[] childNodes;
if (useSurrogates) {
SurrogateSplit surrogateSplit = Surrogates.learnSurrogates(dataMemberships, candidate, data, columnSample, config, rd);
childConditions = surrogateSplit.getChildConditions();
BitSet[] childMarkers = surrogateSplit.getChildMarkers();
assert childMarkers[0].cardinality() + childMarkers[1].cardinality() == dataMemberships.getRowCount() : "Sum of rows in children does not add up to number of rows in parent.";
childNodes = new TreeNodeRegression[2];
for (int i = 0; i < 2; i++) {
DataMemberships childMemberships = dataMemberships.createChildMemberships(childMarkers[i]);
TreeNodeSignature childSignature = getSignatureFactory().getChildSignatureFor(treeNodeSignature, (byte) i);
ColumnSample childColumnSample = getColSamplingStrategy().getColumnSampleForTreeNode(childSignature);
RegressionPriors childTargetPriors = targetColumn.getPriors(childMemberships, config);
childNodes[i] = buildTreeNode(exec, currentDepth + 1, childMemberships, childColumnSample, childSignature, childTargetPriors, forbiddenColumnSet);
childNodes[i].setTreeNodeCondition(childConditions[i]);
}
} else {
SplitCandidate bestSplit = candidate;
TreeAttributeColumnData splitColumn = bestSplit.getColumnData();
final int attributeIndex = splitColumn.getMetaData().getAttributeIndex();
boolean markAttributeAsForbidden = !bestSplit.canColumnBeSplitFurther();
forbiddenColumnSet.set(attributeIndex, markAttributeAsForbidden);
childConditions = bestSplit.getChildConditions();
if (childConditions.length > Short.MAX_VALUE) {
throw new RuntimeException("Too many children when splitting " + "attribute " + bestSplit.getColumnData() + " (maximum supported: " + Short.MAX_VALUE + "): " + childConditions.length);
}
childNodes = new TreeNodeRegression[childConditions.length];
for (int i = 0; i < childConditions.length; i++) {
TreeNodeCondition cond = childConditions[i];
DataMemberships childMemberships = dataMemberships.createChildMemberships(splitColumn.updateChildMemberships(cond, dataMemberships));
RegressionPriors childTargetPriors = targetColumn.getPriors(childMemberships, config);
TreeNodeSignature childSignature = treeNodeSignature.createChildSignature((byte) i);
ColumnSample childColumnSample = getColSamplingStrategy().getColumnSampleForTreeNode(childSignature);
childNodes[i] = buildTreeNode(exec, currentDepth + 1, childMemberships, childColumnSample, childSignature, childTargetPriors, forbiddenColumnSet);
childNodes[i].setTreeNodeCondition(cond);
}
if (markAttributeAsForbidden) {
forbiddenColumnSet.set(attributeIndex, false);
}
}
return new TreeNodeRegression(treeNodeSignature, targetPriors, childNodes);
}
Aggregations