Search in sources :

Example 16 with DefaultDataIndexManager

use of org.knime.base.node.mine.treeensemble2.data.memberships.DefaultDataIndexManager in project knime-core by knime.

the class TreeTargetNominalColumnDataTest method testGetDistribution.

/**
 * Tests the {@link TreeTargetNominalColumnData#getDistribution(DataMemberships, TreeEnsembleLearnerConfiguration)}
 * and {@link TreeTargetNominalColumnData#getDistribution(double[], TreeEnsembleLearnerConfiguration)} methods.
 * @throws InvalidSettingsException
 */
@Test
public void testGetDistribution() throws InvalidSettingsException {
    String targetCSV = "A,A,A,B,B,B,A";
    String attributeCSV = "1,2,3,4,5,6,7";
    TreeEnsembleLearnerConfiguration config = new TreeEnsembleLearnerConfiguration(false);
    TestDataGenerator dataGen = new TestDataGenerator(config);
    TreeTargetNominalColumnData target = TestDataGenerator.createNominalTargetColumn(targetCSV);
    TreeNumericColumnData attribute = dataGen.createNumericAttributeColumn(attributeCSV, "test-col", 0);
    TreeData data = new TreeData(new TreeAttributeColumnData[] { attribute }, target, TreeType.Ordinary);
    double[] weights = new double[7];
    Arrays.fill(weights, 1.0);
    DataMemberships rootMemberships = new RootDataMemberships(weights, data, new DefaultDataIndexManager(data));
    // Gini
    config.setSplitCriterion(SplitCriterion.Gini);
    double expectedGini = 0.4897959184;
    double[] expectedDistribution = new double[] { 4.0, 3.0 };
    ClassificationPriors giniPriorsDatMem = target.getDistribution(rootMemberships, config);
    assertEquals(expectedGini, giniPriorsDatMem.getPriorImpurity(), DELTA);
    assertArrayEquals(expectedDistribution, giniPriorsDatMem.getDistribution(), DELTA);
    ClassificationPriors giniPriorsWeights = target.getDistribution(weights, config);
    assertEquals(expectedGini, giniPriorsWeights.getPriorImpurity(), DELTA);
    assertArrayEquals(expectedDistribution, giniPriorsWeights.getDistribution(), DELTA);
    // Information Gain
    config.setSplitCriterion(SplitCriterion.InformationGain);
    double expectedEntropy = 0.985228136;
    ClassificationPriors igPriorsDatMem = target.getDistribution(rootMemberships, config);
    assertEquals(expectedEntropy, igPriorsDatMem.getPriorImpurity(), DELTA);
    assertArrayEquals(expectedDistribution, igPriorsDatMem.getDistribution(), DELTA);
    ClassificationPriors igPriorsWeights = target.getDistribution(weights, config);
    assertEquals(expectedEntropy, igPriorsWeights.getPriorImpurity(), DELTA);
    assertArrayEquals(expectedDistribution, igPriorsWeights.getDistribution(), DELTA);
    // Information Gain Ratio
    config.setSplitCriterion(SplitCriterion.InformationGainRatio);
    // prior impurity is the same as IG
    ClassificationPriors igrPriorsDatMem = target.getDistribution(rootMemberships, config);
    assertEquals(expectedEntropy, igrPriorsDatMem.getPriorImpurity(), DELTA);
    assertArrayEquals(expectedDistribution, igrPriorsDatMem.getDistribution(), DELTA);
    ClassificationPriors igrPriorsWeights = target.getDistribution(weights, config);
    assertEquals(expectedEntropy, igrPriorsWeights.getPriorImpurity(), DELTA);
    assertArrayEquals(expectedDistribution, igrPriorsWeights.getDistribution(), DELTA);
}
Also used : TreeEnsembleLearnerConfiguration(org.knime.base.node.mine.treeensemble2.node.learner.TreeEnsembleLearnerConfiguration) RootDataMemberships(org.knime.base.node.mine.treeensemble2.data.memberships.RootDataMemberships) RootDataMemberships(org.knime.base.node.mine.treeensemble2.data.memberships.RootDataMemberships) DataMemberships(org.knime.base.node.mine.treeensemble2.data.memberships.DataMemberships) DefaultDataIndexManager(org.knime.base.node.mine.treeensemble2.data.memberships.DefaultDataIndexManager) Test(org.junit.Test)

