Search in sources :

Example 16 with StringCell

use of org.knime.core.data.def.StringCell in project knime-core by knime.

the class TreeEnsembleClassificationPredictorCellFactory2 method getCells.

/**
 * {@inheritDoc}
 */
@Override
public DataCell[] getCells(final DataRow row) {
    TreeEnsembleModelPortObject modelObject = m_predictor.getModelObject();
    TreeEnsemblePredictorConfiguration cfg = m_predictor.getConfiguration();
    final TreeEnsembleModel ensembleModel = modelObject.getEnsembleModel();
    int size = 1;
    final boolean appendConfidence = cfg.isAppendPredictionConfidence();
    if (appendConfidence) {
        size += 1;
    }
    final boolean appendClassConfidences = cfg.isAppendClassConfidences();
    if (appendClassConfidences) {
        size += m_targetValueMap.size();
    }
    final boolean appendModelCount = cfg.isAppendModelCount();
    if (appendModelCount) {
        size += 1;
    }
    final boolean hasOutOfBagFilter = m_predictor.hasOutOfBagFilter();
    DataCell[] result = new DataCell[size];
    DataRow filterRow = new FilterColumnRow(row, m_learnColumnInRealDataIndices);
    PredictorRecord record = ensembleModel.createPredictorRecord(filterRow, m_learnSpec);
    if (record == null) {
        // missing value
        Arrays.fill(result, DataType.getMissingCell());
        return result;
    }
    OccurrenceCounter<String> counter = new OccurrenceCounter<String>();
    final int nrModels = ensembleModel.getNrModels();
    TreeTargetNominalColumnMetaData targetMeta = (TreeTargetNominalColumnMetaData) ensembleModel.getMetaData().getTargetMetaData();
    final double[] classProbabilities = new double[targetMeta.getValues().length];
    int nrValidModels = 0;
    for (int i = 0; i < nrModels; i++) {
        if (hasOutOfBagFilter && m_predictor.isRowPartOfTrainingData(row.getKey(), i)) {
        // ignore, row was used to train the model
        } else {
            TreeModelClassification m = ensembleModel.getTreeModelClassification(i);
            TreeNodeClassification match = m.findMatchingNode(record);
            String majorityClassName = match.getMajorityClassName();
            final float[] nodeClassProbs = match.getTargetDistribution();
            double instancesInNode = 0;
            for (int c = 0; c < nodeClassProbs.length; c++) {
                instancesInNode += nodeClassProbs[c];
            }
            for (int c = 0; c < classProbabilities.length; c++) {
                classProbabilities[c] += nodeClassProbs[c] / instancesInNode;
            }
            counter.add(majorityClassName);
            nrValidModels += 1;
        }
    }
    String bestValue = counter.getMostFrequent();
    int index = 0;
    if (bestValue == null) {
        assert nrValidModels == 0;
        Arrays.fill(result, DataType.getMissingCell());
        index = size - 1;
    } else {
        // result[index++] = m_targetValueMap.get(bestValue);
        int indexBest = -1;
        double probBest = -1;
        for (int c = 0; c < classProbabilities.length; c++) {
            double prob = classProbabilities[c];
            if (prob > probBest) {
                probBest = prob;
                indexBest = c;
            }
        }
        result[index++] = new StringCell(targetMeta.getValues()[indexBest].getNominalValue());
        if (appendConfidence) {
            // final int freqValue = counter.getFrequency(bestValue);
            // result[index++] = new DoubleCell(freqValue / (double)nrValidModels);
            result[index++] = new DoubleCell(probBest);
        }
        if (appendClassConfidences) {
            for (NominalValueRepresentation nomVal : targetMeta.getValues()) {
                double prob = classProbabilities[nomVal.getAssignedInteger()] / nrValidModels;
                result[index++] = new DoubleCell(prob);
            }
        }
    }
    if (appendModelCount) {
        result[index++] = new IntCell(nrValidModels);
    }
    return result;
}
Also used : TreeNodeClassification(org.knime.base.node.mine.treeensemble2.model.TreeNodeClassification) TreeEnsembleModel(org.knime.base.node.mine.treeensemble2.model.TreeEnsembleModel) TreeTargetNominalColumnMetaData(org.knime.base.node.mine.treeensemble2.data.TreeTargetNominalColumnMetaData) DoubleCell(org.knime.core.data.def.DoubleCell) TreeEnsemblePredictorConfiguration(org.knime.base.node.mine.treeensemble2.node.predictor.TreeEnsemblePredictorConfiguration) NominalValueRepresentation(org.knime.base.node.mine.treeensemble2.data.NominalValueRepresentation) DataRow(org.knime.core.data.DataRow) IntCell(org.knime.core.data.def.IntCell) TreeEnsembleModelPortObject(org.knime.base.node.mine.treeensemble2.model.TreeEnsembleModelPortObject) StringCell(org.knime.core.data.def.StringCell) PredictorRecord(org.knime.base.node.mine.treeensemble2.data.PredictorRecord) DataCell(org.knime.core.data.DataCell) FilterColumnRow(org.knime.base.data.filter.column.FilterColumnRow) TreeModelClassification(org.knime.base.node.mine.treeensemble2.model.TreeModelClassification)

