Search in sources :

Example 1 with ModelSpecificationException

use of org.apache.commons.math3.stat.regression.ModelSpecificationException in project knime-core by knime.

the class PolyRegLearnerNodeModel method execute.

/**
 * {@inheritDoc}
 */
@Override
protected PortObject[] execute(final PortObject[] inData, final ExecutionContext exec) throws Exception {
    BufferedDataTable inTable = (BufferedDataTable) inData[0];
    DataTableSpec inSpec = inTable.getDataTableSpec();
    final int colCount = inSpec.getNumColumns();
    String[] selectedCols = computeSelectedColumns(inSpec);
    Set<String> hash = new HashSet<String>(Arrays.asList(selectedCols));
    m_colSelected = new boolean[colCount];
    for (int i = 0; i < colCount; i++) {
        m_colSelected[i] = hash.contains(inTable.getDataTableSpec().getColumnSpec(i).getName());
    }
    final int rowCount = inTable.getRowCount();
    String[] temp = new String[m_columnNames.length + 1];
    System.arraycopy(m_columnNames, 0, temp, 0, m_columnNames.length);
    temp[temp.length - 1] = m_settings.getTargetColumn();
    FilterColumnTable filteredTable = new FilterColumnTable(inTable, temp);
    final DataArray rowContainer = new DefaultDataArray(filteredTable, 1, m_settings.getMaxRowsForView());
    // handle the optional PMML input
    PMMLPortObject inPMMLPort = m_pmmlInEnabled ? (PMMLPortObject) inData[1] : null;
    PortObjectSpec[] outputSpec = configure((inPMMLPort == null) ? new PortObjectSpec[] { inData[0].getSpec(), null } : new PortObjectSpec[] { inData[0].getSpec(), inPMMLPort.getSpec() });
    Learner learner = new Learner((PMMLPortObjectSpec) outputSpec[0], 0d, m_settings.getMissingValueHandling() == MissingValueHandling.fail, m_settings.getDegree());
    try {
        PolyRegContent polyRegContent = learner.perform(inTable, exec);
        m_betas = fillBeta(polyRegContent);
        m_meanValues = polyRegContent.getMeans();
        ColumnRearranger crea = new ColumnRearranger(inTable.getDataTableSpec());
        crea.append(getCellFactory(inTable.getDataTableSpec().findColumnIndex(m_settings.getTargetColumn())));
        PortObject[] bdt = new PortObject[] { createPMMLModel(inPMMLPort, inSpec), exec.createColumnRearrangeTable(inTable, crea, exec.createSilentSubExecutionContext(.2)), polyRegContent.createTablePortObject(exec.createSubExecutionContext(0.2)) };
        m_squaredError /= rowCount;
        if (polyRegContent.getWarningMessage() != null) {
            setWarningMessage(polyRegContent.getWarningMessage());
        }
        double[] stdErrors = PolyRegViewData.mapToArray(polyRegContent.getStandardErrors(), m_columnNames, m_settings.getDegree(), polyRegContent.getInterceptStdErr());
        double[] tValues = PolyRegViewData.mapToArray(polyRegContent.getTValues(), m_columnNames, m_settings.getDegree(), polyRegContent.getInterceptTValue());
        double[] pValues = PolyRegViewData.mapToArray(polyRegContent.getPValues(), m_columnNames, m_settings.getDegree(), polyRegContent.getInterceptPValue());
        m_viewData = new PolyRegViewData(m_meanValues, m_betas, stdErrors, tValues, pValues, m_squaredError, polyRegContent.getAdjustedRSquared(), m_columnNames, m_settings.getDegree(), m_settings.getTargetColumn(), rowContainer);
        return bdt;
    } catch (ModelSpecificationException e) {
        final String origWarning = getWarningMessage();
        final String warning = (origWarning != null && !origWarning.isEmpty()) ? (origWarning + "\n") : "" + e.getMessage();
        setWarningMessage(warning);
        final ExecutionContext subExec = exec.createSubExecutionContext(.1);
        final BufferedDataContainer empty = subExec.createDataContainer(STATS_SPEC);
        int rowIdx = 1;
        for (final String column : m_columnNames) {
            for (int d = 1; d <= m_settings.getDegree(); ++d) {
                empty.addRowToTable(new DefaultRow("Row" + rowIdx++, new StringCell(column), new IntCell(d), new DoubleCell(0.0d), DataType.getMissingCell(), DataType.getMissingCell(), DataType.getMissingCell()));
            }
        }
        empty.addRowToTable(new DefaultRow("Row" + rowIdx, new StringCell("Intercept"), new IntCell(0), new DoubleCell(0.0d), DataType.getMissingCell(), DataType.getMissingCell(), DataType.getMissingCell()));
        double[] nans = new double[m_columnNames.length * m_settings.getDegree() + 1];
        Arrays.fill(nans, Double.NaN);
        m_betas = new double[nans.length];
        // Mean only for the linear tags
        m_meanValues = new double[nans.length / m_settings.getDegree()];
        m_viewData = new PolyRegViewData(m_meanValues, m_betas, nans, nans, nans, m_squaredError, Double.NaN, m_columnNames, m_settings.getDegree(), m_settings.getTargetColumn(), rowContainer);
        empty.close();
        ColumnRearranger crea = new ColumnRearranger(inTable.getDataTableSpec());
        crea.append(getCellFactory(inTable.getDataTableSpec().findColumnIndex(m_settings.getTargetColumn())));
        BufferedDataTable rearrangerTable = exec.createColumnRearrangeTable(inTable, crea, exec.createSubProgress(0.6));
        PMMLPortObject model = createPMMLModel(inPMMLPort, inTable.getDataTableSpec());
        PortObject[] bdt = new PortObject[] { model, rearrangerTable, empty.getTable() };
        return bdt;
    }
}
Also used : DataTableSpec(org.knime.core.data.DataTableSpec) DefaultDataArray(org.knime.base.node.util.DefaultDataArray) DoubleCell(org.knime.core.data.def.DoubleCell) FilterColumnTable(org.knime.base.data.filter.column.FilterColumnTable) DataArray(org.knime.base.node.util.DataArray) DefaultDataArray(org.knime.base.node.util.DefaultDataArray) ModelSpecificationException(org.apache.commons.math3.stat.regression.ModelSpecificationException) IntCell(org.knime.core.data.def.IntCell) ColumnRearranger(org.knime.core.data.container.ColumnRearranger) BufferedDataTable(org.knime.core.node.BufferedDataTable) PMMLPortObjectSpec(org.knime.core.node.port.pmml.PMMLPortObjectSpec) PortObjectSpec(org.knime.core.node.port.PortObjectSpec) PortObject(org.knime.core.node.port.PortObject) PMMLPortObject(org.knime.core.node.port.pmml.PMMLPortObject) HashSet(java.util.HashSet) BufferedDataContainer(org.knime.core.node.BufferedDataContainer) ExecutionContext(org.knime.core.node.ExecutionContext) PMMLPortObject(org.knime.core.node.port.pmml.PMMLPortObject) StringCell(org.knime.core.data.def.StringCell) DefaultRow(org.knime.core.data.def.DefaultRow)

