Search in sources :

Example 81 with AttributeDataset

use of smile.data.AttributeDataset in project smile by haifengl.

the class RBFNetworkTest method testUSPS.

/**
     * Test of learn method, of class RBFNetwork.
     */
@Test
public void testUSPS() {
    System.out.println("USPS");
    DelimitedTextParser parser = new DelimitedTextParser();
    parser.setResponseIndex(new NominalAttribute("class"), 0);
    try {
        AttributeDataset train = parser.parse("USPS Train", smile.data.parser.IOUtils.getTestDataFile("usps/zip.train"));
        AttributeDataset test = parser.parse("USPS Test", smile.data.parser.IOUtils.getTestDataFile("usps/zip.test"));
        double[][] x = train.toArray(new double[train.size()][]);
        int[] y = train.toArray(new int[train.size()]);
        double[][] testx = test.toArray(new double[test.size()][]);
        int[] testy = test.toArray(new int[test.size()]);
        double[][] centers = new double[200][];
        RadialBasisFunction basis = SmileUtils.learnGaussianRadialBasis(x, centers);
        RBFNetwork<double[]> rbf = new RBFNetwork<>(x, y, new EuclideanDistance(), new GaussianRadialBasis(8.0), centers);
        int error = 0;
        for (int i = 0; i < testx.length; i++) {
            if (rbf.predict(testx[i]) != testy[i]) {
                error++;
            }
        }
        System.out.format("USPS error rate = %.2f%%%n", 100.0 * error / testx.length);
        assertTrue(error <= 150);
    } catch (Exception ex) {
        System.err.println(ex);
    }
}
Also used : DelimitedTextParser(smile.data.parser.DelimitedTextParser) RadialBasisFunction(smile.math.rbf.RadialBasisFunction) EuclideanDistance(smile.math.distance.EuclideanDistance) AttributeDataset(smile.data.AttributeDataset) NominalAttribute(smile.data.NominalAttribute) GaussianRadialBasis(smile.math.rbf.GaussianRadialBasis) Test(org.junit.Test)

Example 82 with AttributeDataset

use of smile.data.AttributeDataset in project smile by haifengl.

the class RDATest method testUSPS.

/**
     * Test of learn method, of class RDA.
     */
@Test
public void testUSPS() {
    System.out.println("USPS");
    DelimitedTextParser parser = new DelimitedTextParser();
    parser.setResponseIndex(new NominalAttribute("class"), 0);
    try {
        AttributeDataset train = parser.parse("USPS Train", smile.data.parser.IOUtils.getTestDataFile("usps/zip.train"));
        AttributeDataset test = parser.parse("USPS Test", smile.data.parser.IOUtils.getTestDataFile("usps/zip.test"));
        double[][] x = train.toArray(new double[train.size()][]);
        int[] y = train.toArray(new int[train.size()]);
        double[][] testx = test.toArray(new double[test.size()][]);
        int[] testy = test.toArray(new int[test.size()]);
        RDA rda = new RDA(x, y, 0.7);
        int error = 0;
        for (int i = 0; i < testx.length; i++) {
            if (rda.predict(testx[i]) != testy[i]) {
                error++;
            }
        }
        System.out.format("USPS error rate = %.2f%%%n", 100.0 * error / testx.length);
        assertEquals(235, error);
    } catch (Exception ex) {
        System.err.println(ex);
    }
}
Also used : DelimitedTextParser(smile.data.parser.DelimitedTextParser) AttributeDataset(smile.data.AttributeDataset) NominalAttribute(smile.data.NominalAttribute) Test(org.junit.Test)

Example 83 with AttributeDataset

use of smile.data.AttributeDataset in project smile by haifengl.

the class RandomForestTest method testWeather.

/**
     * Test of learn method, of class RandomForest.
     */
@Test
public void testWeather() {
    System.out.println("Weather");
    ArffParser arffParser = new ArffParser();
    arffParser.setResponseIndex(4);
    try {
        AttributeDataset weather = arffParser.parse(smile.data.parser.IOUtils.getTestDataFile("weka/weather.nominal.arff"));
        double[][] x = weather.toArray(new double[weather.size()][]);
        int[] y = weather.toArray(new int[weather.size()]);
        int n = x.length;
        LOOCV loocv = new LOOCV(n);
        int error = 0;
        for (int i = 0; i < n; i++) {
            double[][] trainx = Math.slice(x, loocv.train[i]);
            int[] trainy = Math.slice(y, loocv.train[i]);
            RandomForest forest = new RandomForest(weather.attributes(), trainx, trainy, 100);
            if (y[loocv.test[i]] != forest.predict(x[loocv.test[i]]))
                error++;
        }
        System.out.println("Random Forest error = " + error);
        assertTrue(error <= 7);
    } catch (Exception ex) {
        System.err.println(ex);
    }
}
Also used : ArffParser(smile.data.parser.ArffParser) AttributeDataset(smile.data.AttributeDataset) LOOCV(smile.validation.LOOCV) Test(org.junit.Test)

Example 84 with AttributeDataset

