Search in sources :

Example 16 with DoubleCell

use of org.knime.core.data.def.DoubleCell in project knime-core by knime.

the class DecTreePredictorNodeModel method execute.

/**
 * {@inheritDoc}
 */
@Override
public PortObject[] execute(final PortObject[] inPorts, final ExecutionContext exec) throws CanceledExecutionException, Exception {
    exec.setMessage("Decision Tree Predictor: Loading predictor...");
    PMMLPortObject port = (PMMLPortObject) inPorts[INMODELPORT];
    List<Node> models = port.getPMMLValue().getModels(PMMLModelType.TreeModel);
    if (models.isEmpty()) {
        String msg = "Decision Tree evaluation failed: " + "No tree model found.";
        LOGGER.error(msg);
        throw new RuntimeException(msg);
    }
    PMMLDecisionTreeTranslator trans = new PMMLDecisionTreeTranslator();
    port.initializeModelTranslator(trans);
    DecisionTree decTree = trans.getDecisionTree();
    decTree.resetColorInformation();
    BufferedDataTable inData = (BufferedDataTable) inPorts[INDATAPORT];
    // get column with color information
    String colorColumn = null;
    for (DataColumnSpec s : inData.getDataTableSpec()) {
        if (s.getColorHandler() != null) {
            colorColumn = s.getName();
            break;
        }
    }
    decTree.setColorColumn(colorColumn);
    exec.setMessage("Decision Tree Predictor: start execution.");
    PortObjectSpec[] inSpecs = new PortObjectSpec[] { inPorts[0].getSpec(), inPorts[1].getSpec() };
    DataTableSpec outSpec = createOutTableSpec(inSpecs);
    BufferedDataContainer outData = exec.createDataContainer(outSpec);
    long coveredPattern = 0;
    long nrPattern = 0;
    long rowCount = 0;
    long numberRows = inData.size();
    exec.setMessage("Classifying...");
    for (DataRow thisRow : inData) {
        DataCell cl = null;
        LinkedHashMap<String, Double> classDistrib = null;
        try {
            Pair<DataCell, LinkedHashMap<DataCell, Double>> pair = decTree.getWinnerAndClasscounts(thisRow, inData.getDataTableSpec());
            cl = pair.getFirst();
            LinkedHashMap<DataCell, Double> classCounts = pair.getSecond();
            classDistrib = getDistribution(classCounts);
            if (coveredPattern < m_maxNumCoveredPattern.getIntValue()) {
                // remember this one for HiLite support
                decTree.addCoveredPattern(thisRow, inData.getDataTableSpec());
                coveredPattern++;
            } else {
                // too many patterns for HiLite - at least remember color
                decTree.addCoveredColor(thisRow, inData.getDataTableSpec());
            }
            nrPattern++;
        } catch (Exception e) {
            LOGGER.error("Decision Tree evaluation failed: " + e.getMessage());
            throw e;
        }
        if (cl == null) {
            LOGGER.error("Decision Tree evaluation failed: result empty");
            throw new Exception("Decision Tree evaluation failed.");
        }
        DataCell[] newCells = new DataCell[outSpec.getNumColumns()];
        int numInCells = thisRow.getNumCells();
        for (int i = 0; i < numInCells; i++) {
            newCells[i] = thisRow.getCell(i);
        }
        if (m_showDistribution.getBooleanValue()) {
            for (int i = numInCells; i < newCells.length - 1; i++) {
                String predClass = outSpec.getColumnSpec(i).getName();
                if (classDistrib != null && classDistrib.get(predClass) != null) {
                    newCells[i] = new DoubleCell(classDistrib.get(predClass));
                } else {
                    newCells[i] = new DoubleCell(0.0);
                }
            }
        }
        newCells[newCells.length - 1] = cl;
        outData.addRowToTable(new DefaultRow(thisRow.getKey(), newCells));
        rowCount++;
        if (rowCount % 100 == 0) {
            exec.setProgress(rowCount / (double) numberRows, "Classifying... Row " + rowCount + " of " + numberRows);
        }
        exec.checkCanceled();
    }
    if (coveredPattern < nrPattern) {
        // let the user know that we did not store all available pattern
        // for HiLiting.
        this.setWarningMessage("Tree only stored first " + m_maxNumCoveredPattern.getIntValue() + " (of " + nrPattern + ") rows for HiLiting!");
    }
    outData.close();
    m_decTree = decTree;
    exec.setMessage("Decision Tree Predictor: end execution.");
    return new BufferedDataTable[] { outData.getTable() };
}
Also used : DataTableSpec(org.knime.core.data.DataTableSpec) PMMLDecisionTreeTranslator(org.knime.base.node.mine.decisiontree2.PMMLDecisionTreeTranslator) DoubleCell(org.knime.core.data.def.DoubleCell) Node(org.w3c.dom.Node) DataRow(org.knime.core.data.DataRow) LinkedHashMap(java.util.LinkedHashMap) DataColumnSpec(org.knime.core.data.DataColumnSpec) BufferedDataTable(org.knime.core.node.BufferedDataTable) PMMLPortObjectSpec(org.knime.core.node.port.pmml.PMMLPortObjectSpec) PortObjectSpec(org.knime.core.node.port.PortObjectSpec) DecisionTree(org.knime.base.node.mine.decisiontree2.model.DecisionTree) BufferedDataContainer(org.knime.core.node.BufferedDataContainer) InvalidSettingsException(org.knime.core.node.InvalidSettingsException) CanceledExecutionException(org.knime.core.node.CanceledExecutionException) IOException(java.io.IOException) PMMLPortObject(org.knime.core.node.port.pmml.PMMLPortObject) DataCell(org.knime.core.data.DataCell) DefaultRow(org.knime.core.data.def.DefaultRow)

