Search in sources :

Example 11 with QRDecomposition

use of org.apache.commons.math3.linear.QRDecomposition in project knime-core by knime.

the class Learner method irlsRls.

/**
 * Do a irls step. The result is stored in beta.
 *
 * @param data over trainings data.
 * @param beta parameter vector
 * @param rC regressors count
 * @param tcC target category count
 * @throws CanceledExecutionException when method is cancelled
 */
private void irlsRls(final RegressionTrainingData data, final RealMatrix beta, final int rC, final int tcC, final ExecutionMonitor exec) throws CanceledExecutionException {
    Iterator<RegressionTrainingRow> iter = data.iterator();
    long rowCount = 0;
    int dim = (rC + 1) * (tcC - 1);
    RealMatrix xTwx = new Array2DRowRealMatrix(dim, dim);
    RealMatrix xTyu = new Array2DRowRealMatrix(dim, 1);
    RealMatrix x = new Array2DRowRealMatrix(1, rC + 1);
    RealMatrix eBetaTx = new Array2DRowRealMatrix(1, tcC - 1);
    RealMatrix pi = new Array2DRowRealMatrix(1, tcC - 1);
    final long totalRowCount = data.getRowCount();
    while (iter.hasNext()) {
        rowCount++;
        RegressionTrainingRow row = iter.next();
        exec.checkCanceled();
        exec.setProgress(rowCount / (double) totalRowCount, "Row " + rowCount + "/" + totalRowCount);
        x.setEntry(0, 0, 1);
        x.setSubMatrix(row.getParameter().getData(), 0, 1);
        for (int k = 0; k < tcC - 1; k++) {
            RealMatrix betaITx = x.multiply(beta.getSubMatrix(0, 0, k * (rC + 1), (k + 1) * (rC + 1) - 1).transpose());
            eBetaTx.setEntry(0, k, Math.exp(betaITx.getEntry(0, 0)));
        }
        double sumEBetaTx = 0;
        for (int k = 0; k < tcC - 1; k++) {
            sumEBetaTx += eBetaTx.getEntry(0, k);
        }
        for (int k = 0; k < tcC - 1; k++) {
            double pik = eBetaTx.getEntry(0, k) / (1 + sumEBetaTx);
            pi.setEntry(0, k, pik);
        }
        // fill the diagonal blocks of matrix xTwx (k = k')
        for (int k = 0; k < tcC - 1; k++) {
            for (int i = 0; i < rC + 1; i++) {
                for (int ii = i; ii < rC + 1; ii++) {
                    int o = k * (rC + 1);
                    double v = xTwx.getEntry(o + i, o + ii);
                    double w = pi.getEntry(0, k) * (1 - pi.getEntry(0, k));
                    v += x.getEntry(0, i) * w * x.getEntry(0, ii);
                    xTwx.setEntry(o + i, o + ii, v);
                    xTwx.setEntry(o + ii, o + i, v);
                }
            }
        }
        // fill the rest of xTwx (k != k')
        for (int k = 0; k < tcC - 1; k++) {
            for (int kk = k + 1; kk < tcC - 1; kk++) {
                for (int i = 0; i < rC + 1; i++) {
                    for (int ii = i; ii < rC + 1; ii++) {
                        int o1 = k * (rC + 1);
                        int o2 = kk * (rC + 1);
                        double v = xTwx.getEntry(o1 + i, o2 + ii);
                        double w = -pi.getEntry(0, k) * pi.getEntry(0, kk);
                        v += x.getEntry(0, i) * w * x.getEntry(0, ii);
                        xTwx.setEntry(o1 + i, o2 + ii, v);
                        xTwx.setEntry(o1 + ii, o2 + i, v);
                        xTwx.setEntry(o2 + ii, o1 + i, v);
                        xTwx.setEntry(o2 + i, o1 + ii, v);
                    }
                }
            }
        }
        int g = (int) row.getTarget();
        // fill matrix xTyu
        for (int k = 0; k < tcC - 1; k++) {
            for (int i = 0; i < rC + 1; i++) {
                int o = k * (rC + 1);
                double v = xTyu.getEntry(o + i, 0);
                double y = k == g ? 1 : 0;
                v += (y - pi.getEntry(0, k)) * x.getEntry(0, i);
                xTyu.setEntry(o + i, 0, v);
            }
        }
    }
    if (m_penaltyTerm > 0.0) {
        RealMatrix stdError = getStdErrorMatrix(xTwx);
        // do not penalize the constant terms
        for (int i = 0; i < tcC - 1; i++) {
            stdError.setEntry(i * (rC + 1), i * (rC + 1), 0);
        }
        xTwx = xTwx.add(stdError.scalarMultiply(-0.00001));
    }
    exec.checkCanceled();
    b = xTwx.multiply(beta.transpose()).add(xTyu);
    A = xTwx;
    if (rowCount < A.getColumnDimension()) {
        throw new IllegalStateException("The dataset must have at least " + A.getColumnDimension() + " rows, but it has only " + rowCount + " rows. It is recommended to use a " + "larger dataset in order to increase accuracy.");
    }
    DecompositionSolver solver = new QRDecomposition(A).getSolver();
    boolean isNonSingular = solver.isNonSingular();
    if (isNonSingular) {
        RealMatrix betaNew = solver.solve(b);
        beta.setSubMatrix(betaNew.transpose().getData(), 0, 0);
    } else {
        throw new RuntimeException(FAILING_MSG);
    }
}
Also used : QRDecomposition(org.apache.commons.math3.linear.QRDecomposition) Array2DRowRealMatrix(org.apache.commons.math3.linear.Array2DRowRealMatrix) RealMatrix(org.apache.commons.math3.linear.RealMatrix) Array2DRowRealMatrix(org.apache.commons.math3.linear.Array2DRowRealMatrix) DecompositionSolver(org.apache.commons.math3.linear.DecompositionSolver) RegressionTrainingRow(org.knime.base.node.mine.regression.RegressionTrainingRow)

