use of org.knime.base.node.mine.treeensemble2.model.TreeNodeCondition in project knime-core by knime.
the class RandomForestClassificationTreeNodeWidget method getConnectorLabelBelow.
/**
* {@inheritDoc}
*/
@Override
public String getConnectorLabelBelow() {
TreeNodeClassification node = (TreeNodeClassification) getUserObject();
if (node.getNrChildren() != 0) {
TreeNodeClassification child = node.getChild(0);
TreeNodeCondition childCondition = child.getCondition();
if (childCondition instanceof TreeNodeColumnCondition) {
return ((TreeNodeColumnCondition) childCondition).getAttributeName();
} else if (childCondition instanceof TreeNodeSurrogateCondition) {
TreeNodeSurrogateCondition surrogateCondition = (TreeNodeSurrogateCondition) childCondition;
TreeNodeCondition headCondition = surrogateCondition.getFirstCondition();
if (headCondition instanceof TreeNodeColumnCondition) {
return ((TreeNodeColumnCondition) headCondition).getAttributeName();
}
}
}
return null;
}
use of org.knime.base.node.mine.treeensemble2.model.TreeNodeCondition in project knime-core by knime.
the class AbstractTreeModelExporter method recursivelyCheckUsedFields.
private void recursivelyCheckUsedFields(final AbstractTreeNode node, final int numLearnCols, final Set<String> usedLearningFields) {
if (usedLearningFields.size() == numLearnCols) {
return;
}
TreeNodeCondition cond = node.getCondition();
addAllFieldsInCondition(cond, usedLearningFields);
node.getChildren().forEach(c -> recursivelyCheckUsedFields(c, numLearnCols, usedLearningFields));
}
use of org.knime.base.node.mine.treeensemble2.model.TreeNodeCondition in project knime-core by knime.
the class BitSplitCandidate method getChildConditions.
/**
* {@inheritDoc}
*/
@Override
public TreeNodeCondition[] getChildConditions() {
TreeBitColumnMetaData metaData = getColumnData().getMetaData();
TreeNodeCondition[] result = new TreeNodeCondition[2];
result[0] = new TreeNodeBitCondition(metaData, true);
result[1] = new TreeNodeBitCondition(metaData, false);
return result;
}
use of org.knime.base.node.mine.treeensemble2.model.TreeNodeCondition in project knime-core by knime.
the class NominalMultiwaySplitCandidate method getChildConditions.
/**
* {@inheritDoc}
*/
@Override
public TreeNodeNominalCondition[] getChildConditions() {
TreeNominalColumnMetaData columnMeta = getColumnData().getMetaData();
NominalValueRepresentation[] values = columnMeta.getValues();
final int lengthNonMissing = values[values.length - 1].getNominalValue().equals(NominalValueRepresentation.MISSING_VALUE) ? values.length - 1 : values.length;
List<TreeNodeCondition> resultList = new ArrayList<TreeNodeCondition>(lengthNonMissing);
for (int i = 0; i < lengthNonMissing; i++) {
if (m_sumWeightsAttributes[i] >= TreeColumnData.EPSILON) {
resultList.add(new TreeNodeNominalCondition(columnMeta, i, i == m_missingsGoToChildIdx));
}
}
return resultList.toArray(new TreeNodeNominalCondition[resultList.size()]);
}
use of org.knime.base.node.mine.treeensemble2.model.TreeNodeCondition in project knime-core by knime.
the class TreeModelPMMLTranslator method addTreeNode.
/**
* @param pmmlNode
* @param node
*/
private void addTreeNode(final Node pmmlNode, final AbstractTreeNode node) {
int index = m_nodeIndex++;
pmmlNode.setId(Integer.toString(index));
if (node instanceof TreeNodeClassification) {
final TreeNodeClassification clazNode = (TreeNodeClassification) node;
pmmlNode.setScore(clazNode.getMajorityClassName());
float[] targetDistribution = clazNode.getTargetDistribution();
NominalValueRepresentation[] targetVals = clazNode.getTargetMetaData().getValues();
double sum = 0.0;
for (Float v : targetDistribution) {
sum += v;
}
pmmlNode.setRecordCount(sum);
// adding score distribution (class counts)
for (int i = 0; i < targetDistribution.length; i++) {
String className = targetVals[i].getNominalValue();
double freq = targetDistribution[i];
ScoreDistribution pmmlScoreDist = pmmlNode.addNewScoreDistribution();
pmmlScoreDist.setValue(className);
pmmlScoreDist.setRecordCount(freq);
}
} else if (node instanceof TreeNodeRegression) {
final TreeNodeRegression regNode = (TreeNodeRegression) node;
pmmlNode.setScore(Double.toString(regNode.getMean()));
}
TreeNodeCondition condition = node.getCondition();
if (condition instanceof TreeNodeTrueCondition) {
pmmlNode.addNewTrue();
} else if (condition instanceof TreeNodeColumnCondition) {
final TreeNodeColumnCondition colCondition = (TreeNodeColumnCondition) condition;
handleColumnCondition(colCondition, pmmlNode);
} else if (condition instanceof AbstractTreeNodeSurrogateCondition) {
final AbstractTreeNodeSurrogateCondition surrogateCond = (AbstractTreeNodeSurrogateCondition) condition;
setValuesFromPMMLCompoundPredicate(pmmlNode.addNewCompoundPredicate(), surrogateCond.toPMMLPredicate());
} else {
throw new IllegalStateException("Unsupported condition (not " + "implemented): " + condition.getClass().getSimpleName());
}
for (int i = 0; i < node.getNrChildren(); i++) {
addTreeNode(pmmlNode.addNewNode(), node.getChild(i));
}
}
Aggregations