use of org.knime.base.node.mine.treeensemble2.node.learner.TreeEnsembleLearnerConfiguration in project knime-core by knime.
the class TreeNominalColumnDataTest method testCalcBestSplitClassificationBinaryTwoClass.
/**
* Tests the method
* {@link TreeNominalColumnData#calcBestSplitClassification(DataMemberships, ClassificationPriors, TreeTargetNominalColumnData, RandomData)}
* in case of a two class problem.
*
* @throws Exception
*/
@Test
public void testCalcBestSplitClassificationBinaryTwoClass() throws Exception {
TreeEnsembleLearnerConfiguration config = createConfig(false);
config.setMissingValueHandling(MissingValueHandling.Surrogate);
Pair<TreeNominalColumnData, TreeTargetNominalColumnData> twoClassTennisData = twoClassTennisData(config);
TreeNominalColumnData columnData = twoClassTennisData.getFirst();
TreeTargetNominalColumnData targetData = twoClassTennisData.getSecond();
TreeData twoClassTennisTreeData = twoClassTennisTreeData(config);
IDataIndexManager indexManager = new DefaultDataIndexManager(twoClassTennisTreeData);
assertEquals(SplitCriterion.Gini, config.getSplitCriterion());
double[] rowWeights = new double[TWO_CLASS_INDICES.length];
Arrays.fill(rowWeights, 1.0);
// DataMemberships dataMemberships = TestDataGenerator.createMockDataMemberships(TWO_CLASS_INDICES.length);
DataMemberships dataMemberships = new RootDataMemberships(rowWeights, twoClassTennisTreeData, indexManager);
ClassificationPriors priors = targetData.getDistribution(rowWeights, config);
SplitCandidate splitCandidate = columnData.calcBestSplitClassification(dataMemberships, priors, targetData, null);
assertNotNull(splitCandidate);
assertThat(splitCandidate, instanceOf(NominalBinarySplitCandidate.class));
assertTrue(splitCandidate.canColumnBeSplitFurther());
// manually via open office calc
assertEquals(0.1371428, splitCandidate.getGainValue(), 0.00001);
NominalBinarySplitCandidate binSplitCandidate = (NominalBinarySplitCandidate) splitCandidate;
TreeNodeNominalBinaryCondition[] childConditions = binSplitCandidate.getChildConditions();
assertEquals(2, childConditions.length);
assertArrayEquals(new String[] { "R" }, childConditions[0].getValues());
assertArrayEquals(new String[] { "R" }, childConditions[1].getValues());
assertEquals(SetLogic.IS_NOT_IN, childConditions[0].getSetLogic());
assertEquals(SetLogic.IS_IN, childConditions[1].getSetLogic());
assertFalse(childConditions[0].acceptsMissings());
assertFalse(childConditions[1].acceptsMissings());
}
use of org.knime.base.node.mine.treeensemble2.node.learner.TreeEnsembleLearnerConfiguration in project knime-core by knime.
the class TreeNominalColumnDataTest method testCalcBestSplitClassificationBinaryPCAXGBoostMissingValueHandling.
/**
* Tests the XGBoost missing value handling in the case of binary splits calculated with the pca method (multiple classes)
*
* @throws Exception
*/
@Test
public void testCalcBestSplitClassificationBinaryPCAXGBoostMissingValueHandling() throws Exception {
final TreeEnsembleLearnerConfiguration config = createConfig(false);
config.setMissingValueHandling(MissingValueHandling.XGBoost);
final TestDataGenerator dataGen = new TestDataGenerator(config);
final RandomData rd = config.createRandomData();
// test the case that there are no missing values in the training data
final String noMissingCSV = "a, a, a, b, b, b, b, c, c";
final String noMissingTarget = "A, B, B, C, C, C, B, A, B";
TreeNominalColumnData dataCol = dataGen.createNominalAttributeColumn(noMissingCSV, "noMissings", 0);
TreeTargetNominalColumnData targetCol = TestDataGenerator.createNominalTargetColumn(noMissingTarget);
DataMemberships dataMem = createMockDataMemberships(targetCol.getNrRows());
SplitCandidate split = dataCol.calcBestSplitClassification(dataMem, targetCol.getDistribution(dataMem, config), targetCol, rd);
assertNotNull("There is a possible split.", split);
assertEquals("Incorrect gain.", 0.2086, split.getGainValue(), 1e-3);
assertThat(split, instanceOf(NominalBinarySplitCandidate.class));
NominalBinarySplitCandidate nomSplit = (NominalBinarySplitCandidate) split;
assertTrue("No missing values in the column.", nomSplit.getMissedRows().isEmpty());
TreeNodeNominalBinaryCondition[] conditions = nomSplit.getChildConditions();
assertEquals("A binary split must have 2 child conditions.", 2, conditions.length);
String[] values = new String[] { "a", "c" };
assertArrayEquals("Wrong values in child condition.", values, conditions[0].getValues());
assertArrayEquals("Wrong values in child condition.", values, conditions[1].getValues());
assertEquals("Wrong set logic.", SetLogic.IS_NOT_IN, conditions[0].getSetLogic());
assertEquals("Wrong set logic.", SetLogic.IS_IN, conditions[1].getSetLogic());
assertFalse("Missing values should be sent to the majority child (i.e. right)", conditions[0].acceptsMissings());
assertTrue("Missing values should be sent to the majority child (i.e. right)", conditions[1].acceptsMissings());
// test the case that there are missing values in the training data
final String missingCSV = "a, a, a, b, b, b, b, c, c, ?";
final String missingTarget = "A, B, B, C, C, C, B, A, B, C";
dataCol = dataGen.createNominalAttributeColumn(missingCSV, "missings", 0);
targetCol = TestDataGenerator.createNominalTargetColumn(missingTarget);
dataMem = createMockDataMemberships(targetCol.getNrRows());
split = dataCol.calcBestSplitClassification(dataMem, targetCol.getDistribution(dataMem, config), targetCol, rd);
assertNotNull("There is a possible split.", split);
assertEquals("Incorrect gain.", 0.24, split.getGainValue(), 1e-3);
assertThat(split, instanceOf(NominalBinarySplitCandidate.class));
nomSplit = (NominalBinarySplitCandidate) split;
assertTrue("Split should handle missing values.", nomSplit.getMissedRows().isEmpty());
conditions = nomSplit.getChildConditions();
assertEquals("Wrong number of child conditions.", 2, conditions.length);
assertArrayEquals("Wrong values in child condition.", values, conditions[0].getValues());
assertArrayEquals("Wrong values in child condition.", values, conditions[1].getValues());
assertEquals("Wrong set logic.", SetLogic.IS_NOT_IN, conditions[0].getSetLogic());
assertEquals("Wrong set logic.", SetLogic.IS_IN, conditions[1].getSetLogic());
assertTrue("Missing values should be sent to left child", conditions[0].acceptsMissings());
assertFalse("Missing values should be sent to left child", conditions[1].acceptsMissings());
}
use of org.knime.base.node.mine.treeensemble2.node.learner.TreeEnsembleLearnerConfiguration in project knime-core by knime.
the class TreeNominalColumnDataTest method testCalcBestSplitRegressionMultiway.
/**
* Tests the method
* {@link TreeNominalColumnData#calcBestSplitRegression(DataMemberships, RegressionPriors, TreeTargetNumericColumnData, RandomData)}
* using multiway splits.
*
* @throws Exception
*/
@Test
public void testCalcBestSplitRegressionMultiway() throws Exception {
TreeEnsembleLearnerConfiguration config = createConfig(true);
config.setUseBinaryNominalSplits(false);
Pair<TreeNominalColumnData, TreeTargetNumericColumnData> tennisDataRegression = tennisDataRegression(config);
TreeNominalColumnData columnData = tennisDataRegression.getFirst();
TreeTargetNumericColumnData targetData = tennisDataRegression.getSecond();
TreeData treeData = createTreeDataRegression(tennisDataRegression);
double[] rowWeights = new double[SMALL_COLUMN_DATA.length];
Arrays.fill(rowWeights, 1.0);
IDataIndexManager indexManager = new DefaultDataIndexManager(treeData);
DataMemberships dataMemberships = new RootDataMemberships(rowWeights, treeData, indexManager);
RegressionPriors priors = targetData.getPriors(rowWeights, config);
SplitCandidate splitCandidate = columnData.calcBestSplitRegression(dataMemberships, priors, targetData, null);
assertNotNull(splitCandidate);
assertThat(splitCandidate, instanceOf(NominalMultiwaySplitCandidate.class));
assertFalse(splitCandidate.canColumnBeSplitFurther());
assertEquals(36.9643, splitCandidate.getGainValue(), 0.0001);
NominalMultiwaySplitCandidate multiwaySplitCandidate = (NominalMultiwaySplitCandidate) splitCandidate;
TreeNodeNominalCondition[] childConditions = multiwaySplitCandidate.getChildConditions();
assertEquals(3, childConditions.length);
assertEquals("S", childConditions[0].getValue());
assertEquals("O", childConditions[1].getValue());
assertEquals("R", childConditions[2].getValue());
}
use of org.knime.base.node.mine.treeensemble2.node.learner.TreeEnsembleLearnerConfiguration in project knime-core by knime.
the class TreeNominalColumnDataTest method testCalcBestSplitRegressionBinary.
/**
* Tests the method
* {@link TreeNominalColumnData#calcBestSplitRegression(DataMemberships, RegressionPriors, TreeTargetNumericColumnData, RandomData)}
* using binary splits
*
* @throws Exception
*/
@Test
public void testCalcBestSplitRegressionBinary() throws Exception {
TreeEnsembleLearnerConfiguration config = new TreeEnsembleLearnerConfiguration(true);
Pair<TreeNominalColumnData, TreeTargetNumericColumnData> tennisDataRegression = tennisDataRegression(config);
TreeNominalColumnData columnData = tennisDataRegression.getFirst();
TreeTargetNumericColumnData targetData = tennisDataRegression.getSecond();
TreeData treeData = createTreeDataRegression(tennisDataRegression);
double[] rowWeights = new double[SMALL_COLUMN_DATA.length];
Arrays.fill(rowWeights, 1.0);
IDataIndexManager indexManager = new DefaultDataIndexManager(treeData);
DataMemberships dataMemberships = new RootDataMemberships(rowWeights, treeData, indexManager);
RegressionPriors priors = targetData.getPriors(rowWeights, config);
SplitCandidate splitCandidate = columnData.calcBestSplitRegression(dataMemberships, priors, targetData, null);
assertNotNull(splitCandidate);
assertThat(splitCandidate, instanceOf(NominalBinarySplitCandidate.class));
assertTrue(splitCandidate.canColumnBeSplitFurther());
assertEquals(32.9143, splitCandidate.getGainValue(), 0.0001);
NominalBinarySplitCandidate binarySplitCandidate = (NominalBinarySplitCandidate) splitCandidate;
TreeNodeNominalBinaryCondition[] childConditions = binarySplitCandidate.getChildConditions();
assertEquals(2, childConditions.length);
assertArrayEquals(new String[] { "R" }, childConditions[0].getValues());
assertArrayEquals(new String[] { "R" }, childConditions[1].getValues());
assertEquals(SetLogic.IS_NOT_IN, childConditions[0].getSetLogic());
assertEquals(SetLogic.IS_IN, childConditions[1].getSetLogic());
}
use of org.knime.base.node.mine.treeensemble2.node.learner.TreeEnsembleLearnerConfiguration in project knime-core by knime.
the class TreeNominalColumnDataTest method testCalcBestSplitClassificationBinary.
/**
* Tests the method
* {@link TreeNominalColumnData#calcBestSplitClassification(DataMemberships, ClassificationPriors, TreeTargetNominalColumnData, RandomData)}
* using binary splits.
*
* @throws Exception
*/
@Test
public void testCalcBestSplitClassificationBinary() throws Exception {
final TreeEnsembleLearnerConfiguration config = createConfig(false);
Pair<TreeNominalColumnData, TreeTargetNominalColumnData> tennisData = tennisData(config);
TreeNominalColumnData columnData = tennisData.getFirst();
TreeTargetNominalColumnData targetData = tennisData.getSecond();
assertEquals(SplitCriterion.Gini, config.getSplitCriterion());
double[] rowWeights = new double[SMALL_COLUMN_DATA.length];
Arrays.fill(rowWeights, 1.0);
TreeData tennisTreeData = tennisTreeData(config);
IDataIndexManager indexManager = new DefaultDataIndexManager(tennisTreeData);
DataMemberships dataMemberships = new RootDataMemberships(rowWeights, tennisTreeData, indexManager);
ClassificationPriors priors = targetData.getDistribution(rowWeights, config);
SplitCandidate splitCandidate = columnData.calcBestSplitClassification(dataMemberships, priors, targetData, null);
assertNotNull(splitCandidate);
assertThat(splitCandidate, instanceOf(NominalBinarySplitCandidate.class));
assertTrue(splitCandidate.canColumnBeSplitFurther());
// manually via libre office calc
assertEquals(0.0689342404, splitCandidate.getGainValue(), 0.00001);
NominalBinarySplitCandidate binSplitCandidate = (NominalBinarySplitCandidate) splitCandidate;
TreeNodeNominalBinaryCondition[] childConditions = binSplitCandidate.getChildConditions();
assertEquals(2, childConditions.length);
assertArrayEquals(new String[] { "R" }, childConditions[0].getValues());
assertArrayEquals(new String[] { "R" }, childConditions[1].getValues());
assertEquals(SetLogic.IS_NOT_IN, childConditions[0].getSetLogic());
assertEquals(SetLogic.IS_IN, childConditions[1].getSetLogic());
BitSet inChild = columnData.updateChildMemberships(childConditions[0], dataMemberships);
DataMemberships child1Memberships = dataMemberships.createChildMemberships(inChild);
ClassificationPriors childTargetPriors = targetData.getDistribution(child1Memberships, config);
SplitCandidate splitCandidateChild = columnData.calcBestSplitClassification(child1Memberships, childTargetPriors, targetData, null);
assertNotNull(splitCandidateChild);
assertThat(splitCandidateChild, instanceOf(NominalBinarySplitCandidate.class));
// manually via libre office calc
assertEquals(0.0086419753, splitCandidateChild.getGainValue(), 0.00001);
inChild = columnData.updateChildMemberships(childConditions[1], dataMemberships);
DataMemberships child2Memberships = dataMemberships.createChildMemberships(inChild);
childTargetPriors = targetData.getDistribution(child2Memberships, config);
splitCandidateChild = columnData.calcBestSplitClassification(child2Memberships, childTargetPriors, targetData, null);
assertNull(splitCandidateChild);
}
Aggregations