Example 12 with QRDecomposition

use of org.apache.commons.math3.linear.QRDecomposition in project knime-core by knime.

the class Learner method getStdErrorMatrix.

private RealMatrix getStdErrorMatrix(final RealMatrix xTwx) {
    RealMatrix covMat = new QRDecomposition(xTwx).getSolver().getInverse().scalarMultiply(-1);
    // the standard error estimate
    RealMatrix stdErr = new Array2DRowRealMatrix(covMat.getColumnDimension(), covMat.getRowDimension());
    for (int i = 0; i < covMat.getRowDimension(); i++) {
        stdErr.setEntry(i, i, sqrt(abs(covMat.getEntry(i, i))));
    }
    return stdErr;
}
Also used : QRDecomposition(org.apache.commons.math3.linear.QRDecomposition) Array2DRowRealMatrix(org.apache.commons.math3.linear.Array2DRowRealMatrix) RealMatrix(org.apache.commons.math3.linear.RealMatrix) Array2DRowRealMatrix(org.apache.commons.math3.linear.Array2DRowRealMatrix)

Example 13 with QRDecomposition

use of org.apache.commons.math3.linear.QRDecomposition in project knime-core by knime.

the class Learner method perform.

/**
 * @param data The data table.
 * @param exec The execution context used for reporting progress.
 * @return An object which holds the results.
 * @throws CanceledExecutionException when method is cancelled
 * @throws InvalidSettingsException When settings are inconsistent with the data
 */
