Search in sources :

Example 1 with DecisionTree

use of org.knime.base.node.mine.decisiontree2.model.DecisionTree in project knime-core by knime.

the class FromDecisionTreeNodeModel method execute.

/**
 * {@inheritDoc}
 * @throws CanceledExecutionException Execution cancelled.
 * @throws InvalidSettingsException No or more than one RuleSet model is in the PMML input.
 */
@Override
protected PortObject[] execute(final PortObject[] inData, final ExecutionContext exec) throws CanceledExecutionException, InvalidSettingsException {
    PMMLPortObject decTreeModel = (PMMLPortObject) inData[0];
    PMMLDecisionTreeTranslator treeTranslator = new PMMLDecisionTreeTranslator();
    decTreeModel.initializeModelTranslator(treeTranslator);
    DecisionTree decisionTree = treeTranslator.getDecisionTree();
    decisionTree.getRootNode();
    PMMLPortObject ruleSetModel = new PMMLPortObject(decTreeModel.getSpec());
    PMMLDocument document = PMMLDocument.Factory.newInstance();
    PMML pmml = document.addNewPMML();
    PMMLPortObjectSpec.writeHeader(pmml);
    pmml.setVersion(PMMLPortObject.PMML_V4_2);
    new PMMLDataDictionaryTranslator().exportTo(document, decTreeModel.getSpec());
    RuleSetModel newRuleSetModel = pmml.addNewRuleSetModel();
    PMMLMiningSchemaTranslator.writeMiningSchema(decTreeModel.getSpec(), newRuleSetModel);
    newRuleSetModel.setFunctionName(MININGFUNCTION.CLASSIFICATION);
    newRuleSetModel.setAlgorithmName("RuleSet");
    RuleSet ruleSet = newRuleSetModel.addNewRuleSet();
    ruleSet.addNewRuleSelectionMethod().setCriterion(Criterion.FIRST_HIT);
    addRules(ruleSet, new ArrayList<DecisionTreeNode>(), decisionTree.getRootNode());
    // TODO: Return a BufferedDataTable for each output port
    PMMLPortObject pmmlPortObject = new PMMLPortObject(ruleSetModel.getSpec(), document);
    return new PortObject[] { pmmlPortObject, new RuleSetToTable(m_rulesToTable).execute(exec, pmmlPortObject) };
}
Also used : RuleSetModel(org.dmg.pmml.RuleSetModelDocument.RuleSetModel) RuleSet(org.dmg.pmml.RuleSetDocument.RuleSet) DecisionTree(org.knime.base.node.mine.decisiontree2.model.DecisionTree) PMMLDecisionTreeTranslator(org.knime.base.node.mine.decisiontree2.PMMLDecisionTreeTranslator) PMMLPortObject(org.knime.core.node.port.pmml.PMMLPortObject) PMML(org.dmg.pmml.PMMLDocument.PMML) DecisionTreeNodeSplitPMML(org.knime.base.node.mine.decisiontree2.model.DecisionTreeNodeSplitPMML) PMMLDocument(org.dmg.pmml.PMMLDocument) PMMLDataDictionaryTranslator(org.knime.core.node.port.pmml.PMMLDataDictionaryTranslator) PortObject(org.knime.core.node.port.PortObject) PMMLPortObject(org.knime.core.node.port.pmml.PMMLPortObject) RuleSetToTable(org.knime.base.node.rules.engine.totable.RuleSetToTable) DecisionTreeNode(org.knime.base.node.mine.decisiontree2.model.DecisionTreeNode)

Example 2 with DecisionTree

use of org.knime.base.node.mine.decisiontree2.model.DecisionTree in project knime-core by knime.

the class DecTreePredictorNodeModel method execute.

/**
 * {@inheritDoc}
 */
