Search in sources :

Example 1 with GaussianMixture

use of smile.stat.distribution.GaussianMixture in project smile by haifengl.

the class SOMDemo method learn.

@Override
public JComponent learn() {
    try {
        width = Integer.parseInt(widthField.getText().trim());
        if (width < 1) {
            JOptionPane.showMessageDialog(this, "Invalid width: " + width, "Error", JOptionPane.ERROR_MESSAGE);
            return null;
        }
    } catch (Exception e) {
        JOptionPane.showMessageDialog(this, "Invalid width: " + widthField.getText(), "Error", JOptionPane.ERROR_MESSAGE);
        return null;
    }
    try {
        height = Integer.parseInt(heightField.getText().trim());
        if (height < 1) {
            JOptionPane.showMessageDialog(this, "Invalid height: " + height, "Error", JOptionPane.ERROR_MESSAGE);
            return null;
        }
    } catch (Exception e) {
        JOptionPane.showMessageDialog(this, "Invalid height: " + heightField.getText(), "Error", JOptionPane.ERROR_MESSAGE);
        return null;
    }
    long clock = System.currentTimeMillis();
    SOM som = new SOM(dataset[datasetIndex], width, height);
    System.out.format("SOM clusterings %d samples in %dms\n", dataset[datasetIndex].length, System.currentTimeMillis() - clock);
    JPanel pane = new JPanel(new GridLayout(2, 3));
    PlotCanvas plot = ScatterPlot.plot(dataset[datasetIndex], pointLegend);
    plot.grid(som.map());
    plot.setTitle("SOM Grid");
    pane.add(plot);
    int[] membership = som.partition(clusterNumber);
    int[] clusterSize = new int[clusterNumber];
    for (int i = 0; i < membership.length; i++) {
        clusterSize[membership[i]]++;
    }
    plot = ScatterPlot.plot(dataset[datasetIndex], pointLegend);
    plot.setTitle("Hierarchical Clustering");
    for (int k = 0; k < clusterNumber; k++) {
        double[][] cluster = new double[clusterSize[k]][];
        for (int i = 0, j = 0; i < dataset[datasetIndex].length; i++) {
            if (membership[i] == k) {
                cluster[j++] = dataset[datasetIndex][i];
            }
        }
        plot.points(cluster, pointLegend, Palette.COLORS[k % Palette.COLORS.length]);
    }
    pane.add(plot);
    double[][] umatrix = som.umatrix();
    double[] umatrix1 = new double[umatrix.length * umatrix[0].length];
    for (int i = 0, k = 0; i < umatrix.length; i++) {
        for (int j = 0; j < umatrix[i].length; j++, k++) umatrix1[k] = umatrix[i][j];
    }
    plot = Histogram.plot(null, umatrix1, 20);
    plot.setTitle("U-Matrix Histogram");
    pane.add(plot);
    GaussianMixture mixture = new GaussianMixture(umatrix1);
    double w = (Math.max(umatrix1) - Math.min(umatrix1)) / 24;
    double[][] p = new double[50][2];
    for (int i = 0; i < p.length; i++) {
        p[i][0] = Math.min(umatrix1) + i * w;
        p[i][1] = mixture.p(p[i][0]) * w;
    }
    plot.line(p, Color.RED);
    plot = Hexmap.plot(umatrix, Palette.jet(256));
    plot.setTitle("U-Matrix");
    pane.add(plot);
    /*
        double[][] x = new double[height][width];
        for (int i = 0; i < height; i++) {
            for (int j = 0; j < width; j++) {
                x[i][j] = som.getMap()[i][j][0];
            }
        }
        plot = PlotCanvas.hexmap(x, Palette.jet(256));
        plot.setTitle("X");
        pane.add(plot);

        double[][] y = new double[height][width];
        for (int i = 0; i < height; i++) {
            for (int j = 0; j < width; j++) {
                y[i][j] = som.getMap()[i][j][1];
            }
        }
        plot = PlotCanvas.hexmap(y, Palette.jet(256));
        plot.setTitle("Y");
        pane.add(plot);
*/
    return pane;
}
Also used : SOM(smile.vq.SOM) JPanel(javax.swing.JPanel) GridLayout(java.awt.GridLayout) GaussianMixture(smile.stat.distribution.GaussianMixture) PlotCanvas(smile.plot.PlotCanvas)

Example 2 with GaussianMixture

use of smile.stat.distribution.GaussianMixture in project smile by haifengl.

the class NaiveBayesTest method testPredict.

/**
     * Test of predict method, of class NaiveBayes.
     */
@Test
public void testPredict() {
    System.out.println("predict");
    ArffParser arffParser = new ArffParser();
    arffParser.setResponseIndex(4);
    try {
        AttributeDataset iris = arffParser.parse(smile.data.parser.IOUtils.getTestDataFile("weka/iris.arff"));
        double[][] x = iris.toArray(new double[iris.size()][]);
        int[] y = iris.toArray(new int[iris.size()]);
        int n = x.length;
        LOOCV loocv = new LOOCV(n);
        int error = 0;
        for (int l = 0; l < n; l++) {
            double[][] trainx = Math.slice(x, loocv.train[l]);
            int[] trainy = Math.slice(y, loocv.train[l]);
            int p = trainx[0].length;
            int k = Math.max(trainy) + 1;
            double[] priori = new double[k];
            Distribution[][] condprob = new Distribution[k][p];
            for (int i = 0; i < k; i++) {
                priori[i] = 1.0 / k;
                for (int j = 0; j < p; j++) {
                    ArrayList<Double> axi = new ArrayList<>();
                    for (int m = 0; m < trainx.length; m++) {
                        if (trainy[m] == i) {
                            axi.add(trainx[m][j]);
                        }
                    }
                    double[] xi = new double[axi.size()];
                    for (int m = 0; m < xi.length; m++) {
                        xi[m] = axi.get(m);
                    }
                    condprob[i][j] = new GaussianMixture(xi, 3);
                }
            }
            NaiveBayes bayes = new NaiveBayes(priori, condprob);
            if (y[loocv.test[l]] != bayes.predict(x[loocv.test[l]]))
                error++;
        }
        System.out.format("Iris error rate = %.2f%%%n", 100.0 * error / x.length);
        assertEquals(5, error);
    } catch (Exception ex) {
        System.err.println(ex);
    }
}
Also used : AttributeDataset(smile.data.AttributeDataset) ArrayList(java.util.ArrayList) GaussianMixture(smile.stat.distribution.GaussianMixture) LOOCV(smile.validation.LOOCV) IOException(java.io.IOException) ArffParser(smile.data.parser.ArffParser) Distribution(smile.stat.distribution.Distribution) Test(org.junit.Test)

Aggregations

GaussianMixture (smile.stat.distribution.GaussianMixture)2 GridLayout (java.awt.GridLayout)1 IOException (java.io.IOException)1 ArrayList (java.util.ArrayList)1 JPanel (javax.swing.JPanel)1 Test (org.junit.Test)1 AttributeDataset (smile.data.AttributeDataset)1 ArffParser (smile.data.parser.ArffParser)1 PlotCanvas (smile.plot.PlotCanvas)1 Distribution (smile.stat.distribution.Distribution)1 LOOCV (smile.validation.LOOCV)1 SOM (smile.vq.SOM)1