use of org.knime.base.node.mine.treeensemble2.node.gradientboosting.predictor.GradientBoostingPredictor in project knime-core by knime.
the class GradientBoostingPredictorCellFactory method createFactory.
public static GradientBoostingPredictorCellFactory createFactory(final GradientBoostingPredictor<GradientBoostedTreesModel> predictor) throws InvalidSettingsException {
TreeEnsembleModelPortObjectSpec modelSpec = predictor.getModelSpec();
DataTableSpec learnSpec = modelSpec.getLearnTableSpec();
DataTableSpec testSpec = predictor.getDataSpec();
UniqueNameGenerator nameGen = new UniqueNameGenerator(testSpec);
DataColumnSpec newColSpec = nameGen.newColumn(predictor.getConfiguration().getPredictionColumnName(), DoubleCell.TYPE);
return new GradientBoostingPredictorCellFactory(newColSpec, predictor.getModel(), learnSpec, modelSpec.calculateFilterIndices(testSpec));
}
use of org.knime.base.node.mine.treeensemble2.node.gradientboosting.predictor.GradientBoostingPredictor in project knime-core by knime.
the class GradientBoostingPredictorNodeModel method configure.
/**
* {@inheritDoc}
*/
@Override
protected PortObjectSpec[] configure(final PortObjectSpec[] inSpecs) throws InvalidSettingsException {
TreeEnsembleModelPortObjectSpec modelSpec = (TreeEnsembleModelPortObjectSpec) inSpecs[0];
String targetColName = modelSpec.getTargetColumn().getName();
if (m_configuration == null) {
m_configuration = TreeEnsemblePredictorConfiguration.createDefault(false, targetColName);
} else if (!m_configuration.isChangePredictionColumnName()) {
m_configuration.setPredictionColumnName(TreeEnsemblePredictorConfiguration.getPredictColumnName(targetColName));
}
modelSpec.assertTargetTypeMatches(true);
DataTableSpec dataSpec = (DataTableSpec) inSpecs[1];
final GradientBoostingPredictor pred = new GradientBoostingPredictor(null, modelSpec, dataSpec, m_configuration);
return new PortObjectSpec[] { pred.getPredictionRearranger().createSpec() };
}
Aggregations