use of org.dmg.pmml.SegmentationDocument.Segmentation in project knime-core by knime.
the class ClassificationGBTModelExporter method doWrite.
/**
* {@inheritDoc}
*/
@Override
protected void doWrite(final MiningModel model) {
Segmentation modelChain = model.addNewSegmentation();
modelChain.setMultipleModelMethod(MULTIPLEMODELMETHOD.MODEL_CHAIN);
MultiClassGradientBoostedTreesModel gbt = getGBTModel();
// write one segment per class
for (int i = 0; i < gbt.getNrClasses(); i++) {
addClassSegment(modelChain, i);
}
// combine class predictions
addAggregationSegment(modelChain);
}
use of org.dmg.pmml.SegmentationDocument.Segmentation in project knime-core by knime.
the class RegressionGBTModelImporter method importFromPMMLInternal.
/**
* {@inheritDoc}
*/
@Override
public GradientBoostedTreesModel importFromPMMLInternal(final MiningModel miningModel) {
Segmentation segmentation = miningModel.getSegmentation();
CheckUtils.checkArgument(segmentation.getMultipleModelMethod() == MULTIPLEMODELMETHOD.SUM, "The provided segmentation has not the required sum as multiple model method but '%s' instead.", segmentation.getMultipleModelMethod());
Pair<List<TreeModelRegression>, List<Map<TreeNodeSignature, Double>>> treesCoeffientMapsPair = readSumSegmentation(segmentation);
List<TreeModelRegression> trees = treesCoeffientMapsPair.getFirst();
// TODO user should be warned if there is no initial value or anything else is fishy
double initialValue = miningModel.getTargets().getTargetList().get(0).getRescaleConstant();
// currently only models learned on "ordinary" columns can be read back in
return new GradientBoostedTreesModel(getMetaDataMapper().getTreeMetaData(), trees.toArray(new TreeModelRegression[trees.size()]), TreeType.Ordinary, initialValue, treesCoeffientMapsPair.getSecond());
}
use of org.dmg.pmml.SegmentationDocument.Segmentation in project knime-core by knime.
the class ClassificationGBTModelImporter method importFromPMMLInternal.
/**
* {@inheritDoc}
*/
@Override
protected MultiClassGradientBoostedTreesModel importFromPMMLInternal(final MiningModel miningModel) {
Segmentation modelChain = miningModel.getSegmentation();
CheckUtils.checkArgument(modelChain.getMultipleModelMethod() == MULTIPLEMODELMETHOD.MODEL_CHAIN, "The top level segmentation should have multiple model method '%s' but has '%s'", MULTIPLEMODELMETHOD.MODEL_CHAIN, modelChain.getMultipleModelMethod());
List<List<TreeModelRegression>> trees = new ArrayList<>();
List<List<Map<TreeNodeSignature, Double>>> coefficientMaps = new ArrayList<>();
List<String> classLabels = new ArrayList<>();
List<Segment> segments = modelChain.getSegmentList();
for (int i = 0; i < segments.size() - 1; i++) {
Pair<List<TreeModelRegression>, List<Map<TreeNodeSignature, Double>>> gbtPair = processClassSegment(segments.get(i));
trees.add(gbtPair.getFirst());
coefficientMaps.add(gbtPair.getSecond());
classLabels.add(extractClassLabel(segments.get(i)));
}
double initialValue = extractInitialValue(segments.get(0));
return MultiClassGradientBoostedTreesModel.create(getMetaDataMapper().getTreeMetaData(), trees, coefficientMaps, initialValue, TreeType.Ordinary, classLabels);
}
use of org.dmg.pmml.SegmentationDocument.Segmentation in project knime-core by knime.
the class ClassificationGBTModelExporter method addSegmentation.
private void addSegmentation(final MiningModel miningModel, final int c) {
Segmentation seg = miningModel.addNewSegmentation();
MultiClassGradientBoostedTreesModel gbt = getGBTModel();
Collection<TreeModelRegression> trees = IntStream.range(0, gbt.getNrLevels()).mapToObj(i -> gbt.getModel(i, c)).collect(Collectors.toList());
Collection<Map<TreeNodeSignature, Double>> coefficientMaps = IntStream.range(0, gbt.getNrLevels()).mapToObj(i -> gbt.getCoefficientMap(i, c)).collect(Collectors.toList());
writeSumSegmentation(seg, trees, coefficientMaps);
}
use of org.dmg.pmml.SegmentationDocument.Segmentation in project knime-core by knime.
the class RegressionGBTModelExporter method doWrite.
/**
* {@inheritDoc}
*/
@Override
protected void doWrite(final MiningModel model) {
// write the initial value
Targets targets = model.addNewTargets();
Target target = targets.addNewTarget();
GradientBoostedTreesModel gbtModel = getGBTModel();
target.setField(gbtModel.getMetaData().getTargetMetaData().getAttributeName());
target.setRescaleConstant(gbtModel.getInitialValue());
// write the model
Segmentation segmentation = model.addNewSegmentation();
List<TreeModelRegression> trees = IntStream.range(0, gbtModel.getNrModels()).mapToObj(gbtModel::getTreeModelRegression).collect(Collectors.toList());
writeSumSegmentation(segmentation, trees, gbtModel.getCoeffientMaps());
}
Aggregations