Search in sources :

Example 6 with RowSample

use of org.knime.base.node.mine.treeensemble2.sample.row.RowSample in project knime-core by knime.

the class TreeNumericColumnData method calcBestSplitRegression.

@Override
public SplitCandidate calcBestSplitRegression(final DataMemberships dataMemberships, final RegressionPriors targetPriors, final TreeTargetNumericColumnData targetColumn, final RandomData rd) {
    final TreeEnsembleLearnerConfiguration config = getConfiguration();
    final boolean useAverageSplitPoints = config.isUseAverageSplitPoints();
    final int minChildNodeSize = config.getMinChildSize();
    // get columnMemberships
    final ColumnMemberships columnMemberships = dataMemberships.getColumnMemberships(getMetaData().getAttributeIndex());
    final int lengthNonMissing = getLengthNonMissing();
    // missing value handling
    final boolean useXGBoostMissingValueHandling = config.getMissingValueHandling() == MissingValueHandling.XGBoost;
    // are there missing values in this column (complete column)
    boolean branchContainsMissingValues = containsMissingValues();
    boolean missingsGoLeft = true;
    double missingWeight = 0.0;
    double missingY = 0.0;
    // check if there are missing values in this rowsample
    if (branchContainsMissingValues) {
        columnMemberships.goToLast();
        while (columnMemberships.getIndexInColumn() >= lengthNonMissing) {
            missingWeight += columnMemberships.getRowWeight();
            missingY += targetColumn.getValueFor(columnMemberships.getOriginalIndex());
            if (!columnMemberships.previous()) {
                break;
            }
        }
        columnMemberships.reset();
        branchContainsMissingValues = missingWeight > 0.0;
    }
    final double ySumTotal = targetPriors.getYSum() - missingY;
    final double nrRecordsTotal = targetPriors.getNrRecords() - missingWeight;
    final double criterionTotal = useXGBoostMissingValueHandling ? (ySumTotal + missingY) * (ySumTotal + missingY) / (nrRecordsTotal + missingWeight) : ySumTotal * ySumTotal / nrRecordsTotal;
    double ySumLeft = 0.0;
    double nrRecordsLeft = 0.0;
    double ySumRight = ySumTotal;
    double nrRecordsRight = nrRecordsTotal;
    // all values in the current branch are missing
    if (nrRecordsRight == 0) {
        // it is impossible to determine a split
        return null;
    }
    double bestSplit = Double.NEGATIVE_INFINITY;
    double bestImprovement = 0.0;
    double lastSeenY = Double.NaN;
    double lastSeenValue = Double.NEGATIVE_INFINITY;
    double lastSeenWeight = -1.0;
    // compute the gain, keep the one that maximizes the split
    while (columnMemberships.next()) {
        final double weight = columnMemberships.getRowWeight();
        if (weight < EPSILON) {
            // ignore record: not in current branch or not in sample
            continue;
        } else if (Math.floor(weight) != weight) {
            throw new UnsupportedOperationException("weighted records (missing values?) not supported, " + "weight is " + weight);
        }
        final double value = getSorted(columnMemberships.getIndexInColumn());
        if (lastSeenWeight > 0.0) {
            ySumLeft += lastSeenWeight * lastSeenY;
            ySumRight -= lastSeenWeight * lastSeenY;
            nrRecordsLeft += lastSeenWeight;
            nrRecordsRight -= lastSeenWeight;
            if (nrRecordsLeft >= minChildNodeSize && nrRecordsRight >= minChildNodeSize && lastSeenValue < value) {
                boolean tempMissingsGoLeft = true;
                double childrenSquaredSum;
                if (branchContainsMissingValues && useXGBoostMissingValueHandling) {
                    final double[] tempChildrenSquaredSum = new double[2];
                    tempChildrenSquaredSum[0] = ((ySumLeft + missingY) * (ySumLeft + missingY) / (nrRecordsLeft + missingWeight)) + (ySumRight * ySumRight / nrRecordsRight);
                    tempChildrenSquaredSum[1] = (ySumLeft * ySumLeft / nrRecordsLeft) + ((ySumRight + missingY) * (ySumRight + missingY) / (nrRecordsRight + missingWeight));
                    if (tempChildrenSquaredSum[0] >= tempChildrenSquaredSum[1]) {
                        childrenSquaredSum = tempChildrenSquaredSum[0];
                        tempMissingsGoLeft = true;
                    } else {
                        childrenSquaredSum = tempChildrenSquaredSum[1];
                        tempMissingsGoLeft = false;
                    }
                } else {
                    childrenSquaredSum = (ySumLeft * ySumLeft / nrRecordsLeft) + (ySumRight * ySumRight / nrRecordsRight);
                }
                double criterion = childrenSquaredSum - criterionTotal;
                boolean randomTieBreaker = criterion == bestImprovement ? rd.nextInt(0, 1) == 1 : false;
                if (criterion > bestImprovement || randomTieBreaker) {
                    bestImprovement = criterion;
                    bestSplit = useAverageSplitPoints ? getCenter(lastSeenValue, value) : lastSeenValue;
                    // if there are no missing values go with majority
                    missingsGoLeft = branchContainsMissingValues ? tempMissingsGoLeft : nrRecordsLeft >= nrRecordsRight;
                }
            }
        }
        lastSeenY = targetColumn.getValueFor(columnMemberships.getOriginalIndex());
        lastSeenValue = value;
        lastSeenWeight = weight;
    }
    // + " but was " + lastSeenY * lastSeenWeight;
    if (bestImprovement > 0.0) {
        if (useXGBoostMissingValueHandling) {
            // return new NumericMissingSplitCandidate(this, bestSplit, bestImprovement, missingsGoLeft);
            return new NumericSplitCandidate(this, bestSplit, bestImprovement, new BitSet(), missingsGoLeft ? NumericSplitCandidate.MISSINGS_GO_LEFT : NumericSplitCandidate.MISSINGS_GO_RIGHT);
        }
        return new NumericSplitCandidate(this, bestSplit, bestImprovement, getMissedRows(columnMemberships), NumericSplitCandidate.NO_MISSINGS);
    } else {
        return null;
    }
}
Also used : TreeEnsembleLearnerConfiguration(org.knime.base.node.mine.treeensemble2.node.learner.TreeEnsembleLearnerConfiguration) NumericSplitCandidate(org.knime.base.node.mine.treeensemble2.learner.NumericSplitCandidate) BitSet(java.util.BitSet) ColumnMemberships(org.knime.base.node.mine.treeensemble2.data.memberships.ColumnMemberships)

