use of org.knime.base.node.mine.treeensemble2.model.TreeNodeNominalBinaryCondition in project knime-core by knime.
the class TreeNominalColumnDataTest method testCalcBestSplitClassificationBinary.
/**
* Tests the method
* {@link TreeNominalColumnData#calcBestSplitClassification(DataMemberships, ClassificationPriors, TreeTargetNominalColumnData, RandomData)}
* using binary splits.
*
* @throws Exception
*/
@Test
public void testCalcBestSplitClassificationBinary() throws Exception {
final TreeEnsembleLearnerConfiguration config = createConfig(false);
Pair<TreeNominalColumnData, TreeTargetNominalColumnData> tennisData = tennisData(config);
TreeNominalColumnData columnData = tennisData.getFirst();
TreeTargetNominalColumnData targetData = tennisData.getSecond();
assertEquals(SplitCriterion.Gini, config.getSplitCriterion());
double[] rowWeights = new double[SMALL_COLUMN_DATA.length];
Arrays.fill(rowWeights, 1.0);
TreeData tennisTreeData = tennisTreeData(config);
IDataIndexManager indexManager = new DefaultDataIndexManager(tennisTreeData);
DataMemberships dataMemberships = new RootDataMemberships(rowWeights, tennisTreeData, indexManager);
ClassificationPriors priors = targetData.getDistribution(rowWeights, config);
SplitCandidate splitCandidate = columnData.calcBestSplitClassification(dataMemberships, priors, targetData, null);
assertNotNull(splitCandidate);
assertThat(splitCandidate, instanceOf(NominalBinarySplitCandidate.class));
assertTrue(splitCandidate.canColumnBeSplitFurther());
// manually via libre office calc
assertEquals(0.0689342404, splitCandidate.getGainValue(), 0.00001);
NominalBinarySplitCandidate binSplitCandidate = (NominalBinarySplitCandidate) splitCandidate;
TreeNodeNominalBinaryCondition[] childConditions = binSplitCandidate.getChildConditions();
assertEquals(2, childConditions.length);
assertArrayEquals(new String[] { "R" }, childConditions[0].getValues());
assertArrayEquals(new String[] { "R" }, childConditions[1].getValues());
assertEquals(SetLogic.IS_NOT_IN, childConditions[0].getSetLogic());
assertEquals(SetLogic.IS_IN, childConditions[1].getSetLogic());
BitSet inChild = columnData.updateChildMemberships(childConditions[0], dataMemberships);
DataMemberships child1Memberships = dataMemberships.createChildMemberships(inChild);
ClassificationPriors childTargetPriors = targetData.getDistribution(child1Memberships, config);
SplitCandidate splitCandidateChild = columnData.calcBestSplitClassification(child1Memberships, childTargetPriors, targetData, null);
assertNotNull(splitCandidateChild);
assertThat(splitCandidateChild, instanceOf(NominalBinarySplitCandidate.class));
// manually via libre office calc
assertEquals(0.0086419753, splitCandidateChild.getGainValue(), 0.00001);
inChild = columnData.updateChildMemberships(childConditions[1], dataMemberships);
DataMemberships child2Memberships = dataMemberships.createChildMemberships(inChild);
childTargetPriors = targetData.getDistribution(child2Memberships, config);
splitCandidateChild = columnData.calcBestSplitClassification(child2Memberships, childTargetPriors, targetData, null);
assertNull(splitCandidateChild);
}
use of org.knime.base.node.mine.treeensemble2.model.TreeNodeNominalBinaryCondition in project knime-core by knime.
the class TreeNominalColumnDataTest method testCalcBestSplitClassificationBinaryPCA.
/**
* Tests the method
* {@link TreeNominalColumnData#calcBestSplitClassification(DataMemberships, ClassificationPriors, TreeTargetNominalColumnData, RandomData)}
* using binary splits. In this test case the data has more than two classes and the used algorithm is therefore PCA
* based.
*
* @throws Exception
*/
@Test
public void testCalcBestSplitClassificationBinaryPCA() throws Exception {
TreeEnsembleLearnerConfiguration config = createConfig(false);
Pair<TreeNominalColumnData, TreeTargetNominalColumnData> pcaData = createPCATestData(config);
TreeNominalColumnData columnData = pcaData.getFirst();
TreeTargetNominalColumnData targetData = pcaData.getSecond();
TreeData treeData = createTreeData(pcaData);
assertEquals(SplitCriterion.Gini, config.getSplitCriterion());
double[] rowWeights = new double[targetData.getNrRows()];
Arrays.fill(rowWeights, 1.0);
IDataIndexManager indexManager = new DefaultDataIndexManager(treeData);
DataMemberships dataMemberships = new RootDataMemberships(rowWeights, treeData, indexManager);
ClassificationPriors priors = targetData.getDistribution(rowWeights, config);
SplitCandidate splitCandidate = columnData.calcBestSplitClassification(dataMemberships, priors, targetData, null);
assertNotNull(splitCandidate);
assertThat(splitCandidate, instanceOf(NominalBinarySplitCandidate.class));
assertTrue(splitCandidate.canColumnBeSplitFurther());
assertEquals(0.0659, splitCandidate.getGainValue(), 0.0001);
NominalBinarySplitCandidate binarySplitCandidate = (NominalBinarySplitCandidate) splitCandidate;
TreeNodeNominalBinaryCondition[] childConditions = binarySplitCandidate.getChildConditions();
assertEquals(2, childConditions.length);
assertArrayEquals(new String[] { "E" }, childConditions[0].getValues());
assertArrayEquals(new String[] { "E" }, childConditions[1].getValues());
assertEquals(SetLogic.IS_NOT_IN, childConditions[0].getSetLogic());
assertEquals(SetLogic.IS_IN, childConditions[1].getSetLogic());
}
use of org.knime.base.node.mine.treeensemble2.model.TreeNodeNominalBinaryCondition in project knime-core by knime.
the class TreeNominalColumnDataTest method testUpdateChildMemberships.
/**
* Tests the method
* {@link TreeNominalColumnData#updateChildMemberships(org.knime.base.node.mine.treeensemble2.model.TreeNodeCondition, DataMemberships)}
* .
*
* @throws Exception
*/
@Test
public void testUpdateChildMemberships() throws Exception {
// in this case it doesn't matter if we use regression or classification (as well as binary and multiway splits)
final TreeEnsembleLearnerConfiguration config = createConfig(true);
final TestDataGenerator dataGen = new TestDataGenerator(config);
final String dataCSV = "A, A, A, A, B, B, B, C, C, C, ?, ?";
TreeNominalColumnData col = dataGen.createNominalAttributeColumn(dataCSV, "test-col", 0);
final int[] indices = new int[12];
final double[] weights = new double[indices.length];
for (int i = 0; i < indices.length; i++) {
indices[i] = i;
weights[i] = 1.0;
}
final DataMemberships dataMem = new MockDataColMem(indices, indices, weights);
TreeNodeNominalBinaryCondition binCond = new TreeNodeNominalBinaryCondition(col.getMetaData(), BigInteger.valueOf(2), true, false);
BitSet expected = new BitSet(12);
BitSet inChild = col.updateChildMemberships(binCond, dataMem);
expected.set(4, 7);
assertEquals("The produced BitSet is incorrect.", expected, inChild);
binCond = new TreeNodeNominalBinaryCondition(col.getMetaData(), BigInteger.valueOf(2), true, true);
expected.clear();
expected.set(4, 7);
expected.set(10, 12);
inChild = col.updateChildMemberships(binCond, dataMem);
assertEquals("The produced BitSet is incorrect.", expected, inChild);
binCond = new TreeNodeNominalBinaryCondition(col.getMetaData(), BigInteger.valueOf(2), false, false);
expected.clear();
expected.set(0, 4);
expected.set(7, 10);
inChild = col.updateChildMemberships(binCond, dataMem);
assertEquals("The produced BitSet is incorrect.", expected, inChild);
binCond = new TreeNodeNominalBinaryCondition(col.getMetaData(), BigInteger.valueOf(2), false, true);
expected.clear();
expected.set(0, 4);
expected.set(7, 12);
inChild = col.updateChildMemberships(binCond, dataMem);
assertEquals("The produced BitSet is incorrect.", expected, inChild);
binCond = new TreeNodeNominalBinaryCondition(col.getMetaData(), BigInteger.valueOf(5), true, false);
expected.clear();
expected.set(0, 4);
expected.set(7, 10);
inChild = col.updateChildMemberships(binCond, dataMem);
assertEquals("The produced BitSet is incorrect.", expected, inChild);
binCond = new TreeNodeNominalBinaryCondition(col.getMetaData(), BigInteger.valueOf(5), true, true);
expected.clear();
expected.set(0, 4);
expected.set(7, 12);
inChild = col.updateChildMemberships(binCond, dataMem);
assertEquals("The produced BitSet is incorrect.", expected, inChild);
TreeNodeNominalCondition multiCond = new TreeNodeNominalCondition(col.getMetaData(), 0, false);
expected.clear();
expected.set(0, 4);
inChild = col.updateChildMemberships(multiCond, dataMem);
assertEquals("The produced BitSet is incorrect.", expected, inChild);
multiCond = new TreeNodeNominalCondition(col.getMetaData(), 0, true);
expected.clear();
expected.set(0, 4);
expected.set(10, 12);
inChild = col.updateChildMemberships(multiCond, dataMem);
assertEquals("The produced BitSet is incorrect.", expected, inChild);
multiCond = new TreeNodeNominalCondition(col.getMetaData(), 2, false);
expected.clear();
expected.set(7, 10);
inChild = col.updateChildMemberships(multiCond, dataMem);
assertEquals("The produced BitSet is incorrect.", expected, inChild);
multiCond = new TreeNodeNominalCondition(col.getMetaData(), 2, true);
expected.clear();
expected.set(7, 12);
inChild = col.updateChildMemberships(multiCond, dataMem);
assertEquals("The produced BitSet is incorrect.", expected, inChild);
}
Aggregations