Search in sources :

Example 11 with MultiClassGradientBoostedTreesModel

use of org.knime.base.node.mine.treeensemble2.model.MultiClassGradientBoostedTreesModel in project knime-core by knime.

the class LKGradientBoostingPredictorCellFactory method createFactory.

public static LKGradientBoostingPredictorCellFactory createFactory(final GradientBoostingPredictor<MultiClassGradientBoostedTreesModel> predictor) throws InvalidSettingsException {
    TreeEnsemblePredictorConfiguration config = predictor.getConfiguration();
    DataTableSpec testSpec = predictor.getDataSpec();
    TreeEnsembleModelPortObjectSpec modelSpec = predictor.getModelSpec();
    ArrayList<DataColumnSpec> newColSpecs = new ArrayList<DataColumnSpec>();
    UniqueNameGenerator nameGen = new UniqueNameGenerator(testSpec);
    newColSpecs.add(nameGen.newColumn(config.getPredictionColumnName(), StringCell.TYPE));
    if (config.isAppendPredictionConfidence()) {
        newColSpecs.add(nameGen.newColumn("Confidence", DoubleCell.TYPE));
    }
    if (config.isAppendClassConfidences()) {
        final String targetColName = modelSpec.getTargetColumn().getName();
        final String suffix = config.getSuffixForClassProbabilities();
        for (String val : modelSpec.getTargetColumnPossibleValueMap().keySet()) {
            String colName = "P(" + targetColName + "=" + val + ")" + suffix;
            newColSpecs.add(nameGen.newColumn(colName, DoubleCell.TYPE));
        }
    }
    final Map<String, DataCell> targetValueMap = modelSpec.getTargetColumnPossibleValueMap();
    return new LKGradientBoostingPredictorCellFactory(newColSpecs.toArray(new DataColumnSpec[newColSpecs.size()]), predictor.getModel(), modelSpec.getLearnTableSpec(), modelSpec.calculateFilterIndices(testSpec), config, targetValueMap);
}
Also used : DataTableSpec(org.knime.core.data.DataTableSpec) DataColumnSpec(org.knime.core.data.DataColumnSpec) TreeEnsembleModelPortObjectSpec(org.knime.base.node.mine.treeensemble2.model.TreeEnsembleModelPortObjectSpec) TreeEnsemblePredictorConfiguration(org.knime.base.node.mine.treeensemble2.node.predictor.TreeEnsemblePredictorConfiguration) ArrayList(java.util.ArrayList) DataCell(org.knime.core.data.DataCell) UniqueNameGenerator(org.knime.core.util.UniqueNameGenerator)

Aggregations

MultiClassGradientBoostedTreesModel (org.knime.base.node.mine.treeensemble2.model.MultiClassGradientBoostedTreesModel)7 TreeEnsembleModelPortObjectSpec (org.knime.base.node.mine.treeensemble2.model.TreeEnsembleModelPortObjectSpec)5 DataTableSpec (org.knime.core.data.DataTableSpec)5 ColumnRearranger (org.knime.core.data.container.ColumnRearranger)4 ArrayList (java.util.ArrayList)3 Segmentation (org.dmg.pmml.SegmentationDocument.Segmentation)3 GradientBoostingModelPortObject (org.knime.base.node.mine.treeensemble2.model.GradientBoostingModelPortObject)3 TreeNodeSignature (org.knime.base.node.mine.treeensemble2.model.TreeNodeSignature)3 GradientBoostingPredictor (org.knime.base.node.mine.treeensemble2.node.gradientboosting.predictor.GradientBoostingPredictor)3 Map (java.util.Map)2 Segment (org.dmg.pmml.SegmentDocument.Segment)2 GradientBoostedTreesModel (org.knime.base.node.mine.treeensemble2.model.GradientBoostedTreesModel)2 TreeModelRegression (org.knime.base.node.mine.treeensemble2.model.TreeModelRegression)2 BufferedInputStream (java.io.BufferedInputStream)1 IOException (java.io.IOException)1 Collection (java.util.Collection)1 HashMap (java.util.HashMap)1 List (java.util.List)1 Future (java.util.concurrent.Future)1 Semaphore (java.util.concurrent.Semaphore)1