Example 17 with DefaultDataIndexManager

use of org.knime.base.node.mine.treeensemble2.data.memberships.DefaultDataIndexManager in project knime-core by knime.

the class RegressionTreeLearnerNodeModel method execute.

/**
 * {@inheritDoc}
 */
@Override
protected PortObject[] execute(final PortObject[] inObjects, final ExecutionContext exec) throws Exception {
    BufferedDataTable t = (BufferedDataTable) inObjects[0];
    DataTableSpec spec = t.getDataTableSpec();
    final FilterLearnColumnRearranger learnRearranger = m_configuration.filterLearnColumns(spec);
    String warn = learnRearranger.getWarning();
    BufferedDataTable learnTable = exec.createColumnRearrangeTable(t, learnRearranger, exec.createSubProgress(0.0));
    DataTableSpec learnSpec = learnTable.getDataTableSpec();
    ExecutionMonitor readInExec = exec.createSubProgress(0.1);
    ExecutionMonitor learnExec = exec.createSubProgress(0.9);
    TreeDataCreator dataCreator = new TreeDataCreator(m_configuration, learnSpec, learnTable.getRowCount());
    exec.setProgress("Reading data into memory");
    TreeData data = dataCreator.readData(learnTable, m_configuration, readInExec);
    m_hiliteRowSample = dataCreator.getDataRowsForHilite();
    m_viewMessage = dataCreator.getViewMessage();
    String dataCreationWarning = dataCreator.getAndClearWarningMessage();
    if (dataCreationWarning != null) {
        if (warn == null) {
            warn = dataCreationWarning;
        } else {
            warn = warn + "\n" + dataCreationWarning;
        }
    }
    readInExec.setProgress(1.0);
    exec.setMessage("Learning tree");
    RandomData rd = m_configuration.createRandomData();
    final IDataIndexManager indexManager;
    if (data.getTreeType() == TreeType.BitVector) {
        indexManager = new BitVectorDataIndexManager(data.getNrRows());
    } else {
        indexManager = new DefaultDataIndexManager(data);
    }
    TreeNodeSignatureFactory signatureFactory = null;
    int maxLevels = m_configuration.getMaxLevels();
    if (maxLevels < TreeEnsembleLearnerConfiguration.MAX_LEVEL_INFINITE) {
        int capacity = IntMath.pow(2, maxLevels - 1);
        signatureFactory = new TreeNodeSignatureFactory(capacity);
    } else {
        signatureFactory = new TreeNodeSignatureFactory();
    }
    final RowSample rowSample = m_configuration.createRowSampler(data).createRowSample(rd);
    TreeLearnerRegression treeLearner = new TreeLearnerRegression(m_configuration, data, indexManager, signatureFactory, rd, rowSample);
    TreeModelRegression regTree = treeLearner.learnSingleTree(learnExec, rd);
    RegressionTreeModel model = new RegressionTreeModel(m_configuration, data.getMetaData(), regTree, data.getTreeType());
    RegressionTreeModelPortObjectSpec treePortObjectSpec = new RegressionTreeModelPortObjectSpec(learnSpec);
    RegressionTreeModelPortObject treePortObject = new RegressionTreeModelPortObject(model, treePortObjectSpec);
    learnExec.setProgress(1.0);
    m_treeModelPortObject = treePortObject;
    if (warn != null) {
        setWarningMessage(warn);
    }
    return new PortObject[] { treePortObject };
}
Also used : RegressionTreeModelPortObject(org.knime.base.node.mine.treeensemble2.model.RegressionTreeModelPortObject) DataTableSpec(org.knime.core.data.DataTableSpec) RandomData(org.apache.commons.math.random.RandomData) RegressionTreeModel(org.knime.base.node.mine.treeensemble2.model.RegressionTreeModel) IDataIndexManager(org.knime.base.node.mine.treeensemble2.data.memberships.IDataIndexManager) BitVectorDataIndexManager(org.knime.base.node.mine.treeensemble2.data.memberships.BitVectorDataIndexManager) RegressionTreeModelPortObjectSpec(org.knime.base.node.mine.treeensemble2.model.RegressionTreeModelPortObjectSpec) DefaultDataIndexManager(org.knime.base.node.mine.treeensemble2.data.memberships.DefaultDataIndexManager) TreeModelRegression(org.knime.base.node.mine.treeensemble2.model.TreeModelRegression) BufferedDataTable(org.knime.core.node.BufferedDataTable) FilterLearnColumnRearranger(org.knime.base.node.mine.treeensemble2.node.learner.TreeEnsembleLearnerConfiguration.FilterLearnColumnRearranger) TreeData(org.knime.base.node.mine.treeensemble2.data.TreeData) TreeLearnerRegression(org.knime.base.node.mine.treeensemble2.learner.TreeLearnerRegression) RowSample(org.knime.base.node.mine.treeensemble2.sample.row.RowSample) ExecutionMonitor(org.knime.core.node.ExecutionMonitor) TreeDataCreator(org.knime.base.node.mine.treeensemble2.data.TreeDataCreator) RegressionTreeModelPortObject(org.knime.base.node.mine.treeensemble2.model.RegressionTreeModelPortObject) PortObject(org.knime.core.node.port.PortObject) TreeNodeSignatureFactory(org.knime.base.node.mine.treeensemble2.learner.TreeNodeSignatureFactory)

