Search in sources :

Example 71 with DataCell

use of org.knime.core.data.DataCell in project knime-core by knime.

the class DecTreePredictorNodeModel method execute.

/**
 * {@inheritDoc}
 */
@Override
public PortObject[] execute(final PortObject[] inPorts, final ExecutionContext exec) throws CanceledExecutionException, Exception {
    exec.setMessage("Decision Tree Predictor: Loading predictor...");
    PMMLPortObject port = (PMMLPortObject) inPorts[INMODELPORT];
    List<Node> models = port.getPMMLValue().getModels(PMMLModelType.TreeModel);
    if (models.isEmpty()) {
        String msg = "Decision Tree evaluation failed: " + "No tree model found.";
        LOGGER.error(msg);
        throw new RuntimeException(msg);
    }
    PMMLDecisionTreeTranslator trans = new PMMLDecisionTreeTranslator();
    port.initializeModelTranslator(trans);
    DecisionTree decTree = trans.getDecisionTree();
    decTree.resetColorInformation();
    BufferedDataTable inData = (BufferedDataTable) inPorts[INDATAPORT];
    // get column with color information
    String colorColumn = null;
    for (DataColumnSpec s : inData.getDataTableSpec()) {
        if (s.getColorHandler() != null) {
            colorColumn = s.getName();
            break;
        }
    }
    decTree.setColorColumn(colorColumn);
    exec.setMessage("Decision Tree Predictor: start execution.");
    PortObjectSpec[] inSpecs = new PortObjectSpec[] { inPorts[0].getSpec(), inPorts[1].getSpec() };
    DataTableSpec outSpec = createOutTableSpec(inSpecs);
    BufferedDataContainer outData = exec.createDataContainer(outSpec);
    long coveredPattern = 0;
    long nrPattern = 0;
    long rowCount = 0;
    long numberRows = inData.size();
    exec.setMessage("Classifying...");
    for (DataRow thisRow : inData) {
        DataCell cl = null;
        LinkedHashMap<String, Double> classDistrib = null;
        try {
            Pair<DataCell, LinkedHashMap<DataCell, Double>> pair = decTree.getWinnerAndClasscounts(thisRow, inData.getDataTableSpec());
            cl = pair.getFirst();
            LinkedHashMap<DataCell, Double> classCounts = pair.getSecond();
            classDistrib = getDistribution(classCounts);
            if (coveredPattern < m_maxNumCoveredPattern.getIntValue()) {
                // remember this one for HiLite support
                decTree.addCoveredPattern(thisRow, inData.getDataTableSpec());
                coveredPattern++;
            } else {
                // too many patterns for HiLite - at least remember color
                decTree.addCoveredColor(thisRow, inData.getDataTableSpec());
            }
            nrPattern++;
        } catch (Exception e) {
            LOGGER.error("Decision Tree evaluation failed: " + e.getMessage());
            throw e;
        }
        if (cl == null) {
            LOGGER.error("Decision Tree evaluation failed: result empty");
            throw new Exception("Decision Tree evaluation failed.");
        }
        DataCell[] newCells = new DataCell[outSpec.getNumColumns()];
        int numInCells = thisRow.getNumCells();
        for (int i = 0; i < numInCells; i++) {
            newCells[i] = thisRow.getCell(i);
        }
        if (m_showDistribution.getBooleanValue()) {
            for (int i = numInCells; i < newCells.length - 1; i++) {
                String predClass = outSpec.getColumnSpec(i).getName();
                if (classDistrib != null && classDistrib.get(predClass) != null) {
                    newCells[i] = new DoubleCell(classDistrib.get(predClass));
                } else {
                    newCells[i] = new DoubleCell(0.0);
                }
            }
        }
        newCells[newCells.length - 1] = cl;
        outData.addRowToTable(new DefaultRow(thisRow.getKey(), newCells));
        rowCount++;
        if (rowCount % 100 == 0) {
            exec.setProgress(rowCount / (double) numberRows, "Classifying... Row " + rowCount + " of " + numberRows);
        }
        exec.checkCanceled();
    }
    if (coveredPattern < nrPattern) {
        // let the user know that we did not store all available pattern
        // for HiLiting.
        this.setWarningMessage("Tree only stored first " + m_maxNumCoveredPattern.getIntValue() + " (of " + nrPattern + ") rows for HiLiting!");
    }
    outData.close();
    m_decTree = decTree;
    exec.setMessage("Decision Tree Predictor: end execution.");
    return new BufferedDataTable[] { outData.getTable() };
}
Also used : DataTableSpec(org.knime.core.data.DataTableSpec) PMMLDecisionTreeTranslator(org.knime.base.node.mine.decisiontree2.PMMLDecisionTreeTranslator) DoubleCell(org.knime.core.data.def.DoubleCell) Node(org.w3c.dom.Node) DataRow(org.knime.core.data.DataRow) LinkedHashMap(java.util.LinkedHashMap) DataColumnSpec(org.knime.core.data.DataColumnSpec) BufferedDataTable(org.knime.core.node.BufferedDataTable) PMMLPortObjectSpec(org.knime.core.node.port.pmml.PMMLPortObjectSpec) PortObjectSpec(org.knime.core.node.port.PortObjectSpec) DecisionTree(org.knime.base.node.mine.decisiontree2.model.DecisionTree) BufferedDataContainer(org.knime.core.node.BufferedDataContainer) InvalidSettingsException(org.knime.core.node.InvalidSettingsException) CanceledExecutionException(org.knime.core.node.CanceledExecutionException) IOException(java.io.IOException) PMMLPortObject(org.knime.core.node.port.pmml.PMMLPortObject) DataCell(org.knime.core.data.DataCell) DefaultRow(org.knime.core.data.def.DefaultRow)

