use of org.knime.core.data.DataRow in project knime-core by knime.
the class TreeDataCreator method readData.
/**
* Reads the data from <b>learnData</b> into memory.
* Each column is represented by a TreeColumnData object corresponding to its type
* and whether it is a attribute or target column.
*
* @param learnData
* @param configuration
* @param exec
* @return the TreeData object that holds all data in memory
* @throws CanceledExecutionException
*/
public TreeData readData(final BufferedDataTable learnData, final TreeEnsembleLearnerConfiguration configuration, final ExecutionMonitor exec) throws CanceledExecutionException {
if (learnData.size() <= 1) {
throw new IllegalArgumentException("The input table must contain at least 2 rows!");
}
int index = 0;
final long nrRows = learnData.size();
final int nrLearnCols = m_attrColCreators.length;
final boolean[] supportMissings = new boolean[nrLearnCols];
for (int i = 0; i < nrLearnCols; i++) {
supportMissings[i] = m_attrColCreators[i].acceptsMissing();
}
int rejectedMissings = 0;
final int nrHilitePatterns = m_configuration.getNrHilitePatterns();
// sort learnData according to the target column to enable equal size sampling
final int targetColIdx = learnData.getDataTableSpec().findColumnIndex(m_configuration.getTargetColumn());
Comparator<DataCell> targetComp = learnData.getDataTableSpec().getColumnSpec(targetColIdx).getType().getComparator();
DataTableSorter sorter = new DataTableSorter(learnData, learnData.size(), new Comparator<DataRow>() {
@Override
public int compare(final DataRow arg0, final DataRow arg1) {
return targetComp.compare(arg0.getCell(targetColIdx), arg1.getCell(targetColIdx));
}
});
final ExecutionMonitor sortExec = exec.createSubProgress(0.5);
final DataTable sortedTable = sorter.sort(sortExec);
final ExecutionMonitor readExec = exec.createSubProgress(0.5);
for (DataRow r : sortedTable) {
double progress = index / (double) nrRows;
readExec.setProgress(progress, "Row " + index + " of " + nrRows + " (\"" + r.getKey() + "\")");
readExec.checkCanceled();
boolean shouldReject = false;
for (int i = 0; i < nrLearnCols; i++) {
DataCell c = r.getCell(i);
if (c.isMissing() && !supportMissings[i]) {
shouldReject = true;
break;
}
}
DataCell targetCell = r.getCell(nrLearnCols);
if (targetCell.isMissing()) {
shouldReject = true;
}
if (shouldReject) {
rejectedMissings += 1;
continue;
}
if (index < nrHilitePatterns) {
m_dataRowsForHiliteContainer.addRowToTable(r);
}
final RowKey key = r.getKey();
for (int i = 0; i < nrLearnCols; i++) {
DataCell c = r.getCell(i);
m_attrColCreators[i].add(key, c);
}
m_targetColCreator.add(key, targetCell);
index++;
}
if (nrHilitePatterns > 0 && index > nrHilitePatterns) {
m_viewMessage = "Hilite (& color graphs) are based on a subset of " + "the data (" + nrHilitePatterns + "/" + index + ")";
}
if (rejectedMissings > 0) {
StringBuffer warnMsgBuilder = new StringBuffer();
warnMsgBuilder.append(rejectedMissings).append("/");
warnMsgBuilder.append(learnData.size());
warnMsgBuilder.append(" row(s) were ignored because they ");
warnMsgBuilder.append("contain missing values.");
m_warningMessage = warnMsgBuilder.toString();
}
CheckUtils.checkArgument(rejectedMissings < learnData.size(), "No rows left after removing missing values (table has %d row(s))", learnData.size());
int nrLearnAttributes = 0;
for (int i = 0; i < m_attrColCreators.length; i++) {
nrLearnAttributes += m_attrColCreators[i].getNrAttributes();
}
TreeAttributeColumnData[] columns = new TreeAttributeColumnData[nrLearnAttributes];
int learnAttributeIndex = 0;
for (int i = 0; i < m_attrColCreators.length; i++) {
TreeAttributeColumnDataCreator creator = m_attrColCreators[i];
for (int a = 0; a < creator.getNrAttributes(); a++) {
final TreeAttributeColumnData columnData = creator.createColumnData(a, configuration);
columnData.getMetaData().setAttributeIndex(learnAttributeIndex);
columns[learnAttributeIndex++] = columnData;
}
}
TreeTargetColumnData targetCol = m_targetColCreator.createColumnData();
return new TreeData(columns, targetCol, m_treeType);
}
use of org.knime.core.data.DataRow in project knime-core by knime.
the class RandomForestDistance method computeDistance.
/**
* {@inheritDoc}
*/
@Override
public double computeDistance(final DataRow row1, final DataRow row2) throws DistanceMeasurementException {
List<Integer> filterIndicesList = getColumnIndices();
int[] filterIndices = new int[filterIndicesList.size()];
int i = 0;
for (Integer index : filterIndicesList) {
filterIndices[i++] = index;
}
final DataRow filterRow1 = new FilterColumnRow(row1, filterIndices);
final DataRow filterRow2 = new FilterColumnRow(row2, filterIndices);
final PredictorRecord record1 = m_ensembleModel.createPredictorRecord(filterRow1, m_learnTableSpec);
final PredictorRecord record2 = m_ensembleModel.createPredictorRecord(filterRow2, m_learnTableSpec);
final int nrModels = m_ensembleModel.getNrModels();
double proximity = 0.0;
for (int t = 0; t < nrModels; t++) {
AbstractTreeModel<?> tree = m_ensembleModel.getTreeModel(t);
AbstractTreeNode leaf1 = tree.findMatchingNode(record1);
AbstractTreeNode leaf2 = tree.findMatchingNode(record2);
if (leaf1.getSignature().equals(leaf2.getSignature())) {
proximity += 1.0;
}
}
proximity /= nrModels;
// to get a distance measure, we have to subtract the proximity from 1
return 1 - proximity;
}
use of org.knime.core.data.DataRow in project knime-core by knime.
the class JoinerTest method compareTables.
private void compareTables(final BufferedDataTable reference, final BufferedDataTable test) {
// Check if it has the same results as defaultResult
assertThat("Unequal number of rows in result table", test.getRowCount(), is(reference.getRowCount()));
RowIterator referenceIter = reference.iterator();
RowIterator testIter = test.iterator();
while (referenceIter.hasNext()) {
DataRow refRow = referenceIter.next();
DataRow testRow = testIter.next();
assertThat("Unexpected row key", testRow.getKey(), is(refRow.getKey()));
Iterator<DataCell> refCell = refRow.iterator();
Iterator<DataCell> testCell = testRow.iterator();
while (refCell.hasNext()) {
assertThat("Unexpected cell in row " + refRow.getKey(), testCell.next(), is(refCell.next()));
}
}
}
use of org.knime.core.data.DataRow in project knime-core by knime.
the class TreeEnsembleClassificationPredictorCellFactory method getCells.
/**
* {@inheritDoc}
*/
@Override
public DataCell[] getCells(final DataRow row) {
TreeEnsembleModelPortObject modelObject = m_predictor.getModelObject();
TreeEnsemblePredictorConfiguration cfg = m_predictor.getConfiguration();
final TreeEnsembleModel ensembleModel = modelObject.getEnsembleModel();
int size = 1;
final boolean appendConfidence = cfg.isAppendPredictionConfidence();
if (appendConfidence) {
size += 1;
}
final boolean appendClassConfidences = cfg.isAppendClassConfidences();
if (appendClassConfidences) {
size += m_targetValueMap.size();
}
final boolean appendModelCount = cfg.isAppendModelCount();
if (appendModelCount) {
size += 1;
}
final boolean hasOutOfBagFilter = m_predictor.hasOutOfBagFilter();
DataCell[] result = new DataCell[size];
DataRow filterRow = new FilterColumnRow(row, m_learnColumnInRealDataIndices);
PredictorRecord record = ensembleModel.createPredictorRecord(filterRow, m_learnSpec);
if (record == null) {
// missing value
Arrays.fill(result, DataType.getMissingCell());
return result;
}
OccurrenceCounter<String> counter = new OccurrenceCounter<String>();
final int nrModels = ensembleModel.getNrModels();
int nrValidModels = 0;
for (int i = 0; i < nrModels; i++) {
if (hasOutOfBagFilter && m_predictor.isRowPartOfTrainingData(row.getKey(), i)) {
// ignore, row was used to train the model
} else {
TreeModelClassification m = ensembleModel.getTreeModelClassification(i);
TreeNodeClassification match = m.findMatchingNode(record);
String majorityClassName = match.getMajorityClassName();
counter.add(majorityClassName);
nrValidModels += 1;
}
}
String bestValue = counter.getMostFrequent();
int index = 0;
if (bestValue == null) {
assert nrValidModels == 0;
Arrays.fill(result, DataType.getMissingCell());
index = size - 1;
} else {
result[index++] = m_targetValueMap.get(bestValue);
if (appendConfidence) {
final int freqValue = counter.getFrequency(bestValue);
result[index++] = new DoubleCell(freqValue / (double) nrValidModels);
}
if (appendClassConfidences) {
for (String key : m_targetValueMap.keySet()) {
int frequency = counter.getFrequency(key);
double ratio = frequency / (double) nrValidModels;
result[index++] = new DoubleCell(ratio);
}
}
}
if (appendModelCount) {
result[index++] = new IntCell(nrValidModels);
}
return result;
}
use of org.knime.core.data.DataRow in project knime-core by knime.
the class RegressionTreePredictorCellFactory method getCells.
/**
* {@inheritDoc}
*/
@Override
public DataCell[] getCells(final DataRow row) {
RegressionTreeModelPortObject modelObject = m_predictor.getModelObject();
final RegressionTreeModel treeModel = modelObject.getModel();
int size = 1;
DataCell[] result = new DataCell[size];
DataRow filterRow = new FilterColumnRow(row, m_learnColumnInRealDataIndices);
PredictorRecord record = treeModel.createPredictorRecord(filterRow, m_learnSpec);
if (record == null) {
// missing value
Arrays.fill(result, DataType.getMissingCell());
return result;
}
TreeModelRegression tree = treeModel.getTreeModel();
TreeNodeRegression match = tree.findMatchingNode(record);
double nodeMean = match.getMean();
result[0] = new DoubleCell(nodeMean);
return result;
}
Aggregations