use of org.knime.base.node.mine.treeensemble2.model.MultiClassGradientBoostedTreesModel in project knime-core by knime.
the class GBTPMMLExporterNodeModel method execute.
/**
* {@inheritDoc}
*/
@Override
protected PortObject[] execute(final PortObject[] inObjects, final ExecutionContext exec) throws Exception {
GradientBoostingModelPortObject gbtPO = (GradientBoostingModelPortObject) inObjects[0];
AbstractGBTModelPMMLTranslator<?> translator;
AbstractGradientBoostingModel gbtModel = gbtPO.getEnsembleModel();
if (gbtModel instanceof GradientBoostedTreesModel) {
translator = new RegressionGBTModelPMMLTranslator((GradientBoostedTreesModel) gbtModel, gbtPO.getSpec().getLearnTableSpec());
} else if (gbtModel instanceof MultiClassGradientBoostedTreesModel) {
translator = new ClassificationGBTModelPMMLTranslator((MultiClassGradientBoostedTreesModel) gbtModel, gbtPO.getSpec().getLearnTableSpec());
} else {
throw new IllegalArgumentException("Unknown gradient boosted trees model type '" + gbtModel.getClass().getSimpleName() + "'.");
}
PMMLPortObjectSpec pmmlSpec = createPMMLSpec(gbtPO.getSpec(), gbtModel);
PMMLPortObject pmmlPO = new PMMLPortObject(pmmlSpec);
pmmlPO.addModelTranslater(translator);
return new PortObject[] { pmmlPO };
}
use of org.knime.base.node.mine.treeensemble2.model.MultiClassGradientBoostedTreesModel in project knime-core by knime.
the class GradientBoostingClassificationPredictorNodeModel method createStreamableOperator.
/**
* {@inheritDoc}
*/
@Override
public StreamableOperator createStreamableOperator(final PartitionInfo partitionInfo, final PortObjectSpec[] inSpecs) throws InvalidSettingsException {
return new StreamableOperator() {
@Override
public void runFinal(final PortInput[] inputs, final PortOutput[] outputs, final ExecutionContext exec) throws Exception {
GradientBoostingModelPortObject model = (GradientBoostingModelPortObject) ((PortObjectInput) inputs[0]).getPortObject();
TreeEnsembleModelPortObjectSpec modelSpec = model.getSpec();
DataTableSpec dataSpec = (DataTableSpec) inSpecs[1];
final GradientBoostingPredictor<MultiClassGradientBoostedTreesModel> pred = new GradientBoostingPredictor<>((MultiClassGradientBoostedTreesModel) model.getEnsembleModel(), modelSpec, dataSpec, m_configuration);
ColumnRearranger rearranger = pred.getPredictionRearranger();
StreamableFunction func = rearranger.createStreamableFunction(1, 0);
func.runFinal(inputs, outputs, exec);
}
};
}
use of org.knime.base.node.mine.treeensemble2.model.MultiClassGradientBoostedTreesModel in project knime-core by knime.
the class GradientBoostingClassificationPredictorNodeModel method execute.
/**
* {@inheritDoc}
*/
@Override
protected PortObject[] execute(final PortObject[] inObjects, final ExecutionContext exec) throws Exception {
GradientBoostingModelPortObject model = (GradientBoostingModelPortObject) inObjects[0];
TreeEnsembleModelPortObjectSpec modelSpec = model.getSpec();
BufferedDataTable data = (BufferedDataTable) inObjects[1];
DataTableSpec dataSpec = data.getDataTableSpec();
GradientBoostingPredictor<MultiClassGradientBoostedTreesModel> predictor = new GradientBoostingPredictor<>((MultiClassGradientBoostedTreesModel) model.getEnsembleModel(), modelSpec, dataSpec, m_configuration);
ColumnRearranger rearranger = predictor.getPredictionRearranger();
BufferedDataTable outTable = exec.createColumnRearrangeTable(data, rearranger, exec);
return new BufferedDataTable[] { outTable };
}
use of org.knime.base.node.mine.treeensemble2.model.MultiClassGradientBoostedTreesModel in project knime-core by knime.
the class GradientBoostingPredictor method createPredictionRearranger.
/**
* {@inheritDoc}
*/
@Override
protected ColumnRearranger createPredictionRearranger() throws InvalidSettingsException {
TreeEnsembleModelPortObjectSpec modelSpec = getModelSpec();
DataTableSpec dataSpec = getDataSpec();
ColumnRearranger predictionRearranger;
boolean hasPossibleValues = modelSpec.getTargetColumnPossibleValueMap() != null;
if (modelSpec.getTargetColumn().getType().isCompatible(DoubleValue.class)) {
predictionRearranger = new ColumnRearranger(dataSpec);
@SuppressWarnings("unchecked") GradientBoostingPredictor<GradientBoostedTreesModel> pred = (GradientBoostingPredictor<GradientBoostedTreesModel>) this;
predictionRearranger.append(GradientBoostingPredictorCellFactory.createFactory(pred));
} else if (getConfiguration().isAppendClassConfidences() && !hasPossibleValues) {
// can't add confidence columns (possible values unknown)
predictionRearranger = null;
} else {
predictionRearranger = new ColumnRearranger(dataSpec);
@SuppressWarnings("unchecked") GradientBoostingPredictor<MultiClassGradientBoostedTreesModel> pred = (GradientBoostingPredictor<MultiClassGradientBoostedTreesModel>) this;
predictionRearranger.append(LKGradientBoostingPredictorCellFactory.createFactory(pred));
}
return predictionRearranger;
}
use of org.knime.base.node.mine.treeensemble2.model.MultiClassGradientBoostedTreesModel in project knime-core by knime.
the class ClassificationGBTModelExporter method doWrite.
/**
* {@inheritDoc}
*/
@Override
protected void doWrite(final MiningModel model) {
Segmentation modelChain = model.addNewSegmentation();
modelChain.setMultipleModelMethod(MULTIPLEMODELMETHOD.MODEL_CHAIN);
MultiClassGradientBoostedTreesModel gbt = getGBTModel();
// write one segment per class
for (int i = 0; i < gbt.getNrClasses(); i++) {
addClassSegment(modelChain, i);
}
// combine class predictions
addAggregationSegment(modelChain);
}
Aggregations