Search in sources :

Example 31 with TreeData

use of org.knime.base.node.mine.treeensemble2.data.TreeData in project knime-core by knime.

the class TreeNominalColumnDataTest method testCalcBestSplitRegressionMultiway.

/**
 * Tests the method
 * {@link TreeNominalColumnData#calcBestSplitRegression(DataMemberships, RegressionPriors, TreeTargetNumericColumnData, RandomData)}
 * using multiway splits.
 *
 * @throws Exception
 */
@Test
public void testCalcBestSplitRegressionMultiway() throws Exception {
    TreeEnsembleLearnerConfiguration config = createConfig(true);
    config.setUseBinaryNominalSplits(false);
    Pair<TreeNominalColumnData, TreeTargetNumericColumnData> tennisDataRegression = tennisDataRegression(config);
    TreeNominalColumnData columnData = tennisDataRegression.getFirst();
    TreeTargetNumericColumnData targetData = tennisDataRegression.getSecond();
    TreeData treeData = createTreeDataRegression(tennisDataRegression);
    double[] rowWeights = new double[SMALL_COLUMN_DATA.length];
    Arrays.fill(rowWeights, 1.0);
    IDataIndexManager indexManager = new DefaultDataIndexManager(treeData);
    DataMemberships dataMemberships = new RootDataMemberships(rowWeights, treeData, indexManager);
    RegressionPriors priors = targetData.getPriors(rowWeights, config);
    SplitCandidate splitCandidate = columnData.calcBestSplitRegression(dataMemberships, priors, targetData, null);
    assertNotNull(splitCandidate);
    assertThat(splitCandidate, instanceOf(NominalMultiwaySplitCandidate.class));
    assertFalse(splitCandidate.canColumnBeSplitFurther());
    assertEquals(36.9643, splitCandidate.getGainValue(), 0.0001);
    NominalMultiwaySplitCandidate multiwaySplitCandidate = (NominalMultiwaySplitCandidate) splitCandidate;
    TreeNodeNominalCondition[] childConditions = multiwaySplitCandidate.getChildConditions();
    assertEquals(3, childConditions.length);
    assertEquals("S", childConditions[0].getValue());
    assertEquals("O", childConditions[1].getValue());
    assertEquals("R", childConditions[2].getValue());
}
Also used : TreeEnsembleLearnerConfiguration(org.knime.base.node.mine.treeensemble2.node.learner.TreeEnsembleLearnerConfiguration) RootDataMemberships(org.knime.base.node.mine.treeensemble2.data.memberships.RootDataMemberships) TreeNodeNominalCondition(org.knime.base.node.mine.treeensemble2.model.TreeNodeNominalCondition) IDataIndexManager(org.knime.base.node.mine.treeensemble2.data.memberships.IDataIndexManager) NominalMultiwaySplitCandidate(org.knime.base.node.mine.treeensemble2.learner.NominalMultiwaySplitCandidate) NominalBinarySplitCandidate(org.knime.base.node.mine.treeensemble2.learner.NominalBinarySplitCandidate) SplitCandidate(org.knime.base.node.mine.treeensemble2.learner.SplitCandidate) DefaultDataIndexManager(org.knime.base.node.mine.treeensemble2.data.memberships.DefaultDataIndexManager) DataMemberships(org.knime.base.node.mine.treeensemble2.data.memberships.DataMemberships) RootDataMemberships(org.knime.base.node.mine.treeensemble2.data.memberships.RootDataMemberships) NominalMultiwaySplitCandidate(org.knime.base.node.mine.treeensemble2.learner.NominalMultiwaySplitCandidate) Test(org.junit.Test)

Example 32 with TreeData

use of org.knime.base.node.mine.treeensemble2.data.TreeData in project knime-core by knime.

the class TreeNominalColumnDataTest method testCalcBestSplitRegressionBinary.

/**
 * Tests the method
 * {@link TreeNominalColumnData#calcBestSplitRegression(DataMemberships, RegressionPriors, TreeTargetNumericColumnData, RandomData)}
 * using binary splits
 *
 * @throws Exception
 */
