Search in sources :

Example 6 with TreeNodeRegression

use of org.knime.base.node.mine.treeensemble2.model.TreeNodeRegression in project knime-core by knime.

the class LKGradientBoostingPredictorCellFactory method getCells.

/**
 * {@inheritDoc}
 */
@Override
public DataCell[] getCells(final DataRow row) {
    final DataRow filterRow = new FilterColumnRow(row, m_learnColumnInRealDataIndices);
    final int nrClasses = m_model.getNrClasses();
    final int nrLevels = m_model.getNrLevels();
    final PredictorRecord record = m_model.createPredictorRecord(filterRow, m_learnSpec);
    final double[] classFunctionPredictions = new double[nrClasses];
    Arrays.fill(classFunctionPredictions, m_model.getInitialValue());
    for (int i = 0; i < nrLevels; i++) {
        for (int j = 0; j < nrClasses; j++) {
            final TreeNodeRegression matchingNode = m_model.getModel(i, j).findMatchingNode(record);
            classFunctionPredictions[j] += m_model.getCoefficientMap(i, j).get(matchingNode.getSignature());
        }
    }
    final double[] classProbabilities = new double[nrClasses];
    double expSum = 0;
    for (int i = 0; i < nrClasses; i++) {
        classProbabilities[i] = Math.exp(classFunctionPredictions[i]);
        expSum += classProbabilities[i];
    }
    int classIdx = -1;
    double classProb = -1;
    for (int i = 0; i < nrClasses; i++) {
        classProbabilities[i] /= expSum;
        if (classProbabilities[i] > classProb) {
            classIdx = i;
            classProb = classProbabilities[i];
        }
    }
    final ArrayList<DataCell> cells = new ArrayList<DataCell>();
    cells.add(new StringCell(m_model.getClassLabel(classIdx)));
    if (m_config.isAppendPredictionConfidence()) {
        cells.add(new DoubleCell(classProb));
    }
    if (m_config.isAppendClassConfidences()) {
        // the map is necessary to ensure that the probabilities are correctly associated with the column header
        final Map<String, Double> classProbMap = new HashMap<String, Double>((int) (nrClasses * 1.5));
        for (int i = 0; i < nrClasses; i++) {
            classProbMap.put(m_model.getClassLabel(i), classProbabilities[i]);
        }
        for (final String className : m_targetValueMap.keySet()) {
            cells.add(new DoubleCell(classProbMap.get(className)));
        }
    }
    return cells.toArray(new DataCell[cells.size()]);
}
Also used : HashMap(java.util.HashMap) DoubleCell(org.knime.core.data.def.DoubleCell) ArrayList(java.util.ArrayList) DataRow(org.knime.core.data.DataRow) TreeNodeRegression(org.knime.base.node.mine.treeensemble2.model.TreeNodeRegression) StringCell(org.knime.core.data.def.StringCell) PredictorRecord(org.knime.base.node.mine.treeensemble2.data.PredictorRecord) DataCell(org.knime.core.data.DataCell) FilterColumnRow(org.knime.base.data.filter.column.FilterColumnRow)

Example 7 with TreeNodeRegression

use of org.knime.base.node.mine.treeensemble2.model.TreeNodeRegression in project knime-core by knime.

the class TreeEnsembleRegressionPredictorCellFactory method getCells.

/**
 * {@inheritDoc}
 */
