use of org.apache.commons.math.stat.descriptive.summary.SumOfSquares in project knime-core by knime.
the class NumericScorerNodeModel method execute.
/**
* {@inheritDoc}
*/
@Override
protected BufferedDataTable[] execute(final BufferedDataTable[] inData, final ExecutionContext exec) throws Exception {
DataTableSpec spec = inData[0].getSpec();
BufferedDataContainer container = exec.createDataContainer(createOutputSpec(spec));
int referenceIdx = spec.findColumnIndex(m_numericScorerSettings.getReferenceColumnName());
int predictionIdx = spec.findColumnIndex(m_numericScorerSettings.getPredictionColumnName());
final Mean meanObserved = new Mean(), meanPredicted = new Mean();
final Mean absError = new Mean(), squaredError = new Mean();
final Mean signedDiff = new Mean();
final SumOfSquares ssTot = new SumOfSquares(), ssRes = new SumOfSquares();
int skippedRowCount = 0;
for (DataRow row : inData[0]) {
DataCell refCell = row.getCell(referenceIdx);
DataCell predCell = row.getCell(predictionIdx);
if (refCell.isMissing()) {
skippedRowCount++;
continue;
}
double ref = ((DoubleValue) refCell).getDoubleValue();
if (predCell.isMissing()) {
throw new IllegalArgumentException("Missing value in prediction column in row: " + row.getKey());
}
double pred = ((DoubleValue) predCell).getDoubleValue();
meanObserved.increment(ref);
meanPredicted.increment(pred);
absError.increment(Math.abs(ref - pred));
squaredError.increment((ref - pred) * (ref - pred));
signedDiff.increment(pred - ref);
}
for (DataRow row : inData[0]) {
DataCell refCell = row.getCell(referenceIdx);
DataCell predCell = row.getCell(predictionIdx);
if (refCell.isMissing()) {
continue;
}
double ref = ((DoubleValue) refCell).getDoubleValue();
double pred = ((DoubleValue) predCell).getDoubleValue();
ssTot.increment(ref - meanObserved.getResult());
ssRes.increment(ref - pred);
}
container.addRowToTable(new DefaultRow("R^2", m_rSquare = 1 - ssRes.getResult() / ssTot.getResult()));
container.addRowToTable(new DefaultRow("mean absolute error", m_meanAbsError = absError.getResult()));
container.addRowToTable(new DefaultRow("mean squared error", m_meanSquaredError = squaredError.getResult()));
container.addRowToTable(new DefaultRow("root mean squared deviation", m_rmsd = Math.sqrt(squaredError.getResult())));
container.addRowToTable(new DefaultRow("mean signed difference", m_meanSignedDifference = signedDiff.getResult()));
container.close();
if (skippedRowCount > 0) {
setWarningMessage("Skipped " + skippedRowCount + " rows, because the reference column contained missing values there.");
}
pushFlowVars(false);
return new BufferedDataTable[] { container.getTable() };
}
Aggregations