@Test
public void testCalcBestSplitRegressionBinary() throws Exception {
    TreeEnsembleLearnerConfiguration config = new TreeEnsembleLearnerConfiguration(true);
    Pair<TreeNominalColumnData, TreeTargetNumericColumnData> tennisDataRegression = tennisDataRegression(config);
    TreeNominalColumnData columnData = tennisDataRegression.getFirst();
    TreeTargetNumericColumnData targetData = tennisDataRegression.getSecond();
    TreeData treeData = createTreeDataRegression(tennisDataRegression);
    double[] rowWeights = new double[SMALL_COLUMN_DATA.length];
    Arrays.fill(rowWeights, 1.0);
    IDataIndexManager indexManager = new DefaultDataIndexManager(treeData);
    DataMemberships dataMemberships = new RootDataMemberships(rowWeights, treeData, indexManager);
    RegressionPriors priors = targetData.getPriors(rowWeights, config);
    SplitCandidate splitCandidate = columnData.calcBestSplitRegression(dataMemberships, priors, targetData, null);
    assertNotNull(splitCandidate);
    assertThat(splitCandidate, instanceOf(NominalBinarySplitCandidate.class));
    assertTrue(splitCandidate.canColumnBeSplitFurther());
    assertEquals(32.9143, splitCandidate.getGainValue(), 0.0001);
    NominalBinarySplitCandidate binarySplitCandidate = (NominalBinarySplitCandidate) splitCandidate;
    TreeNodeNominalBinaryCondition[] childConditions = binarySplitCandidate.getChildConditions();
    assertEquals(2, childConditions.length);
    assertArrayEquals(new String[] { "R" }, childConditions[0].getValues());
    assertArrayEquals(new String[] { "R" }, childConditions[1].getValues());
    assertEquals(SetLogic.IS_NOT_IN, childConditions[0].getSetLogic());
    assertEquals(SetLogic.IS_IN, childConditions[1].getSetLogic());
}
Also used : TreeEnsembleLearnerConfiguration(org.knime.base.node.mine.treeensemble2.node.learner.TreeEnsembleLearnerConfiguration) RootDataMemberships(org.knime.base.node.mine.treeensemble2.data.memberships.RootDataMemberships) IDataIndexManager(org.knime.base.node.mine.treeensemble2.data.memberships.IDataIndexManager) NominalMultiwaySplitCandidate(org.knime.base.node.mine.treeensemble2.learner.NominalMultiwaySplitCandidate) NominalBinarySplitCandidate(org.knime.base.node.mine.treeensemble2.learner.NominalBinarySplitCandidate) SplitCandidate(org.knime.base.node.mine.treeensemble2.learner.SplitCandidate) DefaultDataIndexManager(org.knime.base.node.mine.treeensemble2.data.memberships.DefaultDataIndexManager) DataMemberships(org.knime.base.node.mine.treeensemble2.data.memberships.DataMemberships) RootDataMemberships(org.knime.base.node.mine.treeensemble2.data.memberships.RootDataMemberships) TreeNodeNominalBinaryCondition(org.knime.base.node.mine.treeensemble2.model.TreeNodeNominalBinaryCondition) NominalBinarySplitCandidate(org.knime.base.node.mine.treeensemble2.learner.NominalBinarySplitCandidate) Test(org.junit.Test)

Example 33 with TreeData

use of org.knime.base.node.mine.treeensemble2.data.TreeData in project knime-core by knime.

the class TreeNominalColumnDataTest method testCalcBestSplitClassificationBinary.

/**
 * Tests the method
 * {@link TreeNominalColumnData#calcBestSplitClassification(DataMemberships, ClassificationPriors, TreeTargetNominalColumnData, RandomData)}
 * using binary splits.
 *
 * @throws Exception
 */
