Search in sources :

Example 6 with HalfDoubleMatrix

use of org.knime.base.util.HalfDoubleMatrix in project knime-core by knime.

the class RankCorrelationComputeNodeModel method execute.

/**
 * {@inheritDoc}
 */
@Override
protected PortObject[] execute(final PortObject[] inData, final ExecutionContext exec) throws Exception {
    final BufferedDataTable in = (BufferedDataTable) inData[0];
    final DataTableSpec inSpec = in.getDataTableSpec();
    ColumnRearranger filteredTableRearranger = new ColumnRearranger(inSpec);
    String[] includeNames = m_columnFilterModel.applyTo(inSpec).getIncludes();
    filteredTableRearranger.keepOnly(includeNames);
    final BufferedDataTable filteredTable = exec.createColumnRearrangeTable(in, filteredTableRearranger, exec.createSilentSubExecutionContext(0.0));
    final BufferedDataTable noMissTable = filterMissings(filteredTable, exec);
    if (noMissTable.getRowCount() < filteredTable.getRowCount()) {
        setWarningMessage("Rows containing missing values are filtered. Please resolve them" + " with the Missing Value node.");
    }
    double progStep1 = 0.48;
    double progStep2 = 0.48;
    double progFinish = 1.0 - progStep1 - progStep2;
    SortedCorrelationComputer calculator = new SortedCorrelationComputer();
    exec.setMessage("Generate ranking");
    ExecutionContext execStep1 = exec.createSubExecutionContext(progStep1);
    calculator.generateRank(noMissTable, execStep1);
    execStep1.setProgress(1.0);
    exec.setMessage("Calculating correlation values");
    ExecutionContext execStep2 = exec.createSubExecutionContext(progStep2);
    HalfDoubleMatrix correlationMatrix;
    if (m_corrType.getStringValue().equals(CFG_SPEARMAN)) {
        correlationMatrix = calculator.calculateSpearman(execStep2);
    } else {
        correlationMatrix = calculator.calculateKendallInMemory(m_corrType.getStringValue(), execStep2);
    }
    execStep2.setProgress(1.0);
    exec.setMessage("Assembling output");
    ExecutionContext execFinish = exec.createSubExecutionContext(progFinish);
    PMCCPortObjectAndSpec pmccModel = new PMCCPortObjectAndSpec(includeNames, correlationMatrix);
    BufferedDataTable out = pmccModel.createCorrelationMatrix(execFinish);
    m_correlationTable = out;
    if (in.getRowCount() == 0) {
        setWarningMessage("Empty input table! Generating missing values as correlation values.");
    }
    return new PortObject[] { out, pmccModel, calculator.getRankTable() };
}
Also used : DataTableSpec(org.knime.core.data.DataTableSpec) PMCCPortObjectAndSpec(org.knime.base.node.preproc.correlation.pmcc.PMCCPortObjectAndSpec) ColumnRearranger(org.knime.core.data.container.ColumnRearranger) ExecutionContext(org.knime.core.node.ExecutionContext) HalfDoubleMatrix(org.knime.base.util.HalfDoubleMatrix) BufferedDataTable(org.knime.core.node.BufferedDataTable) SettingsModelString(org.knime.core.node.defaultnodesettings.SettingsModelString) PortObject(org.knime.core.node.port.PortObject)

Example 7 with HalfDoubleMatrix

use of org.knime.base.util.HalfDoubleMatrix in project knime-core by knime.

the class SortedCorrelationComputer method calculateKendallInMemory.

/**
 * Calculates the kendall rank for all pairs of Data table columns based on previously calculated ranks.
 *
 * @param exec the Execution context.
 * @param corrType the type of correlation used, as defined in CorrelationComputeNodeModel
 * @return the output matrix to be turned into the output model
 * @throws CanceledExecutionException if canceled by users
 */
