Search in sources :

Example 1 with Segmentation

use of org.dmg.pmml.SegmentationDocument.Segmentation in project knime-core by knime.

the class ClassificationGBTModelExporter method doWrite.

/**
 * {@inheritDoc}
 */
@Override
protected void doWrite(final MiningModel model) {
    Segmentation modelChain = model.addNewSegmentation();
    modelChain.setMultipleModelMethod(MULTIPLEMODELMETHOD.MODEL_CHAIN);
    MultiClassGradientBoostedTreesModel gbt = getGBTModel();
    // write one segment per class
    for (int i = 0; i < gbt.getNrClasses(); i++) {
        addClassSegment(modelChain, i);
    }
    // combine class predictions
    addAggregationSegment(modelChain);
}
Also used : Segmentation(org.dmg.pmml.SegmentationDocument.Segmentation) MultiClassGradientBoostedTreesModel(org.knime.base.node.mine.treeensemble2.model.MultiClassGradientBoostedTreesModel)

Example 2 with Segmentation

use of org.dmg.pmml.SegmentationDocument.Segmentation in project knime-core by knime.

the class RegressionGBTModelImporter method importFromPMMLInternal.

/**
 * {@inheritDoc}
 */
@Override
public GradientBoostedTreesModel importFromPMMLInternal(final MiningModel miningModel) {
    Segmentation segmentation = miningModel.getSegmentation();
    CheckUtils.checkArgument(segmentation.getMultipleModelMethod() == MULTIPLEMODELMETHOD.SUM, "The provided segmentation has not the required sum as multiple model method but '%s' instead.", segmentation.getMultipleModelMethod());
    Pair<List<TreeModelRegression>, List<Map<TreeNodeSignature, Double>>> treesCoeffientMapsPair = readSumSegmentation(segmentation);
    List<TreeModelRegression> trees = treesCoeffientMapsPair.getFirst();
    // TODO user should be warned if there is no initial value or anything else is fishy
    double initialValue = miningModel.getTargets().getTargetList().get(0).getRescaleConstant();
    // currently only models learned on "ordinary" columns can be read back in
    return new GradientBoostedTreesModel(getMetaDataMapper().getTreeMetaData(), trees.toArray(new TreeModelRegression[trees.size()]), TreeType.Ordinary, initialValue, treesCoeffientMapsPair.getSecond());
}
Also used : Segmentation(org.dmg.pmml.SegmentationDocument.Segmentation) List(java.util.List) GradientBoostedTreesModel(org.knime.base.node.mine.treeensemble2.model.GradientBoostedTreesModel) TreeNodeSignature(org.knime.base.node.mine.treeensemble2.model.TreeNodeSignature) TreeModelRegression(org.knime.base.node.mine.treeensemble2.model.TreeModelRegression)

Example 3 with Segmentation

use of org.dmg.pmml.SegmentationDocument.Segmentation in project knime-core by knime.

the class ClassificationGBTModelImporter method importFromPMMLInternal.

/**
 * {@inheritDoc}
 */
@Override
protected MultiClassGradientBoostedTreesModel importFromPMMLInternal(final MiningModel miningModel) {
    Segmentation modelChain = miningModel.getSegmentation();
    CheckUtils.checkArgument(modelChain.getMultipleModelMethod() == MULTIPLEMODELMETHOD.MODEL_CHAIN, "The top level segmentation should have multiple model method '%s' but has '%s'", MULTIPLEMODELMETHOD.MODEL_CHAIN, modelChain.getMultipleModelMethod());
    List<List<TreeModelRegression>> trees = new ArrayList<>();
    List<List<Map<TreeNodeSignature, Double>>> coefficientMaps = new ArrayList<>();
    List<String> classLabels = new ArrayList<>();
    List<Segment> segments = modelChain.getSegmentList();
    for (int i = 0; i < segments.size() - 1; i++) {
        Pair<List<TreeModelRegression>, List<Map<TreeNodeSignature, Double>>> gbtPair = processClassSegment(segments.get(i));
        trees.add(gbtPair.getFirst());
        coefficientMaps.add(gbtPair.getSecond());
        classLabels.add(extractClassLabel(segments.get(i)));
    }
    double initialValue = extractInitialValue(segments.get(0));
    return MultiClassGradientBoostedTreesModel.create(getMetaDataMapper().getTreeMetaData(), trees, coefficientMaps, initialValue, TreeType.Ordinary, classLabels);
}
Also used : Segmentation(org.dmg.pmml.SegmentationDocument.Segmentation) ArrayList(java.util.ArrayList) TreeNodeSignature(org.knime.base.node.mine.treeensemble2.model.TreeNodeSignature) Segment(org.dmg.pmml.SegmentDocument.Segment) ArrayList(java.util.ArrayList) List(java.util.List)

Example 4 with Segmentation

use of org.dmg.pmml.SegmentationDocument.Segmentation in project knime-core by knime.

the class ClassificationGBTModelExporter method addSegmentation.