Example 72 with DataCell

use of org.knime.core.data.DataCell in project knime-core by knime.

the class DecTreePredictorNodeModel method createOutTableSpec.

private DataTableSpec createOutTableSpec(final PortObjectSpec[] inSpecs) {
    LinkedList<DataCell> predValues = null;
    if (m_showDistribution.getBooleanValue()) {
        predValues = getPredictionValues((PMMLPortObjectSpec) inSpecs[INMODELPORT]);
        if (predValues == null) {
            // no out spec can be determined
            return null;
        }
    }
    int numCols = (predValues == null ? 0 : predValues.size()) + 1;
    DataTableSpec inSpec = (DataTableSpec) inSpecs[INDATAPORT];
    UniqueNameGenerator nameGenerator = new UniqueNameGenerator(inSpec);
    DataColumnSpec[] newCols = new DataColumnSpec[numCols];
    /* Set bar renderer and domain [0,1] as default for the double cells
         * containing the distribution */
    // DataColumnProperties propsRendering = new DataColumnProperties(
    // Collections.singletonMap(
    // DataValueRenderer.PROPERTY_PREFERRED_RENDERER,
    // DoubleBarRenderer.DESCRIPTION));
    DataColumnDomain domain = new DataColumnDomainCreator(new DoubleCell(0.0), new DoubleCell(1.0)).createDomain();
    // add all distribution columns
    for (int i = 0; i < numCols - 1; i++) {
        DataColumnSpecCreator colSpecCreator = nameGenerator.newCreator(predValues.get(i).toString(), DoubleCell.TYPE);
        // colSpecCreator.setProperties(propsRendering);
        colSpecCreator.setDomain(domain);
        newCols[i] = colSpecCreator.createSpec();
    }
    // add the prediction column
    newCols[numCols - 1] = nameGenerator.newColumn("Prediction (DecTree)", StringCell.TYPE);
    DataTableSpec newColSpec = new DataTableSpec(newCols);
    return new DataTableSpec(inSpec, newColSpec);
}
Also used : PMMLPortObjectSpec(org.knime.core.node.port.pmml.PMMLPortObjectSpec) DataTableSpec(org.knime.core.data.DataTableSpec) DataColumnSpec(org.knime.core.data.DataColumnSpec) DataColumnDomain(org.knime.core.data.DataColumnDomain) DataColumnSpecCreator(org.knime.core.data.DataColumnSpecCreator) DoubleCell(org.knime.core.data.def.DoubleCell) DataCell(org.knime.core.data.DataCell) DataColumnDomainCreator(org.knime.core.data.DataColumnDomainCreator) UniqueNameGenerator(org.knime.core.util.UniqueNameGenerator)

Example 73 with DataCell

use of org.knime.core.data.DataCell in project knime-core by knime.

the class LinearRegressionContent method predict.

/**
 * Predicts the target value for the given row.
 *
 * @param row a data row to predict
 * @return the predicted value in a data cell
 */
