Search in sources :

Example 31 with RandomData

use of org.apache.commons.math.random.RandomData in project knime-core by knime.

the class NormalRowSamplerTest method testCreateRowSampleNoReplacement.

@Test
public void testCreateRowSampleNoReplacement() throws Exception {
    final SubsetSelector<SubsetNoReplacementRowSample> selector = SubsetNoReplacementSelector.getInstance();
    final double fraction = 0.5;
    final int nrRows = 20;
    final RandomRowSampler<SubsetNoReplacementRowSample> sampler = new RandomRowSampler<SubsetNoReplacementRowSample>(fraction, selector, nrRows);
    final RandomData rd = TestDataGenerator.createRandomData();
    final SubsetNoReplacementRowSample sample = sampler.createRowSample(rd);
    assertEquals(10, sample.getIncludedBitSet().cardinality());
}
Also used : RandomData(org.apache.commons.math.random.RandomData) Test(org.junit.Test)

Example 32 with RandomData

use of org.apache.commons.math.random.RandomData in project knime-core by knime.

the class NormalRowSamplerTest method testCreateRowSampleWithReplacement.

@Test
public void testCreateRowSampleWithReplacement() throws Exception {
    final SubsetSelector<SubsetWithReplacementRowSample> selector = SubsetWithReplacementSelector.getInstance();
    final double fraction = 0.8;
    final int nrRows = 20;
    final RandomRowSampler<SubsetWithReplacementRowSample> sampler = new RandomRowSampler<SubsetWithReplacementRowSample>(fraction, selector, nrRows);
    final RandomData rd = TestDataGenerator.createRandomData();
    final SubsetWithReplacementRowSample sample = sampler.createRowSample(rd);
    assertEquals(16, SamplerTestUtil.countRows(sample));
}
Also used : RandomData(org.apache.commons.math.random.RandomData) Test(org.junit.Test)

Example 33 with RandomData

use of org.apache.commons.math.random.RandomData in project knime-core by knime.

the class StratifiedRowSamplerTest method testCreateRowSampleNoReplacement.

@Test
public void testCreateRowSampleNoReplacement() throws Exception {
    final RandomData rd = TestDataGenerator.createRandomData();
    double fraction = 0.5;
    final SubsetSelector<SubsetNoReplacementRowSample> selector = SubsetNoReplacementSelector.getInstance();
    StratifiedRowSampler<SubsetNoReplacementRowSample> sampler = new StratifiedRowSampler<SubsetNoReplacementRowSample>(fraction, selector, SamplerTestUtil.TARGET);
    SubsetNoReplacementRowSample sample = sampler.createRowSample(rd);
    assertEquals(8, sample.getIncludedBitSet().cardinality());
    assertEquals(SamplerTestUtil.TARGET.getNrRows(), sample.getNrRows());
    fraction = 1.0;
    sampler = new StratifiedRowSampler<SubsetNoReplacementRowSample>(fraction, selector, SamplerTestUtil.TARGET);
    sample = sampler.createRowSample(rd);
    assertEquals(SamplerTestUtil.TARGET.getNrRows(), sample.getIncludedBitSet().cardinality());
    assertEquals(SamplerTestUtil.TARGET.getNrRows(), sample.getNrRows());
}
Also used : RandomData(org.apache.commons.math.random.RandomData) Test(org.junit.Test)

Example 34 with RandomData

use of org.apache.commons.math.random.RandomData in project knime-core by knime.

the class TreeEnsembleLearner method learnEnsemble.