private void addSegmentation(final MiningModel miningModel, final int c) {
    Segmentation seg = miningModel.addNewSegmentation();
    MultiClassGradientBoostedTreesModel gbt = getGBTModel();
    Collection<TreeModelRegression> trees = IntStream.range(0, gbt.getNrLevels()).mapToObj(i -> gbt.getModel(i, c)).collect(Collectors.toList());
    Collection<Map<TreeNodeSignature, Double>> coefficientMaps = IntStream.range(0, gbt.getNrLevels()).mapToObj(i -> gbt.getCoefficientMap(i, c)).collect(Collectors.toList());
    writeSumSegmentation(seg, trees, coefficientMaps);
}
Also used : IntStream(java.util.stream.IntStream) MININGFUNCTION(org.dmg.pmml.MININGFUNCTION) Enum(org.dmg.pmml.MININGFUNCTION.Enum) Targets(org.dmg.pmml.TargetsDocument.Targets) DATATYPE(org.dmg.pmml.DATATYPE) PMMLMiningSchemaTranslator(org.knime.core.node.port.pmml.PMMLMiningSchemaTranslator) RegressionTable(org.dmg.pmml.RegressionTableDocument.RegressionTable) Output(org.dmg.pmml.OutputDocument.Output) RESULTFEATURE(org.dmg.pmml.RESULTFEATURE) MiningSchema(org.dmg.pmml.MiningSchemaDocument.MiningSchema) Map(java.util.Map) Target(org.dmg.pmml.TargetDocument.Target) FIELDUSAGETYPE(org.dmg.pmml.FIELDUSAGETYPE) TreeNodeSignature(org.knime.base.node.mine.treeensemble2.model.TreeNodeSignature) Collection(java.util.Collection) Segmentation(org.dmg.pmml.SegmentationDocument.Segmentation) RegressionModel(org.dmg.pmml.RegressionModelDocument.RegressionModel) Collectors(java.util.stream.Collectors) MiningField(org.dmg.pmml.MiningFieldDocument.MiningField) OPTYPE(org.dmg.pmml.OPTYPE) MULTIPLEMODELMETHOD(org.dmg.pmml.MULTIPLEMODELMETHOD) NumericPredictor(org.dmg.pmml.NumericPredictorDocument.NumericPredictor) MultiClassGradientBoostedTreesModel(org.knime.base.node.mine.treeensemble2.model.MultiClassGradientBoostedTreesModel) REGRESSIONNORMALIZATIONMETHOD(org.dmg.pmml.REGRESSIONNORMALIZATIONMETHOD) DerivedFieldMapper(org.knime.core.node.port.pmml.preproc.DerivedFieldMapper) MiningModel(org.dmg.pmml.MiningModelDocument.MiningModel) Segment(org.dmg.pmml.SegmentDocument.Segment) OutputField(org.dmg.pmml.OutputFieldDocument.OutputField) TreeModelRegression(org.knime.base.node.mine.treeensemble2.model.TreeModelRegression) Segmentation(org.dmg.pmml.SegmentationDocument.Segmentation) MultiClassGradientBoostedTreesModel(org.knime.base.node.mine.treeensemble2.model.MultiClassGradientBoostedTreesModel) Map(java.util.Map) TreeModelRegression(org.knime.base.node.mine.treeensemble2.model.TreeModelRegression)

Example 5 with Segmentation

use of org.dmg.pmml.SegmentationDocument.Segmentation in project knime-core by knime.

the class RegressionGBTModelExporter method doWrite.

/**
 * {@inheritDoc}
 */
@Override
protected void doWrite(final MiningModel model) {
    // write the initial value
    Targets targets = model.addNewTargets();
    Target target = targets.addNewTarget();
    GradientBoostedTreesModel gbtModel = getGBTModel();
    target.setField(gbtModel.getMetaData().getTargetMetaData().getAttributeName());
    target.setRescaleConstant(gbtModel.getInitialValue());
    // write the model
    Segmentation segmentation = model.addNewSegmentation();
    List<TreeModelRegression> trees = IntStream.range(0, gbtModel.getNrModels()).mapToObj(gbtModel::getTreeModelRegression).collect(Collectors.toList());
    writeSumSegmentation(segmentation, trees, gbtModel.getCoeffientMaps());
}
Also used : Target(org.dmg.pmml.TargetDocument.Target) Segmentation(org.dmg.pmml.SegmentationDocument.Segmentation) Targets(org.dmg.pmml.TargetsDocument.Targets) GradientBoostedTreesModel(org.knime.base.node.mine.treeensemble2.model.GradientBoostedTreesModel) TreeModelRegression(org.knime.base.node.mine.treeensemble2.model.TreeModelRegression)

Aggregations

Segmentation (org.dmg.pmml.SegmentationDocument.Segmentation)5 TreeModelRegression (org.knime.base.node.mine.treeensemble2.model.TreeModelRegression)3 TreeNodeSignature (org.knime.base.node.mine.treeensemble2.model.TreeNodeSignature)3 List (java.util.List)2 Segment (org.dmg.pmml.SegmentDocument.Segment)2 Target (org.dmg.pmml.TargetDocument.Target)2 Targets (org.dmg.pmml.TargetsDocument.Targets)2 GradientBoostedTreesModel (org.knime.base.node.mine.treeensemble2.model.GradientBoostedTreesModel)2 MultiClassGradientBoostedTreesModel (org.knime.base.node.mine.treeensemble2.model.MultiClassGradientBoostedTreesModel)2 ArrayList (java.util.ArrayList)1 Collection (java.util.Collection)1 Map (java.util.Map)1 Collectors (java.util.stream.Collectors)1 IntStream (java.util.stream.IntStream)1 DATATYPE (org.dmg.pmml.DATATYPE)1 FIELDUSAGETYPE (org.dmg.pmml.FIELDUSAGETYPE)1 MININGFUNCTION (org.dmg.pmml.MININGFUNCTION)1 Enum (org.dmg.pmml.MININGFUNCTION.Enum)1 MULTIPLEMODELMETHOD (org.dmg.pmml.MULTIPLEMODELMETHOD)1 MiningField (org.dmg.pmml.MiningFieldDocument.MiningField)1