Example 17 with StringCell

use of org.knime.core.data.def.StringCell in project knime-core by knime.

the class RuleEngineNodeModel method createRearranger.

private ColumnRearranger createRearranger(final DataTableSpec inSpec, final List<Rule> rules) throws InvalidSettingsException {
    ColumnRearranger crea = new ColumnRearranger(inSpec);
    String newColName = DataTableSpec.getUniqueColumnName(inSpec, m_settings.getNewColName());
    final int defaultLabelColumnIndex;
    if (m_settings.getDefaultLabelIsColumn()) {
        if (m_settings.getDefaultLabel().length() < 3) {
            throw new InvalidSettingsException("Default label is not a column reference");
        }
        if (!m_settings.getDefaultLabel().startsWith("$") || !m_settings.getDefaultLabel().endsWith("$")) {
            throw new InvalidSettingsException("Column references in default label must be enclosed in $");
        }
        String colRef = m_settings.getDefaultLabel().substring(1, m_settings.getDefaultLabel().length() - 1);
        defaultLabelColumnIndex = inSpec.findColumnIndex(colRef);
        if (defaultLabelColumnIndex == -1) {
            throw new InvalidSettingsException("Column '" + m_settings.getDefaultLabel() + "' for default label does not exist in input table");
        }
    } else {
        defaultLabelColumnIndex = -1;
    }
    // determine output type
    List<DataType> types = new ArrayList<DataType>();
    // add outcome column types
    for (Rule r : rules) {
        if (r.getOutcome() instanceof ColumnReference) {
            types.add(((ColumnReference) r.getOutcome()).spec.getType());
        } else if (r.getOutcome() instanceof Double) {
            types.add(DoubleCell.TYPE);
        } else if (r.getOutcome() instanceof Integer) {
            types.add(IntCell.TYPE);
        } else if (r.getOutcome().toString().length() > 0) {
            types.add(StringCell.TYPE);
        }
    }
    if (defaultLabelColumnIndex >= 0) {
        types.add(inSpec.getColumnSpec(defaultLabelColumnIndex).getType());
    } else if (m_settings.getDefaultLabel().length() > 0) {
        try {
            Integer.parseInt(m_settings.getDefaultLabel());
            types.add(IntCell.TYPE);
        } catch (NumberFormatException ex) {
            try {
                Double.parseDouble(m_settings.getDefaultLabel());
                types.add(DoubleCell.TYPE);
            } catch (NumberFormatException ex1) {
                types.add(StringCell.TYPE);
            }
        }
    }
    final DataType outType;
    if (types.size() > 0) {
        DataType temp = types.get(0);
        for (int i = 1; i < types.size(); i++) {
            temp = DataType.getCommonSuperType(temp, types.get(i));
        }
        if ((temp.getValueClasses().size() == 1) && temp.getValueClasses().contains(DataValue.class)) {
            // a non-native type, we replace it with string
            temp = StringCell.TYPE;
        }
        outType = temp;
    } else {
        outType = StringCell.TYPE;
    }
    DataColumnSpec cs = new DataColumnSpecCreator(newColName, outType).createSpec();
    crea.append(new SingleCellFactory(cs) {

        @Override
        public DataCell getCell(final DataRow row) {
            for (Rule r : rules) {
                if (r.matches(row)) {
                    Object outcome = r.getOutcome();
                    if (outcome instanceof ColumnReference) {
                        DataCell cell = row.getCell(((ColumnReference) outcome).index);
                        if (outType.equals(StringCell.TYPE) && !cell.isMissing() && !cell.getType().equals(StringCell.TYPE)) {
                            return new StringCell(cell.toString());
                        } else {
                            return cell;
                        }
                    } else if (outType.equals(IntCell.TYPE)) {
                        return new IntCell((Integer) outcome);
                    } else if (outType.equals(DoubleCell.TYPE)) {
                        return new DoubleCell((Double) outcome);
                    } else {
                        return new StringCell(outcome.toString());
                    }
                }
            }
            if (defaultLabelColumnIndex >= 0) {
                DataCell cell = row.getCell(defaultLabelColumnIndex);
                if (outType.equals(StringCell.TYPE) && !cell.getType().equals(StringCell.TYPE)) {
                    return new StringCell(cell.toString());
                } else {
                    return cell;
                }
            } else if (m_settings.getDefaultLabel().length() > 0) {
                String l = m_settings.getDefaultLabel();
                if (outType.equals(StringCell.TYPE)) {
                    return new StringCell(l);
                }
                try {
                    int i = Integer.parseInt(l);
                    return new IntCell(i);
                } catch (NumberFormatException ex) {
                    try {
                        double d = Double.parseDouble(l);
                        return new DoubleCell(d);
                    } catch (NumberFormatException ex1) {
                        return new StringCell(l);
                    }
                }
            } else {
                return DataType.getMissingCell();
            }
        }
    });
    return crea;
}
Also used : DataColumnSpecCreator(org.knime.core.data.DataColumnSpecCreator) DataValue(org.knime.core.data.DataValue) DoubleCell(org.knime.core.data.def.DoubleCell) ArrayList(java.util.ArrayList) DataRow(org.knime.core.data.DataRow) IntCell(org.knime.core.data.def.IntCell) ColumnRearranger(org.knime.core.data.container.ColumnRearranger) DataColumnSpec(org.knime.core.data.DataColumnSpec) InvalidSettingsException(org.knime.core.node.InvalidSettingsException) StringCell(org.knime.core.data.def.StringCell) DataType(org.knime.core.data.DataType) DataCell(org.knime.core.data.DataCell) SingleCellFactory(org.knime.core.data.container.SingleCellFactory) ColumnReference(org.knime.base.node.rules.Rule.ColumnReference)