@Test
public void testCalcBestSplitClassificationBinary() throws Exception {
    final TreeEnsembleLearnerConfiguration config = createConfig(false);
    Pair<TreeNominalColumnData, TreeTargetNominalColumnData> tennisData = tennisData(config);
    TreeNominalColumnData columnData = tennisData.getFirst();
    TreeTargetNominalColumnData targetData = tennisData.getSecond();
    assertEquals(SplitCriterion.Gini, config.getSplitCriterion());
    double[] rowWeights = new double[SMALL_COLUMN_DATA.length];
    Arrays.fill(rowWeights, 1.0);
    TreeData tennisTreeData = tennisTreeData(config);
    IDataIndexManager indexManager = new DefaultDataIndexManager(tennisTreeData);
    DataMemberships dataMemberships = new RootDataMemberships(rowWeights, tennisTreeData, indexManager);
    ClassificationPriors priors = targetData.getDistribution(rowWeights, config);
    SplitCandidate splitCandidate = columnData.calcBestSplitClassification(dataMemberships, priors, targetData, null);
    assertNotNull(splitCandidate);
    assertThat(splitCandidate, instanceOf(NominalBinarySplitCandidate.class));
    assertTrue(splitCandidate.canColumnBeSplitFurther());
    // manually via libre office calc
    assertEquals(0.0689342404, splitCandidate.getGainValue(), 0.00001);
    NominalBinarySplitCandidate binSplitCandidate = (NominalBinarySplitCandidate) splitCandidate;
    TreeNodeNominalBinaryCondition[] childConditions = binSplitCandidate.getChildConditions();
    assertEquals(2, childConditions.length);
    assertArrayEquals(new String[] { "R" }, childConditions[0].getValues());
    assertArrayEquals(new String[] { "R" }, childConditions[1].getValues());
    assertEquals(SetLogic.IS_NOT_IN, childConditions[0].getSetLogic());
    assertEquals(SetLogic.IS_IN, childConditions[1].getSetLogic());
    BitSet inChild = columnData.updateChildMemberships(childConditions[0], dataMemberships);
    DataMemberships child1Memberships = dataMemberships.createChildMemberships(inChild);
    ClassificationPriors childTargetPriors = targetData.getDistribution(child1Memberships, config);
    SplitCandidate splitCandidateChild = columnData.calcBestSplitClassification(child1Memberships, childTargetPriors, targetData, null);
    assertNotNull(splitCandidateChild);
    assertThat(splitCandidateChild, instanceOf(NominalBinarySplitCandidate.class));
    // manually via libre office calc
    assertEquals(0.0086419753, splitCandidateChild.getGainValue(), 0.00001);
    inChild = columnData.updateChildMemberships(childConditions[1], dataMemberships);
    DataMemberships child2Memberships = dataMemberships.createChildMemberships(inChild);
    childTargetPriors = targetData.getDistribution(child2Memberships, config);
    splitCandidateChild = columnData.calcBestSplitClassification(child2Memberships, childTargetPriors, targetData, null);
    assertNull(splitCandidateChild);
}
Also used : TreeEnsembleLearnerConfiguration(org.knime.base.node.mine.treeensemble2.node.learner.TreeEnsembleLearnerConfiguration) RootDataMemberships(org.knime.base.node.mine.treeensemble2.data.memberships.RootDataMemberships) BitSet(java.util.BitSet) IDataIndexManager(org.knime.base.node.mine.treeensemble2.data.memberships.IDataIndexManager) NominalMultiwaySplitCandidate(org.knime.base.node.mine.treeensemble2.learner.NominalMultiwaySplitCandidate) NominalBinarySplitCandidate(org.knime.base.node.mine.treeensemble2.learner.NominalBinarySplitCandidate) SplitCandidate(org.knime.base.node.mine.treeensemble2.learner.SplitCandidate) DefaultDataIndexManager(org.knime.base.node.mine.treeensemble2.data.memberships.DefaultDataIndexManager) DataMemberships(org.knime.base.node.mine.treeensemble2.data.memberships.DataMemberships) RootDataMemberships(org.knime.base.node.mine.treeensemble2.data.memberships.RootDataMemberships) TreeNodeNominalBinaryCondition(org.knime.base.node.mine.treeensemble2.model.TreeNodeNominalBinaryCondition) NominalBinarySplitCandidate(org.knime.base.node.mine.treeensemble2.learner.NominalBinarySplitCandidate) Test(org.junit.Test)

Example 34 with TreeData

use of org.knime.base.node.mine.treeensemble2.data.TreeData in project knime-core by knime.

the class TreeNominalColumnDataTest method testCalcBestSplitClassificationBinaryPCA.

/**
 * Tests the method
 * {@link TreeNominalColumnData#calcBestSplitClassification(DataMemberships, ClassificationPriors, TreeTargetNominalColumnData, RandomData)}
 * using binary splits. In this test case the data has more than two classes and the used algorithm is therefore PCA
 * based.
 *
 * @throws Exception
 */
