Search in sources :

Example 6 with FilterColumnRow

use of org.knime.base.data.filter.column.FilterColumnRow in project knime-core by knime.

the class RegressionTreePredictorCellFactory method getCells.

/**
 * {@inheritDoc}
 */
@Override
public DataCell[] getCells(final DataRow row) {
    RegressionTreeModelPortObject modelObject = m_predictor.getModelObject();
    final RegressionTreeModel treeModel = modelObject.getModel();
    int size = 1;
    DataCell[] result = new DataCell[size];
    DataRow filterRow = new FilterColumnRow(row, m_learnColumnInRealDataIndices);
    PredictorRecord record = treeModel.createPredictorRecord(filterRow, m_learnSpec);
    if (record == null) {
        // missing value
        Arrays.fill(result, DataType.getMissingCell());
        return result;
    }
    TreeModelRegression tree = treeModel.getTreeModel();
    TreeNodeRegression match = tree.findMatchingNode(record);
    double nodeMean = match.getMean();
    result[0] = new DoubleCell(nodeMean);
    return result;
}
Also used : RegressionTreeModelPortObject(org.knime.base.node.mine.treeensemble.model.RegressionTreeModelPortObject) RegressionTreeModel(org.knime.base.node.mine.treeensemble.model.RegressionTreeModel) DoubleCell(org.knime.core.data.def.DoubleCell) PredictorRecord(org.knime.base.node.mine.treeensemble.data.PredictorRecord) DataCell(org.knime.core.data.DataCell) DataRow(org.knime.core.data.DataRow) TreeNodeRegression(org.knime.base.node.mine.treeensemble.model.TreeNodeRegression) FilterColumnRow(org.knime.base.data.filter.column.FilterColumnRow) TreeModelRegression(org.knime.base.node.mine.treeensemble.model.TreeModelRegression)

Example 7 with FilterColumnRow

use of org.knime.base.data.filter.column.FilterColumnRow in project knime-core by knime.

the class BasisFunctionPredictor2CellFactory method getCells.

/**
 * Predicts given row using the underlying basis function model.
 * {@inheritDoc}
 */
@Override
public DataCell[] getCells(final DataRow row) {
    DataRow wRow = new FilterColumnRow(row, m_filteredColumns);
    DataCell[] pred = predict(wRow, m_model);
    if (m_appendClassProps) {
        // complete prediction including class probs and label
        return pred;
    } else {
        // don't append class probabilities
        return new DataCell[] { pred[pred.length - 1] };
    }
}
Also used : DataCell(org.knime.core.data.DataCell) DataRow(org.knime.core.data.DataRow) FilterColumnRow(org.knime.base.data.filter.column.FilterColumnRow)

Example 8 with FilterColumnRow

use of org.knime.base.data.filter.column.FilterColumnRow in project knime-core by knime.

the class Unpivot2NodeModel method execute.

/**
 * {@inheritDoc}
 */
@Override
protected BufferedDataTable[] execute(final BufferedDataTable[] inData, final ExecutionContext exec) throws Exception {
    DataTableSpec inSpec = inData[0].getSpec();
    String[] retainedColumns = m_retainedColumns.applyTo(inSpec).getIncludes();
    String[] valueColumns = m_valueColumns.applyTo(inSpec).getIncludes();
    int[] valueColumnIndices = new int[valueColumns.length];
    for (int i = 0; i < valueColumnIndices.length; i++) {
        valueColumnIndices[i] = inSpec.findColumnIndex(valueColumns[i]);
    }
    int[] orderColumnIdx = new int[retainedColumns.length];
    for (int i = 0; i < orderColumnIdx.length; i++) {
        orderColumnIdx[i] = inSpec.findColumnIndex(retainedColumns[i]);
    }
    final double newRowCnt = inData[0].size() * valueColumns.length;
    final boolean enableHilite = m_enableHilite.getBooleanValue();
    LinkedHashMap<RowKey, Set<RowKey>> map = new LinkedHashMap<RowKey, Set<RowKey>>();
    DataTableSpec outSpec = createOutSpec(inSpec);
    BufferedDataContainer buf = exec.createDataContainer(outSpec);
    final boolean skipMissings = m_missingValues.getBooleanValue();
    for (DataRow row : inData[0]) {
        LinkedHashSet<RowKey> set = new LinkedHashSet<RowKey>();
        FilterColumnRow crow = new FilterColumnRow(row, orderColumnIdx);
        for (int i = 0; i < valueColumns.length; i++) {
            String colName = valueColumns[i];
            DataCell acell = row.getCell(valueColumnIndices[i]);
            if (acell.isMissing() && skipMissings) {
                // skip rows containing missing cells (in Value column(s))
                continue;
            }
            RowKey rowKey = RowKey.createRowKey(buf.size());
            if (enableHilite) {
                set.add(rowKey);
            }
            DefaultRow drow = new DefaultRow(rowKey, new StringCell(row.getKey().getString()), new StringCell(colName), acell);
            buf.addRowToTable(new AppendedColumnRow(rowKey, drow, crow));
            exec.checkCanceled();
            exec.setProgress(buf.size() / newRowCnt);
        }
        if (enableHilite) {
            map.put(crow.getKey(), set);
        }
    }
    buf.close();
    if (enableHilite) {
        m_trans.setMapper(new DefaultHiLiteMapper(map));
    } else {
        m_trans.setMapper(null);
    }
    return new BufferedDataTable[] { buf.getTable() };
}
Also used : LinkedHashSet(java.util.LinkedHashSet) DataTableSpec(org.knime.core.data.DataTableSpec) LinkedHashSet(java.util.LinkedHashSet) Set(java.util.Set) BufferedDataContainer(org.knime.core.node.BufferedDataContainer) RowKey(org.knime.core.data.RowKey) DataRow(org.knime.core.data.DataRow) LinkedHashMap(java.util.LinkedHashMap) StringCell(org.knime.core.data.def.StringCell) BufferedDataTable(org.knime.core.node.BufferedDataTable) DataCell(org.knime.core.data.DataCell) DefaultRow(org.knime.core.data.def.DefaultRow) DefaultHiLiteMapper(org.knime.core.node.property.hilite.DefaultHiLiteMapper) FilterColumnRow(org.knime.base.data.filter.column.FilterColumnRow) AppendedColumnRow(org.knime.base.data.append.column.AppendedColumnRow)

