use of org.knime.base.util.HalfDoubleMatrix in project knime-core by knime.
the class RankCorrelationComputeNodeModel method execute.
/**
* {@inheritDoc}
*/
@Override
protected PortObject[] execute(final PortObject[] inData, final ExecutionContext exec) throws Exception {
final BufferedDataTable in = (BufferedDataTable) inData[0];
final DataTableSpec inSpec = in.getDataTableSpec();
ColumnRearranger filteredTableRearranger = new ColumnRearranger(inSpec);
String[] includeNames = m_columnFilterModel.applyTo(inSpec).getIncludes();
filteredTableRearranger.keepOnly(includeNames);
final BufferedDataTable filteredTable = exec.createColumnRearrangeTable(in, filteredTableRearranger, exec.createSilentSubExecutionContext(0.0));
final BufferedDataTable noMissTable = filterMissings(filteredTable, exec);
if (noMissTable.getRowCount() < filteredTable.getRowCount()) {
setWarningMessage("Rows containing missing values are filtered. Please resolve them" + " with the Missing Value node.");
}
double progStep1 = 0.48;
double progStep2 = 0.48;
double progFinish = 1.0 - progStep1 - progStep2;
SortedCorrelationComputer calculator = new SortedCorrelationComputer();
exec.setMessage("Generate ranking");
ExecutionContext execStep1 = exec.createSubExecutionContext(progStep1);
calculator.generateRank(noMissTable, execStep1);
execStep1.setProgress(1.0);
exec.setMessage("Calculating correlation values");
ExecutionContext execStep2 = exec.createSubExecutionContext(progStep2);
HalfDoubleMatrix correlationMatrix;
if (m_corrType.getStringValue().equals(CFG_SPEARMAN)) {
correlationMatrix = calculator.calculateSpearman(execStep2);
} else {
correlationMatrix = calculator.calculateKendallInMemory(m_corrType.getStringValue(), execStep2);
}
execStep2.setProgress(1.0);
exec.setMessage("Assembling output");
ExecutionContext execFinish = exec.createSubExecutionContext(progFinish);
PMCCPortObjectAndSpec pmccModel = new PMCCPortObjectAndSpec(includeNames, correlationMatrix);
BufferedDataTable out = pmccModel.createCorrelationMatrix(execFinish);
m_correlationTable = out;
if (in.getRowCount() == 0) {
setWarningMessage("Empty input table! Generating missing values as correlation values.");
}
return new PortObject[] { out, pmccModel, calculator.getRankTable() };
}
use of org.knime.base.util.HalfDoubleMatrix in project knime-core by knime.
the class SortedCorrelationComputer method calculateKendallInMemory.
/**
* Calculates the kendall rank for all pairs of Data table columns based on previously calculated ranks.
*
* @param exec the Execution context.
* @param corrType the type of correlation used, as defined in CorrelationComputeNodeModel
* @return the output matrix to be turned into the output model
* @throws CanceledExecutionException if canceled by users
*/
HalfDoubleMatrix calculateKendallInMemory(final String corrType, final ExecutionMonitor exec) throws CanceledExecutionException {
// the ranking must have been calculated before
assert (m_rank != null);
final int coCount = m_rank.getDataTableSpec().getNumColumns();
final int rowCount = m_rank.getRowCount();
double[][] rank = new double[rowCount][coCount];
int c = 0;
for (DataRow row : m_rank) {
for (int k = 0; k < coCount; k++) {
rank[c][k] = ((DoubleValue) row.getCell(k)).getDoubleValue();
}
c++;
}
HalfDoubleMatrix nominatorMatrix = new HalfDoubleMatrix(coCount, /*includeDiagonal=*/
false);
double[][] cMatrix = new double[coCount][coCount];
double[][] dMatrix = new double[coCount][coCount];
double[][] txMatrix = new double[coCount][coCount];
double[][] tyMatrix = new double[coCount][coCount];
for (int rowIn1 = 0; rowIn1 < rowCount; rowIn1++) {
for (int rowIn2 = 0; rowIn2 < rowCount; rowIn2++) {
exec.checkCanceled();
for (int i = 0; i < coCount; i++) {
final double x1 = rank[rowIn1][i];
final double x2 = rank[rowIn2][i];
for (int j = 0; j < coCount; j++) {
final double y1 = rank[rowIn1][j];
final double y2 = rank[rowIn2][j];
if (x1 < x2 && y1 < y2) {
// values are concordant
cMatrix[i][j]++;
} else if (x1 < x2 && y1 > y2) {
// values are discordant
dMatrix[i][j]++;
} else if (x1 != x2 && y1 == y2) {
// values are bounded in y
tyMatrix[i][j]++;
} else if (x1 == x2 && y1 != y2) {
// values are bounded in x
txMatrix[i][j]++;
} else if (x1 == x2 && y1 == y2) {
// values are bounded in x and y
// txyMatrix[i][j]++; // no measure need this count
}
}
}
}
exec.checkCanceled();
exec.setProgress(0.95 * rowIn1 / rowCount, String.format("Calculating - %d/%d ", rowIn1, rowCount));
}
if (corrType.equals(RankCorrelationComputeNodeModel.CFG_KENDALLA)) {
double nrOfRows = m_rank.getRowCount();
// kendalls Tau a
double divisor = (nrOfRows * (nrOfRows - 1.0)) * 0.5;
for (int i = 0; i < coCount; i++) {
for (int j = i + 1; j < coCount; j++) {
nominatorMatrix.set(i, j, (cMatrix[i][j] - dMatrix[i][j]) / divisor);
}
exec.setProgress(0.05 * i / coCount, "Calculating correlations");
}
} else if (corrType.equals(RankCorrelationComputeNodeModel.CFG_KENDALLB)) {
double n0 = rowCount * (rowCount - 1) * 0.5;
// kendalls Tau b
for (int i = 0; i < coCount; i++) {
for (int j = i + 1; j < coCount; j++) {
// // we divide tx and ty by 2, as each of the pairs was counted twice
double n1 = txMatrix[i][j] * 0.5;
double n2 = tyMatrix[i][j] * 0.5;
double div = Math.sqrt((n0 - n1) * (n0 - n2));
nominatorMatrix.set(i, j, (cMatrix[i][j] - dMatrix[i][j]) / div);
}
exec.setProgress(0.05 * i / coCount, "Calculating correlations");
}
} else if (corrType.equals(RankCorrelationComputeNodeModel.CFG_KRUSKALAL)) {
// Kruskals Gamma
for (int i = 0; i < coCount; i++) {
for (int j = i + 1; j < coCount; j++) {
nominatorMatrix.set(i, j, (cMatrix[i][j] - dMatrix[i][j]) / (cMatrix[i][j] + dMatrix[i][j]));
}
exec.setProgress(0.05 * i / coCount, "Calculating correlations");
}
}
return nominatorMatrix;
}
use of org.knime.base.util.HalfDoubleMatrix in project knime-core by knime.
the class PMCCNodeModel method execute.
/**
* {@inheritDoc}
*/
@Override
protected PortObject[] execute(final PortObject[] inData, final ExecutionContext exec) throws Exception {
final BufferedDataTable in = (BufferedDataTable) inData[0];
// floating point operation
final double rC = in.getRowCount();
int[] includes = getIncludes(in.getDataTableSpec());
String[] includeNames = m_columnIncludesList.getIncludeList().toArray(new String[0]);
double progNormalize = 0.3;
double progDetermine = 0.65;
double progFinish = 1.0 - progNormalize - progDetermine;
exec.setMessage("Normalizing data");
final ExecutionMonitor normProg = exec.createSubProgress(progNormalize);
FilterColumnTable filterTable = new FilterColumnTable(in, includes);
final int l = includes.length;
int nomCount = (l - 1) * l / 2;
final HalfDoubleMatrix nominatorMatrix = new HalfDoubleMatrix(includes.length, /*withDiagonal*/
false);
nominatorMatrix.fill(Double.NaN);
@SuppressWarnings("unchecked") final LinkedHashMap<DataCell, Integer>[] possibleValues = new LinkedHashMap[l];
DataTableSpec filterTableSpec = filterTable.getDataTableSpec();
for (int i = 0; i < l; i++) {
DataColumnSpec cs = filterTableSpec.getColumnSpec(i);
if (cs.getType().isCompatible(NominalValue.class)) {
possibleValues[i] = new LinkedHashMap<DataCell, Integer>();
}
}
final int possValueUpperBound = m_maxPossValueCountModel.getIntValue();
// determines possible values. We can't use those from the domain
// as the domain can also contain values not present in the data
// but in the contingency table we need rows/columns to have at least
// one cell with a value >= 1
StatisticsTable statTable = new StatisticsTable(filterTable) {
// that is sort of the constructor in this derived class
{
calculateAllMoments(in.getRowCount(), normProg);
}
@Override
protected void calculateMomentInSubClass(final DataRow row) {
for (int i = 0; i < l; i++) {
if (possibleValues[i] != null) {
DataCell c = row.getCell(i);
// note: also take missing value as possible value
possibleValues[i].put(c, null);
if (possibleValues[i].size() > possValueUpperBound) {
possibleValues[i] = null;
}
}
}
}
};
for (LinkedHashMap<DataCell, Integer> map : possibleValues) {
if (map != null) {
int index = 0;
for (Map.Entry<DataCell, Integer> entry : map.entrySet()) {
entry.setValue(index++);
}
}
}
// stores all pair-wise contingency tables,
// contingencyTables[i] == null <--> either column of the corresponding
// pair is non-categorical.
// What is a contingency table?
// http://en.wikipedia.org/wiki/Contingency_table
int[][][] contingencyTables = new int[nomCount][][];
// column which only contain one value - no correlation available
LinkedHashSet<String> constantColumns = new LinkedHashSet<String>();
int valIndex = 0;
for (int i = 0; i < l; i++) {
for (int j = i + 1; j < l; j++) {
if (possibleValues[i] != null && possibleValues[j] != null) {
int iSize = possibleValues[i].size();
int jSize = possibleValues[j].size();
contingencyTables[valIndex] = new int[iSize][jSize];
}
DataColumnSpec colSpecI = filterTableSpec.getColumnSpec(i);
DataColumnSpec colSpecJ = filterTableSpec.getColumnSpec(j);
DataType ti = colSpecI.getType();
DataType tj = colSpecJ.getType();
if (ti.isCompatible(DoubleValue.class) && tj.isCompatible(DoubleValue.class)) {
// one of the two columns contains only one value
if (statTable.getVariance(i) < PMCCPortObjectAndSpec.ROUND_ERROR_OK) {
constantColumns.add(colSpecI.getName());
nominatorMatrix.set(i, j, Double.NaN);
} else if (statTable.getVariance(j) < PMCCPortObjectAndSpec.ROUND_ERROR_OK) {
constantColumns.add(colSpecJ.getName());
nominatorMatrix.set(i, j, Double.NaN);
} else {
nominatorMatrix.set(i, j, 0.0);
}
}
valIndex++;
}
}
// to other column (will be a missing value)
if (!constantColumns.isEmpty()) {
String[] constantColumnNames = constantColumns.toArray(new String[constantColumns.size()]);
NodeLogger.getLogger(getClass()).info("The following numeric " + "columns contain only one distinct value or have " + "otherwise a low standard deviation: " + Arrays.toString(constantColumnNames));
int maxLength = 4;
if (constantColumns.size() > maxLength) {
constantColumnNames = Arrays.copyOf(constantColumnNames, maxLength);
constantColumnNames[maxLength - 1] = "...";
}
setWarningMessage("Some columns contain only one distinct value: " + Arrays.toString(constantColumnNames));
}
DataTable att;
if (statTable.getNrRows() > 0) {
att = new Normalizer(statTable, includeNames).doZScoreNorm(// no iteration needed
exec.createSubProgress(0.0));
} else {
att = statTable;
}
normProg.setProgress(1.0);
exec.setMessage("Calculating correlation measure");
ExecutionMonitor detProg = exec.createSubProgress(progDetermine);
int rowIndex = 0;
double[] buf = new double[l];
DataCell[] catBuf = new DataCell[l];
boolean containsMissing = false;
for (DataRow r : att) {
detProg.checkCanceled();
for (int i = 0; i < l; i++) {
catBuf[i] = null;
buf[i] = Double.NaN;
DataCell c = r.getCell(i);
// missing value is also a possible value here
if (possibleValues[i] != null) {
catBuf[i] = c;
} else if (c.isMissing()) {
containsMissing = true;
} else if (filterTableSpec.getColumnSpec(i).getType().isCompatible(DoubleValue.class)) {
buf[i] = ((DoubleValue) c).getDoubleValue();
}
}
valIndex = 0;
for (int i = 0; i < l; i++) {
for (int j = i + 1; j < l; j++) {
double b1 = buf[i];
double b2 = buf[j];
if (!Double.isNaN(b1) && !Double.isNaN(b2)) {
double old = nominatorMatrix.get(i, j);
nominatorMatrix.set(i, j, old + b1 * b2);
} else if (catBuf[i] != null && catBuf[j] != null) {
int iIndex = possibleValues[i].get(catBuf[i]);
assert iIndex >= 0 : "Value unknown in value list " + "of column " + includeNames[i] + ": " + catBuf[i];
int jIndex = possibleValues[j].get(catBuf[j]);
assert jIndex >= 0 : "Value unknown in value list " + "of column " + includeNames[j] + ": " + catBuf[j];
contingencyTables[valIndex][iIndex][jIndex]++;
}
valIndex++;
}
}
rowIndex++;
detProg.setProgress(rowIndex / rC, "Processing row " + rowIndex + " (\"" + r.getKey() + "\")");
}
if (containsMissing) {
setWarningMessage("Some row(s) contained missing values.");
}
detProg.setProgress(1.0);
double normalizer = 1.0 / (rC - 1.0);
valIndex = 0;
for (int i = 0; i < l; i++) {
for (int j = i + 1; j < l; j++) {
if (contingencyTables[valIndex] != null) {
nominatorMatrix.set(i, j, computeCramersV(contingencyTables[valIndex]));
} else if (!Double.isNaN(nominatorMatrix.get(i, j))) {
double old = nominatorMatrix.get(i, j);
nominatorMatrix.set(i, j, old * normalizer);
}
// else pair of columns is double - string (for instance)
valIndex++;
}
}
normProg.setProgress(progDetermine);
PMCCPortObjectAndSpec pmccModel = new PMCCPortObjectAndSpec(includeNames, nominatorMatrix);
ExecutionContext subExec = exec.createSubExecutionContext(progFinish);
BufferedDataTable out = pmccModel.createCorrelationMatrix(subExec);
m_correlationTable = out;
return new PortObject[] { out, pmccModel };
}
Aggregations