Example 7 with RowSample

use of org.knime.base.node.mine.treeensemble2.sample.row.RowSample in project knime-core by knime.

the class RegressionTreeLearnerNodeModel method execute.

/**
 * {@inheritDoc}
 */
@Override
protected PortObject[] execute(final PortObject[] inObjects, final ExecutionContext exec) throws Exception {
    BufferedDataTable t = (BufferedDataTable) inObjects[0];
    DataTableSpec spec = t.getDataTableSpec();
    final FilterLearnColumnRearranger learnRearranger = m_configuration.filterLearnColumns(spec);
    String warn = learnRearranger.getWarning();
    BufferedDataTable learnTable = exec.createColumnRearrangeTable(t, learnRearranger, exec.createSubProgress(0.0));
    DataTableSpec learnSpec = learnTable.getDataTableSpec();
    ExecutionMonitor readInExec = exec.createSubProgress(0.1);
    ExecutionMonitor learnExec = exec.createSubProgress(0.9);
    TreeDataCreator dataCreator = new TreeDataCreator(m_configuration, learnSpec, learnTable.getRowCount());
    exec.setProgress("Reading data into memory");
    TreeData data = dataCreator.readData(learnTable, m_configuration, readInExec);
    m_hiliteRowSample = dataCreator.getDataRowsForHilite();
    m_viewMessage = dataCreator.getViewMessage();
    String dataCreationWarning = dataCreator.getAndClearWarningMessage();
    if (dataCreationWarning != null) {
        if (warn == null) {
            warn = dataCreationWarning;
        } else {
            warn = warn + "\n" + dataCreationWarning;
        }
    }
    readInExec.setProgress(1.0);
    exec.setMessage("Learning tree");
    RandomData rd = m_configuration.createRandomData();
    final IDataIndexManager indexManager;
    if (data.getTreeType() == TreeType.BitVector) {
        indexManager = new BitVectorDataIndexManager(data.getNrRows());
    } else {
        indexManager = new DefaultDataIndexManager(data);
    }
    TreeNodeSignatureFactory signatureFactory = null;
    int maxLevels = m_configuration.getMaxLevels();
    if (maxLevels < TreeEnsembleLearnerConfiguration.MAX_LEVEL_INFINITE) {
        int capacity = IntMath.pow(2, maxLevels - 1);
        signatureFactory = new TreeNodeSignatureFactory(capacity);
    } else {
        signatureFactory = new TreeNodeSignatureFactory();
    }
    final RowSample rowSample = m_configuration.createRowSampler(data).createRowSample(rd);
    TreeLearnerRegression treeLearner = new TreeLearnerRegression(m_configuration, data, indexManager, signatureFactory, rd, rowSample);
    TreeModelRegression regTree = treeLearner.learnSingleTree(learnExec, rd);
    RegressionTreeModel model = new RegressionTreeModel(m_configuration, data.getMetaData(), regTree, data.getTreeType());
    RegressionTreeModelPortObjectSpec treePortObjectSpec = new RegressionTreeModelPortObjectSpec(learnSpec);
    RegressionTreeModelPortObject treePortObject = new RegressionTreeModelPortObject(model, treePortObjectSpec);
    learnExec.setProgress(1.0);
    m_treeModelPortObject = treePortObject;
    if (warn != null) {
        setWarningMessage(warn);
    }
    return new PortObject[] { treePortObject };
}
Also used : RegressionTreeModelPortObject(org.knime.base.node.mine.treeensemble2.model.RegressionTreeModelPortObject) DataTableSpec(org.knime.core.data.DataTableSpec) RandomData(org.apache.commons.math.random.RandomData) RegressionTreeModel(org.knime.base.node.mine.treeensemble2.model.RegressionTreeModel) IDataIndexManager(org.knime.base.node.mine.treeensemble2.data.memberships.IDataIndexManager) BitVectorDataIndexManager(org.knime.base.node.mine.treeensemble2.data.memberships.BitVectorDataIndexManager) RegressionTreeModelPortObjectSpec(org.knime.base.node.mine.treeensemble2.model.RegressionTreeModelPortObjectSpec) DefaultDataIndexManager(org.knime.base.node.mine.treeensemble2.data.memberships.DefaultDataIndexManager) TreeModelRegression(org.knime.base.node.mine.treeensemble2.model.TreeModelRegression) BufferedDataTable(org.knime.core.node.BufferedDataTable) FilterLearnColumnRearranger(org.knime.base.node.mine.treeensemble2.node.learner.TreeEnsembleLearnerConfiguration.FilterLearnColumnRearranger) TreeData(org.knime.base.node.mine.treeensemble2.data.TreeData) TreeLearnerRegression(org.knime.base.node.mine.treeensemble2.learner.TreeLearnerRegression) RowSample(org.knime.base.node.mine.treeensemble2.sample.row.RowSample) ExecutionMonitor(org.knime.core.node.ExecutionMonitor) TreeDataCreator(org.knime.base.node.mine.treeensemble2.data.TreeDataCreator) RegressionTreeModelPortObject(org.knime.base.node.mine.treeensemble2.model.RegressionTreeModelPortObject) PortObject(org.knime.core.node.port.PortObject) TreeNodeSignatureFactory(org.knime.base.node.mine.treeensemble2.learner.TreeNodeSignatureFactory)

