Search in sources :

Example 1 with WardLinkage

use of smile.clustering.linkage.WardLinkage in project smile by haifengl.

the class BIRCH method partition.

/**
     * Clustering leaves of CF tree into k clusters.
     * @param k the number of clusters.
     * @param minPts a CF leaf will be treated as outlier if the number of its
     * points is less than minPts.
     * @return the number of non-outlier leaves.
     */
public int partition(int k, int minPts) {
    ArrayList<Leaf> leaves = new ArrayList<>();
    ArrayList<double[]> centers = new ArrayList<>();
    Queue<Node> queue = new LinkedList<>();
    queue.offer(root);
    for (Node node = queue.poll(); node != null; node = queue.poll()) {
        if (node.numChildren == 0) {
            if (node.n >= minPts) {
                double[] x = new double[d];
                for (int i = 0; i < d; i++) {
                    x[i] = node.sum[i] / node.n;
                }
                centers.add(x);
                leaves.add((Leaf) node);
            } else {
                Leaf leaf = (Leaf) node;
                leaf.y = OUTLIER;
            }
        } else {
            for (int i = 0; i < node.numChildren; i++) {
                queue.offer(node.children[i]);
            }
        }
    }
    int n = centers.size();
    centroids = centers.toArray(new double[n][]);
    if (n > k) {
        double[][] proximity = new double[n][];
        for (int i = 0; i < n; i++) {
            proximity[i] = new double[i + 1];
            for (int j = 0; j < i; j++) {
                proximity[i][j] = Math.distance(centroids[i], centroids[j]);
            }
        }
        Linkage linkage = new WardLinkage(proximity);
        HierarchicalClustering hc = new HierarchicalClustering(linkage);
        int[] y = hc.partition(k);
        for (int i = 0; i < n; i++) {
            leaves.get(i).y = y[i];
        }
    } else {
        for (int i = 0; i < n; i++) {
            leaves.get(i).y = i;
        }
    }
    return n;
}
Also used : Linkage(smile.clustering.linkage.Linkage) WardLinkage(smile.clustering.linkage.WardLinkage) ArrayList(java.util.ArrayList) WardLinkage(smile.clustering.linkage.WardLinkage) LinkedList(java.util.LinkedList)

Example 2 with WardLinkage

use of smile.clustering.linkage.WardLinkage in project smile by haifengl.

the class HierarchicalClusteringTest method testUSPS.

/**
     * Test of learn method, of class GMeans.
     */
@Test
public void testUSPS() {
    System.out.println("USPS");
    DelimitedTextParser parser = new DelimitedTextParser();
    parser.setResponseIndex(new NominalAttribute("class"), 0);
    try {
        AttributeDataset train = parser.parse("USPS Train", smile.data.parser.IOUtils.getTestDataFile("usps/zip.train"));
        double[][] x = train.toArray(new double[train.size()][]);
        int[] y = train.toArray(new int[train.size()]);
        int n = x.length;
        double[][] proximity = new double[n][];
        for (int i = 0; i < n; i++) {
            proximity[i] = new double[i + 1];
            for (int j = 0; j < i; j++) {
                proximity[i][j] = Math.distance(x[i], x[j]);
            }
        }
        AdjustedRandIndex ari = new AdjustedRandIndex();
        RandIndex rand = new RandIndex();
        HierarchicalClustering hc = new HierarchicalClustering(new SingleLinkage(proximity));
        int[] label = hc.partition(10);
        double r = rand.measure(y, label);
        double r2 = ari.measure(y, label);
        System.out.format("SingleLinkage rand index = %.2f%%\tadjusted rand index = %.2f%%%n", 100.0 * r, 100.0 * r2);
        assertTrue(r > 0.1);
        hc = new HierarchicalClustering(new CompleteLinkage(proximity));
        label = hc.partition(10);
        r = rand.measure(y, label);
        r2 = ari.measure(y, label);
        System.out.format("CompleteLinkage rand index = %.2f%%\tadjusted rand index = %.2f%%%n", 100.0 * r, 100.0 * r2);
        assertTrue(r > 0.75);
        hc = new HierarchicalClustering(new UPGMALinkage(proximity));
        label = hc.partition(10);
        r = rand.measure(y, label);
        r2 = ari.measure(y, label);
        System.out.format("UPGMA rand index = %.2f%%\tadjusted rand index = %.2f%%%n", 100.0 * r, 100.0 * r2);
        assertTrue(r > 0.1);
        hc = new HierarchicalClustering(new WPGMALinkage(proximity));
        label = hc.partition(10);
        r = rand.measure(y, label);
        r2 = ari.measure(y, label);
        System.out.format("WPGMA rand index = %.2f%%\tadjusted rand index = %.2f%%%n", 100.0 * r, 100.0 * r2);
        assertTrue(r > 0.2);
        hc = new HierarchicalClustering(new UPGMCLinkage(proximity));
        label = hc.partition(10);
        r = rand.measure(y, label);
        r2 = ari.measure(y, label);
        System.out.format("UPGMC rand index = %.2f%%\tadjusted rand index = %.2f%%%n", 100.0 * r, 100.0 * r2);
        assertTrue(r > 0.1);
        hc = new HierarchicalClustering(new WPGMCLinkage(proximity));
        label = hc.partition(10);
        r = rand.measure(y, label);
        r2 = ari.measure(y, label);
        System.out.format("WPGMC rand index = %.2f%%\tadjusted rand index = %.2f%%%n", 100.0 * r, 100.0 * r2);
        assertTrue(r > 0.1);
        hc = new HierarchicalClustering(new WardLinkage(proximity));
        label = hc.partition(10);
        r = rand.measure(y, label);
        r2 = ari.measure(y, label);
        System.out.format("Ward rand index = %.2f%%\tadjusted rand index = %.2f%%%n", 100.0 * r, 100.0 * r2);
        assertTrue(r > 0.9);
        assertTrue(r2 > 0.5);
    } catch (Exception ex) {
        System.err.println(ex);
    }
}
Also used : DelimitedTextParser(smile.data.parser.DelimitedTextParser) WPGMCLinkage(smile.clustering.linkage.WPGMCLinkage) AttributeDataset(smile.data.AttributeDataset) CompleteLinkage(smile.clustering.linkage.CompleteLinkage) AdjustedRandIndex(smile.validation.AdjustedRandIndex) RandIndex(smile.validation.RandIndex) WardLinkage(smile.clustering.linkage.WardLinkage) NominalAttribute(smile.data.NominalAttribute) SingleLinkage(smile.clustering.linkage.SingleLinkage) WPGMALinkage(smile.clustering.linkage.WPGMALinkage) AdjustedRandIndex(smile.validation.AdjustedRandIndex) UPGMALinkage(smile.clustering.linkage.UPGMALinkage) UPGMCLinkage(smile.clustering.linkage.UPGMCLinkage) Test(org.junit.Test)