HalfDoubleMatrix calculateKendallInMemory(final String corrType, final ExecutionMonitor exec) throws CanceledExecutionException {
    // the ranking must have been calculated before
    assert (m_rank != null);
    final int coCount = m_rank.getDataTableSpec().getNumColumns();
    final int rowCount = m_rank.getRowCount();
    double[][] rank = new double[rowCount][coCount];
    int c = 0;
    for (DataRow row : m_rank) {
        for (int k = 0; k < coCount; k++) {
            rank[c][k] = ((DoubleValue) row.getCell(k)).getDoubleValue();
        }
        c++;
    }
    HalfDoubleMatrix nominatorMatrix = new HalfDoubleMatrix(coCount, /*includeDiagonal=*/
    false);
    double[][] cMatrix = new double[coCount][coCount];
    double[][] dMatrix = new double[coCount][coCount];
    double[][] txMatrix = new double[coCount][coCount];
    double[][] tyMatrix = new double[coCount][coCount];
    for (int rowIn1 = 0; rowIn1 < rowCount; rowIn1++) {
        for (int rowIn2 = 0; rowIn2 < rowCount; rowIn2++) {
            exec.checkCanceled();
            for (int i = 0; i < coCount; i++) {
                final double x1 = rank[rowIn1][i];
                final double x2 = rank[rowIn2][i];
                for (int j = 0; j < coCount; j++) {
                    final double y1 = rank[rowIn1][j];
                    final double y2 = rank[rowIn2][j];
                    if (x1 < x2 && y1 < y2) {
                        // values are concordant
                        cMatrix[i][j]++;
                    } else if (x1 < x2 && y1 > y2) {
                        // values are discordant
                        dMatrix[i][j]++;
                    } else if (x1 != x2 && y1 == y2) {
                        // values are bounded in y
                        tyMatrix[i][j]++;
                    } else if (x1 == x2 && y1 != y2) {
                        // values are bounded in x
                        txMatrix[i][j]++;
                    } else if (x1 == x2 && y1 == y2) {
                    // values are bounded in x and y
                    // txyMatrix[i][j]++; // no measure need this count
                    }
                }
            }
        }
        exec.checkCanceled();
        exec.setProgress(0.95 * rowIn1 / rowCount, String.format("Calculating - %d/%d ", rowIn1, rowCount));
    }
    if (corrType.equals(RankCorrelationComputeNodeModel.CFG_KENDALLA)) {
        double nrOfRows = m_rank.getRowCount();
        // kendalls Tau a
        double divisor = (nrOfRows * (nrOfRows - 1.0)) * 0.5;
        for (int i = 0; i < coCount; i++) {
            for (int j = i + 1; j < coCount; j++) {
                nominatorMatrix.set(i, j, (cMatrix[i][j] - dMatrix[i][j]) / divisor);
            }
            exec.setProgress(0.05 * i / coCount, "Calculating correlations");
        }
    } else if (corrType.equals(RankCorrelationComputeNodeModel.CFG_KENDALLB)) {
        double n0 = rowCount * (rowCount - 1) * 0.5;
        // kendalls Tau b
        for (int i = 0; i < coCount; i++) {
            for (int j = i + 1; j < coCount; j++) {
                // // we divide tx and ty by 2, as each of the pairs was counted twice
                double n1 = txMatrix[i][j] * 0.5;
                double n2 = tyMatrix[i][j] * 0.5;
                double div = Math.sqrt((n0 - n1) * (n0 - n2));
                nominatorMatrix.set(i, j, (cMatrix[i][j] - dMatrix[i][j]) / div);
            }
            exec.setProgress(0.05 * i / coCount, "Calculating correlations");
        }
    } else if (corrType.equals(RankCorrelationComputeNodeModel.CFG_KRUSKALAL)) {
        // Kruskals Gamma
        for (int i = 0; i < coCount; i++) {
            for (int j = i + 1; j < coCount; j++) {
                nominatorMatrix.set(i, j, (cMatrix[i][j] - dMatrix[i][j]) / (cMatrix[i][j] + dMatrix[i][j]));
            }
            exec.setProgress(0.05 * i / coCount, "Calculating correlations");
        }
    }
    return nominatorMatrix;
}
Also used : HalfDoubleMatrix(org.knime.base.util.HalfDoubleMatrix) DataRow(org.knime.core.data.DataRow)

