use of org.knime.base.node.mine.treeensemble2.data.TreeData in project knime-core by knime.
the class RegressionTreeLearnerNodeModel method execute.
/**
* {@inheritDoc}
*/
@Override
protected PortObject[] execute(final PortObject[] inObjects, final ExecutionContext exec) throws Exception {
BufferedDataTable t = (BufferedDataTable) inObjects[0];
DataTableSpec spec = t.getDataTableSpec();
final FilterLearnColumnRearranger learnRearranger = m_configuration.filterLearnColumns(spec);
String warn = learnRearranger.getWarning();
BufferedDataTable learnTable = exec.createColumnRearrangeTable(t, learnRearranger, exec.createSubProgress(0.0));
DataTableSpec learnSpec = learnTable.getDataTableSpec();
ExecutionMonitor readInExec = exec.createSubProgress(0.1);
ExecutionMonitor learnExec = exec.createSubProgress(0.9);
TreeDataCreator dataCreator = new TreeDataCreator(m_configuration, learnSpec, learnTable.getRowCount());
exec.setProgress("Reading data into memory");
TreeData data = dataCreator.readData(learnTable, m_configuration, readInExec);
m_hiliteRowSample = dataCreator.getDataRowsForHilite();
m_viewMessage = dataCreator.getViewMessage();
String dataCreationWarning = dataCreator.getAndClearWarningMessage();
if (dataCreationWarning != null) {
if (warn == null) {
warn = dataCreationWarning;
} else {
warn = warn + "\n" + dataCreationWarning;
}
}
readInExec.setProgress(1.0);
exec.setMessage("Learning tree");
RandomData rd = m_configuration.createRandomData();
final IDataIndexManager indexManager;
if (data.getTreeType() == TreeType.BitVector) {
indexManager = new BitVectorDataIndexManager(data.getNrRows());
} else {
indexManager = new DefaultDataIndexManager(data);
}
TreeNodeSignatureFactory signatureFactory = null;
int maxLevels = m_configuration.getMaxLevels();
if (maxLevels < TreeEnsembleLearnerConfiguration.MAX_LEVEL_INFINITE) {
int capacity = IntMath.pow(2, maxLevels - 1);
signatureFactory = new TreeNodeSignatureFactory(capacity);
} else {
signatureFactory = new TreeNodeSignatureFactory();
}
final RowSample rowSample = m_configuration.createRowSampler(data).createRowSample(rd);
TreeLearnerRegression treeLearner = new TreeLearnerRegression(m_configuration, data, indexManager, signatureFactory, rd, rowSample);
TreeModelRegression regTree = treeLearner.learnSingleTree(learnExec, rd);
RegressionTreeModel model = new RegressionTreeModel(m_configuration, data.getMetaData(), regTree, data.getTreeType());
RegressionTreeModelPortObjectSpec treePortObjectSpec = new RegressionTreeModelPortObjectSpec(learnSpec);
RegressionTreeModelPortObject treePortObject = new RegressionTreeModelPortObject(model, treePortObjectSpec);
learnExec.setProgress(1.0);
m_treeModelPortObject = treePortObject;
if (warn != null) {
setWarningMessage(warn);
}
return new PortObject[] { treePortObject };
}
use of org.knime.base.node.mine.treeensemble2.data.TreeData in project knime-core by knime.
the class AllColumnSampleTest method testIterator.
@Test
public void testIterator() throws Exception {
final TreeData data = createTreeData();
AllColumnSample sample = new AllColumnSample(data);
int i = 0;
for (final TreeAttributeColumnData col : sample) {
assertEquals("Returned wrong column.", data.getColumns()[i++], col);
}
}
use of org.knime.base.node.mine.treeensemble2.data.TreeData in project knime-core by knime.
the class AllColumnSampleTest method testIteratorRemove.
@Test(expected = UnsupportedOperationException.class)
public void testIteratorRemove() throws Exception {
final TreeData data = createTreeData();
AllColumnSample sample = new AllColumnSample(data);
Iterator<TreeAttributeColumnData> iterator = sample.iterator();
iterator.remove();
}
use of org.knime.base.node.mine.treeensemble2.data.TreeData in project knime-core by knime.
the class TreeLearnerRegression method buildTreeNode.
private TreeNodeRegression buildTreeNode(final ExecutionMonitor exec, final int currentDepth, final DataMemberships dataMemberships, final ColumnSample columnSample, final TreeNodeSignature treeNodeSignature, final RegressionPriors targetPriors, final BitSet forbiddenColumnSet) throws CanceledExecutionException {
final TreeData data = getData();
final RandomData rd = getRandomData();
final TreeEnsembleLearnerConfiguration config = getConfig();
exec.checkCanceled();
final SplitCandidate candidate = findBestSplitRegression(currentDepth, dataMemberships, columnSample, targetPriors, forbiddenColumnSet);
if (candidate == null) {
if (config instanceof GradientBoostingLearnerConfiguration) {
TreeNodeRegression leaf = new TreeNodeRegression(treeNodeSignature, targetPriors, dataMemberships.getOriginalIndices());
addToLeafList(leaf);
return leaf;
}
return new TreeNodeRegression(treeNodeSignature, targetPriors);
}
final TreeTargetNumericColumnData targetColumn = (TreeTargetNumericColumnData) data.getTargetColumn();
boolean useSurrogates = config.getMissingValueHandling() == MissingValueHandling.Surrogate;
TreeNodeCondition[] childConditions;
TreeNodeRegression[] childNodes;
if (useSurrogates) {
SurrogateSplit surrogateSplit = Surrogates.learnSurrogates(dataMemberships, candidate, data, columnSample, config, rd);
childConditions = surrogateSplit.getChildConditions();
BitSet[] childMarkers = surrogateSplit.getChildMarkers();
assert childMarkers[0].cardinality() + childMarkers[1].cardinality() == dataMemberships.getRowCount() : "Sum of rows in children does not add up to number of rows in parent.";
childNodes = new TreeNodeRegression[2];
for (int i = 0; i < 2; i++) {
DataMemberships childMemberships = dataMemberships.createChildMemberships(childMarkers[i]);
TreeNodeSignature childSignature = getSignatureFactory().getChildSignatureFor(treeNodeSignature, (byte) i);
ColumnSample childColumnSample = getColSamplingStrategy().getColumnSampleForTreeNode(childSignature);
RegressionPriors childTargetPriors = targetColumn.getPriors(childMemberships, config);
childNodes[i] = buildTreeNode(exec, currentDepth + 1, childMemberships, childColumnSample, childSignature, childTargetPriors, forbiddenColumnSet);
childNodes[i].setTreeNodeCondition(childConditions[i]);
}
} else {
SplitCandidate bestSplit = candidate;
TreeAttributeColumnData splitColumn = bestSplit.getColumnData();
final int attributeIndex = splitColumn.getMetaData().getAttributeIndex();
boolean markAttributeAsForbidden = !bestSplit.canColumnBeSplitFurther();
forbiddenColumnSet.set(attributeIndex, markAttributeAsForbidden);
childConditions = bestSplit.getChildConditions();
if (childConditions.length > Short.MAX_VALUE) {
throw new RuntimeException("Too many children when splitting " + "attribute " + bestSplit.getColumnData() + " (maximum supported: " + Short.MAX_VALUE + "): " + childConditions.length);
}
childNodes = new TreeNodeRegression[childConditions.length];
for (int i = 0; i < childConditions.length; i++) {
TreeNodeCondition cond = childConditions[i];
DataMemberships childMemberships = dataMemberships.createChildMemberships(splitColumn.updateChildMemberships(cond, dataMemberships));
RegressionPriors childTargetPriors = targetColumn.getPriors(childMemberships, config);
TreeNodeSignature childSignature = treeNodeSignature.createChildSignature((byte) i);
ColumnSample childColumnSample = getColSamplingStrategy().getColumnSampleForTreeNode(childSignature);
childNodes[i] = buildTreeNode(exec, currentDepth + 1, childMemberships, childColumnSample, childSignature, childTargetPriors, forbiddenColumnSet);
childNodes[i].setTreeNodeCondition(cond);
}
if (markAttributeAsForbidden) {
forbiddenColumnSet.set(attributeIndex, false);
}
}
return new TreeNodeRegression(treeNodeSignature, targetPriors, childNodes);
}
use of org.knime.base.node.mine.treeensemble2.data.TreeData in project knime-core by knime.
the class TreeLearnerRegression method findBestSplitRegression.
private SplitCandidate findBestSplitRegression(final int currentDepth, final DataMemberships dataMemberships, final ColumnSample columnSample, final RegressionPriors targetPriors, final BitSet forbiddenColumnSet) {
final TreeData data = getData();
final RandomData rd = getRandomData();
final TreeEnsembleLearnerConfiguration config = getConfig();
final int maxLevels = config.getMaxLevels();
if (maxLevels != TreeEnsembleLearnerConfiguration.MAX_LEVEL_INFINITE && currentDepth >= maxLevels) {
return null;
}
final int minNodeSize = config.getMinNodeSize();
if (minNodeSize != TreeEnsembleLearnerConfiguration.MIN_NODE_SIZE_UNDEFINED) {
if (targetPriors.getNrRecords() < minNodeSize) {
return null;
}
}
final double priorSquaredDeviation = targetPriors.getSumSquaredDeviation();
if (priorSquaredDeviation < TreeColumnData.EPSILON) {
return null;
}
final TreeTargetNumericColumnData targetColumn = getTargetData();
SplitCandidate splitCandidate = null;
if (currentDepth == 0 && config.getHardCodedRootColumn() != null) {
final TreeAttributeColumnData rootColumn = data.getColumn(config.getHardCodedRootColumn());
return rootColumn.calcBestSplitRegression(dataMemberships, targetPriors, targetColumn, rd);
} else {
double bestGainValue = 0.0;
for (TreeAttributeColumnData col : columnSample) {
if (forbiddenColumnSet.get(col.getMetaData().getAttributeIndex())) {
continue;
}
SplitCandidate currentColSplit = col.calcBestSplitRegression(dataMemberships, targetPriors, targetColumn, rd);
if (currentColSplit != null) {
double gainValue = currentColSplit.getGainValue();
if (gainValue > bestGainValue) {
bestGainValue = gainValue;
splitCandidate = currentColSplit;
}
}
}
return splitCandidate;
}
}
Aggregations