public TreeEnsembleModel learnEnsemble(final ExecutionMonitor exec) throws CanceledExecutionException, ExecutionException {
    final int nrModels = m_config.getNrModels();
    final RandomData rd = m_config.createRandomData();
    final ThreadPool tp = KNIMEConstants.GLOBAL_THREAD_POOL;
    final AtomicReference<Throwable> learnThrowableRef = new AtomicReference<Throwable>();
    @SuppressWarnings("unchecked") final Future<TreeLearnerResult>[] modelFutures = new Future[nrModels];
    final int procCount = 3 * Runtime.getRuntime().availableProcessors() / 2;
    final Semaphore semaphore = new Semaphore(procCount);
    Callable<TreeLearnerResult[]> learnCallable = new Callable<TreeLearnerResult[]>() {

        @Override
        public TreeLearnerResult[] call() throws Exception {
            final TreeLearnerResult[] results = new TreeLearnerResult[nrModels];
            for (int i = 0; i < nrModels; i++) {
                semaphore.acquire();
                finishedTree(i - procCount, exec);
                checkThrowable(learnThrowableRef);
                RandomData rdSingle = TreeEnsembleLearnerConfiguration.createRandomData(rd.nextLong(Long.MIN_VALUE, Long.MAX_VALUE));
                ExecutionMonitor subExec = exec.createSubProgress(0.0);
                modelFutures[i] = tp.enqueue(new TreeLearnerCallable(subExec, rdSingle, learnThrowableRef, semaphore));
            }
            for (int i = 0; i < procCount; i++) {
                semaphore.acquire();
                finishedTree(nrModels - 1 + i - procCount, exec);
            }
            for (int i = 0; i < nrModels; i++) {
                try {
                    results[i] = modelFutures[i].get();
                } catch (Exception e) {
                    learnThrowableRef.compareAndSet(null, e);
                }
            }
            return results;
        }

        private void finishedTree(final int treeIndex, final ExecutionMonitor progMon) {
            if (treeIndex > 0) {
                progMon.setProgress(treeIndex / (double) nrModels, "Tree " + treeIndex + "/" + nrModels);
            }
        }
    };
    TreeLearnerResult[] modelResults = tp.runInvisible(learnCallable);
    checkThrowable(learnThrowableRef);
    AbstractTreeModel[] models = new AbstractTreeModel[nrModels];
    m_rowSamples = new RowSample[nrModels];
    m_columnSampleStrategies = new ColumnSampleStrategy[nrModels];
    for (int i = 0; i < nrModels; i++) {
        models[i] = modelResults[i].m_treeModel;
        m_rowSamples[i] = modelResults[i].m_rowSample;
        m_columnSampleStrategies[i] = modelResults[i].m_rootColumnSampleStrategy;
    }
    m_ensembleModel = new TreeEnsembleModel(m_config, m_data.getMetaData(), models, m_data.getTreeType());
    return m_ensembleModel;
}
Also used : RandomData(org.apache.commons.math.random.RandomData) TreeEnsembleModel(org.knime.base.node.mine.treeensemble2.model.TreeEnsembleModel) ThreadPool(org.knime.core.util.ThreadPool) AtomicReference(java.util.concurrent.atomic.AtomicReference) AbstractTreeModel(org.knime.base.node.mine.treeensemble2.model.AbstractTreeModel) Semaphore(java.util.concurrent.Semaphore) Callable(java.util.concurrent.Callable) CanceledExecutionException(org.knime.core.node.CanceledExecutionException) ExecutionException(java.util.concurrent.ExecutionException) Future(java.util.concurrent.Future) ExecutionMonitor(org.knime.core.node.ExecutionMonitor)

Example 35 with RandomData

use of org.apache.commons.math.random.RandomData in project knime-core by knime.

the class TreeLearnerRegression method buildTreeNode.

