Search in sources :

Example 1 with ColumnSampleStrategy

use of org.knime.base.node.mine.treeensemble2.sample.column.ColumnSampleStrategy in project knime-core by knime.

the class TreeLearnerClassification method findBestSplitClassification.

private SplitCandidate findBestSplitClassification(final int currentDepth, final DataMemberships dataMemberships, final ColumnSample columnSample, final TreeNodeSignature treeNodeSignature, final ClassificationPriors targetPriors, final BitSet forbiddenColumnSet) {
    final TreeData data = getData();
    final RandomData rd = getRandomData();
    // final ColumnSampleStrategy colSamplingStrategy = getColSamplingStrategy();
    final TreeEnsembleLearnerConfiguration config = getConfig();
    final int maxLevels = config.getMaxLevels();
    if (maxLevels != TreeEnsembleLearnerConfiguration.MAX_LEVEL_INFINITE && currentDepth >= maxLevels) {
        return null;
    }
    final int minNodeSize = config.getMinNodeSize();
    if (minNodeSize != TreeEnsembleLearnerConfiguration.MIN_NODE_SIZE_UNDEFINED) {
        if (targetPriors.getNrRecords() < minNodeSize) {
            return null;
        }
    }
    final double priorImpurity = targetPriors.getPriorImpurity();
    if (priorImpurity < TreeColumnData.EPSILON) {
        return null;
    }
    final TreeTargetNominalColumnData targetColumn = (TreeTargetNominalColumnData) data.getTargetColumn();
    SplitCandidate splitCandidate = null;
    if (currentDepth == 0 && config.getHardCodedRootColumn() != null) {
        final TreeAttributeColumnData rootColumn = data.getColumn(config.getHardCodedRootColumn());
        // TODO discuss whether this option makes sense with surrogates
        return rootColumn.calcBestSplitClassification(dataMemberships, targetPriors, targetColumn, rd);
    }
    double bestGainValue = 0.0;
    for (TreeAttributeColumnData col : columnSample) {
        if (forbiddenColumnSet.get(col.getMetaData().getAttributeIndex())) {
            continue;
        }
        final SplitCandidate currentColSplit = col.calcBestSplitClassification(dataMemberships, targetPriors, targetColumn, rd);
        if (currentColSplit != null) {
            final double currentGain = currentColSplit.getGainValue();
            final boolean tiebreaker = currentGain == bestGainValue ? (rd.nextInt(0, 1) == 0) : false;
            if (currentColSplit.getGainValue() > bestGainValue || tiebreaker) {
                splitCandidate = currentColSplit;
                bestGainValue = currentGain;
            }
        }
    }
    return splitCandidate;
}
Also used : TreeEnsembleLearnerConfiguration(org.knime.base.node.mine.treeensemble2.node.learner.TreeEnsembleLearnerConfiguration) RandomData(org.apache.commons.math.random.RandomData) TreeAttributeColumnData(org.knime.base.node.mine.treeensemble2.data.TreeAttributeColumnData) TreeData(org.knime.base.node.mine.treeensemble2.data.TreeData) TreeTargetNominalColumnData(org.knime.base.node.mine.treeensemble2.data.TreeTargetNominalColumnData)

Example 2 with ColumnSampleStrategy

use of org.knime.base.node.mine.treeensemble2.sample.column.ColumnSampleStrategy in project knime-core by knime.

the class TreeLearnerClassification method findBestSplitsClassification.

/**
 * Returns a list of SplitCandidates sorted (descending) by their gain
 *
 * @param currentDepth
 * @param rowSampleWeights
 * @param treeNodeSignature
 * @param targetPriors
 * @param forbiddenColumnSet
 * @param membershipController
 * @return
 */
