Search in sources :

Example 6 with TreeData

use of org.knime.base.node.mine.treeensemble2.data.TreeData in project knime-core by knime.

the class AbstractGradientBoostedTreesLearner method adaptPreviousPrediction.

/**
 * Adapts the previous prediction by adding the predictions of the <b>tree</b> regulated by the respective
 * coefficients in <b>coefficientMap</b>.
 *
 * @param previousPrediction Prediction of the previous steps
 * @param tree the tree of the current iteration
 * @param coefficientMap contains the coefficients for the leafs of the tree
 */
protected void adaptPreviousPrediction(final double[] previousPrediction, final TreeModelRegression tree, final Map<TreeNodeSignature, Double> coefficientMap) {
    TreeData data = getData();
    IDataIndexManager indexManager = getIndexManager();
    for (int i = 0; i < data.getNrRows(); i++) {
        PredictorRecord record = createPredictorRecord(data, indexManager, i);
        previousPrediction[i] += coefficientMap.get(tree.findMatchingNode(record).getSignature());
    }
}
Also used : PredictorRecord(org.knime.base.node.mine.treeensemble2.data.PredictorRecord) TreeData(org.knime.base.node.mine.treeensemble2.data.TreeData) IDataIndexManager(org.knime.base.node.mine.treeensemble2.data.memberships.IDataIndexManager)

Example 7 with TreeData

use of org.knime.base.node.mine.treeensemble2.data.TreeData in project knime-core by knime.

the class LKGradientBoostedTreesLearner method calculateCoefficientMap.

private Map<TreeNodeSignature, Double> calculateCoefficientMap(final TreeModelRegression tree, final TreeData pseudoResiduals, final double numClasses) {
    final List<TreeNodeRegression> leafs = tree.getLeafs();
    final Map<TreeNodeSignature, Double> coefficientMap = new HashMap<TreeNodeSignature, Double>();
    final TreeTargetNumericColumnData pseudoTarget = (TreeTargetNumericColumnData) pseudoResiduals.getTargetColumn();
    double learningRate = getConfig().getLearningRate();
    for (TreeNodeRegression leaf : leafs) {
        final int[] indices = leaf.getRowIndicesInTreeData();
        double sumTop = 0;
        double sumBottom = 0;
        for (int index : indices) {
            double val = pseudoTarget.getValueFor(index);
            sumTop += val;
            double absVal = Math.abs(val);
            sumBottom += Math.abs(absVal) * (1 - Math.abs(absVal));
        }
        final double coefficient = (numClasses - 1) / numClasses * (sumTop / sumBottom);
        coefficientMap.put(leaf.getSignature(), learningRate * coefficient);
    }
    return coefficientMap;
}
Also used : HashMap(java.util.HashMap) TreeTargetNumericColumnData(org.knime.base.node.mine.treeensemble2.data.TreeTargetNumericColumnData) TreeNodeSignature(org.knime.base.node.mine.treeensemble2.model.TreeNodeSignature) TreeNodeRegression(org.knime.base.node.mine.treeensemble2.model.TreeNodeRegression)

Example 8 with TreeData

use of org.knime.base.node.mine.treeensemble2.data.TreeData in project knime-core by knime.

the class LKGradientBoostedTreesLearner method adaptPreviousFunction.

private void adaptPreviousFunction(final double[] previousFunction, final TreeModelRegression tree, final Map<TreeNodeSignature, Double> coefficientMap) {
    final TreeData data = getData();
    final IDataIndexManager indexManager = getIndexManager();
    for (int i = 0; i < previousFunction.length; i++) {
        final PredictorRecord record = createPredictorRecord(data, indexManager, i);
        final TreeNodeSignature signature = tree.findMatchingNode(record).getSignature();
        previousFunction[i] += coefficientMap.get(signature);
    }
}
Also used : PredictorRecord(org.knime.base.node.mine.treeensemble2.data.PredictorRecord) TreeData(org.knime.base.node.mine.treeensemble2.data.TreeData) IDataIndexManager(org.knime.base.node.mine.treeensemble2.data.memberships.IDataIndexManager) TreeNodeSignature(org.knime.base.node.mine.treeensemble2.model.TreeNodeSignature)

Example 9 with TreeData

use of org.knime.base.node.mine.treeensemble2.data.TreeData in project knime-core by knime.

the class LKGradientBoostedTreesLearner method createNumericDataFromArray.

