Search in sources :

Example 81 with DataRow

use of org.knime.core.data.DataRow in project knime-core by knime.

the class TreeDataCreator method readData.

/**
 * Reads the data from <b>learnData</b> into memory.
 * Each column is represented by a TreeColumnData object corresponding to its type
 * and whether it is a attribute or target column.
 *
 * @param learnData
 * @param configuration
 * @param exec
 * @return the TreeData object that holds all data in memory
 * @throws CanceledExecutionException
 */
public TreeData readData(final BufferedDataTable learnData, final TreeEnsembleLearnerConfiguration configuration, final ExecutionMonitor exec) throws CanceledExecutionException {
    if (learnData.size() <= 1) {
        throw new IllegalArgumentException("The input table must contain at least 2 rows!");
    }
    int index = 0;
    final long nrRows = learnData.size();
    final int nrLearnCols = m_attrColCreators.length;
    final boolean[] supportMissings = new boolean[nrLearnCols];
    for (int i = 0; i < nrLearnCols; i++) {
        supportMissings[i] = m_attrColCreators[i].acceptsMissing();
    }
    int rejectedMissings = 0;
    final int nrHilitePatterns = m_configuration.getNrHilitePatterns();
    // sort learnData according to the target column to enable equal size sampling
    final int targetColIdx = learnData.getDataTableSpec().findColumnIndex(m_configuration.getTargetColumn());
    Comparator<DataCell> targetComp = learnData.getDataTableSpec().getColumnSpec(targetColIdx).getType().getComparator();
    DataTableSorter sorter = new DataTableSorter(learnData, learnData.size(), new Comparator<DataRow>() {

        @Override
        public int compare(final DataRow arg0, final DataRow arg1) {
            return targetComp.compare(arg0.getCell(targetColIdx), arg1.getCell(targetColIdx));
        }
    });
    final ExecutionMonitor sortExec = exec.createSubProgress(0.5);
    final DataTable sortedTable = sorter.sort(sortExec);
    final ExecutionMonitor readExec = exec.createSubProgress(0.5);
    for (DataRow r : sortedTable) {
        double progress = index / (double) nrRows;
        readExec.setProgress(progress, "Row " + index + " of " + nrRows + " (\"" + r.getKey() + "\")");
        readExec.checkCanceled();
        boolean shouldReject = false;
        for (int i = 0; i < nrLearnCols; i++) {
            DataCell c = r.getCell(i);
            if (c.isMissing() && !supportMissings[i]) {
                shouldReject = true;
                break;
            }
        }
        DataCell targetCell = r.getCell(nrLearnCols);
        if (targetCell.isMissing()) {
            shouldReject = true;
        }
        if (shouldReject) {
            rejectedMissings += 1;
            continue;
        }
        if (index < nrHilitePatterns) {
            m_dataRowsForHiliteContainer.addRowToTable(r);
        }
        final RowKey key = r.getKey();
        for (int i = 0; i < nrLearnCols; i++) {
            DataCell c = r.getCell(i);
            m_attrColCreators[i].add(key, c);
        }
        m_targetColCreator.add(key, targetCell);
        index++;
    }
    if (nrHilitePatterns > 0 && index > nrHilitePatterns) {
        m_viewMessage = "Hilite (& color graphs) are based on a subset of " + "the data (" + nrHilitePatterns + "/" + index + ")";
    }
    if (rejectedMissings > 0) {
        StringBuffer warnMsgBuilder = new StringBuffer();
        warnMsgBuilder.append(rejectedMissings).append("/");
        warnMsgBuilder.append(learnData.size());
        warnMsgBuilder.append(" row(s) were ignored because they ");
        warnMsgBuilder.append("contain missing values.");
        m_warningMessage = warnMsgBuilder.toString();
    }
    CheckUtils.checkArgument(rejectedMissings < learnData.size(), "No rows left after removing missing values (table has %d row(s))", learnData.size());
    int nrLearnAttributes = 0;
    for (int i = 0; i < m_attrColCreators.length; i++) {
        nrLearnAttributes += m_attrColCreators[i].getNrAttributes();
    }
    TreeAttributeColumnData[] columns = new TreeAttributeColumnData[nrLearnAttributes];
    int learnAttributeIndex = 0;
    for (int i = 0; i < m_attrColCreators.length; i++) {
        TreeAttributeColumnDataCreator creator = m_attrColCreators[i];
        for (int a = 0; a < creator.getNrAttributes(); a++) {
            final TreeAttributeColumnData columnData = creator.createColumnData(a, configuration);
            columnData.getMetaData().setAttributeIndex(learnAttributeIndex);
            columns[learnAttributeIndex++] = columnData;
        }
    }
    TreeTargetColumnData targetCol = m_targetColCreator.createColumnData();
    return new TreeData(columns, targetCol, m_treeType);
}
Also used : DataTable(org.knime.core.data.DataTable) BufferedDataTable(org.knime.core.node.BufferedDataTable) RowKey(org.knime.core.data.RowKey) DataRow(org.knime.core.data.DataRow) DataTableSorter(org.knime.core.data.sort.DataTableSorter) DataCell(org.knime.core.data.DataCell) ExecutionMonitor(org.knime.core.node.ExecutionMonitor)

