Search in sources :

Example 1 with StorelessCovariance

use of org.apache.commons.math3.stat.correlation.StorelessCovariance in project knime-core by knime.

the class CovarianceMatrixCalculator method calculateCovarianceMatrix.

/**
 * Computes the covariance matrix and puts the result in the given (optional) data container and additionally
 * returns a in memory representation. The data container is expected to have the data table spec returned at
 * {@link #getResultSpec()}. The implementation traverses the data once.
 *
 * @param exec the execution container
 * @param inTable input data
 * @param tableSize the data table size
 * @param resultDataContainer optional result data container
 * @return the covariance matrix
 * @throws CanceledExecutionException if the user canceled the execution
 */
public RealMatrix calculateCovarianceMatrix(final ExecutionMonitor exec, final DataTable inTable, final long tableSize, final DataContainer resultDataContainer) throws CanceledExecutionException {
    checkArgument(m_targetSpec.equalStructure(inTable.getDataTableSpec()), "Target tables spec is different from the one given in the constructor!");
    if (resultDataContainer != null) {
        checkArgument(m_resultSpec.equalStructure(resultDataContainer.getTableSpec()), "Result tables spec is invalid!");
    }
    final ExecutionMonitor computingProgress = exec.createSubProgress(resultDataContainer != null ? 0.8 : 1);
    List<StorelessCovariance> covariancesList = new ArrayList<>();
    // create covariance pairs
    for (int i = 0; i < m_indexes.length; i++) {
        for (int j = i; j < m_indexes.length; j++) {
            covariancesList.add(new StorelessCovariance(2));
        }
    }
    // compute rest of co-variance matrix
    int rowCount = 0;
    double[] buffer = new double[2];
    for (DataRow dataRow : inTable) {
        for (int i = 0; i < m_indexes.length; i++) {
            final int outerIndex = m_indexes[i];
            final DataCell outerCell = dataRow.getCell(outerIndex);
            if (outerCell.isMissing()) {
                // skip missing values
                continue;
            }
            final double outerDouble = ((DoubleValue) outerCell).getDoubleValue();
            for (int j = i; j < m_indexes.length; j++) {
                final int innerIndex = m_indexes[j];
                final DataCell innerCell = dataRow.getCell(innerIndex);
                if (innerCell.isMissing()) {
                    // skip missing values
                    continue;
                }
                final double innerDouble = ((DoubleValue) innerCell).getDoubleValue();
                buffer[0] = outerDouble;
                buffer[1] = innerDouble;
                int covListIndex = index(m_indexes.length, i, j);
                covariancesList.get(covListIndex).increment(buffer);
            }
        }
        computingProgress.setProgress(rowCount++ / (double) tableSize, "Calculate covariance values, processing row: '" + dataRow.getKey() + "'");
        computingProgress.checkCanceled();
    }
    // Copy the storeless covariances to a real matrix
    RealMatrix covMatrix = new Array2DRowRealMatrix(m_indexes.length, m_indexes.length);
    for (int i = 0; i < m_indexes.length; i++) {
        for (int j = i; j < m_indexes.length; j++) {
            int covListIndex = index(m_indexes.length, i, j);
            double covValue;
            try {
                covValue = i == j ? covariancesList.get(covListIndex).getCovariance(1, 1) : covariancesList.get(covListIndex).getCovariance(0, 1);
            } catch (NumberIsTooSmallException e) {
                throw new IllegalArgumentException(String.format("There were not enough valid values to " + "compute covariance between columns: '%s' and '%s'.", inTable.getDataTableSpec().getColumnSpec(m_indexes[i]).getName(), inTable.getDataTableSpec().getColumnSpec(m_indexes[j]).getName()), e);
            }
            covMatrix.setEntry(i, j, covValue);
            covMatrix.setEntry(j, i, covValue);
        }
    }
    if (resultDataContainer != null) {
        exec.setProgress("Writing matrix to data table");
        final ExecutionMonitor writingProgress = exec.createSubProgress(0.2);
        for (int i = 0; i < covMatrix.getRowDimension(); i++) {
            resultDataContainer.addRowToTable(new DefaultRow(RowKey.toRowKeys(resultDataContainer.getTableSpec().getColumnSpec(i).getName())[0], covMatrix.getRow(i)));
            exec.checkCanceled();
            writingProgress.setProgress((double) i / covMatrix.getRowDimension(), "Writing row: " + resultDataContainer.getTableSpec().getColumnSpec(i).getName());
        }
    }
    return covMatrix;
}
Also used : ArrayList(java.util.ArrayList) NumberIsTooSmallException(org.apache.commons.math3.exception.NumberIsTooSmallException) StorelessCovariance(org.apache.commons.math3.stat.correlation.StorelessCovariance) DataRow(org.knime.core.data.DataRow) Array2DRowRealMatrix(org.apache.commons.math3.linear.Array2DRowRealMatrix) RealMatrix(org.apache.commons.math3.linear.RealMatrix) Array2DRowRealMatrix(org.apache.commons.math3.linear.Array2DRowRealMatrix) DoubleValue(org.knime.core.data.DoubleValue) DataCell(org.knime.core.data.DataCell) ExecutionMonitor(org.knime.core.node.ExecutionMonitor) DefaultRow(org.knime.core.data.def.DefaultRow)

Aggregations

ArrayList (java.util.ArrayList)1 NumberIsTooSmallException (org.apache.commons.math3.exception.NumberIsTooSmallException)1 Array2DRowRealMatrix (org.apache.commons.math3.linear.Array2DRowRealMatrix)1 RealMatrix (org.apache.commons.math3.linear.RealMatrix)1 StorelessCovariance (org.apache.commons.math3.stat.correlation.StorelessCovariance)1 DataCell (org.knime.core.data.DataCell)1 DataRow (org.knime.core.data.DataRow)1 DoubleValue (org.knime.core.data.DoubleValue)1 DefaultRow (org.knime.core.data.def.DefaultRow)1 ExecutionMonitor (org.knime.core.node.ExecutionMonitor)1