Search in sources :

Example 6 with GradientBoostingLearnerConfiguration

use of org.knime.base.node.mine.treeensemble2.node.gradientboosting.learner.GradientBoostingLearnerConfiguration in project knime-core by knime.

the class MGradientBoostedTreesLearner method learn.

/**
 * {@inheritDoc}
 */
@Override
public AbstractGradientBoostingModel learn(final ExecutionMonitor exec) throws CanceledExecutionException {
    final TreeData actualData = getData();
    final GradientBoostingLearnerConfiguration config = getConfig();
    final int nrModels = config.getNrModels();
    final TreeTargetNumericColumnData actualTarget = getTarget();
    final double initialValue = actualTarget.getMedian();
    final ArrayList<TreeModelRegression> models = new ArrayList<TreeModelRegression>(nrModels);
    final ArrayList<Map<TreeNodeSignature, Double>> coefficientMaps = new ArrayList<Map<TreeNodeSignature, Double>>(nrModels);
    final double[] previousPrediction = new double[actualTarget.getNrRows()];
    Arrays.fill(previousPrediction, initialValue);
    final RandomData rd = config.createRandomData();
    final double alpha = config.getAlpha();
    TreeNodeSignatureFactory signatureFactory = null;
    final int maxLevels = config.getMaxLevels();
    // this should be the default
    if (maxLevels < TreeEnsembleLearnerConfiguration.MAX_LEVEL_INFINITE) {
        final int capacity = IntMath.pow(2, maxLevels - 1);
        signatureFactory = new TreeNodeSignatureFactory(capacity);
    } else {
        signatureFactory = new TreeNodeSignatureFactory();
    }
    exec.setMessage("Learning model");
    TreeData residualData;
    for (int i = 0; i < nrModels; i++) {
        final double[] residuals = new double[actualTarget.getNrRows()];
        for (int j = 0; j < actualTarget.getNrRows(); j++) {
            residuals[j] = actualTarget.getValueFor(j) - previousPrediction[j];
        }
        final double quantile = calculateAlphaQuantile(residuals, alpha);
        final double[] gradients = new double[residuals.length];
        for (int j = 0; j < gradients.length; j++) {
            gradients[j] = Math.abs(residuals[j]) <= quantile ? residuals[j] : quantile * Math.signum(residuals[j]);
        }
        residualData = createResidualDataFromArray(gradients, actualData);
        final RandomData rdSingle = TreeEnsembleLearnerConfiguration.createRandomData(rd.nextLong(Long.MIN_VALUE, Long.MAX_VALUE));
        final RowSample rowSample = getRowSampler().createRowSample(rdSingle);
        final TreeLearnerRegression treeLearner = new TreeLearnerRegression(getConfig(), residualData, getIndexManager(), signatureFactory, rdSingle, rowSample);
        final TreeModelRegression tree = treeLearner.learnSingleTree(exec, rdSingle);
        final Map<TreeNodeSignature, Double> coefficientMap = calcCoefficientMap(residuals, quantile, tree);
        adaptPreviousPrediction(previousPrediction, tree, coefficientMap);
        models.add(tree);
        coefficientMaps.add(coefficientMap);
        exec.setProgress(((double) i) / nrModels, "Finished level " + i + "/" + nrModels);
    }
    return new GradientBoostedTreesModel(getConfig(), actualData.getMetaData(), models.toArray(new TreeModelRegression[models.size()]), actualData.getTreeType(), initialValue, coefficientMaps);
}
Also used : RandomData(org.apache.commons.math.random.RandomData) ArrayList(java.util.ArrayList) TreeTargetNumericColumnData(org.knime.base.node.mine.treeensemble2.data.TreeTargetNumericColumnData) GradientBoostedTreesModel(org.knime.base.node.mine.treeensemble2.model.GradientBoostedTreesModel) TreeNodeSignature(org.knime.base.node.mine.treeensemble2.model.TreeNodeSignature) TreeModelRegression(org.knime.base.node.mine.treeensemble2.model.TreeModelRegression) GradientBoostingLearnerConfiguration(org.knime.base.node.mine.treeensemble2.node.gradientboosting.learner.GradientBoostingLearnerConfiguration) TreeData(org.knime.base.node.mine.treeensemble2.data.TreeData) TreeLearnerRegression(org.knime.base.node.mine.treeensemble2.learner.TreeLearnerRegression) RowSample(org.knime.base.node.mine.treeensemble2.sample.row.RowSample) HashMap(java.util.HashMap) Map(java.util.Map) TreeNodeSignatureFactory(org.knime.base.node.mine.treeensemble2.learner.TreeNodeSignatureFactory)

Example 7 with GradientBoostingLearnerConfiguration

use of org.knime.base.node.mine.treeensemble2.node.gradientboosting.learner.GradientBoostingLearnerConfiguration in project knime-core by knime.

the class AdvancedOptionsPanel method saveSettings.