@Override
public DataCell[] getCells(final DataRow row) {
    TreeEnsembleModelPortObject modelObject = m_predictor.getModelObject();
    TreeEnsemblePredictorConfiguration cfg = m_predictor.getConfiguration();
    final TreeEnsembleModel ensembleModel = modelObject.getEnsembleModel();
    int size = 1;
    final boolean appendConfidence = cfg.isAppendPredictionConfidence();
    final boolean appendModelCount = cfg.isAppendModelCount();
    if (appendConfidence) {
        size += 1;
    }
    if (appendModelCount) {
        size += 1;
    }
    final boolean hasOutOfBagFilter = m_predictor.hasOutOfBagFilter();
    DataCell[] result = new DataCell[size];
    DataRow filterRow = new FilterColumnRow(row, m_learnColumnInRealDataIndices);
    PredictorRecord record = ensembleModel.createPredictorRecord(filterRow, m_learnSpec);
    if (record == null) {
        // missing value
        Arrays.fill(result, DataType.getMissingCell());
        return result;
    }
    Mean mean = new Mean();
    Variance variance = new Variance();
    final int nrModels = ensembleModel.getNrModels();
    for (int i = 0; i < nrModels; i++) {
        if (hasOutOfBagFilter && m_predictor.isRowPartOfTrainingData(row.getKey(), i)) {
        // ignore, row was used to train the model
        } else {
            TreeModelRegression m = ensembleModel.getTreeModelRegression(i);
            TreeNodeRegression match = m.findMatchingNode(record);
            double nodeMean = match.getMean();
            mean.increment(nodeMean);
            variance.increment(nodeMean);
        }
    }
    int nrValidModels = (int) mean.getN();
    int index = 0;
    result[index++] = nrValidModels == 0 ? DataType.getMissingCell() : new DoubleCell(mean.getResult());
    if (appendConfidence) {
        result[index++] = nrValidModels == 0 ? DataType.getMissingCell() : new DoubleCell(variance.getResult());
    }
    if (appendModelCount) {
        result[index++] = new IntCell(nrValidModels);
    }
    return result;
}
Also used : Mean(org.apache.commons.math.stat.descriptive.moment.Mean) TreeEnsembleModel(org.knime.base.node.mine.treeensemble2.model.TreeEnsembleModel) DoubleCell(org.knime.core.data.def.DoubleCell) TreeEnsemblePredictorConfiguration(org.knime.base.node.mine.treeensemble2.node.predictor.TreeEnsemblePredictorConfiguration) DataRow(org.knime.core.data.DataRow) TreeNodeRegression(org.knime.base.node.mine.treeensemble2.model.TreeNodeRegression) Variance(org.apache.commons.math.stat.descriptive.moment.Variance) TreeModelRegression(org.knime.base.node.mine.treeensemble2.model.TreeModelRegression) IntCell(org.knime.core.data.def.IntCell) TreeEnsembleModelPortObject(org.knime.base.node.mine.treeensemble2.model.TreeEnsembleModelPortObject) PredictorRecord(org.knime.base.node.mine.treeensemble2.data.PredictorRecord) DataCell(org.knime.core.data.DataCell) FilterColumnRow(org.knime.base.data.filter.column.FilterColumnRow)

Example 8 with TreeNodeRegression

use of org.knime.base.node.mine.treeensemble2.model.TreeNodeRegression in project knime-core by knime.

the class MGradientBoostedTreesLearner method calcCoefficientMap.

private Map<TreeNodeSignature, Double> calcCoefficientMap(final double[] residuals, final double quantile, final TreeModelRegression tree) {
    final List<TreeNodeRegression> leafs = tree.getLeafs();
    final Map<TreeNodeSignature, Double> coefficientMap = new HashMap<TreeNodeSignature, Double>((int) (leafs.size() / 0.75 + 1));
    final double learningRate = getConfig().getLearningRate();
    for (TreeNodeRegression leaf : leafs) {
        final int[] indices = leaf.getRowIndicesInTreeData();
        final double[] values = new double[indices.length];
        for (int i = 0; i < indices.length; i++) {
            values[i] = residuals[indices[i]];
        }
        final double median = calcMedian(values);
        double sum = 0;
        for (int i = 0; i < values.length; i++) {
            sum += Math.signum(values[i] - median) * Math.min(quantile, Math.abs(values[i] - median));
        }
        final double coefficient = median + (1.0 / values.length) * sum;
        coefficientMap.put(leaf.getSignature(), coefficient * learningRate);
    }
    return coefficientMap;
}
Also used : HashMap(java.util.HashMap) TreeNodeSignature(org.knime.base.node.mine.treeensemble2.model.TreeNodeSignature) TreeNodeRegression(org.knime.base.node.mine.treeensemble2.model.TreeNodeRegression)

Example 9 with TreeNodeRegression

use of org.knime.base.node.mine.treeensemble2.model.TreeNodeRegression in project knime-core by knime.

the class RegressionTreePredictorCellFactory method getCells.

/**
 * {@inheritDoc}
 */