Example 17 with DoubleCell

use of org.knime.core.data.def.DoubleCell in project knime-core by knime.

the class DecTreePredictorNodeModel method createOutTableSpec.

private DataTableSpec createOutTableSpec(final PortObjectSpec[] inSpecs) {
    LinkedList<DataCell> predValues = null;
    if (m_showDistribution.getBooleanValue()) {
        predValues = getPredictionValues((PMMLPortObjectSpec) inSpecs[INMODELPORT]);
        if (predValues == null) {
            // no out spec can be determined
            return null;
        }
    }
    int numCols = (predValues == null ? 0 : predValues.size()) + 1;
    DataTableSpec inSpec = (DataTableSpec) inSpecs[INDATAPORT];
    UniqueNameGenerator nameGenerator = new UniqueNameGenerator(inSpec);
    DataColumnSpec[] newCols = new DataColumnSpec[numCols];
    /* Set bar renderer and domain [0,1] as default for the double cells
         * containing the distribution */
    // DataColumnProperties propsRendering = new DataColumnProperties(
    // Collections.singletonMap(
    // DataValueRenderer.PROPERTY_PREFERRED_RENDERER,
    // DoubleBarRenderer.DESCRIPTION));
    DataColumnDomain domain = new DataColumnDomainCreator(new DoubleCell(0.0), new DoubleCell(1.0)).createDomain();
    // add all distribution columns
    for (int i = 0; i < numCols - 1; i++) {
        DataColumnSpecCreator colSpecCreator = nameGenerator.newCreator(predValues.get(i).toString(), DoubleCell.TYPE);
        // colSpecCreator.setProperties(propsRendering);
        colSpecCreator.setDomain(domain);
        newCols[i] = colSpecCreator.createSpec();
    }
    // add the prediction column
    newCols[numCols - 1] = nameGenerator.newColumn("Prediction (DecTree)", StringCell.TYPE);
    DataTableSpec newColSpec = new DataTableSpec(newCols);
    return new DataTableSpec(inSpec, newColSpec);
}
Also used : PMMLPortObjectSpec(org.knime.core.node.port.pmml.PMMLPortObjectSpec) DataTableSpec(org.knime.core.data.DataTableSpec) DataColumnSpec(org.knime.core.data.DataColumnSpec) DataColumnDomain(org.knime.core.data.DataColumnDomain) DataColumnSpecCreator(org.knime.core.data.DataColumnSpecCreator) DoubleCell(org.knime.core.data.def.DoubleCell) DataCell(org.knime.core.data.DataCell) DataColumnDomainCreator(org.knime.core.data.DataColumnDomainCreator) UniqueNameGenerator(org.knime.core.util.UniqueNameGenerator)

Example 18 with DoubleCell

use of org.knime.core.data.def.DoubleCell in project knime-core by knime.

the class LinearRegressionContent method predict.

/**
 * Predicts the target value for the given row.
 *
 * @param row a data row to predict
 * @return the predicted value in a data cell
 */
public DataCell predict(final DataRow row) {
    double sum = m_offset;
    for (int i = 0; i < row.getNumCells(); i++) {
        DataCell c = row.getCell(i);
        if (c.isMissing()) {
            return DataType.getMissingCell();
        }
        double d = ((DoubleCell) c).getDoubleValue();
        sum += m_multipliers[i] * d;
    }
    return new DoubleCell(sum);
}
Also used : DoubleCell(org.knime.core.data.def.DoubleCell) DataCell(org.knime.core.data.DataCell)

Example 19 with DoubleCell

use of org.knime.core.data.def.DoubleCell in project knime-core by knime.

the class LogRegPredictor method getCells.

/**
 * {@inheritDoc}
 */