Example 18 with StringCell

use of org.knime.core.data.def.StringCell in project knime-core by knime.

the class LogisticRegressionContent method createTablePortObject.

/**
 * Creates a BufferedDataTable with the
 * @param exec The execution context
 * @return a port object
 */
public BufferedDataTable createTablePortObject(final ExecutionContext exec) {
    DataTableSpec tableOutSpec = new DataTableSpec("Coefficients and Statistics", new String[] { "Logit", "Variable", "Coeff.", "Std. Err.", "z-score", "P>|z|" }, new DataType[] { StringCell.TYPE, StringCell.TYPE, DoubleCell.TYPE, DoubleCell.TYPE, DoubleCell.TYPE, DoubleCell.TYPE });
    BufferedDataContainer dc = exec.createDataContainer(tableOutSpec);
    List<DataCell> logits = this.getLogits();
    List<String> parameters = this.getParameters();
    int c = 0;
    for (DataCell logit : logits) {
        Map<String, Double> coefficients = this.getCoefficients(logit);
        Map<String, Double> stdErrs = this.getStandardErrors(logit);
        Map<String, Double> zScores = this.getZScores(logit);
        Map<String, Double> pValues = this.getPValues(logit);
        for (String parameter : parameters) {
            List<DataCell> cells = new ArrayList<DataCell>();
            cells.add(new StringCell(logit.toString()));
            cells.add(new StringCell(parameter));
            cells.add(new DoubleCell(coefficients.get(parameter)));
            cells.add(new DoubleCell(stdErrs.get(parameter)));
            cells.add(new DoubleCell(zScores.get(parameter)));
            cells.add(new DoubleCell(pValues.get(parameter)));
            c++;
            dc.addRowToTable(new DefaultRow("Row" + c, cells));
        }
        List<DataCell> cells = new ArrayList<DataCell>();
        cells.add(new StringCell(logit.toString()));
        cells.add(new StringCell("Constant"));
        cells.add(new DoubleCell(this.getIntercept(logit)));
        cells.add(new DoubleCell(this.getInterceptStdErr(logit)));
        cells.add(new DoubleCell(this.getInterceptZScore(logit)));
        cells.add(new DoubleCell(this.getInterceptPValue(logit)));
        c++;
        dc.addRowToTable(new DefaultRow("Row" + c, cells));
    }
    dc.close();
    return dc.getTable();
}
Also used : DataTableSpec(org.knime.core.data.DataTableSpec) BufferedDataContainer(org.knime.core.node.BufferedDataContainer) DoubleCell(org.knime.core.data.def.DoubleCell) ArrayList(java.util.ArrayList) StringCell(org.knime.core.data.def.StringCell) DataCell(org.knime.core.data.DataCell) DefaultRow(org.knime.core.data.def.DefaultRow)

