use of org.knime.base.util.kdtree.KDTreeBuilder in project knime-core by knime.
the class KnnNodeModel method createRearranger.
/*
* Creates a column rearranger. NOTE: This call possibly involves heavier calculations since the kd-tree is determined here based on the training data.
* @param numRowsTable2 - can be -1 if can't be determined (streaming)
*/
private ColumnRearranger createRearranger(final BufferedDataTable trainData, final DataTableSpec inSpec2, final ExecutionContext exec, final long numRowsTable2) throws CanceledExecutionException, InvalidSettingsException {
int classColIndex = trainData.getDataTableSpec().findColumnIndex(m_settings.classColumn());
if (classColIndex == -1) {
throw new InvalidSettingsException("Invalid class column chosen.");
}
List<Integer> featureColumns = new ArrayList<Integer>();
Map<Integer, Integer> firstToSecond = new HashMap<Integer, Integer>();
checkInputTables(new DataTableSpec[] { trainData.getDataTableSpec(), inSpec2 }, featureColumns, firstToSecond);
KDTreeBuilder<DataCell> treeBuilder = new KDTreeBuilder<DataCell>(featureColumns.size());
int count = 0;
for (DataRow currentRow : trainData) {
exec.checkCanceled();
exec.setProgress(0.1 * count * trainData.size(), "Reading row " + currentRow.getKey());
double[] features = createFeatureVector(currentRow, featureColumns);
if (features == null) {
setWarningMessage("Input table contains missing values, the " + "affected rows are ignored.");
} else {
DataCell thisClassCell = currentRow.getCell(classColIndex);
// and finally add data
treeBuilder.addPattern(features, thisClassCell);
// compute the majority class for breaking possible ties later
MutableInteger t = m_classDistribution.get(thisClassCell);
if (t == null) {
m_classDistribution.put(thisClassCell, new MutableInteger(1));
} else {
t.inc();
}
}
}
// and now use it to classify the test data...
DataColumnSpec classColumnSpec = trainData.getDataTableSpec().getColumnSpec(classColIndex);
exec.setMessage("Building kd-tree");
KDTree<DataCell> tree = treeBuilder.buildTree(exec.createSubProgress(0.3));
if (tree.size() < m_settings.k()) {
setWarningMessage("There are only " + tree.size() + " patterns in the input table, but " + m_settings.k() + " nearest neighbours were requested for classification." + " The prediction will be the majority class for all" + " input patterns.");
}
exec.setMessage("Classifying");
ColumnRearranger c = createRearranger(inSpec2, classColumnSpec, featureColumns, firstToSecond, tree, numRowsTable2);
return c;
}
Aggregations