use of org.knime.base.node.mine.treeensemble2.node.gradientboosting.predictor.GradientBoostingPredictor in project knime-core by knime.
the class GradientBoostingPredictorNodeModel method createStreamableOperator.
/**
* {@inheritDoc}
*/
@Override
public StreamableOperator createStreamableOperator(final PartitionInfo partitionInfo, final PortObjectSpec[] inSpecs) throws InvalidSettingsException {
return new StreamableOperator() {
@Override
public void runFinal(final PortInput[] inputs, final PortOutput[] outputs, final ExecutionContext exec) throws Exception {
GradientBoostingModelPortObject model = (GradientBoostingModelPortObject) ((PortObjectInput) inputs[0]).getPortObject();
TreeEnsembleModelPortObjectSpec modelSpec = model.getSpec();
DataTableSpec dataSpec = (DataTableSpec) inSpecs[1];
final GradientBoostingPredictor<GradientBoostedTreesModel> pred = new GradientBoostingPredictor<>((GradientBoostedTreesModel) model.getEnsembleModel(), modelSpec, dataSpec, m_configuration);
ColumnRearranger rearranger = pred.getPredictionRearranger();
StreamableFunction func = rearranger.createStreamableFunction(1, 0);
func.runFinal(inputs, outputs, exec);
}
};
}
use of org.knime.base.node.mine.treeensemble2.node.gradientboosting.predictor.GradientBoostingPredictor in project knime-core by knime.
the class GradientBoostingPredictorNodeModel method execute.
/**
* {@inheritDoc}
*/
@Override
protected PortObject[] execute(final PortObject[] inObjects, final ExecutionContext exec) throws Exception {
GradientBoostingModelPortObject model = (GradientBoostingModelPortObject) inObjects[0];
TreeEnsembleModelPortObjectSpec modelSpec = model.getSpec();
BufferedDataTable data = (BufferedDataTable) inObjects[1];
DataTableSpec dataSpec = data.getDataTableSpec();
final GradientBoostingPredictor<GradientBoostedTreesModel> pred = new GradientBoostingPredictor<>((GradientBoostedTreesModel) model.getEnsembleModel(), modelSpec, dataSpec, m_configuration);
ColumnRearranger rearranger = pred.getPredictionRearranger();
BufferedDataTable outTable = exec.createColumnRearrangeTable(data, rearranger, exec);
return new BufferedDataTable[] { outTable };
}
use of org.knime.base.node.mine.treeensemble2.node.gradientboosting.predictor.GradientBoostingPredictor in project knime-core by knime.
the class GradientBoostingClassificationPredictorNodeModel method configure.
/**
* {@inheritDoc}
*/
@Override
protected PortObjectSpec[] configure(final PortObjectSpec[] inSpecs) throws InvalidSettingsException {
TreeEnsembleModelPortObjectSpec modelSpec = (TreeEnsembleModelPortObjectSpec) inSpecs[0];
String targetColName = modelSpec.getTargetColumn().getName();
if (m_configuration == null) {
m_configuration = TreeEnsemblePredictorConfiguration.createDefault(false, targetColName);
} else if (!m_configuration.isChangePredictionColumnName()) {
m_configuration.setPredictionColumnName(TreeEnsemblePredictorConfiguration.getPredictColumnName(targetColName));
}
modelSpec.assertTargetTypeMatches(false);
DataTableSpec dataSpec = (DataTableSpec) inSpecs[1];
GradientBoostingPredictor<MultiClassGradientBoostedTreesModel> predictor = new GradientBoostingPredictor<>(null, modelSpec, dataSpec, m_configuration);
ColumnRearranger rearranger = predictor.getPredictionRearranger();
return new PortObjectSpec[] { rearranger.createSpec() };
}
use of org.knime.base.node.mine.treeensemble2.node.gradientboosting.predictor.GradientBoostingPredictor in project knime-core by knime.
the class LKGradientBoostingPredictorCellFactory method createFactory.
public static LKGradientBoostingPredictorCellFactory createFactory(final GradientBoostingPredictor<MultiClassGradientBoostedTreesModel> predictor) throws InvalidSettingsException {
TreeEnsemblePredictorConfiguration config = predictor.getConfiguration();
DataTableSpec testSpec = predictor.getDataSpec();
TreeEnsembleModelPortObjectSpec modelSpec = predictor.getModelSpec();
ArrayList<DataColumnSpec> newColSpecs = new ArrayList<DataColumnSpec>();
UniqueNameGenerator nameGen = new UniqueNameGenerator(testSpec);
newColSpecs.add(nameGen.newColumn(config.getPredictionColumnName(), StringCell.TYPE));
if (config.isAppendPredictionConfidence()) {
newColSpecs.add(nameGen.newColumn("Confidence", DoubleCell.TYPE));
}
if (config.isAppendClassConfidences()) {
final String targetColName = modelSpec.getTargetColumn().getName();
final String suffix = config.getSuffixForClassProbabilities();
for (String val : modelSpec.getTargetColumnPossibleValueMap().keySet()) {
String colName = "P(" + targetColName + "=" + val + ")" + suffix;
newColSpecs.add(nameGen.newColumn(colName, DoubleCell.TYPE));
}
}
final Map<String, DataCell> targetValueMap = modelSpec.getTargetColumnPossibleValueMap();
return new LKGradientBoostingPredictorCellFactory(newColSpecs.toArray(new DataColumnSpec[newColSpecs.size()]), predictor.getModel(), modelSpec.getLearnTableSpec(), modelSpec.calculateFilterIndices(testSpec), config, targetValueMap);
}
use of org.knime.base.node.mine.treeensemble2.node.gradientboosting.predictor.GradientBoostingPredictor in project knime-core by knime.
the class GradientBoostingPMMLPredictorNodeModel method createStreamableOperator.
/**
* {@inheritDoc}
*/
@Override
public StreamableOperator createStreamableOperator(final PartitionInfo partitionInfo, final PortObjectSpec[] inSpecs) throws InvalidSettingsException {
return new StreamableOperator() {
@Override
public void runFinal(final PortInput[] inputs, final PortOutput[] outputs, final ExecutionContext exec) throws Exception {
PMMLPortObject model = (PMMLPortObject) ((PortObjectInput) inputs[0]).getPortObject();
DataTableSpec dataSpec = (DataTableSpec) inSpecs[1];
GradientBoostingModelPortObject gbt = importModel(model);
final GradientBoostingPredictor<?> pred = new GradientBoostingPredictor<>(gbt.getEnsembleModel(), gbt.getSpec(), dataSpec, m_configuration);
ColumnRearranger rearranger = pred.getPredictionRearranger();
StreamableFunction func = rearranger.createStreamableFunction(1, 0);
func.runFinal(inputs, outputs, exec);
}
};
}
Aggregations