use of org.apache.commons.math.random.RandomData in project knime-core by knime.
the class NormalRowSamplerTest method testCreateRowSampleNoReplacement.
@Test
public void testCreateRowSampleNoReplacement() throws Exception {
final SubsetSelector<SubsetNoReplacementRowSample> selector = SubsetNoReplacementSelector.getInstance();
final double fraction = 0.5;
final int nrRows = 20;
final RandomRowSampler<SubsetNoReplacementRowSample> sampler = new RandomRowSampler<SubsetNoReplacementRowSample>(fraction, selector, nrRows);
final RandomData rd = TestDataGenerator.createRandomData();
final SubsetNoReplacementRowSample sample = sampler.createRowSample(rd);
assertEquals(10, sample.getIncludedBitSet().cardinality());
}
use of org.apache.commons.math.random.RandomData in project knime-core by knime.
the class NormalRowSamplerTest method testCreateRowSampleWithReplacement.
@Test
public void testCreateRowSampleWithReplacement() throws Exception {
final SubsetSelector<SubsetWithReplacementRowSample> selector = SubsetWithReplacementSelector.getInstance();
final double fraction = 0.8;
final int nrRows = 20;
final RandomRowSampler<SubsetWithReplacementRowSample> sampler = new RandomRowSampler<SubsetWithReplacementRowSample>(fraction, selector, nrRows);
final RandomData rd = TestDataGenerator.createRandomData();
final SubsetWithReplacementRowSample sample = sampler.createRowSample(rd);
assertEquals(16, SamplerTestUtil.countRows(sample));
}
use of org.apache.commons.math.random.RandomData in project knime-core by knime.
the class StratifiedRowSamplerTest method testCreateRowSampleNoReplacement.
@Test
public void testCreateRowSampleNoReplacement() throws Exception {
final RandomData rd = TestDataGenerator.createRandomData();
double fraction = 0.5;
final SubsetSelector<SubsetNoReplacementRowSample> selector = SubsetNoReplacementSelector.getInstance();
StratifiedRowSampler<SubsetNoReplacementRowSample> sampler = new StratifiedRowSampler<SubsetNoReplacementRowSample>(fraction, selector, SamplerTestUtil.TARGET);
SubsetNoReplacementRowSample sample = sampler.createRowSample(rd);
assertEquals(8, sample.getIncludedBitSet().cardinality());
assertEquals(SamplerTestUtil.TARGET.getNrRows(), sample.getNrRows());
fraction = 1.0;
sampler = new StratifiedRowSampler<SubsetNoReplacementRowSample>(fraction, selector, SamplerTestUtil.TARGET);
sample = sampler.createRowSample(rd);
assertEquals(SamplerTestUtil.TARGET.getNrRows(), sample.getIncludedBitSet().cardinality());
assertEquals(SamplerTestUtil.TARGET.getNrRows(), sample.getNrRows());
}
use of org.apache.commons.math.random.RandomData in project knime-core by knime.
the class TreeEnsembleLearner method learnEnsemble.
public TreeEnsembleModel learnEnsemble(final ExecutionMonitor exec) throws CanceledExecutionException, ExecutionException {
final int nrModels = m_config.getNrModels();
final RandomData rd = m_config.createRandomData();
final ThreadPool tp = KNIMEConstants.GLOBAL_THREAD_POOL;
final AtomicReference<Throwable> learnThrowableRef = new AtomicReference<Throwable>();
@SuppressWarnings("unchecked") final Future<TreeLearnerResult>[] modelFutures = new Future[nrModels];
final int procCount = 3 * Runtime.getRuntime().availableProcessors() / 2;
final Semaphore semaphore = new Semaphore(procCount);
Callable<TreeLearnerResult[]> learnCallable = new Callable<TreeLearnerResult[]>() {
@Override
public TreeLearnerResult[] call() throws Exception {
final TreeLearnerResult[] results = new TreeLearnerResult[nrModels];
for (int i = 0; i < nrModels; i++) {
semaphore.acquire();
finishedTree(i - procCount, exec);
checkThrowable(learnThrowableRef);
RandomData rdSingle = TreeEnsembleLearnerConfiguration.createRandomData(rd.nextLong(Long.MIN_VALUE, Long.MAX_VALUE));
ExecutionMonitor subExec = exec.createSubProgress(0.0);
modelFutures[i] = tp.enqueue(new TreeLearnerCallable(subExec, rdSingle, learnThrowableRef, semaphore));
}
for (int i = 0; i < procCount; i++) {
semaphore.acquire();
finishedTree(nrModels - 1 + i - procCount, exec);
}
for (int i = 0; i < nrModels; i++) {
try {
results[i] = modelFutures[i].get();
} catch (Exception e) {
learnThrowableRef.compareAndSet(null, e);
}
}
return results;
}
private void finishedTree(final int treeIndex, final ExecutionMonitor progMon) {
if (treeIndex > 0) {
progMon.setProgress(treeIndex / (double) nrModels, "Tree " + treeIndex + "/" + nrModels);
}
}
};
TreeLearnerResult[] modelResults = tp.runInvisible(learnCallable);
checkThrowable(learnThrowableRef);
AbstractTreeModel[] models = new AbstractTreeModel[nrModels];
m_rowSamples = new RowSample[nrModels];
m_columnSampleStrategies = new ColumnSampleStrategy[nrModels];
for (int i = 0; i < nrModels; i++) {
models[i] = modelResults[i].m_treeModel;
m_rowSamples[i] = modelResults[i].m_rowSample;
m_columnSampleStrategies[i] = modelResults[i].m_rootColumnSampleStrategy;
}
m_ensembleModel = new TreeEnsembleModel(m_config, m_data.getMetaData(), models, m_data.getTreeType());
return m_ensembleModel;
}
use of org.apache.commons.math.random.RandomData in project knime-core by knime.
the class TreeLearnerRegression method buildTreeNode.
private TreeNodeRegression buildTreeNode(final ExecutionMonitor exec, final int currentDepth, final DataMemberships dataMemberships, final ColumnSample columnSample, final TreeNodeSignature treeNodeSignature, final RegressionPriors targetPriors, final BitSet forbiddenColumnSet) throws CanceledExecutionException {
final TreeData data = getData();
final RandomData rd = getRandomData();
final TreeEnsembleLearnerConfiguration config = getConfig();
exec.checkCanceled();
final SplitCandidate candidate = findBestSplitRegression(currentDepth, dataMemberships, columnSample, targetPriors, forbiddenColumnSet);
if (candidate == null) {
if (config instanceof GradientBoostingLearnerConfiguration) {
TreeNodeRegression leaf = new TreeNodeRegression(treeNodeSignature, targetPriors, dataMemberships.getOriginalIndices());
addToLeafList(leaf);
return leaf;
}
return new TreeNodeRegression(treeNodeSignature, targetPriors);
}
final TreeTargetNumericColumnData targetColumn = (TreeTargetNumericColumnData) data.getTargetColumn();
boolean useSurrogates = config.getMissingValueHandling() == MissingValueHandling.Surrogate;
TreeNodeCondition[] childConditions;
TreeNodeRegression[] childNodes;
if (useSurrogates) {
SurrogateSplit surrogateSplit = Surrogates.learnSurrogates(dataMemberships, candidate, data, columnSample, config, rd);
childConditions = surrogateSplit.getChildConditions();
BitSet[] childMarkers = surrogateSplit.getChildMarkers();
assert childMarkers[0].cardinality() + childMarkers[1].cardinality() == dataMemberships.getRowCount() : "Sum of rows in children does not add up to number of rows in parent.";
childNodes = new TreeNodeRegression[2];
for (int i = 0; i < 2; i++) {
DataMemberships childMemberships = dataMemberships.createChildMemberships(childMarkers[i]);
TreeNodeSignature childSignature = getSignatureFactory().getChildSignatureFor(treeNodeSignature, (byte) i);
ColumnSample childColumnSample = getColSamplingStrategy().getColumnSampleForTreeNode(childSignature);
RegressionPriors childTargetPriors = targetColumn.getPriors(childMemberships, config);
childNodes[i] = buildTreeNode(exec, currentDepth + 1, childMemberships, childColumnSample, childSignature, childTargetPriors, forbiddenColumnSet);
childNodes[i].setTreeNodeCondition(childConditions[i]);
}
} else {
SplitCandidate bestSplit = candidate;
TreeAttributeColumnData splitColumn = bestSplit.getColumnData();
final int attributeIndex = splitColumn.getMetaData().getAttributeIndex();
boolean markAttributeAsForbidden = !bestSplit.canColumnBeSplitFurther();
forbiddenColumnSet.set(attributeIndex, markAttributeAsForbidden);
childConditions = bestSplit.getChildConditions();
if (childConditions.length > Short.MAX_VALUE) {
throw new RuntimeException("Too many children when splitting " + "attribute " + bestSplit.getColumnData() + " (maximum supported: " + Short.MAX_VALUE + "): " + childConditions.length);
}
childNodes = new TreeNodeRegression[childConditions.length];
for (int i = 0; i < childConditions.length; i++) {
TreeNodeCondition cond = childConditions[i];
DataMemberships childMemberships = dataMemberships.createChildMemberships(splitColumn.updateChildMemberships(cond, dataMemberships));
RegressionPriors childTargetPriors = targetColumn.getPriors(childMemberships, config);
TreeNodeSignature childSignature = treeNodeSignature.createChildSignature((byte) i);
ColumnSample childColumnSample = getColSamplingStrategy().getColumnSampleForTreeNode(childSignature);
childNodes[i] = buildTreeNode(exec, currentDepth + 1, childMemberships, childColumnSample, childSignature, childTargetPriors, forbiddenColumnSet);
childNodes[i].setTreeNodeCondition(cond);
}
if (markAttributeAsForbidden) {
forbiddenColumnSet.set(attributeIndex, false);
}
}
return new TreeNodeRegression(treeNodeSignature, targetPriors, childNodes);
}
Aggregations