Search in sources :

Example 1 with AbstractTreeModel

use of org.knime.base.node.mine.treeensemble.model.AbstractTreeModel in project knime-core by knime.

the class TreeEnsembleLearner method learnEnsemble.

public TreeEnsembleModel learnEnsemble(final ExecutionMonitor exec) throws CanceledExecutionException, ExecutionException {
    final int nrModels = m_config.getNrModels();
    final RandomData rd = m_config.createRandomData();
    final ThreadPool tp = KNIMEConstants.GLOBAL_THREAD_POOL;
    final AtomicReference<Throwable> learnThrowableRef = new AtomicReference<Throwable>();
    @SuppressWarnings("unchecked") final Future<TreeLearnerResult>[] modelFutures = new Future[nrModels];
    final int procCount = 3 * Runtime.getRuntime().availableProcessors() / 2;
    final Semaphore semaphore = new Semaphore(procCount);
    Callable<TreeLearnerResult[]> learnCallable = new Callable<TreeLearnerResult[]>() {

        @Override
        public TreeLearnerResult[] call() throws Exception {
            final TreeLearnerResult[] results = new TreeLearnerResult[nrModels];
            for (int i = 0; i < nrModels; i++) {
                semaphore.acquire();
                finishedTree(i - procCount, exec);
                checkThrowable(learnThrowableRef);
                RandomData rdSingle = TreeEnsembleLearnerConfiguration.createRandomData(rd.nextLong(Long.MIN_VALUE, Long.MAX_VALUE));
                ExecutionMonitor subExec = exec.createSubProgress(0.0);
                modelFutures[i] = tp.enqueue(new TreeLearnerCallable(subExec, rdSingle, learnThrowableRef, semaphore));
            }
            for (int i = 0; i < procCount; i++) {
                semaphore.acquire();
                finishedTree(nrModels - 1 + i - procCount, exec);
            }
            for (int i = 0; i < nrModels; i++) {
                try {
                    results[i] = modelFutures[i].get();
                } catch (Exception e) {
                    learnThrowableRef.compareAndSet(null, e);
                }
            }
            return results;
        }

        private void finishedTree(final int treeIndex, final ExecutionMonitor progMon) {
            if (treeIndex > 0) {
                progMon.setProgress(treeIndex / (double) nrModels, "Tree " + treeIndex + "/" + nrModels);
            }
        }
    };
    TreeLearnerResult[] modelResults = tp.runInvisible(learnCallable);
    checkThrowable(learnThrowableRef);
    AbstractTreeModel[] models = new AbstractTreeModel[nrModels];
    m_rowSamples = new RowSample[nrModels];
    m_columnSampleStrategies = new ColumnSampleStrategy[nrModels];
    for (int i = 0; i < nrModels; i++) {
        models[i] = modelResults[i].m_treeModel;
        m_rowSamples[i] = modelResults[i].m_rowSample;
        m_columnSampleStrategies[i] = modelResults[i].m_rootColumnSampleStrategy;
    }
    m_ensembleModel = new TreeEnsembleModel(m_config, m_data.getMetaData(), models, m_data.getTreeType());
    return m_ensembleModel;
}
Also used : RandomData(org.apache.commons.math.random.RandomData) TreeEnsembleModel(org.knime.base.node.mine.treeensemble.model.TreeEnsembleModel) ThreadPool(org.knime.core.util.ThreadPool) AtomicReference(java.util.concurrent.atomic.AtomicReference) AbstractTreeModel(org.knime.base.node.mine.treeensemble.model.AbstractTreeModel) Semaphore(java.util.concurrent.Semaphore) Callable(java.util.concurrent.Callable) CanceledExecutionException(org.knime.core.node.CanceledExecutionException) ExecutionException(java.util.concurrent.ExecutionException) Future(java.util.concurrent.Future) ExecutionMonitor(org.knime.core.node.ExecutionMonitor)

Example 2 with AbstractTreeModel

use of org.knime.base.node.mine.treeensemble.model.AbstractTreeModel in project knime-core by knime.

the class TreeEnsembleShrinker method getModel.

/**
 * Returns the shrunk tree ensemble model.
 *
 * Note: If autoShrink() or shrinkTo() have not been called yet the model will contain the same trees as the initial model.
 *
 * @return The tree ensemble model
 */
public TreeEnsembleModel getModel() {
    // Build a model based on the meta data of our initial model and the currently selected trees
    TreeEnsembleLearnerConfiguration config = new TreeEnsembleLearnerConfiguration(false);
    config.setSaveTargetDistributionInNodes(true);
    return new TreeEnsembleModel(config, m_initialEnsemble.getMetaData(), m_currentTrees.toArray(new AbstractTreeModel[m_currentTrees.size()]), TreeType.Ordinary);
}
Also used : TreeEnsembleLearnerConfiguration(org.knime.base.node.mine.treeensemble.node.learner.TreeEnsembleLearnerConfiguration) TreeEnsembleModel(org.knime.base.node.mine.treeensemble.model.TreeEnsembleModel) AbstractTreeModel(org.knime.base.node.mine.treeensemble.model.AbstractTreeModel)

Aggregations

AbstractTreeModel (org.knime.base.node.mine.treeensemble.model.AbstractTreeModel)2 TreeEnsembleModel (org.knime.base.node.mine.treeensemble.model.TreeEnsembleModel)2 Callable (java.util.concurrent.Callable)1 ExecutionException (java.util.concurrent.ExecutionException)1 Future (java.util.concurrent.Future)1 Semaphore (java.util.concurrent.Semaphore)1 AtomicReference (java.util.concurrent.atomic.AtomicReference)1 RandomData (org.apache.commons.math.random.RandomData)1 TreeEnsembleLearnerConfiguration (org.knime.base.node.mine.treeensemble.node.learner.TreeEnsembleLearnerConfiguration)1 CanceledExecutionException (org.knime.core.node.CanceledExecutionException)1 ExecutionMonitor (org.knime.core.node.ExecutionMonitor)1 ThreadPool (org.knime.core.util.ThreadPool)1