use of org.knime.base.util.HalfFloatMatrix in project knime-core by knime.
the class HierarchicalClusterNodeModel method execute.
/**
* {@inheritDoc}
*/
@Override
protected BufferedDataTable[] execute(final BufferedDataTable[] data, final ExecutionContext exec) throws Exception {
// determine the indices of the selected columns
List<String> inlcludedCols = m_selectedColumns.getIncludeList();
int[] selectedColIndices = new int[inlcludedCols.size()];
for (int count = 0; count < selectedColIndices.length; count++) {
selectedColIndices[count] = data[0].getDataTableSpec().findColumnIndex(inlcludedCols.get(count));
}
BufferedDataTable inputData = data[0];
if (inputData.size() > 65500) {
throw new RuntimeException("At most 65,500 patterns can be clustered");
}
DataTable outputData = null;
if (DistanceFunction.Names.Manhattan.toString().equals(m_distFunctionName.getStringValue())) {
m_distFunction = ManhattanDist.MANHATTEN_DISTANCE;
} else {
m_distFunction = EuclideanDist.EUCLIDEAN_DISTANCE;
}
// generate initial clustering
// which means that every data point is one cluster
List<ClusterNode> clusters = initClusters(inputData, exec);
// store the distance per each fusion step
DataContainer fusionCont = exec.createDataContainer(createFusionSpec());
int iterationStep = 0;
final HalfFloatMatrix cache;
if (m_cacheDistances.getBooleanValue()) {
cache = new HalfFloatMatrix((int) inputData.size(), false);
cache.fill(Float.NaN);
} else {
cache = null;
}
double max = inputData.size();
// the number of clusters at the beginning is equal to the number
// of data rows (each row is a cluster)
int numberDataRows = clusters.size();
while (clusters.size() > 1) {
// checks if number clusters to generate output table is reached
if (m_numClustersForOutput.getIntValue() == clusters.size()) {
outputData = createResultTable(inputData, clusters, exec);
}
exec.setProgress((numberDataRows - clusters.size()) / (double) numberDataRows, clusters.size() + " clusters left to merge.");
iterationStep++;
exec.setProgress(iterationStep / max, "Iteration " + iterationStep + ", " + clusters.size() + " clusters remaining");
// calculate distance between all clusters
float currentSmallestDist = Float.MAX_VALUE;
ClusterNode currentClosestCluster1 = null;
ClusterNode currentClosestCluster2 = null;
// subprogress for loop
double availableProgress = (1.0 / numberDataRows);
ExecutionContext subexec = exec.createSubExecutionContext(availableProgress);
for (int i = 0; i < clusters.size(); i++) {
exec.checkCanceled();
ClusterNode node1 = clusters.get(i);
for (int j = i + 1; j < clusters.size(); j++) {
final float dist;
ClusterNode node2 = clusters.get(j);
// and average linkage supported.
if (m_linkageType.getStringValue().equals(Linkage.SINGLE.name())) {
dist = calculateSingleLinkageDist(node1, node2, cache, selectedColIndices);
} else if (m_linkageType.getStringValue().equals(Linkage.AVERAGE.name())) {
dist = calculateAverageLinkageDist(node1, node2, cache, selectedColIndices);
} else {
dist = calculateCompleteLinkageDist(node1, node2, cache, selectedColIndices);
}
if (dist < currentSmallestDist) {
currentClosestCluster1 = node1;
currentClosestCluster2 = node2;
currentSmallestDist = dist;
}
}
}
subexec.setProgress(1.0);
// make one cluster of the two closest
ClusterNode newNode = new ClusterNode(currentClosestCluster1, currentClosestCluster2, currentSmallestDist);
clusters.remove(currentClosestCluster1);
clusters.remove(currentClosestCluster2);
clusters.add(newNode);
// store the distance per each fusion step
fusionCont.addRowToTable(new DefaultRow(// row key
Integer.toString(clusters.size()), // x-axis scatter plotter
new IntCell(clusters.size()), // y-axis scatter plotter
new DoubleCell(newNode.getDist())));
// // print number clusters and their data points
// LOGGER.debug("Iteration " + iterationStep + ":");
// LOGGER.debug(" Number Clusters: " + clusters.size());
// printClustersDataRows(clusters);
}
if (clusters.size() > 0) {
m_rootNode = clusters.get(0);
}
fusionCont.close();
// if there was no input data create an empty output data
if (outputData == null) {
outputData = createResultTable(inputData, clusters, exec);
}
m_dataArray = new DefaultDataArray(inputData, 1, (int) inputData.size());
m_fusionTable = new DefaultDataArray(fusionCont.getTable(), 1, iterationStep);
return new BufferedDataTable[] { exec.createBufferedDataTable(outputData, exec) };
}
Aggregations