public LogisticRegressionContent perform(final BufferedDataTable data, final ExecutionContext exec) throws CanceledExecutionException, InvalidSettingsException {
    exec.checkCanceled();
    int iter = 0;
    boolean converged = false;
    final RegressionTrainingData trainingData = new RegressionTrainingData(data, m_outSpec, true, m_targetReferenceCategory, m_sortTargetCategories, m_sortFactorsCategories);
    int targetIndex = data.getDataTableSpec().findColumnIndex(m_outSpec.getTargetCols().get(0).getName());
    final int tcC = trainingData.getDomainValues().get(targetIndex).size();
    final int rC = trainingData.getRegressorCount();
    final RealMatrix beta = new Array2DRowRealMatrix(1, (tcC - 1) * (rC + 1));
    Double loglike = 0.0;
    Double loglikeOld = 0.0;
    exec.setMessage("Iterative optimization. Processing iteration 1.");
    // main loop
    while (iter < m_maxIter && !converged) {
        RealMatrix betaOld = beta.copy();
        loglikeOld = loglike;
        // Do heavy work in a separate thread which allows to interrupt it
        // note the queue may block if no more threads are available (e.g. thread count = 1)
        // as soon as we stall in 'get' this thread reduces the number of running thread
        Future<Double> future = ThreadPool.currentPool().enqueue(new Callable<Double>() {

            @Override
            public Double call() throws Exception {
                final ExecutionMonitor progMon = exec.createSubProgress(1.0 / m_maxIter);
                irlsRls(trainingData, beta, rC, tcC, progMon);
                progMon.setProgress(1.0);
                return likelihood(trainingData.iterator(), beta, rC, tcC, exec);
            }
        });
        try {
            loglike = future.get();
        } catch (InterruptedException e) {
            future.cancel(true);
            exec.checkCanceled();
            throw new RuntimeException(e);
        } catch (ExecutionException e) {
            if (e.getCause() instanceof RuntimeException) {
                throw (RuntimeException) e.getCause();
            } else {
                throw new RuntimeException(e.getCause());
            }
        }
        if (Double.isInfinite(loglike) || Double.isNaN(loglike)) {
            throw new RuntimeException(FAILING_MSG);
        }
        exec.checkCanceled();
        // test for decreasing likelihood
        while ((Double.isInfinite(loglike) || Double.isNaN(loglike) || loglike < loglikeOld) && iter > 0) {
            converged = true;
            for (int k = 0; k < beta.getRowDimension(); k++) {
                if (abs(beta.getEntry(k, 0) - betaOld.getEntry(k, 0)) > m_eps * abs(betaOld.getEntry(k, 0))) {
                    converged = false;
                    break;
                }
            }
            if (converged) {
                break;
            }
            // half the step size of beta
            beta.setSubMatrix((beta.add(betaOld)).scalarMultiply(0.5).getData(), 0, 0);
            exec.checkCanceled();
            loglike = likelihood(trainingData.iterator(), beta, rC, tcC, exec);
            exec.checkCanceled();
        }
        // test for convergence
        converged = true;
        for (int k = 0; k < beta.getRowDimension(); k++) {
            if (abs(beta.getEntry(k, 0) - betaOld.getEntry(k, 0)) > m_eps * abs(betaOld.getEntry(k, 0))) {
                converged = false;
                break;
            }
        }
        iter++;
        LOGGER.debug("#Iterations: " + iter);
        LOGGER.debug("Log Likelihood: " + loglike);
        StringBuilder betaBuilder = new StringBuilder();
        for (int i = 0; i < beta.getRowDimension() - 1; i++) {
            betaBuilder.append(Double.toString(beta.getEntry(i, 0)));
            betaBuilder.append(", ");
        }
        if (beta.getRowDimension() > 0) {
            betaBuilder.append(Double.toString(beta.getEntry(beta.getRowDimension() - 1, 0)));
        }
        LOGGER.debug("beta: " + betaBuilder.toString());
        exec.checkCanceled();
        exec.setMessage("Iterative optimization. #Iterations: " + iter + " | Log-likelihood: " + DoubleFormat.formatDouble(loglike) + ". Processing iteration " + (iter + 1) + ".");
    }
    // The covariance matrix
    RealMatrix covMat = new QRDecomposition(A).getSolver().getInverse().scalarMultiply(-1);
    List<String> factorList = new ArrayList<String>();
    List<String> covariateList = new ArrayList<String>();
    Map<String, List<DataCell>> factorDomainValues = new HashMap<String, List<DataCell>>();
    for (int i : trainingData.getActiveCols()) {
        if (trainingData.getIsNominal().get(i)) {
            String factor = data.getDataTableSpec().getColumnSpec(i).getName();
            factorList.add(factor);
            List<DataCell> values = trainingData.getDomainValues().get(i);
            factorDomainValues.put(factor, values);
        } else {
            covariateList.add(data.getDataTableSpec().getColumnSpec(i).getName());
        }
    }
    Matrix betaJama = new Matrix(beta.getData());
    Matrix covMatJama = new Matrix(covMat.getData());
    // create content
    LogisticRegressionContent content = new LogisticRegressionContent(m_outSpec, factorList, covariateList, m_targetReferenceCategory, m_sortTargetCategories, m_sortFactorsCategories, betaJama, loglike, covMatJama, iter);
    return content;
}
Also used : HashMap(java.util.HashMap) ArrayList(java.util.ArrayList) Array2DRowRealMatrix(org.apache.commons.math3.linear.Array2DRowRealMatrix) Matrix(Jama.Matrix) RealMatrix(org.apache.commons.math3.linear.RealMatrix) Array2DRowRealMatrix(org.apache.commons.math3.linear.Array2DRowRealMatrix) RegressionTrainingData(org.knime.base.node.mine.regression.RegressionTrainingData) ArrayList(java.util.ArrayList) List(java.util.List) ExecutionMonitor(org.knime.core.node.ExecutionMonitor) CanceledExecutionException(org.knime.core.node.CanceledExecutionException) ExecutionException(java.util.concurrent.ExecutionException) InvalidSettingsException(org.knime.core.node.InvalidSettingsException) CanceledExecutionException(org.knime.core.node.CanceledExecutionException) ExecutionException(java.util.concurrent.ExecutionException) QRDecomposition(org.apache.commons.math3.linear.QRDecomposition) Array2DRowRealMatrix(org.apache.commons.math3.linear.Array2DRowRealMatrix) RealMatrix(org.apache.commons.math3.linear.RealMatrix) DataCell(org.knime.core.data.DataCell)