@Override
public DataCell[] getCells(final DataRow row) {
    final RegressionTreeModel treeModel = m_predictor.getModel();
    int size = 1;
    DataCell[] result = new DataCell[size];
    DataRow filterRow = new FilterColumnRow(row, m_learnColumnInRealDataIndices);
    PredictorRecord record = treeModel.createPredictorRecord(filterRow, m_learnSpec);
    if (record == null) {
        // missing value
        Arrays.fill(result, DataType.getMissingCell());
        return result;
    }
    TreeModelRegression tree = treeModel.getTreeModel();
    TreeNodeRegression match = tree.findMatchingNode(record);
    double nodeMean = match.getMean();
    result[0] = new DoubleCell(nodeMean);
    return result;
}
Also used : RegressionTreeModel(org.knime.base.node.mine.treeensemble2.model.RegressionTreeModel) DoubleCell(org.knime.core.data.def.DoubleCell) PredictorRecord(org.knime.base.node.mine.treeensemble2.data.PredictorRecord) DataCell(org.knime.core.data.DataCell) DataRow(org.knime.core.data.DataRow) TreeNodeRegression(org.knime.base.node.mine.treeensemble2.model.TreeNodeRegression) FilterColumnRow(org.knime.base.data.filter.column.FilterColumnRow) TreeModelRegression(org.knime.base.node.mine.treeensemble2.model.TreeModelRegression)

Example 10 with TreeNodeRegression

use of org.knime.base.node.mine.treeensemble2.model.TreeNodeRegression in project knime-core by knime.

the class TreeLearnerRegression method buildTreeNode.

private TreeNodeRegression buildTreeNode(final ExecutionMonitor exec, final int currentDepth, final DataMemberships dataMemberships, final ColumnSample columnSample, final TreeNodeSignature treeNodeSignature, final RegressionPriors targetPriors, final BitSet forbiddenColumnSet) throws CanceledExecutionException {
    final TreeData data = getData();
    final RandomData rd = getRandomData();
    final TreeEnsembleLearnerConfiguration config = getConfig();
    exec.checkCanceled();
    final SplitCandidate candidate = findBestSplitRegression(currentDepth, dataMemberships, columnSample, targetPriors, forbiddenColumnSet);
    if (candidate == null) {
        if (config instanceof GradientBoostingLearnerConfiguration) {
            TreeNodeRegression leaf = new TreeNodeRegression(treeNodeSignature, targetPriors, dataMemberships.getOriginalIndices());
            addToLeafList(leaf);
            return leaf;
        }
        return new TreeNodeRegression(treeNodeSignature, targetPriors);
    }
    final TreeTargetNumericColumnData targetColumn = (TreeTargetNumericColumnData) data.getTargetColumn();
    boolean useSurrogates = config.getMissingValueHandling() == MissingValueHandling.Surrogate;
    TreeNodeCondition[] childConditions;
    TreeNodeRegression[] childNodes;
    if (useSurrogates) {
        SurrogateSplit surrogateSplit = Surrogates.learnSurrogates(dataMemberships, candidate, data, columnSample, config, rd);
        childConditions = surrogateSplit.getChildConditions();
        BitSet[] childMarkers = surrogateSplit.getChildMarkers();
        assert childMarkers[0].cardinality() + childMarkers[1].cardinality() == dataMemberships.getRowCount() : "Sum of rows in children does not add up to number of rows in parent.";
        childNodes = new TreeNodeRegression[2];
        for (int i = 0; i < 2; i++) {
            DataMemberships childMemberships = dataMemberships.createChildMemberships(childMarkers[i]);
            TreeNodeSignature childSignature = getSignatureFactory().getChildSignatureFor(treeNodeSignature, (byte) i);
            ColumnSample childColumnSample = getColSamplingStrategy().getColumnSampleForTreeNode(childSignature);
            RegressionPriors childTargetPriors = targetColumn.getPriors(childMemberships, config);
            childNodes[i] = buildTreeNode(exec, currentDepth + 1, childMemberships, childColumnSample, childSignature, childTargetPriors, forbiddenColumnSet);
            childNodes[i].setTreeNodeCondition(childConditions[i]);
        }
    } else {
        SplitCandidate bestSplit = candidate;
        TreeAttributeColumnData splitColumn = bestSplit.getColumnData();
        final int attributeIndex = splitColumn.getMetaData().getAttributeIndex();
        boolean markAttributeAsForbidden = !bestSplit.canColumnBeSplitFurther();
        forbiddenColumnSet.set(attributeIndex, markAttributeAsForbidden);
        childConditions = bestSplit.getChildConditions();
        if (childConditions.length > Short.MAX_VALUE) {
            throw new RuntimeException("Too many children when splitting " + "attribute " + bestSplit.getColumnData() + " (maximum supported: " + Short.MAX_VALUE + "): " + childConditions.length);
        }
        childNodes = new TreeNodeRegression[childConditions.length];
        for (int i = 0; i < childConditions.length; i++) {
            TreeNodeCondition cond = childConditions[i];
            DataMemberships childMemberships = dataMemberships.createChildMemberships(splitColumn.updateChildMemberships(cond, dataMemberships));
            RegressionPriors childTargetPriors = targetColumn.getPriors(childMemberships, config);
            TreeNodeSignature childSignature = treeNodeSignature.createChildSignature((byte) i);
            ColumnSample childColumnSample = getColSamplingStrategy().getColumnSampleForTreeNode(childSignature);
            childNodes[i] = buildTreeNode(exec, currentDepth + 1, childMemberships, childColumnSample, childSignature, childTargetPriors, forbiddenColumnSet);
            childNodes[i].setTreeNodeCondition(cond);
        }
        if (markAttributeAsForbidden) {
            forbiddenColumnSet.set(attributeIndex, false);
        }
    }
    return new TreeNodeRegression(treeNodeSignature, targetPriors, childNodes);
}
Also used : TreeEnsembleLearnerConfiguration(org.knime.base.node.mine.treeensemble2.node.learner.TreeEnsembleLearnerConfiguration) RandomData(org.apache.commons.math.random.RandomData) TreeAttributeColumnData(org.knime.base.node.mine.treeensemble2.data.TreeAttributeColumnData) ColumnSample(org.knime.base.node.mine.treeensemble2.sample.column.ColumnSample) RegressionPriors(org.knime.base.node.mine.treeensemble2.data.RegressionPriors) BitSet(java.util.BitSet) TreeTargetNumericColumnData(org.knime.base.node.mine.treeensemble2.data.TreeTargetNumericColumnData) TreeNodeSignature(org.knime.base.node.mine.treeensemble2.model.TreeNodeSignature) TreeNodeRegression(org.knime.base.node.mine.treeensemble2.model.TreeNodeRegression) DataMemberships(org.knime.base.node.mine.treeensemble2.data.memberships.DataMemberships) RootDataMemberships(org.knime.base.node.mine.treeensemble2.data.memberships.RootDataMemberships) GradientBoostingLearnerConfiguration(org.knime.base.node.mine.treeensemble2.node.gradientboosting.learner.GradientBoostingLearnerConfiguration) TreeData(org.knime.base.node.mine.treeensemble2.data.TreeData) TreeNodeCondition(org.knime.base.node.mine.treeensemble2.model.TreeNodeCondition)

