use of org.knime.base.node.mine.treeensemble2.data.TestDataGenerator in project knime-core by knime.
the class RootDescendantDataMembershipsTest method testCreateChildDataMemberships.
@Test
public void testCreateChildDataMemberships() {
TreeEnsembleLearnerConfiguration config = new TreeEnsembleLearnerConfiguration(false);
TestDataGenerator dataGen = new TestDataGenerator(config);
TreeData data = dataGen.createTennisData();
DefaultDataIndexManager indexManager = new DefaultDataIndexManager(data);
int nrRows = data.getNrRows();
RowSample rowSample = new DefaultRowSample(nrRows);
RootDataMemberships rootMemberships = new RootDataMemberships(rowSample, data, indexManager);
BitSet firstHalf = new BitSet(nrRows);
firstHalf.set(0, nrRows / 2);
DataMemberships firstHalfChildMemberships = rootMemberships.createChildMemberships(firstHalf);
assertThat(firstHalfChildMemberships, instanceOf(BitSetDescendantDataMemberships.class));
BitSetDescendantDataMemberships bitSetFirstHalfChildMemberships = (BitSetDescendantDataMemberships) firstHalfChildMemberships;
assertEquals(firstHalf, bitSetFirstHalfChildMemberships.getBitSet());
BitSet firstQuarter = new BitSet(nrRows);
firstQuarter.set(0, nrRows / 4);
DataMemberships firstQuarterGrandChild = firstHalfChildMemberships.createChildMemberships(firstQuarter);
assertThat(firstQuarterGrandChild, instanceOf(BitSetDescendantDataMemberships.class));
BitSetDescendantDataMemberships bitSetFirstQuarterGrandChild = (BitSetDescendantDataMemberships) firstQuarterGrandChild;
assertEquals(firstQuarter, bitSetFirstQuarterGrandChild.getBitSet());
}
use of org.knime.base.node.mine.treeensemble2.data.TestDataGenerator in project knime-core by knime.
the class RootDescendantDataMembershipsTest method testGetColumnMemberships.
@Test
public void testGetColumnMemberships() {
TreeEnsembleLearnerConfiguration config = new TreeEnsembleLearnerConfiguration(false);
TestDataGenerator dataGen = new TestDataGenerator(config);
TreeData data = dataGen.createTennisData();
DefaultDataIndexManager indexManager = new DefaultDataIndexManager(data);
int nrRows = data.getNrRows();
RowSample rowSample = new DefaultRowSample(nrRows);
RootDataMemberships rootMemberships = new RootDataMemberships(rowSample, data, indexManager);
ColumnMemberships rootColMem = rootMemberships.getColumnMemberships(0);
assertThat(rootColMem, instanceOf(IntArrayColumnMemberships.class));
assertEquals(nrRows, rootColMem.size());
int[] expectedOriginalIndices = new int[] { 0, 1, 7, 8, 10, 2, 6, 11, 12, 3, 4, 5, 9, 13 };
for (int i = 0; rootColMem.next(); i++) {
// in this case originalIndex and indexInDataMemberships are the same
assertEquals(expectedOriginalIndices[i], rootColMem.getIndexInDataMemberships());
assertEquals(expectedOriginalIndices[i], rootColMem.getIndexInDataMemberships());
assertEquals(i, rootColMem.getIndexInColumn());
}
BitSet lastHalf = new BitSet(nrRows);
lastHalf.set(nrRows / 2, nrRows);
DataMemberships lastHalfChild = rootMemberships.createChildMemberships(lastHalf);
ColumnMemberships childColMem = lastHalfChild.getColumnMemberships(0);
assertThat(childColMem, instanceOf(DescendantColumnMemberships.class));
assertEquals(nrRows / 2, childColMem.size());
expectedOriginalIndices = new int[] { 7, 8, 10, 11, 12, 9, 13 };
int[] expectedIndexInColumn = new int[] { 2, 3, 4, 7, 8, 12, 13 };
int[] expectedIndexInDataMemberships = new int[] { 7, 8, 10, 11, 12, 9, 13 };
for (int i = 0; childColMem.next(); i++) {
assertEquals(expectedOriginalIndices[i], childColMem.getOriginalIndex());
assertEquals(expectedIndexInColumn[i], childColMem.getIndexInColumn());
assertEquals(expectedIndexInDataMemberships[i], childColMem.getIndexInDataMemberships());
}
}
use of org.knime.base.node.mine.treeensemble2.data.TestDataGenerator in project knime-core by knime.
the class TreeNodeNominalConditionTest method testToPMMLPredicate.
/**
* This method tests the {@link TreeNodeNominalCondition#toPMMLPredicate()} method.
*
* @throws Exception
*/
@Test
public void testToPMMLPredicate() throws Exception {
final TreeEnsembleLearnerConfiguration config = new TreeEnsembleLearnerConfiguration(false);
final TestDataGenerator dataGen = new TestDataGenerator(config);
final TreeNominalColumnData col = dataGen.createNominalAttributeColumn("A,A,B,C,C,D", "testcol", 0);
TreeNodeNominalCondition cond = new TreeNodeNominalCondition(col.getMetaData(), 3, false);
PMMLPredicate predicate = cond.toPMMLPredicate();
assertThat(predicate, instanceOf(PMMLSimplePredicate.class));
PMMLSimplePredicate simplePredicate = (PMMLSimplePredicate) predicate;
assertEquals("Wrong operator", PMMLOperator.EQUAL, simplePredicate.getOperator());
assertEquals("Wrong split value", "D", simplePredicate.getThreshold());
cond = new TreeNodeNominalCondition(col.getMetaData(), 0, true);
predicate = cond.toPMMLPredicate();
assertThat(predicate, instanceOf(PMMLCompoundPredicate.class));
PMMLCompoundPredicate compound = (PMMLCompoundPredicate) predicate;
assertEquals("Wrong boolean operator.", PMMLBooleanOperator.OR, compound.getBooleanOperator());
List<PMMLPredicate> preds;
preds = compound.getPredicates();
assertEquals("Wrong number of predicates in compound predicate.", 2, preds.size());
assertThat(preds.get(0), instanceOf(PMMLSimplePredicate.class));
simplePredicate = (PMMLSimplePredicate) preds.get(0);
assertEquals("Wrong operator", PMMLOperator.EQUAL, simplePredicate.getOperator());
assertEquals("Wrong split value", "A", simplePredicate.getThreshold());
assertEquals("Wrong attribute.", col.getMetaData().getAttributeName(), simplePredicate.getSplitAttribute());
assertThat(preds.get(1), instanceOf(PMMLSimplePredicate.class));
simplePredicate = (PMMLSimplePredicate) preds.get(1);
assertEquals("Should be isMissing", PMMLOperator.IS_MISSING, simplePredicate.getOperator());
}
use of org.knime.base.node.mine.treeensemble2.data.TestDataGenerator in project knime-core by knime.
the class TreeNodeNominalConditionTest method testTestCondition.
/**
* This method tests the
* {@link TreeNodeNominalCondition#testCondition(org.knime.base.node.mine.treeensemble2.data.PredictorRecord)}
* method
*
* @throws Exception
*/
@Test
public void testTestCondition() throws Exception {
final TreeEnsembleLearnerConfiguration config = new TreeEnsembleLearnerConfiguration(false);
final TestDataGenerator dataGen = new TestDataGenerator(config);
final TreeNominalColumnData col = dataGen.createNominalAttributeColumn("A,A,B,C,C,D", "testcol", 0);
TreeNodeNominalCondition cond = new TreeNodeNominalCondition(col.getMetaData(), 3, false);
final Map<String, Object> map = Maps.newHashMap();
final String colName = col.getMetaData().getAttributeName();
map.put(colName, 0);
final PredictorRecord record = new PredictorRecord(map);
assertFalse("The value A was falsely accepted", cond.testCondition(record));
map.clear();
map.put(colName, 1);
assertFalse("The value B was falsely accepted", cond.testCondition(record));
map.clear();
map.put(colName, 2);
assertFalse("The value C was falsely accepted", cond.testCondition(record));
map.clear();
map.put(colName, 3);
assertTrue("The value D was falsely rejected", cond.testCondition(record));
map.clear();
map.put(colName, PredictorRecord.NULL);
assertFalse("Missing values were falsely accepted", cond.testCondition(record));
cond = new TreeNodeNominalCondition(col.getMetaData(), 0, true);
map.clear();
map.put(colName, 0);
assertTrue("The value A was falsely rejected", cond.testCondition(record));
map.clear();
map.put(colName, 1);
assertFalse("The value B was falsely accepted", cond.testCondition(record));
map.clear();
map.put(colName, 2);
assertFalse("The value C was falsely accepted", cond.testCondition(record));
map.clear();
map.put(colName, 3);
assertFalse("The value D was falsely accepted", cond.testCondition(record));
map.clear();
map.put(colName, PredictorRecord.NULL);
assertTrue("Missing values were falsely rejected", cond.testCondition(record));
}
use of org.knime.base.node.mine.treeensemble2.data.TestDataGenerator in project knime-core by knime.
the class TreeNominalColumnDataTest method testCalcBestSplitCassificationBinaryTwoClassXGBoostMissingValue1.
/**
* Tests the XGBoost missing value handling variant, where for each split it is tried which direction for missing
* values provides the better gain.
*
* @throws Exception
*/
@Test
public void testCalcBestSplitCassificationBinaryTwoClassXGBoostMissingValue1() throws Exception {
final TreeEnsembleLearnerConfiguration config = createConfig(false);
config.setMissingValueHandling(MissingValueHandling.XGBoost);
final TestDataGenerator dataGen = new TestDataGenerator(config);
// check correct behavior if no missing values are encountered during split search
Pair<TreeNominalColumnData, TreeTargetNominalColumnData> twoClassTennisData = twoClassTennisData(config);
String dataContainingMissingsCSV = "S,?,O,R,S,R,S,?,O,?";
final TreeNominalColumnData columnData = dataGen.createNominalAttributeColumn(dataContainingMissingsCSV, "column containing missing values", 0);
final TreeTargetNominalColumnData target = twoClassTennisData.getSecond();
double[] rowWeights = new double[TWO_CLASS_INDICES.length];
Arrays.fill(rowWeights, 1.0);
// based on the ordering in the columnData
final int[] originalIndex = new int[] { 0, 4, 6, 2, 8, 3, 5, 1, 7, 9 };
final int[] columnIndex = new int[] { 0, 1, 2, 3, 4, 5, 6, 7, 8, 9 };
final DataMemberships dataMem = new MockDataColMem(originalIndex, columnIndex, rowWeights);
final SplitCandidate split = columnData.calcBestSplitClassification(dataMem, target.getDistribution(rowWeights, config), target, TestDataGenerator.createRandomData());
assertThat(split, instanceOf(NominalBinarySplitCandidate.class));
final NominalBinarySplitCandidate nomSplit = (NominalBinarySplitCandidate) split;
TreeNodeNominalBinaryCondition[] childConditions = nomSplit.getChildConditions();
assertEquals("Wrong gain value.", 0.18, nomSplit.getGainValue(), 1e-8);
final String[] conditionValues = new String[] { "S", "R" };
assertArrayEquals("Values in nominal condition did not match", conditionValues, childConditions[0].getValues());
assertArrayEquals("Values in nominal condition did not match", conditionValues, childConditions[1].getValues());
assertEquals("Wrong set logic.", SetLogic.IS_NOT_IN, childConditions[0].getSetLogic());
assertEquals("Wrong set logic.", SetLogic.IS_IN, childConditions[1].getSetLogic());
assertTrue("Missing values are not sent to the correct child.", childConditions[0].acceptsMissings());
assertFalse("Missing values are not sent to the correct child.", childConditions[1].acceptsMissings());
}
Aggregations