use of org.knime.base.node.mine.treeensemble2.model.AbstractTreeNode in project knime-core by knime.
the class TreeEnsembleLearnerNodeView2 method getSubtree.
// private void recreateHiLite() {
// HiLiteHandler handler = m_hiLiteHdl;
// if (handler == null) {
// return;
// }
// Set<RowKey> hilited = handler.getHiLitKeys();
// Set<AbstractTreeNode> toHilite = new HashSet<AbstractTreeNode>();
// AbstractTreeNode root = m_graph.getRootNode();
//
// List<AbstractTreeNode> toProcess = new LinkedList<AbstractTreeNode>();
// if (null != root) {
// toProcess.add(0, root);
// }
// // Traverse the tree breadth first
// while (!toProcess.isEmpty()) {
// AbstractTreeNode curr = toProcess.remove(0);
// // bug 2695: if not all pattern are selected for hilting, the
// // view will automatically hilite all branches that does not
// // cover any pattern
// if (curr.coveredPattern().isEmpty()) {
// continue;
// }
// if (hilited.containsAll(curr.coveredPattern())) {
// // hilite subtree starting from curr
// toHilite.addAll(getSubtree(curr));
// } else {
// for (int i = 0; i < curr.getChildCount(); i++) {
// toProcess.add(0, curr.getChildAt(i));
// }
// }
// }
// m_graph.hiLite(toHilite);
// }
private List<AbstractTreeNode> getSubtree(final AbstractTreeNode node) {
List<AbstractTreeNode> subTree = new ArrayList<AbstractTreeNode>();
List<AbstractTreeNode> toProcess = new LinkedList<AbstractTreeNode>();
toProcess.add(0, node);
// Traverse the tree breadth first
while (!toProcess.isEmpty()) {
AbstractTreeNode curr = toProcess.remove(0);
subTree.add(curr);
for (int i = 0; i < curr.getNrChildren(); i++) {
toProcess.add(0, curr.getChild(i));
}
}
return subTree;
}
use of org.knime.base.node.mine.treeensemble2.model.AbstractTreeNode in project knime-core by knime.
the class AbstractTreeModelExporter method recursivelyCheckUsedFields.
private void recursivelyCheckUsedFields(final AbstractTreeNode node, final int numLearnCols, final Set<String> usedLearningFields) {
if (usedLearningFields.size() == numLearnCols) {
return;
}
TreeNodeCondition cond = node.getCondition();
addAllFieldsInCondition(cond, usedLearningFields);
node.getChildren().forEach(c -> recursivelyCheckUsedFields(c, numLearnCols, usedLearningFields));
}
use of org.knime.base.node.mine.treeensemble2.model.AbstractTreeNode in project knime-core by knime.
the class RandomForestDistance method computeDistance.
/**
* {@inheritDoc}
*/
@Override
public double computeDistance(final DataRow row1, final DataRow row2) throws DistanceMeasurementException {
List<Integer> filterIndicesList = getColumnIndices();
int[] filterIndices = new int[filterIndicesList.size()];
int i = 0;
for (Integer index : filterIndicesList) {
filterIndices[i++] = index;
}
final DataRow filterRow1 = new FilterColumnRow(row1, filterIndices);
final DataRow filterRow2 = new FilterColumnRow(row2, filterIndices);
final PredictorRecord record1 = m_ensembleModel.createPredictorRecord(filterRow1, m_learnTableSpec);
final PredictorRecord record2 = m_ensembleModel.createPredictorRecord(filterRow2, m_learnTableSpec);
final int nrModels = m_ensembleModel.getNrModels();
double proximity = 0.0;
for (int t = 0; t < nrModels; t++) {
AbstractTreeModel<?> tree = m_ensembleModel.getTreeModel(t);
AbstractTreeNode leaf1 = tree.findMatchingNode(record1);
AbstractTreeNode leaf2 = tree.findMatchingNode(record2);
if (leaf1.getSignature().equals(leaf2.getSignature())) {
proximity += 1.0;
}
}
proximity /= nrModels;
// to get a distance measure, we have to subtract the proximity from 1
return 1 - proximity;
}
use of org.knime.base.node.mine.treeensemble2.model.AbstractTreeNode in project knime-core by knime.
the class TreeModelPMMLTranslator method addTreeNode.
/**
* @param pmmlNode
* @param node
*/
private void addTreeNode(final Node pmmlNode, final AbstractTreeNode node) {
int index = m_nodeIndex++;
pmmlNode.setId(Integer.toString(index));
if (node instanceof TreeNodeClassification) {
final TreeNodeClassification clazNode = (TreeNodeClassification) node;
pmmlNode.setScore(clazNode.getMajorityClassName());
float[] targetDistribution = clazNode.getTargetDistribution();
NominalValueRepresentation[] targetVals = clazNode.getTargetMetaData().getValues();
double sum = 0.0;
for (Float v : targetDistribution) {
sum += v;
}
pmmlNode.setRecordCount(sum);
// adding score distribution (class counts)
for (int i = 0; i < targetDistribution.length; i++) {
String className = targetVals[i].getNominalValue();
double freq = targetDistribution[i];
ScoreDistribution pmmlScoreDist = pmmlNode.addNewScoreDistribution();
pmmlScoreDist.setValue(className);
pmmlScoreDist.setRecordCount(freq);
}
} else if (node instanceof TreeNodeRegression) {
final TreeNodeRegression regNode = (TreeNodeRegression) node;
pmmlNode.setScore(Double.toString(regNode.getMean()));
}
TreeNodeCondition condition = node.getCondition();
if (condition instanceof TreeNodeTrueCondition) {
pmmlNode.addNewTrue();
} else if (condition instanceof TreeNodeColumnCondition) {
final TreeNodeColumnCondition colCondition = (TreeNodeColumnCondition) condition;
handleColumnCondition(colCondition, pmmlNode);
} else if (condition instanceof AbstractTreeNodeSurrogateCondition) {
final AbstractTreeNodeSurrogateCondition surrogateCond = (AbstractTreeNodeSurrogateCondition) condition;
setValuesFromPMMLCompoundPredicate(pmmlNode.addNewCompoundPredicate(), surrogateCond.toPMMLPredicate());
} else {
throw new IllegalStateException("Unsupported condition (not " + "implemented): " + condition.getClass().getSimpleName());
}
for (int i = 0; i < node.getNrChildren(); i++) {
addTreeNode(pmmlNode.addNewNode(), node.getChild(i));
}
}
use of org.knime.base.node.mine.treeensemble2.model.AbstractTreeNode in project knime-core by knime.
the class RandomForestClassificationTreeGraphView method mousePressed.
/**
* {@inheritDoc}
*/
@Override
public void mousePressed(final MouseEvent e) {
AbstractTreeNode nodePressed = nodeAtPoint(e.getPoint());
Dimension preferredSize = null != nodePressed ? getWidgets().get(nodePressed).getPreferredSize() : null;
super.mousePressed(e);
// relayout when preferred size of the clicked node has changed
Dimension preferredSizeAfter = null != nodePressed ? getWidgets().get(nodePressed).getPreferredSize() : null;
if (null != nodePressed && !preferredSize.equals(preferredSizeAfter)) {
layoutGraph();
getView().revalidate();
getView().repaint();
// make sure that the clicked node is in the visible area
getView().scrollRectToVisible(getVisible().get(nodePressed));
}
}
Aggregations