Search in sources :

Example 1 with LinearRegressionContent

use of org.knime.base.node.mine.regression.linear.LinearRegressionContent in project knime-core by knime.

the class LinRegLearnerNodeView method modelChanged.

/**
 * {@inheritDoc}
 */
@Override
protected void modelChanged() {
    LinRegLearnerNodeModel model = getNodeModel();
    m_pane.setText("");
    LinearRegressionContent params = model.getParams();
    int nrRows = model.getNrRows();
    int nrSkipped = model.getNrRowsSkipped();
    final StringBuilder buffer = new StringBuilder();
    buffer.append("<html>\n");
    buffer.append("<body>\n");
    buffer.append("<h1>Statistics on Linear Regression</h1>");
    buffer.append("<hr>\n");
    if (params == null) {
        buffer.append("No parameters available.\n");
    } else {
        DataTableSpec outSpec = params.getSpec();
        double[] multipliers = params.getMultipliers();
        buffer.append("<table>\n");
        buffer.append("<caption align=\"left\">Parameters</caption>");
        buffer.append("<tr>");
        buffer.append("<th>Column</th>");
        buffer.append("<th>Value</th>");
        buffer.append("</tr>");
        for (int i = 0; i < multipliers.length + 1; i++) {
            buffer.append("<tr>\n");
            buffer.append("<td>\n");
            String key;
            double value;
            if (i == 0) {
                key = "offset";
                value = params.getOffset();
            } else {
                key = outSpec.getColumnSpec(i - 1).getName();
                value = multipliers[i - 1];
            }
            buffer.append(key);
            buffer.append("\n</td>\n");
            buffer.append("<td>\n");
            String format = DoubleFormat.formatDouble(value);
            buffer.append(format);
            buffer.append("\n</td>\n");
            buffer.append("</tr>\n");
        }
        buffer.append("</table>\n");
    }
    buffer.append("<hr>\n");
    buffer.append("<table>\n");
    buffer.append("<caption align=\"left\">Statistics</caption>\n");
    buffer.append("<tr>\n");
    buffer.append("<td>\n");
    buffer.append("Total Row Count\n");
    buffer.append("</td>\n");
    buffer.append("<td>\n");
    buffer.append(nrRows);
    buffer.append("\n</td>\n");
    buffer.append("</tr>\n");
    buffer.append("<tr>\n");
    buffer.append("<td>\n");
    buffer.append("Rows Processed\n");
    buffer.append("</td>\n");
    buffer.append("<td align=\"right\">\n");
    buffer.append(nrRows - nrSkipped);
    buffer.append("\n</td>\n");
    buffer.append("</tr>\n");
    buffer.append("<tr>\n");
    buffer.append("<td>\n");
    buffer.append("Rows Skipped\n");
    buffer.append("</td>\n");
    buffer.append("<td align=\"right\">\n");
    buffer.append(nrSkipped);
    buffer.append("\n</td>\n");
    buffer.append("</tr>\n");
    buffer.append("</table>\n");
    if (model.isCalcError()) {
        double error = model.getError();
        buffer.append("<hr>\n");
        buffer.append("<table>\n");
        buffer.append("<caption align=\"left\">Error</caption>\n");
        buffer.append("<tr>\n");
        buffer.append("<td>\n");
        buffer.append("Total Squared Error\n");
        buffer.append("</td>\n");
        buffer.append("<td>\n");
        buffer.append(DoubleFormat.formatDouble(error));
        buffer.append("\n</td>\n");
        buffer.append("</tr>\n");
        buffer.append("<tr>\n");
        buffer.append("<td>\n");
        buffer.append("Squared Error per Row\n");
        buffer.append("</td>\n");
        buffer.append("<td>\n");
        buffer.append(DoubleFormat.formatDouble(error / (nrRows - nrSkipped)));
        buffer.append("\n</td>\n");
        buffer.append("</tr>\n");
        buffer.append("</table>\n");
    }
    buffer.append("</body>\n");
    buffer.append("</html>\n");
    m_pane.setText(buffer.toString());
    m_pane.revalidate();
}
Also used : DataTableSpec(org.knime.core.data.DataTableSpec) LinearRegressionContent(org.knime.base.node.mine.regression.linear.LinearRegressionContent)

Example 2 with LinearRegressionContent

use of org.knime.base.node.mine.regression.linear.LinearRegressionContent in project knime-core by knime.