Example 8 with RowSample

use of org.knime.base.node.mine.treeensemble2.sample.row.RowSample in project knime-core by knime.

the class TreeEnsembleLearner method learnEnsemble.

public TreeEnsembleModel learnEnsemble(final ExecutionMonitor exec) throws CanceledExecutionException, ExecutionException {
    final int nrModels = m_config.getNrModels();
    final RandomData rd = m_config.createRandomData();
    final ThreadPool tp = KNIMEConstants.GLOBAL_THREAD_POOL;
    final AtomicReference<Throwable> learnThrowableRef = new AtomicReference<Throwable>();
    @SuppressWarnings("unchecked") final Future<TreeLearnerResult>[] modelFutures = new Future[nrModels];
    final int procCount = 3 * Runtime.getRuntime().availableProcessors() / 2;
    final Semaphore semaphore = new Semaphore(procCount);
    Callable<TreeLearnerResult[]> learnCallable = new Callable<TreeLearnerResult[]>() {

        @Override
        public TreeLearnerResult[] call() throws Exception {
            final TreeLearnerResult[] results = new TreeLearnerResult[nrModels];
            for (int i = 0; i < nrModels; i++) {
                semaphore.acquire();
                finishedTree(i - procCount, exec);
                checkThrowable(learnThrowableRef);
                RandomData rdSingle = TreeEnsembleLearnerConfiguration.createRandomData(rd.nextLong(Long.MIN_VALUE, Long.MAX_VALUE));
                ExecutionMonitor subExec = exec.createSubProgress(0.0);
                modelFutures[i] = tp.enqueue(new TreeLearnerCallable(subExec, rdSingle, learnThrowableRef, semaphore));
            }
            for (int i = 0; i < procCount; i++) {
                semaphore.acquire();
                finishedTree(nrModels - 1 + i - procCount, exec);
            }
            for (int i = 0; i < nrModels; i++) {
                try {
                    results[i] = modelFutures[i].get();
                } catch (Exception e) {
                    learnThrowableRef.compareAndSet(null, e);
                }
            }
            return results;
        }

        private void finishedTree(final int treeIndex, final ExecutionMonitor progMon) {
            if (treeIndex > 0) {
                progMon.setProgress(treeIndex / (double) nrModels, "Tree " + treeIndex + "/" + nrModels);
            }
        }
    };
    TreeLearnerResult[] modelResults = tp.runInvisible(learnCallable);
    checkThrowable(learnThrowableRef);
    AbstractTreeModel[] models = new AbstractTreeModel[nrModels];
    m_rowSamples = new RowSample[nrModels];
    m_columnSampleStrategies = new ColumnSampleStrategy[nrModels];
    for (int i = 0; i < nrModels; i++) {
        models[i] = modelResults[i].m_treeModel;
        m_rowSamples[i] = modelResults[i].m_rowSample;
        m_columnSampleStrategies[i] = modelResults[i].m_rootColumnSampleStrategy;
    }
    m_ensembleModel = new TreeEnsembleModel(m_config, m_data.getMetaData(), models, m_data.getTreeType());
    return m_ensembleModel;
}
Also used : RandomData(org.apache.commons.math.random.RandomData) TreeEnsembleModel(org.knime.base.node.mine.treeensemble2.model.TreeEnsembleModel) ThreadPool(org.knime.core.util.ThreadPool) AtomicReference(java.util.concurrent.atomic.AtomicReference) AbstractTreeModel(org.knime.base.node.mine.treeensemble2.model.AbstractTreeModel) Semaphore(java.util.concurrent.Semaphore) Callable(java.util.concurrent.Callable) CanceledExecutionException(org.knime.core.node.CanceledExecutionException) ExecutionException(java.util.concurrent.ExecutionException) Future(java.util.concurrent.Future) ExecutionMonitor(org.knime.core.node.ExecutionMonitor)