public DataCell predict(final DataRow row) {
    double sum = m_offset;
    for (int i = 0; i < row.getNumCells(); i++) {
        DataCell c = row.getCell(i);
        if (c.isMissing()) {
            return DataType.getMissingCell();
        }
        double d = ((DoubleCell) c).getDoubleValue();
        sum += m_multipliers[i] * d;
    }
    return new DoubleCell(sum);
}
Also used : DoubleCell(org.knime.core.data.def.DoubleCell) DataCell(org.knime.core.data.DataCell)

Example 74 with DataCell

use of org.knime.core.data.DataCell in project knime-core by knime.

the class LogRegPredictor method determineTargetCategories.

/**
 * Retrieve the target values from the PMML model.
 * @throws InvalidSettingsException if PMML model is inconsistent or ambiguous
 */
private static List<DataCell> determineTargetCategories(final DataColumnSpec targetCol, final PMMLGeneralRegressionContent content) throws InvalidSettingsException {
    Map<String, DataCell> domainValues = new HashMap<String, DataCell>();
    for (DataCell cell : targetCol.getDomain().getValues()) {
        domainValues.put(cell.toString(), cell);
    }
    // Collect target categories from model
    Set<DataCell> modelTargetCategories = new LinkedHashSet<DataCell>();
    for (PMMLPCell cell : content.getParamMatrix()) {
        modelTargetCategories.add(domainValues.get(cell.getTargetCategory()));
    }
    String targetReferenceCategory = content.getTargetReferenceCategory();
    if (targetReferenceCategory == null || targetReferenceCategory.isEmpty()) {
        List<DataCell> targetCategories = new ArrayList<DataCell>();
        targetCategories.addAll(targetCol.getDomain().getValues());
        Collections.sort(targetCategories, targetCol.getType().getComparator());
        if (targetCategories.size() == modelTargetCategories.size() + 1) {
            targetReferenceCategory = targetCategories.get(targetCategories.size() - 1).toString();
            // the last target category is the target reference category
            LOGGER.debug("The target reference category is not explicitly set in PMML. Automatically choose : " + targetReferenceCategory);
        } else {
            throw new InvalidSettingsException("Please set the attribute \"targetReferenceCategory\" of the" + "\"GeneralRegression\" element in the PMML file.");
        }
    }
    modelTargetCategories.add(domainValues.get(targetReferenceCategory));
    List<DataCell> toReturn = new ArrayList<DataCell>();
    toReturn.addAll(modelTargetCategories);
    return toReturn;
}
Also used : LinkedHashSet(java.util.LinkedHashSet) PMMLPCell(org.knime.base.node.mine.regression.pmmlgreg.PMMLPCell) HashMap(java.util.HashMap) InvalidSettingsException(org.knime.core.node.InvalidSettingsException) ArrayList(java.util.ArrayList) DataCell(org.knime.core.data.DataCell)

Example 75 with DataCell

use of org.knime.core.data.DataCell in project knime-core by knime.

the class LogRegPredictor method getCells.

/**
 * {@inheritDoc}
 */
