Search in sources :

Example 26 with RandomData

use of org.apache.commons.math.random.RandomData in project knime-core by knime.

the class TreeNominalColumnDataTest method testCalcBestSplitClassificationMultiwayXGBoostMissingValueHandling.

/**
 * This method tests the XGBoost missing value handling for classification in case of multiway splits.
 *
 * @throws Exception
 */
@Test
public void testCalcBestSplitClassificationMultiwayXGBoostMissingValueHandling() throws Exception {
    final TreeEnsembleLearnerConfiguration config = createConfig(false);
    config.setUseBinaryNominalSplits(false);
    config.setMissingValueHandling(MissingValueHandling.XGBoost);
    final TestDataGenerator dataGen = new TestDataGenerator(config);
    final RandomData rd = config.createRandomData();
    // test the case that there are no missing values in the training data
    final String noMissingCSV = "a, a, a, b, b, b, b, c, c";
    final String noMissingTarget = "A, B, B, C, C, C, B, A, B";
    TreeNominalColumnData dataCol = dataGen.createNominalAttributeColumn(noMissingCSV, "noMissings", 0);
    TreeTargetNominalColumnData targetCol = TestDataGenerator.createNominalTargetColumn(noMissingTarget);
    DataMemberships dataMem = createMockDataMemberships(targetCol.getNrRows());
    SplitCandidate split = dataCol.calcBestSplitClassification(dataMem, targetCol.getDistribution(dataMem, config), targetCol, rd);
    assertNotNull("There is a possible split.", split);
    assertEquals("Incorrect gain.", 0.216, split.getGainValue(), 1e-3);
    assertThat(split, instanceOf(NominalMultiwaySplitCandidate.class));
    NominalMultiwaySplitCandidate nomSplit = (NominalMultiwaySplitCandidate) split;
    assertTrue("No missing values in the column.", nomSplit.getMissedRows().isEmpty());
    TreeNodeNominalCondition[] conditions = nomSplit.getChildConditions();
    assertEquals("Wrong number of child conditions.", 3, conditions.length);
    assertEquals("Wrong value in child condition.", "a", conditions[0].getValue());
    assertEquals("Wrong value in child condition.", "b", conditions[1].getValue());
    assertEquals("Wrong value in child condition.", "c", conditions[2].getValue());
    assertFalse("Missing values should be sent to the majority child (i.e. b)", conditions[0].acceptsMissings());
    assertTrue("Missing values should be sent to the majority child (i.e. b)", conditions[1].acceptsMissings());
    assertFalse("Missing values should be sent to the majority child (i.e. b)", conditions[2].acceptsMissings());
    // test the case that there are missing values in the training data
    final String missingCSV = "a, a, a, b, b, b, b, c, c, ?";
    final String missingTarget = "A, B, B, C, C, C, B, A, B, C";
    dataCol = dataGen.createNominalAttributeColumn(missingCSV, "missings", 0);
    targetCol = TestDataGenerator.createNominalTargetColumn(missingTarget);
    dataMem = createMockDataMemberships(targetCol.getNrRows());
    split = dataCol.calcBestSplitClassification(dataMem, targetCol.getDistribution(dataMem, config), targetCol, rd);
    assertNotNull("There is a possible split.", split);
    assertEquals("Incorrect gain.", 0.2467, split.getGainValue(), 1e-3);
    assertThat(split, instanceOf(NominalMultiwaySplitCandidate.class));
    nomSplit = (NominalMultiwaySplitCandidate) split;
    assertTrue("Split should handle missing values.", nomSplit.getMissedRows().isEmpty());
    conditions = nomSplit.getChildConditions();
    assertEquals("Wrong number of child conditions.", 3, conditions.length);
    assertEquals("Wrong value in child condition.", "a", conditions[0].getValue());
    assertEquals("Wrong value in child condition.", "b", conditions[1].getValue());
    assertEquals("Wrong value in child condition.", "c", conditions[2].getValue());
    assertFalse("Missing values should be sent to b", conditions[0].acceptsMissings());
    assertTrue("Missing values should be sent to b", conditions[1].acceptsMissings());
    assertFalse("Missing values should be sent to b", conditions[2].acceptsMissings());
}
Also used : TreeEnsembleLearnerConfiguration(org.knime.base.node.mine.treeensemble2.node.learner.TreeEnsembleLearnerConfiguration) RandomData(org.apache.commons.math.random.RandomData) TreeNodeNominalCondition(org.knime.base.node.mine.treeensemble2.model.TreeNodeNominalCondition) NominalMultiwaySplitCandidate(org.knime.base.node.mine.treeensemble2.learner.NominalMultiwaySplitCandidate) NominalMultiwaySplitCandidate(org.knime.base.node.mine.treeensemble2.learner.NominalMultiwaySplitCandidate) NominalBinarySplitCandidate(org.knime.base.node.mine.treeensemble2.learner.NominalBinarySplitCandidate) SplitCandidate(org.knime.base.node.mine.treeensemble2.learner.SplitCandidate) DataMemberships(org.knime.base.node.mine.treeensemble2.data.memberships.DataMemberships) RootDataMemberships(org.knime.base.node.mine.treeensemble2.data.memberships.RootDataMemberships) Test(org.junit.Test)

