Search in sources :

Example 11 with TreeModelRegression

use of org.knime.base.node.mine.treeensemble2.model.TreeModelRegression in project knime-core by knime.

the class AbstractGBTModelExporter method writeSumSegmentation.

protected void writeSumSegmentation(final Segmentation segmentation, final Collection<TreeModelRegression> trees, final Collection<Map<TreeNodeSignature, Double>> coefficientMaps) {
    assert trees.size() == coefficientMaps.size() : "The number of trees does not match the number of coefficient maps.";
    segmentation.setMultipleModelMethod(MULTIPLEMODELMETHOD.SUM);
    Iterator<TreeModelRegression> treeIterator = trees.iterator();
    Iterator<Map<TreeNodeSignature, Double>> coefficientMapIterator = coefficientMaps.iterator();
    for (int i = 1; i <= trees.size(); i++) {
        Segment segment = segmentation.addNewSegment();
        segment.setId(Integer.toString(i));
        segment.addNewTrue();
        writeTreeIntoSegment(segment, treeIterator.next(), coefficientMapIterator.next());
    }
}
Also used : Map(java.util.Map) TreeModelRegression(org.knime.base.node.mine.treeensemble2.model.TreeModelRegression) Segment(org.dmg.pmml.SegmentDocument.Segment)

Example 12 with TreeModelRegression

use of org.knime.base.node.mine.treeensemble2.model.TreeModelRegression in project knime-core by knime.

the class ClassificationGBTModelImporter method importFromPMMLInternal.

/**
 * {@inheritDoc}
 */
@Override
protected MultiClassGradientBoostedTreesModel importFromPMMLInternal(final MiningModel miningModel) {
    Segmentation modelChain = miningModel.getSegmentation();
    CheckUtils.checkArgument(modelChain.getMultipleModelMethod() == MULTIPLEMODELMETHOD.MODEL_CHAIN, "The top level segmentation should have multiple model method '%s' but has '%s'", MULTIPLEMODELMETHOD.MODEL_CHAIN, modelChain.getMultipleModelMethod());
    List<List<TreeModelRegression>> trees = new ArrayList<>();
    List<List<Map<TreeNodeSignature, Double>>> coefficientMaps = new ArrayList<>();
    List<String> classLabels = new ArrayList<>();
    List<Segment> segments = modelChain.getSegmentList();
    for (int i = 0; i < segments.size() - 1; i++) {
        Pair<List<TreeModelRegression>, List<Map<TreeNodeSignature, Double>>> gbtPair = processClassSegment(segments.get(i));
        trees.add(gbtPair.getFirst());
        coefficientMaps.add(gbtPair.getSecond());
        classLabels.add(extractClassLabel(segments.get(i)));
    }
    double initialValue = extractInitialValue(segments.get(0));
    return MultiClassGradientBoostedTreesModel.create(getMetaDataMapper().getTreeMetaData(), trees, coefficientMaps, initialValue, TreeType.Ordinary, classLabels);
}
Also used : Segmentation(org.dmg.pmml.SegmentationDocument.Segmentation) ArrayList(java.util.ArrayList) TreeNodeSignature(org.knime.base.node.mine.treeensemble2.model.TreeNodeSignature) Segment(org.dmg.pmml.SegmentDocument.Segment) ArrayList(java.util.ArrayList) List(java.util.List)

Example 13 with TreeModelRegression

use of org.knime.base.node.mine.treeensemble2.model.TreeModelRegression in project knime-core by knime.

the class AbstractGBTModelImporter method readTreeModel.

private Pair<TreeModelRegression, Map<TreeNodeSignature, Double>> readTreeModel(final Segment segment) {
    GBTRegressionContentParser contentParser = new GBTRegressionContentParser();
    TreeModelImporter<TreeNodeRegression, TreeModelRegression, TreeTargetNumericColumnMetaData> treeImporter = new TreeModelImporter<TreeNodeRegression, TreeModelRegression, TreeTargetNumericColumnMetaData>(m_metaDataMapper, m_conditionParser, m_signatureFactory, contentParser, m_treeFactory);
    TreeModel treeModel = segment.getTreeModel();
    TreeModelRegression tree = treeImporter.importFromPMML(treeModel);
    Map<TreeNodeSignature, Double> coefficientMap = contentParser.getCoefficientMap();
    return new Pair<>(tree, coefficientMap);
}
Also used : TreeModel(org.dmg.pmml.TreeModelDocument.TreeModel) TreeTargetNumericColumnMetaData(org.knime.base.node.mine.treeensemble2.data.TreeTargetNumericColumnMetaData) TreeNodeSignature(org.knime.base.node.mine.treeensemble2.model.TreeNodeSignature) TreeNodeRegression(org.knime.base.node.mine.treeensemble2.model.TreeNodeRegression) TreeModelRegression(org.knime.base.node.mine.treeensemble2.model.TreeModelRegression) Pair(org.knime.core.util.Pair)

Example 14 with TreeModelRegression

use of org.knime.base.node.mine.treeensemble2.model.TreeModelRegression in project knime-core by knime.

the class AbstractGBTModelImporter method readSumSegmentation.