the class LinRegLearnerNodeModel method execute.

/**
 * {@inheritDoc}
 */
@Override
protected PortObject[] execute(final PortObject[] inData, final ExecutionContext exec) throws Exception {
    /*
         * What comes next is the matrix calculation, solving A \times w = b
         * where A is the matrix having the training data (as many rows as there
         * are rows in inData[0], w is the vector of weights to learn (number of
         * variables) and b is the target output
         */
    // reset was called, must be cleared
    final BufferedDataTable data = (BufferedDataTable) inData[0];
    final DataTableSpec spec = data.getDataTableSpec();
    final String[] includes = computeIncludes(spec);
    final int nrUnknown = includes.length + 1;
    double[] means = new double[includes.length];
    // indices of the columns in m_includes
    final int[] colIndizes = new int[includes.length];
    for (int i = 0; i < includes.length; i++) {
        colIndizes[i] = spec.findColumnIndex(includes[i]);
    }
    // index of m_target
    final int target = spec.findColumnIndex(m_target);
    // this is the matrix (A^T x A) where A is the training data including
    // one column fixed to one.
    // (we do it here manually in order to avoid to get all the data in
    // double[][])
    double[][] ata = new double[nrUnknown][nrUnknown];
    double[] buffer = new double[nrUnknown];
    // we memorize for each row if it contains missing values.
    BitSet missingSet = new BitSet();
    m_nrRows = data.getRowCount();
    int myProgress = 0;
    // we need 2 or 3 scans on the data (first run was done already)
    final double totalProgress = (2 + (m_isCalcError ? 1 : 0)) * m_nrRows;
    int rowCount = 0;
    boolean hasPrintedWarning = false;
    for (RowIterator it = data.iterator(); it.hasNext(); rowCount++) {
        DataRow row = it.next();
        myProgress++;
        exec.setProgress(myProgress / totalProgress, "Calculating matrix " + (rowCount + 1) + " (\"" + row.getKey().getString() + "\")");
        exec.checkCanceled();
        DataCell targetValue = row.getCell(target);
        // read data from row into buffer, skip missing value rows
        boolean containsMissing = targetValue.isMissing() || readIntoBuffer(row, buffer, colIndizes);
        missingSet.set(rowCount, containsMissing);
        if (containsMissing) {
            String errorMessage = "Row \"" + row.getKey().getString() + "\" contains missing values, skipping it.";
            if (!hasPrintedWarning) {
                LOGGER.warn(errorMessage + " Suppress further warnings.");
                hasPrintedWarning = true;
            } else {
                LOGGER.debug(errorMessage);
            }
            m_nrRowsSkipped++;
            // with next row
            continue;
        }
        updateMean(buffer, means);
        // the matrix is symmetric
        for (int i = 0; i < nrUnknown; i++) {
            for (int j = 0; j < nrUnknown; j++) {
                ata[i][j] += buffer[i] * buffer[j];
            }
        }
    }
    assert (m_nrRows == rowCount);
    normalizeMean(means);
    // no unique solution when there are less rows than unknown variables
    if (rowCount <= nrUnknown) {
        throw new Exception("Too few rows to perform regression (" + rowCount + " rows, but degree of freedom of " + nrUnknown + ")");
    }
    exec.setMessage("Calculating pseudo inverse...");
    double[][] ataInverse = MathUtils.inverse(ata);
    checkForNaN(ataInverse);
    // multiply with A^T and b, i.e. (A^T x A)^-1 x A^T x b
    double[] multipliers = new double[nrUnknown];
    rowCount = 0;
    for (RowIterator it = data.iterator(); it.hasNext(); rowCount++) {
        DataRow row = it.next();
        exec.setMessage("Determining output " + (rowCount + 1) + " (\"" + row.getKey().getString() + "\")");
        myProgress++;
        exec.setProgress(myProgress / totalProgress);
        exec.checkCanceled();
        // does row containing missing values?
        if (missingSet.get(rowCount)) {
            // error has printed above, silently ignore here.
            continue;
        }
        boolean containsMissing = readIntoBuffer(row, buffer, colIndizes);
        assert !containsMissing;
        DataCell targetValue = row.getCell(target);
        double b = ((DoubleValue) targetValue).getDoubleValue();
        for (int i = 0; i < nrUnknown; i++) {
            double buf = 0.0;
            for (int j = 0; j < nrUnknown; j++) {
                buf += ataInverse[i][j] * buffer[j];
            }
            multipliers[i] += buf * b;
        }
    }
    if (m_isCalcError) {
        assert m_error == 0.0;
        rowCount = 0;
        for (RowIterator it = data.iterator(); it.hasNext(); rowCount++) {
            DataRow row = it.next();
            exec.setMessage("Calculating error " + (rowCount + 1) + " (\"" + row.getKey().getString() + "\")");
            myProgress++;
            exec.setProgress(myProgress / totalProgress);
            exec.checkCanceled();
            // does row containing missing values?
            if (missingSet.get(rowCount)) {
                // error has printed above, silently ignore here.
                continue;
            }
            boolean hasMissing = readIntoBuffer(row, buffer, colIndizes);
            assert !hasMissing;
            DataCell targetValue = row.getCell(target);
            double b = ((DoubleValue) targetValue).getDoubleValue();
            double out = 0.0;
            for (int i = 0; i < nrUnknown; i++) {
                out += multipliers[i] * buffer[i];
            }
            m_error += (b - out) * (b - out);
        }
    }
    // handle the optional PMML input
    PMMLPortObject inPMMLPort = (PMMLPortObject) inData[1];
    DataTableSpec outSpec = getLearningSpec(spec);
    double offset = multipliers[0];
    multipliers = Arrays.copyOfRange(multipliers, 1, multipliers.length);
    m_params = new LinearRegressionContent(outSpec, offset, multipliers, means);
    // cache the entire table as otherwise the color information
    // may be lost (filtering out the "colored" column)
    m_rowContainer = new DefaultDataArray(data, m_firstRowPaint, m_rowCountPaint);
    m_actualUsedColumns = includes;
    return new PortObject[] { m_params.createPortObject(inPMMLPort, spec, outSpec) };
}
Also used : DataTableSpec(org.knime.core.data.DataTableSpec) DefaultDataArray(org.knime.base.node.util.DefaultDataArray) BitSet(java.util.BitSet) DataRow(org.knime.core.data.DataRow) InvalidSettingsException(org.knime.core.node.InvalidSettingsException) CanceledExecutionException(org.knime.core.node.CanceledExecutionException) IOException(java.io.IOException) DoubleValue(org.knime.core.data.DoubleValue) PMMLPortObject(org.knime.core.node.port.pmml.PMMLPortObject) RowIterator(org.knime.core.data.RowIterator) LinearRegressionContent(org.knime.base.node.mine.regression.linear.LinearRegressionContent) BufferedDataTable(org.knime.core.node.BufferedDataTable) DataCell(org.knime.core.data.DataCell) PMMLPortObject(org.knime.core.node.port.pmml.PMMLPortObject) PortObject(org.knime.core.node.port.PortObject)

