Search in sources :

Example 21 with DataArray

use of org.knime.base.node.util.DataArray in project knime-core by knime.

the class PolyRegLearnerNodeModel method execute.

/**
 * {@inheritDoc}
 */
@Override
protected PortObject[] execute(final PortObject[] inData, final ExecutionContext exec) throws Exception {
    BufferedDataTable inTable = (BufferedDataTable) inData[0];
    DataTableSpec inSpec = inTable.getDataTableSpec();
    final int colCount = inSpec.getNumColumns();
    String[] selectedCols = computeSelectedColumns(inSpec);
    Set<String> hash = new HashSet<String>(Arrays.asList(selectedCols));
    m_colSelected = new boolean[colCount];
    for (int i = 0; i < colCount; i++) {
        m_colSelected[i] = hash.contains(inTable.getDataTableSpec().getColumnSpec(i).getName());
    }
    final int rowCount = inTable.getRowCount();
    final int independentVariables = selectedCols.length;
    final int degree = m_settings.getDegree();
    final int dependentIndex = inTable.getDataTableSpec().findColumnIndex(m_settings.getTargetColumn());
    double[][] xMat = new double[rowCount][1 + independentVariables * degree];
    double[][] yMat = new double[rowCount][1];
    int rowIndex = 0;
    for (DataRow row : inTable) {
        exec.checkCanceled();
        exec.setProgress(0.2 * rowIndex / rowCount);
        xMat[rowIndex][0] = 1;
        int colIndex = 1;
        for (int i = 0; i < row.getNumCells(); i++) {
            if ((m_colSelected[i] || (i == dependentIndex)) && row.getCell(i).isMissing()) {
                throw new IllegalArgumentException("Missing values are not supported by this node.");
            }
            if (m_colSelected[i]) {
                double val = ((DoubleValue) row.getCell(i)).getDoubleValue();
                double poly = val;
                xMat[rowIndex][colIndex] = poly;
                colIndex++;
                for (int d = 2; d <= degree; d++) {
                    poly *= val;
                    xMat[rowIndex][colIndex] = poly;
                    colIndex++;
                }
            } else if (i == dependentIndex) {
                double val = ((DoubleValue) row.getCell(i)).getDoubleValue();
                yMat[rowIndex][0] = val;
            }
        }
        rowIndex++;
    }
    // compute X'
    double[][] xTransMat = MathUtils.transpose(xMat);
    exec.setProgress(0.24);
    exec.checkCanceled();
    // compute X'X
    double[][] xxMat = MathUtils.multiply(xTransMat, xMat);
    exec.setProgress(0.28);
    exec.checkCanceled();
    // compute X'Y
    double[][] xyMat = MathUtils.multiply(xTransMat, yMat);
    exec.setProgress(0.32);
    exec.checkCanceled();
    // compute (X'X)^-1
    double[][] xxInverse;
    try {
        xxInverse = MathUtils.inverse(xxMat);
        exec.setProgress(0.36);
        exec.checkCanceled();
    } catch (ArithmeticException ex) {
        throw new ArithmeticException("The attributes of the data samples" + " are not mutually independent.");
    }
    // compute (X'X)^-1 * (X'Y)
    final double[][] betas = MathUtils.multiply(xxInverse, xyMat);
    exec.setProgress(0.4);
    m_betas = new double[independentVariables * degree + 1];
    for (int i = 0; i < betas.length; i++) {
        m_betas[i] = betas[i][0];
    }
    m_columnNames = selectedCols;
    String[] temp = new String[m_columnNames.length + 1];
    System.arraycopy(m_columnNames, 0, temp, 0, m_columnNames.length);
    temp[temp.length - 1] = m_settings.getTargetColumn();
    FilterColumnTable filteredTable = new FilterColumnTable(inTable, temp);
    DataArray rowContainer = new DefaultDataArray(filteredTable, 1, m_settings.getMaxRowsForView());
    int ignore = rowContainer.getDataTableSpec().findColumnIndex(m_settings.getTargetColumn());
    m_meanValues = new double[independentVariables];
    for (DataRow row : rowContainer) {
        int k = 0;
        for (int i = 0; i < row.getNumCells(); i++) {
            if (i != ignore) {
                m_meanValues[k++] += ((DoubleValue) row.getCell(i)).getDoubleValue();
            }
        }
    }
    for (int i = 0; i < m_meanValues.length; i++) {
        m_meanValues[i] /= rowContainer.size();
    }
    ColumnRearranger crea = new ColumnRearranger(inTable.getDataTableSpec());
    crea.append(getCellFactory(inTable.getDataTableSpec().findColumnIndex(m_settings.getTargetColumn())));
    // handle the optional PMML input
    PMMLPortObject inPMMLPort = (PMMLPortObject) inData[1];
    PortObject[] bdt = new PortObject[] { exec.createColumnRearrangeTable(inTable, crea, exec.createSubProgress(0.6)), createPMMLModel(inPMMLPort, inTable.getDataTableSpec()) };
    m_squaredError /= rowCount;
    m_viewData = new PolyRegViewData(m_meanValues, m_betas, m_squaredError, m_columnNames, m_settings.getDegree(), m_settings.getTargetColumn());
    m_rowContainer = rowContainer;
    return bdt;
}
Also used : DataTableSpec(org.knime.core.data.DataTableSpec) DefaultDataArray(org.knime.base.node.util.DefaultDataArray) FilterColumnTable(org.knime.base.data.filter.column.FilterColumnTable) DataRow(org.knime.core.data.DataRow) DataArray(org.knime.base.node.util.DataArray) DefaultDataArray(org.knime.base.node.util.DefaultDataArray) ColumnRearranger(org.knime.core.data.container.ColumnRearranger) DoubleValue(org.knime.core.data.DoubleValue) PMMLPortObject(org.knime.core.node.port.pmml.PMMLPortObject) BufferedDataTable(org.knime.core.node.BufferedDataTable) PortObject(org.knime.core.node.port.PortObject) PMMLPortObject(org.knime.core.node.port.pmml.PMMLPortObject) HashSet(java.util.HashSet)