Aggregations

TreeEnsembleLearnerConfiguration (org.knime.base.node.mine.treeensemble2.node.learner.TreeEnsembleLearnerConfiguration)16 DefaultDataIndexManager (org.knime.base.node.mine.treeensemble2.data.memberships.DefaultDataIndexManager)15 Test (org.junit.Test)14 DataMemberships (org.knime.base.node.mine.treeensemble2.data.memberships.DataMemberships)14 RootDataMemberships (org.knime.base.node.mine.treeensemble2.data.memberships.RootDataMemberships)14 IDataIndexManager (org.knime.base.node.mine.treeensemble2.data.memberships.IDataIndexManager)12 SplitCandidate (org.knime.base.node.mine.treeensemble2.learner.SplitCandidate)12 RandomData (org.apache.commons.math.random.RandomData)7 NominalBinarySplitCandidate (org.knime.base.node.mine.treeensemble2.learner.NominalBinarySplitCandidate)7 NominalMultiwaySplitCandidate (org.knime.base.node.mine.treeensemble2.learner.NominalMultiwaySplitCandidate)7 BitSet (java.util.BitSet)6 NumericMissingSplitCandidate (org.knime.base.node.mine.treeensemble2.learner.NumericMissingSplitCandidate)5 NumericSplitCandidate (org.knime.base.node.mine.treeensemble2.learner.NumericSplitCandidate)5 TreeNodeNominalBinaryCondition (org.knime.base.node.mine.treeensemble2.model.TreeNodeNominalBinaryCondition)5 TreeNodeNumericCondition (org.knime.base.node.mine.treeensemble2.model.TreeNodeNumericCondition)5 TreeData (org.knime.base.node.mine.treeensemble2.data.TreeData)3 RowSample (org.knime.base.node.mine.treeensemble2.sample.row.RowSample)3 TestDataGenerator (org.knime.base.node.mine.treeensemble2.data.TestDataGenerator)2 TreeNodeNominalCondition (org.knime.base.node.mine.treeensemble2.model.TreeNodeNominalCondition)2 DefaultRowSample (org.knime.base.node.mine.treeensemble2.sample.row.DefaultRowSample)2