public void saveSettings(final GradientBoostingLearnerConfiguration cfg) throws InvalidSettingsException {
    cfg.setAlpha((Double) m_alphaFractionSpinner.getValue());
    cfg.setUseAverageSplitPoints(m_useAverageSplitPointsChecker.isSelected());
    cfg.setUseBinaryNominalSplits(m_useBinaryNominalSplitsChecker.isSelected());
    final MissingValueHandling missValHandling = (MissingValueHandling) m_missingValueHandlingComboBox.getSelectedItem();
    if (missValHandling == MissingValueHandling.Surrogate && !m_useBinaryNominalSplitsChecker.isSelected()) {
        throw new InvalidSettingsException("Surrogate missing value handling can only be used if binary nominal splits are enabled.");
    }
    cfg.setMissingValueHandling((MissingValueHandling) m_missingValueHandlingComboBox.getSelectedItem());
    double dataFrac;
    boolean isSamplingWithReplacement;
    if (m_dataFractionPerTreeChecker.isSelected()) {
        dataFrac = (Double) m_dataFractionPerTreeSpinner.getValue();
        isSamplingWithReplacement = m_dataSamplingWithReplacementChecker.isSelected();
    } else {
        dataFrac = 1.0;
        isSamplingWithReplacement = false;
    }
    cfg.setDataFractionPerTree(dataFrac);
    cfg.setDataSelectionWithReplacement(isSamplingWithReplacement);
    ColumnSamplingMode cf;
    double columnFrac = 1.0;
    int columnAbsolute = TreeEnsembleLearnerConfiguration.DEF_COLUMN_ABSOLUTE;
    boolean isUseDifferentAttributesAtEachNode = m_columnUseDifferentSetOfAttributesForNodes.isSelected();
    if (m_columnFractionNoneButton.isSelected()) {
        cf = ColumnSamplingMode.None;
        isUseDifferentAttributesAtEachNode = false;
    } else if (m_columnFractionLinearButton.isSelected()) {
        cf = ColumnSamplingMode.Linear;
        columnFrac = (Double) m_columnFractionLinearTreeSpinner.getValue();
    } else if (m_columnFractionAbsoluteButton.isSelected()) {
        cf = ColumnSamplingMode.Absolute;
        columnAbsolute = (Integer) m_columnFractionAbsoluteTreeSpinner.getValue();
    } else if (m_columnFractionSqrtButton.isSelected()) {
        cf = ColumnSamplingMode.SquareRoot;
    } else {
        throw new InvalidSettingsException("No column selection policy selected");
    }
    cfg.setColumnSamplingMode(cf);
    cfg.setColumnFractionLinearValue(columnFrac);
    cfg.setColumnAbsoluteValue(columnAbsolute);
    cfg.setUseDifferentAttributesAtEachNode(isUseDifferentAttributesAtEachNode);
    Long seed;
    if (m_seedChecker.isSelected()) {
        final String seedText = m_seedTextField.getText();
        try {
            seed = Long.valueOf(seedText);
        } catch (Exception e) {
            throw new InvalidSettingsException("Unable to parse seed \"" + seedText + "\"", e);
        }
    } else {
        seed = null;
    }
    cfg.setSeed(seed);
}
Also used : InvalidSettingsException(org.knime.core.node.InvalidSettingsException) MissingValueHandling(org.knime.base.node.mine.treeensemble2.node.learner.TreeEnsembleLearnerConfiguration.MissingValueHandling) ColumnSamplingMode(org.knime.base.node.mine.treeensemble2.node.learner.TreeEnsembleLearnerConfiguration.ColumnSamplingMode) InvalidSettingsException(org.knime.core.node.InvalidSettingsException)

Example 8 with GradientBoostingLearnerConfiguration

use of org.knime.base.node.mine.treeensemble2.node.gradientboosting.learner.GradientBoostingLearnerConfiguration in project knime-core by knime.

the class AdvancedOptionsPanel method loadSettings.

public void loadSettings(final GradientBoostingLearnerConfiguration cfg) {
    m_alphaFractionSpinner.setValue(cfg.getAlpha());
    m_useAverageSplitPointsChecker.setSelected(cfg.isUseAverageSplitPoints());
    m_useBinaryNominalSplitsChecker.setSelected(cfg.isUseBinaryNominalSplits());
    m_missingValueHandlingComboBox.setSelectedItem(cfg.getMissingValueHandling());
    double dataFrac = cfg.getDataFractionPerTree();
    boolean isDataWithReplacement = cfg.isDataSelectionWithReplacement();
    boolean doesSampling = dataFrac < 1.0 || isDataWithReplacement;
    m_dataFractionPerTreeSpinner.setValue(dataFrac);
    if (m_dataFractionPerTreeChecker.isSelected() != doesSampling) {
        m_dataFractionPerTreeChecker.doClick();
    }
    if (isDataWithReplacement) {
        m_dataSamplingWithReplacementChecker.doClick();
    } else {
        m_dataSamplingWithOutReplacementChecker.doClick();
    }
    double colFrac = cfg.getColumnFractionLinearValue();
    int colAbsolute = cfg.getColumnAbsoluteValue();
    boolean useDifferentAttributesAtEachNode = cfg.isUseDifferentAttributesAtEachNode();
    ColumnSamplingMode columnFraction = cfg.getColumnSamplingMode();
    switch(columnFraction) {
        case None:
            m_columnFractionNoneButton.doClick();
            useDifferentAttributesAtEachNode = false;
            colFrac = 1.0;
            break;
        case Linear:
            m_columnFractionLinearButton.doClick();
            break;
        case Absolute:
            m_columnFractionAbsoluteButton.doClick();
            break;
        case SquareRoot:
            m_columnFractionSqrtButton.doClick();
            colFrac = 1.0;
            break;
    }
    m_columnFractionLinearTreeSpinner.setValue(colFrac);
    m_columnFractionAbsoluteTreeSpinner.setValue(colAbsolute);
    if (useDifferentAttributesAtEachNode) {
        m_columnUseDifferentSetOfAttributesForNodes.doClick();
    } else {
        m_columnUseSameSetOfAttributesForNodes.doClick();
    }
    Long seed = cfg.getSeed();
    if (m_seedChecker.isSelected() != (seed != null)) {
        m_seedChecker.doClick();
    }
    m_seedTextField.setText(Long.toString(seed != null ? seed : System.currentTimeMillis()));
}
Also used : ColumnSamplingMode(org.knime.base.node.mine.treeensemble2.node.learner.TreeEnsembleLearnerConfiguration.ColumnSamplingMode)