Example 3 with LinearRegressionContent

use of org.knime.base.node.mine.regression.linear.LinearRegressionContent in project knime-core by knime.

the class LinRegLinePlotter method updateSize.

/**
 * First calls super then adapts the regression line.
 */
@Override
public void updateSize() {
    if (getXAxis() == null || getXAxis().getCoordinate() == null || getYAxis() == null || getYAxis().getCoordinate() == null) {
        return;
    }
    super.updateSize();
    DataProvider dataProvider = getDataProvider();
    if (dataProvider == null) {
        return;
    }
    DataArray data = dataProvider.getDataArray(0);
    if (data == null) {
        return;
    }
    LinearRegressionContent params = ((LinRegDataProvider) dataProvider).getParams();
    if (params == null) {
        return;
    }
    double xMin = ((NumericCoordinate) getXAxis().getCoordinate()).getMinDomainValue();
    double xMax = ((NumericCoordinate) getXAxis().getCoordinate()).getMaxDomainValue();
    String xName = getSelectedXColumn().getName();
    String[] temp = ((LinRegDataProvider) dataProvider).getLearningColumns();
    if (temp == null) {
        return;
    }
    List<String> includedCols = Arrays.asList(temp);
    if (!xName.equals(params.getTargetColumnName()) && includedCols.contains(xName)) {
        double yMin = params.getApproximationFor(xName, xMin);
        double yMax = params.getApproximationFor(xName, xMax);
        ((LinRegLineDrawingPane) getDrawingPane()).setLineFirstPoint(getMappedXValue(new DoubleCell(xMin)), getMappedYValue(new DoubleCell(yMin)));
        ((LinRegLineDrawingPane) getDrawingPane()).setLineLastPoint(getMappedXValue(new DoubleCell(xMax)), getMappedYValue(new DoubleCell(yMax)));
    }
}
Also used : DataProvider(org.knime.base.node.viz.plotter.DataProvider) DoubleCell(org.knime.core.data.def.DoubleCell) LinearRegressionContent(org.knime.base.node.mine.regression.linear.LinearRegressionContent) NumericCoordinate(org.knime.base.util.coordinate.NumericCoordinate) DataArray(org.knime.base.node.util.DataArray)