Example 9 with FilterColumnRow

use of org.knime.base.data.filter.column.FilterColumnRow in project knime-core by knime.

the class SotaPredictorCellFactory method getCells.

/**
 * {@inheritDoc}
 */
@Override
public DataCell[] getCells(final DataRow row) {
    if (row != null) {
        DataRow filteredRow = new FilterColumnRow(row, m_includedColsIndices);
        Iterator<DataCell> it = filteredRow.iterator();
        while (it.hasNext()) {
            if (it.next().isMissing()) {
                return new DataCell[] { DataType.getMissingCell() };
            }
        }
        SotaTreeCell winner = null;
        double minDist = Double.MAX_VALUE;
        Map<String, Double> minDists = new HashMap<String, Double>();
        for (int j = 0; j < m_cells.size(); j++) {
            double dist = m_distanceManager.getDistance(filteredRow, m_cells.get(j));
            String treeCellClass = m_cells.get(j).getTreeCellClass();
            if (minDists.containsKey(treeCellClass)) {
                Double old = minDists.get(treeCellClass);
                if (old.doubleValue() > dist) {
                    minDists.put(treeCellClass, dist);
                }
            } else {
                minDists.put(treeCellClass, dist);
            }
            if (dist < minDist) {
                winner = m_cells.get(j);
                minDist = dist;
            }
        }
        String predClass = SotaTreeCell.DEFAULT_CLASS;
        if (winner != null) {
            predClass = winner.getTreeCellClass();
        }
        DataCell[] ret;
        ret = new DataCell[m_newColumns.getNumColumns()];
        if (m_appendProbs) {
            double sumDists = 0d;
            for (Double d : minDists.values()) {
                sumDists += 1d / Math.max(EPSILON, d.doubleValue());
            }
            for (int i = ret.length; i-- > 0; ) {
                ret[i] = DataType.getMissingCell();
            }
            for (Entry<String, Double> entry : minDists.entrySet()) {
                final String target = entry.getKey();
                final String colName = PredictorHelper.getInstance().probabilityColumnName(m_targetColumn.getName(), target, m_suffix);
                int colIndex = m_newColumns.findColumnIndex(colName);
                if (colIndex >= 0) {
                    ret[colIndex] = new DoubleCell(1d / Math.max(EPSILON, entry.getValue().doubleValue()) / sumDists);
                }
            }
            ret[ret.length - 1] = new StringCell(predClass);
        } else {
            ret = new DataCell[] { new StringCell(predClass) };
        }
        return ret;
    }
    return null;
}
Also used : SotaTreeCell(org.knime.base.node.mine.sota.logic.SotaTreeCell) HashMap(java.util.HashMap) DoubleCell(org.knime.core.data.def.DoubleCell) DataRow(org.knime.core.data.DataRow) StringCell(org.knime.core.data.def.StringCell) DataCell(org.knime.core.data.DataCell) FilterColumnRow(org.knime.base.data.filter.column.FilterColumnRow)

Example 10 with FilterColumnRow

use of org.knime.base.data.filter.column.FilterColumnRow in project knime-core by knime.

the class TreeEnsembleRegressionPredictorCellFactory method getCells.

/**
 * {@inheritDoc}
 */
