use of smile.vq.SOM in project smile by haifengl.
the class SOMDemo method learn.
@Override
public JComponent learn() {
try {
width = Integer.parseInt(widthField.getText().trim());
if (width < 1) {
JOptionPane.showMessageDialog(this, "Invalid width: " + width, "Error", JOptionPane.ERROR_MESSAGE);
return null;
}
} catch (Exception e) {
JOptionPane.showMessageDialog(this, "Invalid width: " + widthField.getText(), "Error", JOptionPane.ERROR_MESSAGE);
return null;
}
try {
height = Integer.parseInt(heightField.getText().trim());
if (height < 1) {
JOptionPane.showMessageDialog(this, "Invalid height: " + height, "Error", JOptionPane.ERROR_MESSAGE);
return null;
}
} catch (Exception e) {
JOptionPane.showMessageDialog(this, "Invalid height: " + heightField.getText(), "Error", JOptionPane.ERROR_MESSAGE);
return null;
}
long clock = System.currentTimeMillis();
SOM som = new SOM(dataset[datasetIndex], width, height);
System.out.format("SOM clusterings %d samples in %dms\n", dataset[datasetIndex].length, System.currentTimeMillis() - clock);
JPanel pane = new JPanel(new GridLayout(2, 3));
PlotCanvas plot = ScatterPlot.plot(dataset[datasetIndex], pointLegend);
plot.grid(som.map());
plot.setTitle("SOM Grid");
pane.add(plot);
int[] membership = som.partition(clusterNumber);
int[] clusterSize = new int[clusterNumber];
for (int i = 0; i < membership.length; i++) {
clusterSize[membership[i]]++;
}
plot = ScatterPlot.plot(dataset[datasetIndex], pointLegend);
plot.setTitle("Hierarchical Clustering");
for (int k = 0; k < clusterNumber; k++) {
double[][] cluster = new double[clusterSize[k]][];
for (int i = 0, j = 0; i < dataset[datasetIndex].length; i++) {
if (membership[i] == k) {
cluster[j++] = dataset[datasetIndex][i];
}
}
plot.points(cluster, pointLegend, Palette.COLORS[k % Palette.COLORS.length]);
}
pane.add(plot);
double[][] umatrix = som.umatrix();
double[] umatrix1 = new double[umatrix.length * umatrix[0].length];
for (int i = 0, k = 0; i < umatrix.length; i++) {
for (int j = 0; j < umatrix[i].length; j++, k++) umatrix1[k] = umatrix[i][j];
}
plot = Histogram.plot(null, umatrix1, 20);
plot.setTitle("U-Matrix Histogram");
pane.add(plot);
GaussianMixture mixture = new GaussianMixture(umatrix1);
double w = (Math.max(umatrix1) - Math.min(umatrix1)) / 24;
double[][] p = new double[50][2];
for (int i = 0; i < p.length; i++) {
p[i][0] = Math.min(umatrix1) + i * w;
p[i][1] = mixture.p(p[i][0]) * w;
}
plot.line(p, Color.RED);
plot = Hexmap.plot(umatrix, Palette.jet(256));
plot.setTitle("U-Matrix");
pane.add(plot);
/*
double[][] x = new double[height][width];
for (int i = 0; i < height; i++) {
for (int j = 0; j < width; j++) {
x[i][j] = som.getMap()[i][j][0];
}
}
plot = PlotCanvas.hexmap(x, Palette.jet(256));
plot.setTitle("X");
pane.add(plot);
double[][] y = new double[height][width];
for (int i = 0; i < height; i++) {
for (int j = 0; j < width; j++) {
y[i][j] = som.getMap()[i][j][1];
}
}
plot = PlotCanvas.hexmap(y, Palette.jet(256));
plot.setTitle("Y");
pane.add(plot);
*/
return pane;
}
use of smile.vq.SOM in project smile by haifengl.
the class HexmapDemo method main.
public static void main(String[] args) {
DelimitedTextParser parser = new DelimitedTextParser();
parser.setResponseIndex(new NominalAttribute("class"), 0);
try {
AttributeDataset train = parser.parse("USPS Train", smile.data.parser.IOUtils.getTestDataFile("usps/zip.train"));
double[][] x = train.toArray(new double[train.size()][]);
int[] y = train.toArray(new int[train.size()]);
int m = 20;
int n = 20;
SOM som = new SOM(x, m, n);
String[][] labels = new String[m][n];
int[] neurons = new int[x.length];
for (int i = 0; i < x.length; i++) {
neurons[i] = som.predict(x[i]);
}
int[] count = new int[10];
for (int i = 0; i < m; i++) {
for (int j = 0; j < n; j++) {
Arrays.fill(count, 0);
for (int k = 0; k < neurons.length; k++) {
if (neurons[k] == i * n + j) {
count[y[k]]++;
}
}
int sum = Math.sum(count);
if (sum == 0.0) {
labels[i][j] = "no samples";
} else {
labels[i][j] = String.format("<table border=\"1\"><tr><td>Total</td><td align=\"right\">%d</td></tr>", sum);
for (int l = 0; l < count.length; l++) {
if (count[l] > 0) {
labels[i][j] += String.format("<tr><td>class %d</td><td align=\"right\">%.1f%%</td></tr>", l, 100.0 * count[l] / sum);
}
}
labels[i][j] += "</table>";
}
}
}
double[][] umatrix = som.umatrix();
double[][][] map = som.map();
double[][] proximity = new double[m * n][m * n];
for (int i = 0; i < m * n; i++) {
for (int j = 0; j < m * n; j++) {
proximity[i][j] = Math.distance(map[i / n][i % n], map[j / n][j % n]);
}
}
MDS mds = new MDS(proximity, 3);
double[][] coords = mds.getCoordinates();
double[][][] mdsgrid = new double[m][n][];
for (int i = 0; i < m * n; i++) {
mdsgrid[i / n][i % n] = mds.getCoordinates()[i];
}
SammonMapping sammon = new SammonMapping(proximity, coords);
double[][][] sammongrid = new double[m][n][];
for (int i = 0; i < m * n; i++) {
sammongrid[i / n][i % n] = sammon.getCoordinates()[i];
}
IsotonicMDS isomds = new IsotonicMDS(proximity, coords);
double[][][] isomdsgrid = new double[m][n][];
for (int i = 0; i < m * n; i++) {
isomdsgrid[i / n][i % n] = isomds.getCoordinates()[i];
}
JFrame frame = new JFrame("Hexmap");
frame.setDefaultCloseOperation(JFrame.EXIT_ON_CLOSE);
frame.setLocationRelativeTo(null);
frame.add(Hexmap.plot(labels, umatrix));
PlotCanvas canvas = Surface.plot(mdsgrid);
canvas.setTitle("MDS");
frame.add(canvas);
canvas = Surface.plot(isomdsgrid);
canvas.setTitle("Isotonic MDS");
frame.add(canvas);
canvas = Surface.plot(sammongrid);
canvas.setTitle("Sammon Mapping");
frame.add(canvas);
frame.setVisible(true);
} catch (Exception ex) {
System.err.println(ex);
}
}
Aggregations