use of org.knime.core.data.DataColumnSpec in project knime-core by knime.
the class MLPPredictorNodeModel method execute.
/**
* {@inheritDoc}
*/
@Override
public PortObject[] execute(final PortObject[] inData, final ExecutionContext exec) throws Exception {
BufferedDataTable testdata = (BufferedDataTable) inData[1];
PMMLPortObject pmmlPort = (PMMLPortObject) inData[0];
List<Node> models = pmmlPort.getPMMLValue().getModels(PMMLModelType.NeuralNetwork);
if (models.isEmpty()) {
String msg = "Neural network evaluation failed: " + "No neural network model found.";
LOGGER.error(msg);
throw new RuntimeException(msg);
}
PMMLNeuralNetworkTranslator trans = new PMMLNeuralNetworkTranslator();
pmmlPort.initializeModelTranslator(trans);
m_mlp = trans.getMLP();
m_columns = getLearningColumnIndices(testdata.getDataTableSpec(), pmmlPort.getSpec());
DataColumnSpec targetCol = pmmlPort.getSpec().getTargetCols().iterator().next();
MLPClassificationFactory mymlp;
/*
* Regression
*/
if (m_mlp.getMode() == MultiLayerPerceptron.REGRESSION_MODE) {
mymlp = new MLPClassificationFactory(true, m_columns, targetCol);
} else if (m_mlp.getMode() == MultiLayerPerceptron.CLASSIFICATION_MODE) {
/*
* Classification
*/
mymlp = new MLPClassificationFactory(false, m_columns, targetCol);
} else {
throw new Exception("Unsupported Mode: " + m_mlp.getMode());
}
ColumnRearranger colre = new ColumnRearranger(testdata.getDataTableSpec());
colre.append(mymlp);
BufferedDataTable bdt = exec.createColumnRearrangeTable(testdata, colre, exec);
return new BufferedDataTable[] { bdt };
}
use of org.knime.core.data.DataColumnSpec in project knime-core by knime.
the class LinRegLearnerNodeDialogPane method loadSettingsFrom.
/**
* {@inheritDoc}
*/
@Override
protected void loadSettingsFrom(final NodeSettingsRO settings, final PortObjectSpec[] specs) throws NotConfigurableException {
// must check if there are at least two numeric columns
int numColsCount = 0;
DataTableSpec dts = (DataTableSpec) specs[0];
for (DataColumnSpec c : dts) {
if (c.getType().isCompatible(DoubleValue.class)) {
numColsCount++;
if (numColsCount >= 2) {
break;
}
}
}
if (numColsCount < 2) {
throw new NotConfigurableException("Too few numeric columns " + "(need at least 2): " + numColsCount);
}
boolean includeAll = settings.getBoolean(LinRegLearnerNodeModel.CFG_VARIATES_USE_ALL, false);
String[] includes = settings.getStringArray(LinRegLearnerNodeModel.CFG_VARIATES, new String[0]);
String target = settings.getString(LinRegLearnerNodeModel.CFG_TARGET, null);
boolean isCalcError = settings.getBoolean(LinRegLearnerNodeModel.CFG_CALC_ERROR, true);
int first = settings.getInt(LinRegLearnerNodeModel.CFG_FROMROW, 1);
int count = settings.getInt(LinRegLearnerNodeModel.CFG_ROWCNT, 10000);
m_selectionPanel.update(dts, target);
m_filterPanel.setKeepAllSelected(includeAll);
// if includes list is empty, put everything into the include list
m_filterPanel.update(dts, includes.length == 0, includes);
// must hide the target from filter panel
// updating m_filterPanel first does not work as the first
// element in the spec will always be in the exclude list.
String selected = m_selectionPanel.getSelectedColumn();
if (selected != null) {
DataColumnSpec colSpec = dts.getColumnSpec(selected);
m_filterPanel.hideColumns(colSpec);
}
m_isCalcErrorChecker.setSelected(isCalcError);
m_firstSpinner.setValue(first);
m_countSpinner.setValue(count);
}
use of org.knime.core.data.DataColumnSpec in project knime-core by knime.
the class LinRegLearnerNodeModel method computeIncludes.
/**
* Determines the list of variate columns (learning columns). This is
* either the m_includes[] field or, if m_includeAll is set, the list
* of double-compatible columns in the input table spec (excluding the
* response column).
* @param in Spec contributing the column list
* @return A new array containg the variates
* @throws InvalidSettingsException If no double-compatible learning columns
* exist in the input table.
*/
private String[] computeIncludes(final DataTableSpec in) throws InvalidSettingsException {
String[] includes;
if (m_includeAll) {
List<String> includeList = new ArrayList<String>();
for (DataColumnSpec s : in) {
if (s.getType().isCompatible(DoubleValue.class)) {
String name = s.getName();
if (!name.equals(m_target)) {
includeList.add(name);
}
}
}
includes = includeList.toArray(new String[includeList.size()]);
if (includes.length == 0) {
throw new InvalidSettingsException("No double-compatible " + "variables (learning columns) in input table");
}
} else {
if (m_includes == null) {
throw new InvalidSettingsException("No settings available");
}
includes = m_includes.clone();
}
return includes;
}
use of org.knime.core.data.DataColumnSpec in project knime-core by knime.
the class RegressionPredictorNodeDialogPane method loadSettingsFrom.
/**
* {@inheritDoc}
*/
@Override
protected void loadSettingsFrom(final NodeSettingsRO settings, final PortObjectSpec[] specs) throws NotConfigurableException {
RegressionPredictorSettings s = new RegressionPredictorSettings();
s.loadSettingsForDialog(settings);
m_hasCustomPredictionName.setSelected(s.getHasCustomPredictionName());
PMMLPortObjectSpec portSpec = (PMMLPortObjectSpec) specs[0];
DataTableSpec tableSpec = (DataTableSpec) specs[1];
if (s.getCustomPredictionName() != null) {
m_customPredictionName.setText(s.getCustomPredictionName());
} else {
try {
DataColumnSpec[] outSpec = RegressionPredictorCellFactory.createColumnSpec(portSpec, tableSpec, new RegressionPredictorSettings());
m_customPredictionName.setText(outSpec[outSpec.length - 1].getName());
} catch (InvalidSettingsException e) {
// Open dialog and give a chance define settings
}
}
m_includeProbs.setSelected(s.getIncludeProbabilities());
m_probColumnSuffix.setText(s.getPropColumnSuffix());
updateEnableState();
}
use of org.knime.core.data.DataColumnSpec in project knime-core by knime.
the class RegressionPredictorNodeModel method createRearranger.
private ColumnRearranger createRearranger(final PMMLGeneralRegressionContent content, final PMMLPortObjectSpec pmmlSpec, final DataTableSpec inDataSpec) throws InvalidSettingsException {
if (content == null) {
throw new InvalidSettingsException("No input");
}
// the predictor can only predict linear regression models
if (!(content.getModelType().equals(ModelType.multinomialLogistic) || content.getModelType().equals(ModelType.generalLinear))) {
throw new InvalidSettingsException("Model Type: " + content.getModelType() + " is not supported.");
}
if (content.getModelType().equals(ModelType.generalLinear) && !content.getFunctionName().equals(FunctionName.regression)) {
throw new InvalidSettingsException("Function Name: " + content.getFunctionName() + " is not supported for linear regression.");
}
if (content.getModelType().equals(ModelType.multinomialLogistic) && !content.getFunctionName().equals(FunctionName.classification)) {
throw new InvalidSettingsException("Function Name: " + content.getFunctionName() + " is not supported for logistic regression.");
}
// are nominal values
for (PMMLPredictor factor : content.getFactorList()) {
DataColumnSpec columnSpec = inDataSpec.getColumnSpec(factor.getName());
if (null == columnSpec) {
throw new InvalidSettingsException("The column \"" + factor.getName() + "\" is in the model but not in given table.");
}
if (!columnSpec.getType().isCompatible(NominalValue.class)) {
throw new InvalidSettingsException("The column \"" + factor.getName() + "\" is supposed to be nominal.");
}
}
// check if all covariates are in the given data table and that they
// are numeric values
Pattern pattern = Pattern.compile("(.*)\\[\\d+\\]");
for (PMMLPredictor covariate : content.getCovariateList()) {
DataColumnSpec columnSpec = inDataSpec.getColumnSpec(covariate.getName());
if (null == columnSpec) {
Matcher matcher = pattern.matcher(covariate.getName());
boolean found = matcher.matches();
columnSpec = inDataSpec.getColumnSpec(matcher.group(1));
found = found && null != columnSpec;
if (!found) {
throw new InvalidSettingsException("The column \"" + covariate.getName() + "\" is in the model but not in given table.");
}
}
if (columnSpec != null && !columnSpec.getType().isCompatible(DoubleValue.class) && !(content.getVectorLengths().containsKey(columnSpec.getName()) && ((columnSpec.getType().isCollectionType() && columnSpec.getType().getCollectionElementType().isCompatible(DoubleValue.class)) || columnSpec.getType().isCompatible(BitVectorValue.class) || columnSpec.getType().isCompatible(ByteVectorValue.class)))) {
throw new InvalidSettingsException("The column \"" + covariate.getName() + "\" is supposed to be numeric.");
}
}
ColumnRearranger c = new ColumnRearranger(inDataSpec);
if (content.getModelType().equals(ModelType.generalLinear)) {
c.append(new LinReg2Predictor(content, inDataSpec, pmmlSpec, pmmlSpec.getTargetFields().get(0), m_settings));
} else {
c.append(new LogRegPredictor(content, inDataSpec, pmmlSpec, pmmlSpec.getTargetFields().get(0), m_settings));
}
return c;
}
Aggregations