private SplitCandidate[] findBestSplitsClassification(final int currentDepth, final DataMemberships dataMemberships, final ColumnSample columnSample, final TreeNodeSignature treeNodeSignature, final ClassificationPriors targetPriors, final BitSet forbiddenColumnSet) {
    final TreeData data = getData();
    final RandomData rd = getRandomData();
    // final ColumnSampleStrategy colSamplingStrategy = getColSamplingStrategy();
    final TreeEnsembleLearnerConfiguration config = getConfig();
    final int maxLevels = config.getMaxLevels();
    if (maxLevels != TreeEnsembleLearnerConfiguration.MAX_LEVEL_INFINITE && currentDepth >= maxLevels) {
        return null;
    }
    final int minNodeSize = config.getMinNodeSize();
    if (minNodeSize != TreeEnsembleLearnerConfiguration.MIN_NODE_SIZE_UNDEFINED) {
        if (targetPriors.getNrRecords() < minNodeSize) {
            return null;
        }
    }
    final double priorImpurity = targetPriors.getPriorImpurity();
    if (priorImpurity < TreeColumnData.EPSILON) {
        return null;
    }
    final TreeTargetNominalColumnData targetColumn = (TreeTargetNominalColumnData) data.getTargetColumn();
    SplitCandidate splitCandidate = null;
    if (currentDepth == 0 && config.getHardCodedRootColumn() != null) {
        final TreeAttributeColumnData rootColumn = data.getColumn(config.getHardCodedRootColumn());
        // TODO discuss whether this option makes sense with surrogates
        return new SplitCandidate[] { rootColumn.calcBestSplitClassification(dataMemberships, targetPriors, targetColumn, rd) };
    }
    double bestGainValue = 0.0;
    final Comparator<SplitCandidate> comp = new Comparator<SplitCandidate>() {

        @Override
        public int compare(final SplitCandidate o1, final SplitCandidate o2) {
            int compareDouble = -Double.compare(o1.getGainValue(), o2.getGainValue());
            return compareDouble;
        }
    };
    ArrayList<SplitCandidate> candidates = new ArrayList<SplitCandidate>(columnSample.getNumCols());
    for (TreeAttributeColumnData col : columnSample) {
        if (forbiddenColumnSet.get(col.getMetaData().getAttributeIndex())) {
            continue;
        }
        SplitCandidate currentColSplit = col.calcBestSplitClassification(dataMemberships, targetPriors, targetColumn, rd);
        if (currentColSplit != null) {
            candidates.add(currentColSplit);
        }
    }
    if (candidates.isEmpty()) {
        return null;
    }
    candidates.sort(comp);
    return candidates.toArray(new SplitCandidate[candidates.size()]);
}
Also used : TreeEnsembleLearnerConfiguration(org.knime.base.node.mine.treeensemble2.node.learner.TreeEnsembleLearnerConfiguration) RandomData(org.apache.commons.math.random.RandomData) TreeAttributeColumnData(org.knime.base.node.mine.treeensemble2.data.TreeAttributeColumnData) ArrayList(java.util.ArrayList) TreeData(org.knime.base.node.mine.treeensemble2.data.TreeData) TreeTargetNominalColumnData(org.knime.base.node.mine.treeensemble2.data.TreeTargetNominalColumnData) Comparator(java.util.Comparator)

Example 3 with ColumnSampleStrategy

use of org.knime.base.node.mine.treeensemble2.sample.column.ColumnSampleStrategy in project knime-core by knime.

the class TreeEnsembleLearner method createColumnStatisticTable.

