Search in sources :

Example 1 with LogRegLearnerResult

use of org.knime.base.node.mine.regression.logistic.learner4.LogRegLearnerResult in project knime-core by knime.

the class SagLogRegLearner method learn.

/**
 * {@inheritDoc}
 */
@Override
public LogRegLearnerResult learn(final TrainingData<ClassificationTrainingRow> data, final ExecutionMonitor progressMonitor) throws CanceledExecutionException, InvalidSettingsException {
    AbstractSGOptimizer sgOpt = createOptimizer(m_settings, data);
    SimpleProgress progMon = new SimpleProgress(progressMonitor.getProgressMonitor());
    LogRegLearnerResult result = sgOpt.optimize(m_settings.getMaxEpoch(), data, progMon);
    Optional<String> warning = sgOpt.getWarning();
    if (warning.isPresent()) {
        m_warning = warning.get();
    }
    return result;
}
Also used : LogRegLearnerResult(org.knime.base.node.mine.regression.logistic.learner4.LogRegLearnerResult)

Example 2 with LogRegLearnerResult

use of org.knime.base.node.mine.regression.logistic.learner4.LogRegLearnerResult in project knime-core by knime.

the class LogRegCoordinator method learn.

/**
 * Performs the learning task by creating the appropriate LogRegLearner and all other objects
 * necessary for a successful training.
 *
 * @param trainingData a DataTable that contains the data on which to learn the logistic regression model
 * @param exec the execution context of the corresponding KNIME node
 * @return the content of the logistic regression model
 * @throws InvalidSettingsException if the settings cause inconsistencies during training
 * @throws CanceledExecutionException if the training is canceled
 */
LogisticRegressionContent learn(final BufferedDataTable trainingData, final ExecutionContext exec) throws InvalidSettingsException, CanceledExecutionException {
    CheckUtils.checkArgument(trainingData.size() > 0, "The input table is empty. Please provide data to learn on.");
    CheckUtils.checkArgument(trainingData.size() <= Integer.MAX_VALUE, "The input table contains too many rows.");
    LogRegLearner learner;
    if (m_settings.getSolver() == Solver.IRLS) {
        learner = new IrlsLearner(m_settings.getMaxEpoch(), m_settings.getEpsilon(), m_settings.isCalcCovMatrix());
    } else {
        learner = new SagLogRegLearner(m_settings);
    }
    double calcDomainTime = 1.0 / (5.0 * 2.0 + 1.0);
    exec.setMessage("Analyzing categorical data");
    BufferedDataTable dataTable = recalcDomainForTargetAndLearningFields(trainingData, exec.createSubExecutionContext(calcDomainTime));
    checkConstantLearningFields(dataTable);
    exec.setMessage("Building logistic regression model");
    ExecutionMonitor trainExec = exec.createSubProgress(1.0 - calcDomainTime);
    LogRegLearnerResult result;
    TrainingRowBuilder<ClassificationTrainingRow> rowBuilder = new SparseClassificationTrainingRowBuilder(dataTable, m_pmmlOutSpec, m_settings.getTargetReferenceCategory(), m_settings.getSortTargetCategories(), m_settings.getSortIncludesCategories());
    TrainingData<ClassificationTrainingRow> data;
    Long seed = m_settings.getSeed();
    if (m_settings.isInMemory()) {
        data = new InMemoryData<ClassificationTrainingRow>(dataTable, seed, rowBuilder);
    } else {
        data = new DataTableTrainingData<ClassificationTrainingRow>(trainingData, seed, rowBuilder, m_settings.getChunkSize(), exec.createSilentSubExecutionContext(0.0));
    }
    checkShapeCompatibility(data);
    result = learner.learn(data, trainExec);
    LogisticRegressionContent content = createContentFromLearnerResult(result, rowBuilder, trainingData.getDataTableSpec());
    addToWarning(learner.getWarningMessage());
    return content;
}
Also used : ClassificationTrainingRow(org.knime.base.node.mine.regression.logistic.learner4.data.ClassificationTrainingRow) SparseClassificationTrainingRowBuilder(org.knime.base.node.mine.regression.logistic.learner4.data.SparseClassificationTrainingRowBuilder) SagLogRegLearner(org.knime.base.node.mine.regression.logistic.learner4.sg.SagLogRegLearner) SagLogRegLearner(org.knime.base.node.mine.regression.logistic.learner4.sg.SagLogRegLearner) BufferedDataTable(org.knime.core.node.BufferedDataTable) ExecutionMonitor(org.knime.core.node.ExecutionMonitor)