Example 2 with ModelSpecificationException

use of org.apache.commons.math3.stat.regression.ModelSpecificationException in project knime-core by knime.

the class Learner method perform.

/**
 * @param data The data table.
 * @param exec The execution context used for reporting progress.
 * @return An object which holds the results.
 * @throws CanceledExecutionException When method is cancelled
 * @throws InvalidSettingsException When settings are inconsistent with the data
 */
@Override
public LinearRegressionContent perform(final BufferedDataTable data, final ExecutionContext exec) throws CanceledExecutionException, InvalidSettingsException {
    exec.checkCanceled();
    RegressionTrainingData trainingData = new RegressionTrainingData(data, m_outSpec, m_failOnMissing);
    final int regressorCount = Math.max(1, trainingData.getRegressorCount());
    SummaryStatistics[] stats = new SummaryStatistics[regressorCount];
    UpdatingMultipleLinearRegression regr = initStatistics(regressorCount, stats);
    processTable(exec, trainingData, stats, regr);
    List<String> factorList = new ArrayList<String>();
    List<String> covariateList = createCovariateListAndFillFactors(data, trainingData, factorList);
    try {
        RegressionResults result = regr.regress();
        RealMatrix beta = MatrixUtils.createRowRealMatrix(result.getParameterEstimates());
        // The covariance matrix
        RealMatrix covMat = createCovarianceMatrix(result);
        LinearRegressionContent content = new LinearRegressionContent(m_outSpec, (int) stats[0].getN(), factorList, covariateList, beta, m_includeConstant, m_offsetValue, covMat, result.getRSquared(), result.getAdjustedRSquared(), stats, null);
        return content;
    } catch (ModelSpecificationException e) {
        int dim = (m_includeConstant ? 1 : 0) + trainingData.getRegressorCount() + (factorList.size() > 0 ? Math.max(1, data.getDataTableSpec().getColumnSpec(factorList.get(0)).getDomain().getValues().size() - 1) : 0);
        RealMatrix beta = MatrixUtils.createRealMatrix(1, dim);
        RealMatrix covMat = MatrixUtils.createRealMatrix(dim, dim);
        // fillWithNaNs(beta);
        fillWithNaNs(covMat);
        return new LinearRegressionContent(m_outSpec, (int) stats[0].getN(), factorList, covariateList, beta, m_includeConstant, m_offsetValue, covMat, Double.NaN, Double.NaN, stats, e.getMessage());
    }
}
Also used : RealMatrix(org.apache.commons.math3.linear.RealMatrix) ArrayList(java.util.ArrayList) RegressionTrainingData(org.knime.base.node.mine.regression.RegressionTrainingData) SummaryStatistics(org.apache.commons.math3.stat.descriptive.SummaryStatistics) RegressionResults(org.apache.commons.math3.stat.regression.RegressionResults) UpdatingMultipleLinearRegression(org.apache.commons.math3.stat.regression.UpdatingMultipleLinearRegression) ModelSpecificationException(org.apache.commons.math3.stat.regression.ModelSpecificationException)

