use of org.knime.base.node.mine.treeensemble2.model.TreeEnsembleModelPortObjectSpec in project knime-core by knime.
the class GradientBoostingClassificationPredictorNodeModel method execute.
/**
* {@inheritDoc}
*/
@Override
protected PortObject[] execute(final PortObject[] inObjects, final ExecutionContext exec) throws Exception {
GradientBoostingModelPortObject model = (GradientBoostingModelPortObject) inObjects[0];
TreeEnsembleModelPortObjectSpec modelSpec = model.getSpec();
BufferedDataTable data = (BufferedDataTable) inObjects[1];
DataTableSpec dataSpec = data.getDataTableSpec();
GradientBoostingPredictor<MultiClassGradientBoostedTreesModel> predictor = new GradientBoostingPredictor<>((MultiClassGradientBoostedTreesModel) model.getEnsembleModel(), modelSpec, dataSpec, m_configuration);
ColumnRearranger rearranger = predictor.getPredictionRearranger();
BufferedDataTable outTable = exec.createColumnRearrangeTable(data, rearranger, exec);
return new BufferedDataTable[] { outTable };
}
use of org.knime.base.node.mine.treeensemble2.model.TreeEnsembleModelPortObjectSpec in project knime-core by knime.
the class GradientBoostingPredictor method createPredictionRearranger.
/**
* {@inheritDoc}
*/
@Override
protected ColumnRearranger createPredictionRearranger() throws InvalidSettingsException {
TreeEnsembleModelPortObjectSpec modelSpec = getModelSpec();
DataTableSpec dataSpec = getDataSpec();
ColumnRearranger predictionRearranger;
boolean hasPossibleValues = modelSpec.getTargetColumnPossibleValueMap() != null;
if (modelSpec.getTargetColumn().getType().isCompatible(DoubleValue.class)) {
predictionRearranger = new ColumnRearranger(dataSpec);
@SuppressWarnings("unchecked") GradientBoostingPredictor<GradientBoostedTreesModel> pred = (GradientBoostingPredictor<GradientBoostedTreesModel>) this;
predictionRearranger.append(GradientBoostingPredictorCellFactory.createFactory(pred));
} else if (getConfiguration().isAppendClassConfidences() && !hasPossibleValues) {
// can't add confidence columns (possible values unknown)
predictionRearranger = null;
} else {
predictionRearranger = new ColumnRearranger(dataSpec);
@SuppressWarnings("unchecked") GradientBoostingPredictor<MultiClassGradientBoostedTreesModel> pred = (GradientBoostingPredictor<MultiClassGradientBoostedTreesModel>) this;
predictionRearranger.append(LKGradientBoostingPredictorCellFactory.createFactory(pred));
}
return predictionRearranger;
}
use of org.knime.base.node.mine.treeensemble2.model.TreeEnsembleModelPortObjectSpec in project knime-core by knime.
the class GradientBoostingPMMLPredictorNodeModel method configure.
/**
* {@inheritDoc}
*/
@Override
protected PortObjectSpec[] configure(final PortObjectSpec[] inSpecs) throws InvalidSettingsException {
PMMLPortObjectSpec pmmlSpec = (PMMLPortObjectSpec) inSpecs[0];
DataType targetType = extractTargetType(pmmlSpec);
if (m_isRegression && !targetType.isCompatible(DoubleValue.class)) {
throw new InvalidSettingsException("This node expects a regression model.");
} else if (!m_isRegression && !targetType.isCompatible(StringValue.class)) {
throw new InvalidSettingsException("This node expectes a classification model.");
}
try {
AbstractTreeModelPMMLTranslator.checkPMMLSpec(pmmlSpec);
} catch (IllegalArgumentException e) {
throw new InvalidSettingsException(e.getMessage());
}
TreeEnsembleModelPortObjectSpec modelSpec = translateSpec(pmmlSpec);
String targetColName = modelSpec.getTargetColumn().getName();
if (m_configuration == null) {
m_configuration = TreeEnsemblePredictorConfiguration.createDefault(m_isRegression, targetColName);
} else if (!m_configuration.isChangePredictionColumnName()) {
m_configuration.setPredictionColumnName(TreeEnsemblePredictorConfiguration.getPredictColumnName(targetColName));
}
modelSpec.assertTargetTypeMatches(m_isRegression);
DataTableSpec dataSpec = (DataTableSpec) inSpecs[1];
final GradientBoostingPredictor<GradientBoostedTreesModel> pred = new GradientBoostingPredictor<>(null, modelSpec, dataSpec, m_configuration);
return new PortObjectSpec[] { pred.getPredictionRearranger().createSpec() };
}
use of org.knime.base.node.mine.treeensemble2.model.TreeEnsembleModelPortObjectSpec in project knime-core by knime.
the class GradientBoostingPMMLPredictorNodeModel method importModel.
@SuppressWarnings("unchecked")
private GradientBoostingModelPortObject importModel(final PMMLPortObject pmmlPO) {
AbstractGBTModelPMMLTranslator<M> pmmlTranslator;
DataType targetType = extractTargetType(pmmlPO.getSpec());
if (targetType.isCompatible(DoubleValue.class)) {
pmmlTranslator = (AbstractGBTModelPMMLTranslator<M>) new RegressionGBTModelPMMLTranslator();
} else if (targetType.isCompatible(StringValue.class)) {
pmmlTranslator = (AbstractGBTModelPMMLTranslator<M>) new ClassificationGBTModelPMMLTranslator();
} else {
throw new IllegalArgumentException("Currently only regression models are supported.");
}
pmmlPO.initializeModelTranslator(pmmlTranslator);
if (pmmlTranslator.hasWarning()) {
setWarningMessage(pmmlTranslator.getWarning());
}
return new GradientBoostingModelPortObject(new TreeEnsembleModelPortObjectSpec(pmmlTranslator.getLearnSpec()), pmmlTranslator.getGBTModel());
}
use of org.knime.base.node.mine.treeensemble2.model.TreeEnsembleModelPortObjectSpec in project knime-core by knime.
the class GradientBoostingPMMLPredictorNodeModel method translateSpec.
private TreeEnsembleModelPortObjectSpec translateSpec(final PMMLPortObjectSpec pmmlSpec) {
DataTableSpec pmmlDataSpec = pmmlSpec.getDataTableSpec();
ColumnRearranger cr = new ColumnRearranger(pmmlDataSpec);
List<DataColumnSpec> targets = pmmlSpec.getTargetCols();
CheckUtils.checkArgument(!targets.isEmpty(), "The provided PMML does not declare a target field.");
CheckUtils.checkArgument(targets.size() == 1, "The provided PMML declares multiple target. " + "This behavior is currently not supported.");
cr.move(targets.get(0).getName(), pmmlDataSpec.getNumColumns());
return new TreeEnsembleModelPortObjectSpec(cr.createSpec());
}
Aggregations