Example 27 with RandomData

use of org.apache.commons.math.random.RandomData in project knime-core by knime.

the class SubsetNoReplacementSelectorTest method testSelectValidParameters.

@Test
public void testSelectValidParameters() throws Exception {
    final RandomData rd = TestDataGenerator.createRandomData();
    final SubsetNoReplacementSelector selector = SubsetNoReplacementSelector.getInstance();
    for (int i = 1; i <= 20; i++) {
        SubsetNoReplacementRowSample sample = selector.select(rd, 20, i);
        int included = sample.getIncludedBitSet().cardinality();
        assertEquals("The sample was expected to contain " + i + "rows but contained " + included + "rows instead.", i, included);
    }
}
Also used : RandomData(org.apache.commons.math.random.RandomData) Test(org.junit.Test)

Example 28 with RandomData

use of org.apache.commons.math.random.RandomData in project knime-core by knime.

the class SubsetNoReplacementSelectorTest method testSelectNrTotalSmallerNrSelect.

@Test(expected = IllegalArgumentException.class)
public void testSelectNrTotalSmallerNrSelect() throws Exception {
    final RandomData rd = TestDataGenerator.createRandomData();
    SubsetNoReplacementSelector.getInstance().select(rd, 10, 20);
}
Also used : RandomData(org.apache.commons.math.random.RandomData) Test(org.junit.Test)

Example 29 with RandomData

use of org.apache.commons.math.random.RandomData in project knime-core by knime.

the class SubsetNoReplacementSelectorTest method testSelectNrSelectSmallerZero.

@Test(expected = IllegalArgumentException.class)
public void testSelectNrSelectSmallerZero() throws Exception {
    final RandomData rd = TestDataGenerator.createRandomData();
    SubsetNoReplacementSelector.getInstance().select(rd, 10, -5);
}
Also used : RandomData(org.apache.commons.math.random.RandomData) Test(org.junit.Test)

Example 30 with RandomData

use of org.apache.commons.math.random.RandomData in project knime-core by knime.

the class RegressionTreeLearnerNodeModel method execute.

/**
 * {@inheritDoc}
 */
