Search in sources :

Example 6 with CrossValidation

use of smile.validation.CrossValidation in project smile by haifengl.

the class GradientTreeBoostTest method test.

public void test(GradientTreeBoost.Loss loss, String dataset, String url, int response) {
    System.out.println(dataset + "\t" + loss);
    ArffParser parser = new ArffParser();
    parser.setResponseIndex(response);
    try {
        AttributeDataset data = parser.parse(smile.data.parser.IOUtils.getTestDataFile(url));
        double[] datay = data.toArray(new double[data.size()]);
        double[][] datax = data.toArray(new double[data.size()][]);
        int n = datax.length;
        int k = 10;
        CrossValidation cv = new CrossValidation(n, k);
        double rss = 0.0;
        double ad = 0.0;
        for (int i = 0; i < k; i++) {
            double[][] trainx = Math.slice(datax, cv.train[i]);
            double[] trainy = Math.slice(datay, cv.train[i]);
            double[][] testx = Math.slice(datax, cv.test[i]);
            double[] testy = Math.slice(datay, cv.test[i]);
            GradientTreeBoost boost = new GradientTreeBoost(data.attributes(), trainx, trainy, loss, 100, 6, 0.05, 0.7);
            for (int j = 0; j < testx.length; j++) {
                double r = testy[j] - boost.predict(testx[j]);
                ad += Math.abs(r);
                rss += r * r;
            }
        }
        System.out.format("10-CV RMSE = %.4f \t AbsoluteDeviation = %.4f%n", Math.sqrt(rss / n), ad / n);
    } catch (Exception ex) {
        System.err.println(ex);
    }
}
Also used : ArffParser(smile.data.parser.ArffParser) AttributeDataset(smile.data.AttributeDataset) CrossValidation(smile.validation.CrossValidation)

Example 7 with CrossValidation

use of smile.validation.CrossValidation in project smile by haifengl.

the class RBFNetworkTest method testAilerons.

/**
     * Test of learn method, of class RBFNetwork.
     */
@Test
public void testAilerons() {
    System.out.println("ailerons");
    ArffParser parser = new ArffParser();
    parser.setResponseIndex(40);
    try {
        AttributeDataset data = parser.parse(smile.data.parser.IOUtils.getTestDataFile("weka/regression/ailerons.arff"));
        double[][] datax = data.toArray(new double[data.size()][]);
        Math.standardize(datax);
        double[] datay = data.toArray(new double[data.size()]);
        for (int i = 0; i < datay.length; i++) {
            datay[i] *= 10000;
        }
        int n = datax.length;
        int k = 10;
        CrossValidation cv = new CrossValidation(n, k);
        double rss = 0.0;
        for (int i = 0; i < k; i++) {
            double[][] trainx = Math.slice(datax, cv.train[i]);
            double[] trainy = Math.slice(datay, cv.train[i]);
            double[][] testx = Math.slice(datax, cv.test[i]);
            double[] testy = Math.slice(datay, cv.test[i]);
            double[][] centers = new double[20][];
            RadialBasisFunction[] basis = SmileUtils.learnGaussianRadialBasis(trainx, centers, 5.0);
            RBFNetwork<double[]> rbf = new RBFNetwork<>(trainx, trainy, new EuclideanDistance(), basis, centers);
            for (int j = 0; j < testx.length; j++) {
                double r = testy[j] - rbf.predict(testx[j]);
                rss += r * r;
            }
        }
        System.out.println("10-CV MSE = " + rss / n);
    } catch (Exception ex) {
        System.err.println(ex);
    }
}
Also used : RadialBasisFunction(smile.math.rbf.RadialBasisFunction) AttributeDataset(smile.data.AttributeDataset) EuclideanDistance(smile.math.distance.EuclideanDistance) ArffParser(smile.data.parser.ArffParser) CrossValidation(smile.validation.CrossValidation) Test(org.junit.Test)

Example 8 with CrossValidation

use of smile.validation.CrossValidation in project smile by haifengl.

the class RBFNetworkTest method testBank32nh.

/**
     * Test of learn method, of class RBFNetwork.
     */
