Search in sources :

Example 1 with ClassificationTrainingRow

use of org.knime.base.node.mine.regression.logistic.learner4.data.ClassificationTrainingRow in project knime-core by knime.

the class SagLogRegLearner method learn.

/**
 * {@inheritDoc}
 */
@Override
public LogRegLearnerResult learn(final TrainingData<ClassificationTrainingRow> data, final ExecutionMonitor progressMonitor) throws CanceledExecutionException, InvalidSettingsException {
    AbstractSGOptimizer sgOpt = createOptimizer(m_settings, data);
    SimpleProgress progMon = new SimpleProgress(progressMonitor.getProgressMonitor());
    LogRegLearnerResult result = sgOpt.optimize(m_settings.getMaxEpoch(), data, progMon);
    Optional<String> warning = sgOpt.getWarning();
    if (warning.isPresent()) {
        m_warning = warning.get();
    }
    return result;
}
Also used : LogRegLearnerResult(org.knime.base.node.mine.regression.logistic.learner4.LogRegLearnerResult)

Example 2 with ClassificationTrainingRow

use of org.knime.base.node.mine.regression.logistic.learner4.data.ClassificationTrainingRow in project knime-core by knime.

the class SparseClassificationTrainingRowTest method testFeatureIterator.

/**
 * Tests the {@link FeatureIterator} returned by {@link ClassificationTrainingRow#getFeatureIterator()}.
 *
 * @throws Exception
 */
@Test
public void testFeatureIterator() throws Exception {
    SparseClassificationTrainingRow row = createRow();
    FeatureIterator fi = row.getFeatureIterator();
    for (int i = 0; i < INDICES.length; i++) {
        assertTrue(fi.hasNext());
        assertTrue(fi.next());
        assertEquals(INDICES[i], fi.getFeatureIndex());
        // there are no differences allowed here
        assertEquals(VALUES[i], fi.getFeatureValue(), 0);
        if (i == 2) {
            FeatureIterator sfi = fi.spawn();
            assertEquals(INDICES[i - 1], sfi.getFeatureIndex());
            assertEquals(VALUES[i - 1], sfi.getFeatureValue(), 0);
        }
    }
    assertFalse(fi.hasNext());
    assertFalse(fi.next());
}
Also used : FeatureIterator(org.knime.base.node.mine.regression.logistic.learner4.data.TrainingRow.FeatureIterator) Test(org.junit.Test)

Example 3 with ClassificationTrainingRow

use of org.knime.base.node.mine.regression.logistic.learner4.data.ClassificationTrainingRow in project knime-core by knime.

the class LogRegCoordinator method learn.

/**
 * Performs the learning task by creating the appropriate LogRegLearner and all other objects
 * necessary for a successful training.
 *
 * @param trainingData a DataTable that contains the data on which to learn the logistic regression model
 * @param exec the execution context of the corresponding KNIME node
 * @return the content of the logistic regression model
 * @throws InvalidSettingsException if the settings cause inconsistencies during training
 * @throws CanceledExecutionException if the training is canceled
 */
LogisticRegressionContent learn(final BufferedDataTable trainingData, final ExecutionContext exec) throws InvalidSettingsException, CanceledExecutionException {
    CheckUtils.checkArgument(trainingData.size() > 0, "The input table is empty. Please provide data to learn on.");
    CheckUtils.checkArgument(trainingData.size() <= Integer.MAX_VALUE, "The input table contains too many rows.");
    LogRegLearner learner;
    if (m_settings.getSolver() == Solver.IRLS) {
        learner = new IrlsLearner(m_settings.getMaxEpoch(), m_settings.getEpsilon(), m_settings.isCalcCovMatrix());
    } else {
        learner = new SagLogRegLearner(m_settings);
    }
    double calcDomainTime = 1.0 / (5.0 * 2.0 + 1.0);
    exec.setMessage("Analyzing categorical data");
    BufferedDataTable dataTable = recalcDomainForTargetAndLearningFields(trainingData, exec.createSubExecutionContext(calcDomainTime));
    checkConstantLearningFields(dataTable);
    exec.setMessage("Building logistic regression model");
    ExecutionMonitor trainExec = exec.createSubProgress(1.0 - calcDomainTime);
    LogRegLearnerResult result;
    TrainingRowBuilder<ClassificationTrainingRow> rowBuilder = new SparseClassificationTrainingRowBuilder(dataTable, m_pmmlOutSpec, m_settings.getTargetReferenceCategory(), m_settings.getSortTargetCategories(), m_settings.getSortIncludesCategories());
    TrainingData<ClassificationTrainingRow> data;
    Long seed = m_settings.getSeed();
    if (m_settings.isInMemory()) {
        data = new InMemoryData<ClassificationTrainingRow>(dataTable, seed, rowBuilder);
    } else {
        data = new DataTableTrainingData<ClassificationTrainingRow>(trainingData, seed, rowBuilder, m_settings.getChunkSize(), exec.createSilentSubExecutionContext(0.0));
    }
    checkShapeCompatibility(data);
    result = learner.learn(data, trainExec);
    LogisticRegressionContent content = createContentFromLearnerResult(result, rowBuilder, trainingData.getDataTableSpec());
    addToWarning(learner.getWarningMessage());
    return content;
}
Also used : ClassificationTrainingRow(org.knime.base.node.mine.regression.logistic.learner4.data.ClassificationTrainingRow) SparseClassificationTrainingRowBuilder(org.knime.base.node.mine.regression.logistic.learner4.data.SparseClassificationTrainingRowBuilder) SagLogRegLearner(org.knime.base.node.mine.regression.logistic.learner4.sg.SagLogRegLearner) SagLogRegLearner(org.knime.base.node.mine.regression.logistic.learner4.sg.SagLogRegLearner) BufferedDataTable(org.knime.core.node.BufferedDataTable) ExecutionMonitor(org.knime.core.node.ExecutionMonitor)