public BufferedDataTable createColumnStatisticTable(final ExecutionContext exec) throws CanceledExecutionException {
    BufferedDataContainer c = exec.createDataContainer(getColumnStatisticTableSpec());
    final int nrModels = m_ensembleModel.getNrModels();
    final TreeAttributeColumnData[] columns = m_data.getColumns();
    final int nrAttributes = columns.length;
    int[][] columnOnLevelCounts = new int[REPORT_LEVEL][nrAttributes];
    int[][] columnInLevelSampleCounts = new int[REPORT_LEVEL][nrAttributes];
    for (int i = 0; i < nrModels; i++) {
        final AbstractTreeModel<?> treeModel = m_ensembleModel.getTreeModel(i);
        for (int level = 0; level < REPORT_LEVEL; level++) {
            for (AbstractTreeNode treeNodeOnLevel : treeModel.getTreeNodes(level)) {
                TreeNodeSignature sig = treeNodeOnLevel.getSignature();
                ColumnSampleStrategy colStrat = m_columnSampleStrategies[i];
                ColumnSample cs = colStrat.getColumnSampleForTreeNode(sig);
                for (TreeAttributeColumnData col : cs) {
                    final int index = col.getMetaData().getAttributeIndex();
                    columnInLevelSampleCounts[level][index] += 1;
                }
                int splitAttIdx = treeNodeOnLevel.getSplitAttributeIndex();
                if (splitAttIdx >= 0) {
                    columnOnLevelCounts[level][splitAttIdx] += 1;
                }
            }
        }
    }
    for (int i = 0; i < nrAttributes; i++) {
        String name = columns[i].getMetaData().getAttributeName();
        int[] counts = new int[2 * REPORT_LEVEL];
        for (int level = 0; level < REPORT_LEVEL; level++) {
            counts[level] = columnOnLevelCounts[level][i];
            counts[REPORT_LEVEL + level] = columnInLevelSampleCounts[level][i];
        }
        DataRow row = new DefaultRow(name, counts);
        c.addRowToTable(row);
        exec.checkCanceled();
    }
    c.close();
    return c.getTable();
}
Also used : ColumnSampleStrategy(org.knime.base.node.mine.treeensemble2.sample.column.ColumnSampleStrategy) TreeAttributeColumnData(org.knime.base.node.mine.treeensemble2.data.TreeAttributeColumnData) BufferedDataContainer(org.knime.core.node.BufferedDataContainer) ColumnSample(org.knime.base.node.mine.treeensemble2.sample.column.ColumnSample) AbstractTreeNode(org.knime.base.node.mine.treeensemble2.model.AbstractTreeNode) TreeNodeSignature(org.knime.base.node.mine.treeensemble2.model.TreeNodeSignature) DataRow(org.knime.core.data.DataRow) DefaultRow(org.knime.core.data.def.DefaultRow)

Example 4 with ColumnSampleStrategy

use of org.knime.base.node.mine.treeensemble2.sample.column.ColumnSampleStrategy in project knime-core by knime.

the class TreeEnsembleLearner method learnEnsemble.