@Override
public PortObject[] execute(final PortObject[] inPorts, final ExecutionContext exec) throws CanceledExecutionException, Exception {
    exec.setMessage("Decision Tree Predictor: Loading predictor...");
    PMMLPortObject port = (PMMLPortObject) inPorts[INMODELPORT];
    List<Node> models = port.getPMMLValue().getModels(PMMLModelType.TreeModel);
    if (models.isEmpty()) {
        String msg = "Decision Tree evaluation failed: " + "No tree model found.";
        LOGGER.error(msg);
        throw new RuntimeException(msg);
    }
    PMMLDecisionTreeTranslator trans = new PMMLDecisionTreeTranslator();
    port.initializeModelTranslator(trans);
    DecisionTree decTree = trans.getDecisionTree();
    decTree.resetColorInformation();
    BufferedDataTable inData = (BufferedDataTable) inPorts[INDATAPORT];
    // get column with color information
    String colorColumn = null;
    for (DataColumnSpec s : inData.getDataTableSpec()) {
        if (s.getColorHandler() != null) {
            colorColumn = s.getName();
            break;
        }
    }
    decTree.setColorColumn(colorColumn);
    exec.setMessage("Decision Tree Predictor: start execution.");
    PortObjectSpec[] inSpecs = new PortObjectSpec[] { inPorts[0].getSpec(), inPorts[1].getSpec() };
    DataTableSpec outSpec = createOutTableSpec(inSpecs);
    BufferedDataContainer outData = exec.createDataContainer(outSpec);
    long coveredPattern = 0;
    long nrPattern = 0;
    long rowCount = 0;
    long numberRows = inData.size();
    exec.setMessage("Classifying...");
    for (DataRow thisRow : inData) {
        DataCell cl = null;
        LinkedHashMap<String, Double> classDistrib = null;
        try {
            Pair<DataCell, LinkedHashMap<DataCell, Double>> pair = decTree.getWinnerAndClasscounts(thisRow, inData.getDataTableSpec());
            cl = pair.getFirst();
            LinkedHashMap<DataCell, Double> classCounts = pair.getSecond();
            classDistrib = getDistribution(classCounts);
            if (coveredPattern < m_maxNumCoveredPattern.getIntValue()) {
                // remember this one for HiLite support
                decTree.addCoveredPattern(thisRow, inData.getDataTableSpec());
                coveredPattern++;
            } else {
                // too many patterns for HiLite - at least remember color
                decTree.addCoveredColor(thisRow, inData.getDataTableSpec());
            }
            nrPattern++;
        } catch (Exception e) {
            LOGGER.error("Decision Tree evaluation failed: " + e.getMessage());
            throw e;
        }
        if (cl == null) {
            LOGGER.error("Decision Tree evaluation failed: result empty");
            throw new Exception("Decision Tree evaluation failed.");
        }
        DataCell[] newCells = new DataCell[outSpec.getNumColumns()];
        int numInCells = thisRow.getNumCells();
        for (int i = 0; i < numInCells; i++) {
            newCells[i] = thisRow.getCell(i);
        }
        if (m_showDistribution.getBooleanValue()) {
            for (int i = numInCells; i < newCells.length - 1; i++) {
                String predClass = outSpec.getColumnSpec(i).getName();
                if (classDistrib != null && classDistrib.get(predClass) != null) {
                    newCells[i] = new DoubleCell(classDistrib.get(predClass));
                } else {
                    newCells[i] = new DoubleCell(0.0);
                }
            }
        }
        newCells[newCells.length - 1] = cl;
        outData.addRowToTable(new DefaultRow(thisRow.getKey(), newCells));
        rowCount++;
        if (rowCount % 100 == 0) {
            exec.setProgress(rowCount / (double) numberRows, "Classifying... Row " + rowCount + " of " + numberRows);
        }
        exec.checkCanceled();
    }
    if (coveredPattern < nrPattern) {
        // let the user know that we did not store all available pattern
        // for HiLiting.
        this.setWarningMessage("Tree only stored first " + m_maxNumCoveredPattern.getIntValue() + " (of " + nrPattern + ") rows for HiLiting!");
    }
    outData.close();
    m_decTree = decTree;
    exec.setMessage("Decision Tree Predictor: end execution.");
    return new BufferedDataTable[] { outData.getTable() };
}
Also used : DataTableSpec(org.knime.core.data.DataTableSpec) PMMLDecisionTreeTranslator(org.knime.base.node.mine.decisiontree2.PMMLDecisionTreeTranslator) DoubleCell(org.knime.core.data.def.DoubleCell) Node(org.w3c.dom.Node) DataRow(org.knime.core.data.DataRow) LinkedHashMap(java.util.LinkedHashMap) DataColumnSpec(org.knime.core.data.DataColumnSpec) BufferedDataTable(org.knime.core.node.BufferedDataTable) PMMLPortObjectSpec(org.knime.core.node.port.pmml.PMMLPortObjectSpec) PortObjectSpec(org.knime.core.node.port.PortObjectSpec) DecisionTree(org.knime.base.node.mine.decisiontree2.model.DecisionTree) BufferedDataContainer(org.knime.core.node.BufferedDataContainer) InvalidSettingsException(org.knime.core.node.InvalidSettingsException) CanceledExecutionException(org.knime.core.node.CanceledExecutionException) IOException(java.io.IOException) PMMLPortObject(org.knime.core.node.port.pmml.PMMLPortObject) DataCell(org.knime.core.data.DataCell) DefaultRow(org.knime.core.data.def.DefaultRow)