protected Pair<List<TreeModelRegression>, List<Map<TreeNodeSignature, Double>>> readSumSegmentation(final Segmentation segmentation) {
    List<Segment> segments = segmentation.getSegmentList();
    List<TreeModelRegression> trees = new ArrayList<>(segments.size());
    List<Map<TreeNodeSignature, Double>> coefficientMaps = new ArrayList<>(segments.size());
    for (Segment segment : segments) {
        Pair<TreeModelRegression, Map<TreeNodeSignature, Double>> treeCoeffientMapPair = readTreeModel(segment);
        trees.add(treeCoeffientMapPair.getFirst());
        coefficientMaps.add(treeCoeffientMapPair.getSecond());
    }
    return new Pair<>(trees, coefficientMaps);
}
Also used : ArrayList(java.util.ArrayList) Map(java.util.Map) Segment(org.dmg.pmml.SegmentDocument.Segment) TreeModelRegression(org.knime.base.node.mine.treeensemble2.model.TreeModelRegression) Pair(org.knime.core.util.Pair)

Example 15 with TreeModelRegression

use of org.knime.base.node.mine.treeensemble2.model.TreeModelRegression in project knime-core by knime.

the class ClassificationGBTModelExporter method addSegmentation.

private void addSegmentation(final MiningModel miningModel, final int c) {
    Segmentation seg = miningModel.addNewSegmentation();
    MultiClassGradientBoostedTreesModel gbt = getGBTModel();
    Collection<TreeModelRegression> trees = IntStream.range(0, gbt.getNrLevels()).mapToObj(i -> gbt.getModel(i, c)).collect(Collectors.toList());
    Collection<Map<TreeNodeSignature, Double>> coefficientMaps = IntStream.range(0, gbt.getNrLevels()).mapToObj(i -> gbt.getCoefficientMap(i, c)).collect(Collectors.toList());
    writeSumSegmentation(seg, trees, coefficientMaps);
}
Also used : IntStream(java.util.stream.IntStream) MININGFUNCTION(org.dmg.pmml.MININGFUNCTION) Enum(org.dmg.pmml.MININGFUNCTION.Enum) Targets(org.dmg.pmml.TargetsDocument.Targets) DATATYPE(org.dmg.pmml.DATATYPE) PMMLMiningSchemaTranslator(org.knime.core.node.port.pmml.PMMLMiningSchemaTranslator) RegressionTable(org.dmg.pmml.RegressionTableDocument.RegressionTable) Output(org.dmg.pmml.OutputDocument.Output) RESULTFEATURE(org.dmg.pmml.RESULTFEATURE) MiningSchema(org.dmg.pmml.MiningSchemaDocument.MiningSchema) Map(java.util.Map) Target(org.dmg.pmml.TargetDocument.Target) FIELDUSAGETYPE(org.dmg.pmml.FIELDUSAGETYPE) TreeNodeSignature(org.knime.base.node.mine.treeensemble2.model.TreeNodeSignature) Collection(java.util.Collection) Segmentation(org.dmg.pmml.SegmentationDocument.Segmentation) RegressionModel(org.dmg.pmml.RegressionModelDocument.RegressionModel) Collectors(java.util.stream.Collectors) MiningField(org.dmg.pmml.MiningFieldDocument.MiningField) OPTYPE(org.dmg.pmml.OPTYPE) MULTIPLEMODELMETHOD(org.dmg.pmml.MULTIPLEMODELMETHOD) NumericPredictor(org.dmg.pmml.NumericPredictorDocument.NumericPredictor) MultiClassGradientBoostedTreesModel(org.knime.base.node.mine.treeensemble2.model.MultiClassGradientBoostedTreesModel) REGRESSIONNORMALIZATIONMETHOD(org.dmg.pmml.REGRESSIONNORMALIZATIONMETHOD) DerivedFieldMapper(org.knime.core.node.port.pmml.preproc.DerivedFieldMapper) MiningModel(org.dmg.pmml.MiningModelDocument.MiningModel) Segment(org.dmg.pmml.SegmentDocument.Segment) OutputField(org.dmg.pmml.OutputFieldDocument.OutputField) TreeModelRegression(org.knime.base.node.mine.treeensemble2.model.TreeModelRegression) Segmentation(org.dmg.pmml.SegmentationDocument.Segmentation) MultiClassGradientBoostedTreesModel(org.knime.base.node.mine.treeensemble2.model.MultiClassGradientBoostedTreesModel) Map(java.util.Map) TreeModelRegression(org.knime.base.node.mine.treeensemble2.model.TreeModelRegression)

Aggregations

TreeModelRegression (org.knime.base.node.mine.treeensemble2.model.TreeModelRegression)13 TreeNodeSignature (org.knime.base.node.mine.treeensemble2.model.TreeNodeSignature)10 TreeNodeRegression (org.knime.base.node.mine.treeensemble2.model.TreeNodeRegression)6 Map (java.util.Map)5 TreeData (org.knime.base.node.mine.treeensemble2.data.TreeData)5 ArrayList (java.util.ArrayList)4 HashMap (java.util.HashMap)4 Segment (org.dmg.pmml.SegmentDocument.Segment)4 Segmentation (org.dmg.pmml.SegmentationDocument.Segmentation)4 PredictorRecord (org.knime.base.node.mine.treeensemble2.data.PredictorRecord)4 RandomData (org.apache.commons.math.random.RandomData)3 TreeTargetNumericColumnData (org.knime.base.node.mine.treeensemble2.data.TreeTargetNumericColumnData)3 IDataIndexManager (org.knime.base.node.mine.treeensemble2.data.memberships.IDataIndexManager)3 RegressionTreeModel (org.knime.base.node.mine.treeensemble2.model.RegressionTreeModel)3 DataRow (org.knime.core.data.DataRow)3 DataTableSpec (org.knime.core.data.DataTableSpec)3 IOException (java.io.IOException)2 List (java.util.List)2 Target (org.dmg.pmml.TargetDocument.Target)2 Targets (org.dmg.pmml.TargetsDocument.Targets)2