use of org.knime.base.node.mine.treeensemble2.model.TreeNodeClassification in project knime-core by knime.
the class TreeModelPMMLTranslator method addTreeNode.
/**
* @param pmmlNode
* @param node
*/
private void addTreeNode(final Node pmmlNode, final AbstractTreeNode node) {
int index = m_nodeIndex++;
pmmlNode.setId(Integer.toString(index));
if (node instanceof TreeNodeClassification) {
final TreeNodeClassification clazNode = (TreeNodeClassification) node;
pmmlNode.setScore(clazNode.getMajorityClassName());
float[] targetDistribution = clazNode.getTargetDistribution();
NominalValueRepresentation[] targetVals = clazNode.getTargetMetaData().getValues();
double sum = 0.0;
for (Float v : targetDistribution) {
sum += v;
}
pmmlNode.setRecordCount(sum);
// adding score distribution (class counts)
for (int i = 0; i < targetDistribution.length; i++) {
String className = targetVals[i].getNominalValue();
double freq = targetDistribution[i];
ScoreDistribution pmmlScoreDist = pmmlNode.addNewScoreDistribution();
pmmlScoreDist.setValue(className);
pmmlScoreDist.setRecordCount(freq);
}
} else if (node instanceof TreeNodeRegression) {
final TreeNodeRegression regNode = (TreeNodeRegression) node;
pmmlNode.setScore(Double.toString(regNode.getMean()));
}
TreeNodeCondition condition = node.getCondition();
if (condition instanceof TreeNodeTrueCondition) {
pmmlNode.addNewTrue();
} else if (condition instanceof TreeNodeColumnCondition) {
final TreeNodeColumnCondition colCondition = (TreeNodeColumnCondition) condition;
handleColumnCondition(colCondition, pmmlNode);
} else if (condition instanceof AbstractTreeNodeSurrogateCondition) {
final AbstractTreeNodeSurrogateCondition surrogateCond = (AbstractTreeNodeSurrogateCondition) condition;
setValuesFromPMMLCompoundPredicate(pmmlNode.addNewCompoundPredicate(), surrogateCond.toPMMLPredicate());
} else {
throw new IllegalStateException("Unsupported condition (not " + "implemented): " + condition.getClass().getSimpleName());
}
for (int i = 0; i < node.getNrChildren(); i++) {
addTreeNode(pmmlNode.addNewNode(), node.getChild(i));
}
}
use of org.knime.base.node.mine.treeensemble2.model.TreeNodeClassification in project knime-core by knime.
the class TreeLearnerClassification method learnSingleTreeRecursive.
private TreeModelClassification learnSingleTreeRecursive(final ExecutionMonitor exec, final RandomData rd) throws CanceledExecutionException {
final TreeData data = getData();
final RowSample rowSampling = getRowSampling();
final TreeEnsembleLearnerConfiguration config = getConfig();
final TreeTargetNominalColumnData targetColumn = (TreeTargetNominalColumnData) data.getTargetColumn();
final // new RootDataMem(rowSampling, getIndexManager());
DataMemberships rootDataMemberships = new RootDataMemberships(rowSampling, data, getIndexManager());
ClassificationPriors targetPriors = targetColumn.getDistribution(rootDataMemberships, config);
BitSet forbiddenColumnSet = new BitSet(data.getNrAttributes());
// final DataMemberships rootDataMemberships = new IntArrayDataMemberships(sampleWeights, data);
final TreeNodeSignature rootSignature = TreeNodeSignature.ROOT_SIGNATURE;
final ColumnSample rootColumnSample = getColSamplingStrategy().getColumnSampleForTreeNode(rootSignature);
TreeNodeClassification rootNode = null;
rootNode = buildTreeNode(exec, 0, rootDataMemberships, rootColumnSample, rootSignature, targetPriors, forbiddenColumnSet);
assert forbiddenColumnSet.cardinality() == 0;
rootNode.setTreeNodeCondition(TreeNodeTrueCondition.INSTANCE);
return new TreeModelClassification(rootNode);
}
use of org.knime.base.node.mine.treeensemble2.model.TreeNodeClassification in project knime-core by knime.
the class TreeNodeClassification method createDecisionTreeNode.
/**
* Creates DecisionTreeNode model that is used in Decision Tree of KNIME
*
* @param idGenerator
* @param metaData
* @return a DecisionTreeNode
*/
public DecisionTreeNode createDecisionTreeNode(final MutableInteger idGenerator, final TreeMetaData metaData) {
DataCell majorityCell = new StringCell(getMajorityClassName());
final float[] targetDistribution = getTargetDistribution();
int initSize = (int) (targetDistribution.length / 0.75 + 1.0);
LinkedHashMap<DataCell, Double> scoreDistributionMap = new LinkedHashMap<DataCell, Double>(initSize);
NominalValueRepresentation[] targets = getTargetMetaData().getValues();
for (int i = 0; i < targetDistribution.length; i++) {
String cl = targets[i].getNominalValue();
double d = targetDistribution[i];
scoreDistributionMap.put(new StringCell(cl), d);
}
final int nrChildren = getNrChildren();
if (nrChildren == 0) {
return new DecisionTreeNodeLeaf(idGenerator.inc(), majorityCell, scoreDistributionMap);
} else {
int id = idGenerator.inc();
DecisionTreeNode[] childNodes = new DecisionTreeNode[nrChildren];
int splitAttributeIndex = getSplitAttributeIndex();
assert splitAttributeIndex >= 0 : "non-leaf node has no split";
String splitAttribute = metaData.getAttributeMetaData(splitAttributeIndex).getAttributeName();
PMMLPredicate[] childPredicates = new PMMLPredicate[nrChildren];
for (int i = 0; i < nrChildren; i++) {
final TreeNodeClassification treeNode = getChild(i);
TreeNodeCondition cond = treeNode.getCondition();
childPredicates[i] = cond.toPMMLPredicate();
childNodes[i] = treeNode.createDecisionTreeNode(idGenerator, metaData);
}
return new DecisionTreeNodeSplitPMML(id, majorityCell, scoreDistributionMap, splitAttribute, childPredicates, childNodes);
}
}
use of org.knime.base.node.mine.treeensemble2.model.TreeNodeClassification in project knime-core by knime.
the class TreeLearnerClassification method buildTreeNode.
private TreeNodeClassification buildTreeNode(final ExecutionMonitor exec, final int currentDepth, final DataMemberships dataMemberships, final ColumnSample columnSample, final TreeNodeSignature treeNodeSignature, final ClassificationPriors targetPriors, final BitSet forbiddenColumnSet) throws CanceledExecutionException {
final TreeData data = getData();
final TreeEnsembleLearnerConfiguration config = getConfig();
exec.checkCanceled();
final boolean useSurrogates = getConfig().getMissingValueHandling() == MissingValueHandling.Surrogate;
TreeNodeCondition[] childConditions;
boolean markAttributeAsForbidden = false;
final TreeTargetNominalColumnData targetColumn = (TreeTargetNominalColumnData) data.getTargetColumn();
TreeNodeClassification[] childNodes;
int attributeIndex = -1;
if (useSurrogates) {
SplitCandidate[] candidates = findBestSplitsClassification(currentDepth, dataMemberships, columnSample, treeNodeSignature, targetPriors, forbiddenColumnSet);
if (candidates == null) {
return new TreeNodeClassification(treeNodeSignature, targetPriors, config);
}
SurrogateSplit surrogateSplit = Surrogates.learnSurrogates(dataMemberships, candidates[0], data, columnSample, config, getRandomData());
childConditions = surrogateSplit.getChildConditions();
BitSet[] childMarkers = surrogateSplit.getChildMarkers();
childNodes = new TreeNodeClassification[2];
for (int i = 0; i < 2; i++) {
DataMemberships childMemberships = dataMemberships.createChildMemberships(childMarkers[i]);
ClassificationPriors childTargetPriors = targetColumn.getDistribution(childMemberships, config);
TreeNodeSignature childSignature = getSignatureFactory().getChildSignatureFor(treeNodeSignature, (byte) i);
ColumnSample childColumnSample = getColSamplingStrategy().getColumnSampleForTreeNode(childSignature);
childNodes[i] = buildTreeNode(exec, currentDepth + 1, childMemberships, childColumnSample, childSignature, childTargetPriors, forbiddenColumnSet);
childNodes[i].setTreeNodeCondition(childConditions[i]);
}
} else {
// handle non surrogate case
SplitCandidate bestSplit = findBestSplitClassification(currentDepth, dataMemberships, columnSample, treeNodeSignature, targetPriors, forbiddenColumnSet);
if (bestSplit == null) {
return new TreeNodeClassification(treeNodeSignature, targetPriors, config);
}
TreeAttributeColumnData splitColumn = bestSplit.getColumnData();
attributeIndex = splitColumn.getMetaData().getAttributeIndex();
markAttributeAsForbidden = !bestSplit.canColumnBeSplitFurther();
forbiddenColumnSet.set(attributeIndex, markAttributeAsForbidden);
childConditions = bestSplit.getChildConditions();
childNodes = new TreeNodeClassification[childConditions.length];
if (childConditions.length > Short.MAX_VALUE) {
throw new RuntimeException("Too many children when splitting " + "attribute " + bestSplit.getColumnData() + " (maximum supported: " + Short.MAX_VALUE + "): " + childConditions.length);
}
// Build child nodes
for (int i = 0; i < childConditions.length; i++) {
DataMemberships childMemberships = null;
TreeNodeCondition cond = childConditions[i];
childMemberships = dataMemberships.createChildMemberships(splitColumn.updateChildMemberships(cond, dataMemberships));
ClassificationPriors childTargetPriors = targetColumn.getDistribution(childMemberships, config);
TreeNodeSignature childSignature = treeNodeSignature.createChildSignature((byte) i);
ColumnSample childColumnSample = getColSamplingStrategy().getColumnSampleForTreeNode(childSignature);
childNodes[i] = buildTreeNode(exec, currentDepth + 1, childMemberships, childColumnSample, childSignature, childTargetPriors, forbiddenColumnSet);
childNodes[i].setTreeNodeCondition(cond);
}
}
if (markAttributeAsForbidden) {
forbiddenColumnSet.set(attributeIndex, false);
}
return new TreeNodeClassification(treeNodeSignature, targetPriors, childNodes, getConfig());
}
use of org.knime.base.node.mine.treeensemble2.model.TreeNodeClassification in project knime-core by knime.
the class RandomForestClassificationTreeNodeWidget method createTablePanel.
private JPanel createTablePanel(final float scale) {
TreeNodeClassification node = (TreeNodeClassification) getUserObject();
final float[] targetDistribution = node.getTargetDistribution();
double totalClassCount = 0.0;
for (double classCount : targetDistribution) {
totalClassCount += classCount;
}
JPanel p = new JPanel(new GridBagLayout());
GridBagConstraints c = new GridBagConstraints();
int gridwidth = 3;
c.fill = GridBagConstraints.HORIZONTAL;
c.anchor = GridBagConstraints.NORTHWEST;
int bw = Math.round(1 * scale);
c.insets = new Insets(bw, bw, bw, bw);
c.gridx = 0;
c.gridy = 0;
c.weightx = 1;
c.weighty = 1;
c.gridwidth = 1;
p.add(scaledLabel("Category", scale), c);
c.gridx++;
p.add(scaledLabel("% ", scale, SwingConstants.RIGHT), c);
c.gridx++;
p.add(scaledLabel("n ", scale, SwingConstants.RIGHT), c);
c.gridy++;
c.gridx = 0;
c.gridwidth = GridBagConstraints.REMAINDER;
p.add(new MyJSeparator(), c);
c.gridwidth = 1;
int majorityClassIndex = node.getMajorityClassIndex();
NominalValueRepresentation[] classNomVals = node.getTargetMetaData().getValues();
List<Double> classFreqList = new ArrayList<Double>();
for (int i = 0; i < targetDistribution.length; i++) {
JLabel classLabel = scaledLabel(classNomVals[i].getNominalValue(), scale);
c.gridy++;
c.gridx = 0;
p.add(classLabel, c);
c.gridx++;
double classFreq = targetDistribution[i] / totalClassCount;
classFreqList.add(classFreq);
p.add(scaledLabel(convertPercentage(classFreq), scale, SwingConstants.RIGHT), c);
c.gridx++;
final Float classCountValue = targetDistribution[i];
p.add(scaledLabel(convertCount(classCountValue), scale, SwingConstants.RIGHT), c);
if (i == majorityClassIndex) {
c.gridx = 0;
JComponent comp = new JPanel();
comp.setMinimumSize(classLabel.getPreferredSize());
comp.setPreferredSize(classLabel.getPreferredSize());
comp.setBackground(new Color(225, 225, 225));
c.gridwidth = gridwidth;
p.add(comp, c);
c.gridwidth = 1;
}
}
c.gridy++;
c.gridx = 0;
c.gridwidth = gridwidth;
p.add(new MyJSeparator(), c);
c.gridwidth = 1;
c.gridy++;
c.gridx = 0;
p.add(scaledLabel("Total", scale), c);
c.gridx++;
double nominator = 0.0;
TreeNodeClassification root = (TreeNodeClassification) getGraphView().getRootNode();
if (root != null) {
final float[] rootTargetDistribution = root.getTargetDistribution();
double rootTotalClassCount = 0.0;
for (double classCount : rootTargetDistribution) {
rootTotalClassCount += classCount;
}
nominator = rootTotalClassCount;
} else {
nominator = totalClassCount;
}
double coverage = totalClassCount / nominator;
p.add(scaledLabel(convertPercentage(coverage), scale, SwingConstants.RIGHT), c);
c.gridx++;
p.add(scaledLabel(convertCount(totalClassCount), scale, SwingConstants.RIGHT), c);
return p;
}
Aggregations