@Test
public void testBank32nh() {
    System.out.println("bank32nh");
    ArffParser parser = new ArffParser();
    parser.setResponseIndex(31);
    try {
        AttributeDataset data = parser.parse(smile.data.parser.IOUtils.getTestDataFile("weka/regression/bank32nh.arff"));
        double[] datay = data.toArray(new double[data.size()]);
        double[][] datax = data.toArray(new double[data.size()][]);
        Math.standardize(datax);
        int n = datax.length;
        int k = 10;
        CrossValidation cv = new CrossValidation(n, k);
        double rss = 0.0;
        for (int i = 0; i < k; i++) {
            double[][] trainx = Math.slice(datax, cv.train[i]);
            double[] trainy = Math.slice(datay, cv.train[i]);
            double[][] testx = Math.slice(datax, cv.test[i]);
            double[] testy = Math.slice(datay, cv.test[i]);
            double[][] centers = new double[20][];
            RadialBasisFunction[] basis = SmileUtils.learnGaussianRadialBasis(trainx, centers, 5.0);
            RBFNetwork<double[]> rbf = new RBFNetwork<>(trainx, trainy, new EuclideanDistance(), basis, centers);
            for (int j = 0; j < testx.length; j++) {
                double r = testy[j] - rbf.predict(testx[j]);
                rss += r * r;
            }
        }
        System.out.println("10-CV MSE = " + rss / n);
    } catch (Exception ex) {
        System.err.println(ex);
    }
}
Also used : RadialBasisFunction(smile.math.rbf.RadialBasisFunction) AttributeDataset(smile.data.AttributeDataset) EuclideanDistance(smile.math.distance.EuclideanDistance) ArffParser(smile.data.parser.ArffParser) CrossValidation(smile.validation.CrossValidation) Test(org.junit.Test)

Example 9 with CrossValidation

use of smile.validation.CrossValidation in project smile by haifengl.

the class RandomForestTest method test.

public void test(String dataset, String url, int response) {
    System.out.println(dataset);
    ArffParser parser = new ArffParser();
    parser.setResponseIndex(response);
    try {
        AttributeDataset data = parser.parse(smile.data.parser.IOUtils.getTestDataFile(url));
        double[] datay = data.toArray(new double[data.size()]);
        double[][] datax = data.toArray(new double[data.size()][]);
        int n = datax.length;
        int k = 10;
        CrossValidation cv = new CrossValidation(n, k);
        double rss = 0.0;
        double ad = 0.0;
        for (int i = 0; i < k; i++) {
            double[][] trainx = Math.slice(datax, cv.train[i]);
            double[] trainy = Math.slice(datay, cv.train[i]);
            double[][] testx = Math.slice(datax, cv.test[i]);
            double[] testy = Math.slice(datay, cv.test[i]);
            RandomForest forest = new RandomForest(data.attributes(), trainx, trainy, 200, n, 5, trainx[0].length / 3);
            System.out.format("OOB error rate = %.4f%n", forest.error());
            for (int j = 0; j < testx.length; j++) {
                double r = testy[j] - forest.predict(testx[j]);
                rss += r * r;
                ad += Math.abs(r);
            }
        }
        System.out.format("10-CV RMSE = %.4f \t AbsoluteDeviation = %.4f%n", Math.sqrt(rss / n), ad / n);
    } catch (Exception ex) {
        System.err.println(ex);
    }
}
Also used : ArffParser(smile.data.parser.ArffParser) AttributeDataset(smile.data.AttributeDataset) CrossValidation(smile.validation.CrossValidation)

Example 10 with CrossValidation

use of smile.validation.CrossValidation in project smile by haifengl.

the class NaiveBayesTest method testLearnBernoulli2.

/**
     * Test of learn method, of class SequenceNaiveBayes.
     */
@Test
public void testLearnBernoulli2() {
    System.out.println("online learn Bernoulli");
    double[][] x = moviex;
    int[] y = moviey;
    int n = x.length;
    int k = 10;
    CrossValidation cv = new CrossValidation(n, k);
    int error = 0;
    int total = 0;
    for (int i = 0; i < k; i++) {
        double[][] trainx = Math.slice(x, cv.train[i]);
        int[] trainy = Math.slice(y, cv.train[i]);
        NaiveBayes bayes = new NaiveBayes(NaiveBayes.Model.BERNOULLI, 2, feature.length);
        for (int j = 0; j < trainx.length; j++) {
            bayes.learn(trainx[j], trainy[j]);
        }
        double[][] testx = Math.slice(x, cv.test[i]);
        int[] testy = Math.slice(y, cv.test[i]);
        for (int j = 0; j < testx.length; j++) {
            int label = bayes.predict(testx[j]);
            if (label != -1) {
                total++;
                if (testy[j] != label) {
                    error++;
                }
            }
        }
    }
    System.out.format("Bernoulli error = %d of %d%n", error, total);
    assertTrue(error < 270);
}
Also used : CrossValidation(smile.validation.CrossValidation) Test(org.junit.Test)

Aggregations

CrossValidation (smile.validation.CrossValidation)23 Test (org.junit.Test)20 AttributeDataset (smile.data.AttributeDataset)17 ArffParser (smile.data.parser.ArffParser)17 KMeans (smile.clustering.KMeans)6 GaussianKernel (smile.math.kernel.GaussianKernel)6 EuclideanDistance (smile.math.distance.EuclideanDistance)4 RadialBasisFunction (smile.math.rbf.RadialBasisFunction)4 PolynomialKernel (smile.math.kernel.PolynomialKernel)1