use of org.knime.core.data.container.ColumnRearranger in project knime-core by knime.
the class MLPPredictorNodeModel method execute.
/**
* {@inheritDoc}
*/
@Override
public PortObject[] execute(final PortObject[] inData, final ExecutionContext exec) throws Exception {
BufferedDataTable testdata = (BufferedDataTable) inData[1];
PMMLPortObject pmmlPort = (PMMLPortObject) inData[0];
List<Node> models = pmmlPort.getPMMLValue().getModels(PMMLModelType.NeuralNetwork);
if (models.isEmpty()) {
String msg = "Neural network evaluation failed: " + "No neural network model found.";
LOGGER.error(msg);
throw new RuntimeException(msg);
}
PMMLNeuralNetworkTranslator trans = new PMMLNeuralNetworkTranslator();
pmmlPort.initializeModelTranslator(trans);
m_mlp = trans.getMLP();
m_columns = getLearningColumnIndices(testdata.getDataTableSpec(), pmmlPort.getSpec());
DataColumnSpec targetCol = pmmlPort.getSpec().getTargetCols().iterator().next();
MLPClassificationFactory mymlp;
/*
* Regression
*/
if (m_mlp.getMode() == MultiLayerPerceptron.REGRESSION_MODE) {
mymlp = new MLPClassificationFactory(true, m_columns, targetCol);
} else if (m_mlp.getMode() == MultiLayerPerceptron.CLASSIFICATION_MODE) {
/*
* Classification
*/
mymlp = new MLPClassificationFactory(false, m_columns, targetCol);
} else {
throw new Exception("Unsupported Mode: " + m_mlp.getMode());
}
ColumnRearranger colre = new ColumnRearranger(testdata.getDataTableSpec());
colre.append(mymlp);
BufferedDataTable bdt = exec.createColumnRearrangeTable(testdata, colre, exec);
return new BufferedDataTable[] { bdt };
}
use of org.knime.core.data.container.ColumnRearranger in project knime-core by knime.
the class RegressionPredictorNodeModel method createRearranger.
private ColumnRearranger createRearranger(final PMMLGeneralRegressionContent content, final PMMLPortObjectSpec pmmlSpec, final DataTableSpec inDataSpec) throws InvalidSettingsException {
if (content == null) {
throw new InvalidSettingsException("No input");
}
// the predictor can only predict linear regression models
if (!(content.getModelType().equals(ModelType.multinomialLogistic) || content.getModelType().equals(ModelType.generalLinear))) {
throw new InvalidSettingsException("Model Type: " + content.getModelType() + " is not supported.");
}
if (content.getModelType().equals(ModelType.generalLinear) && !content.getFunctionName().equals(FunctionName.regression)) {
throw new InvalidSettingsException("Function Name: " + content.getFunctionName() + " is not supported for linear regression.");
}
if (content.getModelType().equals(ModelType.multinomialLogistic) && !content.getFunctionName().equals(FunctionName.classification)) {
throw new InvalidSettingsException("Function Name: " + content.getFunctionName() + " is not supported for logistic regression.");
}
// are nominal values
for (PMMLPredictor factor : content.getFactorList()) {
DataColumnSpec columnSpec = inDataSpec.getColumnSpec(factor.getName());
if (null == columnSpec) {
throw new InvalidSettingsException("The column \"" + factor.getName() + "\" is in the model but not in given table.");
}
if (!columnSpec.getType().isCompatible(NominalValue.class)) {
throw new InvalidSettingsException("The column \"" + factor.getName() + "\" is supposed to be nominal.");
}
}
// check if all covariates are in the given data table and that they
// are numeric values
Pattern pattern = Pattern.compile("(.*)\\[\\d+\\]");
for (PMMLPredictor covariate : content.getCovariateList()) {
DataColumnSpec columnSpec = inDataSpec.getColumnSpec(covariate.getName());
if (null == columnSpec) {
Matcher matcher = pattern.matcher(covariate.getName());
boolean found = matcher.matches();
columnSpec = inDataSpec.getColumnSpec(matcher.group(1));
found = found && null != columnSpec;
if (!found) {
throw new InvalidSettingsException("The column \"" + covariate.getName() + "\" is in the model but not in given table.");
}
}
if (columnSpec != null && !columnSpec.getType().isCompatible(DoubleValue.class) && !(content.getVectorLengths().containsKey(columnSpec.getName()) && ((columnSpec.getType().isCollectionType() && columnSpec.getType().getCollectionElementType().isCompatible(DoubleValue.class)) || columnSpec.getType().isCompatible(BitVectorValue.class) || columnSpec.getType().isCompatible(ByteVectorValue.class)))) {
throw new InvalidSettingsException("The column \"" + covariate.getName() + "\" is supposed to be numeric.");
}
}
ColumnRearranger c = new ColumnRearranger(inDataSpec);
if (content.getModelType().equals(ModelType.generalLinear)) {
c.append(new LinReg2Predictor(content, inDataSpec, pmmlSpec, pmmlSpec.getTargetFields().get(0), m_settings));
} else {
c.append(new LogRegPredictor(content, inDataSpec, pmmlSpec, pmmlSpec.getTargetFields().get(0), m_settings));
}
return c;
}
use of org.knime.core.data.container.ColumnRearranger in project knime-core by knime.
the class RegressionPredictorNodeModel method execute.
/**
* {@inheritDoc}
*/
@Override
public PortObject[] execute(final PortObject[] inData, final ExecutionContext exec) throws Exception {
PMMLPortObject port = (PMMLPortObject) inData[0];
List<Node> models = port.getPMMLValue().getModels(PMMLModelType.GeneralRegressionModel);
if (models.isEmpty()) {
LOGGER.warn("No regression models in the input PMML.");
@SuppressWarnings("deprecation") org.knime.base.node.mine.regression.predict.RegressionPredictorNodeModel regrPredictor = new org.knime.base.node.mine.regression.predict.RegressionPredictorNodeModel();
@SuppressWarnings("deprecation") PortObject[] regrPredOut = regrPredictor.execute(inData, exec);
if (regrPredOut.length > 0 && regrPredOut[0] instanceof BufferedDataTable) {
BufferedDataTable regrPredOutTable = (BufferedDataTable) regrPredOut[0];
// replace name of prediction column (the last column of regrPredOutTable)
return new PortObject[] { adjustSpecOfRegressionPredictorTable(regrPredOutTable, inData, exec) };
} else {
return regrPredOut;
}
}
PMMLGeneralRegressionTranslator trans = new PMMLGeneralRegressionTranslator();
port.initializeModelTranslator(trans);
BufferedDataTable data = (BufferedDataTable) inData[1];
DataTableSpec spec = data.getDataTableSpec();
ColumnRearranger c = createRearranger(trans.getContent(), port.getSpec(), spec);
BufferedDataTable out = exec.createColumnRearrangeTable(data, c, exec);
return new BufferedDataTable[] { out };
}
use of org.knime.core.data.container.ColumnRearranger in project knime-core by knime.
the class RegressionPredictorNodeModel method configure.
/**
* {@inheritDoc}
*/
@Override
protected PortObjectSpec[] configure(final PortObjectSpec[] inSpecs) throws InvalidSettingsException {
PMMLPortObjectSpec regModelSpec = (PMMLPortObjectSpec) inSpecs[0];
DataTableSpec dataSpec = (DataTableSpec) inSpecs[1];
if (dataSpec == null || regModelSpec == null) {
throw new InvalidSettingsException("No input specification available");
}
if (regModelSpec.getTargetCols().get(0).getType().isCompatible(DoubleValue.class) && m_settings.getIncludeProbabilities()) {
setWarningMessage("The option \"Append columns with predicted probabilities\"" + " has only an effect for nominal targets");
}
if (null != RegressionPredictorCellFactory.createColumnSpec(regModelSpec, dataSpec, m_settings)) {
ColumnRearranger c = new ColumnRearranger(dataSpec);
c.append(new RegressionPredictorCellFactory(regModelSpec, dataSpec, m_settings) {
@Override
public DataCell[] getCells(final DataRow row) {
// not called during configure
return null;
}
});
DataTableSpec outSpec = c.createSpec();
return new DataTableSpec[] { outSpec };
} else {
return null;
}
}
use of org.knime.core.data.container.ColumnRearranger in project knime-core by knime.
the class BitVectorGeneratorNodeModel method configure.
/**
* Assume to get numeric data only. Output is one column of type BitVector.
*
* {@inheritDoc}
*/
@Override
protected DataTableSpec[] configure(final DataTableSpec[] inSpecs) throws InvalidSettingsException {
DataTableSpec spec = inSpecs[0];
if (!m_fromString) {
// check if there is at least one numeric column selected
if (m_includedColumns.isEnabled() && m_includedColumns.getIncludeList().isEmpty()) {
// the includeColumns model cannot be empty
// through the dialog (see #validateSettings)
// only case where !m_fromString and includeColumns evaluates
// to true is for old workflows.
// For backward compatiblity include all numeric columns
// which was the behavior before 2.0
List<String> allNumericColumns = new ArrayList<String>();
for (DataColumnSpec colSpec : spec) {
if (colSpec.getType().isCompatible(DoubleValue.class)) {
allNumericColumns.add(colSpec.getName());
}
}
m_includedColumns.setIncludeList(allNumericColumns);
m_loadedSettingsDontHaveIncludeColumns = false;
}
for (String inclColName : m_includedColumns.getIncludeList()) {
DataColumnSpec colSpec = spec.getColumnSpec(inclColName);
if (colSpec == null) {
throw new InvalidSettingsException("Column " + inclColName + " not found in input table. " + "Please re-configure the node.");
}
if (!colSpec.getType().isCompatible(DoubleValue.class)) {
throw new InvalidSettingsException("Column " + inclColName + " is not a numeric column");
}
}
} else {
// parse from string column
if (m_stringColumn.getStringValue() == null) {
throw new InvalidSettingsException("No string column selected. " + "Please (re-)configure the node.");
}
// -> check if selected column is a string column
if (!spec.containsName(m_stringColumn.getStringValue()) || !(spec.getColumnSpec(m_stringColumn.getStringValue()).getType().isCompatible(StringValue.class))) {
throw new InvalidSettingsException("Selected string column " + m_stringColumn.getStringValue() + " not in the input table");
}
}
if (m_fromString) {
int stringColIdx = inSpecs[0].findColumnIndex(m_stringColumn.getStringValue());
ColumnRearranger c = createColumnRearranger(inSpecs[0], stringColIdx);
return new DataTableSpec[] { c.createSpec() };
} else {
// numeric input
DataTableSpec newSpec;
DataColumnSpec newColSpec = createNumericOutputSpec(spec);
if (m_replace) {
ColumnRearranger colR = new ColumnRearranger(spec);
colR.remove(m_includedColumns.getIncludeList().toArray(new String[m_includedColumns.getIncludeList().size()]));
newSpec = new DataTableSpec(colR.createSpec(), new DataTableSpec(newColSpec));
} else {
newSpec = new DataTableSpec(spec, new DataTableSpec(newColSpec));
}
return new DataTableSpec[] { newSpec };
}
}
Aggregations