Aggregations

TreeData (org.knime.base.node.mine.treeensemble2.data.TreeData)6 RowSample (org.knime.base.node.mine.treeensemble2.sample.row.RowSample)6 BitSet (java.util.BitSet)5 TreeEnsembleLearnerConfiguration (org.knime.base.node.mine.treeensemble2.node.learner.TreeEnsembleLearnerConfiguration)5 RandomData (org.apache.commons.math.random.RandomData)3 TreeModelRegression (org.knime.base.node.mine.treeensemble2.model.TreeModelRegression)3 TreeNodeSignature (org.knime.base.node.mine.treeensemble2.model.TreeNodeSignature)3 Test (org.junit.Test)2 TestDataGenerator (org.knime.base.node.mine.treeensemble2.data.TestDataGenerator)2 TreeTargetNumericColumnData (org.knime.base.node.mine.treeensemble2.data.TreeTargetNumericColumnData)2 IDataIndexManager (org.knime.base.node.mine.treeensemble2.data.memberships.IDataIndexManager)2 RootDataMemberships (org.knime.base.node.mine.treeensemble2.data.memberships.RootDataMemberships)2 TreeLearnerRegression (org.knime.base.node.mine.treeensemble2.learner.TreeLearnerRegression)2 TreeNodeSignatureFactory (org.knime.base.node.mine.treeensemble2.learner.TreeNodeSignatureFactory)2 GradientBoostingLearnerConfiguration (org.knime.base.node.mine.treeensemble2.node.gradientboosting.learner.GradientBoostingLearnerConfiguration)2 ColumnSample (org.knime.base.node.mine.treeensemble2.sample.column.ColumnSample)2 DefaultRowSample (org.knime.base.node.mine.treeensemble2.sample.row.DefaultRowSample)2 ExecutionMonitor (org.knime.core.node.ExecutionMonitor)2 ArrayList (java.util.ArrayList)1 HashMap (java.util.HashMap)1