private TreeData createNumericDataFromArray(final double[] numericData) {
    TreeData data = getData();
    TreeTargetNominalColumnData nominalTarget = (TreeTargetNominalColumnData) data.getTargetColumn();
    TreeTargetNumericColumnMetaData newMeta = new TreeTargetNumericColumnMetaData(nominalTarget.getMetaData().getAttributeName());
    TreeTargetNumericColumnData newTarget = new TreeTargetNumericColumnData(newMeta, nominalTarget.getRowKeys(), numericData);
    return new TreeData(data.getColumns(), newTarget, data.getTreeType());
}
Also used : TreeTargetNumericColumnData(org.knime.base.node.mine.treeensemble2.data.TreeTargetNumericColumnData) TreeData(org.knime.base.node.mine.treeensemble2.data.TreeData) TreeTargetNumericColumnMetaData(org.knime.base.node.mine.treeensemble2.data.TreeTargetNumericColumnMetaData) TreeTargetNominalColumnData(org.knime.base.node.mine.treeensemble2.data.TreeTargetNominalColumnData)

Example 10 with TreeData

use of org.knime.base.node.mine.treeensemble2.data.TreeData in project knime-core by knime.

the class LKGradientBoostedTreesLearner method learn.

/**
 * {@inheritDoc}
 *
 * @throws ExecutionException
 * @throws InterruptedException
 */
@Override
public MultiClassGradientBoostedTreesModel learn(final ExecutionMonitor exec) throws CanceledExecutionException, InterruptedException, ExecutionException {
    final TreeData data = getData();
    final TreeTargetNominalColumnData target = (TreeTargetNominalColumnData) data.getTargetColumn();
    final NominalValueRepresentation[] classNomVals = target.getMetaData().getValues();
    final int numClasses = classNomVals.length;
    final String[] classLabels = new String[numClasses];
    final int nrModels = getConfig().getNrModels();
    final int nrRows = target.getNrRows();
    final TreeModelRegression[][] models = new TreeModelRegression[nrModels][numClasses];
    final ArrayList<ArrayList<Map<TreeNodeSignature, Double>>> coefficientMaps = new ArrayList<ArrayList<Map<TreeNodeSignature, Double>>>(nrModels);
    // variables for parallelization
    final ThreadPool tp = KNIMEConstants.GLOBAL_THREAD_POOL;
    final AtomicReference<Throwable> learnThrowableRef = new AtomicReference<Throwable>();
    final int procCount = 3 * Runtime.getRuntime().availableProcessors() / 2;
    exec.setMessage("Transforming problem");
    // transform the original k class classification problem into k regression problems
    final TreeData[] actual = new TreeData[numClasses];
    for (int i = 0; i < numClasses; i++) {
        final double[] newTarget = calculateNewTarget(target, i);
        actual[i] = createNumericDataFromArray(newTarget);
        classLabels[i] = classNomVals[i].getNominalValue();
    }
    final RandomData rd = getConfig().createRandomData();
    final double[][] previousFunctions = new double[numClasses][nrRows];
    TreeNodeSignatureFactory signatureFactory = null;
    final int maxLevels = getConfig().getMaxLevels();
    if (maxLevels < TreeEnsembleLearnerConfiguration.MAX_LEVEL_INFINITE) {
        int capacity = IntMath.pow(2, maxLevels - 1);
        signatureFactory = new TreeNodeSignatureFactory(capacity);
    } else {
        signatureFactory = new TreeNodeSignatureFactory();
    }
    exec.setMessage("Learn trees");
    for (int i = 0; i < nrModels; i++) {
        final Semaphore semaphore = new Semaphore(procCount);
        final ArrayList<Map<TreeNodeSignature, Double>> classCoefficientMaps = new ArrayList<Map<TreeNodeSignature, Double>>(numClasses);
        // prepare calculation of pseudoResiduals
        final double[][] probs = new double[numClasses][nrRows];
        for (int r = 0; r < nrRows; r++) {
            double sumExpF = 0;
            for (int j = 0; j < numClasses; j++) {
                sumExpF += Math.exp(previousFunctions[j][r]);
            }
            for (int j = 0; j < numClasses; j++) {
                probs[j][r] = Math.exp(previousFunctions[j][r]) / sumExpF;
            }
        }
        final Future<?>[] treeCoefficientMapPairs = new Future<?>[numClasses];
        for (int j = 0; j < numClasses; j++) {
            checkThrowable(learnThrowableRef);
            final RandomData rdSingle = TreeEnsembleLearnerConfiguration.createRandomData(rd.nextLong(Long.MIN_VALUE, Long.MAX_VALUE));
            final ExecutionMonitor subExec = exec.createSubProgress(0.0);
            semaphore.acquire();
            treeCoefficientMapPairs[j] = tp.enqueue(new TreeLearnerCallable(rdSingle, probs[j], actual[j], subExec, numClasses, previousFunctions[j], semaphore, learnThrowableRef, signatureFactory));
        }
        for (int j = 0; j < numClasses; j++) {
            checkThrowable(learnThrowableRef);
            semaphore.acquire();
            final Pair<TreeModelRegression, Map<TreeNodeSignature, Double>> pair = (Pair<TreeModelRegression, Map<TreeNodeSignature, Double>>) treeCoefficientMapPairs[j].get();
            models[i][j] = pair.getFirst();
            classCoefficientMaps.add(pair.getSecond());
            semaphore.release();
        }
        checkThrowable(learnThrowableRef);
        coefficientMaps.add(classCoefficientMaps);
        exec.setProgress((double) i / nrModels, "Finished level " + i + "/" + nrModels);
    }
    return MultiClassGradientBoostedTreesModel.createMultiClassGradientBoostedTreesModel(getConfig(), data.getMetaData(), models, data.getTreeType(), 0, numClasses, coefficientMaps, classLabels);
}
Also used : RandomData(org.apache.commons.math.random.RandomData) ArrayList(java.util.ArrayList) ThreadPool(org.knime.core.util.ThreadPool) NominalValueRepresentation(org.knime.base.node.mine.treeensemble2.data.NominalValueRepresentation) Semaphore(java.util.concurrent.Semaphore) TreeNodeSignature(org.knime.base.node.mine.treeensemble2.model.TreeNodeSignature) TreeModelRegression(org.knime.base.node.mine.treeensemble2.model.TreeModelRegression) ExecutionMonitor(org.knime.core.node.ExecutionMonitor) TreeNodeSignatureFactory(org.knime.base.node.mine.treeensemble2.learner.TreeNodeSignatureFactory) Pair(org.knime.core.util.Pair) AtomicReference(java.util.concurrent.atomic.AtomicReference) Future(java.util.concurrent.Future) TreeData(org.knime.base.node.mine.treeensemble2.data.TreeData) HashMap(java.util.HashMap) Map(java.util.Map) TreeTargetNominalColumnData(org.knime.base.node.mine.treeensemble2.data.TreeTargetNominalColumnData)