Example 82 with DataRow

use of org.knime.core.data.DataRow in project knime-core by knime.

the class RandomForestDistance method computeDistance.

/**
 * {@inheritDoc}
 */
@Override
public double computeDistance(final DataRow row1, final DataRow row2) throws DistanceMeasurementException {
    List<Integer> filterIndicesList = getColumnIndices();
    int[] filterIndices = new int[filterIndicesList.size()];
    int i = 0;
    for (Integer index : filterIndicesList) {
        filterIndices[i++] = index;
    }
    final DataRow filterRow1 = new FilterColumnRow(row1, filterIndices);
    final DataRow filterRow2 = new FilterColumnRow(row2, filterIndices);
    final PredictorRecord record1 = m_ensembleModel.createPredictorRecord(filterRow1, m_learnTableSpec);
    final PredictorRecord record2 = m_ensembleModel.createPredictorRecord(filterRow2, m_learnTableSpec);
    final int nrModels = m_ensembleModel.getNrModels();
    double proximity = 0.0;
    for (int t = 0; t < nrModels; t++) {
        AbstractTreeModel<?> tree = m_ensembleModel.getTreeModel(t);
        AbstractTreeNode leaf1 = tree.findMatchingNode(record1);
        AbstractTreeNode leaf2 = tree.findMatchingNode(record2);
        if (leaf1.getSignature().equals(leaf2.getSignature())) {
            proximity += 1.0;
        }
    }
    proximity /= nrModels;
    // to get a distance measure, we have to subtract the proximity from 1
    return 1 - proximity;
}
Also used : PredictorRecord(org.knime.base.node.mine.treeensemble2.data.PredictorRecord) AbstractTreeNode(org.knime.base.node.mine.treeensemble2.model.AbstractTreeNode) DataRow(org.knime.core.data.DataRow) FilterColumnRow(org.knime.base.data.filter.column.FilterColumnRow)

Example 83 with DataRow

use of org.knime.core.data.DataRow in project knime-core by knime.

the class JoinerTest method compareTables.

private void compareTables(final BufferedDataTable reference, final BufferedDataTable test) {
    // Check if it has the same results as defaultResult
    assertThat("Unequal number of rows in result table", test.getRowCount(), is(reference.getRowCount()));
    RowIterator referenceIter = reference.iterator();
    RowIterator testIter = test.iterator();
    while (referenceIter.hasNext()) {
        DataRow refRow = referenceIter.next();
        DataRow testRow = testIter.next();
        assertThat("Unexpected row key", testRow.getKey(), is(refRow.getKey()));
        Iterator<DataCell> refCell = refRow.iterator();
        Iterator<DataCell> testCell = testRow.iterator();
        while (refCell.hasNext()) {
            assertThat("Unexpected cell in row " + refRow.getKey(), testCell.next(), is(refCell.next()));
        }
    }
}
Also used : RowIterator(org.knime.core.data.RowIterator) DataCell(org.knime.core.data.DataCell) DataRow(org.knime.core.data.DataRow)

Example 84 with DataRow

use of org.knime.core.data.DataRow in project knime-core by knime.

the class TreeEnsembleClassificationPredictorCellFactory method getCells.

/**
 * {@inheritDoc}
 */
