Search in sources :

Example 11 with TreeEnsembleModelPortObjectSpec

use of org.knime.base.node.mine.treeensemble2.model.TreeEnsembleModelPortObjectSpec in project knime-core by knime.

the class GradientBoostingClassificationPredictorNodeModel method execute.

/**
 * {@inheritDoc}
 */
@Override
protected PortObject[] execute(final PortObject[] inObjects, final ExecutionContext exec) throws Exception {
    GradientBoostingModelPortObject model = (GradientBoostingModelPortObject) inObjects[0];
    TreeEnsembleModelPortObjectSpec modelSpec = model.getSpec();
    BufferedDataTable data = (BufferedDataTable) inObjects[1];
    DataTableSpec dataSpec = data.getDataTableSpec();
    GradientBoostingPredictor<MultiClassGradientBoostedTreesModel> predictor = new GradientBoostingPredictor<>((MultiClassGradientBoostedTreesModel) model.getEnsembleModel(), modelSpec, dataSpec, m_configuration);
    ColumnRearranger rearranger = predictor.getPredictionRearranger();
    BufferedDataTable outTable = exec.createColumnRearrangeTable(data, rearranger, exec);
    return new BufferedDataTable[] { outTable };
}
Also used : DataTableSpec(org.knime.core.data.DataTableSpec) ColumnRearranger(org.knime.core.data.container.ColumnRearranger) GradientBoostingModelPortObject(org.knime.base.node.mine.treeensemble2.model.GradientBoostingModelPortObject) TreeEnsembleModelPortObjectSpec(org.knime.base.node.mine.treeensemble2.model.TreeEnsembleModelPortObjectSpec) MultiClassGradientBoostedTreesModel(org.knime.base.node.mine.treeensemble2.model.MultiClassGradientBoostedTreesModel) GradientBoostingPredictor(org.knime.base.node.mine.treeensemble2.node.gradientboosting.predictor.GradientBoostingPredictor) BufferedDataTable(org.knime.core.node.BufferedDataTable)

Example 12 with TreeEnsembleModelPortObjectSpec

use of org.knime.base.node.mine.treeensemble2.model.TreeEnsembleModelPortObjectSpec in project knime-core by knime.

the class GradientBoostingPredictor method createPredictionRearranger.

/**
 * {@inheritDoc}
 */
@Override
protected ColumnRearranger createPredictionRearranger() throws InvalidSettingsException {
    TreeEnsembleModelPortObjectSpec modelSpec = getModelSpec();
    DataTableSpec dataSpec = getDataSpec();
    ColumnRearranger predictionRearranger;
    boolean hasPossibleValues = modelSpec.getTargetColumnPossibleValueMap() != null;
    if (modelSpec.getTargetColumn().getType().isCompatible(DoubleValue.class)) {
        predictionRearranger = new ColumnRearranger(dataSpec);
        @SuppressWarnings("unchecked") GradientBoostingPredictor<GradientBoostedTreesModel> pred = (GradientBoostingPredictor<GradientBoostedTreesModel>) this;
        predictionRearranger.append(GradientBoostingPredictorCellFactory.createFactory(pred));
    } else if (getConfiguration().isAppendClassConfidences() && !hasPossibleValues) {
        // can't add confidence columns (possible values unknown)
        predictionRearranger = null;
    } else {
        predictionRearranger = new ColumnRearranger(dataSpec);
        @SuppressWarnings("unchecked") GradientBoostingPredictor<MultiClassGradientBoostedTreesModel> pred = (GradientBoostingPredictor<MultiClassGradientBoostedTreesModel>) this;
        predictionRearranger.append(LKGradientBoostingPredictorCellFactory.createFactory(pred));
    }
    return predictionRearranger;
}
Also used : DataTableSpec(org.knime.core.data.DataTableSpec) ColumnRearranger(org.knime.core.data.container.ColumnRearranger) TreeEnsembleModelPortObjectSpec(org.knime.base.node.mine.treeensemble2.model.TreeEnsembleModelPortObjectSpec) MultiClassGradientBoostedTreesModel(org.knime.base.node.mine.treeensemble2.model.MultiClassGradientBoostedTreesModel) MultiClassGradientBoostedTreesModel(org.knime.base.node.mine.treeensemble2.model.MultiClassGradientBoostedTreesModel) GradientBoostedTreesModel(org.knime.base.node.mine.treeensemble2.model.GradientBoostedTreesModel)

Example 13 with TreeEnsembleModelPortObjectSpec

