use of org.knime.base.node.mine.treeensemble.data.NominalValueRepresentation in project knime-core by knime.
the class NominalSplitCandidate method getChildConditions.
/**
* {@inheritDoc}
*/
@Override
public TreeNodeCondition[] getChildConditions() {
TreeNominalColumnMetaData columnMeta = getColumnData().getMetaData();
NominalValueRepresentation[] values = columnMeta.getValues();
List<TreeNodeCondition> resultList = new ArrayList<TreeNodeCondition>(values.length);
for (int i = 0; i < values.length; i++) {
if (m_sumWeightsAttributes[i] >= TreeColumnData.EPSILON) {
resultList.add(new TreeNodeNominalCondition(columnMeta, i));
}
}
return resultList.toArray(new TreeNodeCondition[resultList.size()]);
}
use of org.knime.base.node.mine.treeensemble.data.NominalValueRepresentation in project knime-core by knime.
the class TreeNodeClassification method createDecisionTreeNode.
/**
* Creates DecisionTreeNode model that is used in Decision Tree of KNIME
*
* @param idGenerator
* @param metaData
* @return a DecisionTreeNode
*/
public DecisionTreeNode createDecisionTreeNode(final MutableInteger idGenerator, final TreeMetaData metaData) {
DataCell majorityCell = new StringCell(getMajorityClassName());
double[] targetDistribution = getTargetDistribution();
int initSize = (int) (targetDistribution.length / 0.75 + 1.0);
LinkedHashMap<DataCell, Double> scoreDistributionMap = new LinkedHashMap<DataCell, Double>(initSize);
NominalValueRepresentation[] targets = getTargetMetaData().getValues();
for (int i = 0; i < targetDistribution.length; i++) {
String cl = targets[i].getNominalValue();
double d = targetDistribution[i];
scoreDistributionMap.put(new StringCell(cl), d);
}
final int nrChildren = getNrChildren();
if (nrChildren == 0) {
return new DecisionTreeNodeLeaf(idGenerator.inc(), majorityCell, scoreDistributionMap);
} else {
int id = idGenerator.inc();
DecisionTreeNode[] childNodes = new DecisionTreeNode[nrChildren];
int splitAttributeIndex = getSplitAttributeIndex();
assert splitAttributeIndex >= 0 : "non-leaf node has no split";
String splitAttribute = metaData.getAttributeMetaData(splitAttributeIndex).getAttributeName();
PMMLPredicate[] childPredicates = new PMMLPredicate[nrChildren];
for (int i = 0; i < nrChildren; i++) {
final TreeNodeClassification treeNode = getChild(i);
TreeNodeCondition cond = treeNode.getCondition();
childPredicates[i] = cond.toPMMLPredicate();
childNodes[i] = treeNode.createDecisionTreeNode(idGenerator, metaData);
}
return new DecisionTreeNodeSplitPMML(id, majorityCell, scoreDistributionMap, splitAttribute, childPredicates, childNodes);
}
}
use of org.knime.base.node.mine.treeensemble.data.NominalValueRepresentation in project knime-core by knime.
the class TreeModelPMMLTranslator method addTreeNode.
/**
* @param pmmlNode
* @param node
*/
private void addTreeNode(final Node pmmlNode, final TreeNodeClassification node) {
int index = m_nodeIndex++;
pmmlNode.setId(Integer.toString(index));
pmmlNode.setScore(node.getMajorityClassName());
double[] targetDistribution = node.getTargetDistribution();
NominalValueRepresentation[] targetVals = node.getTargetMetaData().getValues();
double sum = 0.0;
for (Double v : targetDistribution) {
sum += v;
}
pmmlNode.setRecordCount(sum);
TreeNodeCondition condition = node.getCondition();
if (condition instanceof TreeNodeTrueCondition) {
pmmlNode.addNewTrue();
} else if (condition instanceof TreeNodeColumnCondition) {
final TreeNodeColumnCondition colCondition = (TreeNodeColumnCondition) condition;
final String colName = colCondition.getColumnMetaData().getAttributeName();
final Operator.Enum operator;
final String value;
if (condition instanceof TreeNodeNominalCondition) {
final TreeNodeNominalCondition nominalCondition = (TreeNodeNominalCondition) condition;
operator = Operator.EQUAL;
value = nominalCondition.getValue();
} else if (condition instanceof TreeNodeBitCondition) {
final TreeNodeBitCondition bitCondition = (TreeNodeBitCondition) condition;
operator = Operator.EQUAL;
value = bitCondition.getValue() ? "1" : "0";
} else if (condition instanceof TreeNodeNumericCondition) {
final TreeNodeNumericCondition numCondition = (TreeNodeNumericCondition) condition;
NumericOperator numOperator = numCondition.getNumericOperator();
switch(numOperator) {
case LargerThan:
operator = Operator.GREATER_THAN;
break;
case LessThanOrEqual:
operator = Operator.LESS_OR_EQUAL;
break;
default:
throw new IllegalStateException("Unsupported operator (not " + "implemented): " + numOperator);
}
value = Double.toString(numCondition.getSplitValue());
} else {
throw new IllegalStateException("Unsupported condition (not " + "implemented): " + condition.getClass().getSimpleName());
}
SimplePredicate pmmlSimplePredicate = pmmlNode.addNewSimplePredicate();
pmmlSimplePredicate.setField(colName);
pmmlSimplePredicate.setOperator(operator);
pmmlSimplePredicate.setValue(value);
} else {
throw new IllegalStateException("Unsupported condition (not " + "implemented): " + condition.getClass().getSimpleName());
}
// adding score distribution (class counts)
for (int i = 0; i < targetDistribution.length; i++) {
String className = targetVals[i].getNominalValue();
double freq = targetDistribution[i];
ScoreDistribution pmmlScoreDist = pmmlNode.addNewScoreDistribution();
pmmlScoreDist.setValue(className);
pmmlScoreDist.setRecordCount(freq);
}
for (int i = 0; i < node.getNrChildren(); i++) {
addTreeNode(pmmlNode.addNewNode(), node.getChild(i));
}
}
Aggregations