@Override
public DataCell[] getCells(final DataRow row) {
    TreeEnsembleModelPortObject modelObject = m_predictor.getModelObject();
    TreeEnsemblePredictorConfiguration cfg = m_predictor.getConfiguration();
    final TreeEnsembleModel ensembleModel = modelObject.getEnsembleModel();
    int size = 1;
    final boolean appendConfidence = cfg.isAppendPredictionConfidence();
    if (appendConfidence) {
        size += 1;
    }
    final boolean appendClassConfidences = cfg.isAppendClassConfidences();
    if (appendClassConfidences) {
        size += m_targetValueMap.size();
    }
    final boolean appendModelCount = cfg.isAppendModelCount();
    if (appendModelCount) {
        size += 1;
    }
    final boolean hasOutOfBagFilter = m_predictor.hasOutOfBagFilter();
    DataCell[] result = new DataCell[size];
    DataRow filterRow = new FilterColumnRow(row, m_learnColumnInRealDataIndices);
    PredictorRecord record = ensembleModel.createPredictorRecord(filterRow, m_learnSpec);
    if (record == null) {
        // missing value
        Arrays.fill(result, DataType.getMissingCell());
        return result;
    }
    OccurrenceCounter<String> counter = new OccurrenceCounter<String>();
    final int nrModels = ensembleModel.getNrModels();
    int nrValidModels = 0;
    for (int i = 0; i < nrModels; i++) {
        if (hasOutOfBagFilter && m_predictor.isRowPartOfTrainingData(row.getKey(), i)) {
        // ignore, row was used to train the model
        } else {
            TreeModelClassification m = ensembleModel.getTreeModelClassification(i);
            TreeNodeClassification match = m.findMatchingNode(record);
            String majorityClassName = match.getMajorityClassName();
            counter.add(majorityClassName);
            nrValidModels += 1;
        }
    }
    String bestValue = counter.getMostFrequent();
    int index = 0;
    if (bestValue == null) {
        assert nrValidModels == 0;
        Arrays.fill(result, DataType.getMissingCell());
        index = size - 1;
    } else {
        result[index++] = m_targetValueMap.get(bestValue);
        if (appendConfidence) {
            final int freqValue = counter.getFrequency(bestValue);
            result[index++] = new DoubleCell(freqValue / (double) nrValidModels);
        }
        if (appendClassConfidences) {
            for (String key : m_targetValueMap.keySet()) {
                int frequency = counter.getFrequency(key);
                double ratio = frequency / (double) nrValidModels;
                result[index++] = new DoubleCell(ratio);
            }
        }
    }
    if (appendModelCount) {
        result[index++] = new IntCell(nrValidModels);
    }
    return result;
}
Also used : TreeNodeClassification(org.knime.base.node.mine.treeensemble.model.TreeNodeClassification) TreeEnsembleModel(org.knime.base.node.mine.treeensemble.model.TreeEnsembleModel) DoubleCell(org.knime.core.data.def.DoubleCell) TreeEnsemblePredictorConfiguration(org.knime.base.node.mine.treeensemble.node.predictor.TreeEnsemblePredictorConfiguration) DataRow(org.knime.core.data.DataRow) IntCell(org.knime.core.data.def.IntCell) TreeEnsembleModelPortObject(org.knime.base.node.mine.treeensemble.model.TreeEnsembleModelPortObject) PredictorRecord(org.knime.base.node.mine.treeensemble.data.PredictorRecord) DataCell(org.knime.core.data.DataCell) FilterColumnRow(org.knime.base.data.filter.column.FilterColumnRow) TreeModelClassification(org.knime.base.node.mine.treeensemble.model.TreeModelClassification)

Example 85 with DataRow

use of org.knime.core.data.DataRow in project knime-core by knime.

the class RegressionTreePredictorCellFactory method getCells.

/**
 * {@inheritDoc}
 */
@Override
public DataCell[] getCells(final DataRow row) {
    RegressionTreeModelPortObject modelObject = m_predictor.getModelObject();
    final RegressionTreeModel treeModel = modelObject.getModel();
    int size = 1;
    DataCell[] result = new DataCell[size];
    DataRow filterRow = new FilterColumnRow(row, m_learnColumnInRealDataIndices);
    PredictorRecord record = treeModel.createPredictorRecord(filterRow, m_learnSpec);
    if (record == null) {
        // missing value
        Arrays.fill(result, DataType.getMissingCell());
        return result;
    }
    TreeModelRegression tree = treeModel.getTreeModel();
    TreeNodeRegression match = tree.findMatchingNode(record);
    double nodeMean = match.getMean();
    result[0] = new DoubleCell(nodeMean);
    return result;
}
Also used : RegressionTreeModelPortObject(org.knime.base.node.mine.treeensemble.model.RegressionTreeModelPortObject) RegressionTreeModel(org.knime.base.node.mine.treeensemble.model.RegressionTreeModel) DoubleCell(org.knime.core.data.def.DoubleCell) PredictorRecord(org.knime.base.node.mine.treeensemble.data.PredictorRecord) DataCell(org.knime.core.data.DataCell) DataRow(org.knime.core.data.DataRow) TreeNodeRegression(org.knime.base.node.mine.treeensemble.model.TreeNodeRegression) FilterColumnRow(org.knime.base.data.filter.column.FilterColumnRow) TreeModelRegression(org.knime.base.node.mine.treeensemble.model.TreeModelRegression)

Aggregations

DataRow (org.knime.core.data.DataRow)482 DataCell (org.knime.core.data.DataCell)268 DataTableSpec (org.knime.core.data.DataTableSpec)159 BufferedDataTable (org.knime.core.node.BufferedDataTable)125 DataColumnSpec (org.knime.core.data.DataColumnSpec)109 RowKey (org.knime.core.data.RowKey)88 DefaultRow (org.knime.core.data.def.DefaultRow)88 BufferedDataContainer (org.knime.core.node.BufferedDataContainer)80 InvalidSettingsException (org.knime.core.node.InvalidSettingsException)76 ColumnRearranger (org.knime.core.data.container.ColumnRearranger)73 DoubleValue (org.knime.core.data.DoubleValue)72 ArrayList (java.util.ArrayList)65 DataColumnSpecCreator (org.knime.core.data.DataColumnSpecCreator)65 RowIterator (org.knime.core.data.RowIterator)62 DataType (org.knime.core.data.DataType)61 DoubleCell (org.knime.core.data.def.DoubleCell)57 StringCell (org.knime.core.data.def.StringCell)53 SingleCellFactory (org.knime.core.data.container.SingleCellFactory)48 ExecutionMonitor (org.knime.core.node.ExecutionMonitor)44 CanceledExecutionException (org.knime.core.node.CanceledExecutionException)43