use of org.apache.commons.math.random.RandomData in project knime-core by knime.
the class RFSubsetColumnSampleStrategy method getColumnSampleForTreeNode.
/**
* {@inheritDoc}
*/
@Override
public ColumnSample getColumnSampleForTreeNode(final TreeNodeSignature treeNodeSignature) {
byte[] signature = treeNodeSignature.getSignaturePath();
JDKRandomGenerator generator = new JDKRandomGenerator();
generator.setSeed(m_seed);
int[] newSeed = new int[signature.length];
for (int i = 0; i < signature.length; i++) {
for (int p = 0; p <= signature[i]; p++) {
newSeed[i] = generator.nextInt();
}
}
generator.setSeed(newSeed);
int totalColCount = m_data.getColumns().length;
RandomData rd = new RandomDataImpl(generator);
int[] includes = rd.nextPermutation(totalColCount, m_subsetSize);
Arrays.sort(includes);
return new SubsetColumnSample(m_data, includes);
}
use of org.apache.commons.math.random.RandomData in project knime-core by knime.
the class DefaultRowSamplerTest method testCreateRowSample.
@Test
public void testCreateRowSample() throws Exception {
final RandomData rd = TestDataGenerator.createRandomData();
final DefaultRowSampler sampler = new DefaultRowSampler(20);
final RowSample sample = sampler.createRowSample(rd);
assertEquals(20, sample.getNrRows());
for (int i = 0; i < sample.getNrRows(); i++) {
assertEquals(1, sample.getCountFor(i));
}
}
use of org.apache.commons.math.random.RandomData in project knime-core by knime.
the class StratifiedRowSamplerTest method testCreateRowSampleWithReplacement.
@Test
public void testCreateRowSampleWithReplacement() throws Exception {
final RandomData rd = TestDataGenerator.createRandomData();
double fraction = 0.5;
final SubsetSelector<SubsetWithReplacementRowSample> selector = SubsetWithReplacementSelector.getInstance();
StratifiedRowSampler<SubsetWithReplacementRowSample> sampler = new StratifiedRowSampler<SubsetWithReplacementRowSample>(fraction, selector, SamplerTestUtil.TARGET);
SubsetWithReplacementRowSample sample = sampler.createRowSample(rd);
assertEquals(8, SamplerTestUtil.countRows(sample));
assertEquals(SamplerTestUtil.TARGET.getNrRows(), sample.getNrRows());
fraction = 1.0;
sampler = new StratifiedRowSampler<SubsetWithReplacementRowSample>(fraction, selector, SamplerTestUtil.TARGET);
sample = sampler.createRowSample(rd);
assertEquals(SamplerTestUtil.TARGET.getNrRows(), SamplerTestUtil.countRows(sample));
assertEquals(SamplerTestUtil.TARGET.getNrRows(), sample.getNrRows());
}
use of org.apache.commons.math.random.RandomData in project knime-core by knime.
the class SubsetWithReplacementSelectorTest method testSelect.
@Test
public void testSelect() throws Exception {
final SubsetWithReplacementSelector selector = SubsetWithReplacementSelector.getInstance();
final RandomData rd = TestDataGenerator.createRandomData();
for (int i = 1; i < 20; i++) {
SubsetWithReplacementRowSample sample = selector.select(rd, 20, i);
int included = SamplerTestUtil.countRows(sample);
assertThat("Unexpected number of included rows", included, is(i));
}
SubsetWithReplacementRowSample sample = selector.select(rd, 1000, 1000);
int uniqueRows = SamplerTestUtil.countUniqueRows(sample);
assertThat("A bootstrap sample will usually contain about 63.2% of the rows.", uniqueRows, is(lessThan(700)));
}
use of org.apache.commons.math.random.RandomData in project knime-core by knime.
the class TreeLearnerRegression method findBestSplitsRegression.
private SplitCandidate[] findBestSplitsRegression(final int currentDepth, final DataMemberships dataMemberships, final ColumnSample columnSample, final RegressionPriors targetPriors, final BitSet forbiddenColumnSet) {
final TreeData data = getData();
final RandomData rd = getRandomData();
final TreeEnsembleLearnerConfiguration config = getConfig();
final int maxLevels = config.getMaxLevels();
if (maxLevels != TreeEnsembleLearnerConfiguration.MAX_LEVEL_INFINITE && currentDepth >= maxLevels) {
return null;
}
final int minNodeSize = config.getMinNodeSize();
if (minNodeSize != TreeEnsembleLearnerConfiguration.MIN_NODE_SIZE_UNDEFINED) {
if (targetPriors.getNrRecords() < minNodeSize) {
return null;
}
}
final double priorSquaredDeviation = targetPriors.getSumSquaredDeviation();
if (priorSquaredDeviation < TreeColumnData.EPSILON) {
return null;
}
final TreeTargetNumericColumnData targetColumn = getTargetData();
ArrayList<SplitCandidate> splitCandidates = null;
if (currentDepth == 0 && config.getHardCodedRootColumn() != null) {
final TreeAttributeColumnData rootColumn = data.getColumn(config.getHardCodedRootColumn());
return new SplitCandidate[] { rootColumn.calcBestSplitRegression(dataMemberships, targetPriors, targetColumn, rd) };
} else {
splitCandidates = new ArrayList<SplitCandidate>(columnSample.getNumCols());
for (TreeAttributeColumnData col : columnSample) {
if (forbiddenColumnSet.get(col.getMetaData().getAttributeIndex())) {
continue;
}
SplitCandidate currentColSplit = col.calcBestSplitRegression(dataMemberships, targetPriors, targetColumn, rd);
if (currentColSplit != null) {
splitCandidates.add(currentColSplit);
}
}
}
Comparator<SplitCandidate> comp = new Comparator<SplitCandidate>() {
@Override
public int compare(final SplitCandidate arg0, final SplitCandidate arg1) {
int compareDouble = -Double.compare(arg0.getGainValue(), arg1.getGainValue());
return compareDouble;
}
};
if (splitCandidates.isEmpty()) {
return null;
}
splitCandidates.sort(comp);
return splitCandidates.toArray(new SplitCandidate[splitCandidates.size()]);
}
Aggregations