Example 8 with HalfDoubleMatrix

use of org.knime.base.util.HalfDoubleMatrix in project knime-core by knime.

the class PMCCNodeModel method execute.

/**
 * {@inheritDoc}
 */
@Override
protected PortObject[] execute(final PortObject[] inData, final ExecutionContext exec) throws Exception {
    final BufferedDataTable in = (BufferedDataTable) inData[0];
    // floating point operation
    final double rC = in.getRowCount();
    int[] includes = getIncludes(in.getDataTableSpec());
    String[] includeNames = m_columnIncludesList.getIncludeList().toArray(new String[0]);
    double progNormalize = 0.3;
    double progDetermine = 0.65;
    double progFinish = 1.0 - progNormalize - progDetermine;
    exec.setMessage("Normalizing data");
    final ExecutionMonitor normProg = exec.createSubProgress(progNormalize);
    FilterColumnTable filterTable = new FilterColumnTable(in, includes);
    final int l = includes.length;
    int nomCount = (l - 1) * l / 2;
    final HalfDoubleMatrix nominatorMatrix = new HalfDoubleMatrix(includes.length, /*withDiagonal*/
    false);
    nominatorMatrix.fill(Double.NaN);
    @SuppressWarnings("unchecked") final LinkedHashMap<DataCell, Integer>[] possibleValues = new LinkedHashMap[l];
    DataTableSpec filterTableSpec = filterTable.getDataTableSpec();
    for (int i = 0; i < l; i++) {
        DataColumnSpec cs = filterTableSpec.getColumnSpec(i);
        if (cs.getType().isCompatible(NominalValue.class)) {
            possibleValues[i] = new LinkedHashMap<DataCell, Integer>();
        }
    }
    final int possValueUpperBound = m_maxPossValueCountModel.getIntValue();
    // determines possible values. We can't use those from the domain
    // as the domain can also contain values not present in the data
    // but in the contingency table we need rows/columns to have at least
    // one cell with a value >= 1
    StatisticsTable statTable = new StatisticsTable(filterTable) {

        // that is sort of the constructor in this derived class
        {
            calculateAllMoments(in.getRowCount(), normProg);
        }

        @Override
        protected void calculateMomentInSubClass(final DataRow row) {
            for (int i = 0; i < l; i++) {
                if (possibleValues[i] != null) {
                    DataCell c = row.getCell(i);
                    // note: also take missing value as possible value
                    possibleValues[i].put(c, null);
                    if (possibleValues[i].size() > possValueUpperBound) {
                        possibleValues[i] = null;
                    }
                }
            }
        }
    };
    for (LinkedHashMap<DataCell, Integer> map : possibleValues) {
        if (map != null) {
            int index = 0;
            for (Map.Entry<DataCell, Integer> entry : map.entrySet()) {
                entry.setValue(index++);
            }
        }
    }
    // stores all pair-wise contingency tables,
    // contingencyTables[i] == null <--> either column of the corresponding
    // pair is non-categorical.
    // What is a contingency table?
    // http://en.wikipedia.org/wiki/Contingency_table
    int[][][] contingencyTables = new int[nomCount][][];
    // column which only contain one value - no correlation available
    LinkedHashSet<String> constantColumns = new LinkedHashSet<String>();
    int valIndex = 0;
    for (int i = 0; i < l; i++) {
        for (int j = i + 1; j < l; j++) {
            if (possibleValues[i] != null && possibleValues[j] != null) {
                int iSize = possibleValues[i].size();
                int jSize = possibleValues[j].size();
                contingencyTables[valIndex] = new int[iSize][jSize];
            }
            DataColumnSpec colSpecI = filterTableSpec.getColumnSpec(i);
            DataColumnSpec colSpecJ = filterTableSpec.getColumnSpec(j);
            DataType ti = colSpecI.getType();
            DataType tj = colSpecJ.getType();
            if (ti.isCompatible(DoubleValue.class) && tj.isCompatible(DoubleValue.class)) {
                // one of the two columns contains only one value
                if (statTable.getVariance(i) < PMCCPortObjectAndSpec.ROUND_ERROR_OK) {
                    constantColumns.add(colSpecI.getName());
                    nominatorMatrix.set(i, j, Double.NaN);
                } else if (statTable.getVariance(j) < PMCCPortObjectAndSpec.ROUND_ERROR_OK) {
                    constantColumns.add(colSpecJ.getName());
                    nominatorMatrix.set(i, j, Double.NaN);
                } else {
                    nominatorMatrix.set(i, j, 0.0);
                }
            }
            valIndex++;
        }
    }
    // to other column (will be a missing value)
    if (!constantColumns.isEmpty()) {
        String[] constantColumnNames = constantColumns.toArray(new String[constantColumns.size()]);
        NodeLogger.getLogger(getClass()).info("The following numeric " + "columns contain only one distinct value or have " + "otherwise a low standard deviation: " + Arrays.toString(constantColumnNames));
        int maxLength = 4;
        if (constantColumns.size() > maxLength) {
            constantColumnNames = Arrays.copyOf(constantColumnNames, maxLength);
            constantColumnNames[maxLength - 1] = "...";
        }
        setWarningMessage("Some columns contain only one distinct value: " + Arrays.toString(constantColumnNames));
    }
    DataTable att;
    if (statTable.getNrRows() > 0) {
        att = new Normalizer(statTable, includeNames).doZScoreNorm(// no iteration needed
        exec.createSubProgress(0.0));
    } else {
        att = statTable;
    }
    normProg.setProgress(1.0);
    exec.setMessage("Calculating correlation measure");
    ExecutionMonitor detProg = exec.createSubProgress(progDetermine);
    int rowIndex = 0;
    double[] buf = new double[l];
    DataCell[] catBuf = new DataCell[l];
    boolean containsMissing = false;
    for (DataRow r : att) {
        detProg.checkCanceled();
        for (int i = 0; i < l; i++) {
            catBuf[i] = null;
            buf[i] = Double.NaN;
            DataCell c = r.getCell(i);
            // missing value is also a possible value here
            if (possibleValues[i] != null) {
                catBuf[i] = c;
            } else if (c.isMissing()) {
                containsMissing = true;
            } else if (filterTableSpec.getColumnSpec(i).getType().isCompatible(DoubleValue.class)) {
                buf[i] = ((DoubleValue) c).getDoubleValue();
            }
        }
        valIndex = 0;
        for (int i = 0; i < l; i++) {
            for (int j = i + 1; j < l; j++) {
                double b1 = buf[i];
                double b2 = buf[j];
                if (!Double.isNaN(b1) && !Double.isNaN(b2)) {
                    double old = nominatorMatrix.get(i, j);
                    nominatorMatrix.set(i, j, old + b1 * b2);
                } else if (catBuf[i] != null && catBuf[j] != null) {
                    int iIndex = possibleValues[i].get(catBuf[i]);
                    assert iIndex >= 0 : "Value unknown in value list " + "of column " + includeNames[i] + ": " + catBuf[i];
                    int jIndex = possibleValues[j].get(catBuf[j]);
                    assert jIndex >= 0 : "Value unknown in value list " + "of column " + includeNames[j] + ": " + catBuf[j];
                    contingencyTables[valIndex][iIndex][jIndex]++;
                }
                valIndex++;
            }
        }
        rowIndex++;
        detProg.setProgress(rowIndex / rC, "Processing row " + rowIndex + " (\"" + r.getKey() + "\")");
    }
    if (containsMissing) {
        setWarningMessage("Some row(s) contained missing values.");
    }
    detProg.setProgress(1.0);
    double normalizer = 1.0 / (rC - 1.0);
    valIndex = 0;
    for (int i = 0; i < l; i++) {
        for (int j = i + 1; j < l; j++) {
            if (contingencyTables[valIndex] != null) {
                nominatorMatrix.set(i, j, computeCramersV(contingencyTables[valIndex]));
            } else if (!Double.isNaN(nominatorMatrix.get(i, j))) {
                double old = nominatorMatrix.get(i, j);
                nominatorMatrix.set(i, j, old * normalizer);
            }
            // else pair of columns is double - string (for instance)
            valIndex++;
        }
    }
    normProg.setProgress(progDetermine);
    PMCCPortObjectAndSpec pmccModel = new PMCCPortObjectAndSpec(includeNames, nominatorMatrix);
    ExecutionContext subExec = exec.createSubExecutionContext(progFinish);
    BufferedDataTable out = pmccModel.createCorrelationMatrix(subExec);
    m_correlationTable = out;
    return new PortObject[] { out, pmccModel };
}
Also used : LinkedHashSet(java.util.LinkedHashSet) DataTable(org.knime.core.data.DataTable) BufferedDataTable(org.knime.core.node.BufferedDataTable) DataTableSpec(org.knime.core.data.DataTableSpec) FilterColumnTable(org.knime.base.data.filter.column.FilterColumnTable) StatisticsTable(org.knime.base.data.statistics.StatisticsTable) SettingsModelFilterString(org.knime.core.node.defaultnodesettings.SettingsModelFilterString) DataRow(org.knime.core.data.DataRow) LinkedHashMap(java.util.LinkedHashMap) DataColumnSpec(org.knime.core.data.DataColumnSpec) BufferedDataTable(org.knime.core.node.BufferedDataTable) DataType(org.knime.core.data.DataType) ExecutionMonitor(org.knime.core.node.ExecutionMonitor) PortObject(org.knime.core.node.port.PortObject) Normalizer(org.knime.base.data.normalize.Normalizer) ExecutionContext(org.knime.core.node.ExecutionContext) DoubleValue(org.knime.core.data.DoubleValue) HalfDoubleMatrix(org.knime.base.util.HalfDoubleMatrix) DataCell(org.knime.core.data.DataCell) LinkedHashMap(java.util.LinkedHashMap) Map(java.util.Map)

