Search in sources :

Example 16 with TreeNodeSignature

use of org.knime.base.node.mine.treeensemble2.model.TreeNodeSignature in project knime-core by knime.

the class AbstractGBTModelImporter method readTreeModel.

private Pair<TreeModelRegression, Map<TreeNodeSignature, Double>> readTreeModel(final Segment segment) {
    GBTRegressionContentParser contentParser = new GBTRegressionContentParser();
    TreeModelImporter<TreeNodeRegression, TreeModelRegression, TreeTargetNumericColumnMetaData> treeImporter = new TreeModelImporter<TreeNodeRegression, TreeModelRegression, TreeTargetNumericColumnMetaData>(m_metaDataMapper, m_conditionParser, m_signatureFactory, contentParser, m_treeFactory);
    TreeModel treeModel = segment.getTreeModel();
    TreeModelRegression tree = treeImporter.importFromPMML(treeModel);
    Map<TreeNodeSignature, Double> coefficientMap = contentParser.getCoefficientMap();
    return new Pair<>(tree, coefficientMap);
}
Also used : TreeModel(org.dmg.pmml.TreeModelDocument.TreeModel) TreeTargetNumericColumnMetaData(org.knime.base.node.mine.treeensemble2.data.TreeTargetNumericColumnMetaData) TreeNodeSignature(org.knime.base.node.mine.treeensemble2.model.TreeNodeSignature) TreeNodeRegression(org.knime.base.node.mine.treeensemble2.model.TreeNodeRegression) TreeModelRegression(org.knime.base.node.mine.treeensemble2.model.TreeModelRegression) Pair(org.knime.core.util.Pair)

Example 17 with TreeNodeSignature

use of org.knime.base.node.mine.treeensemble2.model.TreeNodeSignature in project knime-core by knime.

the class AbstractGBTModelImporter method readSumSegmentation.

protected Pair<List<TreeModelRegression>, List<Map<TreeNodeSignature, Double>>> readSumSegmentation(final Segmentation segmentation) {
    List<Segment> segments = segmentation.getSegmentList();
    List<TreeModelRegression> trees = new ArrayList<>(segments.size());
    List<Map<TreeNodeSignature, Double>> coefficientMaps = new ArrayList<>(segments.size());
    for (Segment segment : segments) {
        Pair<TreeModelRegression, Map<TreeNodeSignature, Double>> treeCoeffientMapPair = readTreeModel(segment);
        trees.add(treeCoeffientMapPair.getFirst());
        coefficientMaps.add(treeCoeffientMapPair.getSecond());
    }
    return new Pair<>(trees, coefficientMaps);
}
Also used : ArrayList(java.util.ArrayList) Map(java.util.Map) Segment(org.dmg.pmml.SegmentDocument.Segment) TreeModelRegression(org.knime.base.node.mine.treeensemble2.model.TreeModelRegression) Pair(org.knime.core.util.Pair)

Example 18 with TreeNodeSignature

use of org.knime.base.node.mine.treeensemble2.model.TreeNodeSignature in project knime-core by knime.

the class ClassificationGBTModelExporter method addSegmentation.

private void addSegmentation(final MiningModel miningModel, final int c) {
    Segmentation seg = miningModel.addNewSegmentation();
    MultiClassGradientBoostedTreesModel gbt = getGBTModel();
    Collection<TreeModelRegression> trees = IntStream.range(0, gbt.getNrLevels()).mapToObj(i -> gbt.getModel(i, c)).collect(Collectors.toList());
    Collection<Map<TreeNodeSignature, Double>> coefficientMaps = IntStream.range(0, gbt.getNrLevels()).mapToObj(i -> gbt.getCoefficientMap(i, c)).collect(Collectors.toList());
    writeSumSegmentation(seg, trees, coefficientMaps);
}
Also used : IntStream(java.util.stream.IntStream) MININGFUNCTION(org.dmg.pmml.MININGFUNCTION) Enum(org.dmg.pmml.MININGFUNCTION.Enum) Targets(org.dmg.pmml.TargetsDocument.Targets) DATATYPE(org.dmg.pmml.DATATYPE) PMMLMiningSchemaTranslator(org.knime.core.node.port.pmml.PMMLMiningSchemaTranslator) RegressionTable(org.dmg.pmml.RegressionTableDocument.RegressionTable) Output(org.dmg.pmml.OutputDocument.Output) RESULTFEATURE(org.dmg.pmml.RESULTFEATURE) MiningSchema(org.dmg.pmml.MiningSchemaDocument.MiningSchema) Map(java.util.Map) Target(org.dmg.pmml.TargetDocument.Target) FIELDUSAGETYPE(org.dmg.pmml.FIELDUSAGETYPE) TreeNodeSignature(org.knime.base.node.mine.treeensemble2.model.TreeNodeSignature) Collection(java.util.Collection) Segmentation(org.dmg.pmml.SegmentationDocument.Segmentation) RegressionModel(org.dmg.pmml.RegressionModelDocument.RegressionModel) Collectors(java.util.stream.Collectors) MiningField(org.dmg.pmml.MiningFieldDocument.MiningField) OPTYPE(org.dmg.pmml.OPTYPE) MULTIPLEMODELMETHOD(org.dmg.pmml.MULTIPLEMODELMETHOD) NumericPredictor(org.dmg.pmml.NumericPredictorDocument.NumericPredictor) MultiClassGradientBoostedTreesModel(org.knime.base.node.mine.treeensemble2.model.MultiClassGradientBoostedTreesModel) REGRESSIONNORMALIZATIONMETHOD(org.dmg.pmml.REGRESSIONNORMALIZATIONMETHOD) DerivedFieldMapper(org.knime.core.node.port.pmml.preproc.DerivedFieldMapper) MiningModel(org.dmg.pmml.MiningModelDocument.MiningModel) Segment(org.dmg.pmml.SegmentDocument.Segment) OutputField(org.dmg.pmml.OutputFieldDocument.OutputField) TreeModelRegression(org.knime.base.node.mine.treeensemble2.model.TreeModelRegression) Segmentation(org.dmg.pmml.SegmentationDocument.Segmentation) MultiClassGradientBoostedTreesModel(org.knime.base.node.mine.treeensemble2.model.MultiClassGradientBoostedTreesModel) Map(java.util.Map) TreeModelRegression(org.knime.base.node.mine.treeensemble2.model.TreeModelRegression)

