use of org.knime.base.node.mine.regression.linear.LinearRegressionContent in project knime-core by knime.
the class LinRegLearnerNodeView method modelChanged.
/**
* {@inheritDoc}
*/
@Override
protected void modelChanged() {
LinRegLearnerNodeModel model = getNodeModel();
m_pane.setText("");
LinearRegressionContent params = model.getParams();
int nrRows = model.getNrRows();
int nrSkipped = model.getNrRowsSkipped();
final StringBuilder buffer = new StringBuilder();
buffer.append("<html>\n");
buffer.append("<body>\n");
buffer.append("<h1>Statistics on Linear Regression</h1>");
buffer.append("<hr>\n");
if (params == null) {
buffer.append("No parameters available.\n");
} else {
DataTableSpec outSpec = params.getSpec();
double[] multipliers = params.getMultipliers();
buffer.append("<table>\n");
buffer.append("<caption align=\"left\">Parameters</caption>");
buffer.append("<tr>");
buffer.append("<th>Column</th>");
buffer.append("<th>Value</th>");
buffer.append("</tr>");
for (int i = 0; i < multipliers.length + 1; i++) {
buffer.append("<tr>\n");
buffer.append("<td>\n");
String key;
double value;
if (i == 0) {
key = "offset";
value = params.getOffset();
} else {
key = outSpec.getColumnSpec(i - 1).getName();
value = multipliers[i - 1];
}
buffer.append(key);
buffer.append("\n</td>\n");
buffer.append("<td>\n");
String format = DoubleFormat.formatDouble(value);
buffer.append(format);
buffer.append("\n</td>\n");
buffer.append("</tr>\n");
}
buffer.append("</table>\n");
}
buffer.append("<hr>\n");
buffer.append("<table>\n");
buffer.append("<caption align=\"left\">Statistics</caption>\n");
buffer.append("<tr>\n");
buffer.append("<td>\n");
buffer.append("Total Row Count\n");
buffer.append("</td>\n");
buffer.append("<td>\n");
buffer.append(nrRows);
buffer.append("\n</td>\n");
buffer.append("</tr>\n");
buffer.append("<tr>\n");
buffer.append("<td>\n");
buffer.append("Rows Processed\n");
buffer.append("</td>\n");
buffer.append("<td align=\"right\">\n");
buffer.append(nrRows - nrSkipped);
buffer.append("\n</td>\n");
buffer.append("</tr>\n");
buffer.append("<tr>\n");
buffer.append("<td>\n");
buffer.append("Rows Skipped\n");
buffer.append("</td>\n");
buffer.append("<td align=\"right\">\n");
buffer.append(nrSkipped);
buffer.append("\n</td>\n");
buffer.append("</tr>\n");
buffer.append("</table>\n");
if (model.isCalcError()) {
double error = model.getError();
buffer.append("<hr>\n");
buffer.append("<table>\n");
buffer.append("<caption align=\"left\">Error</caption>\n");
buffer.append("<tr>\n");
buffer.append("<td>\n");
buffer.append("Total Squared Error\n");
buffer.append("</td>\n");
buffer.append("<td>\n");
buffer.append(DoubleFormat.formatDouble(error));
buffer.append("\n</td>\n");
buffer.append("</tr>\n");
buffer.append("<tr>\n");
buffer.append("<td>\n");
buffer.append("Squared Error per Row\n");
buffer.append("</td>\n");
buffer.append("<td>\n");
buffer.append(DoubleFormat.formatDouble(error / (nrRows - nrSkipped)));
buffer.append("\n</td>\n");
buffer.append("</tr>\n");
buffer.append("</table>\n");
}
buffer.append("</body>\n");
buffer.append("</html>\n");
m_pane.setText(buffer.toString());
m_pane.revalidate();
}
use of org.knime.base.node.mine.regression.linear.LinearRegressionContent in project knime-core by knime.
the class LinRegLearnerNodeModel method execute.
/**
* {@inheritDoc}
*/
@Override
protected PortObject[] execute(final PortObject[] inData, final ExecutionContext exec) throws Exception {
/*
* What comes next is the matrix calculation, solving A \times w = b
* where A is the matrix having the training data (as many rows as there
* are rows in inData[0], w is the vector of weights to learn (number of
* variables) and b is the target output
*/
// reset was called, must be cleared
final BufferedDataTable data = (BufferedDataTable) inData[0];
final DataTableSpec spec = data.getDataTableSpec();
final String[] includes = computeIncludes(spec);
final int nrUnknown = includes.length + 1;
double[] means = new double[includes.length];
// indices of the columns in m_includes
final int[] colIndizes = new int[includes.length];
for (int i = 0; i < includes.length; i++) {
colIndizes[i] = spec.findColumnIndex(includes[i]);
}
// index of m_target
final int target = spec.findColumnIndex(m_target);
// this is the matrix (A^T x A) where A is the training data including
// one column fixed to one.
// (we do it here manually in order to avoid to get all the data in
// double[][])
double[][] ata = new double[nrUnknown][nrUnknown];
double[] buffer = new double[nrUnknown];
// we memorize for each row if it contains missing values.
BitSet missingSet = new BitSet();
m_nrRows = data.getRowCount();
int myProgress = 0;
// we need 2 or 3 scans on the data (first run was done already)
final double totalProgress = (2 + (m_isCalcError ? 1 : 0)) * m_nrRows;
int rowCount = 0;
boolean hasPrintedWarning = false;
for (RowIterator it = data.iterator(); it.hasNext(); rowCount++) {
DataRow row = it.next();
myProgress++;
exec.setProgress(myProgress / totalProgress, "Calculating matrix " + (rowCount + 1) + " (\"" + row.getKey().getString() + "\")");
exec.checkCanceled();
DataCell targetValue = row.getCell(target);
// read data from row into buffer, skip missing value rows
boolean containsMissing = targetValue.isMissing() || readIntoBuffer(row, buffer, colIndizes);
missingSet.set(rowCount, containsMissing);
if (containsMissing) {
String errorMessage = "Row \"" + row.getKey().getString() + "\" contains missing values, skipping it.";
if (!hasPrintedWarning) {
LOGGER.warn(errorMessage + " Suppress further warnings.");
hasPrintedWarning = true;
} else {
LOGGER.debug(errorMessage);
}
m_nrRowsSkipped++;
// with next row
continue;
}
updateMean(buffer, means);
// the matrix is symmetric
for (int i = 0; i < nrUnknown; i++) {
for (int j = 0; j < nrUnknown; j++) {
ata[i][j] += buffer[i] * buffer[j];
}
}
}
assert (m_nrRows == rowCount);
normalizeMean(means);
// no unique solution when there are less rows than unknown variables
if (rowCount <= nrUnknown) {
throw new Exception("Too few rows to perform regression (" + rowCount + " rows, but degree of freedom of " + nrUnknown + ")");
}
exec.setMessage("Calculating pseudo inverse...");
double[][] ataInverse = MathUtils.inverse(ata);
checkForNaN(ataInverse);
// multiply with A^T and b, i.e. (A^T x A)^-1 x A^T x b
double[] multipliers = new double[nrUnknown];
rowCount = 0;
for (RowIterator it = data.iterator(); it.hasNext(); rowCount++) {
DataRow row = it.next();
exec.setMessage("Determining output " + (rowCount + 1) + " (\"" + row.getKey().getString() + "\")");
myProgress++;
exec.setProgress(myProgress / totalProgress);
exec.checkCanceled();
// does row containing missing values?
if (missingSet.get(rowCount)) {
// error has printed above, silently ignore here.
continue;
}
boolean containsMissing = readIntoBuffer(row, buffer, colIndizes);
assert !containsMissing;
DataCell targetValue = row.getCell(target);
double b = ((DoubleValue) targetValue).getDoubleValue();
for (int i = 0; i < nrUnknown; i++) {
double buf = 0.0;
for (int j = 0; j < nrUnknown; j++) {
buf += ataInverse[i][j] * buffer[j];
}
multipliers[i] += buf * b;
}
}
if (m_isCalcError) {
assert m_error == 0.0;
rowCount = 0;
for (RowIterator it = data.iterator(); it.hasNext(); rowCount++) {
DataRow row = it.next();
exec.setMessage("Calculating error " + (rowCount + 1) + " (\"" + row.getKey().getString() + "\")");
myProgress++;
exec.setProgress(myProgress / totalProgress);
exec.checkCanceled();
// does row containing missing values?
if (missingSet.get(rowCount)) {
// error has printed above, silently ignore here.
continue;
}
boolean hasMissing = readIntoBuffer(row, buffer, colIndizes);
assert !hasMissing;
DataCell targetValue = row.getCell(target);
double b = ((DoubleValue) targetValue).getDoubleValue();
double out = 0.0;
for (int i = 0; i < nrUnknown; i++) {
out += multipliers[i] * buffer[i];
}
m_error += (b - out) * (b - out);
}
}
// handle the optional PMML input
PMMLPortObject inPMMLPort = (PMMLPortObject) inData[1];
DataTableSpec outSpec = getLearningSpec(spec);
double offset = multipliers[0];
multipliers = Arrays.copyOfRange(multipliers, 1, multipliers.length);
m_params = new LinearRegressionContent(outSpec, offset, multipliers, means);
// cache the entire table as otherwise the color information
// may be lost (filtering out the "colored" column)
m_rowContainer = new DefaultDataArray(data, m_firstRowPaint, m_rowCountPaint);
m_actualUsedColumns = includes;
return new PortObject[] { m_params.createPortObject(inPMMLPort, spec, outSpec) };
}
use of org.knime.base.node.mine.regression.linear.LinearRegressionContent in project knime-core by knime.
the class LinRegLinePlotter method updateSize.
/**
* First calls super then adapts the regression line.
*/
@Override
public void updateSize() {
if (getXAxis() == null || getXAxis().getCoordinate() == null || getYAxis() == null || getYAxis().getCoordinate() == null) {
return;
}
super.updateSize();
DataProvider dataProvider = getDataProvider();
if (dataProvider == null) {
return;
}
DataArray data = dataProvider.getDataArray(0);
if (data == null) {
return;
}
LinearRegressionContent params = ((LinRegDataProvider) dataProvider).getParams();
if (params == null) {
return;
}
double xMin = ((NumericCoordinate) getXAxis().getCoordinate()).getMinDomainValue();
double xMax = ((NumericCoordinate) getXAxis().getCoordinate()).getMaxDomainValue();
String xName = getSelectedXColumn().getName();
String[] temp = ((LinRegDataProvider) dataProvider).getLearningColumns();
if (temp == null) {
return;
}
List<String> includedCols = Arrays.asList(temp);
if (!xName.equals(params.getTargetColumnName()) && includedCols.contains(xName)) {
double yMin = params.getApproximationFor(xName, xMin);
double yMax = params.getApproximationFor(xName, xMax);
((LinRegLineDrawingPane) getDrawingPane()).setLineFirstPoint(getMappedXValue(new DoubleCell(xMin)), getMappedYValue(new DoubleCell(yMin)));
((LinRegLineDrawingPane) getDrawingPane()).setLineLastPoint(getMappedXValue(new DoubleCell(xMax)), getMappedYValue(new DoubleCell(yMax)));
}
}
use of org.knime.base.node.mine.regression.linear.LinearRegressionContent in project knime-core by knime.
the class LinRegLinePlotter method updatePaintModel.
/**
* Retrieves the linear regression params, updates the column selection
* boxes appropriately and adds the regression line to the scatterplot.
*/
@Override
public void updatePaintModel() {
DataProvider dataProvider = getDataProvider();
if (dataProvider == null) {
return;
}
DataArray data = dataProvider.getDataArray(0);
if (data == null) {
return;
}
LinearRegressionContent params = ((LinRegDataProvider) dataProvider).getParams();
if (params == null) {
return;
}
// set the target column to fix
((LinRegLinePlotterProperties) getProperties()).setTargetColumn(params.getTargetColumnName());
// get the included columns
String[] includedCols = ((LinRegDataProvider) dataProvider).getLearningColumns();
if (includedCols == null) {
return;
}
((LinRegLinePlotterProperties) getProperties()).setIncludedColumns(includedCols);
// update the combo boxes
DataTableSpec spec = data.getDataTableSpec();
((LinRegLinePlotterProperties) getProperties()).update(spec);
super.updatePaintModel();
double xMin = ((NumericCoordinate) getXAxis().getCoordinate()).getMinDomainValue();
double xMax = ((NumericCoordinate) getXAxis().getCoordinate()).getMaxDomainValue();
String xName = getSelectedXColumn().getName();
List<String> includedList = Arrays.asList(includedCols);
if (!xName.equals(params.getTargetColumnName()) && includedList.contains(xName)) {
double yMin = params.getApproximationFor(xName, xMin);
double yMax = params.getApproximationFor(xName, xMax);
((LinRegLineDrawingPane) getDrawingPane()).setLineFirstPoint(getMappedXValue(new DoubleCell(xMin)), getMappedYValue(new DoubleCell(yMin)));
((LinRegLineDrawingPane) getDrawingPane()).setLineLastPoint(getMappedXValue(new DoubleCell(xMax)), getMappedYValue(new DoubleCell(yMax)));
getDrawingPane().repaint();
}
}
Aggregations