Search in sources :

Example 11 with FilterColumnRow

use of org.knime.base.data.filter.column.FilterColumnRow in project knime-core by knime.

the class BasisFunctionPredictorCellFactory method getCells.

/**
 * Predicts given row using the underlying basis function model.
 * {@inheritDoc}
 */
@Override
public DataCell[] getCells(final DataRow row) {
    DataRow wRow = new FilterColumnRow(row, m_filteredColumns);
    DataCell[] pred = predict(wRow, m_model);
    if (m_appendClassProps) {
        // complete prediction including class probs and label
        return pred;
    } else {
        // don't append class probabilities
        return new DataCell[] { pred[pred.length - 1] };
    }
}
Also used : DataCell(org.knime.core.data.DataCell) DataRow(org.knime.core.data.DataRow) FilterColumnRow(org.knime.base.data.filter.column.FilterColumnRow)

Example 12 with FilterColumnRow

use of org.knime.base.data.filter.column.FilterColumnRow in project knime-core by knime.

the class LKGradientBoostingPredictorCellFactory method getCells.

/**
 * {@inheritDoc}
 */
@Override
public DataCell[] getCells(final DataRow row) {
    final DataRow filterRow = new FilterColumnRow(row, m_learnColumnInRealDataIndices);
    final int nrClasses = m_model.getNrClasses();
    final int nrLevels = m_model.getNrLevels();
    final PredictorRecord record = m_model.createPredictorRecord(filterRow, m_learnSpec);
    final double[] classFunctionPredictions = new double[nrClasses];
    Arrays.fill(classFunctionPredictions, m_model.getInitialValue());
    for (int i = 0; i < nrLevels; i++) {
        for (int j = 0; j < nrClasses; j++) {
            final TreeNodeRegression matchingNode = m_model.getModel(i, j).findMatchingNode(record);
            classFunctionPredictions[j] += m_model.getCoefficientMap(i, j).get(matchingNode.getSignature());
        }
    }
    final double[] classProbabilities = new double[nrClasses];
    double expSum = 0;
    for (int i = 0; i < nrClasses; i++) {
        classProbabilities[i] = Math.exp(classFunctionPredictions[i]);
        expSum += classProbabilities[i];
    }
    int classIdx = -1;
    double classProb = -1;
    for (int i = 0; i < nrClasses; i++) {
        classProbabilities[i] /= expSum;
        if (classProbabilities[i] > classProb) {
            classIdx = i;
            classProb = classProbabilities[i];
        }
    }
    final ArrayList<DataCell> cells = new ArrayList<DataCell>();
    cells.add(new StringCell(m_model.getClassLabel(classIdx)));
    if (m_config.isAppendPredictionConfidence()) {
        cells.add(new DoubleCell(classProb));
    }
    if (m_config.isAppendClassConfidences()) {
        // the map is necessary to ensure that the probabilities are correctly associated with the column header
        final Map<String, Double> classProbMap = new HashMap<String, Double>((int) (nrClasses * 1.5));
        for (int i = 0; i < nrClasses; i++) {
            classProbMap.put(m_model.getClassLabel(i), classProbabilities[i]);
        }
        for (final String className : m_targetValueMap.keySet()) {
            cells.add(new DoubleCell(classProbMap.get(className)));
        }
    }
    return cells.toArray(new DataCell[cells.size()]);
}
Also used : HashMap(java.util.HashMap) DoubleCell(org.knime.core.data.def.DoubleCell) ArrayList(java.util.ArrayList) DataRow(org.knime.core.data.DataRow) TreeNodeRegression(org.knime.base.node.mine.treeensemble2.model.TreeNodeRegression) StringCell(org.knime.core.data.def.StringCell) PredictorRecord(org.knime.base.node.mine.treeensemble2.data.PredictorRecord) DataCell(org.knime.core.data.DataCell) FilterColumnRow(org.knime.base.data.filter.column.FilterColumnRow)

Example 13 with FilterColumnRow

use of org.knime.base.data.filter.column.FilterColumnRow in project knime-core by knime.

the class TreeEnsembleRegressionPredictorCellFactory method getCells.

/**
 * {@inheritDoc}
 */
