use of org.knime.base.node.mine.treeensemble2.model.TreeNodeCondition in project knime-core by knime.
the class TreeNumericColumnDataTest method testXGBoostMissingValueHandling.
/**
* This method tests if the conditions for child nodes are correct in case of XGBoostMissingValueHandling
*
* @throws Exception
*/
@Test
public void testXGBoostMissingValueHandling() throws Exception {
TreeEnsembleLearnerConfiguration config = createConfig();
config.setMissingValueHandling(MissingValueHandling.XGBoost);
final TestDataGenerator dataGen = new TestDataGenerator(config);
final RandomData rd = config.createRandomData();
final int[] indices = new int[] { 0, 1, 2, 3, 4, 5, 6, 7, 8, 9 };
final double[] weights = new double[10];
Arrays.fill(weights, 1.0);
final MockDataColMem dataMem = new MockDataColMem(indices, indices, weights);
final String dataCSV = "1,2,2,3,4,5,6,7,NaN,NaN";
final String target1CSV = "A,A,A,A,B,B,B,B,A,A";
final String target2CSV = "A,A,A,A,B,B,B,B,B,B";
final double expectedGain = 0.48;
final TreeNumericColumnData col = dataGen.createNumericAttributeColumn(dataCSV, "testCol", 0);
final TreeTargetNominalColumnData target1 = TestDataGenerator.createNominalTargetColumn(target1CSV);
final SplitCandidate split1 = col.calcBestSplitClassification(dataMem, target1.getDistribution(weights, config), target1, rd);
assertEquals("Wrong gain.", expectedGain, split1.getGainValue(), 1e-8);
final TreeNodeCondition[] childConds1 = split1.getChildConditions();
final TreeNodeNumericCondition numCondLeft1 = (TreeNodeNumericCondition) childConds1[0];
assertEquals("Wrong split point.", 3.5, numCondLeft1.getSplitValue(), 1e-8);
assertTrue("Missings were not sent in the correct direction.", numCondLeft1.acceptsMissings());
final TreeNodeNumericCondition numCondRight1 = (TreeNodeNumericCondition) childConds1[1];
assertEquals("Wrong split point.", 3.5, numCondRight1.getSplitValue(), 1e-8);
assertFalse("Missings were not sent in the correct direction.", numCondRight1.acceptsMissings());
final TreeTargetNominalColumnData target2 = TestDataGenerator.createNominalTargetColumn(target2CSV);
final SplitCandidate split2 = col.calcBestSplitClassification(dataMem, target2.getDistribution(weights, config), target2, rd);
assertEquals("Wrong gain.", expectedGain, split2.getGainValue(), 1e-8);
final TreeNodeCondition[] childConds2 = split2.getChildConditions();
final TreeNodeNumericCondition numCondLeft2 = (TreeNodeNumericCondition) childConds2[0];
assertEquals("Wrong split point.", 3.5, numCondLeft2.getSplitValue(), 1e-8);
assertFalse("Missings were not sent in the correct direction.", numCondLeft2.acceptsMissings());
final TreeNodeNumericCondition numCondRight2 = (TreeNodeNumericCondition) childConds2[1];
assertEquals("Wrong split point.", 3.5, numCondRight2.getSplitValue(), 1e-8);
assertTrue("Missings were not sent in the correct direction.", numCondRight2.acceptsMissings());
}
use of org.knime.base.node.mine.treeensemble2.model.TreeNodeCondition in project knime-core by knime.
the class TreeNumericColumnDataTest method testCalcBestSplitRegression.
@Test
public void testCalcBestSplitRegression() throws InvalidSettingsException {
String dataCSV = "1,2,3,4,5,6,7,8,9,10";
String targetCSV = "1,5,4,4.3,6.5,6.5,4,3,3,4";
TreeEnsembleLearnerConfiguration config = new TreeEnsembleLearnerConfiguration(true);
config.setNrModels(1);
config.setDataSelectionWithReplacement(false);
config.setUseDifferentAttributesAtEachNode(false);
config.setDataFractionPerTree(1.0);
config.setColumnSamplingMode(ColumnSamplingMode.None);
TestDataGenerator dataGen = new TestDataGenerator(config);
RandomData rd = config.createRandomData();
TreeTargetNumericColumnData target = TestDataGenerator.createNumericTargetColumn(targetCSV);
TreeNumericColumnData attribute = dataGen.createNumericAttributeColumn(dataCSV, "test-col", 0);
TreeData data = new TreeData(new TreeAttributeColumnData[] { attribute }, target, TreeType.Ordinary);
double[] weights = new double[10];
Arrays.fill(weights, 1.0);
DataMemberships rootMem = new RootDataMemberships(weights, data, new DefaultDataIndexManager(data));
SplitCandidate firstSplit = attribute.calcBestSplitRegression(rootMem, target.getPriors(rootMem, config), target, rd);
// calculated via OpenOffice calc
assertEquals(10.885444, firstSplit.getGainValue(), 1e-5);
TreeNodeCondition[] firstConditions = firstSplit.getChildConditions();
assertEquals(2, firstConditions.length);
for (int i = 0; i < firstConditions.length; i++) {
assertThat(firstConditions[i], instanceOf(TreeNodeNumericCondition.class));
TreeNodeNumericCondition numCond = (TreeNodeNumericCondition) firstConditions[i];
assertEquals(1.5, numCond.getSplitValue(), 0);
}
// left child contains only one row therefore only look at right child
BitSet expectedInChild = new BitSet(10);
expectedInChild.set(1, 10);
BitSet inChild = attribute.updateChildMemberships(firstConditions[1], rootMem);
assertEquals(expectedInChild, inChild);
DataMemberships childMem = rootMem.createChildMemberships(inChild);
SplitCandidate secondSplit = attribute.calcBestSplitRegression(childMem, target.getPriors(childMem, config), target, rd);
assertEquals(6.883555, secondSplit.getGainValue(), 1e-5);
TreeNodeCondition[] secondConditions = secondSplit.getChildConditions();
for (int i = 0; i < secondConditions.length; i++) {
assertThat(secondConditions[i], instanceOf(TreeNodeNumericCondition.class));
TreeNodeNumericCondition numCond = (TreeNodeNumericCondition) secondConditions[i];
assertEquals(6.5, numCond.getSplitValue(), 0);
}
}
use of org.knime.base.node.mine.treeensemble2.model.TreeNodeCondition in project knime-core by knime.
the class TreeNodeClassification method createDecisionTreeNode.
/**
* Creates DecisionTreeNode model that is used in Decision Tree of KNIME
*
* @param idGenerator
* @param metaData
* @return a DecisionTreeNode
*/
public DecisionTreeNode createDecisionTreeNode(final MutableInteger idGenerator, final TreeMetaData metaData) {
DataCell majorityCell = new StringCell(getMajorityClassName());
final float[] targetDistribution = getTargetDistribution();
int initSize = (int) (targetDistribution.length / 0.75 + 1.0);
LinkedHashMap<DataCell, Double> scoreDistributionMap = new LinkedHashMap<DataCell, Double>(initSize);
NominalValueRepresentation[] targets = getTargetMetaData().getValues();
for (int i = 0; i < targetDistribution.length; i++) {
String cl = targets[i].getNominalValue();
double d = targetDistribution[i];
scoreDistributionMap.put(new StringCell(cl), d);
}
final int nrChildren = getNrChildren();
if (nrChildren == 0) {
return new DecisionTreeNodeLeaf(idGenerator.inc(), majorityCell, scoreDistributionMap);
} else {
int id = idGenerator.inc();
DecisionTreeNode[] childNodes = new DecisionTreeNode[nrChildren];
int splitAttributeIndex = getSplitAttributeIndex();
assert splitAttributeIndex >= 0 : "non-leaf node has no split";
String splitAttribute = metaData.getAttributeMetaData(splitAttributeIndex).getAttributeName();
PMMLPredicate[] childPredicates = new PMMLPredicate[nrChildren];
for (int i = 0; i < nrChildren; i++) {
final TreeNodeClassification treeNode = getChild(i);
TreeNodeCondition cond = treeNode.getCondition();
childPredicates[i] = cond.toPMMLPredicate();
childNodes[i] = treeNode.createDecisionTreeNode(idGenerator, metaData);
}
return new DecisionTreeNodeSplitPMML(id, majorityCell, scoreDistributionMap, splitAttribute, childPredicates, childNodes);
}
}
use of org.knime.base.node.mine.treeensemble2.model.TreeNodeCondition in project knime-core by knime.
the class AbstractTreeModelExporter method addTreeNode.
/**
* @param pmmlNode
* @param node
*/
@SuppressWarnings("unchecked")
private void addTreeNode(final Node pmmlNode, final T node) {
int index = m_nodeIndex;
m_nodeIndex++;
pmmlNode.setId(Integer.toString(index));
addNodeContent(index, pmmlNode, node);
TreeNodeCondition condition = node.getCondition();
m_conditionExporter.exportCondition(condition, pmmlNode);
for (int i = 0; i < node.getNrChildren(); i++) {
addTreeNode(pmmlNode.addNewNode(), (T) node.getChild(i));
}
}
use of org.knime.base.node.mine.treeensemble2.model.TreeNodeCondition in project knime-core by knime.
the class ConditionExporter method exportCondition.
void exportCondition(final TreeNodeCondition condition, final Node pmmlNode) {
if (condition instanceof TreeNodeTrueCondition) {
pmmlNode.addNewTrue();
} else if (condition instanceof TreeNodeColumnCondition) {
final TreeNodeColumnCondition colCondition = (TreeNodeColumnCondition) condition;
exportColumnCondition(colCondition, pmmlNode);
} else if (condition instanceof AbstractTreeNodeSurrogateCondition) {
final AbstractTreeNodeSurrogateCondition surrogateCond = (AbstractTreeNodeSurrogateCondition) condition;
setValuesFromPMMLCompoundPredicate(pmmlNode.addNewCompoundPredicate(), surrogateCond.toPMMLPredicate());
} else {
throw new IllegalStateException("Unsupported condition (not implemented): " + condition.getClass().getSimpleName());
}
}
Aggregations