@Override
public DataCell[] getCells(final DataRow row) {
    if (hasMissingValues(row)) {
        return createMissingOutput();
    }
    final MissingHandling missingHandling = new MissingHandling(true);
    DataCell[] cells = m_includeProbs ? new DataCell[1 + m_targetDomainValuesCount] : new DataCell[1];
    Arrays.fill(cells, new IntCell(0));
    // column vector
    final RealMatrix x = MatrixUtils.createRealMatrix(1, m_parameters.size());
    for (int i = 0; i < m_parameters.size(); i++) {
        String parameter = m_parameters.get(i);
        String predictor = null;
        String value = null;
        boolean rowIsEmpty = true;
        for (final Iterator<String> iter = m_predictors.iterator(); iter.hasNext(); ) {
            predictor = iter.next();
            value = m_ppMatrix.getValue(parameter, predictor, null);
            if (null != value) {
                rowIsEmpty = false;
                break;
            }
        }
        if (rowIsEmpty) {
            x.setEntry(0, i, 1);
        } else {
            if (m_factors.contains(predictor)) {
                List<DataCell> values = m_values.get(predictor);
                DataCell cell = row.getCell(m_parameterI.get(parameter));
                int index = values.indexOf(cell);
                /* When building a general regression model, for each
                    categorical fields, there is one category used as the
                    default baseline and therefore it didn't show in the
                    ParameterList in PMML. This design for the training is fine,
                    but in the prediction, when the input of Employment is
                    the default baseline, the parameters should all be 0.
                    See the commit message for an example and more details.
                    */
                if (index > 0) {
                    x.setEntry(0, i + index - 1, 1);
                    i += values.size() - 2;
                }
            } else if (m_baseLabelToColName.containsKey(parameter) && m_vectorLengths.containsKey(m_baseLabelToColName.get(parameter))) {
                final DataCell cell = row.getCell(m_parameterI.get(parameter));
                Optional<NameAndIndex> vectorValue = VectorHandling.parse(predictor);
                if (vectorValue.isPresent()) {
                    int j = vectorValue.get().getIndex();
                    value = m_ppMatrix.getValue(parameter, predictor, null);
                    double exponent = Integer.valueOf(value);
                    double radix = RegressionTrainingRow.getValue(cell, j, missingHandling);
                    x.setEntry(0, i, Math.pow(radix, exponent));
                }
            } else {
                DataCell cell = row.getCell(m_parameterI.get(parameter));
                double radix = ((DoubleValue) cell).getDoubleValue();
                double exponent = Integer.valueOf(value);
                x.setEntry(0, i, Math.pow(radix, exponent));
            }
        }
    }
    // column vector
    RealMatrix r = x.multiply(m_beta);
    // determine the column with highest probability
    int maxIndex = 0;
    double maxValue = r.getEntry(0, 0);
    for (int i = 1; i < r.getColumnDimension(); i++) {
        if (r.getEntry(0, i) > maxValue) {
            maxValue = r.getEntry(0, i);
            maxIndex = i;
        }
    }
    if (m_includeProbs) {
        // compute probabilities of the target categories
        for (int i = 0; i < m_targetCategories.size(); i++) {
            // test if calculation would overflow
            boolean overflow = false;
            for (int k = 0; k < r.getColumnDimension(); k++) {
                if ((r.getEntry(0, k) - r.getEntry(0, i)) > 700) {
                    overflow = true;
                }
            }
            if (!overflow) {
                double sum = 0;
                for (int k = 0; k < r.getColumnDimension(); k++) {
                    sum += Math.exp(r.getEntry(0, k) - r.getEntry(0, i));
                }
                cells[m_targetCategoryIndex.get(i)] = new DoubleCell(1.0 / sum);
            } else {
                cells[m_targetCategoryIndex.get(i)] = new DoubleCell(0);
            }
        }
    }
    // the last cell is the prediction
    cells[cells.length - 1] = m_targetCategories.get(maxIndex);
    return cells;
}
Also used : Optional(java.util.Optional) DoubleCell(org.knime.core.data.def.DoubleCell) IntCell(org.knime.core.data.def.IntCell) RealMatrix(org.apache.commons.math3.linear.RealMatrix) MissingHandling(org.knime.base.node.mine.regression.RegressionTrainingRow.MissingHandling) DataCell(org.knime.core.data.DataCell)

Example 20 with DoubleCell

use of org.knime.core.data.def.DoubleCell in project knime-core by knime.

the class AutoBinner method calcDomainBoundsIfNeccessary.

/**
 * Determines the per column min/max values of the given data if not already
 * present in the domain.
 * @param data the data
 * @param exec the execution context
 * @param recalcValuesFor The columns
 * @return The data with extended domain information
 * @throws InvalidSettingsException
 * @throws CanceledExecutionException
 */