Example 22 with DataArray

use of org.knime.base.node.util.DataArray in project knime-core by knime.

the class ScatterPlotter method getYmin.

/**
 * @return the lower limit of the Y scale
 */
public double getYmin() {
    if (getYColName() == null) {
        return 0.0;
    }
    DataArray rows = m_rowContainer;
    if ((rows == null) || (rows.size() == 0)) {
        return 0.0;
    }
    DataTableSpec tSpec = rows.getDataTableSpec();
    int idx = tSpec.findColumnIndex(getYColName());
    if (idx < 0) {
        return 0.0;
    }
    // 'getDoubleValue' returns the first valid double value
    return getDoubleValue(m_userYmin, tSpec.getColumnSpec(idx).getDomain().getLowerBound(), rows.getMinValue(idx), Double.NaN);
}
Also used : DataTableSpec(org.knime.core.data.DataTableSpec) DataArray(org.knime.base.node.util.DataArray)

Example 23 with DataArray

use of org.knime.base.node.util.DataArray in project knime-core by knime.

the class ScatterPlotter method updateDotsAndPaint.

/*
     * takes the data from the private row container and recalculates the dots,
     * the coordinates, adjusts sizes and repaints.
     */
private void updateDotsAndPaint() {
    getScatterPlotterDrawingPane().clearSelection();
    // get the rowInfo from the model
    DataArray rowsCont = m_rowContainer;
    if (rowsCont != null) {
        // and create a new DotInfo array with the rowKeys in the DotInfos.
        DotInfo[] newDots = new DotInfo[rowsCont.size()];
        for (int r = 0; r < rowsCont.size(); r++) {
            DataRow row = rowsCont.getRow(r);
            double size = getSize(row);
            ColorAttr colorAttr = getColorAttr(row);
            newDots[r] = new DotInfo(0, 0, row.getKey(), false, colorAttr, size, r);
        }
        // now create a new DotInfoArray
        DotInfoArray newDotArray = new DotInfoArray(newDots);
        // store it in the drawing pane
        getScatterPlotterDrawingPane().setDotInfoArray(newDotArray);
        // update hilit and colors.
        updateDotHiLiting();
        // and get the coordinates calculated.
        adjustSizes();
        calculateCoordinates(newDotArray);
        repaint();
    }
}
Also used : ColorAttr(org.knime.core.data.property.ColorAttr) DataRow(org.knime.core.data.DataRow) DataArray(org.knime.base.node.util.DataArray)

Aggregations

DataArray (org.knime.base.node.util.DataArray)23 DataTableSpec (org.knime.core.data.DataTableSpec)8 DataRow (org.knime.core.data.DataRow)7 DefaultDataArray (org.knime.base.node.util.DefaultDataArray)6 DataCell (org.knime.core.data.DataCell)6 Point (java.awt.Point)5 ArrayList (java.util.ArrayList)5 DataProvider (org.knime.base.node.viz.plotter.DataProvider)5 StringCell (org.knime.core.data.def.StringCell)5 BufferedDataTable (org.knime.core.node.BufferedDataTable)4 Color (java.awt.Color)3 ColorAttr (org.knime.core.data.property.ColorAttr)3 File (java.io.File)2 FileInputStream (java.io.FileInputStream)2 IOException (java.io.IOException)2 HashSet (java.util.HashSet)2 LinkedHashSet (java.util.LinkedHashSet)2 FilterColumnTable (org.knime.base.data.filter.column.FilterColumnTable)2 LinearRegressionContent (org.knime.base.node.mine.regression.linear.LinearRegressionContent)2 DotInfo (org.knime.base.node.viz.plotter.scatter.DotInfo)2