Example 14 with QRDecomposition

use of org.apache.commons.math3.linear.QRDecomposition in project knime-core by knime.

the class Learner method getStdErrorMatrix.

private RealMatrix getStdErrorMatrix(final RealMatrix xTwx) {
    RealMatrix covMat = new QRDecomposition(xTwx).getSolver().getInverse().scalarMultiply(-1);
    // the standard error estimate
    RealMatrix stdErr = new Array2DRowRealMatrix(covMat.getColumnDimension(), covMat.getRowDimension());
    for (int i = 0; i < covMat.getRowDimension(); i++) {
        stdErr.setEntry(i, i, sqrt(abs(covMat.getEntry(i, i))));
    }
    return stdErr;
}
Also used : QRDecomposition(org.apache.commons.math3.linear.QRDecomposition) Array2DRowRealMatrix(org.apache.commons.math3.linear.Array2DRowRealMatrix) RealMatrix(org.apache.commons.math3.linear.RealMatrix) Array2DRowRealMatrix(org.apache.commons.math3.linear.Array2DRowRealMatrix)

Aggregations

RealMatrix (org.apache.commons.math3.linear.RealMatrix)14 QRDecomposition (org.apache.commons.math3.linear.QRDecomposition)13 Array2DRowRealMatrix (org.apache.commons.math3.linear.Array2DRowRealMatrix)9 ExecutionException (java.util.concurrent.ExecutionException)3 DecompositionSolver (org.apache.commons.math3.linear.DecompositionSolver)3 CanceledExecutionException (org.knime.core.node.CanceledExecutionException)3 ExecutionMonitor (org.knime.core.node.ExecutionMonitor)3 InvalidSettingsException (org.knime.core.node.InvalidSettingsException)3 ArrayList (java.util.ArrayList)2 HashMap (java.util.HashMap)2 List (java.util.List)2 RegressionTrainingData (org.knime.base.node.mine.regression.RegressionTrainingData)2 DataCell (org.knime.core.data.DataCell)2 Matrix (Jama.Matrix)1 LinkedHashMap (java.util.LinkedHashMap)1 SingularMatrixException (org.apache.commons.math3.linear.SingularMatrixException)1 DMLRuntimeException (org.apache.sysml.runtime.DMLRuntimeException)1 RegressionTrainingRow (org.knime.base.node.mine.regression.RegressionTrainingRow)1 ClassificationTrainingRow (org.knime.base.node.mine.regression.logistic.learner4.data.ClassificationTrainingRow)1 DataColumnSpec (org.knime.core.data.DataColumnSpec)1