Search in sources :

Example 1 with HalfIntMatrix

use of org.knime.base.util.HalfIntMatrix in project knime-core by knime.

the class CorrelationComputer method calculateStatistics.

/**
 * First scan on the data. Calculates (pair wise) means and std dev
 * and determines the list of distinct values for each categorical column.
 * @param table ...
 * @param exec ...
 * @throws CanceledExecutionException
 */
@SuppressWarnings("unchecked")
public void calculateStatistics(final BufferedDataTable table, final ExecutionContext exec) throws CanceledExecutionException {
    DataTableSpec filterTableSpec = table.getDataTableSpec();
    assert filterTableSpec.equalStructure(m_tableSpec);
    m_possibleValues = new LinkedHashMap[m_categoricalColIndexMap.length];
    for (int i = 0; i < m_possibleValues.length; i++) {
        m_possibleValues[i] = new LinkedHashMap<DataCell, Integer>();
    }
    final int numericColCount = m_numericColIndexMap.length;
    double[][] sumMatrix = new double[numericColCount][numericColCount];
    double[][] sumSqMatrix = new double[numericColCount][numericColCount];
    HalfIntMatrix validCountMatrix = new HalfIntMatrix(numericColCount, true);
    final DataCell[] cells = new DataCell[m_tableSpec.getNumColumns()];
    long rowIndex = 0;
    final long rowCount = table.size();
    for (DataRow r : table) {
        // multiple times, so we buffer it
        for (int i = 0; i < cells.length; i++) {
            cells[i] = r.getCell(i);
        }
        for (int i = 0; i < m_numericColIndexMap.length; i++) {
            DataCell c = cells[m_numericColIndexMap[i]];
            final boolean isMissing = c.isMissing();
            if (isMissing) {
                m_numericsWithMissings.add(m_numericColIndexMap[i]);
            } else {
                final double val = ((DoubleValue) c).getDoubleValue();
                final double valSquare = val * val;
                for (int j = 0; j < m_numericColIndexMap.length; j++) {
                    if (!cells[m_numericColIndexMap[j]].isMissing()) {
                        sumMatrix[i][j] += val;
                        sumSqMatrix[i][j] += valSquare;
                        if (j >= i) {
                            // don't count twice
                            validCountMatrix.add(i, j, 1);
                        }
                    }
                }
            }
        }
        for (int i = 0; i < m_categoricalColIndexMap.length; i++) {
            DataCell c = r.getCell(m_categoricalColIndexMap[i]);
            if (m_possibleValues[i] != null) {
                // note: also take missing value as possible value
                m_possibleValues[i].put(c, null);
                if (m_possibleValues[i].size() > m_maxPossibleValues) {
                    m_possibleValues[i] = null;
                }
            }
        }
        exec.checkCanceled();
        exec.setProgress(rowIndex / (double) rowCount, String.format("Calculating statistics - %d/%d (\"%s\")", rowIndex, rowCount, r.getKey()));
        rowIndex += 1;
    }
    for (LinkedHashMap<DataCell, Integer> map : m_possibleValues) {
        if (map != null) {
            int index = 0;
            for (Map.Entry<DataCell, Integer> entry : map.entrySet()) {
                entry.setValue(index++);
            }
        }
    }
    // sumSqMatrix --> m_numericStdDevMatrix
    for (int i = 0; i < numericColCount; i++) {
        for (int j = 0; j < numericColCount; j++) {
            final int validCount = validCountMatrix.get(i, j);
            if (validCount > 1) {
                double variance = (sumSqMatrix[i][j] - (sumMatrix[i][j] * sumMatrix[i][j]) / validCount) / (validCount - 1);
                if (variance < PMCCPortObjectAndSpec.ROUND_ERROR_OK) {
                    variance = 0.0;
                }
                sumSqMatrix[i][j] = Math.sqrt(variance);
            } else {
                sumSqMatrix[i][j] = 0.0;
            }
            sumMatrix[i][j] = validCount > 0 ? sumMatrix[i][j] / validCount : Double.NaN;
        }
    }
    m_numericMeanMatrix = sumMatrix;
    m_numericStdDevMatrix = sumSqMatrix;
    m_numericValidCountMatrix = validCountMatrix;
}
Also used : DataTableSpec(org.knime.core.data.DataTableSpec) DataRow(org.knime.core.data.DataRow) HalfIntMatrix(org.knime.base.util.HalfIntMatrix) DoubleValue(org.knime.core.data.DoubleValue) DataCell(org.knime.core.data.DataCell) LinkedHashMap(java.util.LinkedHashMap) Map(java.util.Map)

Aggregations

LinkedHashMap (java.util.LinkedHashMap)1 Map (java.util.Map)1 HalfIntMatrix (org.knime.base.util.HalfIntMatrix)1 DataCell (org.knime.core.data.DataCell)1 DataRow (org.knime.core.data.DataRow)1 DataTableSpec (org.knime.core.data.DataTableSpec)1 DoubleValue (org.knime.core.data.DoubleValue)1