Search in sources :

Example 26 with LOOCV

use of smile.validation.LOOCV in project smile by haifengl.

the class NeuralNetworkTest method testIris.

/**
     * Test of learn method, of class NeuralNetwork.
     */
@Test
public void testIris() {
    System.out.println("Iris");
    ArffParser arffParser = new ArffParser();
    arffParser.setResponseIndex(4);
    try {
        AttributeDataset iris = arffParser.parse(smile.data.parser.IOUtils.getTestDataFile("weka/iris.arff"));
        double[][] x = iris.toArray(new double[iris.size()][]);
        int[] y = iris.toArray(new int[iris.size()]);
        int n = x.length;
        int p = x[0].length;
        double[] mu = Math.colMean(x);
        double[] sd = Math.colSd(x);
        for (int i = 0; i < n; i++) {
            for (int j = 0; j < p; j++) {
                x[i][j] = (x[i][j] - mu[j]) / sd[j];
            }
        }
        LOOCV loocv = new LOOCV(n);
        int error = 0;
        for (int i = 0; i < n; i++) {
            double[][] trainx = Math.slice(x, loocv.train[i]);
            int[] trainy = Math.slice(y, loocv.train[i]);
            NeuralNetwork net = new NeuralNetwork(NeuralNetwork.ErrorFunction.CROSS_ENTROPY, NeuralNetwork.ActivationFunction.SOFTMAX, x[0].length, 10, 3);
            for (int j = 0; j < 20; j++) {
                net.learn(trainx, trainy);
            }
            if (y[loocv.test[i]] != net.predict(x[loocv.test[i]]))
                error++;
        }
        System.out.println("Neural network error = " + error);
        assertTrue(error <= 8);
    } catch (Exception ex) {
        System.err.println(ex);
    }
}
Also used : ArffParser(smile.data.parser.ArffParser) AttributeDataset(smile.data.AttributeDataset) LOOCV(smile.validation.LOOCV) Test(org.junit.Test)

Aggregations

Test (org.junit.Test)26 LOOCV (smile.validation.LOOCV)26 AttributeDataset (smile.data.AttributeDataset)18 ArffParser (smile.data.parser.ArffParser)18 EuclideanDistance (smile.math.distance.EuclideanDistance)2 RadialBasisFunction (smile.math.rbf.RadialBasisFunction)2 IOException (java.io.IOException)1 ArrayList (java.util.ArrayList)1 GaussianKernel (smile.math.kernel.GaussianKernel)1 Distribution (smile.stat.distribution.Distribution)1 GaussianMixture (smile.stat.distribution.GaussianMixture)1