public TreeEnsembleModel learnEnsemble(final ExecutionMonitor exec) throws CanceledExecutionException, ExecutionException {
    final int nrModels = m_config.getNrModels();
    final RandomData rd = m_config.createRandomData();
    final ThreadPool tp = KNIMEConstants.GLOBAL_THREAD_POOL;
    final AtomicReference<Throwable> learnThrowableRef = new AtomicReference<Throwable>();
    @SuppressWarnings("unchecked") final Future<TreeLearnerResult>[] modelFutures = new Future[nrModels];
    final int procCount = 3 * Runtime.getRuntime().availableProcessors() / 2;
    final Semaphore semaphore = new Semaphore(procCount);
    Callable<TreeLearnerResult[]> learnCallable = new Callable<TreeLearnerResult[]>() {

        @Override
        public TreeLearnerResult[] call() throws Exception {
            final TreeLearnerResult[] results = new TreeLearnerResult[nrModels];
            for (int i = 0; i < nrModels; i++) {
                semaphore.acquire();
                finishedTree(i - procCount, exec);
                checkThrowable(learnThrowableRef);
                RandomData rdSingle = TreeEnsembleLearnerConfiguration.createRandomData(rd.nextLong(Long.MIN_VALUE, Long.MAX_VALUE));
                ExecutionMonitor subExec = exec.createSubProgress(0.0);
                modelFutures[i] = tp.enqueue(new TreeLearnerCallable(subExec, rdSingle, learnThrowableRef, semaphore));
            }
            for (int i = 0; i < procCount; i++) {
                semaphore.acquire();
                finishedTree(nrModels - 1 + i - procCount, exec);
            }
            for (int i = 0; i < nrModels; i++) {
                try {
                    results[i] = modelFutures[i].get();
                } catch (Exception e) {
                    learnThrowableRef.compareAndSet(null, e);
                }
            }
            return results;
        }

        private void finishedTree(final int treeIndex, final ExecutionMonitor progMon) {
            if (treeIndex > 0) {
                progMon.setProgress(treeIndex / (double) nrModels, "Tree " + treeIndex + "/" + nrModels);
            }
        }
    };
    TreeLearnerResult[] modelResults = tp.runInvisible(learnCallable);
    checkThrowable(learnThrowableRef);
    AbstractTreeModel[] models = new AbstractTreeModel[nrModels];
    m_rowSamples = new RowSample[nrModels];
    m_columnSampleStrategies = new ColumnSampleStrategy[nrModels];
    for (int i = 0; i < nrModels; i++) {
        models[i] = modelResults[i].m_treeModel;
        m_rowSamples[i] = modelResults[i].m_rowSample;
        m_columnSampleStrategies[i] = modelResults[i].m_rootColumnSampleStrategy;
    }
    m_ensembleModel = new TreeEnsembleModel(m_config, m_data.getMetaData(), models, m_data.getTreeType());
    return m_ensembleModel;
}
Also used : RandomData(org.apache.commons.math.random.RandomData) TreeEnsembleModel(org.knime.base.node.mine.treeensemble2.model.TreeEnsembleModel) ThreadPool(org.knime.core.util.ThreadPool) AtomicReference(java.util.concurrent.atomic.AtomicReference) AbstractTreeModel(org.knime.base.node.mine.treeensemble2.model.AbstractTreeModel) Semaphore(java.util.concurrent.Semaphore) Callable(java.util.concurrent.Callable) CanceledExecutionException(org.knime.core.node.CanceledExecutionException) ExecutionException(java.util.concurrent.ExecutionException) Future(java.util.concurrent.Future) ExecutionMonitor(org.knime.core.node.ExecutionMonitor)

Aggregations

RandomData (org.apache.commons.math.random.RandomData)3 TreeAttributeColumnData (org.knime.base.node.mine.treeensemble2.data.TreeAttributeColumnData)3 TreeData (org.knime.base.node.mine.treeensemble2.data.TreeData)2 TreeTargetNominalColumnData (org.knime.base.node.mine.treeensemble2.data.TreeTargetNominalColumnData)2 TreeEnsembleLearnerConfiguration (org.knime.base.node.mine.treeensemble2.node.learner.TreeEnsembleLearnerConfiguration)2 ArrayList (java.util.ArrayList)1 Comparator (java.util.Comparator)1 Callable (java.util.concurrent.Callable)1 ExecutionException (java.util.concurrent.ExecutionException)1 Future (java.util.concurrent.Future)1 Semaphore (java.util.concurrent.Semaphore)1 AtomicReference (java.util.concurrent.atomic.AtomicReference)1 AbstractTreeModel (org.knime.base.node.mine.treeensemble2.model.AbstractTreeModel)1 AbstractTreeNode (org.knime.base.node.mine.treeensemble2.model.AbstractTreeNode)1 TreeEnsembleModel (org.knime.base.node.mine.treeensemble2.model.TreeEnsembleModel)1 TreeNodeSignature (org.knime.base.node.mine.treeensemble2.model.TreeNodeSignature)1 ColumnSample (org.knime.base.node.mine.treeensemble2.sample.column.ColumnSample)1 ColumnSampleStrategy (org.knime.base.node.mine.treeensemble2.sample.column.ColumnSampleStrategy)1 DataRow (org.knime.core.data.DataRow)1 DefaultRow (org.knime.core.data.def.DefaultRow)1