use of org.knime.base.node.mine.treeensemble2.sample.column.ColumnSampleStrategy in project knime-core by knime.
the class TreeLearnerClassification method findBestSplitClassification.
private SplitCandidate findBestSplitClassification(final int currentDepth, final DataMemberships dataMemberships, final ColumnSample columnSample, final TreeNodeSignature treeNodeSignature, final ClassificationPriors targetPriors, final BitSet forbiddenColumnSet) {
final TreeData data = getData();
final RandomData rd = getRandomData();
// final ColumnSampleStrategy colSamplingStrategy = getColSamplingStrategy();
final TreeEnsembleLearnerConfiguration config = getConfig();
final int maxLevels = config.getMaxLevels();
if (maxLevels != TreeEnsembleLearnerConfiguration.MAX_LEVEL_INFINITE && currentDepth >= maxLevels) {
return null;
}
final int minNodeSize = config.getMinNodeSize();
if (minNodeSize != TreeEnsembleLearnerConfiguration.MIN_NODE_SIZE_UNDEFINED) {
if (targetPriors.getNrRecords() < minNodeSize) {
return null;
}
}
final double priorImpurity = targetPriors.getPriorImpurity();
if (priorImpurity < TreeColumnData.EPSILON) {
return null;
}
final TreeTargetNominalColumnData targetColumn = (TreeTargetNominalColumnData) data.getTargetColumn();
SplitCandidate splitCandidate = null;
if (currentDepth == 0 && config.getHardCodedRootColumn() != null) {
final TreeAttributeColumnData rootColumn = data.getColumn(config.getHardCodedRootColumn());
// TODO discuss whether this option makes sense with surrogates
return rootColumn.calcBestSplitClassification(dataMemberships, targetPriors, targetColumn, rd);
}
double bestGainValue = 0.0;
for (TreeAttributeColumnData col : columnSample) {
if (forbiddenColumnSet.get(col.getMetaData().getAttributeIndex())) {
continue;
}
final SplitCandidate currentColSplit = col.calcBestSplitClassification(dataMemberships, targetPriors, targetColumn, rd);
if (currentColSplit != null) {
final double currentGain = currentColSplit.getGainValue();
final boolean tiebreaker = currentGain == bestGainValue ? (rd.nextInt(0, 1) == 0) : false;
if (currentColSplit.getGainValue() > bestGainValue || tiebreaker) {
splitCandidate = currentColSplit;
bestGainValue = currentGain;
}
}
}
return splitCandidate;
}
use of org.knime.base.node.mine.treeensemble2.sample.column.ColumnSampleStrategy in project knime-core by knime.
the class TreeLearnerClassification method findBestSplitsClassification.
/**
* Returns a list of SplitCandidates sorted (descending) by their gain
*
* @param currentDepth
* @param rowSampleWeights
* @param treeNodeSignature
* @param targetPriors
* @param forbiddenColumnSet
* @param membershipController
* @return
*/
private SplitCandidate[] findBestSplitsClassification(final int currentDepth, final DataMemberships dataMemberships, final ColumnSample columnSample, final TreeNodeSignature treeNodeSignature, final ClassificationPriors targetPriors, final BitSet forbiddenColumnSet) {
final TreeData data = getData();
final RandomData rd = getRandomData();
// final ColumnSampleStrategy colSamplingStrategy = getColSamplingStrategy();
final TreeEnsembleLearnerConfiguration config = getConfig();
final int maxLevels = config.getMaxLevels();
if (maxLevels != TreeEnsembleLearnerConfiguration.MAX_LEVEL_INFINITE && currentDepth >= maxLevels) {
return null;
}
final int minNodeSize = config.getMinNodeSize();
if (minNodeSize != TreeEnsembleLearnerConfiguration.MIN_NODE_SIZE_UNDEFINED) {
if (targetPriors.getNrRecords() < minNodeSize) {
return null;
}
}
final double priorImpurity = targetPriors.getPriorImpurity();
if (priorImpurity < TreeColumnData.EPSILON) {
return null;
}
final TreeTargetNominalColumnData targetColumn = (TreeTargetNominalColumnData) data.getTargetColumn();
SplitCandidate splitCandidate = null;
if (currentDepth == 0 && config.getHardCodedRootColumn() != null) {
final TreeAttributeColumnData rootColumn = data.getColumn(config.getHardCodedRootColumn());
// TODO discuss whether this option makes sense with surrogates
return new SplitCandidate[] { rootColumn.calcBestSplitClassification(dataMemberships, targetPriors, targetColumn, rd) };
}
double bestGainValue = 0.0;
final Comparator<SplitCandidate> comp = new Comparator<SplitCandidate>() {
@Override
public int compare(final SplitCandidate o1, final SplitCandidate o2) {
int compareDouble = -Double.compare(o1.getGainValue(), o2.getGainValue());
return compareDouble;
}
};
ArrayList<SplitCandidate> candidates = new ArrayList<SplitCandidate>(columnSample.getNumCols());
for (TreeAttributeColumnData col : columnSample) {
if (forbiddenColumnSet.get(col.getMetaData().getAttributeIndex())) {
continue;
}
SplitCandidate currentColSplit = col.calcBestSplitClassification(dataMemberships, targetPriors, targetColumn, rd);
if (currentColSplit != null) {
candidates.add(currentColSplit);
}
}
if (candidates.isEmpty()) {
return null;
}
candidates.sort(comp);
return candidates.toArray(new SplitCandidate[candidates.size()]);
}
use of org.knime.base.node.mine.treeensemble2.sample.column.ColumnSampleStrategy in project knime-core by knime.
the class TreeEnsembleLearner method createColumnStatisticTable.
public BufferedDataTable createColumnStatisticTable(final ExecutionContext exec) throws CanceledExecutionException {
BufferedDataContainer c = exec.createDataContainer(getColumnStatisticTableSpec());
final int nrModels = m_ensembleModel.getNrModels();
final TreeAttributeColumnData[] columns = m_data.getColumns();
final int nrAttributes = columns.length;
int[][] columnOnLevelCounts = new int[REPORT_LEVEL][nrAttributes];
int[][] columnInLevelSampleCounts = new int[REPORT_LEVEL][nrAttributes];
for (int i = 0; i < nrModels; i++) {
final AbstractTreeModel<?> treeModel = m_ensembleModel.getTreeModel(i);
for (int level = 0; level < REPORT_LEVEL; level++) {
for (AbstractTreeNode treeNodeOnLevel : treeModel.getTreeNodes(level)) {
TreeNodeSignature sig = treeNodeOnLevel.getSignature();
ColumnSampleStrategy colStrat = m_columnSampleStrategies[i];
ColumnSample cs = colStrat.getColumnSampleForTreeNode(sig);
for (TreeAttributeColumnData col : cs) {
final int index = col.getMetaData().getAttributeIndex();
columnInLevelSampleCounts[level][index] += 1;
}
int splitAttIdx = treeNodeOnLevel.getSplitAttributeIndex();
if (splitAttIdx >= 0) {
columnOnLevelCounts[level][splitAttIdx] += 1;
}
}
}
}
for (int i = 0; i < nrAttributes; i++) {
String name = columns[i].getMetaData().getAttributeName();
int[] counts = new int[2 * REPORT_LEVEL];
for (int level = 0; level < REPORT_LEVEL; level++) {
counts[level] = columnOnLevelCounts[level][i];
counts[REPORT_LEVEL + level] = columnInLevelSampleCounts[level][i];
}
DataRow row = new DefaultRow(name, counts);
c.addRowToTable(row);
exec.checkCanceled();
}
c.close();
return c.getTable();
}
use of org.knime.base.node.mine.treeensemble2.sample.column.ColumnSampleStrategy in project knime-core by knime.
the class TreeEnsembleLearner method learnEnsemble.
public TreeEnsembleModel learnEnsemble(final ExecutionMonitor exec) throws CanceledExecutionException, ExecutionException {
final int nrModels = m_config.getNrModels();
final RandomData rd = m_config.createRandomData();
final ThreadPool tp = KNIMEConstants.GLOBAL_THREAD_POOL;
final AtomicReference<Throwable> learnThrowableRef = new AtomicReference<Throwable>();
@SuppressWarnings("unchecked") final Future<TreeLearnerResult>[] modelFutures = new Future[nrModels];
final int procCount = 3 * Runtime.getRuntime().availableProcessors() / 2;
final Semaphore semaphore = new Semaphore(procCount);
Callable<TreeLearnerResult[]> learnCallable = new Callable<TreeLearnerResult[]>() {
@Override
public TreeLearnerResult[] call() throws Exception {
final TreeLearnerResult[] results = new TreeLearnerResult[nrModels];
for (int i = 0; i < nrModels; i++) {
semaphore.acquire();
finishedTree(i - procCount, exec);
checkThrowable(learnThrowableRef);
RandomData rdSingle = TreeEnsembleLearnerConfiguration.createRandomData(rd.nextLong(Long.MIN_VALUE, Long.MAX_VALUE));
ExecutionMonitor subExec = exec.createSubProgress(0.0);
modelFutures[i] = tp.enqueue(new TreeLearnerCallable(subExec, rdSingle, learnThrowableRef, semaphore));
}
for (int i = 0; i < procCount; i++) {
semaphore.acquire();
finishedTree(nrModels - 1 + i - procCount, exec);
}
for (int i = 0; i < nrModels; i++) {
try {
results[i] = modelFutures[i].get();
} catch (Exception e) {
learnThrowableRef.compareAndSet(null, e);
}
}
return results;
}
private void finishedTree(final int treeIndex, final ExecutionMonitor progMon) {
if (treeIndex > 0) {
progMon.setProgress(treeIndex / (double) nrModels, "Tree " + treeIndex + "/" + nrModels);
}
}
};
TreeLearnerResult[] modelResults = tp.runInvisible(learnCallable);
checkThrowable(learnThrowableRef);
AbstractTreeModel[] models = new AbstractTreeModel[nrModels];
m_rowSamples = new RowSample[nrModels];
m_columnSampleStrategies = new ColumnSampleStrategy[nrModels];
for (int i = 0; i < nrModels; i++) {
models[i] = modelResults[i].m_treeModel;
m_rowSamples[i] = modelResults[i].m_rowSample;
m_columnSampleStrategies[i] = modelResults[i].m_rootColumnSampleStrategy;
}
m_ensembleModel = new TreeEnsembleModel(m_config, m_data.getMetaData(), models, m_data.getTreeType());
return m_ensembleModel;
}
Aggregations