@Override
protected PortObject[] execute(final PortObject[] inObjects, final ExecutionContext exec) throws Exception {
    BufferedDataTable t = (BufferedDataTable) inObjects[0];
    DataTableSpec spec = t.getDataTableSpec();
    final FilterLearnColumnRearranger learnRearranger = m_configuration.filterLearnColumns(spec);
    String warn = learnRearranger.getWarning();
    BufferedDataTable learnTable = exec.createColumnRearrangeTable(t, learnRearranger, exec.createSubProgress(0.0));
    DataTableSpec learnSpec = learnTable.getDataTableSpec();
    ExecutionMonitor readInExec = exec.createSubProgress(0.1);
    ExecutionMonitor learnExec = exec.createSubProgress(0.9);
    TreeDataCreator dataCreator = new TreeDataCreator(m_configuration, learnSpec, learnTable.getRowCount());
    exec.setProgress("Reading data into memory");
    TreeData data = dataCreator.readData(learnTable, m_configuration, readInExec);
    m_hiliteRowSample = dataCreator.getDataRowsForHilite();
    m_viewMessage = dataCreator.getViewMessage();
    String dataCreationWarning = dataCreator.getAndClearWarningMessage();
    if (dataCreationWarning != null) {
        if (warn == null) {
            warn = dataCreationWarning;
        } else {
            warn = warn + "\n" + dataCreationWarning;
        }
    }
    readInExec.setProgress(1.0);
    exec.setMessage("Learning tree");
    RandomData rd = m_configuration.createRandomData();
    final IDataIndexManager indexManager;
    if (data.getTreeType() == TreeType.BitVector) {
        indexManager = new BitVectorDataIndexManager(data.getNrRows());
    } else {
        indexManager = new DefaultDataIndexManager(data);
    }
    TreeNodeSignatureFactory signatureFactory = null;
    int maxLevels = m_configuration.getMaxLevels();
    if (maxLevels < TreeEnsembleLearnerConfiguration.MAX_LEVEL_INFINITE) {
        int capacity = IntMath.pow(2, maxLevels - 1);
        signatureFactory = new TreeNodeSignatureFactory(capacity);
    } else {
        signatureFactory = new TreeNodeSignatureFactory();
    }
    final RowSample rowSample = m_configuration.createRowSampler(data).createRowSample(rd);
    TreeLearnerRegression treeLearner = new TreeLearnerRegression(m_configuration, data, indexManager, signatureFactory, rd, rowSample);
    TreeModelRegression regTree = treeLearner.learnSingleTree(learnExec, rd);
    RegressionTreeModel model = new RegressionTreeModel(m_configuration, data.getMetaData(), regTree, data.getTreeType());
    RegressionTreeModelPortObjectSpec treePortObjectSpec = new RegressionTreeModelPortObjectSpec(learnSpec);
    RegressionTreeModelPortObject treePortObject = new RegressionTreeModelPortObject(model, treePortObjectSpec);
    learnExec.setProgress(1.0);
    m_treeModelPortObject = treePortObject;
    if (warn != null) {
        setWarningMessage(warn);
    }
    return new PortObject[] { treePortObject };
}
Also used : RegressionTreeModelPortObject(org.knime.base.node.mine.treeensemble2.model.RegressionTreeModelPortObject) DataTableSpec(org.knime.core.data.DataTableSpec) RandomData(org.apache.commons.math.random.RandomData) RegressionTreeModel(org.knime.base.node.mine.treeensemble2.model.RegressionTreeModel) IDataIndexManager(org.knime.base.node.mine.treeensemble2.data.memberships.IDataIndexManager) BitVectorDataIndexManager(org.knime.base.node.mine.treeensemble2.data.memberships.BitVectorDataIndexManager) RegressionTreeModelPortObjectSpec(org.knime.base.node.mine.treeensemble2.model.RegressionTreeModelPortObjectSpec) DefaultDataIndexManager(org.knime.base.node.mine.treeensemble2.data.memberships.DefaultDataIndexManager) TreeModelRegression(org.knime.base.node.mine.treeensemble2.model.TreeModelRegression) BufferedDataTable(org.knime.core.node.BufferedDataTable) FilterLearnColumnRearranger(org.knime.base.node.mine.treeensemble2.node.learner.TreeEnsembleLearnerConfiguration.FilterLearnColumnRearranger) TreeData(org.knime.base.node.mine.treeensemble2.data.TreeData) TreeLearnerRegression(org.knime.base.node.mine.treeensemble2.learner.TreeLearnerRegression) RowSample(org.knime.base.node.mine.treeensemble2.sample.row.RowSample) ExecutionMonitor(org.knime.core.node.ExecutionMonitor) TreeDataCreator(org.knime.base.node.mine.treeensemble2.data.TreeDataCreator) RegressionTreeModelPortObject(org.knime.base.node.mine.treeensemble2.model.RegressionTreeModelPortObject) PortObject(org.knime.core.node.port.PortObject) TreeNodeSignatureFactory(org.knime.base.node.mine.treeensemble2.learner.TreeNodeSignatureFactory)

Aggregations

RandomData (org.apache.commons.math.random.RandomData)36 Test (org.junit.Test)21 TreeEnsembleLearnerConfiguration (org.knime.base.node.mine.treeensemble2.node.learner.TreeEnsembleLearnerConfiguration)16 DataMemberships (org.knime.base.node.mine.treeensemble2.data.memberships.DataMemberships)11 RootDataMemberships (org.knime.base.node.mine.treeensemble2.data.memberships.RootDataMemberships)11 SplitCandidate (org.knime.base.node.mine.treeensemble2.learner.SplitCandidate)11 TreeData (org.knime.base.node.mine.treeensemble2.data.TreeData)8 DefaultDataIndexManager (org.knime.base.node.mine.treeensemble2.data.memberships.DefaultDataIndexManager)7 IDataIndexManager (org.knime.base.node.mine.treeensemble2.data.memberships.IDataIndexManager)6 NumericMissingSplitCandidate (org.knime.base.node.mine.treeensemble2.learner.NumericMissingSplitCandidate)6 NumericSplitCandidate (org.knime.base.node.mine.treeensemble2.learner.NumericSplitCandidate)6 TreeNodeNumericCondition (org.knime.base.node.mine.treeensemble2.model.TreeNodeNumericCondition)6 TreeAttributeColumnData (org.knime.base.node.mine.treeensemble2.data.TreeAttributeColumnData)5 NominalBinarySplitCandidate (org.knime.base.node.mine.treeensemble2.learner.NominalBinarySplitCandidate)5 NominalMultiwaySplitCandidate (org.knime.base.node.mine.treeensemble2.learner.NominalMultiwaySplitCandidate)5 ExecutionMonitor (org.knime.core.node.ExecutionMonitor)5 BitSet (java.util.BitSet)4 TreeTargetNumericColumnData (org.knime.base.node.mine.treeensemble2.data.TreeTargetNumericColumnData)4 ArrayList (java.util.ArrayList)3 Future (java.util.concurrent.Future)3