Search in sources :

Example 1 with Solver

use of org.knime.base.node.mine.regression.logistic.learner4.LogRegLearnerSettings.Solver in project knime-core by knime.

the class LogRegLearnerNodeDialogPane method loadSettingsFrom.

/**
 * {@inheritDoc}
 */
@Override
protected void loadSettingsFrom(final NodeSettingsRO s, final PortObjectSpec[] specs) throws NotConfigurableException {
    final LogRegLearnerSettings settings = new LogRegLearnerSettings();
    m_inSpec = (DataTableSpec) specs[0];
    settings.loadSettingsForDialog(s, m_inSpec);
    final DataColumnSpecFilterConfiguration config = settings.getIncludedColumns();
    m_filterPanel.loadConfiguration(config, m_inSpec);
    String target = settings.getTargetColumn();
    m_selectionPanel.update(m_inSpec, target);
    // m_filterPanel.updateWithNewConfiguration(config); is not enough, we have to reload things as selection update might change the UI
    m_filterPanel.loadConfiguration(config, m_inSpec);
    // must hide the target from filter panel
    // updating m_filterPanel first does not work as the first
    // element in the spec will always be in the exclude list.
    String selected = m_selectionPanel.getSelectedColumn();
    if (null == selected) {
        for (DataColumnSpec colSpec : m_inSpec) {
            if (colSpec.getType().isCompatible(NominalValue.class)) {
                selected = colSpec.getName();
                break;
            }
        }
    }
    if (selected != null) {
        DataColumnSpec colSpec = m_inSpec.getColumnSpec(selected);
        m_filterPanel.hideNames(colSpec);
    }
    updateTargetCategories(settings.getTargetReferenceCategory());
    m_notSortTarget.setSelected(settings.getUseTargetDomainOrder());
    m_notSortIncludes.setSelected(settings.getUseFeatureDomainOrder());
    Solver solver = settings.getSolver();
    m_solverComboBox.setSelectedItem(solver);
    if (solver == Solver.IRLS) {
        setEnabledSGRelated(false);
    }
    m_maxEpochSpinner.setValue(settings.getMaxEpoch());
    m_lazyCalculationCheckBox.setSelected(settings.isPerformLazy());
    m_calcCovMatrixCheckBox.setSelected(settings.isCalcCovMatrix());
    double epsilon = settings.getEpsilon();
    m_epsilonField.setText(Double.toString(epsilon));
    m_learningRateStrategyComboBox.setSelectedItem(settings.getLearningRateStrategy());
    m_initialLearningRateField.setText(Double.toString(settings.getInitialLearningRate()));
    m_priorComboBox.setSelectedItem(settings.getPrior());
    m_priorVarianceSpinner.setValue(settings.getPriorVariance());
    m_inMemoryCheckBox.setSelected(settings.isInMemory());
    Long seed = settings.getSeed();
    toggleSeedComponents();
    m_seedField.setText(Long.toString(seed != null ? seed : System.currentTimeMillis()));
    m_chunkSizeSpinner.setValue(settings.getChunkSize());
    m_chunkSizeSpinner.setEnabled(!settings.isInMemory());
}
Also used : Solver(org.knime.base.node.mine.regression.logistic.learner4.LogRegLearnerSettings.Solver) DataColumnSpec(org.knime.core.data.DataColumnSpec) DataColumnSpecFilterConfiguration(org.knime.core.node.util.filter.column.DataColumnSpecFilterConfiguration)

Example 2 with Solver

use of org.knime.base.node.mine.regression.logistic.learner4.LogRegLearnerSettings.Solver in project knime-core by knime.

the class IrlsLearner method irlsRls.

/**
 * Do an irls step. The result is stored in beta.
 *
 * @param data over trainings data.
 * @param beta parameter vector
 * @param rC regressors count
 * @param tcC target category count
 * @throws CanceledExecutionException when method is cancelled
 */
