Search in sources :

Example 1 with RandomData

use of org.apache.commons.math.random.RandomData in project knime-core by knime.

the class RFSubsetColumnSampleStrategy method getColumnSampleForTreeNode.

/**
 * {@inheritDoc}
 */
@Override
public ColumnSample getColumnSampleForTreeNode(final TreeNodeSignature treeNodeSignature) {
    byte[] signature = treeNodeSignature.getSignaturePath();
    JDKRandomGenerator generator = new JDKRandomGenerator();
    generator.setSeed(m_seed);
    int[] newSeed = new int[signature.length];
    for (int i = 0; i < signature.length; i++) {
        for (int p = 0; p <= signature[i]; p++) {
            newSeed[i] = generator.nextInt();
        }
    }
    generator.setSeed(newSeed);
    int totalColCount = m_data.getColumns().length;
    RandomData rd = new RandomDataImpl(generator);
    int[] includes = rd.nextPermutation(totalColCount, m_subsetSize);
    Arrays.sort(includes);
    return new SubsetColumnSample(m_data, includes);
}
Also used : RandomData(org.apache.commons.math.random.RandomData) RandomDataImpl(org.apache.commons.math.random.RandomDataImpl) JDKRandomGenerator(org.apache.commons.math.random.JDKRandomGenerator)

Example 2 with RandomData

use of org.apache.commons.math.random.RandomData in project knime-core by knime.

the class DefaultRowSamplerTest method testCreateRowSample.

@Test
public void testCreateRowSample() throws Exception {
    final RandomData rd = TestDataGenerator.createRandomData();
    final DefaultRowSampler sampler = new DefaultRowSampler(20);
    final RowSample sample = sampler.createRowSample(rd);
    assertEquals(20, sample.getNrRows());
    for (int i = 0; i < sample.getNrRows(); i++) {
        assertEquals(1, sample.getCountFor(i));
    }
}
Also used : RandomData(org.apache.commons.math.random.RandomData) Test(org.junit.Test)

Example 3 with RandomData

use of org.apache.commons.math.random.RandomData in project knime-core by knime.

the class StratifiedRowSamplerTest method testCreateRowSampleWithReplacement.

@Test
public void testCreateRowSampleWithReplacement() throws Exception {
    final RandomData rd = TestDataGenerator.createRandomData();
    double fraction = 0.5;
    final SubsetSelector<SubsetWithReplacementRowSample> selector = SubsetWithReplacementSelector.getInstance();
    StratifiedRowSampler<SubsetWithReplacementRowSample> sampler = new StratifiedRowSampler<SubsetWithReplacementRowSample>(fraction, selector, SamplerTestUtil.TARGET);
    SubsetWithReplacementRowSample sample = sampler.createRowSample(rd);
    assertEquals(8, SamplerTestUtil.countRows(sample));
    assertEquals(SamplerTestUtil.TARGET.getNrRows(), sample.getNrRows());
    fraction = 1.0;
    sampler = new StratifiedRowSampler<SubsetWithReplacementRowSample>(fraction, selector, SamplerTestUtil.TARGET);
    sample = sampler.createRowSample(rd);
    assertEquals(SamplerTestUtil.TARGET.getNrRows(), SamplerTestUtil.countRows(sample));
    assertEquals(SamplerTestUtil.TARGET.getNrRows(), sample.getNrRows());
}
Also used : RandomData(org.apache.commons.math.random.RandomData) Test(org.junit.Test)

Example 4 with RandomData

use of org.apache.commons.math.random.RandomData in project knime-core by knime.

the class SubsetWithReplacementSelectorTest method testSelect.

@Test
public void testSelect() throws Exception {
    final SubsetWithReplacementSelector selector = SubsetWithReplacementSelector.getInstance();
    final RandomData rd = TestDataGenerator.createRandomData();
    for (int i = 1; i < 20; i++) {
        SubsetWithReplacementRowSample sample = selector.select(rd, 20, i);
        int included = SamplerTestUtil.countRows(sample);
        assertThat("Unexpected number of included rows", included, is(i));
    }
    SubsetWithReplacementRowSample sample = selector.select(rd, 1000, 1000);
    int uniqueRows = SamplerTestUtil.countUniqueRows(sample);
    assertThat("A bootstrap sample will usually contain about 63.2% of the rows.", uniqueRows, is(lessThan(700)));
}
Also used : RandomData(org.apache.commons.math.random.RandomData) Test(org.junit.Test)

Example 5 with RandomData

use of org.apache.commons.math.random.RandomData in project knime-core by knime.

the class TreeLearnerRegression method findBestSplitsRegression.