Example 4 with ClassificationTrainingRow

use of org.knime.base.node.mine.regression.logistic.learner4.data.ClassificationTrainingRow in project knime-core by knime.

the class IrlsLearner method likelihood.

// private RealMatrix getStdErrorMatrix(final RealMatrix xTwx) {
// RealMatrix covMat = new QRDecomposition(xTwx).getSolver().getInverse().scalarMultiply(-1);
// // the standard error estimate
// RealMatrix stdErr = MatrixUtils.createRealMatrix(covMat.getColumnDimension(),
// covMat.getRowDimension());
// for (int i = 0; i < covMat.getRowDimension(); i++) {
// stdErr.setEntry(i, i, sqrt(abs(covMat.getEntry(i, i))));
// }
// return stdErr;
// }
/**
 * Compute the likelihood at given beta.
 *
 * @param iter iterator over trainings data.
 * @param beta parameter vector
 * @param rC regressors count
 * @param tcC target category count
 * @throws CanceledExecutionException when method is cancelled
 */
private double likelihood(final Iterator<ClassificationTrainingRow> iter, final RealMatrix beta, final int rC, final int tcC, final ExecutionMonitor exec) throws CanceledExecutionException {
    double loglike = 0;
    RealMatrix x = MatrixUtils.createRealMatrix(1, rC + 1);
    while (iter.hasNext()) {
        exec.checkCanceled();
        ClassificationTrainingRow row = iter.next();
        fillXFromRow(x, row);
        double sumEBetaTx = 0;
        for (int i = 0; i < tcC - 1; i++) {
            RealMatrix betaITx = x.multiply(beta.getSubMatrix(0, 0, i * (rC + 1), (i + 1) * (rC + 1) - 1).transpose());
            sumEBetaTx += Math.exp(betaITx.getEntry(0, 0));
        }
        int y = row.getCategory();
        double yBetaTx = 0;
        if (y < tcC - 1) {
            yBetaTx = x.multiply(beta.getSubMatrix(0, 0, y * (rC + 1), (y + 1) * (rC + 1) - 1).transpose()).getEntry(0, 0);
        }
        loglike += yBetaTx - Math.log(1 + sumEBetaTx);
    }
    return loglike;
}
Also used : ClassificationTrainingRow(org.knime.base.node.mine.regression.logistic.learner4.data.ClassificationTrainingRow) RealMatrix(org.apache.commons.math3.linear.RealMatrix)

Example 5 with ClassificationTrainingRow

use of org.knime.base.node.mine.regression.logistic.learner4.data.ClassificationTrainingRow in project knime-core by knime.

the class IrlsLearner method fillXFromRow.

private static void fillXFromRow(final RealMatrix x, final ClassificationTrainingRow row) {
    FeatureIterator iter = row.getFeatureIterator();
    boolean hasNext = iter.next();
    for (int i = 0; i < x.getColumnDimension(); i++) {
        double val = 0.0;
        if (hasNext && iter.getFeatureIndex() == i) {
            val = iter.getFeatureValue();
            hasNext = iter.next();
        }
        x.setEntry(0, i, val);
    }
}
Also used : FeatureIterator(org.knime.base.node.mine.regression.logistic.learner4.data.TrainingRow.FeatureIterator)

Aggregations

ClassificationTrainingRow (org.knime.base.node.mine.regression.logistic.learner4.data.ClassificationTrainingRow)3 FeatureIterator (org.knime.base.node.mine.regression.logistic.learner4.data.TrainingRow.FeatureIterator)3 RealMatrix (org.apache.commons.math3.linear.RealMatrix)2 DecompositionSolver (org.apache.commons.math3.linear.DecompositionSolver)1 SingularValueDecomposition (org.apache.commons.math3.linear.SingularValueDecomposition)1 Test (org.junit.Test)1 LogRegLearnerResult (org.knime.base.node.mine.regression.logistic.learner4.LogRegLearnerResult)1 SparseClassificationTrainingRowBuilder (org.knime.base.node.mine.regression.logistic.learner4.data.SparseClassificationTrainingRowBuilder)1 SagLogRegLearner (org.knime.base.node.mine.regression.logistic.learner4.sg.SagLogRegLearner)1 BufferedDataTable (org.knime.core.node.BufferedDataTable)1 ExecutionMonitor (org.knime.core.node.ExecutionMonitor)1