use of smile.data.AttributeDataset in project smile by haifengl.

the class RandomForestTest method testUSPS.

/**
     * Test of learn method, of class RandomForest.
     */
@Test
public void testUSPS() {
    System.out.println("USPS");
    DelimitedTextParser parser = new DelimitedTextParser();
    parser.setResponseIndex(new NominalAttribute("class"), 0);
    try {
        AttributeDataset train = parser.parse("USPS Train", smile.data.parser.IOUtils.getTestDataFile("usps/zip.train"));
        AttributeDataset test = parser.parse("USPS Test", smile.data.parser.IOUtils.getTestDataFile("usps/zip.test"));
        double[][] x = train.toArray(new double[train.size()][]);
        int[] y = train.toArray(new int[train.size()]);
        double[][] testx = test.toArray(new double[test.size()][]);
        int[] testy = test.toArray(new int[test.size()]);
        RandomForest forest = new RandomForest(x, y, 200);
        int error = 0;
        for (int i = 0; i < testx.length; i++) {
            if (forest.predict(testx[i]) != testy[i]) {
                error++;
            }
        }
        System.out.println(error);
        System.out.format("USPS OOB error rate = %.2f%%%n", 100.0 * forest.error());
        System.out.format("USPS error rate = %.2f%%%n", 100.0 * error / testx.length);
        assertTrue(error <= 140);
    } catch (Exception ex) {
        System.err.println(ex);
    }
}
Also used : DelimitedTextParser(smile.data.parser.DelimitedTextParser) AttributeDataset(smile.data.AttributeDataset) NominalAttribute(smile.data.NominalAttribute) Test(org.junit.Test)

Example 85 with AttributeDataset

use of smile.data.AttributeDataset in project smile by haifengl.

the class SVMTest method testUSPS.

/**
     * Test of learn method, of class SVM.
     */
@Test
public void testUSPS() {
    System.out.println("USPS");
    DelimitedTextParser parser = new DelimitedTextParser();
    parser.setResponseIndex(new NominalAttribute("class"), 0);
    try {
        AttributeDataset train = parser.parse("USPS Train", smile.data.parser.IOUtils.getTestDataFile("usps/zip.train"));
        AttributeDataset test = parser.parse("USPS Test", smile.data.parser.IOUtils.getTestDataFile("usps/zip.test"));
        double[][] x = train.toArray(new double[train.size()][]);
        int[] y = train.toArray(new int[train.size()]);
        double[][] testx = test.toArray(new double[test.size()][]);
        int[] testy = test.toArray(new int[test.size()]);
        SVM<double[]> svm = new SVM<>(new GaussianKernel(8.0), 5.0, Math.max(y) + 1, SVM.Multiclass.ONE_VS_ONE);
        svm.learn(x, y);
        svm.finish();
        int error = 0;
        for (int i = 0; i < testx.length; i++) {
            if (svm.predict(testx[i]) != testy[i]) {
                error++;
            }
        }
        System.out.format("USPS error rate = %.2f%%%n", 100.0 * error / testx.length);
        assertTrue(error < 95);
        System.out.println("USPS one more epoch...");
        for (int i = 0; i < x.length; i++) {
            int j = Math.randomInt(x.length);
            svm.learn(x[j], y[j]);
        }
        svm.finish();
        error = 0;
        for (int i = 0; i < testx.length; i++) {
            if (svm.predict(testx[i]) != testy[i]) {
                error++;
            }
        }
        System.out.format("USPS error rate = %.2f%%%n", 100.0 * error / testx.length);
        assertTrue(error < 95);
    } catch (Exception ex) {
        ex.printStackTrace();
    }
}
Also used : DelimitedTextParser(smile.data.parser.DelimitedTextParser) AttributeDataset(smile.data.AttributeDataset) NominalAttribute(smile.data.NominalAttribute) GaussianKernel(smile.math.kernel.GaussianKernel) Test(org.junit.Test)

Aggregations

AttributeDataset (smile.data.AttributeDataset)140 Test (org.junit.Test)125 ArffParser (smile.data.parser.ArffParser)75 NominalAttribute (smile.data.NominalAttribute)50 DelimitedTextParser (smile.data.parser.DelimitedTextParser)48 Attribute (smile.data.Attribute)29 EuclideanDistance (smile.math.distance.EuclideanDistance)19 LOOCV (smile.validation.LOOCV)18 CrossValidation (smile.validation.CrossValidation)17 AdjustedRandIndex (smile.validation.AdjustedRandIndex)14 RandIndex (smile.validation.RandIndex)14 ClassifierTrainer (smile.classification.ClassifierTrainer)13 GaussianKernel (smile.math.kernel.GaussianKernel)11 IOException (java.io.IOException)10 RadialBasisFunction (smile.math.rbf.RadialBasisFunction)9 RBFNetwork (smile.regression.RBFNetwork)8 ArrayList (java.util.ArrayList)6 KMeans (smile.clustering.KMeans)6 Datum (smile.data.Datum)6 NumericAttribute (smile.data.NumericAttribute)6