Example 19 with StringCell

use of org.knime.core.data.def.StringCell in project knime-core by knime.

the class LogisticRegressionContent method createTablePortObject.

/**
 * Creates a BufferedDataTable with the
 * @param exec The execution context
 * @return a port object
 */
public BufferedDataTable createTablePortObject(final ExecutionContext exec) {
    DataTableSpec tableOutSpec = new DataTableSpec("Coefficients and Statistics", new String[] { "Logit", "Variable", "Coeff.", "Std. Err.", "z-score", "P>|z|" }, new DataType[] { StringCell.TYPE, StringCell.TYPE, DoubleCell.TYPE, DoubleCell.TYPE, DoubleCell.TYPE, DoubleCell.TYPE });
    BufferedDataContainer dc = exec.createDataContainer(tableOutSpec);
    List<DataCell> logits = this.getLogits();
    List<String> parameters = this.getParameters();
    int c = 0;
    for (DataCell logit : logits) {
        Map<String, Double> coefficients = this.getCoefficients(logit);
        Map<String, Double> stdErrs = this.getStandardErrors(logit);
        Map<String, Double> zScores = this.getZScores(logit);
        Map<String, Double> pValues = this.getPValues(logit);
        for (String parameter : parameters) {
            List<DataCell> cells = new ArrayList<DataCell>();
            cells.add(new StringCell(logit.toString()));
            cells.add(new StringCell(parameter));
            cells.add(new DoubleCell(coefficients.get(parameter)));
            cells.add(new DoubleCell(stdErrs.get(parameter)));
            cells.add(new DoubleCell(zScores.get(parameter)));
            cells.add(new DoubleCell(pValues.get(parameter)));
            c++;
            dc.addRowToTable(new DefaultRow("Row" + c, cells));
        }
        List<DataCell> cells = new ArrayList<DataCell>();
        cells.add(new StringCell(logit.toString()));
        cells.add(new StringCell("Constant"));
        cells.add(new DoubleCell(this.getIntercept(logit)));
        cells.add(new DoubleCell(this.getInterceptStdErr(logit)));
        cells.add(new DoubleCell(this.getInterceptZScore(logit)));
        cells.add(new DoubleCell(this.getInterceptPValue(logit)));
        c++;
        dc.addRowToTable(new DefaultRow("Row" + c, cells));
    }
    dc.close();
    return dc.getTable();
}
Also used : DataTableSpec(org.knime.core.data.DataTableSpec) BufferedDataContainer(org.knime.core.node.BufferedDataContainer) DoubleCell(org.knime.core.data.def.DoubleCell) ArrayList(java.util.ArrayList) StringCell(org.knime.core.data.def.StringCell) DataCell(org.knime.core.data.DataCell) DefaultRow(org.knime.core.data.def.DefaultRow)

