use of org.knime.core.node.port.pmml.PMMLPortObjectSpec in project knime-core by knime.
the class DecisionTreeLearnerNodeModel2 method configure.
/**
* The number of the class column must be > 0 and < number of input columns.
*
* @param inSpecs the tabel specs on the input port to use for configuration
* @see NodeModel#configure(DataTableSpec[])
* @throws InvalidSettingsException thrown if the configuration is not
* correct
* @return the table specs for the output ports
*/
@Override
protected PortObjectSpec[] configure(final PortObjectSpec[] inSpecs) throws InvalidSettingsException {
DataTableSpec inSpec = (DataTableSpec) inSpecs[DATA_INPORT];
PMMLPortObjectSpec modelSpec = m_pmmlInEnabled ? (PMMLPortObjectSpec) inSpecs[MODEL_INPORT] : null;
// check spec with selected column
String classifyColumn = m_classifyColumn.getStringValue();
DataColumnSpec columnSpec = inSpec.getColumnSpec(classifyColumn);
boolean isValid = columnSpec != null && columnSpec.getType().isCompatible(NominalValue.class);
if (classifyColumn != null && !isValid) {
throw new InvalidSettingsException("Class column \"" + classifyColumn + "\" not found or incompatible");
}
if (classifyColumn == null) {
// auto-guessing
assert !isValid : "No class column set but valid configuration";
// get the first useful one starting at the end of the table
for (int i = inSpec.getNumColumns() - 1; i >= 0; i--) {
if (inSpec.getColumnSpec(i).getType().isCompatible(NominalValue.class)) {
m_classifyColumn.setStringValue(inSpec.getColumnSpec(i).getName());
super.setWarningMessage("Guessing target column: \"" + m_classifyColumn.getStringValue() + "\".");
break;
}
}
if (m_classifyColumn.getStringValue() == null) {
throw new InvalidSettingsException("Table contains no nominal" + " attribute for classification.");
}
}
if (m_useFirstSplitCol.getBooleanValue()) {
String firstSplitCol = m_firstSplitCol.getStringValue();
DataColumnSpec firstSplitSpec = inSpec.getColumnSpec(firstSplitCol);
if (firstSplitCol == null) {
throw new InvalidSettingsException("Root split column should be used but is not specified.");
} else if (firstSplitSpec == null) {
throw new InvalidSettingsException("The selected column for the root split \"" + firstSplitCol + "\" is not in the table.");
} else if (firstSplitSpec.equals(columnSpec)) {
throw new InvalidSettingsException("The class column can not be selected as the" + " first column to split on.");
}
}
return new PortObjectSpec[] { createPMMLPortObjectSpec(modelSpec, inSpec) };
}
use of org.knime.core.node.port.pmml.PMMLPortObjectSpec in project knime-core by knime.
the class DecisionTreeLearnerNodeModel2 method execute.
/**
* Start of decision tree induction.
*
* @param exec the execution context for this run
* @param data the input data to build the decision tree from
* @return an empty data table array, as just a model is provided
* @throws Exception any type of exception, e.g. for cancellation,
* invalid input,...
* @see NodeModel#execute(BufferedDataTable[],ExecutionContext)
*/
@Override
protected PortObject[] execute(final PortObject[] data, final ExecutionContext exec) throws Exception {
// holds the warning message displayed after execution
m_warningMessageSb = new StringBuilder();
ParallelProcessing parallelProcessing = new ParallelProcessing(m_parallelProcessing.getIntValue());
if (LOGGER.isDebugEnabled()) {
LOGGER.debug("Number available threads: " + parallelProcessing.getMaxNumberThreads() + " used threads: " + parallelProcessing.getCurrentThreadsInUse());
}
exec.setProgress("Preparing...");
// check input data
assert (data != null && data[DATA_INPORT] != null);
BufferedDataTable inData = (BufferedDataTable) data[DATA_INPORT];
// get column with color information
String colorColumn = null;
for (DataColumnSpec s : inData.getDataTableSpec()) {
if (s.getColorHandler() != null) {
colorColumn = s.getName();
break;
}
}
// the data table must have more than 2 records
if (inData.size() <= 1) {
throw new IllegalArgumentException("Input data table must have at least 2 records!");
}
// get class column index
int classColumnIndex = inData.getDataTableSpec().findColumnIndex(m_classifyColumn.getStringValue());
assert classColumnIndex > -1;
// create initial In-Memory table
exec.setProgress("Create initial In-Memory table...");
InMemoryTableCreator tableCreator = new InMemoryTableCreator(inData, classColumnIndex, m_minNumberRecordsPerNode.getIntValue(), m_skipColumns.getBooleanValue());
InMemoryTable initialTable = tableCreator.createInMemoryTable(exec.createSubExecutionContext(0.05));
int removedRows = tableCreator.getRemovedRowsDueToMissingClassValue();
if (removedRows == inData.size()) {
throw new IllegalArgumentException("Class column contains only " + "missing values");
}
if (removedRows > 0) {
m_warningMessageSb.append(removedRows);
m_warningMessageSb.append(" rows removed due to missing class value;");
}
// the all over row count is used to report progress
m_alloverRowCount = initialTable.getSumOfWeights();
// set the finishing counter
// this counter will always be incremented when a leaf node is
// created, as this determines the recursion end and can thus
// be used for progress indication
m_finishedCounter = new AtomicDouble(0);
// get the number of attributes
m_numberAttributes = initialTable.getNumAttributes();
// create the quality measure
final SplitQualityMeasure splitQualityMeasure;
if (m_splitQualityMeasureType.getStringValue().equals(SPLIT_QUALITY_GINI)) {
splitQualityMeasure = new SplitQualityGini();
} else {
splitQualityMeasure = new SplitQualityGainRatio();
}
// build the tree
// before this set the node counter to 0
m_counter.set(0);
exec.setMessage("Building tree...");
final int firstSplitColIdx = initialTable.getAttributeIndex(m_firstSplitCol.getStringValue());
DecisionTreeNode root = null;
root = buildTree(initialTable, exec, 0, splitQualityMeasure, parallelProcessing, firstSplitColIdx);
boolean isBinaryNominal = m_binaryNominalSplitMode.getBooleanValue();
boolean isFilterInvalidAttributeValues = m_filterNominalValuesFromParent.getBooleanValue();
if (isBinaryNominal && isFilterInvalidAttributeValues) {
// traverse tree nodes and remove from the children the attribute
// values that were filtered out further up in the tree. "Bug" 3124
root.filterIllegalAttributes(Collections.<String, Set<String>>emptyMap());
}
// the decision tree model saved as PMML at the second out-port
DecisionTree decisionTree = new DecisionTree(root, m_classifyColumn.getStringValue(), /* strategy has to be set explicitly as the default in PMML is
none, which means rows with missing values are not
classified. */
PMMLMissingValueStrategy.get(m_missingValues.getStringValue()), PMMLNoTrueChildStrategy.get(m_noTrueChild.getStringValue()));
decisionTree.setColorColumn(colorColumn);
// prune the tree
exec.setMessage("Prune tree with " + m_pruningMethod.getStringValue() + "...");
pruneTree(decisionTree);
// add highlight patterns and color information
exec.setMessage("Adding hilite and color info to tree...");
addHiliteAndColorInfo(inData, decisionTree);
LOGGER.info("Decision tree consisting of " + decisionTree.getNumberNodes() + " nodes created with pruning method " + m_pruningMethod.getStringValue());
// set the warning message if available
if (m_warningMessageSb.length() > 0) {
setWarningMessage(m_warningMessageSb.toString());
}
// reset the number available threads
parallelProcessing.reset();
parallelProcessing = null;
// no data out table is created -> return an empty table array
exec.setMessage("Creating PMML decision tree model...");
// handle the optional PMML input
PMMLPortObject inPMMLPort = m_pmmlInEnabled ? (PMMLPortObject) data[1] : null;
DataTableSpec inSpec = inData.getSpec();
PMMLPortObjectSpec outPortSpec = createPMMLPortObjectSpec(inPMMLPort == null ? null : inPMMLPort.getSpec(), inSpec);
PMMLPortObject outPMMLPort = new PMMLPortObject(outPortSpec, inPMMLPort, inData.getSpec());
outPMMLPort.addModelTranslater(new PMMLDecisionTreeTranslator(decisionTree));
m_decisionTree = decisionTree;
return new PortObject[] { outPMMLPort };
}
use of org.knime.core.node.port.pmml.PMMLPortObjectSpec in project knime-core by knime.
the class DecTreePredictorNodeModel method execute.
/**
* {@inheritDoc}
*/
@Override
public PortObject[] execute(final PortObject[] inPorts, final ExecutionContext exec) throws CanceledExecutionException, Exception {
exec.setMessage("Decision Tree Predictor: Loading predictor...");
PMMLPortObject port = (PMMLPortObject) inPorts[INMODELPORT];
List<Node> models = port.getPMMLValue().getModels(PMMLModelType.TreeModel);
if (models.isEmpty()) {
String msg = "Decision Tree evaluation failed: " + "No tree model found.";
LOGGER.error(msg);
throw new RuntimeException(msg);
}
PMMLDecisionTreeTranslator trans = new PMMLDecisionTreeTranslator();
port.initializeModelTranslator(trans);
DecisionTree decTree = trans.getDecisionTree();
decTree.resetColorInformation();
BufferedDataTable inData = (BufferedDataTable) inPorts[INDATAPORT];
// get column with color information
String colorColumn = null;
for (DataColumnSpec s : inData.getDataTableSpec()) {
if (s.getColorHandler() != null) {
colorColumn = s.getName();
break;
}
}
decTree.setColorColumn(colorColumn);
exec.setMessage("Decision Tree Predictor: start execution.");
PortObjectSpec[] inSpecs = new PortObjectSpec[] { inPorts[0].getSpec(), inPorts[1].getSpec() };
DataTableSpec outSpec = createOutTableSpec(inSpecs);
BufferedDataContainer outData = exec.createDataContainer(outSpec);
long coveredPattern = 0;
long nrPattern = 0;
long rowCount = 0;
final long numberRows = inData.size();
exec.setMessage("Classifying...");
List<String> predictionValues = getPredictionStrings((PMMLPortObjectSpec) inPorts[INMODELPORT].getSpec());
for (DataRow thisRow : inData) {
DataCell cl = null;
LinkedHashMap<String, Double> classDistrib = null;
try {
Pair<DataCell, LinkedHashMap<DataCell, Double>> pair = decTree.getWinnerAndClasscounts(thisRow, inData.getDataTableSpec());
cl = pair.getFirst();
LinkedHashMap<DataCell, Double> classCounts = pair.getSecond();
classDistrib = getDistribution(classCounts);
if (coveredPattern < m_maxNumCoveredPattern.getIntValue()) {
// remember this one for HiLite support
decTree.addCoveredPattern(thisRow, inData.getDataTableSpec());
coveredPattern++;
} else {
// too many patterns for HiLite - at least remember color
decTree.addCoveredColor(thisRow, inData.getDataTableSpec());
}
nrPattern++;
} catch (Exception e) {
LOGGER.error("Decision Tree evaluation failed: " + e.getMessage());
throw e;
}
if (cl == null) {
LOGGER.error("Decision Tree evaluation failed: result empty");
throw new Exception("Decision Tree evaluation failed.");
}
DataCell[] newCells = new DataCell[outSpec.getNumColumns()];
int numInCells = thisRow.getNumCells();
for (int i = 0; i < numInCells; i++) {
newCells[i] = thisRow.getCell(i);
}
if (m_showDistribution.getBooleanValue()) {
assert predictionValues.size() >= newCells.length - 1 - numInCells : "Could not determine the prediction values: " + newCells.length + "; " + numInCells + "; " + predictionValues;
for (int i = numInCells; i < newCells.length - 1; i++) {
String predClass = predictionValues.get(i - numInCells);
if (classDistrib != null && classDistrib.get(predClass) != null) {
newCells[i] = new DoubleCell(classDistrib.get(predClass));
} else {
newCells[i] = new DoubleCell(0.0);
}
}
}
newCells[newCells.length - 1] = cl;
outData.addRowToTable(new DefaultRow(thisRow.getKey(), newCells));
rowCount++;
if (rowCount % 100 == 0) {
exec.setProgress(rowCount / (double) numberRows, "Classifying... Row " + rowCount + " of " + numberRows);
}
exec.checkCanceled();
}
if (coveredPattern < nrPattern) {
// let the user know that we did not store all available pattern
// for HiLiting.
this.setWarningMessage("Tree only stored first " + m_maxNumCoveredPattern.getIntValue() + " (of " + nrPattern + ") rows for HiLiting!");
}
outData.close();
m_decTree = decTree;
exec.setMessage("Decision Tree Predictor: end execution.");
return new BufferedDataTable[] { outData.getTable() };
}
use of org.knime.core.node.port.pmml.PMMLPortObjectSpec in project knime-core by knime.
the class DecTreePredictorNodeModel method configure.
/**
* {@inheritDoc}
*/
@Override
protected PortObjectSpec[] configure(final PortObjectSpec[] inSpecs) throws InvalidSettingsException {
String predCol = m_predictionColumn.getStringValue();
CheckUtils.checkSetting(!m_overridePrediction.getBooleanValue() || (predCol != null && !predCol.trim().isEmpty()), "Prediction column name cannot be empty");
PMMLPortObjectSpec treeSpec = (PMMLPortObjectSpec) inSpecs[INMODELPORT];
DataTableSpec inSpec = (DataTableSpec) inSpecs[1];
for (String learnColName : treeSpec.getLearningFields()) {
if (!inSpec.containsName(learnColName)) {
throw new InvalidSettingsException("Learning column \"" + learnColName + "\" not found in input " + "data to be predicted");
}
}
return new PortObjectSpec[] { createOutTableSpec(inSpecs) };
}
use of org.knime.core.node.port.pmml.PMMLPortObjectSpec in project knime-core by knime.
the class RPropNodeModel method configure.
/**
* returns null.
*
* {@inheritDoc}
*/
@Override
protected PortObjectSpec[] configure(final PortObjectSpec[] inSpecs) throws InvalidSettingsException {
if (m_classcol.getStringValue() != null) {
List<String> learningCols = new LinkedList<String>();
List<String> targetCols = new LinkedList<String>();
boolean classcolinspec = false;
for (DataColumnSpec colspec : (DataTableSpec) inSpecs[INDATA]) {
if (!(colspec.getName().toString().compareTo(m_classcol.getStringValue()) == 0)) {
if (!colspec.getType().isCompatible(DoubleValue.class)) {
throw new InvalidSettingsException("Only double columns for input");
} else {
learningCols.add(colspec.getName());
DataColumnDomain domain = colspec.getDomain();
if (domain.hasBounds()) {
double lower = ((DoubleValue) domain.getLowerBound()).getDoubleValue();
double upper = ((DoubleValue) domain.getUpperBound()).getDoubleValue();
if (lower < 0 || upper > 1) {
setWarningMessage("Input data not normalized." + " Please consider using the " + "Normalizer Node first.");
}
}
}
} else {
targetCols.add(colspec.getName());
classcolinspec = true;
// TODO: Check what happens to other values than double
if (colspec.getType().isCompatible(DoubleValue.class)) {
// check if the values are in range [0,1]
DataColumnDomain domain = colspec.getDomain();
if (domain.hasBounds()) {
double lower = ((DoubleValue) domain.getLowerBound()).getDoubleValue();
double upper = ((DoubleValue) domain.getUpperBound()).getDoubleValue();
if (lower < 0 || upper > 1) {
throw new InvalidSettingsException("Domain range for regression in column " + colspec.getName() + " not in range [0,1]");
}
}
}
}
}
if (!classcolinspec) {
throw new InvalidSettingsException("Class column " + m_classcol.getStringValue() + " not found in DataTableSpec");
}
return new PortObjectSpec[] { createPMMLPortObjectSpec(m_pmmlInEnabled ? (PMMLPortObjectSpec) inSpecs[1] : null, (DataTableSpec) inSpecs[0], learningCols, targetCols) };
} else {
throw new InvalidSettingsException("Class column not set");
}
}
Aggregations