use of org.knime.base.node.mine.treeensemble2.model.pmml.ClassificationGBTModelPMMLTranslator in project knime-core by knime.
the class GBTPMMLExporterNodeModel method execute.
/**
* {@inheritDoc}
*/
@Override
protected PortObject[] execute(final PortObject[] inObjects, final ExecutionContext exec) throws Exception {
GradientBoostingModelPortObject gbtPO = (GradientBoostingModelPortObject) inObjects[0];
AbstractGBTModelPMMLTranslator<?> translator;
AbstractGradientBoostingModel gbtModel = gbtPO.getEnsembleModel();
if (gbtModel instanceof GradientBoostedTreesModel) {
translator = new RegressionGBTModelPMMLTranslator((GradientBoostedTreesModel) gbtModel, gbtPO.getSpec().getLearnTableSpec());
} else if (gbtModel instanceof MultiClassGradientBoostedTreesModel) {
translator = new ClassificationGBTModelPMMLTranslator((MultiClassGradientBoostedTreesModel) gbtModel, gbtPO.getSpec().getLearnTableSpec());
} else {
throw new IllegalArgumentException("Unknown gradient boosted trees model type '" + gbtModel.getClass().getSimpleName() + "'.");
}
PMMLPortObjectSpec pmmlSpec = createPMMLSpec(gbtPO.getSpec(), gbtModel);
PMMLPortObject pmmlPO = new PMMLPortObject(pmmlSpec);
pmmlPO.addModelTranslater(translator);
return new PortObject[] { pmmlPO };
}
use of org.knime.base.node.mine.treeensemble2.model.pmml.ClassificationGBTModelPMMLTranslator in project knime-core by knime.
the class GradientBoostingPMMLPredictorNodeModel method importModel.
@SuppressWarnings("unchecked")
private GradientBoostingModelPortObject importModel(final PMMLPortObject pmmlPO) {
AbstractGBTModelPMMLTranslator<M> pmmlTranslator;
DataType targetType = extractTargetType(pmmlPO.getSpec());
if (targetType.isCompatible(DoubleValue.class)) {
pmmlTranslator = (AbstractGBTModelPMMLTranslator<M>) new RegressionGBTModelPMMLTranslator();
} else if (targetType.isCompatible(StringValue.class)) {
pmmlTranslator = (AbstractGBTModelPMMLTranslator<M>) new ClassificationGBTModelPMMLTranslator();
} else {
throw new IllegalArgumentException("Currently only regression models are supported.");
}
pmmlPO.initializeModelTranslator(pmmlTranslator);
if (pmmlTranslator.hasWarning()) {
setWarningMessage(pmmlTranslator.getWarning());
}
return new GradientBoostingModelPortObject(new TreeEnsembleModelPortObjectSpec(pmmlTranslator.getLearnSpec()), pmmlTranslator.getGBTModel());
}
Aggregations