use of org.knime.base.node.mine.treeensemble2.model.TreeModelRegression in project knime-core by knime.
the class AbstractGBTModelExporter method writeSumSegmentation.
protected void writeSumSegmentation(final Segmentation segmentation, final Collection<TreeModelRegression> trees, final Collection<Map<TreeNodeSignature, Double>> coefficientMaps) {
assert trees.size() == coefficientMaps.size() : "The number of trees does not match the number of coefficient maps.";
segmentation.setMultipleModelMethod(MULTIPLEMODELMETHOD.SUM);
Iterator<TreeModelRegression> treeIterator = trees.iterator();
Iterator<Map<TreeNodeSignature, Double>> coefficientMapIterator = coefficientMaps.iterator();
for (int i = 1; i <= trees.size(); i++) {
Segment segment = segmentation.addNewSegment();
segment.setId(Integer.toString(i));
segment.addNewTrue();
writeTreeIntoSegment(segment, treeIterator.next(), coefficientMapIterator.next());
}
}
use of org.knime.base.node.mine.treeensemble2.model.TreeModelRegression in project knime-core by knime.
the class ClassificationGBTModelImporter method importFromPMMLInternal.
/**
* {@inheritDoc}
*/
@Override
protected MultiClassGradientBoostedTreesModel importFromPMMLInternal(final MiningModel miningModel) {
Segmentation modelChain = miningModel.getSegmentation();
CheckUtils.checkArgument(modelChain.getMultipleModelMethod() == MULTIPLEMODELMETHOD.MODEL_CHAIN, "The top level segmentation should have multiple model method '%s' but has '%s'", MULTIPLEMODELMETHOD.MODEL_CHAIN, modelChain.getMultipleModelMethod());
List<List<TreeModelRegression>> trees = new ArrayList<>();
List<List<Map<TreeNodeSignature, Double>>> coefficientMaps = new ArrayList<>();
List<String> classLabels = new ArrayList<>();
List<Segment> segments = modelChain.getSegmentList();
for (int i = 0; i < segments.size() - 1; i++) {
Pair<List<TreeModelRegression>, List<Map<TreeNodeSignature, Double>>> gbtPair = processClassSegment(segments.get(i));
trees.add(gbtPair.getFirst());
coefficientMaps.add(gbtPair.getSecond());
classLabels.add(extractClassLabel(segments.get(i)));
}
double initialValue = extractInitialValue(segments.get(0));
return MultiClassGradientBoostedTreesModel.create(getMetaDataMapper().getTreeMetaData(), trees, coefficientMaps, initialValue, TreeType.Ordinary, classLabels);
}
use of org.knime.base.node.mine.treeensemble2.model.TreeModelRegression in project knime-core by knime.
the class AbstractGBTModelImporter method readTreeModel.
private Pair<TreeModelRegression, Map<TreeNodeSignature, Double>> readTreeModel(final Segment segment) {
GBTRegressionContentParser contentParser = new GBTRegressionContentParser();
TreeModelImporter<TreeNodeRegression, TreeModelRegression, TreeTargetNumericColumnMetaData> treeImporter = new TreeModelImporter<TreeNodeRegression, TreeModelRegression, TreeTargetNumericColumnMetaData>(m_metaDataMapper, m_conditionParser, m_signatureFactory, contentParser, m_treeFactory);
TreeModel treeModel = segment.getTreeModel();
TreeModelRegression tree = treeImporter.importFromPMML(treeModel);
Map<TreeNodeSignature, Double> coefficientMap = contentParser.getCoefficientMap();
return new Pair<>(tree, coefficientMap);
}
use of org.knime.base.node.mine.treeensemble2.model.TreeModelRegression in project knime-core by knime.
the class AbstractGBTModelImporter method readSumSegmentation.
protected Pair<List<TreeModelRegression>, List<Map<TreeNodeSignature, Double>>> readSumSegmentation(final Segmentation segmentation) {
List<Segment> segments = segmentation.getSegmentList();
List<TreeModelRegression> trees = new ArrayList<>(segments.size());
List<Map<TreeNodeSignature, Double>> coefficientMaps = new ArrayList<>(segments.size());
for (Segment segment : segments) {
Pair<TreeModelRegression, Map<TreeNodeSignature, Double>> treeCoeffientMapPair = readTreeModel(segment);
trees.add(treeCoeffientMapPair.getFirst());
coefficientMaps.add(treeCoeffientMapPair.getSecond());
}
return new Pair<>(trees, coefficientMaps);
}
use of org.knime.base.node.mine.treeensemble2.model.TreeModelRegression in project knime-core by knime.
the class ClassificationGBTModelExporter method addSegmentation.
private void addSegmentation(final MiningModel miningModel, final int c) {
Segmentation seg = miningModel.addNewSegmentation();
MultiClassGradientBoostedTreesModel gbt = getGBTModel();
Collection<TreeModelRegression> trees = IntStream.range(0, gbt.getNrLevels()).mapToObj(i -> gbt.getModel(i, c)).collect(Collectors.toList());
Collection<Map<TreeNodeSignature, Double>> coefficientMaps = IntStream.range(0, gbt.getNrLevels()).mapToObj(i -> gbt.getCoefficientMap(i, c)).collect(Collectors.toList());
writeSumSegmentation(seg, trees, coefficientMaps);
}
Aggregations