Example 4 with LinearRegressionContent

use of org.knime.base.node.mine.regression.linear.LinearRegressionContent in project knime-core by knime.

the class LinRegLinePlotter method updatePaintModel.

/**
 * Retrieves the linear regression params, updates the column selection
 * boxes appropriately and adds the regression line to the scatterplot.
 */
@Override
public void updatePaintModel() {
    DataProvider dataProvider = getDataProvider();
    if (dataProvider == null) {
        return;
    }
    DataArray data = dataProvider.getDataArray(0);
    if (data == null) {
        return;
    }
    LinearRegressionContent params = ((LinRegDataProvider) dataProvider).getParams();
    if (params == null) {
        return;
    }
    // set the target column to fix
    ((LinRegLinePlotterProperties) getProperties()).setTargetColumn(params.getTargetColumnName());
    // get the included columns
    String[] includedCols = ((LinRegDataProvider) dataProvider).getLearningColumns();
    if (includedCols == null) {
        return;
    }
    ((LinRegLinePlotterProperties) getProperties()).setIncludedColumns(includedCols);
    // update the combo boxes
    DataTableSpec spec = data.getDataTableSpec();
    ((LinRegLinePlotterProperties) getProperties()).update(spec);
    super.updatePaintModel();
    double xMin = ((NumericCoordinate) getXAxis().getCoordinate()).getMinDomainValue();
    double xMax = ((NumericCoordinate) getXAxis().getCoordinate()).getMaxDomainValue();
    String xName = getSelectedXColumn().getName();
    List<String> includedList = Arrays.asList(includedCols);
    if (!xName.equals(params.getTargetColumnName()) && includedList.contains(xName)) {
        double yMin = params.getApproximationFor(xName, xMin);
        double yMax = params.getApproximationFor(xName, xMax);
        ((LinRegLineDrawingPane) getDrawingPane()).setLineFirstPoint(getMappedXValue(new DoubleCell(xMin)), getMappedYValue(new DoubleCell(yMin)));
        ((LinRegLineDrawingPane) getDrawingPane()).setLineLastPoint(getMappedXValue(new DoubleCell(xMax)), getMappedYValue(new DoubleCell(yMax)));
        getDrawingPane().repaint();
    }
}
Also used : DataTableSpec(org.knime.core.data.DataTableSpec) DoubleCell(org.knime.core.data.def.DoubleCell) NumericCoordinate(org.knime.base.util.coordinate.NumericCoordinate) DataArray(org.knime.base.node.util.DataArray) DataProvider(org.knime.base.node.viz.plotter.DataProvider) LinearRegressionContent(org.knime.base.node.mine.regression.linear.LinearRegressionContent)

Aggregations

LinearRegressionContent (org.knime.base.node.mine.regression.linear.LinearRegressionContent)4 DataTableSpec (org.knime.core.data.DataTableSpec)3 DataArray (org.knime.base.node.util.DataArray)2 DataProvider (org.knime.base.node.viz.plotter.DataProvider)2 NumericCoordinate (org.knime.base.util.coordinate.NumericCoordinate)2 DoubleCell (org.knime.core.data.def.DoubleCell)2 IOException (java.io.IOException)1 BitSet (java.util.BitSet)1 DefaultDataArray (org.knime.base.node.util.DefaultDataArray)1 DataCell (org.knime.core.data.DataCell)1 DataRow (org.knime.core.data.DataRow)1 DoubleValue (org.knime.core.data.DoubleValue)1 RowIterator (org.knime.core.data.RowIterator)1 BufferedDataTable (org.knime.core.node.BufferedDataTable)1 CanceledExecutionException (org.knime.core.node.CanceledExecutionException)1 InvalidSettingsException (org.knime.core.node.InvalidSettingsException)1 PortObject (org.knime.core.node.port.PortObject)1 PMMLPortObject (org.knime.core.node.port.pmml.PMMLPortObject)1