Example 3 with WardLinkage

use of smile.clustering.linkage.WardLinkage in project smile by haifengl.

the class HierarchicalClusteringDemo method learn.

@Override
public JComponent learn() {
    long clock = System.currentTimeMillis();
    double[][] data = dataset[datasetIndex];
    int n = data.length;
    double[][] proximity = new double[n][];
    for (int i = 0; i < n; i++) {
        proximity[i] = new double[i + 1];
        for (int j = 0; j < i; j++) proximity[i][j] = Math.distance(data[i], data[j]);
    }
    HierarchicalClustering hac = null;
    switch(linkageBox.getSelectedIndex()) {
        case 0:
            hac = new HierarchicalClustering(new SingleLinkage(proximity));
            break;
        case 1:
            hac = new HierarchicalClustering(new CompleteLinkage(proximity));
            break;
        case 2:
            hac = new HierarchicalClustering(new UPGMALinkage(proximity));
            break;
        case 3:
            hac = new HierarchicalClustering(new WPGMALinkage(proximity));
            break;
        case 4:
            hac = new HierarchicalClustering(new UPGMCLinkage(proximity));
            break;
        case 5:
            hac = new HierarchicalClustering(new WPGMCLinkage(proximity));
            break;
        case 6:
            hac = new HierarchicalClustering(new WardLinkage(proximity));
            break;
        default:
            throw new IllegalStateException("Unsupported Linkage");
    }
    System.out.format("Hierarchical clusterings %d samples in %dms\n", dataset[datasetIndex].length, System.currentTimeMillis() - clock);
    int[] membership = hac.partition(clusterNumber);
    int[] clusterSize = new int[clusterNumber];
    for (int i = 0; i < membership.length; i++) {
        clusterSize[membership[i]]++;
    }
    JPanel pane = new JPanel(new GridLayout(1, 3));
    PlotCanvas plot = ScatterPlot.plot(dataset[datasetIndex], pointLegend);
    plot.setTitle("Data");
    pane.add(plot);
    for (int k = 0; k < clusterNumber; k++) {
        double[][] cluster = new double[clusterSize[k]][];
        for (int i = 0, j = 0; i < dataset[datasetIndex].length; i++) {
            if (membership[i] == k) {
                cluster[j++] = dataset[datasetIndex][i];
            }
        }
        plot.points(cluster, pointLegend, Palette.COLORS[k % Palette.COLORS.length]);
    }
    plot = Dendrogram.plot("Dendrogram", hac.getTree(), hac.getHeight());
    plot.setTitle("Dendrogram");
    pane.add(plot);
    return pane;
}
Also used : WPGMCLinkage(smile.clustering.linkage.WPGMCLinkage) JPanel(javax.swing.JPanel) CompleteLinkage(smile.clustering.linkage.CompleteLinkage) WardLinkage(smile.clustering.linkage.WardLinkage) HierarchicalClustering(smile.clustering.HierarchicalClustering) GridLayout(java.awt.GridLayout) SingleLinkage(smile.clustering.linkage.SingleLinkage) WPGMALinkage(smile.clustering.linkage.WPGMALinkage) UPGMALinkage(smile.clustering.linkage.UPGMALinkage) UPGMCLinkage(smile.clustering.linkage.UPGMCLinkage) PlotCanvas(smile.plot.PlotCanvas)

Aggregations

WardLinkage (smile.clustering.linkage.WardLinkage)3 CompleteLinkage (smile.clustering.linkage.CompleteLinkage)2 SingleLinkage (smile.clustering.linkage.SingleLinkage)2 UPGMALinkage (smile.clustering.linkage.UPGMALinkage)2 UPGMCLinkage (smile.clustering.linkage.UPGMCLinkage)2 WPGMALinkage (smile.clustering.linkage.WPGMALinkage)2 WPGMCLinkage (smile.clustering.linkage.WPGMCLinkage)2 GridLayout (java.awt.GridLayout)1 ArrayList (java.util.ArrayList)1 LinkedList (java.util.LinkedList)1 JPanel (javax.swing.JPanel)1 Test (org.junit.Test)1 HierarchicalClustering (smile.clustering.HierarchicalClustering)1 Linkage (smile.clustering.linkage.Linkage)1 AttributeDataset (smile.data.AttributeDataset)1 NominalAttribute (smile.data.NominalAttribute)1 DelimitedTextParser (smile.data.parser.DelimitedTextParser)1 PlotCanvas (smile.plot.PlotCanvas)1 AdjustedRandIndex (smile.validation.AdjustedRandIndex)1 RandIndex (smile.validation.RandIndex)1