private void irlsRls(final TrainingData<ClassificationTrainingRow> data, final RealMatrix beta, final int rC, final int tcC, final ExecutionMonitor exec) throws CanceledExecutionException {
    long rowCount = 0;
    int dim = (rC + 1) * (tcC - 1);
    RealMatrix xTwx = MatrixUtils.createRealMatrix(dim, dim);
    RealMatrix xTyu = MatrixUtils.createRealMatrix(dim, 1);
    double[] eBetaTx = new double[tcC - 1];
    double[] pi = new double[tcC - 1];
    final long totalRowCount = data.getRowCount();
    for (ClassificationTrainingRow row : data) {
        rowCount++;
        exec.checkCanceled();
        exec.setProgress(rowCount / (double) totalRowCount, "Row " + rowCount + "/" + totalRowCount);
        for (int k = 0; k < tcC - 1; k++) {
            double z = 0.0;
            for (FeatureIterator iter = row.getFeatureIterator(); iter.next(); ) {
                double featureVal = iter.getFeatureValue();
                int featureIdx = iter.getFeatureIndex();
                z += featureVal * beta.getEntry(0, k * (rC + 1) + featureIdx);
            }
            eBetaTx[k] = Math.exp(z);
        }
        double sumEBetaTx = 0;
        for (int k = 0; k < tcC - 1; k++) {
            sumEBetaTx += eBetaTx[k];
        }
        for (int k = 0; k < tcC - 1; k++) {
            double pik = eBetaTx[k] / (1 + sumEBetaTx);
            pi[k] = pik;
        }
        // fill xTwx (aka the hessian of the loglikelihood)
        for (FeatureIterator outer = row.getFeatureIterator(); outer.next(); ) {
            int i = outer.getFeatureIndex();
            double outerVal = outer.getFeatureValue();
            for (FeatureIterator inner = outer.spawn(); inner.next(); ) {
                int ii = inner.getFeatureIndex();
                double innerVal = inner.getFeatureValue();
                for (int k = 0; k < tcC - 1; k++) {
                    for (int kk = k; kk < tcC - 1; kk++) {
                        int o1 = k * (rC + 1);
                        int o2 = kk * (rC + 1);
                        double v = xTwx.getEntry(o1 + i, o2 + ii);
                        if (k == kk) {
                            double w = pi[k] * (1 - pi[k]);
                            v += outerVal * w * innerVal;
                            assert o1 == o2;
                        } else {
                            double w = -pi[k] * pi[kk];
                            v += outerVal * w * innerVal;
                        }
                        xTwx.setEntry(o1 + i, o2 + ii, v);
                        xTwx.setEntry(o1 + ii, o2 + i, v);
                        if (k != kk) {
                            xTwx.setEntry(o2 + ii, o1 + i, v);
                            xTwx.setEntry(o2 + i, o1 + ii, v);
                        }
                    }
                }
            }
        }
        int g = row.getCategory();
        // fill matrix xTyu
        for (FeatureIterator iter = row.getFeatureIterator(); iter.next(); ) {
            int idx = iter.getFeatureIndex();
            double val = iter.getFeatureValue();
            for (int k = 0; k < tcC - 1; k++) {
                int o = k * (rC + 1);
                double v = xTyu.getEntry(o + idx, 0);
                double y = k == g ? 1 : 0;
                v += (y - pi[k]) * val;
                xTyu.setEntry(o + idx, 0, v);
            }
        }
    }
    // currently not used but could become interesting in the future
    // if (m_penaltyTerm > 0.0) {
    // RealMatrix stdError = getStdErrorMatrix(xTwx);
    // // do not penalize the constant terms
    // for (int i = 0; i < tcC - 1; i++) {
    // stdError.setEntry(i * (rC + 1), i * (rC + 1), 0);
    // }
    // xTwx = xTwx.add(stdError.scalarMultiply(-0.00001));
    // }
    exec.checkCanceled();
    b = xTwx.multiply(beta.transpose()).add(xTyu);
    A = xTwx;
    if (rowCount < A.getColumnDimension()) {
        // but it's important to ensure this property
        throw new IllegalStateException("The dataset must have at least " + A.getColumnDimension() + " rows, but it has only " + rowCount + " rows. It is recommended to use a " + "larger dataset in order to increase accuracy.");
    }
    DecompositionSolver solver = new SingularValueDecomposition(A).getSolver();
    RealMatrix betaNew = solver.solve(b);
    beta.setSubMatrix(betaNew.transpose().getData(), 0, 0);
}
Also used : FeatureIterator(org.knime.base.node.mine.regression.logistic.learner4.data.TrainingRow.FeatureIterator) ClassificationTrainingRow(org.knime.base.node.mine.regression.logistic.learner4.data.ClassificationTrainingRow) RealMatrix(org.apache.commons.math3.linear.RealMatrix) DecompositionSolver(org.apache.commons.math3.linear.DecompositionSolver) SingularValueDecomposition(org.apache.commons.math3.linear.SingularValueDecomposition)

