use of org.knime.base.node.mine.treeensemble2.data.TestDataGenerator in project knime-core by knime.
the class TreeNominalColumnDataTest method testCalcBestSplitClassificationBinaryPCAXGBoostMissingValueHandling.
/**
* Tests the XGBoost missing value handling in the case of binary splits calculated with the pca method (multiple classes)
*
* @throws Exception
*/
@Test
public void testCalcBestSplitClassificationBinaryPCAXGBoostMissingValueHandling() throws Exception {
final TreeEnsembleLearnerConfiguration config = createConfig(false);
config.setMissingValueHandling(MissingValueHandling.XGBoost);
final TestDataGenerator dataGen = new TestDataGenerator(config);
final RandomData rd = config.createRandomData();
// test the case that there are no missing values in the training data
final String noMissingCSV = "a, a, a, b, b, b, b, c, c";
final String noMissingTarget = "A, B, B, C, C, C, B, A, B";
TreeNominalColumnData dataCol = dataGen.createNominalAttributeColumn(noMissingCSV, "noMissings", 0);
TreeTargetNominalColumnData targetCol = TestDataGenerator.createNominalTargetColumn(noMissingTarget);
DataMemberships dataMem = createMockDataMemberships(targetCol.getNrRows());
SplitCandidate split = dataCol.calcBestSplitClassification(dataMem, targetCol.getDistribution(dataMem, config), targetCol, rd);
assertNotNull("There is a possible split.", split);
assertEquals("Incorrect gain.", 0.2086, split.getGainValue(), 1e-3);
assertThat(split, instanceOf(NominalBinarySplitCandidate.class));
NominalBinarySplitCandidate nomSplit = (NominalBinarySplitCandidate) split;
assertTrue("No missing values in the column.", nomSplit.getMissedRows().isEmpty());
TreeNodeNominalBinaryCondition[] conditions = nomSplit.getChildConditions();
assertEquals("A binary split must have 2 child conditions.", 2, conditions.length);
String[] values = new String[] { "a", "c" };
assertArrayEquals("Wrong values in child condition.", values, conditions[0].getValues());
assertArrayEquals("Wrong values in child condition.", values, conditions[1].getValues());
assertEquals("Wrong set logic.", SetLogic.IS_NOT_IN, conditions[0].getSetLogic());
assertEquals("Wrong set logic.", SetLogic.IS_IN, conditions[1].getSetLogic());
assertFalse("Missing values should be sent to the majority child (i.e. right)", conditions[0].acceptsMissings());
assertTrue("Missing values should be sent to the majority child (i.e. right)", conditions[1].acceptsMissings());
// test the case that there are missing values in the training data
final String missingCSV = "a, a, a, b, b, b, b, c, c, ?";
final String missingTarget = "A, B, B, C, C, C, B, A, B, C";
dataCol = dataGen.createNominalAttributeColumn(missingCSV, "missings", 0);
targetCol = TestDataGenerator.createNominalTargetColumn(missingTarget);
dataMem = createMockDataMemberships(targetCol.getNrRows());
split = dataCol.calcBestSplitClassification(dataMem, targetCol.getDistribution(dataMem, config), targetCol, rd);
assertNotNull("There is a possible split.", split);
assertEquals("Incorrect gain.", 0.24, split.getGainValue(), 1e-3);
assertThat(split, instanceOf(NominalBinarySplitCandidate.class));
nomSplit = (NominalBinarySplitCandidate) split;
assertTrue("Split should handle missing values.", nomSplit.getMissedRows().isEmpty());
conditions = nomSplit.getChildConditions();
assertEquals("Wrong number of child conditions.", 2, conditions.length);
assertArrayEquals("Wrong values in child condition.", values, conditions[0].getValues());
assertArrayEquals("Wrong values in child condition.", values, conditions[1].getValues());
assertEquals("Wrong set logic.", SetLogic.IS_NOT_IN, conditions[0].getSetLogic());
assertEquals("Wrong set logic.", SetLogic.IS_IN, conditions[1].getSetLogic());
assertTrue("Missing values should be sent to left child", conditions[0].acceptsMissings());
assertFalse("Missing values should be sent to left child", conditions[1].acceptsMissings());
}
use of org.knime.base.node.mine.treeensemble2.data.TestDataGenerator in project knime-core by knime.
the class TreeNominalColumnDataTest method testCalcBestSplitClassificationMultiwayXGBoostMissingValueHandling.
/**
* This method tests the XGBoost missing value handling for classification in case of multiway splits.
*
* @throws Exception
*/
@Test
public void testCalcBestSplitClassificationMultiwayXGBoostMissingValueHandling() throws Exception {
final TreeEnsembleLearnerConfiguration config = createConfig(false);
config.setUseBinaryNominalSplits(false);
config.setMissingValueHandling(MissingValueHandling.XGBoost);
final TestDataGenerator dataGen = new TestDataGenerator(config);
final RandomData rd = config.createRandomData();
// test the case that there are no missing values in the training data
final String noMissingCSV = "a, a, a, b, b, b, b, c, c";
final String noMissingTarget = "A, B, B, C, C, C, B, A, B";
TreeNominalColumnData dataCol = dataGen.createNominalAttributeColumn(noMissingCSV, "noMissings", 0);
TreeTargetNominalColumnData targetCol = TestDataGenerator.createNominalTargetColumn(noMissingTarget);
DataMemberships dataMem = createMockDataMemberships(targetCol.getNrRows());
SplitCandidate split = dataCol.calcBestSplitClassification(dataMem, targetCol.getDistribution(dataMem, config), targetCol, rd);
assertNotNull("There is a possible split.", split);
assertEquals("Incorrect gain.", 0.216, split.getGainValue(), 1e-3);
assertThat(split, instanceOf(NominalMultiwaySplitCandidate.class));
NominalMultiwaySplitCandidate nomSplit = (NominalMultiwaySplitCandidate) split;
assertTrue("No missing values in the column.", nomSplit.getMissedRows().isEmpty());
TreeNodeNominalCondition[] conditions = nomSplit.getChildConditions();
assertEquals("Wrong number of child conditions.", 3, conditions.length);
assertEquals("Wrong value in child condition.", "a", conditions[0].getValue());
assertEquals("Wrong value in child condition.", "b", conditions[1].getValue());
assertEquals("Wrong value in child condition.", "c", conditions[2].getValue());
assertFalse("Missing values should be sent to the majority child (i.e. b)", conditions[0].acceptsMissings());
assertTrue("Missing values should be sent to the majority child (i.e. b)", conditions[1].acceptsMissings());
assertFalse("Missing values should be sent to the majority child (i.e. b)", conditions[2].acceptsMissings());
// test the case that there are missing values in the training data
final String missingCSV = "a, a, a, b, b, b, b, c, c, ?";
final String missingTarget = "A, B, B, C, C, C, B, A, B, C";
dataCol = dataGen.createNominalAttributeColumn(missingCSV, "missings", 0);
targetCol = TestDataGenerator.createNominalTargetColumn(missingTarget);
dataMem = createMockDataMemberships(targetCol.getNrRows());
split = dataCol.calcBestSplitClassification(dataMem, targetCol.getDistribution(dataMem, config), targetCol, rd);
assertNotNull("There is a possible split.", split);
assertEquals("Incorrect gain.", 0.2467, split.getGainValue(), 1e-3);
assertThat(split, instanceOf(NominalMultiwaySplitCandidate.class));
nomSplit = (NominalMultiwaySplitCandidate) split;
assertTrue("Split should handle missing values.", nomSplit.getMissedRows().isEmpty());
conditions = nomSplit.getChildConditions();
assertEquals("Wrong number of child conditions.", 3, conditions.length);
assertEquals("Wrong value in child condition.", "a", conditions[0].getValue());
assertEquals("Wrong value in child condition.", "b", conditions[1].getValue());
assertEquals("Wrong value in child condition.", "c", conditions[2].getValue());
assertFalse("Missing values should be sent to b", conditions[0].acceptsMissings());
assertTrue("Missing values should be sent to b", conditions[1].acceptsMissings());
assertFalse("Missing values should be sent to b", conditions[2].acceptsMissings());
}
use of org.knime.base.node.mine.treeensemble2.data.TestDataGenerator in project knime-core by knime.
the class TreeNominalColumnDataTest method testUpdateChildMemberships.
/**
* Tests the method
* {@link TreeNominalColumnData#updateChildMemberships(org.knime.base.node.mine.treeensemble2.model.TreeNodeCondition, DataMemberships)}
* .
*
* @throws Exception
*/
@Test
public void testUpdateChildMemberships() throws Exception {
// in this case it doesn't matter if we use regression or classification (as well as binary and multiway splits)
final TreeEnsembleLearnerConfiguration config = createConfig(true);
final TestDataGenerator dataGen = new TestDataGenerator(config);
final String dataCSV = "A, A, A, A, B, B, B, C, C, C, ?, ?";
TreeNominalColumnData col = dataGen.createNominalAttributeColumn(dataCSV, "test-col", 0);
final int[] indices = new int[12];
final double[] weights = new double[indices.length];
for (int i = 0; i < indices.length; i++) {
indices[i] = i;
weights[i] = 1.0;
}
final DataMemberships dataMem = new MockDataColMem(indices, indices, weights);
TreeNodeNominalBinaryCondition binCond = new TreeNodeNominalBinaryCondition(col.getMetaData(), BigInteger.valueOf(2), true, false);
BitSet expected = new BitSet(12);
BitSet inChild = col.updateChildMemberships(binCond, dataMem);
expected.set(4, 7);
assertEquals("The produced BitSet is incorrect.", expected, inChild);
binCond = new TreeNodeNominalBinaryCondition(col.getMetaData(), BigInteger.valueOf(2), true, true);
expected.clear();
expected.set(4, 7);
expected.set(10, 12);
inChild = col.updateChildMemberships(binCond, dataMem);
assertEquals("The produced BitSet is incorrect.", expected, inChild);
binCond = new TreeNodeNominalBinaryCondition(col.getMetaData(), BigInteger.valueOf(2), false, false);
expected.clear();
expected.set(0, 4);
expected.set(7, 10);
inChild = col.updateChildMemberships(binCond, dataMem);
assertEquals("The produced BitSet is incorrect.", expected, inChild);
binCond = new TreeNodeNominalBinaryCondition(col.getMetaData(), BigInteger.valueOf(2), false, true);
expected.clear();
expected.set(0, 4);
expected.set(7, 12);
inChild = col.updateChildMemberships(binCond, dataMem);
assertEquals("The produced BitSet is incorrect.", expected, inChild);
binCond = new TreeNodeNominalBinaryCondition(col.getMetaData(), BigInteger.valueOf(5), true, false);
expected.clear();
expected.set(0, 4);
expected.set(7, 10);
inChild = col.updateChildMemberships(binCond, dataMem);
assertEquals("The produced BitSet is incorrect.", expected, inChild);
binCond = new TreeNodeNominalBinaryCondition(col.getMetaData(), BigInteger.valueOf(5), true, true);
expected.clear();
expected.set(0, 4);
expected.set(7, 12);
inChild = col.updateChildMemberships(binCond, dataMem);
assertEquals("The produced BitSet is incorrect.", expected, inChild);
TreeNodeNominalCondition multiCond = new TreeNodeNominalCondition(col.getMetaData(), 0, false);
expected.clear();
expected.set(0, 4);
inChild = col.updateChildMemberships(multiCond, dataMem);
assertEquals("The produced BitSet is incorrect.", expected, inChild);
multiCond = new TreeNodeNominalCondition(col.getMetaData(), 0, true);
expected.clear();
expected.set(0, 4);
expected.set(10, 12);
inChild = col.updateChildMemberships(multiCond, dataMem);
assertEquals("The produced BitSet is incorrect.", expected, inChild);
multiCond = new TreeNodeNominalCondition(col.getMetaData(), 2, false);
expected.clear();
expected.set(7, 10);
inChild = col.updateChildMemberships(multiCond, dataMem);
assertEquals("The produced BitSet is incorrect.", expected, inChild);
multiCond = new TreeNodeNominalCondition(col.getMetaData(), 2, true);
expected.clear();
expected.set(7, 12);
inChild = col.updateChildMemberships(multiCond, dataMem);
assertEquals("The produced BitSet is incorrect.", expected, inChild);
}
use of org.knime.base.node.mine.treeensemble2.data.TestDataGenerator in project knime-core by knime.
the class TreeTargetNominalColumnDataTest method testGetDistribution.
/**
* Tests the {@link TreeTargetNominalColumnData#getDistribution(DataMemberships, TreeEnsembleLearnerConfiguration)}
* and {@link TreeTargetNominalColumnData#getDistribution(double[], TreeEnsembleLearnerConfiguration)} methods.
* @throws InvalidSettingsException
*/
@Test
public void testGetDistribution() throws InvalidSettingsException {
String targetCSV = "A,A,A,B,B,B,A";
String attributeCSV = "1,2,3,4,5,6,7";
TreeEnsembleLearnerConfiguration config = new TreeEnsembleLearnerConfiguration(false);
TestDataGenerator dataGen = new TestDataGenerator(config);
TreeTargetNominalColumnData target = TestDataGenerator.createNominalTargetColumn(targetCSV);
TreeNumericColumnData attribute = dataGen.createNumericAttributeColumn(attributeCSV, "test-col", 0);
TreeData data = new TreeData(new TreeAttributeColumnData[] { attribute }, target, TreeType.Ordinary);
double[] weights = new double[7];
Arrays.fill(weights, 1.0);
DataMemberships rootMemberships = new RootDataMemberships(weights, data, new DefaultDataIndexManager(data));
// Gini
config.setSplitCriterion(SplitCriterion.Gini);
double expectedGini = 0.4897959184;
double[] expectedDistribution = new double[] { 4.0, 3.0 };
ClassificationPriors giniPriorsDatMem = target.getDistribution(rootMemberships, config);
assertEquals(expectedGini, giniPriorsDatMem.getPriorImpurity(), DELTA);
assertArrayEquals(expectedDistribution, giniPriorsDatMem.getDistribution(), DELTA);
ClassificationPriors giniPriorsWeights = target.getDistribution(weights, config);
assertEquals(expectedGini, giniPriorsWeights.getPriorImpurity(), DELTA);
assertArrayEquals(expectedDistribution, giniPriorsWeights.getDistribution(), DELTA);
// Information Gain
config.setSplitCriterion(SplitCriterion.InformationGain);
double expectedEntropy = 0.985228136;
ClassificationPriors igPriorsDatMem = target.getDistribution(rootMemberships, config);
assertEquals(expectedEntropy, igPriorsDatMem.getPriorImpurity(), DELTA);
assertArrayEquals(expectedDistribution, igPriorsDatMem.getDistribution(), DELTA);
ClassificationPriors igPriorsWeights = target.getDistribution(weights, config);
assertEquals(expectedEntropy, igPriorsWeights.getPriorImpurity(), DELTA);
assertArrayEquals(expectedDistribution, igPriorsWeights.getDistribution(), DELTA);
// Information Gain Ratio
config.setSplitCriterion(SplitCriterion.InformationGainRatio);
// prior impurity is the same as IG
ClassificationPriors igrPriorsDatMem = target.getDistribution(rootMemberships, config);
assertEquals(expectedEntropy, igrPriorsDatMem.getPriorImpurity(), DELTA);
assertArrayEquals(expectedDistribution, igrPriorsDatMem.getDistribution(), DELTA);
ClassificationPriors igrPriorsWeights = target.getDistribution(weights, config);
assertEquals(expectedEntropy, igrPriorsWeights.getPriorImpurity(), DELTA);
assertArrayEquals(expectedDistribution, igrPriorsWeights.getDistribution(), DELTA);
}
use of org.knime.base.node.mine.treeensemble2.data.TestDataGenerator in project knime-core by knime.
the class TreeNodeNumericConditionTest method testTestCondition.
/**
* This method tests the
* {@link TreeNodeNominalCondition#testCondition(org.knime.base.node.mine.treeensemble2.data.PredictorRecord)}
* method.
*
* @throws Exception
*/
@Test
public void testTestCondition() throws Exception {
final TreeEnsembleLearnerConfiguration config = new TreeEnsembleLearnerConfiguration(false);
final TestDataGenerator dataGen = new TestDataGenerator(config);
final TreeNumericColumnData col = dataGen.createNumericAttributeColumn("1,2,3,4,4,5,6,7", "testCol", 0);
TreeNodeNumericCondition cond = new TreeNodeNumericCondition(col.getMetaData(), 3, NumericOperator.LessThanOrEqual, false);
final Map<String, Object> map = Maps.newHashMap();
final String colName = col.getMetaData().getAttributeName();
map.put(colName, 2.5);
final PredictorRecord record = new PredictorRecord(map);
assertTrue("2.5 was falsely rejected.", cond.testCondition(record));
map.clear();
map.put(colName, 3);
assertTrue("3 was falsely rejected.", cond.testCondition(record));
map.clear();
map.put(colName, 4);
assertFalse("4 was falsely accepted.", cond.testCondition(record));
map.clear();
map.put(colName, PredictorRecord.NULL);
assertFalse("Missing values were falsely accepted.", cond.testCondition(record));
cond = new TreeNodeNumericCondition(col.getMetaData(), 3, NumericOperator.LessThanOrEqual, true);
map.clear();
map.put(colName, 2.5);
assertTrue("2.5 was falsely rejected.", cond.testCondition(record));
map.clear();
map.put(colName, 3);
assertTrue("3 was falsely rejected.", cond.testCondition(record));
map.clear();
map.put(colName, 4);
assertFalse("4 was falsely accepted.", cond.testCondition(record));
map.clear();
map.put(colName, PredictorRecord.NULL);
assertTrue("Missing values were falsely rejected.", cond.testCondition(record));
cond = new TreeNodeNumericCondition(col.getMetaData(), 4, NumericOperator.LargerThan, false);
map.clear();
map.put(colName, 2.5);
assertFalse("2.5 was falsely accepted.", cond.testCondition(record));
map.clear();
map.put(colName, 3);
assertFalse("3 was falsely accepted.", cond.testCondition(record));
map.clear();
map.put(colName, 4);
assertFalse("4 was falsely accepted.", cond.testCondition(record));
map.clear();
map.put(colName, 4.01);
assertTrue("4.01 was falsely rejected.", cond.testCondition(record));
map.clear();
map.put(colName, PredictorRecord.NULL);
assertFalse("Missing values were falsely accepted.", cond.testCondition(record));
cond = new TreeNodeNumericCondition(col.getMetaData(), 4, NumericOperator.LargerThan, true);
map.clear();
map.put(colName, 2.5);
assertFalse("2.5 was falsely accepted.", cond.testCondition(record));
map.clear();
map.put(colName, 3);
assertFalse("3 was falsely accepted.", cond.testCondition(record));
map.clear();
map.put(colName, 4.01);
assertTrue("4 was falsely rejected.", cond.testCondition(record));
map.clear();
map.put(colName, PredictorRecord.NULL);
assertTrue("Missing values were falsely rejected.", cond.testCondition(record));
}
Aggregations