private TreeNodeRegression buildTreeNode(final ExecutionMonitor exec, final int currentDepth, final DataMemberships dataMemberships, final ColumnSample columnSample, final TreeNodeSignature treeNodeSignature, final RegressionPriors targetPriors, final BitSet forbiddenColumnSet) throws CanceledExecutionException {
    final TreeData data = getData();
    final RandomData rd = getRandomData();
    final TreeEnsembleLearnerConfiguration config = getConfig();
    exec.checkCanceled();
    final SplitCandidate candidate = findBestSplitRegression(currentDepth, dataMemberships, columnSample, targetPriors, forbiddenColumnSet);
    if (candidate == null) {
        if (config instanceof GradientBoostingLearnerConfiguration) {
            TreeNodeRegression leaf = new TreeNodeRegression(treeNodeSignature, targetPriors, dataMemberships.getOriginalIndices());
            addToLeafList(leaf);
            return leaf;
        }
        return new TreeNodeRegression(treeNodeSignature, targetPriors);
    }
    final TreeTargetNumericColumnData targetColumn = (TreeTargetNumericColumnData) data.getTargetColumn();
    boolean useSurrogates = config.getMissingValueHandling() == MissingValueHandling.Surrogate;
    TreeNodeCondition[] childConditions;
    TreeNodeRegression[] childNodes;
    if (useSurrogates) {
        SurrogateSplit surrogateSplit = Surrogates.learnSurrogates(dataMemberships, candidate, data, columnSample, config, rd);
        childConditions = surrogateSplit.getChildConditions();
        BitSet[] childMarkers = surrogateSplit.getChildMarkers();
        assert childMarkers[0].cardinality() + childMarkers[1].cardinality() == dataMemberships.getRowCount() : "Sum of rows in children does not add up to number of rows in parent.";
        childNodes = new TreeNodeRegression[2];
        for (int i = 0; i < 2; i++) {
            DataMemberships childMemberships = dataMemberships.createChildMemberships(childMarkers[i]);
            TreeNodeSignature childSignature = getSignatureFactory().getChildSignatureFor(treeNodeSignature, (byte) i);
            ColumnSample childColumnSample = getColSamplingStrategy().getColumnSampleForTreeNode(childSignature);
            RegressionPriors childTargetPriors = targetColumn.getPriors(childMemberships, config);
            childNodes[i] = buildTreeNode(exec, currentDepth + 1, childMemberships, childColumnSample, childSignature, childTargetPriors, forbiddenColumnSet);
            childNodes[i].setTreeNodeCondition(childConditions[i]);
        }
    } else {
        SplitCandidate bestSplit = candidate;
        TreeAttributeColumnData splitColumn = bestSplit.getColumnData();
        final int attributeIndex = splitColumn.getMetaData().getAttributeIndex();
        boolean markAttributeAsForbidden = !bestSplit.canColumnBeSplitFurther();
        forbiddenColumnSet.set(attributeIndex, markAttributeAsForbidden);
        childConditions = bestSplit.getChildConditions();
        if (childConditions.length > Short.MAX_VALUE) {
            throw new RuntimeException("Too many children when splitting " + "attribute " + bestSplit.getColumnData() + " (maximum supported: " + Short.MAX_VALUE + "): " + childConditions.length);
        }
        childNodes = new TreeNodeRegression[childConditions.length];
        for (int i = 0; i < childConditions.length; i++) {
            TreeNodeCondition cond = childConditions[i];
            DataMemberships childMemberships = dataMemberships.createChildMemberships(splitColumn.updateChildMemberships(cond, dataMemberships));
            RegressionPriors childTargetPriors = targetColumn.getPriors(childMemberships, config);
            TreeNodeSignature childSignature = treeNodeSignature.createChildSignature((byte) i);
            ColumnSample childColumnSample = getColSamplingStrategy().getColumnSampleForTreeNode(childSignature);
            childNodes[i] = buildTreeNode(exec, currentDepth + 1, childMemberships, childColumnSample, childSignature, childTargetPriors, forbiddenColumnSet);
            childNodes[i].setTreeNodeCondition(cond);
        }
        if (markAttributeAsForbidden) {
            forbiddenColumnSet.set(attributeIndex, false);
        }
    }
    return new TreeNodeRegression(treeNodeSignature, targetPriors, childNodes);
}
Also used : TreeEnsembleLearnerConfiguration(org.knime.base.node.mine.treeensemble2.node.learner.TreeEnsembleLearnerConfiguration) RandomData(org.apache.commons.math.random.RandomData) TreeAttributeColumnData(org.knime.base.node.mine.treeensemble2.data.TreeAttributeColumnData) ColumnSample(org.knime.base.node.mine.treeensemble2.sample.column.ColumnSample) RegressionPriors(org.knime.base.node.mine.treeensemble2.data.RegressionPriors) BitSet(java.util.BitSet) TreeTargetNumericColumnData(org.knime.base.node.mine.treeensemble2.data.TreeTargetNumericColumnData) TreeNodeSignature(org.knime.base.node.mine.treeensemble2.model.TreeNodeSignature) TreeNodeRegression(org.knime.base.node.mine.treeensemble2.model.TreeNodeRegression) DataMemberships(org.knime.base.node.mine.treeensemble2.data.memberships.DataMemberships) RootDataMemberships(org.knime.base.node.mine.treeensemble2.data.memberships.RootDataMemberships) GradientBoostingLearnerConfiguration(org.knime.base.node.mine.treeensemble2.node.gradientboosting.learner.GradientBoostingLearnerConfiguration) TreeData(org.knime.base.node.mine.treeensemble2.data.TreeData) TreeNodeCondition(org.knime.base.node.mine.treeensemble2.model.TreeNodeCondition)

Aggregations

RandomData (org.apache.commons.math.random.RandomData)36 Test (org.junit.Test)21 TreeEnsembleLearnerConfiguration (org.knime.base.node.mine.treeensemble2.node.learner.TreeEnsembleLearnerConfiguration)16 DataMemberships (org.knime.base.node.mine.treeensemble2.data.memberships.DataMemberships)11 RootDataMemberships (org.knime.base.node.mine.treeensemble2.data.memberships.RootDataMemberships)11 SplitCandidate (org.knime.base.node.mine.treeensemble2.learner.SplitCandidate)11 TreeData (org.knime.base.node.mine.treeensemble2.data.TreeData)8 DefaultDataIndexManager (org.knime.base.node.mine.treeensemble2.data.memberships.DefaultDataIndexManager)7 IDataIndexManager (org.knime.base.node.mine.treeensemble2.data.memberships.IDataIndexManager)6 NumericMissingSplitCandidate (org.knime.base.node.mine.treeensemble2.learner.NumericMissingSplitCandidate)6 NumericSplitCandidate (org.knime.base.node.mine.treeensemble2.learner.NumericSplitCandidate)6 TreeNodeNumericCondition (org.knime.base.node.mine.treeensemble2.model.TreeNodeNumericCondition)6 TreeAttributeColumnData (org.knime.base.node.mine.treeensemble2.data.TreeAttributeColumnData)5 NominalBinarySplitCandidate (org.knime.base.node.mine.treeensemble2.learner.NominalBinarySplitCandidate)5 NominalMultiwaySplitCandidate (org.knime.base.node.mine.treeensemble2.learner.NominalMultiwaySplitCandidate)5 ExecutionMonitor (org.knime.core.node.ExecutionMonitor)5 BitSet (java.util.BitSet)4 TreeTargetNumericColumnData (org.knime.base.node.mine.treeensemble2.data.TreeTargetNumericColumnData)4 ArrayList (java.util.ArrayList)3 Future (java.util.concurrent.Future)3