Example 3 with DecisionTree

use of org.knime.base.node.mine.decisiontree2.model.DecisionTree in project knime-core by knime.

the class DecTreePredictorNodeView method modelChanged.

/**
 * {@inheritDoc}
 */
@Override
protected void modelChanged() {
    DecTreePredictorNodeModel model = this.getNodeModel();
    DecisionTree dt = model.getDecisionTree();
    if (dt != null) {
        // set new model
        m_jTree.setModel(new DefaultTreeModel(dt.getRootNode()));
        // change default renderer
        m_jTree.setCellRenderer(new DecisionTreeNodeRenderer());
        // make sure no default height is assumed (the renderer's
        // preferred size should be used instead)
        m_jTree.setRowHeight(0);
        // retrieve HiLiteHandler from Input port
        m_hiLiteHdl = model.getInHiLiteHandler(DecTreePredictorNodeModel.INDATAPORT);
        // and adjust menu entries for HiLite-ing
        m_hiLiteMenu.setEnabled(m_hiLiteHdl != null);
    } else {
        m_jTree.setModel(null);
    }
}
Also used : DecisionTree(org.knime.base.node.mine.decisiontree2.model.DecisionTree) DecisionTreeNodeRenderer(org.knime.base.node.mine.decisiontree2.model.DecisionTreeNodeRenderer) DefaultTreeModel(javax.swing.tree.DefaultTreeModel)

Example 4 with DecisionTree

use of org.knime.base.node.mine.decisiontree2.model.DecisionTree in project knime-core by knime.

the class RegressionTreeModel method createDecisionTree.

public DecisionTree createDecisionTree(final DataTable sampleForHiliting) {
    final DecisionTree result;
    TreeModelRegression treeModel = getTreeModelRegression();
    result = treeModel.createDecisionTree(getMetaData());
    if (sampleForHiliting != null) {
        final DataTableSpec dataSpec = sampleForHiliting.getDataTableSpec();
        final DataTableSpec spec = getLearnAttributeSpec(dataSpec);
        for (DataRow r : sampleForHiliting) {
            try {
                DataRow fullAttributeRow = createLearnAttributeRow(r, spec);
                result.addCoveredPattern(fullAttributeRow, spec);
            } catch (Exception e) {
                // dunno what to do with that
                NodeLogger.getLogger(getClass()).error("Error updating hilite info in tree view", e);
                break;
            }
        }
    }
    return result;
}
Also used : DecisionTree(org.knime.base.node.mine.decisiontree2.model.DecisionTree) DataTableSpec(org.knime.core.data.DataTableSpec) DataRow(org.knime.core.data.DataRow) CanceledExecutionException(org.knime.core.node.CanceledExecutionException) IOException(java.io.IOException)

Example 5 with DecisionTree

use of org.knime.base.node.mine.decisiontree2.model.DecisionTree in project knime-core by knime.

the class PMMLDecisionTreeTranslator method exportTo.

/**
 * {@inheritDoc}
 */
