use of org.knime.base.node.mine.decisiontree2.model.DecisionTree in project knime-core by knime.
the class FromDecisionTreeNodeModel method execute.
/**
* {@inheritDoc}
* @throws CanceledExecutionException Execution cancelled.
* @throws InvalidSettingsException No or more than one RuleSet model is in the PMML input.
*/
@Override
protected PortObject[] execute(final PortObject[] inData, final ExecutionContext exec) throws CanceledExecutionException, InvalidSettingsException {
PMMLPortObject decTreeModel = (PMMLPortObject) inData[0];
PMMLDecisionTreeTranslator treeTranslator = new PMMLDecisionTreeTranslator();
decTreeModel.initializeModelTranslator(treeTranslator);
DecisionTree decisionTree = treeTranslator.getDecisionTree();
decisionTree.getRootNode();
PMMLPortObject ruleSetModel = new PMMLPortObject(decTreeModel.getSpec());
PMMLDocument document = PMMLDocument.Factory.newInstance();
PMML pmml = document.addNewPMML();
PMMLPortObjectSpec.writeHeader(pmml);
pmml.setVersion(PMMLPortObject.PMML_V4_2);
new PMMLDataDictionaryTranslator().exportTo(document, decTreeModel.getSpec());
RuleSetModel newRuleSetModel = pmml.addNewRuleSetModel();
PMMLMiningSchemaTranslator.writeMiningSchema(decTreeModel.getSpec(), newRuleSetModel);
newRuleSetModel.setFunctionName(MININGFUNCTION.CLASSIFICATION);
newRuleSetModel.setAlgorithmName("RuleSet");
RuleSet ruleSet = newRuleSetModel.addNewRuleSet();
ruleSet.addNewRuleSelectionMethod().setCriterion(Criterion.FIRST_HIT);
addRules(ruleSet, new ArrayList<DecisionTreeNode>(), decisionTree.getRootNode());
// TODO: Return a BufferedDataTable for each output port
PMMLPortObject pmmlPortObject = new PMMLPortObject(ruleSetModel.getSpec(), document);
return new PortObject[] { pmmlPortObject, new RuleSetToTable(m_rulesToTable).execute(exec, pmmlPortObject) };
}
use of org.knime.base.node.mine.decisiontree2.model.DecisionTree in project knime-core by knime.
the class DecTreePredictorNodeModel method execute.
/**
* {@inheritDoc}
*/
@Override
public PortObject[] execute(final PortObject[] inPorts, final ExecutionContext exec) throws CanceledExecutionException, Exception {
exec.setMessage("Decision Tree Predictor: Loading predictor...");
PMMLPortObject port = (PMMLPortObject) inPorts[INMODELPORT];
List<Node> models = port.getPMMLValue().getModels(PMMLModelType.TreeModel);
if (models.isEmpty()) {
String msg = "Decision Tree evaluation failed: " + "No tree model found.";
LOGGER.error(msg);
throw new RuntimeException(msg);
}
PMMLDecisionTreeTranslator trans = new PMMLDecisionTreeTranslator();
port.initializeModelTranslator(trans);
DecisionTree decTree = trans.getDecisionTree();
decTree.resetColorInformation();
BufferedDataTable inData = (BufferedDataTable) inPorts[INDATAPORT];
// get column with color information
String colorColumn = null;
for (DataColumnSpec s : inData.getDataTableSpec()) {
if (s.getColorHandler() != null) {
colorColumn = s.getName();
break;
}
}
decTree.setColorColumn(colorColumn);
exec.setMessage("Decision Tree Predictor: start execution.");
PortObjectSpec[] inSpecs = new PortObjectSpec[] { inPorts[0].getSpec(), inPorts[1].getSpec() };
DataTableSpec outSpec = createOutTableSpec(inSpecs);
BufferedDataContainer outData = exec.createDataContainer(outSpec);
long coveredPattern = 0;
long nrPattern = 0;
long rowCount = 0;
long numberRows = inData.size();
exec.setMessage("Classifying...");
for (DataRow thisRow : inData) {
DataCell cl = null;
LinkedHashMap<String, Double> classDistrib = null;
try {
Pair<DataCell, LinkedHashMap<DataCell, Double>> pair = decTree.getWinnerAndClasscounts(thisRow, inData.getDataTableSpec());
cl = pair.getFirst();
LinkedHashMap<DataCell, Double> classCounts = pair.getSecond();
classDistrib = getDistribution(classCounts);
if (coveredPattern < m_maxNumCoveredPattern.getIntValue()) {
// remember this one for HiLite support
decTree.addCoveredPattern(thisRow, inData.getDataTableSpec());
coveredPattern++;
} else {
// too many patterns for HiLite - at least remember color
decTree.addCoveredColor(thisRow, inData.getDataTableSpec());
}
nrPattern++;
} catch (Exception e) {
LOGGER.error("Decision Tree evaluation failed: " + e.getMessage());
throw e;
}
if (cl == null) {
LOGGER.error("Decision Tree evaluation failed: result empty");
throw new Exception("Decision Tree evaluation failed.");
}
DataCell[] newCells = new DataCell[outSpec.getNumColumns()];
int numInCells = thisRow.getNumCells();
for (int i = 0; i < numInCells; i++) {
newCells[i] = thisRow.getCell(i);
}
if (m_showDistribution.getBooleanValue()) {
for (int i = numInCells; i < newCells.length - 1; i++) {
String predClass = outSpec.getColumnSpec(i).getName();
if (classDistrib != null && classDistrib.get(predClass) != null) {
newCells[i] = new DoubleCell(classDistrib.get(predClass));
} else {
newCells[i] = new DoubleCell(0.0);
}
}
}
newCells[newCells.length - 1] = cl;
outData.addRowToTable(new DefaultRow(thisRow.getKey(), newCells));
rowCount++;
if (rowCount % 100 == 0) {
exec.setProgress(rowCount / (double) numberRows, "Classifying... Row " + rowCount + " of " + numberRows);
}
exec.checkCanceled();
}
if (coveredPattern < nrPattern) {
// let the user know that we did not store all available pattern
// for HiLiting.
this.setWarningMessage("Tree only stored first " + m_maxNumCoveredPattern.getIntValue() + " (of " + nrPattern + ") rows for HiLiting!");
}
outData.close();
m_decTree = decTree;
exec.setMessage("Decision Tree Predictor: end execution.");
return new BufferedDataTable[] { outData.getTable() };
}
use of org.knime.base.node.mine.decisiontree2.model.DecisionTree in project knime-core by knime.
the class DecTreePredictorNodeView method modelChanged.
/**
* {@inheritDoc}
*/
@Override
protected void modelChanged() {
DecTreePredictorNodeModel model = this.getNodeModel();
DecisionTree dt = model.getDecisionTree();
if (dt != null) {
// set new model
m_jTree.setModel(new DefaultTreeModel(dt.getRootNode()));
// change default renderer
m_jTree.setCellRenderer(new DecisionTreeNodeRenderer());
// make sure no default height is assumed (the renderer's
// preferred size should be used instead)
m_jTree.setRowHeight(0);
// retrieve HiLiteHandler from Input port
m_hiLiteHdl = model.getInHiLiteHandler(DecTreePredictorNodeModel.INDATAPORT);
// and adjust menu entries for HiLite-ing
m_hiLiteMenu.setEnabled(m_hiLiteHdl != null);
} else {
m_jTree.setModel(null);
}
}
use of org.knime.base.node.mine.decisiontree2.model.DecisionTree in project knime-core by knime.
the class RegressionTreeModel method createDecisionTree.
public DecisionTree createDecisionTree(final DataTable sampleForHiliting) {
final DecisionTree result;
TreeModelRegression treeModel = getTreeModelRegression();
result = treeModel.createDecisionTree(getMetaData());
if (sampleForHiliting != null) {
final DataTableSpec dataSpec = sampleForHiliting.getDataTableSpec();
final DataTableSpec spec = getLearnAttributeSpec(dataSpec);
for (DataRow r : sampleForHiliting) {
try {
DataRow fullAttributeRow = createLearnAttributeRow(r, spec);
result.addCoveredPattern(fullAttributeRow, spec);
} catch (Exception e) {
// dunno what to do with that
NodeLogger.getLogger(getClass()).error("Error updating hilite info in tree view", e);
break;
}
}
}
return result;
}
use of org.knime.base.node.mine.decisiontree2.model.DecisionTree in project knime-core by knime.
the class PMMLDecisionTreeTranslator method exportTo.
/**
* {@inheritDoc}
*/
@Override
public SchemaType exportTo(final PMMLDocument pmmlDoc, final PMMLPortObjectSpec spec) {
m_nameMapper = new DerivedFieldMapper(pmmlDoc);
PMML pmml = pmmlDoc.getPMML();
TreeModelDocument.TreeModel treeModel = pmml.addNewTreeModel();
PMMLMiningSchemaTranslator.writeMiningSchema(spec, treeModel);
treeModel.setModelName("DecisionTree");
if (m_isClassification) {
treeModel.setFunctionName(MININGFUNCTION.CLASSIFICATION);
} else {
treeModel.setFunctionName(MININGFUNCTION.REGRESSION);
}
// set up splitCharacteristic
if (treeIsMultisplit(m_tree.getRootNode())) {
treeModel.setSplitCharacteristic(SplitCharacteristic.MULTI_SPLIT);
} else {
treeModel.setSplitCharacteristic(SplitCharacteristic.BINARY_SPLIT);
}
// ----------------------------------------------
// set up missing value strategy
PMMLMissingValueStrategy mvStrategy = m_tree.getMVStrategy() != null ? m_tree.getMVStrategy() : PMMLMissingValueStrategy.NONE;
treeModel.setMissingValueStrategy(MV_STRATEGY_TO_PMML_MAP.get(mvStrategy));
// -------------------------------------------------
// set up no true child strategy
PMMLNoTrueChildStrategy ntcStrategy = m_tree.getNTCStrategy();
if (PMMLNoTrueChildStrategy.RETURN_LAST_PREDICTION.equals(ntcStrategy)) {
treeModel.setNoTrueChildStrategy(NOTRUECHILDSTRATEGY.RETURN_LAST_PREDICTION);
} else if (PMMLNoTrueChildStrategy.RETURN_NULL_PREDICTION.equals(ntcStrategy)) {
treeModel.setNoTrueChildStrategy(NOTRUECHILDSTRATEGY.RETURN_NULL_PREDICTION);
}
// --------------------------------------------------
// set up tree node
NodeDocument.Node rootNode = treeModel.addNewNode();
addTreeNode(rootNode, m_tree.getRootNode(), new DerivedFieldMapper(pmmlDoc));
return TreeModel.type;
}
Aggregations