Aggregations

HalfDoubleMatrix (org.knime.base.util.HalfDoubleMatrix)8 DataRow (org.knime.core.data.DataRow)5 BufferedDataTable (org.knime.core.node.BufferedDataTable)4 PMCCPortObjectAndSpec (org.knime.base.node.preproc.correlation.pmcc.PMCCPortObjectAndSpec)3 DataCell (org.knime.core.data.DataCell)3 DataTableSpec (org.knime.core.data.DataTableSpec)3 DoubleValue (org.knime.core.data.DoubleValue)3 ExecutionContext (org.knime.core.node.ExecutionContext)3 PortObject (org.knime.core.node.port.PortObject)3 ColumnRearranger (org.knime.core.data.container.ColumnRearranger)2 ExecutionMonitor (org.knime.core.node.ExecutionMonitor)2 IOException (java.io.IOException)1 LinkedHashMap (java.util.LinkedHashMap)1 LinkedHashSet (java.util.LinkedHashSet)1 Map (java.util.Map)1 FilterColumnTable (org.knime.base.data.filter.column.FilterColumnTable)1 Normalizer (org.knime.base.data.normalize.Normalizer)1 StatisticsTable (org.knime.base.data.statistics.StatisticsTable)1 DataColumnSpec (org.knime.core.data.DataColumnSpec)1 DataTable (org.knime.core.data.DataTable)1