Search in sources :

Example 1 with PolynomialKernel

use of smile.math.kernel.PolynomialKernel in project smile by haifengl.

the class SVRTest method testCPU.

/**
     * Test of learn method, of class SVR.
     */
@Test
public void testCPU() {
    System.out.println("CPU");
    ArffParser parser = new ArffParser();
    parser.setResponseIndex(6);
    try {
        AttributeDataset data = parser.parse(smile.data.parser.IOUtils.getTestDataFile("weka/cpu.arff"));
        double[] datay = data.toArray(new double[data.size()]);
        double[][] datax = data.toArray(new double[data.size()][]);
        Math.standardize(datax);
        int n = datax.length;
        int k = 10;
        CrossValidation cv = new CrossValidation(n, k);
        double rss = 0.0;
        for (int i = 0; i < k; i++) {
            double[][] trainx = Math.slice(datax, cv.train[i]);
            double[] trainy = Math.slice(datay, cv.train[i]);
            double[][] testx = Math.slice(datax, cv.test[i]);
            double[] testy = Math.slice(datay, cv.test[i]);
            SVR<double[]> svr = new SVR<>(trainx, trainy, new PolynomialKernel(3, 1.0, 1.0), 0.1, 1.0);
            for (int j = 0; j < testx.length; j++) {
                double r = testy[j] - svr.predict(testx[j]);
                rss += r * r;
            }
        }
        System.out.println("10-CV RMSE = " + Math.sqrt(rss / n));
    } catch (Exception ex) {
        System.err.println(ex);
    }
}
Also used : PolynomialKernel(smile.math.kernel.PolynomialKernel) ArffParser(smile.data.parser.ArffParser) AttributeDataset(smile.data.AttributeDataset) CrossValidation(smile.validation.CrossValidation) Test(org.junit.Test)

Example 2 with PolynomialKernel

use of smile.math.kernel.PolynomialKernel in project smile by haifengl.

the class SVMTest method testLearn.

/**
     * Test of learn method, of class SVM.
     */
@Test
public void testLearn() {
    System.out.println("learn");
    ArffParser arffParser = new ArffParser();
    arffParser.setResponseIndex(4);
    try {
        AttributeDataset iris = arffParser.parse(smile.data.parser.IOUtils.getTestDataFile("weka/iris.arff"));
        double[][] x = iris.toArray(new double[iris.size()][]);
        int[] y = iris.toArray(new int[iris.size()]);
        SVM<double[]> svm = new SVM<>(new LinearKernel(), 10.0, Math.max(y) + 1, SVM.Multiclass.ONE_VS_ALL);
        svm.learn(x, y);
        svm.learn(x, y);
        svm.finish();
        int error = 0;
        for (int i = 0; i < x.length; i++) {
            if (svm.predict(x[i]) != y[i]) {
                error++;
            }
        }
        System.out.println("Linear ONE vs. ALL error = " + error);
        assertTrue(error <= 10);
        svm = new SVM<>(new GaussianKernel(1), 1.0, Math.max(y) + 1, SVM.Multiclass.ONE_VS_ALL);
        svm.learn(x, y);
        svm.learn(x, y);
        svm.finish();
        svm.trainPlattScaling(x, y);
        error = 0;
        for (int i = 0; i < x.length; i++) {
            if (svm.predict(x[i]) != y[i]) {
                error++;
            }
            double[] prob = new double[3];
            int yp = svm.predict(x[i], prob);
        //System.out.format("%d %d %.2f, %.2f %.2f\n", y[i], yp, prob[0], prob[1], prob[2]);
        }
        System.out.println("Gaussian ONE vs. ALL error = " + error);
        assertTrue(error <= 5);
        svm = new SVM<>(new GaussianKernel(1), 1.0, Math.max(y) + 1, SVM.Multiclass.ONE_VS_ONE);
        svm.learn(x, y);
        svm.learn(x, y);
        svm.finish();
        svm.trainPlattScaling(x, y);
        error = 0;
        for (int i = 0; i < x.length; i++) {
            if (svm.predict(x[i]) != y[i]) {
                error++;
            }
            double[] prob = new double[3];
            int yp = svm.predict(x[i], prob);
        //System.out.format("%d %d %.2f, %.2f %.2f\n", y[i], yp, prob[0], prob[1], prob[2]);
        }
        System.out.println("Gaussian ONE vs. ONE error = " + error);
        assertTrue(error <= 5);
        svm = new SVM<>(new PolynomialKernel(2), 1.0, Math.max(y) + 1, SVM.Multiclass.ONE_VS_ALL);
        svm.learn(x, y);
        svm.learn(x, y);
        svm.finish();
        error = 0;
        for (int i = 0; i < x.length; i++) {
            if (svm.predict(x[i]) != y[i]) {
                error++;
            }
        }
        System.out.println("Polynomial ONE vs. ALL error = " + error);
        assertTrue(error <= 5);
    } catch (Exception ex) {
        ex.printStackTrace();
    }
}
Also used : AttributeDataset(smile.data.AttributeDataset) PolynomialKernel(smile.math.kernel.PolynomialKernel) ArffParser(smile.data.parser.ArffParser) LinearKernel(smile.math.kernel.LinearKernel) GaussianKernel(smile.math.kernel.GaussianKernel) Test(org.junit.Test)

Aggregations

Test (org.junit.Test)2 AttributeDataset (smile.data.AttributeDataset)2 ArffParser (smile.data.parser.ArffParser)2 PolynomialKernel (smile.math.kernel.PolynomialKernel)2 GaussianKernel (smile.math.kernel.GaussianKernel)1 LinearKernel (smile.math.kernel.LinearKernel)1 CrossValidation (smile.validation.CrossValidation)1