Example 19 with TreeNodeSignature

use of org.knime.base.node.mine.treeensemble2.model.TreeNodeSignature in project knime-core by knime.

the class TreeModelImporter method createNodeFromPMML.

private N createNodeFromPMML(final Node pmmlNode, final TreeNodeSignature signature) {
    List<N> children = new ArrayList<>();
    byte i = 0;
    for (Node child : pmmlNode.getNodeList()) {
        TreeNodeSignature childSignature = m_signatureFactory.getChildSignatureFor(signature, i);
        i++;
        children.add(createNodeFromPMML(child, childSignature));
    }
    TreeNodeCondition condition = m_conditionParser.parseCondition(pmmlNode);
    N node = m_contentParser.createNode(pmmlNode, m_metaDataMapper.getTargetColumnHelper(), signature, children);
    node.setTreeNodeCondition(condition);
    return node;
}
Also used : AbstractTreeNode(org.knime.base.node.mine.treeensemble2.model.AbstractTreeNode) Node(org.dmg.pmml.NodeDocument.Node) ArrayList(java.util.ArrayList) TreeNodeCondition(org.knime.base.node.mine.treeensemble2.model.TreeNodeCondition) TreeNodeSignature(org.knime.base.node.mine.treeensemble2.model.TreeNodeSignature)

Example 20 with TreeNodeSignature

use of org.knime.base.node.mine.treeensemble2.model.TreeNodeSignature in project knime-core by knime.

the class AllColumnSampleStrategyTest method testGetColumnSampleForTreeNodeTest.

/**
 * Tests the method {@link AllColumnSampleStrategy#getColumnSampleForTreeNode(org.knime.base.node.mine.treeensemble2.model.TreeNodeSignature)}
 * This also tests the class {@link AllColumnSample}
 *
 * @throws Exception
 */
@Test
public void testGetColumnSampleForTreeNodeTest() throws Exception {
    final AllColumnSampleStrategy allColStrategy = new AllColumnSampleStrategy(createTreeData());
    final TreeNodeSignatureFactory sigFac = createSignatureFactory();
    TreeNodeSignature rootSig = sigFac.getRootSignature();
    ColumnSample sample = allColStrategy.getColumnSampleForTreeNode(rootSig);
    assertEquals("Wrong number of columns in sample.", TREE_DATA_SIZE, sample.getNumCols());
    int[] colIndices = new int[TREE_DATA_SIZE];
    for (int i = 0; i < colIndices.length; i++) {
        colIndices[i] = i;
    }
    assertArrayEquals(colIndices, sample.getColumnIndices());
    TreeNodeSignature childSig = sigFac.getChildSignatureFor(rootSig, (byte) 0);
    sample = allColStrategy.getColumnSampleForTreeNode(childSig);
    assertEquals("Wrong number of columns in sample.", TREE_DATA_SIZE, sample.getNumCols());
    assertArrayEquals(colIndices, sample.getColumnIndices());
}
Also used : TreeNodeSignature(org.knime.base.node.mine.treeensemble2.model.TreeNodeSignature) TreeNodeSignatureFactory(org.knime.base.node.mine.treeensemble2.learner.TreeNodeSignatureFactory) Test(org.junit.Test)

Aggregations

TreeNodeSignature (org.knime.base.node.mine.treeensemble2.model.TreeNodeSignature)20 TreeData (org.knime.base.node.mine.treeensemble2.data.TreeData)10 TreeModelRegression (org.knime.base.node.mine.treeensemble2.model.TreeModelRegression)8 ArrayList (java.util.ArrayList)6 TreeNodeRegression (org.knime.base.node.mine.treeensemble2.model.TreeNodeRegression)6 TreeEnsembleLearnerConfiguration (org.knime.base.node.mine.treeensemble2.node.learner.TreeEnsembleLearnerConfiguration)6 Map (java.util.Map)5 RandomData (org.apache.commons.math.random.RandomData)5 TreeAttributeColumnData (org.knime.base.node.mine.treeensemble2.data.TreeAttributeColumnData)5 TreeNodeSignatureFactory (org.knime.base.node.mine.treeensemble2.learner.TreeNodeSignatureFactory)5 ColumnSample (org.knime.base.node.mine.treeensemble2.sample.column.ColumnSample)5 BitSet (java.util.BitSet)4 HashMap (java.util.HashMap)4 Segment (org.dmg.pmml.SegmentDocument.Segment)4 TreeTargetNominalColumnData (org.knime.base.node.mine.treeensemble2.data.TreeTargetNominalColumnData)4 RootDataMemberships (org.knime.base.node.mine.treeensemble2.data.memberships.RootDataMemberships)4 List (java.util.List)3 Segmentation (org.dmg.pmml.SegmentationDocument.Segmentation)3 Test (org.junit.Test)3 TreeTargetNumericColumnData (org.knime.base.node.mine.treeensemble2.data.TreeTargetNumericColumnData)3