Example 3 with Solver

use of org.knime.base.node.mine.regression.logistic.learner4.LogRegLearnerSettings.Solver in project knime-core by knime.

the class LogRegLearnerNodeDialogPane method solverChanged.

private void solverChanged(final Solver solver) {
    boolean sgMethod = solver != Solver.IRLS;
    if (sgMethod) {
        setEnabledSGRelated(true);
        m_lazyCalculationCheckBox.setEnabled(solver.supportsLazy());
        ComboBoxModel<Prior> oldPriorModel = m_priorComboBox.getModel();
        EnumSet<Prior> compatiblePriors = solver.getCompatiblePriors();
        Prior oldSelectedPrior = (Prior) oldPriorModel.getSelectedItem();
        m_priorComboBox.setModel(new DefaultComboBoxModel<>(compatiblePriors.toArray(new Prior[compatiblePriors.size()])));
        Prior newSelectedPrior;
        if (compatiblePriors.contains(oldSelectedPrior)) {
            m_priorComboBox.setSelectedItem(oldSelectedPrior);
            newSelectedPrior = oldSelectedPrior;
        } else {
            newSelectedPrior = (Prior) m_priorComboBox.getSelectedItem();
        // TODO warn user that the prior selection changed
        }
        enforcePriorCompatibilities(newSelectedPrior);
        LearningRateStrategies oldSelectedLRS = (LearningRateStrategies) m_learningRateStrategyComboBox.getSelectedItem();
        EnumSet<LearningRateStrategies> compatibleLRS = solver.getCompatibleLearningRateStrategies();
        m_learningRateStrategyComboBox.setModel(new DefaultComboBoxModel<>(compatibleLRS.toArray(new LearningRateStrategies[compatibleLRS.size()])));
        LearningRateStrategies newSelectedLRS = (LearningRateStrategies) m_learningRateStrategyComboBox.getSelectedItem();
        if (compatibleLRS.contains(oldSelectedLRS)) {
            m_learningRateStrategyComboBox.setSelectedItem(oldSelectedLRS);
            newSelectedLRS = oldSelectedLRS;
        } else {
            newSelectedLRS = (LearningRateStrategies) m_learningRateStrategyComboBox.getSelectedItem();
        // TODO warn user that the selected learning rate strategy changed
        }
        enforceLRSCompatibilities(newSelectedLRS);
    } else {
        setEnabledSGRelated(false);
    }
}
Also used : LearningRateStrategies(org.knime.base.node.mine.regression.logistic.learner4.LogRegLearnerSettings.LearningRateStrategies) Prior(org.knime.base.node.mine.regression.logistic.learner4.LogRegLearnerSettings.Prior)

Aggregations

DecompositionSolver (org.apache.commons.math3.linear.DecompositionSolver)1 RealMatrix (org.apache.commons.math3.linear.RealMatrix)1 SingularValueDecomposition (org.apache.commons.math3.linear.SingularValueDecomposition)1 LearningRateStrategies (org.knime.base.node.mine.regression.logistic.learner4.LogRegLearnerSettings.LearningRateStrategies)1 Prior (org.knime.base.node.mine.regression.logistic.learner4.LogRegLearnerSettings.Prior)1 Solver (org.knime.base.node.mine.regression.logistic.learner4.LogRegLearnerSettings.Solver)1 ClassificationTrainingRow (org.knime.base.node.mine.regression.logistic.learner4.data.ClassificationTrainingRow)1 FeatureIterator (org.knime.base.node.mine.regression.logistic.learner4.data.TrainingRow.FeatureIterator)1 DataColumnSpec (org.knime.core.data.DataColumnSpec)1 DataColumnSpecFilterConfiguration (org.knime.core.node.util.filter.column.DataColumnSpecFilterConfiguration)1