@Test
public void testCalcBestSplitClassificationBinaryPCA() throws Exception {
    TreeEnsembleLearnerConfiguration config = createConfig(false);
    Pair<TreeNominalColumnData, TreeTargetNominalColumnData> pcaData = createPCATestData(config);
    TreeNominalColumnData columnData = pcaData.getFirst();
    TreeTargetNominalColumnData targetData = pcaData.getSecond();
    TreeData treeData = createTreeData(pcaData);
    assertEquals(SplitCriterion.Gini, config.getSplitCriterion());
    double[] rowWeights = new double[targetData.getNrRows()];
    Arrays.fill(rowWeights, 1.0);
    IDataIndexManager indexManager = new DefaultDataIndexManager(treeData);
    DataMemberships dataMemberships = new RootDataMemberships(rowWeights, treeData, indexManager);
    ClassificationPriors priors = targetData.getDistribution(rowWeights, config);
    SplitCandidate splitCandidate = columnData.calcBestSplitClassification(dataMemberships, priors, targetData, null);
    assertNotNull(splitCandidate);
    assertThat(splitCandidate, instanceOf(NominalBinarySplitCandidate.class));
    assertTrue(splitCandidate.canColumnBeSplitFurther());
    assertEquals(0.0659, splitCandidate.getGainValue(), 0.0001);
    NominalBinarySplitCandidate binarySplitCandidate = (NominalBinarySplitCandidate) splitCandidate;
    TreeNodeNominalBinaryCondition[] childConditions = binarySplitCandidate.getChildConditions();
    assertEquals(2, childConditions.length);
    assertArrayEquals(new String[] { "E" }, childConditions[0].getValues());
    assertArrayEquals(new String[] { "E" }, childConditions[1].getValues());
    assertEquals(SetLogic.IS_NOT_IN, childConditions[0].getSetLogic());
    assertEquals(SetLogic.IS_IN, childConditions[1].getSetLogic());
}
Also used : TreeEnsembleLearnerConfiguration(org.knime.base.node.mine.treeensemble2.node.learner.TreeEnsembleLearnerConfiguration) RootDataMemberships(org.knime.base.node.mine.treeensemble2.data.memberships.RootDataMemberships) IDataIndexManager(org.knime.base.node.mine.treeensemble2.data.memberships.IDataIndexManager) NominalMultiwaySplitCandidate(org.knime.base.node.mine.treeensemble2.learner.NominalMultiwaySplitCandidate) NominalBinarySplitCandidate(org.knime.base.node.mine.treeensemble2.learner.NominalBinarySplitCandidate) SplitCandidate(org.knime.base.node.mine.treeensemble2.learner.SplitCandidate) DefaultDataIndexManager(org.knime.base.node.mine.treeensemble2.data.memberships.DefaultDataIndexManager) DataMemberships(org.knime.base.node.mine.treeensemble2.data.memberships.DataMemberships) RootDataMemberships(org.knime.base.node.mine.treeensemble2.data.memberships.RootDataMemberships) TreeNodeNominalBinaryCondition(org.knime.base.node.mine.treeensemble2.model.TreeNodeNominalBinaryCondition) NominalBinarySplitCandidate(org.knime.base.node.mine.treeensemble2.learner.NominalBinarySplitCandidate) Test(org.junit.Test)

Example 35 with TreeData

use of org.knime.base.node.mine.treeensemble2.data.TreeData in project knime-core by knime.

the class TreeTargetNominalColumnDataTest method testGetDistribution.

/**
 * Tests the {@link TreeTargetNominalColumnData#getDistribution(DataMemberships, TreeEnsembleLearnerConfiguration)}
 * and {@link TreeTargetNominalColumnData#getDistribution(double[], TreeEnsembleLearnerConfiguration)} methods.
 * @throws InvalidSettingsException
 */
