Search in sources :

Example 1 with MissingHandling

use of org.knime.base.node.mine.regression.RegressionTrainingRow.MissingHandling in project knime-core by knime.

the class LogRegPredictor method getCells.

/**
 * {@inheritDoc}
 */
@Override
public DataCell[] getCells(final DataRow row) {
    if (hasMissingValues(row)) {
        return createMissingOutput();
    }
    final MissingHandling missingHandling = new MissingHandling(true);
    DataCell[] cells = m_includeProbs ? new DataCell[1 + m_targetDomainValuesCount] : new DataCell[1];
    Arrays.fill(cells, new IntCell(0));
    // column vector
    final RealMatrix x = MatrixUtils.createRealMatrix(1, m_parameters.size());
    for (int i = 0; i < m_parameters.size(); i++) {
        String parameter = m_parameters.get(i);
        String predictor = null;
        String value = null;
        boolean rowIsEmpty = true;
        for (final Iterator<String> iter = m_predictors.iterator(); iter.hasNext(); ) {
            predictor = iter.next();
            value = m_ppMatrix.getValue(parameter, predictor, null);
            if (null != value) {
                rowIsEmpty = false;
                break;
            }
        }
        if (rowIsEmpty) {
            x.setEntry(0, i, 1);
        } else {
            if (m_factors.contains(predictor)) {
                List<DataCell> values = m_values.get(predictor);
                DataCell cell = row.getCell(m_parameterI.get(parameter));
                int index = values.indexOf(cell);
                /* When building a general regression model, for each
                    categorical fields, there is one category used as the
                    default baseline and therefore it didn't show in the
                    ParameterList in PMML. This design for the training is fine,
                    but in the prediction, when the input of Employment is
                    the default baseline, the parameters should all be 0.
                    See the commit message for an example and more details.
                    */
                if (index > 0) {
                    x.setEntry(0, i + index - 1, 1);
                    i += values.size() - 2;
                }
            } else if (m_baseLabelToColName.containsKey(parameter) && m_vectorLengths.containsKey(m_baseLabelToColName.get(parameter))) {
                final DataCell cell = row.getCell(m_parameterI.get(parameter));
                Optional<NameAndIndex> vectorValue = VectorHandling.parse(predictor);
                if (vectorValue.isPresent()) {
                    int j = vectorValue.get().getIndex();
                    value = m_ppMatrix.getValue(parameter, predictor, null);
                    double exponent = Integer.valueOf(value);
                    double radix = RegressionTrainingRow.getValue(cell, j, missingHandling);
                    x.setEntry(0, i, Math.pow(radix, exponent));
                }
            } else {
                DataCell cell = row.getCell(m_parameterI.get(parameter));
                double radix = ((DoubleValue) cell).getDoubleValue();
                double exponent = Integer.valueOf(value);
                x.setEntry(0, i, Math.pow(radix, exponent));
            }
        }
    }
    // column vector
    RealMatrix r = x.multiply(m_beta);
    // determine the column with highest probability
    int maxIndex = 0;
    double maxValue = r.getEntry(0, 0);
    for (int i = 1; i < r.getColumnDimension(); i++) {
        if (r.getEntry(0, i) > maxValue) {
            maxValue = r.getEntry(0, i);
            maxIndex = i;
        }
    }
    if (m_includeProbs) {
        // compute probabilities of the target categories
        for (int i = 0; i < m_targetCategories.size(); i++) {
            // test if calculation would overflow
            boolean overflow = false;
            for (int k = 0; k < r.getColumnDimension(); k++) {
                if ((r.getEntry(0, k) - r.getEntry(0, i)) > 700) {
                    overflow = true;
                }
            }
            if (!overflow) {
                double sum = 0;
                for (int k = 0; k < r.getColumnDimension(); k++) {
                    sum += Math.exp(r.getEntry(0, k) - r.getEntry(0, i));
                }
                cells[m_targetCategoryIndex.get(i)] = new DoubleCell(1.0 / sum);
            } else {
                cells[m_targetCategoryIndex.get(i)] = new DoubleCell(0);
            }
        }
    }
    // the last cell is the prediction
    cells[cells.length - 1] = m_targetCategories.get(maxIndex);
    return cells;
}
Also used : Optional(java.util.Optional) DoubleCell(org.knime.core.data.def.DoubleCell) IntCell(org.knime.core.data.def.IntCell) RealMatrix(org.apache.commons.math3.linear.RealMatrix) MissingHandling(org.knime.base.node.mine.regression.RegressionTrainingRow.MissingHandling) DataCell(org.knime.core.data.DataCell)

Aggregations

Optional (java.util.Optional)1 RealMatrix (org.apache.commons.math3.linear.RealMatrix)1 MissingHandling (org.knime.base.node.mine.regression.RegressionTrainingRow.MissingHandling)1 DataCell (org.knime.core.data.DataCell)1 DoubleCell (org.knime.core.data.def.DoubleCell)1 IntCell (org.knime.core.data.def.IntCell)1