Aggregations

TreeData (org.knime.base.node.mine.treeensemble2.data.TreeData)27 TreeEnsembleLearnerConfiguration (org.knime.base.node.mine.treeensemble2.node.learner.TreeEnsembleLearnerConfiguration)25 Test (org.junit.Test)18 DataMemberships (org.knime.base.node.mine.treeensemble2.data.memberships.DataMemberships)18 RootDataMemberships (org.knime.base.node.mine.treeensemble2.data.memberships.RootDataMemberships)18 DefaultDataIndexManager (org.knime.base.node.mine.treeensemble2.data.memberships.DefaultDataIndexManager)15 IDataIndexManager (org.knime.base.node.mine.treeensemble2.data.memberships.IDataIndexManager)15 RandomData (org.apache.commons.math.random.RandomData)14 TreeAttributeColumnData (org.knime.base.node.mine.treeensemble2.data.TreeAttributeColumnData)12 SplitCandidate (org.knime.base.node.mine.treeensemble2.learner.SplitCandidate)12 BitSet (java.util.BitSet)11 TreeTargetNumericColumnData (org.knime.base.node.mine.treeensemble2.data.TreeTargetNumericColumnData)8 TreeNodeSignature (org.knime.base.node.mine.treeensemble2.model.TreeNodeSignature)8 ExecutionMonitor (org.knime.core.node.ExecutionMonitor)8 TreeDataCreator (org.knime.base.node.mine.treeensemble2.data.TreeDataCreator)7 TreeTargetNominalColumnData (org.knime.base.node.mine.treeensemble2.data.TreeTargetNominalColumnData)7 NominalBinarySplitCandidate (org.knime.base.node.mine.treeensemble2.learner.NominalBinarySplitCandidate)7 NominalMultiwaySplitCandidate (org.knime.base.node.mine.treeensemble2.learner.NominalMultiwaySplitCandidate)7 FilterLearnColumnRearranger (org.knime.base.node.mine.treeensemble2.node.learner.TreeEnsembleLearnerConfiguration.FilterLearnColumnRearranger)7 TreeEnsembleModelPortObjectSpec (org.knime.base.node.mine.treeensemble2.model.TreeEnsembleModelPortObjectSpec)6