private SplitCandidate[] findBestSplitsRegression(final int currentDepth, final DataMemberships dataMemberships, final ColumnSample columnSample, final RegressionPriors targetPriors, final BitSet forbiddenColumnSet) {
    final TreeData data = getData();
    final RandomData rd = getRandomData();
    final TreeEnsembleLearnerConfiguration config = getConfig();
    final int maxLevels = config.getMaxLevels();
    if (maxLevels != TreeEnsembleLearnerConfiguration.MAX_LEVEL_INFINITE && currentDepth >= maxLevels) {
        return null;
    }
    final int minNodeSize = config.getMinNodeSize();
    if (minNodeSize != TreeEnsembleLearnerConfiguration.MIN_NODE_SIZE_UNDEFINED) {
        if (targetPriors.getNrRecords() < minNodeSize) {
            return null;
        }
    }
    final double priorSquaredDeviation = targetPriors.getSumSquaredDeviation();
    if (priorSquaredDeviation < TreeColumnData.EPSILON) {
        return null;
    }
    final TreeTargetNumericColumnData targetColumn = getTargetData();
    ArrayList<SplitCandidate> splitCandidates = null;
    if (currentDepth == 0 && config.getHardCodedRootColumn() != null) {
        final TreeAttributeColumnData rootColumn = data.getColumn(config.getHardCodedRootColumn());
        return new SplitCandidate[] { rootColumn.calcBestSplitRegression(dataMemberships, targetPriors, targetColumn, rd) };
    } else {
        splitCandidates = new ArrayList<SplitCandidate>(columnSample.getNumCols());
        for (TreeAttributeColumnData col : columnSample) {
            if (forbiddenColumnSet.get(col.getMetaData().getAttributeIndex())) {
                continue;
            }
            SplitCandidate currentColSplit = col.calcBestSplitRegression(dataMemberships, targetPriors, targetColumn, rd);
            if (currentColSplit != null) {
                splitCandidates.add(currentColSplit);
            }
        }
    }
    Comparator<SplitCandidate> comp = new Comparator<SplitCandidate>() {

        @Override
        public int compare(final SplitCandidate arg0, final SplitCandidate arg1) {
            int compareDouble = -Double.compare(arg0.getGainValue(), arg1.getGainValue());
            return compareDouble;
        }
    };
    if (splitCandidates.isEmpty()) {
        return null;
    }
    splitCandidates.sort(comp);
    return splitCandidates.toArray(new SplitCandidate[splitCandidates.size()]);
}
Also used : TreeEnsembleLearnerConfiguration(org.knime.base.node.mine.treeensemble2.node.learner.TreeEnsembleLearnerConfiguration) RandomData(org.apache.commons.math.random.RandomData) TreeAttributeColumnData(org.knime.base.node.mine.treeensemble2.data.TreeAttributeColumnData) TreeTargetNumericColumnData(org.knime.base.node.mine.treeensemble2.data.TreeTargetNumericColumnData) TreeData(org.knime.base.node.mine.treeensemble2.data.TreeData) Comparator(java.util.Comparator)

Aggregations

RandomData (org.apache.commons.math.random.RandomData)36 Test (org.junit.Test)21 TreeEnsembleLearnerConfiguration (org.knime.base.node.mine.treeensemble2.node.learner.TreeEnsembleLearnerConfiguration)16 DataMemberships (org.knime.base.node.mine.treeensemble2.data.memberships.DataMemberships)11 RootDataMemberships (org.knime.base.node.mine.treeensemble2.data.memberships.RootDataMemberships)11 SplitCandidate (org.knime.base.node.mine.treeensemble2.learner.SplitCandidate)11 TreeData (org.knime.base.node.mine.treeensemble2.data.TreeData)8 DefaultDataIndexManager (org.knime.base.node.mine.treeensemble2.data.memberships.DefaultDataIndexManager)7 IDataIndexManager (org.knime.base.node.mine.treeensemble2.data.memberships.IDataIndexManager)6 NumericMissingSplitCandidate (org.knime.base.node.mine.treeensemble2.learner.NumericMissingSplitCandidate)6 NumericSplitCandidate (org.knime.base.node.mine.treeensemble2.learner.NumericSplitCandidate)6 TreeNodeNumericCondition (org.knime.base.node.mine.treeensemble2.model.TreeNodeNumericCondition)6 TreeAttributeColumnData (org.knime.base.node.mine.treeensemble2.data.TreeAttributeColumnData)5 NominalBinarySplitCandidate (org.knime.base.node.mine.treeensemble2.learner.NominalBinarySplitCandidate)5 NominalMultiwaySplitCandidate (org.knime.base.node.mine.treeensemble2.learner.NominalMultiwaySplitCandidate)5 ExecutionMonitor (org.knime.core.node.ExecutionMonitor)5 BitSet (java.util.BitSet)4 TreeTargetNumericColumnData (org.knime.base.node.mine.treeensemble2.data.TreeTargetNumericColumnData)4 ArrayList (java.util.ArrayList)3 Future (java.util.concurrent.Future)3