use of org.knime.base.node.mine.treeensemble.model.TreeNodeNumericCondition.NumericOperator in project knime-core by knime.
the class TreeNumericColumnData method updateChildMemberships.
@Override
public void updateChildMemberships(final TreeNodeCondition childCondition, final double[] parentMemberships, final double[] childMembershipsToUpdate) {
final TreeNodeNumericCondition numCondition = (TreeNodeNumericCondition) childCondition;
final NumericOperator numOperator = numCondition.getNumericOperator();
final double splitValue = numCondition.getSplitValue();
for (int i = 0; i < m_originalIndexInColumnList.length; i++) {
final double value = getSorted(i);
final int originalColIndex = m_originalIndexInColumnList[i];
boolean matches;
switch(numOperator) {
case LessThanOrEqual:
matches = value <= splitValue;
break;
case LargerThan:
matches = value > splitValue;
break;
default:
throw new IllegalStateException("Unknown operator " + numOperator);
}
if (!matches) {
childMembershipsToUpdate[originalColIndex] = 0.0;
} else {
assert childMembershipsToUpdate[originalColIndex] == parentMemberships[originalColIndex];
}
}
}
use of org.knime.base.node.mine.treeensemble.model.TreeNodeNumericCondition.NumericOperator in project knime-core by knime.
the class TreeModelPMMLTranslator method addTreeNode.
/**
* @param pmmlNode
* @param node
*/
private void addTreeNode(final Node pmmlNode, final TreeNodeClassification node) {
int index = m_nodeIndex++;
pmmlNode.setId(Integer.toString(index));
pmmlNode.setScore(node.getMajorityClassName());
double[] targetDistribution = node.getTargetDistribution();
NominalValueRepresentation[] targetVals = node.getTargetMetaData().getValues();
double sum = 0.0;
for (Double v : targetDistribution) {
sum += v;
}
pmmlNode.setRecordCount(sum);
TreeNodeCondition condition = node.getCondition();
if (condition instanceof TreeNodeTrueCondition) {
pmmlNode.addNewTrue();
} else if (condition instanceof TreeNodeColumnCondition) {
final TreeNodeColumnCondition colCondition = (TreeNodeColumnCondition) condition;
final String colName = colCondition.getColumnMetaData().getAttributeName();
final Operator.Enum operator;
final String value;
if (condition instanceof TreeNodeNominalCondition) {
final TreeNodeNominalCondition nominalCondition = (TreeNodeNominalCondition) condition;
operator = Operator.EQUAL;
value = nominalCondition.getValue();
} else if (condition instanceof TreeNodeBitCondition) {
final TreeNodeBitCondition bitCondition = (TreeNodeBitCondition) condition;
operator = Operator.EQUAL;
value = bitCondition.getValue() ? "1" : "0";
} else if (condition instanceof TreeNodeNumericCondition) {
final TreeNodeNumericCondition numCondition = (TreeNodeNumericCondition) condition;
NumericOperator numOperator = numCondition.getNumericOperator();
switch(numOperator) {
case LargerThan:
operator = Operator.GREATER_THAN;
break;
case LessThanOrEqual:
operator = Operator.LESS_OR_EQUAL;
break;
default:
throw new IllegalStateException("Unsupported operator (not " + "implemented): " + numOperator);
}
value = Double.toString(numCondition.getSplitValue());
} else {
throw new IllegalStateException("Unsupported condition (not " + "implemented): " + condition.getClass().getSimpleName());
}
SimplePredicate pmmlSimplePredicate = pmmlNode.addNewSimplePredicate();
pmmlSimplePredicate.setField(colName);
pmmlSimplePredicate.setOperator(operator);
pmmlSimplePredicate.setValue(value);
} else {
throw new IllegalStateException("Unsupported condition (not " + "implemented): " + condition.getClass().getSimpleName());
}
// adding score distribution (class counts)
for (int i = 0; i < targetDistribution.length; i++) {
String className = targetVals[i].getNominalValue();
double freq = targetDistribution[i];
ScoreDistribution pmmlScoreDist = pmmlNode.addNewScoreDistribution();
pmmlScoreDist.setValue(className);
pmmlScoreDist.setRecordCount(freq);
}
for (int i = 0; i < node.getNrChildren(); i++) {
addTreeNode(pmmlNode.addNewNode(), node.getChild(i));
}
}
Aggregations