Example 3 with LogRegLearnerResult

use of org.knime.base.node.mine.regression.logistic.learner4.LogRegLearnerResult in project knime-core by knime.

the class AbstractSGOptimizer method optimize.

public LogRegLearnerResult optimize(final int maxEpoch, final TrainingData<T> data, final Progress progress) throws CanceledExecutionException {
    final int nRows = data.getRowCount();
    final int nFets = data.getFeatureCount();
    final int nCats = data.getTargetDimension();
    final U updater = m_updaterFactory.create();
    final WeightMatrix<T> beta = new SimpleWeightMatrix<>(nFets, nCats, true);
    int epoch = 0;
    for (; epoch < maxEpoch; epoch++) {
        // notify learning rate strategy that a new epoch starts
        m_lrStrategy.startNewEpoch(epoch);
        progress.setProgress(((double) epoch) / maxEpoch, "Start epoch " + epoch + " of " + maxEpoch);
        for (int k = 0; k < nRows; k++) {
            progress.checkCanceled();
            T x = data.getRandomRow();
            prepareIteration(beta, x, updater, m_regUpdater, k);
            double[] prediction = beta.predict(x);
            double[] sig = m_loss.gradient(x, prediction);
            double stepSize = m_lrStrategy.getCurrentLearningRate(x, prediction, sig);
            // beta is updated in two steps
            m_regUpdater.update(beta, stepSize, k);
            performUpdate(x, updater, sig, beta, stepSize, k);
            double scale = beta.getScale();
            if (scale > 1e10 || scale < -1e10 || (scale > 0 && scale < 1e-10) || (scale < 0 && scale > -1e-10)) {
                normalize(beta, updater, k);
                beta.normalize();
            }
        }
        postProcessEpoch(beta, updater, m_regUpdater);
        if (m_stoppingCriterion.checkConvergence(beta)) {
            break;
        }
    }
    StringBuilder warnBuilder = new StringBuilder();
    if (epoch >= maxEpoch) {
        warnBuilder.append("The algorithm did not reach convergence after the specified number of epochs. " + "Setting the epoch limit higher might result in a better model.");
    }
    double lossSum = totalLoss(beta);
    RealMatrix betaMat = MatrixUtils.createRealMatrix(beta.getWeightVector());
    RealMatrix covMat = null;
    if (m_calcCovMatrix) {
        try {
            covMat = calculateCovariateMatrix(beta);
        } catch (SingularMatrixException e) {
            if (warnBuilder.length() > 0) {
                warnBuilder.append("\n");
            }
            warnBuilder.append("The covariance matrix could not be calculated because the" + " observed fisher information matrix was singular. Did you properly normalize the numerical features?");
            covMat = null;
        }
    }
    m_warning = warnBuilder.length() > 0 ? warnBuilder.toString() : null;
    // in a maximum likelihood sense
    return new LogRegLearnerResult(betaMat, covMat, epoch, -lossSum);
}
Also used : RealMatrix(org.apache.commons.math3.linear.RealMatrix) SingularMatrixException(org.apache.commons.math3.linear.SingularMatrixException) LogRegLearnerResult(org.knime.base.node.mine.regression.logistic.learner4.LogRegLearnerResult)

Aggregations

LogRegLearnerResult (org.knime.base.node.mine.regression.logistic.learner4.LogRegLearnerResult)2 RealMatrix (org.apache.commons.math3.linear.RealMatrix)1 SingularMatrixException (org.apache.commons.math3.linear.SingularMatrixException)1 ClassificationTrainingRow (org.knime.base.node.mine.regression.logistic.learner4.data.ClassificationTrainingRow)1 SparseClassificationTrainingRowBuilder (org.knime.base.node.mine.regression.logistic.learner4.data.SparseClassificationTrainingRowBuilder)1 SagLogRegLearner (org.knime.base.node.mine.regression.logistic.learner4.sg.SagLogRegLearner)1 BufferedDataTable (org.knime.core.node.BufferedDataTable)1 ExecutionMonitor (org.knime.core.node.ExecutionMonitor)1