@Override
public DataCell[] getCells(final DataRow row) {
    if (hasMissingValues(row)) {
        return createMissingOutput();
    }
    final MissingHandling missingHandling = new MissingHandling(true);
    DataCell[] cells = m_includeProbs ? new DataCell[1 + m_targetDomainValuesCount] : new DataCell[1];
    Arrays.fill(cells, new IntCell(0));
    // column vector
    final RealMatrix x = MatrixUtils.createRealMatrix(1, m_parameters.size());
    for (int i = 0; i < m_parameters.size(); i++) {
        String parameter = m_parameters.get(i);
        String predictor = null;
        String value = null;
        boolean rowIsEmpty = true;
        for (final Iterator<String> iter = m_predictors.iterator(); iter.hasNext(); ) {
            predictor = iter.next();
            value = m_ppMatrix.getValue(parameter, predictor, null);
            if (null != value) {
                rowIsEmpty = false;
                break;
            }
        }
        if (rowIsEmpty) {
            x.setEntry(0, i, 1);
        } else {
            if (m_factors.contains(predictor)) {
                List<DataCell> values = m_values.get(predictor);
                DataCell cell = row.getCell(m_parameterI.get(parameter));
                int index = values.indexOf(cell);
                /* When building a general regression model, for each
                    categorical fields, there is one category used as the
                    default baseline and therefore it didn't show in the
                    ParameterList in PMML. This design for the training is fine,
                    but in the prediction, when the input of Employment is
                    the default baseline, the parameters should all be 0.
                    See the commit message for an example and more details.
                    */
                if (index > 0) {
                    x.setEntry(0, i + index - 1, 1);
                    i += values.size() - 2;
                }
            } else if (m_baseLabelToColName.containsKey(parameter) && m_vectorLengths.containsKey(m_baseLabelToColName.get(parameter))) {
                final DataCell cell = row.getCell(m_parameterI.get(parameter));
                Optional<NameAndIndex> vectorValue = VectorHandling.parse(predictor);
                if (vectorValue.isPresent()) {
                    int j = vectorValue.get().getIndex();
                    value = m_ppMatrix.getValue(parameter, predictor, null);
                    double exponent = Integer.valueOf(value);
                    double radix = RegressionTrainingRow.getValue(cell, j, missingHandling);
                    x.setEntry(0, i, Math.pow(radix, exponent));
                }
            } else {
                DataCell cell = row.getCell(m_parameterI.get(parameter));
                double radix = ((DoubleValue) cell).getDoubleValue();
                double exponent = Integer.valueOf(value);
                x.setEntry(0, i, Math.pow(radix, exponent));
            }
        }
    }
    // column vector
    RealMatrix r = x.multiply(m_beta);
    // determine the column with highest probability
    int maxIndex = 0;
    double maxValue = r.getEntry(0, 0);
    for (int i = 1; i < r.getColumnDimension(); i++) {
        if (r.getEntry(0, i) > maxValue) {
            maxValue = r.getEntry(0, i);
            maxIndex = i;
        }
    }
    if (m_includeProbs) {
        // compute probabilities of the target categories
        for (int i = 0; i < m_targetCategories.size(); i++) {
            // test if calculation would overflow
            boolean overflow = false;
            for (int k = 0; k < r.getColumnDimension(); k++) {
                if ((r.getEntry(0, k) - r.getEntry(0, i)) > 700) {
                    overflow = true;
                }
            }
            if (!overflow) {
                double sum = 0;
                for (int k = 0; k < r.getColumnDimension(); k++) {
                    sum += Math.exp(r.getEntry(0, k) - r.getEntry(0, i));
                }
                cells[m_targetCategoryIndex.get(i)] = new DoubleCell(1.0 / sum);
            } else {
                cells[m_targetCategoryIndex.get(i)] = new DoubleCell(0);
            }
        }
    }
    // the last cell is the prediction
    cells[cells.length - 1] = m_targetCategories.get(maxIndex);
    return cells;
}
Also used : Optional(java.util.Optional) DoubleCell(org.knime.core.data.def.DoubleCell) IntCell(org.knime.core.data.def.IntCell) RealMatrix(org.apache.commons.math3.linear.RealMatrix) MissingHandling(org.knime.base.node.mine.regression.RegressionTrainingRow.MissingHandling) DataCell(org.knime.core.data.DataCell)

Aggregations

DataCell (org.knime.core.data.DataCell)780 DataRow (org.knime.core.data.DataRow)268 DataTableSpec (org.knime.core.data.DataTableSpec)175 DataColumnSpec (org.knime.core.data.DataColumnSpec)170 DefaultRow (org.knime.core.data.def.DefaultRow)169 ArrayList (java.util.ArrayList)141 StringCell (org.knime.core.data.def.StringCell)131 DoubleCell (org.knime.core.data.def.DoubleCell)129 DoubleValue (org.knime.core.data.DoubleValue)111 InvalidSettingsException (org.knime.core.node.InvalidSettingsException)109 DataType (org.knime.core.data.DataType)97 RowKey (org.knime.core.data.RowKey)94 BufferedDataTable (org.knime.core.node.BufferedDataTable)93 BufferedDataContainer (org.knime.core.node.BufferedDataContainer)91 DataColumnSpecCreator (org.knime.core.data.DataColumnSpecCreator)84 LinkedHashMap (java.util.LinkedHashMap)81 IntCell (org.knime.core.data.def.IntCell)79 HashMap (java.util.HashMap)60 SettingsModelString (org.knime.core.node.defaultnodesettings.SettingsModelString)57 ColumnRearranger (org.knime.core.data.container.ColumnRearranger)56