@Override
public SchemaType exportTo(final PMMLDocument pmmlDoc, final PMMLPortObjectSpec spec) {
    m_nameMapper = new DerivedFieldMapper(pmmlDoc);
    PMML pmml = pmmlDoc.getPMML();
    TreeModelDocument.TreeModel treeModel = pmml.addNewTreeModel();
    PMMLMiningSchemaTranslator.writeMiningSchema(spec, treeModel);
    treeModel.setModelName("DecisionTree");
    if (m_isClassification) {
        treeModel.setFunctionName(MININGFUNCTION.CLASSIFICATION);
    } else {
        treeModel.setFunctionName(MININGFUNCTION.REGRESSION);
    }
    // set up splitCharacteristic
    if (treeIsMultisplit(m_tree.getRootNode())) {
        treeModel.setSplitCharacteristic(SplitCharacteristic.MULTI_SPLIT);
    } else {
        treeModel.setSplitCharacteristic(SplitCharacteristic.BINARY_SPLIT);
    }
    // ----------------------------------------------
    // set up missing value strategy
    PMMLMissingValueStrategy mvStrategy = m_tree.getMVStrategy() != null ? m_tree.getMVStrategy() : PMMLMissingValueStrategy.NONE;
    treeModel.setMissingValueStrategy(MV_STRATEGY_TO_PMML_MAP.get(mvStrategy));
    // -------------------------------------------------
    // set up no true child strategy
    PMMLNoTrueChildStrategy ntcStrategy = m_tree.getNTCStrategy();
    if (PMMLNoTrueChildStrategy.RETURN_LAST_PREDICTION.equals(ntcStrategy)) {
        treeModel.setNoTrueChildStrategy(NOTRUECHILDSTRATEGY.RETURN_LAST_PREDICTION);
    } else if (PMMLNoTrueChildStrategy.RETURN_NULL_PREDICTION.equals(ntcStrategy)) {
        treeModel.setNoTrueChildStrategy(NOTRUECHILDSTRATEGY.RETURN_NULL_PREDICTION);
    }
    // --------------------------------------------------
    // set up tree node
    NodeDocument.Node rootNode = treeModel.addNewNode();
    addTreeNode(rootNode, m_tree.getRootNode(), new DerivedFieldMapper(pmmlDoc));
    return TreeModel.type;
}
Also used : DerivedFieldMapper(org.knime.core.node.port.pmml.preproc.DerivedFieldMapper) Node(org.dmg.pmml.NodeDocument.Node) PMML(org.dmg.pmml.PMMLDocument.PMML) DecisionTreeNodeSplitPMML(org.knime.base.node.mine.decisiontree2.model.DecisionTreeNodeSplitPMML) TreeModel(org.dmg.pmml.TreeModelDocument.TreeModel) NodeDocument(org.dmg.pmml.NodeDocument) TreeModelDocument(org.dmg.pmml.TreeModelDocument)

Aggregations

DecisionTree (org.knime.base.node.mine.decisiontree2.model.DecisionTree)24 DecisionTreeNode (org.knime.base.node.mine.decisiontree2.model.DecisionTreeNode)11 IOException (java.io.IOException)9 CanceledExecutionException (org.knime.core.node.CanceledExecutionException)9 DataTableSpec (org.knime.core.data.DataTableSpec)8 InvalidSettingsException (org.knime.core.node.InvalidSettingsException)7 DataRow (org.knime.core.data.DataRow)6 BufferedInputStream (java.io.BufferedInputStream)5 File (java.io.File)5 FileInputStream (java.io.FileInputStream)5 GZIPInputStream (java.util.zip.GZIPInputStream)5 PMMLDecisionTreeTranslator (org.knime.base.node.mine.decisiontree2.PMMLDecisionTreeTranslator)5 PMMLPortObject (org.knime.core.node.port.pmml.PMMLPortObject)5 DefaultTreeModel (javax.swing.tree.DefaultTreeModel)4 DecisionTreeNodeRenderer (org.knime.base.node.mine.decisiontree2.model.DecisionTreeNodeRenderer)4 DataColumnSpec (org.knime.core.data.DataColumnSpec)4 BufferedDataTable (org.knime.core.node.BufferedDataTable)4 ModelContentRO (org.knime.core.node.ModelContentRO)4 PMMLPortObjectSpec (org.knime.core.node.port.pmml.PMMLPortObjectSpec)4 DecisionTreeNodeLeaf (org.knime.base.node.mine.decisiontree2.model.DecisionTreeNodeLeaf)3