use of org.knime.base.node.mine.treeensemble2.model.TreeNodeSignature in project knime-core by knime.
the class AbstractGBTModelImporter method readTreeModel.
private Pair<TreeModelRegression, Map<TreeNodeSignature, Double>> readTreeModel(final Segment segment) {
GBTRegressionContentParser contentParser = new GBTRegressionContentParser();
TreeModelImporter<TreeNodeRegression, TreeModelRegression, TreeTargetNumericColumnMetaData> treeImporter = new TreeModelImporter<TreeNodeRegression, TreeModelRegression, TreeTargetNumericColumnMetaData>(m_metaDataMapper, m_conditionParser, m_signatureFactory, contentParser, m_treeFactory);
TreeModel treeModel = segment.getTreeModel();
TreeModelRegression tree = treeImporter.importFromPMML(treeModel);
Map<TreeNodeSignature, Double> coefficientMap = contentParser.getCoefficientMap();
return new Pair<>(tree, coefficientMap);
}
use of org.knime.base.node.mine.treeensemble2.model.TreeNodeSignature in project knime-core by knime.
the class AbstractGBTModelImporter method readSumSegmentation.
protected Pair<List<TreeModelRegression>, List<Map<TreeNodeSignature, Double>>> readSumSegmentation(final Segmentation segmentation) {
List<Segment> segments = segmentation.getSegmentList();
List<TreeModelRegression> trees = new ArrayList<>(segments.size());
List<Map<TreeNodeSignature, Double>> coefficientMaps = new ArrayList<>(segments.size());
for (Segment segment : segments) {
Pair<TreeModelRegression, Map<TreeNodeSignature, Double>> treeCoeffientMapPair = readTreeModel(segment);
trees.add(treeCoeffientMapPair.getFirst());
coefficientMaps.add(treeCoeffientMapPair.getSecond());
}
return new Pair<>(trees, coefficientMaps);
}
use of org.knime.base.node.mine.treeensemble2.model.TreeNodeSignature in project knime-core by knime.
the class ClassificationGBTModelExporter method addSegmentation.
private void addSegmentation(final MiningModel miningModel, final int c) {
Segmentation seg = miningModel.addNewSegmentation();
MultiClassGradientBoostedTreesModel gbt = getGBTModel();
Collection<TreeModelRegression> trees = IntStream.range(0, gbt.getNrLevels()).mapToObj(i -> gbt.getModel(i, c)).collect(Collectors.toList());
Collection<Map<TreeNodeSignature, Double>> coefficientMaps = IntStream.range(0, gbt.getNrLevels()).mapToObj(i -> gbt.getCoefficientMap(i, c)).collect(Collectors.toList());
writeSumSegmentation(seg, trees, coefficientMaps);
}
use of org.knime.base.node.mine.treeensemble2.model.TreeNodeSignature in project knime-core by knime.
the class TreeModelImporter method createNodeFromPMML.
private N createNodeFromPMML(final Node pmmlNode, final TreeNodeSignature signature) {
List<N> children = new ArrayList<>();
byte i = 0;
for (Node child : pmmlNode.getNodeList()) {
TreeNodeSignature childSignature = m_signatureFactory.getChildSignatureFor(signature, i);
i++;
children.add(createNodeFromPMML(child, childSignature));
}
TreeNodeCondition condition = m_conditionParser.parseCondition(pmmlNode);
N node = m_contentParser.createNode(pmmlNode, m_metaDataMapper.getTargetColumnHelper(), signature, children);
node.setTreeNodeCondition(condition);
return node;
}
use of org.knime.base.node.mine.treeensemble2.model.TreeNodeSignature in project knime-core by knime.
the class AllColumnSampleStrategyTest method testGetColumnSampleForTreeNodeTest.
/**
* Tests the method {@link AllColumnSampleStrategy#getColumnSampleForTreeNode(org.knime.base.node.mine.treeensemble2.model.TreeNodeSignature)}
* This also tests the class {@link AllColumnSample}
*
* @throws Exception
*/
@Test
public void testGetColumnSampleForTreeNodeTest() throws Exception {
final AllColumnSampleStrategy allColStrategy = new AllColumnSampleStrategy(createTreeData());
final TreeNodeSignatureFactory sigFac = createSignatureFactory();
TreeNodeSignature rootSig = sigFac.getRootSignature();
ColumnSample sample = allColStrategy.getColumnSampleForTreeNode(rootSig);
assertEquals("Wrong number of columns in sample.", TREE_DATA_SIZE, sample.getNumCols());
int[] colIndices = new int[TREE_DATA_SIZE];
for (int i = 0; i < colIndices.length; i++) {
colIndices[i] = i;
}
assertArrayEquals(colIndices, sample.getColumnIndices());
TreeNodeSignature childSig = sigFac.getChildSignatureFor(rootSig, (byte) 0);
sample = allColStrategy.getColumnSampleForTreeNode(childSig);
assertEquals("Wrong number of columns in sample.", TREE_DATA_SIZE, sample.getNumCols());
assertArrayEquals(colIndices, sample.getColumnIndices());
}
Aggregations