use of org.knime.core.node.port.pmml.PMMLPortObjectSpecCreator in project knime-core by knime.
the class RegressionTreePMMLTranslatorNodeModel method createPMMLSpec.
private PMMLPortObjectSpec createPMMLSpec(final RegressionTreeModelPortObjectSpec treeSpec, final RegressionTreeModel model) {
DataColumnSpec targetSpec = treeSpec.getTargetColumn();
DataTableSpec learnFeatureSpec = treeSpec.getLearnTableSpec();
if (containsVector(learnFeatureSpec)) {
setWarningMessage("The model was learned on a vector column. It's possible to export the model " + "to PMML but it won't be possible to import it from the exported PMML.");
}
if (model == null && containsVector(learnFeatureSpec)) {
// at this point we don't know how long the vector column is
return null;
} else if (model != null) {
// possibly expand vectors with model
learnFeatureSpec = model.getLearnAttributeSpec(learnFeatureSpec);
}
DataTableSpec completeLearnSpec = new DataTableSpec(learnFeatureSpec, new DataTableSpec(targetSpec));
PMMLPortObjectSpecCreator pmmlSpecCreator = new PMMLPortObjectSpecCreator(completeLearnSpec);
try {
pmmlSpecCreator.setLearningCols(learnFeatureSpec);
} catch (InvalidSettingsException e) {
// (as of KNIME v2.5.1)
throw new IllegalStateException(e);
}
pmmlSpecCreator.setTargetCol(targetSpec);
return pmmlSpecCreator.createSpec();
}
use of org.knime.core.node.port.pmml.PMMLPortObjectSpecCreator in project knime-core by knime.
the class LogisticRegressionContent method createSpec.
private static PMMLPortObjectSpec createSpec(final DataTableSpec spec, final String target, final String[] learningCols) {
PMMLPortObjectSpecCreator c = new PMMLPortObjectSpecCreator(spec);
c.setTargetColName(target);
c.setLearningColsNames(Arrays.asList(learningCols));
return c.createSpec();
}
use of org.knime.core.node.port.pmml.PMMLPortObjectSpecCreator in project knime-core by knime.
the class LogRegLearner method init.
/**
* Initialize instance and check if settings are consistent.
*/
private void init(final DataTableSpec inSpec, final Set<String> exclude) throws InvalidSettingsException {
List<String> inputCols = new ArrayList<String>();
FilterResult includedColumns = m_settings.getIncludedColumns().applyTo(inSpec);
for (String column : includedColumns.getIncludes()) {
inputCols.add(column);
}
inputCols.remove(m_settings.getTargetColumn());
if (inputCols.isEmpty()) {
throw new InvalidSettingsException("At least one column must " + "be included.");
}
DataColumnSpec targetColSpec = null;
List<DataColumnSpec> regressorColSpecs = new ArrayList<DataColumnSpec>();
// Auto configuration when target is not set
if (null == m_settings.getTargetColumn() && m_settings.getIncludedColumns().applyTo(inSpec).getExcludes().length == 0) {
for (int i = 0; i < inSpec.getNumColumns(); i++) {
DataColumnSpec colSpec = inSpec.getColumnSpec(i);
String colName = colSpec.getName();
inputCols.remove(colName);
if (colSpec.getType().isCompatible(NominalValue.class)) {
m_settings.setTargetColumn(colName);
}
}
// when there is no column with nominal data
if (null == m_settings.getTargetColumn()) {
throw new InvalidSettingsException("No column in " + "spec compatible to \"NominalValue\".");
}
}
// remove all columns that should not be used
inputCols.removeAll(exclude);
m_specialColumns = new LinkedList<>();
for (int i = 0; i < inSpec.getNumColumns(); i++) {
DataColumnSpec colSpec = inSpec.getColumnSpec(i);
String colName = colSpec.getName();
final DataType type = colSpec.getType();
if (m_settings.getTargetColumn().equals(colName)) {
if (type.isCompatible(NominalValue.class)) {
targetColSpec = colSpec;
} else {
throw new InvalidSettingsException("Type of column \"" + colName + "\" is not nominal.");
}
} else if (inputCols.contains(colName)) {
if (type.isCompatible(DoubleValue.class) || type.isCompatible(NominalValue.class)) {
regressorColSpecs.add(colSpec);
} else if (type.isCompatible(BitVectorValue.class) || type.isCompatible(ByteVectorValue.class) || (type.isCollectionType() && type.getCollectionElementType().isCompatible(DoubleValue.class))) {
m_specialColumns.add(colSpec);
// We change the table spec later to encode it as a string.
regressorColSpecs.add(new DataColumnSpecCreator(colSpec.getName(), StringCell.TYPE).createSpec());
} else {
throw new InvalidSettingsException("Type of column \"" + colName + "\" is not one of the allowed types, " + "which are numeric or nomial.");
}
}
}
if (null != targetColSpec) {
// Check if target has at least two categories.
final Set<DataCell> targetValues = targetColSpec.getDomain().getValues();
if (targetValues != null && targetValues.size() < 2) {
throw new InvalidSettingsException("The target column \"" + targetColSpec.getName() + "\" has one value, only. " + "At least two target categories are expected.");
}
String[] learnerCols = new String[regressorColSpecs.size() + 1];
for (int i = 0; i < regressorColSpecs.size(); i++) {
learnerCols[i] = regressorColSpecs.get(i).getName();
}
learnerCols[learnerCols.length - 1] = targetColSpec.getName();
final DataColumnSpec[] updatedSpecs = new DataColumnSpec[inSpec.getNumColumns()];
for (int i = updatedSpecs.length; i-- > 0; ) {
final DataColumnSpec columnSpec = inSpec.getColumnSpec(i);
final DataType type = columnSpec.getType();
if (type.isCompatible(BitVectorValue.class) || type.isCompatible(ByteVectorValue.class)) {
final DataColumnSpecCreator colSpecCreator = new DataColumnSpecCreator(columnSpec.getName(), StringCell.TYPE);
colSpecCreator.setProperties(new DataColumnProperties(Collections.singletonMap("realType", type.isCompatible(BitVectorValue.class) ? "BitVector" : "ByteVector")));
updatedSpecs[i] = colSpecCreator.createSpec();
} else {
updatedSpecs[i] = columnSpec;
}
}
DataTableSpec updated = new DataTableSpec(updatedSpecs);
PMMLPortObjectSpecCreator creator = new PMMLPortObjectSpecCreator(updated);
creator.setTargetCols(Arrays.asList(targetColSpec));
creator.setLearningCols(regressorColSpecs);
// creator.addPreprocColNames(m_specialColumns.stream().flatMap(spec -> ));
m_pmmlOutSpec = creator.createSpec();
m_learner = new Learner(m_pmmlOutSpec, m_specialColumns, m_settings.getTargetReferenceCategory(), m_settings.getSortTargetCategories(), m_settings.getSortIncludesCategories());
} else {
throw new InvalidSettingsException("The target is " + "not in the input.");
}
}
use of org.knime.core.node.port.pmml.PMMLPortObjectSpecCreator in project knime-core by knime.
the class RegressionTreeModelPortObject method createDecisionTreePMMLPortObject.
public PMMLPortObject createDecisionTreePMMLPortObject() {
final RegressionTreeModel model = getModel();
DataTableSpec attributeLearnSpec = model.getLearnAttributeSpec(m_spec.getLearnTableSpec());
DataColumnSpec targetSpec = m_spec.getTargetColumn();
PMMLPortObjectSpecCreator pmmlSpecCreator = new PMMLPortObjectSpecCreator(new DataTableSpec(attributeLearnSpec, new DataTableSpec(targetSpec)));
try {
pmmlSpecCreator.setLearningCols(attributeLearnSpec);
} catch (InvalidSettingsException e) {
// (as of KNIME v2.5.1)
throw new IllegalStateException(e);
}
pmmlSpecCreator.setTargetCol(targetSpec);
PMMLPortObjectSpec pmmlSpec = pmmlSpecCreator.createSpec();
PMMLPortObject portObject = new PMMLPortObject(pmmlSpec);
final TreeModelRegression tree = model.getTreeModel();
portObject.addModelTranslater(new RegressionTreeModelPMMLTranslator(tree, model.getMetaData(), m_spec.getLearnTableSpec()));
return portObject;
}
use of org.knime.core.node.port.pmml.PMMLPortObjectSpecCreator in project knime-core by knime.
the class RuleEngine2PortsNodeModel method createPMMLPortObjectSpec.
/**
* Initializes the {@link PMMLPortObjectSpec} based on the model, input and the used column.
*
* @param modelSpec The preprocessing model, can be {@code null}.
* @param spec The input table spec.
* @param usedColumns The columns used by the rules.
* @return The {@link PMMLPortObjectSpec} filled with proper data for configuration.
*/
private PMMLPortObjectSpec createPMMLPortObjectSpec(final DataTableSpec spec, final List<String> usedColumns) {
// this assumes that the new column is always the last column in the spec; which is the case if
// #createRearranger uses ColumnRearranger.append.
String targetCol = m_settings.isReplaceColumn() ? m_settings.getReplaceColumn() : spec.getColumnSpec(spec.getNumColumns() - 1).getName();
Set<String> set = new LinkedHashSet<String>(usedColumns);
List<String> learnCols = new LinkedList<String>();
for (int i = 0; i < spec.getNumColumns(); i++) {
DataColumnSpec columnSpec = spec.getColumnSpec(i);
String col = columnSpec.getName();
if (!col.equals(targetCol) && set.contains(col) && (columnSpec.getType().isCompatible(DoubleValue.class) || columnSpec.getType().isCompatible(NominalValue.class) && (/*!m_skipColumns.getBooleanValue() ||*/
columnSpec.getDomain().hasValues()))) {
learnCols.add(spec.getColumnSpec(i).getName());
}
}
PMMLPortObjectSpecCreator pmmlSpecCreator = new PMMLPortObjectSpecCreator(spec);
pmmlSpecCreator.setLearningColsNames(learnCols);
pmmlSpecCreator.setTargetColName(targetCol);
return pmmlSpecCreator.createSpec();
}
Aggregations