@Test
public void testGetDistribution() throws InvalidSettingsException {
    String targetCSV = "A,A,A,B,B,B,A";
    String attributeCSV = "1,2,3,4,5,6,7";
    TreeEnsembleLearnerConfiguration config = new TreeEnsembleLearnerConfiguration(false);
    TestDataGenerator dataGen = new TestDataGenerator(config);
    TreeTargetNominalColumnData target = TestDataGenerator.createNominalTargetColumn(targetCSV);
    TreeNumericColumnData attribute = dataGen.createNumericAttributeColumn(attributeCSV, "test-col", 0);
    TreeData data = new TreeData(new TreeAttributeColumnData[] { attribute }, target, TreeType.Ordinary);
    double[] weights = new double[7];
    Arrays.fill(weights, 1.0);
    DataMemberships rootMemberships = new RootDataMemberships(weights, data, new DefaultDataIndexManager(data));
    // Gini
    config.setSplitCriterion(SplitCriterion.Gini);
    double expectedGini = 0.4897959184;
    double[] expectedDistribution = new double[] { 4.0, 3.0 };
    ClassificationPriors giniPriorsDatMem = target.getDistribution(rootMemberships, config);
    assertEquals(expectedGini, giniPriorsDatMem.getPriorImpurity(), DELTA);
    assertArrayEquals(expectedDistribution, giniPriorsDatMem.getDistribution(), DELTA);
    ClassificationPriors giniPriorsWeights = target.getDistribution(weights, config);
    assertEquals(expectedGini, giniPriorsWeights.getPriorImpurity(), DELTA);
    assertArrayEquals(expectedDistribution, giniPriorsWeights.getDistribution(), DELTA);
    // Information Gain
    config.setSplitCriterion(SplitCriterion.InformationGain);
    double expectedEntropy = 0.985228136;
    ClassificationPriors igPriorsDatMem = target.getDistribution(rootMemberships, config);
    assertEquals(expectedEntropy, igPriorsDatMem.getPriorImpurity(), DELTA);
    assertArrayEquals(expectedDistribution, igPriorsDatMem.getDistribution(), DELTA);
    ClassificationPriors igPriorsWeights = target.getDistribution(weights, config);
    assertEquals(expectedEntropy, igPriorsWeights.getPriorImpurity(), DELTA);
    assertArrayEquals(expectedDistribution, igPriorsWeights.getDistribution(), DELTA);
    // Information Gain Ratio
    config.setSplitCriterion(SplitCriterion.InformationGainRatio);
    // prior impurity is the same as IG
    ClassificationPriors igrPriorsDatMem = target.getDistribution(rootMemberships, config);
    assertEquals(expectedEntropy, igrPriorsDatMem.getPriorImpurity(), DELTA);
    assertArrayEquals(expectedDistribution, igrPriorsDatMem.getDistribution(), DELTA);
    ClassificationPriors igrPriorsWeights = target.getDistribution(weights, config);
    assertEquals(expectedEntropy, igrPriorsWeights.getPriorImpurity(), DELTA);
    assertArrayEquals(expectedDistribution, igrPriorsWeights.getDistribution(), DELTA);
}
Also used : TreeEnsembleLearnerConfiguration(org.knime.base.node.mine.treeensemble2.node.learner.TreeEnsembleLearnerConfiguration) RootDataMemberships(org.knime.base.node.mine.treeensemble2.data.memberships.RootDataMemberships) RootDataMemberships(org.knime.base.node.mine.treeensemble2.data.memberships.RootDataMemberships) DataMemberships(org.knime.base.node.mine.treeensemble2.data.memberships.DataMemberships) DefaultDataIndexManager(org.knime.base.node.mine.treeensemble2.data.memberships.DefaultDataIndexManager) Test(org.junit.Test)

Aggregations

TreeData (org.knime.base.node.mine.treeensemble2.data.TreeData)27 TreeEnsembleLearnerConfiguration (org.knime.base.node.mine.treeensemble2.node.learner.TreeEnsembleLearnerConfiguration)25 Test (org.junit.Test)18 DataMemberships (org.knime.base.node.mine.treeensemble2.data.memberships.DataMemberships)18 RootDataMemberships (org.knime.base.node.mine.treeensemble2.data.memberships.RootDataMemberships)18 DefaultDataIndexManager (org.knime.base.node.mine.treeensemble2.data.memberships.DefaultDataIndexManager)15 IDataIndexManager (org.knime.base.node.mine.treeensemble2.data.memberships.IDataIndexManager)15 RandomData (org.apache.commons.math.random.RandomData)14 TreeAttributeColumnData (org.knime.base.node.mine.treeensemble2.data.TreeAttributeColumnData)12 SplitCandidate (org.knime.base.node.mine.treeensemble2.learner.SplitCandidate)12 BitSet (java.util.BitSet)11 TreeTargetNumericColumnData (org.knime.base.node.mine.treeensemble2.data.TreeTargetNumericColumnData)8 TreeNodeSignature (org.knime.base.node.mine.treeensemble2.model.TreeNodeSignature)8 ExecutionMonitor (org.knime.core.node.ExecutionMonitor)8 TreeDataCreator (org.knime.base.node.mine.treeensemble2.data.TreeDataCreator)7 TreeTargetNominalColumnData (org.knime.base.node.mine.treeensemble2.data.TreeTargetNominalColumnData)7 NominalBinarySplitCandidate (org.knime.base.node.mine.treeensemble2.learner.NominalBinarySplitCandidate)7 NominalMultiwaySplitCandidate (org.knime.base.node.mine.treeensemble2.learner.NominalMultiwaySplitCandidate)7 FilterLearnColumnRearranger (org.knime.base.node.mine.treeensemble2.node.learner.TreeEnsembleLearnerConfiguration.FilterLearnColumnRearranger)7 TreeEnsembleModelPortObjectSpec (org.knime.base.node.mine.treeensemble2.model.TreeEnsembleModelPortObjectSpec)6