use of org.knime.base.node.mine.treeensemble2.model.TreeEnsembleModelPortObjectSpec in project knime-core by knime.

the class GradientBoostingPMMLPredictorNodeModel method configure.

/**
 * {@inheritDoc}
 */
@Override
protected PortObjectSpec[] configure(final PortObjectSpec[] inSpecs) throws InvalidSettingsException {
    PMMLPortObjectSpec pmmlSpec = (PMMLPortObjectSpec) inSpecs[0];
    DataType targetType = extractTargetType(pmmlSpec);
    if (m_isRegression && !targetType.isCompatible(DoubleValue.class)) {
        throw new InvalidSettingsException("This node expects a regression model.");
    } else if (!m_isRegression && !targetType.isCompatible(StringValue.class)) {
        throw new InvalidSettingsException("This node expectes a classification model.");
    }
    try {
        AbstractTreeModelPMMLTranslator.checkPMMLSpec(pmmlSpec);
    } catch (IllegalArgumentException e) {
        throw new InvalidSettingsException(e.getMessage());
    }
    TreeEnsembleModelPortObjectSpec modelSpec = translateSpec(pmmlSpec);
    String targetColName = modelSpec.getTargetColumn().getName();
    if (m_configuration == null) {
        m_configuration = TreeEnsemblePredictorConfiguration.createDefault(m_isRegression, targetColName);
    } else if (!m_configuration.isChangePredictionColumnName()) {
        m_configuration.setPredictionColumnName(TreeEnsemblePredictorConfiguration.getPredictColumnName(targetColName));
    }
    modelSpec.assertTargetTypeMatches(m_isRegression);
    DataTableSpec dataSpec = (DataTableSpec) inSpecs[1];
    final GradientBoostingPredictor<GradientBoostedTreesModel> pred = new GradientBoostingPredictor<>(null, modelSpec, dataSpec, m_configuration);
    return new PortObjectSpec[] { pred.getPredictionRearranger().createSpec() };
}
Also used : PMMLPortObjectSpec(org.knime.core.node.port.pmml.PMMLPortObjectSpec) DataTableSpec(org.knime.core.data.DataTableSpec) InvalidSettingsException(org.knime.core.node.InvalidSettingsException) TreeEnsembleModelPortObjectSpec(org.knime.base.node.mine.treeensemble2.model.TreeEnsembleModelPortObjectSpec) GradientBoostingPredictor(org.knime.base.node.mine.treeensemble2.node.gradientboosting.predictor.GradientBoostingPredictor) PMMLPortObjectSpec(org.knime.core.node.port.pmml.PMMLPortObjectSpec) TreeEnsembleModelPortObjectSpec(org.knime.base.node.mine.treeensemble2.model.TreeEnsembleModelPortObjectSpec) PortObjectSpec(org.knime.core.node.port.PortObjectSpec) DataType(org.knime.core.data.DataType) GradientBoostedTreesModel(org.knime.base.node.mine.treeensemble2.model.GradientBoostedTreesModel) MultiClassGradientBoostedTreesModel(org.knime.base.node.mine.treeensemble2.model.MultiClassGradientBoostedTreesModel) StringValue(org.knime.core.data.StringValue)

Example 14 with TreeEnsembleModelPortObjectSpec

use of org.knime.base.node.mine.treeensemble2.model.TreeEnsembleModelPortObjectSpec in project knime-core by knime.

the class GradientBoostingPMMLPredictorNodeModel method importModel.

@SuppressWarnings("unchecked")
private GradientBoostingModelPortObject importModel(final PMMLPortObject pmmlPO) {
    AbstractGBTModelPMMLTranslator<M> pmmlTranslator;
    DataType targetType = extractTargetType(pmmlPO.getSpec());
    if (targetType.isCompatible(DoubleValue.class)) {
        pmmlTranslator = (AbstractGBTModelPMMLTranslator<M>) new RegressionGBTModelPMMLTranslator();
    } else if (targetType.isCompatible(StringValue.class)) {
        pmmlTranslator = (AbstractGBTModelPMMLTranslator<M>) new ClassificationGBTModelPMMLTranslator();
    } else {
        throw new IllegalArgumentException("Currently only regression models are supported.");
    }
    pmmlPO.initializeModelTranslator(pmmlTranslator);
    if (pmmlTranslator.hasWarning()) {
        setWarningMessage(pmmlTranslator.getWarning());
    }
    return new GradientBoostingModelPortObject(new TreeEnsembleModelPortObjectSpec(pmmlTranslator.getLearnSpec()), pmmlTranslator.getGBTModel());
}
Also used : AbstractGBTModelPMMLTranslator(org.knime.base.node.mine.treeensemble2.model.pmml.AbstractGBTModelPMMLTranslator) ClassificationGBTModelPMMLTranslator(org.knime.base.node.mine.treeensemble2.model.pmml.ClassificationGBTModelPMMLTranslator) GradientBoostingModelPortObject(org.knime.base.node.mine.treeensemble2.model.GradientBoostingModelPortObject) TreeEnsembleModelPortObjectSpec(org.knime.base.node.mine.treeensemble2.model.TreeEnsembleModelPortObjectSpec) DataType(org.knime.core.data.DataType) RegressionGBTModelPMMLTranslator(org.knime.base.node.mine.treeensemble2.model.pmml.RegressionGBTModelPMMLTranslator) StringValue(org.knime.core.data.StringValue)