Example 20 with StringCell

use of org.knime.core.data.def.StringCell in project knime-core by knime.

the class AbstractTreeEnsembleModel method createLearnAttributeRow.

public DataRow createLearnAttributeRow(final DataRow learnRow, final DataTableSpec learnSpec) {
    final TreeType type = getType();
    final DataCell c = learnRow.getCell(0);
    final int nrAttributes = getMetaData().getNrAttributes();
    switch(type) {
        case Ordinary:
            return learnRow;
        case BitVector:
            if (c.isMissing()) {
                return null;
            }
            BitVectorValue bv = (BitVectorValue) c;
            final long length = bv.length();
            if (length != nrAttributes) {
                // TODO indicate error message
                return null;
            }
            DataCell trueCell = new StringCell("1");
            DataCell falseCell = new StringCell("0");
            DataCell[] cells = new DataCell[nrAttributes];
            for (int i = 0; i < nrAttributes; i++) {
                cells[i] = bv.get(i) ? trueCell : falseCell;
            }
            return new DefaultRow(learnRow.getKey(), cells);
        case ByteVector:
            if (c.isMissing()) {
                return null;
            }
            ByteVectorValue byteVector = (ByteVectorValue) c;
            final long bvLength = byteVector.length();
            if (bvLength != nrAttributes) {
                return null;
            }
            DataCell[] bvCells = new DataCell[nrAttributes];
            for (int i = 0; i < nrAttributes; i++) {
                bvCells[i] = new IntCell(byteVector.get(i));
            }
            return new DefaultRow(learnRow.getKey(), bvCells);
        case DoubleVector:
            if (c.isMissing()) {
                return null;
            }
            DoubleVectorValue doubleVector = (DoubleVectorValue) c;
            final int dvLength = doubleVector.getLength();
            if (dvLength != nrAttributes) {
                return null;
            }
            DataCell[] dvCells = new DataCell[nrAttributes];
            for (int i = 0; i < nrAttributes; i++) {
                dvCells[i] = new DoubleCell(doubleVector.getValue(i));
            }
            return new DefaultRow(learnRow.getKey(), dvCells);
        default:
            throw new IllegalStateException("Type unknown (not implemented): " + type);
    }
}
Also used : DoubleVectorValue(org.knime.core.data.vector.doublevector.DoubleVectorValue) DoubleCell(org.knime.core.data.def.DoubleCell) ByteVectorValue(org.knime.core.data.vector.bytevector.ByteVectorValue) IntCell(org.knime.core.data.def.IntCell) StringCell(org.knime.core.data.def.StringCell) DataCell(org.knime.core.data.DataCell) BitVectorValue(org.knime.core.data.vector.bitvector.BitVectorValue) DefaultRow(org.knime.core.data.def.DefaultRow)

Aggregations

StringCell (org.knime.core.data.def.StringCell)176 DataCell (org.knime.core.data.DataCell)130 DoubleCell (org.knime.core.data.def.DoubleCell)67 DefaultRow (org.knime.core.data.def.DefaultRow)65 IntCell (org.knime.core.data.def.IntCell)55 DataRow (org.knime.core.data.DataRow)52 DataTableSpec (org.knime.core.data.DataTableSpec)49 ArrayList (java.util.ArrayList)41 DataColumnSpec (org.knime.core.data.DataColumnSpec)37 RowKey (org.knime.core.data.RowKey)36 DataColumnSpecCreator (org.knime.core.data.DataColumnSpecCreator)26 BufferedDataContainer (org.knime.core.node.BufferedDataContainer)26 DataType (org.knime.core.data.DataType)22 LinkedHashSet (java.util.LinkedHashSet)21 BufferedDataTable (org.knime.core.node.BufferedDataTable)20 ColumnRearranger (org.knime.core.data.container.ColumnRearranger)19 InvalidSettingsException (org.knime.core.node.InvalidSettingsException)16 LinkedHashMap (java.util.LinkedHashMap)15 Test (org.junit.Test)15 HashMap (java.util.HashMap)11