use of org.knime.base.node.mine.treeensemble.data.TreeAttributeColumnData in project knime-core by knime.
the class TreeLearnerClassification method findBestSplitClassification.
private SplitCandidate findBestSplitClassification(final int currentDepth, final double[] rowSampleWeights, final TreeNodeSignature treeNodeSignature, final ClassificationPriors targetPriors, final BitSet forbiddenColumnSet, final TreeNodeMembershipController membershipController) {
final TreeData data = getData();
final ColumnSampleStrategy colSamplingStrategy = getColSamplingStrategy();
final TreeEnsembleLearnerConfiguration config = getConfig();
final int maxLevels = config.getMaxLevels();
if (maxLevels != TreeEnsembleLearnerConfiguration.MAX_LEVEL_INFINITE && currentDepth >= maxLevels) {
return null;
}
final int minNodeSize = config.getMinNodeSize();
if (minNodeSize != TreeEnsembleLearnerConfiguration.MIN_NODE_SIZE_UNDEFINED) {
if (targetPriors.getNrRecords() < minNodeSize) {
return null;
}
}
final double priorImpurity = targetPriors.getPriorImpurity();
if (priorImpurity < TreeColumnData.EPSILON) {
return null;
}
final TreeTargetNominalColumnData targetColumn = (TreeTargetNominalColumnData) data.getTargetColumn();
SplitCandidate splitCandidate = null;
if (currentDepth == 0 && config.getHardCodedRootColumn() != null) {
final TreeAttributeColumnData rootColumn = data.getColumn(config.getHardCodedRootColumn());
return rootColumn.calcBestSplitClassification(membershipController, rowSampleWeights, targetPriors, targetColumn);
} else {
double bestGainValue = 0.0;
final ColumnSample columnSample = colSamplingStrategy.getColumnSampleForTreeNode(treeNodeSignature);
for (TreeAttributeColumnData col : columnSample) {
if (forbiddenColumnSet.get(col.getMetaData().getAttributeIndex())) {
continue;
}
SplitCandidate currentColSplit = col.calcBestSplitClassification(membershipController, rowSampleWeights, targetPriors, targetColumn);
if (currentColSplit != null) {
double gainValue = currentColSplit.getGainValue();
if (gainValue > bestGainValue) {
bestGainValue = gainValue;
splitCandidate = currentColSplit;
}
}
}
}
return splitCandidate;
}
use of org.knime.base.node.mine.treeensemble.data.TreeAttributeColumnData in project knime-core by knime.
the class TreeLearnerRegression method findBestSplitRegression.
private SplitCandidate findBestSplitRegression(final int currentDepth, final double[] rowSampleWeights, final TreeNodeSignature treeNodeSignature, final RegressionPriors targetPriors, final BitSet forbiddenColumnSet, final TreeNodeMembershipController membershipController) {
final TreeData data = getData();
final ColumnSampleStrategy colSamplingStrategy = getColSamplingStrategy();
final TreeEnsembleLearnerConfiguration config = getConfig();
final int maxLevels = config.getMaxLevels();
if (maxLevels != TreeEnsembleLearnerConfiguration.MAX_LEVEL_INFINITE && currentDepth >= maxLevels) {
return null;
}
final int minNodeSize = config.getMinNodeSize();
if (minNodeSize != TreeEnsembleLearnerConfiguration.MIN_NODE_SIZE_UNDEFINED) {
if (targetPriors.getNrRecords() < minNodeSize) {
return null;
}
}
final double priorSquaredDeviation = targetPriors.getSumSquaredDeviation();
if (priorSquaredDeviation < TreeColumnData.EPSILON) {
return null;
}
final TreeTargetNumericColumnData targetColumn = getTargetData();
SplitCandidate splitCandidate = null;
if (currentDepth == 0 && config.getHardCodedRootColumn() != null) {
final TreeAttributeColumnData rootColumn = data.getColumn(config.getHardCodedRootColumn());
return rootColumn.calcBestSplitRegression(membershipController, rowSampleWeights, targetPriors, targetColumn);
} else {
double bestGainValue = 0.0;
final ColumnSample columnSample = colSamplingStrategy.getColumnSampleForTreeNode(treeNodeSignature);
for (TreeAttributeColumnData col : columnSample) {
if (forbiddenColumnSet.get(col.getMetaData().getAttributeIndex())) {
continue;
}
SplitCandidate currentColSplit = col.calcBestSplitRegression(membershipController, rowSampleWeights, targetPriors, targetColumn);
if (currentColSplit != null) {
double gainValue = currentColSplit.getGainValue();
if (gainValue > bestGainValue) {
bestGainValue = gainValue;
splitCandidate = currentColSplit;
}
}
}
}
return splitCandidate;
}
use of org.knime.base.node.mine.treeensemble.data.TreeAttributeColumnData in project knime-core by knime.
the class TreeLearnerClassification method buildTreeNode.
private TreeNodeClassification buildTreeNode(final ExecutionMonitor exec, final int currentDepth, final double[] rowSampleWeights, final TreeNodeSignature treeNodeSignature, final ClassificationPriors targetPriors, final BitSet forbiddenColumnSet, final TreeNodeMembershipController membershipController) throws CanceledExecutionException {
final TreeData data = getData();
final TreeEnsembleLearnerConfiguration config = getConfig();
exec.checkCanceled();
SplitCandidate bestSplit = findBestSplitClassification(currentDepth, rowSampleWeights, treeNodeSignature, targetPriors, forbiddenColumnSet, membershipController);
if (bestSplit == null) {
return new TreeNodeClassification(treeNodeSignature, targetPriors, getConfig());
}
TreeAttributeColumnData splitColumn = bestSplit.getColumnData();
final int attributeIndex = splitColumn.getMetaData().getAttributeIndex();
boolean markAttributeAsForbidden = !bestSplit.canColumnBeSplitFurther();
forbiddenColumnSet.set(attributeIndex, markAttributeAsForbidden);
TreeNodeCondition[] childConditions = bestSplit.getChildConditions();
if (childConditions.length > Short.MAX_VALUE) {
throw new RuntimeException("Too many children when splitting " + "attribute " + bestSplit.getColumnData() + " (maximum supported: " + Short.MAX_VALUE + "): " + childConditions.length);
}
TreeNodeClassification[] childNodes = new TreeNodeClassification[childConditions.length];
final double[] dataMemberships = rowSampleWeights;
// final double[] dataMemberships = rowSampleWeights.getMemberships();
final double[] childMemberships = new double[dataMemberships.length];
final TreeTargetNominalColumnData targetColumn = (TreeTargetNominalColumnData) data.getTargetColumn();
for (int i = 0; i < childConditions.length; i++) {
System.arraycopy(dataMemberships, 0, childMemberships, 0, dataMemberships.length);
TreeNodeCondition cond = childConditions[i];
splitColumn.updateChildMemberships(cond, dataMemberships, childMemberships);
// TreeNodeMembershipController childMembershipController = membershipController.createChildTreeNodeMembershipController(childMemberships);
TreeNodeMembershipController childMembershipController = null;
ClassificationPriors childTargetPriors = targetColumn.getDistribution(childMemberships, config);
TreeNodeSignature childSignature = treeNodeSignature.createChildSignature((short) i);
childNodes[i] = buildTreeNode(exec, currentDepth + 1, childMemberships, childSignature, childTargetPriors, forbiddenColumnSet, childMembershipController);
childNodes[i].setTreeNodeCondition(cond);
}
if (markAttributeAsForbidden) {
forbiddenColumnSet.set(attributeIndex, false);
}
return new TreeNodeClassification(treeNodeSignature, targetPriors, childNodes, getConfig());
}
use of org.knime.base.node.mine.treeensemble.data.TreeAttributeColumnData in project knime-core by knime.
the class TreeEnsembleLearner method createColumnStatisticTable.
public BufferedDataTable createColumnStatisticTable(final ExecutionContext exec) throws CanceledExecutionException {
BufferedDataContainer c = exec.createDataContainer(getColumnStatisticTableSpec());
final int nrModels = m_ensembleModel.getNrModels();
final TreeAttributeColumnData[] columns = m_data.getColumns();
final int nrAttributes = columns.length;
int[][] columnOnLevelCounts = new int[REPORT_LEVEL][nrAttributes];
int[][] columnInLevelSampleCounts = new int[REPORT_LEVEL][nrAttributes];
for (int i = 0; i < nrModels; i++) {
final AbstractTreeModel<?> treeModel = m_ensembleModel.getTreeModel(i);
for (int level = 0; level < REPORT_LEVEL; level++) {
for (AbstractTreeNode treeNodeOnLevel : treeModel.getTreeNodes(level)) {
TreeNodeSignature sig = treeNodeOnLevel.getSignature();
ColumnSampleStrategy colStrat = m_columnSampleStrategies[i];
ColumnSample cs = colStrat.getColumnSampleForTreeNode(sig);
for (TreeAttributeColumnData col : cs) {
final int index = col.getMetaData().getAttributeIndex();
columnInLevelSampleCounts[level][index] += 1;
}
int splitAttIdx = treeNodeOnLevel.getSplitAttributeIndex();
if (splitAttIdx >= 0) {
columnOnLevelCounts[level][splitAttIdx] += 1;
}
}
}
}
for (int i = 0; i < nrAttributes; i++) {
String name = columns[i].getMetaData().getAttributeName();
int[] counts = new int[2 * REPORT_LEVEL];
for (int level = 0; level < REPORT_LEVEL; level++) {
counts[level] = columnOnLevelCounts[level][i];
counts[REPORT_LEVEL + level] = columnInLevelSampleCounts[level][i];
}
DataRow row = new DefaultRow(name, counts);
c.addRowToTable(row);
exec.checkCanceled();
}
c.close();
return c.getTable();
}
use of org.knime.base.node.mine.treeensemble.data.TreeAttributeColumnData in project knime-core by knime.
the class TreeLearnerRegression method buildTreeNode.
private TreeNodeRegression buildTreeNode(final ExecutionMonitor exec, final int currentDepth, final double[] rowSampleWeights, final TreeNodeSignature treeNodeSignature, final RegressionPriors targetPriors, final BitSet forbiddenColumnSet, final TreeNodeMembershipController membershipController) throws CanceledExecutionException {
final TreeData data = getData();
final TreeEnsembleLearnerConfiguration config = getConfig();
exec.checkCanceled();
SplitCandidate bestSplit = findBestSplitRegression(currentDepth, rowSampleWeights, treeNodeSignature, targetPriors, forbiddenColumnSet, membershipController);
if (bestSplit == null) {
return new TreeNodeRegression(treeNodeSignature, targetPriors);
}
TreeAttributeColumnData splitColumn = bestSplit.getColumnData();
final int attributeIndex = splitColumn.getMetaData().getAttributeIndex();
boolean markAttributeAsForbidden = !bestSplit.canColumnBeSplitFurther();
forbiddenColumnSet.set(attributeIndex, markAttributeAsForbidden);
TreeNodeCondition[] childConditions = bestSplit.getChildConditions();
if (childConditions.length > Short.MAX_VALUE) {
throw new RuntimeException("Too many children when splitting " + "attribute " + bestSplit.getColumnData() + " (maximum supported: " + Short.MAX_VALUE + "): " + childConditions.length);
}
TreeNodeRegression[] childNodes = new TreeNodeRegression[childConditions.length];
final double[] dataMemberships = rowSampleWeights;
final double[] childMemberships = new double[dataMemberships.length];
final TreeTargetNumericColumnData targetColumn = (TreeTargetNumericColumnData) data.getTargetColumn();
for (int i = 0; i < childConditions.length; i++) {
System.arraycopy(dataMemberships, 0, childMemberships, 0, dataMemberships.length);
TreeNodeCondition cond = childConditions[i];
splitColumn.updateChildMemberships(cond, dataMemberships, childMemberships);
RegressionPriors childTargetPriors = targetColumn.getPriors(childMemberships, config);
TreeNodeSignature childSignature = treeNodeSignature.createChildSignature((short) i);
TreeNodeMembershipController childMembershipController = splitColumn.getChildNodeMembershipController(cond, membershipController);
childNodes[i] = buildTreeNode(exec, currentDepth + 1, childMemberships, childSignature, childTargetPriors, forbiddenColumnSet, childMembershipController);
childNodes[i].setTreeNodeCondition(cond);
}
if (markAttributeAsForbidden) {
forbiddenColumnSet.set(attributeIndex, false);
}
return new TreeNodeRegression(treeNodeSignature, targetPriors, childNodes);
}
Aggregations