Search in sources :

Example 16 with TreeEnsembleModel

use of org.knime.base.node.mine.treeensemble2.model.TreeEnsembleModel in project knime-core by knime.

the class RandomForestClassificationLearnerNodeModel method printEnsembleStatistics.

private void printEnsembleStatistics(final TreeEnsembleModel ensembleModel) {
    EnsembleStatistic stat = new EnsembleStatistic(ensembleModel);
    System.out.println("minLevel: " + stat.getMinLevel());
    System.out.println("maxLevel: " + stat.getMaxLevel());
    System.out.println("avgLevel: " + stat.getAvgLevel());
    System.out.println("minNumNodes: " + stat.getMinNumNodes());
    System.out.println("maxNumNodes: " + stat.getMaxNumNodes());
    System.out.println("avgNumNodes: " + stat.getAvgNumNodes());
    System.out.println("avgNumSurrogates: " + stat.getAvgNumSurrogates());
}
Also used : EnsembleStatistic(org.knime.base.node.mine.treeensemble2.statistics.EnsembleStatistic)

Example 17 with TreeEnsembleModel

use of org.knime.base.node.mine.treeensemble2.model.TreeEnsembleModel in project knime-core by knime.

the class TreeEnsembleLearner method learnEnsemble.

public TreeEnsembleModel learnEnsemble(final ExecutionMonitor exec) throws CanceledExecutionException, ExecutionException {
    final int nrModels = m_config.getNrModels();
    final RandomData rd = m_config.createRandomData();
    final ThreadPool tp = KNIMEConstants.GLOBAL_THREAD_POOL;
    final AtomicReference<Throwable> learnThrowableRef = new AtomicReference<Throwable>();
    @SuppressWarnings("unchecked") final Future<TreeLearnerResult>[] modelFutures = new Future[nrModels];
    final int procCount = 3 * Runtime.getRuntime().availableProcessors() / 2;
    final Semaphore semaphore = new Semaphore(procCount);
    Callable<TreeLearnerResult[]> learnCallable = new Callable<TreeLearnerResult[]>() {

        @Override
        public TreeLearnerResult[] call() throws Exception {
            final TreeLearnerResult[] results = new TreeLearnerResult[nrModels];
            for (int i = 0; i < nrModels; i++) {
                semaphore.acquire();
                finishedTree(i - procCount, exec);
                checkThrowable(learnThrowableRef);
                RandomData rdSingle = TreeEnsembleLearnerConfiguration.createRandomData(rd.nextLong(Long.MIN_VALUE, Long.MAX_VALUE));
                ExecutionMonitor subExec = exec.createSubProgress(0.0);
                modelFutures[i] = tp.enqueue(new TreeLearnerCallable(subExec, rdSingle, learnThrowableRef, semaphore));
            }
            for (int i = 0; i < procCount; i++) {
                semaphore.acquire();
                finishedTree(nrModels - 1 + i - procCount, exec);
            }
            for (int i = 0; i < nrModels; i++) {
                try {
                    results[i] = modelFutures[i].get();
                } catch (Exception e) {
                    learnThrowableRef.compareAndSet(null, e);
                }
            }
            return results;
        }

        private void finishedTree(final int treeIndex, final ExecutionMonitor progMon) {
            if (treeIndex > 0) {
                progMon.setProgress(treeIndex / (double) nrModels, "Tree " + treeIndex + "/" + nrModels);
            }
        }
    };
    TreeLearnerResult[] modelResults = tp.runInvisible(learnCallable);
    checkThrowable(learnThrowableRef);
    AbstractTreeModel[] models = new AbstractTreeModel[nrModels];
    m_rowSamples = new RowSample[nrModels];
    m_columnSampleStrategies = new ColumnSampleStrategy[nrModels];
    for (int i = 0; i < nrModels; i++) {
        models[i] = modelResults[i].m_treeModel;
        m_rowSamples[i] = modelResults[i].m_rowSample;
        m_columnSampleStrategies[i] = modelResults[i].m_rootColumnSampleStrategy;
    }
    m_ensembleModel = new TreeEnsembleModel(m_config, m_data.getMetaData(), models, m_data.getTreeType());
    return m_ensembleModel;
}
Also used : RandomData(org.apache.commons.math.random.RandomData) TreeEnsembleModel(org.knime.base.node.mine.treeensemble2.model.TreeEnsembleModel) ThreadPool(org.knime.core.util.ThreadPool) AtomicReference(java.util.concurrent.atomic.AtomicReference) AbstractTreeModel(org.knime.base.node.mine.treeensemble2.model.AbstractTreeModel) Semaphore(java.util.concurrent.Semaphore) Callable(java.util.concurrent.Callable) CanceledExecutionException(org.knime.core.node.CanceledExecutionException) ExecutionException(java.util.concurrent.ExecutionException) Future(java.util.concurrent.Future) ExecutionMonitor(org.knime.core.node.ExecutionMonitor)

Aggregations

TreeEnsembleModel (org.knime.base.node.mine.treeensemble2.model.TreeEnsembleModel)14 TreeEnsembleModelPortObject (org.knime.base.node.mine.treeensemble2.model.TreeEnsembleModelPortObject)9 DataCell (org.knime.core.data.DataCell)7 ExecutionException (java.util.concurrent.ExecutionException)6 CanceledExecutionException (org.knime.core.node.CanceledExecutionException)6 ExecutionMonitor (org.knime.core.node.ExecutionMonitor)6 IOException (java.io.IOException)5 TreeEnsembleModelPortObjectSpec (org.knime.base.node.mine.treeensemble2.model.TreeEnsembleModelPortObjectSpec)5 DataTableSpec (org.knime.core.data.DataTableSpec)5 BufferedDataTable (org.knime.core.node.BufferedDataTable)5 InvalidSettingsException (org.knime.core.node.InvalidSettingsException)5 PortObject (org.knime.core.node.port.PortObject)5 TreeData (org.knime.base.node.mine.treeensemble2.data.TreeData)4 TreeDataCreator (org.knime.base.node.mine.treeensemble2.data.TreeDataCreator)4 TreeEnsembleLearner (org.knime.base.node.mine.treeensemble2.learner.TreeEnsembleLearner)4 FilterLearnColumnRearranger (org.knime.base.node.mine.treeensemble2.node.learner.TreeEnsembleLearnerConfiguration.FilterLearnColumnRearranger)4 TreeEnsemblePredictor (org.knime.base.node.mine.treeensemble2.node.predictor.TreeEnsemblePredictor)4 ColumnRearranger (org.knime.core.data.container.ColumnRearranger)4 DoubleCell (org.knime.core.data.def.DoubleCell)4 IntCell (org.knime.core.data.def.IntCell)4