Example 9 with GradientBoostingLearnerConfiguration

use of org.knime.base.node.mine.treeensemble2.node.gradientboosting.learner.GradientBoostingLearnerConfiguration in project knime-core by knime.

the class GradientBoostingRegressionLearnerNodeDialogPane method saveSettingsTo.

/**
 * {@inheritDoc}
 */
@Override
protected void saveSettingsTo(final NodeSettingsWO settings) throws InvalidSettingsException {
    GradientBoostingLearnerConfiguration cfg = new GradientBoostingLearnerConfiguration(true);
    m_optionsPanel.saveSettings(cfg);
    m_advancedOptionsPanel.saveSettings(cfg);
    cfg.save(settings);
}
Also used : GradientBoostingLearnerConfiguration(org.knime.base.node.mine.treeensemble2.node.gradientboosting.learner.GradientBoostingLearnerConfiguration)

Example 10 with GradientBoostingLearnerConfiguration

use of org.knime.base.node.mine.treeensemble2.node.gradientboosting.learner.GradientBoostingLearnerConfiguration in project knime-core by knime.

the class GradientBoostingRegressionLearnerNodeDialogPane method loadSettingsFrom.

/**
 * {@inheritDoc}
 */
@Override
protected void loadSettingsFrom(final NodeSettingsRO settings, final DataTableSpec[] specs) throws NotConfigurableException {
    final DataTableSpec inSpec = specs[0];
    GradientBoostingLearnerConfiguration cfg = new GradientBoostingLearnerConfiguration(true);
    cfg.loadInDialog(settings, inSpec);
    m_optionsPanel.loadSettingsFrom(inSpec, cfg);
    m_advancedOptionsPanel.loadSettings(cfg);
}
Also used : DataTableSpec(org.knime.core.data.DataTableSpec) GradientBoostingLearnerConfiguration(org.knime.base.node.mine.treeensemble2.node.gradientboosting.learner.GradientBoostingLearnerConfiguration)

Aggregations

GradientBoostingLearnerConfiguration (org.knime.base.node.mine.treeensemble2.node.gradientboosting.learner.GradientBoostingLearnerConfiguration)9 TreeData (org.knime.base.node.mine.treeensemble2.data.TreeData)3 TreeTargetNumericColumnData (org.knime.base.node.mine.treeensemble2.data.TreeTargetNumericColumnData)3 TreeNodeSignature (org.knime.base.node.mine.treeensemble2.model.TreeNodeSignature)3 BitSet (java.util.BitSet)2 RandomData (org.apache.commons.math.random.RandomData)2 RegressionPriors (org.knime.base.node.mine.treeensemble2.data.RegressionPriors)2 DataMemberships (org.knime.base.node.mine.treeensemble2.data.memberships.DataMemberships)2 RootDataMemberships (org.knime.base.node.mine.treeensemble2.data.memberships.RootDataMemberships)2 TreeModelRegression (org.knime.base.node.mine.treeensemble2.model.TreeModelRegression)2 TreeNodeRegression (org.knime.base.node.mine.treeensemble2.model.TreeNodeRegression)2 TreeEnsembleLearnerConfiguration (org.knime.base.node.mine.treeensemble2.node.learner.TreeEnsembleLearnerConfiguration)2 ColumnSamplingMode (org.knime.base.node.mine.treeensemble2.node.learner.TreeEnsembleLearnerConfiguration.ColumnSamplingMode)2 ColumnSample (org.knime.base.node.mine.treeensemble2.sample.column.ColumnSample)2 RowSample (org.knime.base.node.mine.treeensemble2.sample.row.RowSample)2 DataTableSpec (org.knime.core.data.DataTableSpec)2 ArrayList (java.util.ArrayList)1 HashMap (java.util.HashMap)1 Map (java.util.Map)1 TreeAttributeColumnData (org.knime.base.node.mine.treeensemble2.data.TreeAttributeColumnData)1