use of org.knime.base.node.mine.treeensemble.sample.row.RowSample in project knime-core by knime.
the class TreeLearnerClassification method learnSingleTree.
/**
* {@inheritDoc}
*/
@Override
public TreeModelClassification learnSingleTree(final ExecutionMonitor exec, final RandomData rd) throws CanceledExecutionException {
final TreeData data = getData();
final RowSample rowSampling = getRowSampling();
final TreeEnsembleLearnerConfiguration config = getConfig();
final TreeTargetNominalColumnData targetColumn = (TreeTargetNominalColumnData) data.getTargetColumn();
double[] dataMemberships = new double[data.getNrRows()];
for (int i = 0; i < dataMemberships.length; i++) {
// dataMemberships[i] = m_rowSampling.getCountFor(i) > 0 ? 1.0 : 0.0;
dataMemberships[i] = rowSampling.getCountFor(i);
}
ClassificationPriors targetPriors = targetColumn.getDistribution(dataMemberships, config);
BitSet forbiddenColumnSet = new BitSet(data.getNrAttributes());
// TreeNodeMembershipController rootMembershipController = new TreeNodeMembershipController(data, dataMemberships);
TreeNodeMembershipController rootMembershipController = null;
TreeNodeClassification rootNode = buildTreeNode(exec, 0, dataMemberships, TreeNodeSignature.ROOT_SIGNATURE, targetPriors, forbiddenColumnSet, rootMembershipController);
assert forbiddenColumnSet.cardinality() == 0;
rootNode.setTreeNodeCondition(TreeNodeTrueCondition.INSTANCE);
return new TreeModelClassification(rootNode);
}
use of org.knime.base.node.mine.treeensemble.sample.row.RowSample in project knime-core by knime.
the class TreeLearnerRegression method learnSingleTree.
/**
* {@inheritDoc}
*/
@Override
public TreeModelRegression learnSingleTree(final ExecutionMonitor exec, final RandomData rd) throws CanceledExecutionException {
final TreeTargetNumericColumnData targetColumn = getTargetData();
final TreeData data = getData();
final RowSample rowSampling = getRowSampling();
final TreeEnsembleLearnerConfiguration config = getConfig();
double[] dataMemberships = new double[data.getNrRows()];
for (int i = 0; i < dataMemberships.length; i++) {
dataMemberships[i] = rowSampling.getCountFor(i);
}
RegressionPriors targetPriors = targetColumn.getPriors(dataMemberships, config);
BitSet forbiddenColumnSet = new BitSet(data.getNrAttributes());
// TreeNodeMembershipController rootMembershipController = new TreeNodeMembershipController(data, dataMemberships);
TreeNodeMembershipController rootMembershipController = null;
TreeNodeRegression rootNode = buildTreeNode(exec, 0, dataMemberships, TreeNodeSignature.ROOT_SIGNATURE, targetPriors, forbiddenColumnSet, rootMembershipController);
assert forbiddenColumnSet.cardinality() == 0;
rootNode.setTreeNodeCondition(TreeNodeTrueCondition.INSTANCE);
return new TreeModelRegression(rootNode);
}
Aggregations