Aggregations

TreeNodeRegression (org.knime.base.node.mine.treeensemble2.model.TreeNodeRegression)9 TreeNodeSignature (org.knime.base.node.mine.treeensemble2.model.TreeNodeSignature)6 TreeModelRegression (org.knime.base.node.mine.treeensemble2.model.TreeModelRegression)4 HashMap (java.util.HashMap)3 FilterColumnRow (org.knime.base.data.filter.column.FilterColumnRow)3 PredictorRecord (org.knime.base.node.mine.treeensemble2.data.PredictorRecord)3 TreeTargetNumericColumnData (org.knime.base.node.mine.treeensemble2.data.TreeTargetNumericColumnData)3 DataCell (org.knime.core.data.DataCell)3 DataRow (org.knime.core.data.DataRow)3 DoubleCell (org.knime.core.data.def.DoubleCell)3 BitSet (java.util.BitSet)2 RegressionPriors (org.knime.base.node.mine.treeensemble2.data.RegressionPriors)2 TreeData (org.knime.base.node.mine.treeensemble2.data.TreeData)2 TreeTargetNumericColumnMetaData (org.knime.base.node.mine.treeensemble2.data.TreeTargetNumericColumnMetaData)2 DataMemberships (org.knime.base.node.mine.treeensemble2.data.memberships.DataMemberships)2 RootDataMemberships (org.knime.base.node.mine.treeensemble2.data.memberships.RootDataMemberships)2 GradientBoostingLearnerConfiguration (org.knime.base.node.mine.treeensemble2.node.gradientboosting.learner.GradientBoostingLearnerConfiguration)2 TreeEnsembleLearnerConfiguration (org.knime.base.node.mine.treeensemble2.node.learner.TreeEnsembleLearnerConfiguration)2 ColumnSample (org.knime.base.node.mine.treeensemble2.sample.column.ColumnSample)2 ArrayList (java.util.ArrayList)1