@Override
public DataCell[] getCells(final DataRow row) {
    TreeEnsembleModelPortObject modelObject = m_predictor.getModelObject();
    TreeEnsemblePredictorConfiguration cfg = m_predictor.getConfiguration();
    final TreeEnsembleModel ensembleModel = modelObject.getEnsembleModel();
    int size = 1;
    final boolean appendConfidence = cfg.isAppendPredictionConfidence();
    final boolean appendModelCount = cfg.isAppendModelCount();
    if (appendConfidence) {
        size += 1;
    }
    if (appendModelCount) {
        size += 1;
    }
    final boolean hasOutOfBagFilter = m_predictor.hasOutOfBagFilter();
    DataCell[] result = new DataCell[size];
    DataRow filterRow = new FilterColumnRow(row, m_learnColumnInRealDataIndices);
    PredictorRecord record = ensembleModel.createPredictorRecord(filterRow, m_learnSpec);
    if (record == null) {
        // missing value
        Arrays.fill(result, DataType.getMissingCell());
        return result;
    }
    Mean mean = new Mean();
    Variance variance = new Variance();
    final int nrModels = ensembleModel.getNrModels();
    for (int i = 0; i < nrModels; i++) {
        if (hasOutOfBagFilter && m_predictor.isRowPartOfTrainingData(row.getKey(), i)) {
        // ignore, row was used to train the model
        } else {
            TreeModelRegression m = ensembleModel.getTreeModelRegression(i);
            TreeNodeRegression match = m.findMatchingNode(record);
            double nodeMean = match.getMean();
            mean.increment(nodeMean);
            variance.increment(nodeMean);
        }
    }
    int nrValidModels = (int) mean.getN();
    int index = 0;
    result[index++] = nrValidModels == 0 ? DataType.getMissingCell() : new DoubleCell(mean.getResult());
    if (appendConfidence) {
        result[index++] = nrValidModels == 0 ? DataType.getMissingCell() : new DoubleCell(variance.getResult());
    }
    if (appendModelCount) {
        result[index++] = new IntCell(nrValidModels);
    }
    return result;
}
Also used : Mean(org.apache.commons.math.stat.descriptive.moment.Mean) TreeEnsembleModel(org.knime.base.node.mine.treeensemble2.model.TreeEnsembleModel) DoubleCell(org.knime.core.data.def.DoubleCell) TreeEnsemblePredictorConfiguration(org.knime.base.node.mine.treeensemble2.node.predictor.TreeEnsemblePredictorConfiguration) DataRow(org.knime.core.data.DataRow) TreeNodeRegression(org.knime.base.node.mine.treeensemble2.model.TreeNodeRegression) Variance(org.apache.commons.math.stat.descriptive.moment.Variance) TreeModelRegression(org.knime.base.node.mine.treeensemble2.model.TreeModelRegression) IntCell(org.knime.core.data.def.IntCell) TreeEnsembleModelPortObject(org.knime.base.node.mine.treeensemble2.model.TreeEnsembleModelPortObject) PredictorRecord(org.knime.base.node.mine.treeensemble2.data.PredictorRecord) DataCell(org.knime.core.data.DataCell) FilterColumnRow(org.knime.base.data.filter.column.FilterColumnRow)

Example 14 with FilterColumnRow

use of org.knime.base.data.filter.column.FilterColumnRow in project knime-core by knime.

the class RegressionTreePredictorCellFactory method getCells.

/**
 * {@inheritDoc}
 */
@Override
public DataCell[] getCells(final DataRow row) {
    final RegressionTreeModel treeModel = m_predictor.getModel();
    int size = 1;
    DataCell[] result = new DataCell[size];
    DataRow filterRow = new FilterColumnRow(row, m_learnColumnInRealDataIndices);
    PredictorRecord record = treeModel.createPredictorRecord(filterRow, m_learnSpec);
    if (record == null) {
        // missing value
        Arrays.fill(result, DataType.getMissingCell());
        return result;
    }
    TreeModelRegression tree = treeModel.getTreeModel();
    TreeNodeRegression match = tree.findMatchingNode(record);
    double nodeMean = match.getMean();
    result[0] = new DoubleCell(nodeMean);
    return result;
}
Also used : RegressionTreeModel(org.knime.base.node.mine.treeensemble2.model.RegressionTreeModel) DoubleCell(org.knime.core.data.def.DoubleCell) PredictorRecord(org.knime.base.node.mine.treeensemble2.data.PredictorRecord) DataCell(org.knime.core.data.DataCell) DataRow(org.knime.core.data.DataRow) TreeNodeRegression(org.knime.base.node.mine.treeensemble2.model.TreeNodeRegression) FilterColumnRow(org.knime.base.data.filter.column.FilterColumnRow) TreeModelRegression(org.knime.base.node.mine.treeensemble2.model.TreeModelRegression)

Example 15 with FilterColumnRow

use of org.knime.base.data.filter.column.FilterColumnRow in project knime-core by knime.

the class UnpivotNodeModel method execute.

/**
 * {@inheritDoc}
 */
