use of org.knime.base.node.mine.treeensemble2.data.TestDataGenerator in project knime-core by knime.
the class TreeNodeNominalBinaryConditionTest method testTestCondition.
/**
* This method tests the
* {@link TreeNodeNominalBinaryCondition#testCondition(org.knime.base.node.mine.treeensemble2.data.PredictorRecord)}
* method.
*
* @throws Exception
*/
@Test
public void testTestCondition() throws Exception {
final TreeEnsembleLearnerConfiguration config = new TreeEnsembleLearnerConfiguration(false);
final TestDataGenerator dataGen = new TestDataGenerator(config);
final TreeNominalColumnData col = dataGen.createNominalAttributeColumn("A,A,B,C,C,D", "testcol", 0);
TreeNodeNominalBinaryCondition cond = new TreeNodeNominalBinaryCondition(col.getMetaData(), BigInteger.valueOf(1), true, false);
final Map<String, Object> map = Maps.newHashMap();
final String colName = col.getMetaData().getAttributeName();
map.put(colName, 0);
PredictorRecord record = new PredictorRecord(map);
assertTrue("The value A was not accepted but should have been.", cond.testCondition(record));
map.clear();
map.put(colName, 1);
assertFalse("The value B was falsely accepted", cond.testCondition(record));
map.clear();
map.put(colName, 2);
assertFalse("The value C was falsely accepted", cond.testCondition(record));
map.clear();
map.put(colName, 3);
assertFalse("The value D was falsely accepted", cond.testCondition(record));
map.clear();
map.put(colName, PredictorRecord.NULL);
assertFalse("The condition falsely accepted missing values", cond.testCondition(record));
cond = new TreeNodeNominalBinaryCondition(col.getMetaData(), BigInteger.valueOf(5), true, true);
map.clear();
map.put(colName, 0);
assertTrue("The value A was falsely rejected.", cond.testCondition(record));
map.clear();
map.put(colName, 2);
assertTrue("The value C was falsely rejected.", cond.testCondition(record));
map.clear();
map.put(colName, PredictorRecord.NULL);
assertTrue("Missing values were falsely rejected.", cond.testCondition(record));
map.clear();
map.put(colName, 1);
assertFalse("The value B was falsely accepted.", cond.testCondition(record));
map.clear();
map.put(colName, 3);
assertFalse("The value B was falsely accepted.", cond.testCondition(record));
cond = new TreeNodeNominalBinaryCondition(col.getMetaData(), BigInteger.valueOf(5), false, true);
map.clear();
map.put(colName, 0);
assertFalse("The value A was falsely accepted.", cond.testCondition(record));
map.clear();
map.put(colName, 2);
assertFalse("The value C was falsely accepted.", cond.testCondition(record));
map.clear();
map.put(colName, PredictorRecord.NULL);
assertTrue("Missing values were falsely rejected.", cond.testCondition(record));
map.clear();
map.put(colName, 1);
assertTrue("The value B was falsely rejected.", cond.testCondition(record));
map.clear();
map.put(colName, 3);
assertTrue("The value D was falsely rejected.", cond.testCondition(record));
cond = new TreeNodeNominalBinaryCondition(col.getMetaData(), BigInteger.valueOf(5), false, false);
map.clear();
map.put(colName, 0);
assertFalse("The value A was falsely accepted.", cond.testCondition(record));
map.clear();
map.put(colName, 2);
assertFalse("The value C was falsely accepted.", cond.testCondition(record));
map.clear();
map.put(colName, PredictorRecord.NULL);
assertFalse("Missing values were falsely accepted.", cond.testCondition(record));
map.clear();
map.put(colName, 1);
assertTrue("The value B was falsely rejected.", cond.testCondition(record));
map.clear();
map.put(colName, 3);
assertTrue("The value D was falsely rejected.", cond.testCondition(record));
}
use of org.knime.base.node.mine.treeensemble2.data.TestDataGenerator in project knime-core by knime.
the class TreeNumericColumnDataTest method testUpdateChildMemberships.
/**
* Tests the {@link TreeNumericColumnData#updateChildMemberships(TreeNodeCondition, DataMemberships)} methods with
* different conditions including missing values.
*
* @throws Exception
*/
@Test
public void testUpdateChildMemberships() throws Exception {
final TreeEnsembleLearnerConfiguration config = createConfig();
final TestDataGenerator dataGen = new TestDataGenerator(config);
final int[] indices = new int[] { 0, 1, 2, 3, 4, 5, 6 };
final double[] weights = new double[7];
Arrays.fill(weights, 1.0);
final DataMemberships dataMem = new MockDataColMem(indices, indices, weights);
final String noMissingsCSV = "-50, -3, -2, 2, 25, 100, 101";
final TreeNumericColumnData col = dataGen.createNumericAttributeColumn(noMissingsCSV, "noMissings-col", 0);
// less than or equals
TreeNodeNumericCondition numCond = new TreeNodeNumericCondition(col.getMetaData(), -2, NumericOperator.LessThanOrEqual, false);
BitSet inChild = col.updateChildMemberships(numCond, dataMem);
BitSet expected = new BitSet(3);
expected.set(0, 3);
assertEquals("The produced BitSet is incorrect.", expected, inChild);
// greater than
numCond = new TreeNodeNumericCondition(col.getMetaData(), 10, NumericOperator.LargerThan, false);
inChild = col.updateChildMemberships(numCond, dataMem);
expected.clear();
expected.set(4, 7);
assertEquals("The produced BitSet is incorrect", expected, inChild);
// with missing values
final String missingsCSV = "-2, 0, 1, 43, 61, 66, NaN";
final TreeNumericColumnData colWithMissings = dataGen.createNumericAttributeColumn(missingsCSV, "missings-col", 0);
// less than or equal or missing
numCond = new TreeNodeNumericCondition(colWithMissings.getMetaData(), 12, NumericOperator.LessThanOrEqual, true);
inChild = colWithMissings.updateChildMemberships(numCond, dataMem);
expected.clear();
expected.set(0, 3);
expected.set(6);
assertEquals("The produced BitSet is incorrect", expected, inChild);
// less than or equals not missing
numCond = new TreeNodeNumericCondition(colWithMissings.getMetaData(), 12, NumericOperator.LessThanOrEqual, false);
inChild = colWithMissings.updateChildMemberships(numCond, dataMem);
expected.clear();
expected.set(0, 3);
assertEquals("The produced BitSet is incorrect", expected, inChild);
// larger than or missing
numCond = new TreeNodeNumericCondition(colWithMissings.getMetaData(), 43, NumericOperator.LargerThan, true);
inChild = colWithMissings.updateChildMemberships(numCond, dataMem);
expected.clear();
expected.set(4, 7);
assertEquals("The produced BitSet is incorrect", expected, inChild);
// larger than not missing
numCond = new TreeNodeNumericCondition(colWithMissings.getMetaData(), 12, NumericOperator.LargerThan, false);
inChild = colWithMissings.updateChildMemberships(numCond, dataMem);
expected.clear();
expected.set(3, 6);
assertEquals("The produced BitSet is incorrect", expected, inChild);
}
use of org.knime.base.node.mine.treeensemble2.data.TestDataGenerator in project knime-core by knime.
the class TreeNumericColumnDataTest method testXGBoostMissingValueHandling.
/**
* This method tests if the conditions for child nodes are correct in case of XGBoostMissingValueHandling
*
* @throws Exception
*/
@Test
public void testXGBoostMissingValueHandling() throws Exception {
TreeEnsembleLearnerConfiguration config = createConfig();
config.setMissingValueHandling(MissingValueHandling.XGBoost);
final TestDataGenerator dataGen = new TestDataGenerator(config);
final RandomData rd = config.createRandomData();
final int[] indices = new int[] { 0, 1, 2, 3, 4, 5, 6, 7, 8, 9 };
final double[] weights = new double[10];
Arrays.fill(weights, 1.0);
final MockDataColMem dataMem = new MockDataColMem(indices, indices, weights);
final String dataCSV = "1,2,2,3,4,5,6,7,NaN,NaN";
final String target1CSV = "A,A,A,A,B,B,B,B,A,A";
final String target2CSV = "A,A,A,A,B,B,B,B,B,B";
final double expectedGain = 0.48;
final TreeNumericColumnData col = dataGen.createNumericAttributeColumn(dataCSV, "testCol", 0);
final TreeTargetNominalColumnData target1 = TestDataGenerator.createNominalTargetColumn(target1CSV);
final SplitCandidate split1 = col.calcBestSplitClassification(dataMem, target1.getDistribution(weights, config), target1, rd);
assertEquals("Wrong gain.", expectedGain, split1.getGainValue(), 1e-8);
final TreeNodeCondition[] childConds1 = split1.getChildConditions();
final TreeNodeNumericCondition numCondLeft1 = (TreeNodeNumericCondition) childConds1[0];
assertEquals("Wrong split point.", 3.5, numCondLeft1.getSplitValue(), 1e-8);
assertTrue("Missings were not sent in the correct direction.", numCondLeft1.acceptsMissings());
final TreeNodeNumericCondition numCondRight1 = (TreeNodeNumericCondition) childConds1[1];
assertEquals("Wrong split point.", 3.5, numCondRight1.getSplitValue(), 1e-8);
assertFalse("Missings were not sent in the correct direction.", numCondRight1.acceptsMissings());
final TreeTargetNominalColumnData target2 = TestDataGenerator.createNominalTargetColumn(target2CSV);
final SplitCandidate split2 = col.calcBestSplitClassification(dataMem, target2.getDistribution(weights, config), target2, rd);
assertEquals("Wrong gain.", expectedGain, split2.getGainValue(), 1e-8);
final TreeNodeCondition[] childConds2 = split2.getChildConditions();
final TreeNodeNumericCondition numCondLeft2 = (TreeNodeNumericCondition) childConds2[0];
assertEquals("Wrong split point.", 3.5, numCondLeft2.getSplitValue(), 1e-8);
assertFalse("Missings were not sent in the correct direction.", numCondLeft2.acceptsMissings());
final TreeNodeNumericCondition numCondRight2 = (TreeNodeNumericCondition) childConds2[1];
assertEquals("Wrong split point.", 3.5, numCondRight2.getSplitValue(), 1e-8);
assertTrue("Missings were not sent in the correct direction.", numCondRight2.acceptsMissings());
}
use of org.knime.base.node.mine.treeensemble2.data.TestDataGenerator in project knime-core by knime.
the class TreeNumericColumnDataTest method testCalcBestSplitRegression.
@Test
public void testCalcBestSplitRegression() throws InvalidSettingsException {
String dataCSV = "1,2,3,4,5,6,7,8,9,10";
String targetCSV = "1,5,4,4.3,6.5,6.5,4,3,3,4";
TreeEnsembleLearnerConfiguration config = new TreeEnsembleLearnerConfiguration(true);
config.setNrModels(1);
config.setDataSelectionWithReplacement(false);
config.setUseDifferentAttributesAtEachNode(false);
config.setDataFractionPerTree(1.0);
config.setColumnSamplingMode(ColumnSamplingMode.None);
TestDataGenerator dataGen = new TestDataGenerator(config);
RandomData rd = config.createRandomData();
TreeTargetNumericColumnData target = TestDataGenerator.createNumericTargetColumn(targetCSV);
TreeNumericColumnData attribute = dataGen.createNumericAttributeColumn(dataCSV, "test-col", 0);
TreeData data = new TreeData(new TreeAttributeColumnData[] { attribute }, target, TreeType.Ordinary);
double[] weights = new double[10];
Arrays.fill(weights, 1.0);
DataMemberships rootMem = new RootDataMemberships(weights, data, new DefaultDataIndexManager(data));
SplitCandidate firstSplit = attribute.calcBestSplitRegression(rootMem, target.getPriors(rootMem, config), target, rd);
// calculated via OpenOffice calc
assertEquals(10.885444, firstSplit.getGainValue(), 1e-5);
TreeNodeCondition[] firstConditions = firstSplit.getChildConditions();
assertEquals(2, firstConditions.length);
for (int i = 0; i < firstConditions.length; i++) {
assertThat(firstConditions[i], instanceOf(TreeNodeNumericCondition.class));
TreeNodeNumericCondition numCond = (TreeNodeNumericCondition) firstConditions[i];
assertEquals(1.5, numCond.getSplitValue(), 0);
}
// left child contains only one row therefore only look at right child
BitSet expectedInChild = new BitSet(10);
expectedInChild.set(1, 10);
BitSet inChild = attribute.updateChildMemberships(firstConditions[1], rootMem);
assertEquals(expectedInChild, inChild);
DataMemberships childMem = rootMem.createChildMemberships(inChild);
SplitCandidate secondSplit = attribute.calcBestSplitRegression(childMem, target.getPriors(childMem, config), target, rd);
assertEquals(6.883555, secondSplit.getGainValue(), 1e-5);
TreeNodeCondition[] secondConditions = secondSplit.getChildConditions();
for (int i = 0; i < secondConditions.length; i++) {
assertThat(secondConditions[i], instanceOf(TreeNodeNumericCondition.class));
TreeNodeNumericCondition numCond = (TreeNodeNumericCondition) secondConditions[i];
assertEquals(6.5, numCond.getSplitValue(), 0);
}
}
use of org.knime.base.node.mine.treeensemble2.data.TestDataGenerator in project knime-core by knime.
the class TreeTargetNumericColumnDataTest method testGetPriors.
/**
* Tests the {@link TreeTargetNumericColumnData#getPriors(DataMemberships, TreeEnsembleLearnerConfiguration)} and
* {@link TreeTargetNumericColumnData#getPriors(double[], TreeEnsembleLearnerConfiguration)} methods.
*/
@Test
public void testGetPriors() {
String targetCSV = "1,4,3,5,6,7,8,12,22,1";
// irrelevant but necessary to build TreeDataObject
String someAttributeCSV = "A,B,A,B,A,A,B,A,A,B";
TreeEnsembleLearnerConfiguration config = new TreeEnsembleLearnerConfiguration(true);
TestDataGenerator dataGen = new TestDataGenerator(config);
TreeTargetNumericColumnData target = TestDataGenerator.createNumericTargetColumn(targetCSV);
TreeNominalColumnData attribute = dataGen.createNominalAttributeColumn(someAttributeCSV, "test-col", 0);
TreeData data = new TreeData(new TreeAttributeColumnData[] { attribute }, target, TreeType.Ordinary);
double[] weights = new double[10];
Arrays.fill(weights, 1.0);
DataMemberships rootMem = new RootDataMemberships(weights, data, new DefaultDataIndexManager(data));
RegressionPriors datMemPriors = target.getPriors(rootMem, config);
assertEquals(6.9, datMemPriors.getMean(), DELTA);
assertEquals(69, datMemPriors.getYSum(), DELTA);
assertEquals(352.9, datMemPriors.getSumSquaredDeviation(), DELTA);
}
Aggregations