use of org.knime.base.util.kdtree.NearestNeighbour in project knime-core by knime.
the class KnnNodeModel method classify.
// returns a list where the first value if the winner class, and the
// following values are the class probabilities (if enabled)
private List<DataCell> classify(final DataRow row, final KDTree<DataCell> tree, final List<Integer> featureColumns, final Map<Integer, Integer> firstToSecond, final DataCell[] allClassValues) {
double[] features = createQueryVector(row, featureColumns, firstToSecond);
List<DataCell> output = new ArrayList<DataCell>();
if (features == null) {
for (int i = 0; i < 1 + allClassValues.length; i++) {
output.add(DataType.getMissingCell());
}
return output;
}
HashMap<DataCell, MutableDouble> classWeights = new LinkedHashMap<DataCell, MutableDouble>();
List<NearestNeighbour<DataCell>> nearestN = tree.getKNearestNeighbours(features, Math.min(m_settings.k(), tree.size()));
for (NearestNeighbour<DataCell> n : nearestN) {
MutableDouble count = classWeights.get(n.getData());
if (count == null) {
count = new MutableDouble(0);
classWeights.put(n.getData(), count);
}
if (m_settings.weightByDistance()) {
count.add(1 / n.getDistance());
} else {
count.inc();
}
}
double winnerWeight = 0;
double weightSum = 0;
DataCell winnerCell = DataType.getMissingCell();
for (Map.Entry<DataCell, MutableDouble> e : classWeights.entrySet()) {
double weight = e.getValue().doubleValue();
if (weight > winnerWeight) {
winnerWeight = weight;
winnerCell = e.getKey();
}
weightSum += weight;
}
// check if there are other classes with the same weight
for (Map.Entry<DataCell, MutableDouble> e : classWeights.entrySet()) {
double weight = e.getValue().doubleValue();
if (weight == winnerWeight) {
if (m_classDistribution.get(winnerCell).intValue() < m_classDistribution.get(e.getKey()).intValue()) {
winnerCell = e.getKey();
}
}
}
output.add(winnerCell);
if (m_settings.outputClassProbabilities()) {
for (DataCell classVal : allClassValues) {
MutableDouble v = classWeights.get(classVal);
if (v == null) {
output.add(new DoubleCell(0));
// } else if (Double.isInfinite(v.doubleValue())) { // if distance to prototype is 0
// output.add(new DoubleCell(1));
} else {
output.add(new DoubleCell(v.doubleValue() / weightSum));
}
}
}
return output;
}
Aggregations