Example 15 with TreeEnsembleModelPortObjectSpec

use of org.knime.base.node.mine.treeensemble2.model.TreeEnsembleModelPortObjectSpec in project knime-core by knime.

the class GradientBoostingPMMLPredictorNodeModel method translateSpec.

private TreeEnsembleModelPortObjectSpec translateSpec(final PMMLPortObjectSpec pmmlSpec) {
    DataTableSpec pmmlDataSpec = pmmlSpec.getDataTableSpec();
    ColumnRearranger cr = new ColumnRearranger(pmmlDataSpec);
    List<DataColumnSpec> targets = pmmlSpec.getTargetCols();
    CheckUtils.checkArgument(!targets.isEmpty(), "The provided PMML does not declare a target field.");
    CheckUtils.checkArgument(targets.size() == 1, "The provided PMML declares multiple target. " + "This behavior is currently not supported.");
    cr.move(targets.get(0).getName(), pmmlDataSpec.getNumColumns());
    return new TreeEnsembleModelPortObjectSpec(cr.createSpec());
}
Also used : DataTableSpec(org.knime.core.data.DataTableSpec) ColumnRearranger(org.knime.core.data.container.ColumnRearranger) DataColumnSpec(org.knime.core.data.DataColumnSpec) TreeEnsembleModelPortObjectSpec(org.knime.base.node.mine.treeensemble2.model.TreeEnsembleModelPortObjectSpec)

Aggregations

TreeEnsembleModelPortObjectSpec (org.knime.base.node.mine.treeensemble2.model.TreeEnsembleModelPortObjectSpec)40 DataTableSpec (org.knime.core.data.DataTableSpec)38 ColumnRearranger (org.knime.core.data.container.ColumnRearranger)25 TreeEnsemblePredictor (org.knime.base.node.mine.treeensemble2.node.predictor.TreeEnsemblePredictor)22 TreeEnsembleModelPortObject (org.knime.base.node.mine.treeensemble2.model.TreeEnsembleModelPortObject)12 FilterLearnColumnRearranger (org.knime.base.node.mine.treeensemble2.node.learner.TreeEnsembleLearnerConfiguration.FilterLearnColumnRearranger)12 BufferedDataTable (org.knime.core.node.BufferedDataTable)12 InvalidSettingsException (org.knime.core.node.InvalidSettingsException)12 PortObjectSpec (org.knime.core.node.port.PortObjectSpec)9 TreeEnsemblePredictorConfiguration (org.knime.base.node.mine.treeensemble2.node.predictor.TreeEnsemblePredictorConfiguration)8 GradientBoostingModelPortObject (org.knime.base.node.mine.treeensemble2.model.GradientBoostingModelPortObject)7 GradientBoostingPredictor (org.knime.base.node.mine.treeensemble2.node.gradientboosting.predictor.GradientBoostingPredictor)7 ExecutionMonitor (org.knime.core.node.ExecutionMonitor)7 TreeData (org.knime.base.node.mine.treeensemble2.data.TreeData)6 TreeDataCreator (org.knime.base.node.mine.treeensemble2.data.TreeDataCreator)6 DataColumnSpec (org.knime.core.data.DataColumnSpec)6 PortObject (org.knime.core.node.port.PortObject)6 ExecutionException (java.util.concurrent.ExecutionException)5 MultiClassGradientBoostedTreesModel (org.knime.base.node.mine.treeensemble2.model.MultiClassGradientBoostedTreesModel)5 TreeEnsembleModel (org.knime.base.node.mine.treeensemble2.model.TreeEnsembleModel)5