use of org.knime.base.node.mine.treeensemble2.data.TreeData in project knime-core by knime.
the class SubsetColumnSampleTest method testIterator.
@Test
public void testIterator() throws Exception {
final TreeData data = createTreeData();
int[] colIndices = new int[] { 1, 3, 5 };
SubsetColumnSample sample = new SubsetColumnSample(data, colIndices);
int i = 0;
for (final TreeAttributeColumnData col : sample) {
assertEquals("Wrong column returned.", data.getColumns()[colIndices[i++]], col);
}
}
use of org.knime.base.node.mine.treeensemble2.data.TreeData in project knime-core by knime.
the class Surrogates method learnSurrogates.
/**
* This function searches for splits in the remaining columns of <b>colSample</b>. It is doing so by taking the
* directions (left or right) that are induced by the <b>bestSplit</b> as new target.
*
* @param dataMemberships provides information which rows are in the current branch
* @param bestSplit the best split for the current node
* @param oldData the TreeData object that contains all attributes and the target
* @param colSample provides information which columns are to be considered as surrogates
* @param config the configuration
* @param rd
* @return a SurrogateSplit that contains the conditions for both children
*/
public static SurrogateSplit learnSurrogates(final DataMemberships dataMemberships, final SplitCandidate bestSplit, final TreeData oldData, final ColumnSample colSample, final TreeEnsembleLearnerConfiguration config, final RandomData rd) {
TreeAttributeColumnData bestSplitCol = bestSplit.getColumnData();
TreeNodeCondition[] bestSplitChildConditions = bestSplit.getChildConditions();
// calculate new Target
BitSet bestSplitLeft = bestSplitCol.updateChildMemberships(bestSplitChildConditions[0], dataMemberships);
BitSet bestSplitRight = bestSplitCol.updateChildMemberships(bestSplitChildConditions[1], dataMemberships);
// create DataMemberships that only contains the instances that are not missed by bestSplit
BitSet surrogateBitSet = (BitSet) bestSplitLeft.clone();
surrogateBitSet.or(bestSplitRight);
DataMemberships surrogateCalcDataMemberships = dataMemberships.createChildMemberships(surrogateBitSet);
TreeTargetNominalColumnData newTarget = createNewTargetColumn(bestSplitLeft, bestSplitRight, oldData.getNrRows(), surrogateCalcDataMemberships);
// find best splits on new target
ArrayList<SplitCandidate> candidates = new ArrayList<SplitCandidate>();
ClassificationPriors newTargetPriors = newTarget.getDistribution(surrogateCalcDataMemberships, config);
for (TreeAttributeColumnData col : colSample) {
if (col != bestSplitCol) {
SplitCandidate candidate = col.calcBestSplitClassification(surrogateCalcDataMemberships, newTargetPriors, newTarget, rd);
if (candidate != null) {
candidates.add(candidate);
}
}
}
SplitCandidate[] candidatesWithBestAtHead = new SplitCandidate[candidates.size() + 1];
candidatesWithBestAtHead[0] = bestSplit;
for (int i = 1; i < candidatesWithBestAtHead.length; i++) {
candidatesWithBestAtHead[i] = candidates.get(i - 1);
}
return calculateSurrogates(dataMemberships, candidatesWithBestAtHead);
}
use of org.knime.base.node.mine.treeensemble2.data.TreeData in project knime-core by knime.
the class TreeLearnerClassification method buildTreeNode.
private TreeNodeClassification buildTreeNode(final ExecutionMonitor exec, final int currentDepth, final DataMemberships dataMemberships, final ColumnSample columnSample, final TreeNodeSignature treeNodeSignature, final ClassificationPriors targetPriors, final BitSet forbiddenColumnSet) throws CanceledExecutionException {
final TreeData data = getData();
final TreeEnsembleLearnerConfiguration config = getConfig();
exec.checkCanceled();
final boolean useSurrogates = getConfig().getMissingValueHandling() == MissingValueHandling.Surrogate;
TreeNodeCondition[] childConditions;
boolean markAttributeAsForbidden = false;
final TreeTargetNominalColumnData targetColumn = (TreeTargetNominalColumnData) data.getTargetColumn();
TreeNodeClassification[] childNodes;
int attributeIndex = -1;
if (useSurrogates) {
SplitCandidate[] candidates = findBestSplitsClassification(currentDepth, dataMemberships, columnSample, treeNodeSignature, targetPriors, forbiddenColumnSet);
if (candidates == null) {
return new TreeNodeClassification(treeNodeSignature, targetPriors, config);
}
SurrogateSplit surrogateSplit = Surrogates.learnSurrogates(dataMemberships, candidates[0], data, columnSample, config, getRandomData());
childConditions = surrogateSplit.getChildConditions();
BitSet[] childMarkers = surrogateSplit.getChildMarkers();
childNodes = new TreeNodeClassification[2];
for (int i = 0; i < 2; i++) {
DataMemberships childMemberships = dataMemberships.createChildMemberships(childMarkers[i]);
ClassificationPriors childTargetPriors = targetColumn.getDistribution(childMemberships, config);
TreeNodeSignature childSignature = getSignatureFactory().getChildSignatureFor(treeNodeSignature, (byte) i);
ColumnSample childColumnSample = getColSamplingStrategy().getColumnSampleForTreeNode(childSignature);
childNodes[i] = buildTreeNode(exec, currentDepth + 1, childMemberships, childColumnSample, childSignature, childTargetPriors, forbiddenColumnSet);
childNodes[i].setTreeNodeCondition(childConditions[i]);
}
} else {
// handle non surrogate case
SplitCandidate bestSplit = findBestSplitClassification(currentDepth, dataMemberships, columnSample, treeNodeSignature, targetPriors, forbiddenColumnSet);
if (bestSplit == null) {
return new TreeNodeClassification(treeNodeSignature, targetPriors, config);
}
TreeAttributeColumnData splitColumn = bestSplit.getColumnData();
attributeIndex = splitColumn.getMetaData().getAttributeIndex();
markAttributeAsForbidden = !bestSplit.canColumnBeSplitFurther();
forbiddenColumnSet.set(attributeIndex, markAttributeAsForbidden);
childConditions = bestSplit.getChildConditions();
childNodes = new TreeNodeClassification[childConditions.length];
if (childConditions.length > Short.MAX_VALUE) {
throw new RuntimeException("Too many children when splitting " + "attribute " + bestSplit.getColumnData() + " (maximum supported: " + Short.MAX_VALUE + "): " + childConditions.length);
}
// Build child nodes
for (int i = 0; i < childConditions.length; i++) {
DataMemberships childMemberships = null;
TreeNodeCondition cond = childConditions[i];
childMemberships = dataMemberships.createChildMemberships(splitColumn.updateChildMemberships(cond, dataMemberships));
ClassificationPriors childTargetPriors = targetColumn.getDistribution(childMemberships, config);
TreeNodeSignature childSignature = treeNodeSignature.createChildSignature((byte) i);
ColumnSample childColumnSample = getColSamplingStrategy().getColumnSampleForTreeNode(childSignature);
childNodes[i] = buildTreeNode(exec, currentDepth + 1, childMemberships, childColumnSample, childSignature, childTargetPriors, forbiddenColumnSet);
childNodes[i].setTreeNodeCondition(cond);
}
}
if (markAttributeAsForbidden) {
forbiddenColumnSet.set(attributeIndex, false);
}
return new TreeNodeClassification(treeNodeSignature, targetPriors, childNodes, getConfig());
}
use of org.knime.base.node.mine.treeensemble2.data.TreeData in project knime-core by knime.
the class AbstractGradientBoostingLearner method createPredictorRecord.
/**
* Creates a PredictorRecord from the inMemory TreeData object
*
* @param data
* @param indexManager
* @param rowIdx
* @return a PredictorRecord for the row at <b>rowIdx</b> in <b>data</b>
*/
public static PredictorRecord createPredictorRecord(final TreeData data, final IDataIndexManager indexManager, final int rowIdx) {
Map<String, Object> valMap = new HashMap<String, Object>();
for (TreeAttributeColumnData column : data.getColumns()) {
TreeAttributeColumnMetaData meta = column.getMetaData();
valMap.put(meta.getAttributeName(), handleMissingValues(column.getValueAt(indexManager.getPositionsInColumn(meta.getAttributeIndex())[rowIdx]), column));
}
return new PredictorRecord(valMap);
}
use of org.knime.base.node.mine.treeensemble2.data.TreeData in project knime-core by knime.
the class RandomForestClassificationLearnerNodeModel method execute.
/**
* {@inheritDoc}
*/
@Override
protected PortObject[] execute(final PortObject[] inObjects, final ExecutionContext exec) throws Exception {
BufferedDataTable t = (BufferedDataTable) inObjects[0];
DataTableSpec spec = t.getDataTableSpec();
final FilterLearnColumnRearranger learnRearranger = m_configuration.filterLearnColumns(spec);
String warn = learnRearranger.getWarning();
BufferedDataTable learnTable = exec.createColumnRearrangeTable(t, learnRearranger, exec.createSubProgress(0.0));
DataTableSpec learnSpec = learnTable.getDataTableSpec();
TreeEnsembleModelPortObjectSpec ensembleSpec = m_configuration.createPortObjectSpec(learnSpec);
Map<String, DataCell> targetValueMap = ensembleSpec.getTargetColumnPossibleValueMap();
if (targetValueMap == null) {
throw new InvalidSettingsException("The target column does not " + "have possible values assigned. Most likely it " + "has too many different distinct values (learning an ID " + "column?) Fix it by preprocessing the table using " + "a \"Domain Calculator\".");
}
ExecutionMonitor readInExec = exec.createSubProgress(0.1);
ExecutionMonitor learnExec = exec.createSubProgress(0.8);
ExecutionMonitor outOfBagExec = exec.createSubProgress(0.1);
TreeDataCreator dataCreator = new TreeDataCreator(m_configuration, learnSpec, learnTable.getRowCount());
exec.setProgress("Reading data into memory");
TreeData data = dataCreator.readData(learnTable, m_configuration, readInExec);
m_hiliteRowSample = dataCreator.getDataRowsForHilite();
m_viewMessage = dataCreator.getViewMessage();
String dataCreationWarning = dataCreator.getAndClearWarningMessage();
if (dataCreationWarning != null) {
if (warn == null) {
warn = dataCreationWarning;
} else {
warn = warn + "\n" + dataCreationWarning;
}
}
readInExec.setProgress(1.0);
exec.setMessage("Learning trees");
// Use xgboost missing value handling
m_configuration.setMissingValueHandling(MissingValueHandling.XGBoost);
TreeEnsembleLearner learner = new TreeEnsembleLearner(m_configuration, data);
TreeEnsembleModel model;
try {
model = learner.learnEnsemble(learnExec);
} catch (ExecutionException e) {
Throwable cause = e.getCause();
if (cause instanceof Exception) {
throw (Exception) cause;
}
throw e;
}
TreeEnsembleModelPortObject modelPortObject = TreeEnsembleModelPortObject.createPortObject(ensembleSpec, model, exec.createFileStore("TreeEnsemble"));
learnExec.setProgress(1.0);
exec.setMessage("Out of bag prediction");
TreeEnsemblePredictor outOfBagPredictor = createOutOfBagPredictor(ensembleSpec, modelPortObject, spec);
outOfBagPredictor.setOutofBagFilter(learner.getRowSamples(), data.getTargetColumn());
ColumnRearranger outOfBagRearranger = outOfBagPredictor.getPredictionRearranger();
BufferedDataTable outOfBagTable = exec.createColumnRearrangeTable(t, outOfBagRearranger, outOfBagExec);
BufferedDataTable colStatsTable = learner.createColumnStatisticTable(exec.createSubExecutionContext(0.0));
m_ensembleModelPortObject = modelPortObject;
if (warn != null) {
setWarningMessage(warn);
}
return new PortObject[] { outOfBagTable, colStatsTable, modelPortObject };
}
Aggregations