use of org.knime.core.data.DataType in project knime-core by knime.
the class RegressionTreePMMLTranslatorNodeModel method containsVector.
private static boolean containsVector(final DataTableSpec learnFeatureSpec) {
for (DataColumnSpec colSpec : learnFeatureSpec) {
DataType type = colSpec.getType();
boolean isVector = type.isCompatible(BitVectorValue.class) || type.isCompatible(DoubleVectorValue.class) || type.isCompatible(ByteVectorValue.class);
if (isVector) {
return true;
}
}
return false;
}
use of org.knime.core.data.DataType in project knime-core by knime.
the class GradientBoostingPMMLPredictorNodeModel method configure.
/**
* {@inheritDoc}
*/
@Override
protected PortObjectSpec[] configure(final PortObjectSpec[] inSpecs) throws InvalidSettingsException {
PMMLPortObjectSpec pmmlSpec = (PMMLPortObjectSpec) inSpecs[0];
DataType targetType = extractTargetType(pmmlSpec);
if (m_isRegression && !targetType.isCompatible(DoubleValue.class)) {
throw new InvalidSettingsException("This node expects a regression model.");
} else if (!m_isRegression && !targetType.isCompatible(StringValue.class)) {
throw new InvalidSettingsException("This node expectes a classification model.");
}
try {
AbstractTreeModelPMMLTranslator.checkPMMLSpec(pmmlSpec);
} catch (IllegalArgumentException e) {
throw new InvalidSettingsException(e.getMessage());
}
TreeEnsembleModelPortObjectSpec modelSpec = translateSpec(pmmlSpec);
String targetColName = modelSpec.getTargetColumn().getName();
if (m_configuration == null) {
m_configuration = TreeEnsemblePredictorConfiguration.createDefault(m_isRegression, targetColName);
} else if (!m_configuration.isChangePredictionColumnName()) {
m_configuration.setPredictionColumnName(TreeEnsemblePredictorConfiguration.getPredictColumnName(targetColName));
}
modelSpec.assertTargetTypeMatches(m_isRegression);
DataTableSpec dataSpec = (DataTableSpec) inSpecs[1];
final GradientBoostingPredictor<GradientBoostedTreesModel> pred = new GradientBoostingPredictor<>(null, modelSpec, dataSpec, m_configuration);
return new PortObjectSpec[] { pred.getPredictionRearranger().createSpec() };
}
use of org.knime.core.data.DataType in project knime-core by knime.
the class GradientBoostingPMMLPredictorNodeModel method importModel.
@SuppressWarnings("unchecked")
private GradientBoostingModelPortObject importModel(final PMMLPortObject pmmlPO) {
AbstractGBTModelPMMLTranslator<M> pmmlTranslator;
DataType targetType = extractTargetType(pmmlPO.getSpec());
if (targetType.isCompatible(DoubleValue.class)) {
pmmlTranslator = (AbstractGBTModelPMMLTranslator<M>) new RegressionGBTModelPMMLTranslator();
} else if (targetType.isCompatible(StringValue.class)) {
pmmlTranslator = (AbstractGBTModelPMMLTranslator<M>) new ClassificationGBTModelPMMLTranslator();
} else {
throw new IllegalArgumentException("Currently only regression models are supported.");
}
pmmlPO.initializeModelTranslator(pmmlTranslator);
if (pmmlTranslator.hasWarning()) {
setWarningMessage(pmmlTranslator.getWarning());
}
return new GradientBoostingModelPortObject(new TreeEnsembleModelPortObjectSpec(pmmlTranslator.getLearnSpec()), pmmlTranslator.getGBTModel());
}
use of org.knime.core.data.DataType in project knime-core by knime.
the class TreeEnsembleLearnerConfiguration method checkColumnSelection.
/**
* To be used in the configure of the learner nodes. Checks if the column selection makes sense and throws an
* InvalidSettingsException otherwise. The sanity checks include: <br>
* Existence and type check of fingerprint columns if specified. <br>
* Check if any attributes are selected if no fingerprint column is used for learning.
*
* @param inSpec Spec of the incoming table
* @throws InvalidSettingsException thrown if the column selection makes no sense
*/
public void checkColumnSelection(final DataTableSpec inSpec) throws InvalidSettingsException {
FilterResult filterResult = m_columnFilterConfig.applyTo(inSpec);
if (m_fingerprintColumn != null) {
DataColumnSpec colSpec = inSpec.getColumnSpec(m_fingerprintColumn);
if (colSpec == null) {
throw new InvalidSettingsException("The fingerprint column is not contained in the incoming table.");
}
DataType colType = colSpec.getType();
if (!(colType.isCompatible(BitVectorValue.class) || colType.isCompatible(ByteVectorValue.class) || colType.isCompatible(DoubleVectorValue.class))) {
throw new InvalidSettingsException("The specified fingerprint column is not of a compatible vector type.");
}
} else if (filterResult.getIncludes().length > 0) {
// ok, there are some features selected
} else {
throw new InvalidSettingsException("No attributes are selected.");
}
}
use of org.knime.core.data.DataType in project knime-core by knime.
the class TreeEnsembleLearnerConfiguration method filterLearnColumns.
/**
* @param spec
* @return ColumnRearranger that filters out all columns not part of the learning columns.
* @throws InvalidSettingsException
*/
public FilterLearnColumnRearranger filterLearnColumns(final DataTableSpec spec) throws InvalidSettingsException {
// (ColumnRearranger is a final class in v2.5)
if (m_targetColumn == null) {
throw new InvalidSettingsException("Target column not set");
}
DataColumnSpec targetCol = spec.getColumnSpec(m_targetColumn);
if (targetCol == null || !targetCol.getType().isCompatible(getRequiredTargetClass())) {
throw new InvalidSettingsException("Target column \"" + m_targetColumn + "\" does not exist or is not of the " + "correct type");
}
FilterResult filterResult = m_columnFilterConfig.applyTo(spec);
List<String> noDomainColumns = new ArrayList<String>();
FilterLearnColumnRearranger rearranger = new FilterLearnColumnRearranger(spec);
if (m_fingerprintColumn == null) {
// use ordinary data
Set<String> incl = new HashSet<String>(Arrays.asList(filterResult.getIncludes()));
// the target column can possibly show up in the include list of the filter result
// therefore we have to remove it
incl.remove(targetCol.getName());
for (DataColumnSpec col : spec) {
String colName = col.getName();
if (colName.equals(m_targetColumn)) {
continue;
}
DataType type = col.getType();
boolean ignoreColumn = false;
boolean isAppropriateType = type.isCompatible(DoubleValue.class) || type.isCompatible(NominalValue.class);
if (incl.remove(colName)) {
// accept unless type mismatch
if (!isAppropriateType) {
throw new InvalidSettingsException("Attribute column \"" + colName + "\" is " + "not of the expected type (must be " + "numeric or nominal).");
} else if (shouldIgnoreLearnColumn(col)) {
ignoreColumn = true;
noDomainColumns.add(colName);
} else {
// accept
}
} else {
ignoreColumn = true;
}
// }
if (ignoreColumn) {
rearranger.remove(colName);
}
}
if (rearranger.getColumnCount() <= 1) {
StringBuilder b = new StringBuilder("Input table has no valid " + "learning columns (need one additional numeric or " + "nominal column).");
if (!noDomainColumns.isEmpty()) {
b.append(" ").append(noDomainColumns.size());
b.append(" column(s) were ignored due to missing domain ");
b.append("information -- execute predecessor and/or ");
b.append(" use Domain Calculator node.");
throw new InvalidSettingsException(b.toString());
}
}
if (/*!m_includeAllColumns &&*/
!incl.isEmpty()) {
StringBuilder missings = new StringBuilder();
int i = 0;
for (Iterator<String> it = incl.iterator(); it.hasNext() && i < 4; i++) {
String s = it.next();
missings.append(i > 0 ? ", " : "").append(s);
it.remove();
}
if (!incl.isEmpty()) {
missings.append(",...").append(incl.size()).append(" more");
}
throw new InvalidSettingsException("Some selected attributes " + "are not present in the input table: " + missings);
}
} else {
// use fingerprint data
DataColumnSpec fpCol = spec.getColumnSpec(m_fingerprintColumn);
if (fpCol == null || !(fpCol.getType().isCompatible(BitVectorValue.class) || fpCol.getType().isCompatible(ByteVectorValue.class) || fpCol.getType().isCompatible(DoubleVectorValue.class))) {
throw new InvalidSettingsException("Fingerprint columnn \"" + m_fingerprintColumn + "\" does not exist or is not " + "of correct type.");
}
rearranger.keepOnly(m_targetColumn, m_fingerprintColumn);
}
rearranger.move(m_targetColumn, rearranger.getColumnCount());
String warn = null;
if (!noDomainColumns.isEmpty()) {
StringBuilder b = new StringBuilder();
b.append(noDomainColumns.size());
b.append(" column(s) were ignored due to missing domain");
b.append(" information: [");
int index = 0;
for (String s : noDomainColumns) {
if (index > 3) {
b.append(", ...");
break;
}
if (index > 0) {
b.append(", ");
}
b.append("\"").append(s).append("\"");
index++;
}
b.append("] -- change the node configuration or use a");
b.append(" Domain Calculator node to fix it");
warn = b.toString();
}
rearranger.setWarning(warn);
return rearranger;
}
Aggregations