@Override
public DataCell[] getCells(final DataRow row) {
    TreeEnsembleModelPortObject modelObject = m_predictor.getModelObject();
    TreeEnsemblePredictorConfiguration cfg = m_predictor.getConfiguration();
    final TreeEnsembleModel ensembleModel = modelObject.getEnsembleModel();
    int size = 1;
    final boolean appendConfidence = cfg.isAppendPredictionConfidence();
    final boolean appendModelCount = cfg.isAppendModelCount();
    if (appendConfidence) {
        size += 1;
    }
    if (appendModelCount) {
        size += 1;
    }
    final boolean hasOutOfBagFilter = m_predictor.hasOutOfBagFilter();
    DataCell[] result = new DataCell[size];
    DataRow filterRow = new FilterColumnRow(row, m_learnColumnInRealDataIndices);
    PredictorRecord record = ensembleModel.createPredictorRecord(filterRow, m_learnSpec);
    if (record == null) {
        // missing value
        Arrays.fill(result, DataType.getMissingCell());
        return result;
    }
    Mean mean = new Mean();
    Variance variance = new Variance();
    final int nrModels = ensembleModel.getNrModels();
    for (int i = 0; i < nrModels; i++) {
        if (hasOutOfBagFilter && m_predictor.isRowPartOfTrainingData(row.getKey(), i)) {
        // ignore, row was used to train the model
        } else {
            TreeModelRegression m = ensembleModel.getTreeModelRegression(i);
            TreeNodeRegression match = m.findMatchingNode(record);
            double nodeMean = match.getMean();
            mean.increment(nodeMean);
            variance.increment(nodeMean);
        }
    }
    int nrValidModels = (int) mean.getN();
    int index = 0;
    result[index++] = nrValidModels == 0 ? DataType.getMissingCell() : new DoubleCell(mean.getResult());
    if (appendConfidence) {
        result[index++] = nrValidModels == 0 ? DataType.getMissingCell() : new DoubleCell(variance.getResult());
    }
    if (appendModelCount) {
        result[index++] = new IntCell(nrValidModels);
    }
    return result;
}
Also used : Mean(org.apache.commons.math.stat.descriptive.moment.Mean) TreeEnsembleModel(org.knime.base.node.mine.treeensemble.model.TreeEnsembleModel) DoubleCell(org.knime.core.data.def.DoubleCell) TreeEnsemblePredictorConfiguration(org.knime.base.node.mine.treeensemble.node.predictor.TreeEnsemblePredictorConfiguration) DataRow(org.knime.core.data.DataRow) TreeNodeRegression(org.knime.base.node.mine.treeensemble.model.TreeNodeRegression) Variance(org.apache.commons.math.stat.descriptive.moment.Variance) TreeModelRegression(org.knime.base.node.mine.treeensemble.model.TreeModelRegression) IntCell(org.knime.core.data.def.IntCell) TreeEnsembleModelPortObject(org.knime.base.node.mine.treeensemble.model.TreeEnsembleModelPortObject) PredictorRecord(org.knime.base.node.mine.treeensemble.data.PredictorRecord) DataCell(org.knime.core.data.DataCell) FilterColumnRow(org.knime.base.data.filter.column.FilterColumnRow)

Aggregations

FilterColumnRow (org.knime.base.data.filter.column.FilterColumnRow)15 DataRow (org.knime.core.data.DataRow)15 DataCell (org.knime.core.data.DataCell)13 DoubleCell (org.knime.core.data.def.DoubleCell)10 PredictorRecord (org.knime.base.node.mine.treeensemble2.data.PredictorRecord)6 IntCell (org.knime.core.data.def.IntCell)5 StringCell (org.knime.core.data.def.StringCell)5 PredictorRecord (org.knime.base.node.mine.treeensemble.data.PredictorRecord)3 TreeEnsembleModel (org.knime.base.node.mine.treeensemble2.model.TreeEnsembleModel)3 TreeEnsembleModelPortObject (org.knime.base.node.mine.treeensemble2.model.TreeEnsembleModelPortObject)3 TreeNodeRegression (org.knime.base.node.mine.treeensemble2.model.TreeNodeRegression)3 TreeEnsemblePredictorConfiguration (org.knime.base.node.mine.treeensemble2.node.predictor.TreeEnsemblePredictorConfiguration)3 HashMap (java.util.HashMap)2 LinkedHashMap (java.util.LinkedHashMap)2 LinkedHashSet (java.util.LinkedHashSet)2 Set (java.util.Set)2 Mean (org.apache.commons.math.stat.descriptive.moment.Mean)2 Variance (org.apache.commons.math.stat.descriptive.moment.Variance)2 AppendedColumnRow (org.knime.base.data.append.column.AppendedColumnRow)2 TreeEnsembleModel (org.knime.base.node.mine.treeensemble.model.TreeEnsembleModel)2