@Override
protected BufferedDataTable[] execute(final BufferedDataTable[] inData, final ExecutionContext exec) throws Exception {
    DataTableSpec inSpec = inData[0].getSpec();
    List<String> orderColumns = m_orderColumns.getIncludeList();
    List<String> valueColumns = m_valueColumns.getIncludeList();
    int[] orderColumnIdx = new int[orderColumns.size()];
    for (int i = 0; i < orderColumnIdx.length; i++) {
        orderColumnIdx[i] = inSpec.findColumnIndex(orderColumns.get(i));
    }
    final double newRowCnt = inData[0].getRowCount() * valueColumns.size();
    final boolean enableHilite = m_enableHilite.getBooleanValue();
    LinkedHashMap<RowKey, Set<RowKey>> map = new LinkedHashMap<RowKey, Set<RowKey>>();
    DataTableSpec outSpec = createOutSpec(inSpec);
    BufferedDataContainer buf = exec.createDataContainer(outSpec);
    for (DataRow row : inData[0]) {
        LinkedHashSet<RowKey> set = new LinkedHashSet<RowKey>();
        FilterColumnRow crow = new FilterColumnRow(row, orderColumnIdx);
        for (int i = 0; i < valueColumns.size(); i++) {
            String colName = valueColumns.get(i);
            DataCell acell = row.getCell(inSpec.findColumnIndex(colName));
            if (acell.isMissing() && m_missingValues.getBooleanValue()) {
                // skip rows containing missing cells (in Value column(s))
                continue;
            }
            RowKey rowKey = RowKey.createRowKey(buf.size());
            if (enableHilite) {
                set.add(rowKey);
            }
            DefaultRow drow = new DefaultRow(rowKey, new StringCell(row.getKey().getString()), new StringCell(colName), acell);
            buf.addRowToTable(new AppendedColumnRow(rowKey, drow, crow));
            exec.checkCanceled();
            exec.setProgress(buf.size() / newRowCnt);
        }
        if (enableHilite) {
            map.put(crow.getKey(), set);
        }
    }
    buf.close();
    if (enableHilite) {
        m_trans.setMapper(new DefaultHiLiteMapper(map));
    } else {
        m_trans.setMapper(null);
    }
    return new BufferedDataTable[] { buf.getTable() };
}
Also used : LinkedHashSet(java.util.LinkedHashSet) DataTableSpec(org.knime.core.data.DataTableSpec) LinkedHashSet(java.util.LinkedHashSet) Set(java.util.Set) BufferedDataContainer(org.knime.core.node.BufferedDataContainer) RowKey(org.knime.core.data.RowKey) SettingsModelFilterString(org.knime.core.node.defaultnodesettings.SettingsModelFilterString) DataRow(org.knime.core.data.DataRow) LinkedHashMap(java.util.LinkedHashMap) StringCell(org.knime.core.data.def.StringCell) BufferedDataTable(org.knime.core.node.BufferedDataTable) DataCell(org.knime.core.data.DataCell) DefaultRow(org.knime.core.data.def.DefaultRow) DefaultHiLiteMapper(org.knime.core.node.property.hilite.DefaultHiLiteMapper) FilterColumnRow(org.knime.base.data.filter.column.FilterColumnRow) AppendedColumnRow(org.knime.base.data.append.column.AppendedColumnRow)

Aggregations

FilterColumnRow (org.knime.base.data.filter.column.FilterColumnRow)15 DataRow (org.knime.core.data.DataRow)15 DataCell (org.knime.core.data.DataCell)13 DoubleCell (org.knime.core.data.def.DoubleCell)10 PredictorRecord (org.knime.base.node.mine.treeensemble2.data.PredictorRecord)6 IntCell (org.knime.core.data.def.IntCell)5 StringCell (org.knime.core.data.def.StringCell)5 PredictorRecord (org.knime.base.node.mine.treeensemble.data.PredictorRecord)3 TreeEnsembleModel (org.knime.base.node.mine.treeensemble2.model.TreeEnsembleModel)3 TreeEnsembleModelPortObject (org.knime.base.node.mine.treeensemble2.model.TreeEnsembleModelPortObject)3 TreeNodeRegression (org.knime.base.node.mine.treeensemble2.model.TreeNodeRegression)3 TreeEnsemblePredictorConfiguration (org.knime.base.node.mine.treeensemble2.node.predictor.TreeEnsemblePredictorConfiguration)3 HashMap (java.util.HashMap)2 LinkedHashMap (java.util.LinkedHashMap)2 LinkedHashSet (java.util.LinkedHashSet)2 Set (java.util.Set)2 Mean (org.apache.commons.math.stat.descriptive.moment.Mean)2 Variance (org.apache.commons.math.stat.descriptive.moment.Variance)2 AppendedColumnRow (org.knime.base.data.append.column.AppendedColumnRow)2 TreeEnsembleModel (org.knime.base.node.mine.treeensemble.model.TreeEnsembleModel)2