public BufferedDataTable calcDomainBoundsIfNeccessary(final BufferedDataTable data, final ExecutionContext exec, final List<String> recalcValuesFor) throws InvalidSettingsException, CanceledExecutionException {
    if (null == recalcValuesFor || recalcValuesFor.isEmpty()) {
        return data;
    }
    List<Integer> valuesI = new ArrayList<Integer>();
    for (String colName : recalcValuesFor) {
        DataColumnSpec colSpec = data.getDataTableSpec().getColumnSpec(colName);
        if (!colSpec.getType().isCompatible(DoubleValue.class)) {
            throw new InvalidSettingsException("Can only process numeric " + "data. The column \"" + colSpec.getName() + "\" is not numeric.");
        }
        if (recalcValuesFor.contains(colName) && !colSpec.getDomain().hasBounds()) {
            valuesI.add(data.getDataTableSpec().findColumnIndex(colName));
        }
    }
    if (valuesI.isEmpty()) {
        return data;
    }
    Map<Integer, Double> min = new HashMap<Integer, Double>();
    Map<Integer, Double> max = new HashMap<Integer, Double>();
    for (int col : valuesI) {
        min.put(col, Double.MAX_VALUE);
        max.put(col, Double.MIN_VALUE);
    }
    int c = 0;
    for (DataRow row : data) {
        c++;
        exec.checkCanceled();
        exec.setProgress(c / (double) data.getRowCount());
        for (int col : valuesI) {
            double val = ((DoubleValue) row.getCell(col)).getDoubleValue();
            if (min.get(col) > val) {
                min.put(col, val);
            }
            if (max.get(col) < val) {
                min.put(col, val);
            }
        }
    }
    List<DataColumnSpec> newColSpecList = new ArrayList<DataColumnSpec>();
    int cc = 0;
    for (DataColumnSpec columnSpec : data.getDataTableSpec()) {
        if (recalcValuesFor.contains(columnSpec.getName())) {
            DataColumnSpecCreator specCreator = new DataColumnSpecCreator(columnSpec);
            DataColumnDomainCreator domainCreator = new DataColumnDomainCreator(new DoubleCell(min.get(cc)), new DoubleCell(max.get(cc)));
            specCreator.setDomain(domainCreator.createDomain());
            DataColumnSpec newColSpec = specCreator.createSpec();
            newColSpecList.add(newColSpec);
        } else {
            newColSpecList.add(columnSpec);
        }
        cc++;
    }
    DataTableSpec spec = new DataTableSpec(newColSpecList.toArray(new DataColumnSpec[0]));
    BufferedDataTable newDataTable = exec.createSpecReplacerTable(data, spec);
    return newDataTable;
}
Also used : DataTableSpec(org.knime.core.data.DataTableSpec) DataColumnSpecCreator(org.knime.core.data.DataColumnSpecCreator) HashMap(java.util.HashMap) LinkedHashMap(java.util.LinkedHashMap) DoubleCell(org.knime.core.data.def.DoubleCell) ArrayList(java.util.ArrayList) DataColumnDomainCreator(org.knime.core.data.DataColumnDomainCreator) DataRow(org.knime.core.data.DataRow) DataColumnSpec(org.knime.core.data.DataColumnSpec) DoubleValue(org.knime.core.data.DoubleValue) InvalidSettingsException(org.knime.core.node.InvalidSettingsException) BufferedDataTable(org.knime.core.node.BufferedDataTable)

Aggregations

DoubleCell (org.knime.core.data.def.DoubleCell)189 DataCell (org.knime.core.data.DataCell)129 IntCell (org.knime.core.data.def.IntCell)67 DefaultRow (org.knime.core.data.def.DefaultRow)66 StringCell (org.knime.core.data.def.StringCell)65 DataRow (org.knime.core.data.DataRow)57 DataTableSpec (org.knime.core.data.DataTableSpec)55 ArrayList (java.util.ArrayList)42 DataColumnSpec (org.knime.core.data.DataColumnSpec)42 DataColumnSpecCreator (org.knime.core.data.DataColumnSpecCreator)41 BufferedDataContainer (org.knime.core.node.BufferedDataContainer)39 RowKey (org.knime.core.data.RowKey)37 DoubleValue (org.knime.core.data.DoubleValue)35 BufferedDataTable (org.knime.core.node.BufferedDataTable)28 DataColumnDomainCreator (org.knime.core.data.DataColumnDomainCreator)26 InvalidSettingsException (org.knime.core.node.InvalidSettingsException)22 DataType (org.knime.core.data.DataType)20 LinkedHashMap (java.util.LinkedHashMap)17 HashMap (java.util.HashMap)13 Point (java.awt.Point)12