Aggregations

ModelSpecificationException (org.apache.commons.math3.stat.regression.ModelSpecificationException)2 ArrayList (java.util.ArrayList)1 HashSet (java.util.HashSet)1 RealMatrix (org.apache.commons.math3.linear.RealMatrix)1 SummaryStatistics (org.apache.commons.math3.stat.descriptive.SummaryStatistics)1 RegressionResults (org.apache.commons.math3.stat.regression.RegressionResults)1 UpdatingMultipleLinearRegression (org.apache.commons.math3.stat.regression.UpdatingMultipleLinearRegression)1 FilterColumnTable (org.knime.base.data.filter.column.FilterColumnTable)1 RegressionTrainingData (org.knime.base.node.mine.regression.RegressionTrainingData)1 DataArray (org.knime.base.node.util.DataArray)1 DefaultDataArray (org.knime.base.node.util.DefaultDataArray)1 DataTableSpec (org.knime.core.data.DataTableSpec)1 ColumnRearranger (org.knime.core.data.container.ColumnRearranger)1 DefaultRow (org.knime.core.data.def.DefaultRow)1 DoubleCell (org.knime.core.data.def.DoubleCell)1 IntCell (org.knime.core.data.def.IntCell)1 StringCell (org.knime.core.data.def.StringCell)1 BufferedDataContainer (org.knime.core.node.BufferedDataContainer)1 BufferedDataTable (org.knime.core.node.BufferedDataTable)1 ExecutionContext (org.knime.core.node.ExecutionContext)1