Search in sources :

Example 1 with Segment

use of org.dmg.pmml.SegmentDocument.Segment in project knime-core by knime.

the class ClassificationGBTModelExporter method addClassSegment.

private void addClassSegment(final Segmentation modelChain, final int classIdx) {
    Segment cs = modelChain.addNewSegment();
    cs.setId(Integer.toString(classIdx + 1));
    cs.addNewTrue();
    MiningModel cm = cs.addNewMiningModel();
    cm.setFunctionName(MININGFUNCTION.REGRESSION);
    // write mining schema
    PMMLMiningSchemaTranslator.writeMiningSchema(getPMMLSpec(), cm);
    addOutput(cm, classIdx);
    addTarget(cm);
    addSegmentation(cm, classIdx);
}
Also used : MiningModel(org.dmg.pmml.MiningModelDocument.MiningModel) Segment(org.dmg.pmml.SegmentDocument.Segment)

Example 2 with Segment

use of org.dmg.pmml.SegmentDocument.Segment in project knime-core by knime.

the class AbstractGBTModelExporter method writeSumSegmentation.

protected void writeSumSegmentation(final Segmentation segmentation, final Collection<TreeModelRegression> trees, final Collection<Map<TreeNodeSignature, Double>> coefficientMaps) {
    assert trees.size() == coefficientMaps.size() : "The number of trees does not match the number of coefficient maps.";
    segmentation.setMultipleModelMethod(MULTIPLEMODELMETHOD.SUM);
    Iterator<TreeModelRegression> treeIterator = trees.iterator();
    Iterator<Map<TreeNodeSignature, Double>> coefficientMapIterator = coefficientMaps.iterator();
    for (int i = 1; i <= trees.size(); i++) {
        Segment segment = segmentation.addNewSegment();
        segment.setId(Integer.toString(i));
        segment.addNewTrue();
        writeTreeIntoSegment(segment, treeIterator.next(), coefficientMapIterator.next());
    }
}
Also used : Map(java.util.Map) TreeModelRegression(org.knime.base.node.mine.treeensemble2.model.TreeModelRegression) Segment(org.dmg.pmml.SegmentDocument.Segment)

Example 3 with Segment

use of org.dmg.pmml.SegmentDocument.Segment in project knime-core by knime.

the class ClassificationGBTModelImporter method importFromPMMLInternal.

/**
 * {@inheritDoc}
 */
@Override
protected MultiClassGradientBoostedTreesModel importFromPMMLInternal(final MiningModel miningModel) {
    Segmentation modelChain = miningModel.getSegmentation();
    CheckUtils.checkArgument(modelChain.getMultipleModelMethod() == MULTIPLEMODELMETHOD.MODEL_CHAIN, "The top level segmentation should have multiple model method '%s' but has '%s'", MULTIPLEMODELMETHOD.MODEL_CHAIN, modelChain.getMultipleModelMethod());
    List<List<TreeModelRegression>> trees = new ArrayList<>();
    List<List<Map<TreeNodeSignature, Double>>> coefficientMaps = new ArrayList<>();
    List<String> classLabels = new ArrayList<>();
    List<Segment> segments = modelChain.getSegmentList();
    for (int i = 0; i < segments.size() - 1; i++) {
        Pair<List<TreeModelRegression>, List<Map<TreeNodeSignature, Double>>> gbtPair = processClassSegment(segments.get(i));
        trees.add(gbtPair.getFirst());
        coefficientMaps.add(gbtPair.getSecond());
        classLabels.add(extractClassLabel(segments.get(i)));
    }
    double initialValue = extractInitialValue(segments.get(0));
    return MultiClassGradientBoostedTreesModel.create(getMetaDataMapper().getTreeMetaData(), trees, coefficientMaps, initialValue, TreeType.Ordinary, classLabels);
}
Also used : Segmentation(org.dmg.pmml.SegmentationDocument.Segmentation) ArrayList(java.util.ArrayList) TreeNodeSignature(org.knime.base.node.mine.treeensemble2.model.TreeNodeSignature) Segment(org.dmg.pmml.SegmentDocument.Segment) ArrayList(java.util.ArrayList) List(java.util.List)

Example 4 with Segment

use of org.dmg.pmml.SegmentDocument.Segment in project knime-core by knime.

the class AbstractGBTModelImporter method readSumSegmentation.

protected Pair<List<TreeModelRegression>, List<Map<TreeNodeSignature, Double>>> readSumSegmentation(final Segmentation segmentation) {
    List<Segment> segments = segmentation.getSegmentList();
    List<TreeModelRegression> trees = new ArrayList<>(segments.size());
    List<Map<TreeNodeSignature, Double>> coefficientMaps = new ArrayList<>(segments.size());
    for (Segment segment : segments) {
        Pair<TreeModelRegression, Map<TreeNodeSignature, Double>> treeCoeffientMapPair = readTreeModel(segment);
        trees.add(treeCoeffientMapPair.getFirst());
        coefficientMaps.add(treeCoeffientMapPair.getSecond());
    }
    return new Pair<>(trees, coefficientMaps);
}
Also used : ArrayList(java.util.ArrayList) Map(java.util.Map) Segment(org.dmg.pmml.SegmentDocument.Segment) TreeModelRegression(org.knime.base.node.mine.treeensemble2.model.TreeModelRegression) Pair(org.knime.core.util.Pair)

Example 5 with Segment

use of org.dmg.pmml.SegmentDocument.Segment in project knime-core by knime.

the class ClassificationGBTModelExporter method addAggregationSegment.

private void addAggregationSegment(final Segmentation modelChain) {
    Segment seg = modelChain.addNewSegment();
    seg.setId(Integer.toString(getGBTModel().getNrClasses() + 1));
    seg.addNewTrue();
    addSoftmaxRegression(seg);
}
Also used : Segment(org.dmg.pmml.SegmentDocument.Segment)

Aggregations

Segment (org.dmg.pmml.SegmentDocument.Segment)5 ArrayList (java.util.ArrayList)2 Map (java.util.Map)2 TreeModelRegression (org.knime.base.node.mine.treeensemble2.model.TreeModelRegression)2 List (java.util.List)1 MiningModel (org.dmg.pmml.MiningModelDocument.MiningModel)1 Segmentation (org.dmg.pmml.SegmentationDocument.Segmentation)1 TreeNodeSignature (org.knime.base.node.mine.treeensemble2.model.TreeNodeSignature)1 Pair (org.knime.core.util.Pair)1