Search in sources :

Example 6 with RadialBasisFunction

use of smile.math.rbf.RadialBasisFunction in project smile by haifengl.

the class ValidationTest method testTest_3args_2.

/**
     * Test of test method, of class Validation.
     */
@Test
public void testTest_3args_2() {
    System.out.println("test");
    ArffParser parser = new ArffParser();
    parser.setResponseIndex(6);
    try {
        AttributeDataset data = parser.parse(smile.data.parser.IOUtils.getTestDataFile("weka/cpu.arff"));
        double[] datay = data.toArray(new double[data.size()]);
        double[][] datax = data.toArray(new double[data.size()][]);
        Math.standardize(datax);
        int n = datax.length;
        int m = 3 * n / 4;
        double[][] x = new double[m][];
        double[] y = new double[m];
        double[][] testx = new double[n - m][];
        double[] testy = new double[n - m];
        int[] index = Math.permutate(n);
        for (int i = 0; i < m; i++) {
            x[i] = datax[index[i]];
            y[i] = datay[index[i]];
        }
        for (int i = m; i < n; i++) {
            testx[i - m] = datax[index[i]];
            testy[i - m] = datay[index[i]];
        }
        double[][] centers = new double[20][];
        RadialBasisFunction[] rbf = SmileUtils.learnGaussianRadialBasis(x, centers, 2);
        RBFNetwork<double[]> rkhs = new RBFNetwork<>(x, y, new EuclideanDistance(), rbf, centers);
        double rmse = Validation.test(rkhs, testx, testy);
        System.out.println("RMSE = " + rmse);
    } catch (Exception ex) {
        System.err.println(ex);
    }
}
Also used : RadialBasisFunction(smile.math.rbf.RadialBasisFunction) AttributeDataset(smile.data.AttributeDataset) RBFNetwork(smile.regression.RBFNetwork) EuclideanDistance(smile.math.distance.EuclideanDistance) ArffParser(smile.data.parser.ArffParser) Test(org.junit.Test)

Example 7 with RadialBasisFunction

use of smile.math.rbf.RadialBasisFunction in project smile by haifengl.

the class RBFNetworkTest method testCPU.

/**
     * Test of learn method, of class RBFNetwork.
     */
@Test
public void testCPU() {
    System.out.println("CPU");
    ArffParser parser = new ArffParser();
    parser.setResponseIndex(6);
    try {
        AttributeDataset data = parser.parse(smile.data.parser.IOUtils.getTestDataFile("weka/cpu.arff"));
        double[] datay = data.toArray(new double[data.size()]);
        double[][] datax = data.toArray(new double[data.size()][]);
        Math.standardize(datax);
        int n = datax.length;
        int k = 10;
        CrossValidation cv = new CrossValidation(n, k);
        double rss = 0.0;
        for (int i = 0; i < k; i++) {
            double[][] trainx = Math.slice(datax, cv.train[i]);
            double[] trainy = Math.slice(datay, cv.train[i]);
            double[][] testx = Math.slice(datax, cv.test[i]);
            double[] testy = Math.slice(datay, cv.test[i]);
            double[][] centers = new double[20][];
            RadialBasisFunction[] basis = SmileUtils.learnGaussianRadialBasis(trainx, centers, 5.0);
            RBFNetwork<double[]> rbf = new RBFNetwork<>(trainx, trainy, new EuclideanDistance(), basis, centers);
            for (int j = 0; j < testx.length; j++) {
                double r = testy[j] - rbf.predict(testx[j]);
                rss += r * r;
            }
        }
        System.out.println("10-CV MSE = " + rss / n);
    } catch (Exception ex) {
        System.err.println(ex);
    }
}
Also used : RadialBasisFunction(smile.math.rbf.RadialBasisFunction) AttributeDataset(smile.data.AttributeDataset) EuclideanDistance(smile.math.distance.EuclideanDistance) ArffParser(smile.data.parser.ArffParser) CrossValidation(smile.validation.CrossValidation) Test(org.junit.Test)

Example 8 with RadialBasisFunction

use of smile.math.rbf.RadialBasisFunction in project smile by haifengl.

the class RBFNetworkTest method test2DPlanes.

/**
     * Test of learn method, of class RBFNetwork.
     */
@Test
public void test2DPlanes() {
    System.out.println("2dplanes");
    ArffParser parser = new ArffParser();
    parser.setResponseIndex(10);
    try {
        AttributeDataset data = parser.parse(smile.data.parser.IOUtils.getTestDataFile("weka/regression/2dplanes.arff"));
        double[] datay = data.toArray(new double[data.size()]);
        double[][] datax = data.toArray(new double[data.size()][]);
        //Math.normalize(datax);
        int n = datax.length;
        int k = 10;
        CrossValidation cv = new CrossValidation(n, k);
        double rss = 0.0;
        for (int i = 0; i < k; i++) {
            double[][] trainx = Math.slice(datax, cv.train[i]);
            double[] trainy = Math.slice(datay, cv.train[i]);
            double[][] testx = Math.slice(datax, cv.test[i]);
            double[] testy = Math.slice(datay, cv.test[i]);
            double[][] centers = new double[20][];
            RadialBasisFunction[] basis = SmileUtils.learnGaussianRadialBasis(trainx, centers, 5.0);
            RBFNetwork<double[]> rbf = new RBFNetwork<>(trainx, trainy, new EuclideanDistance(), basis, centers);
            for (int j = 0; j < testx.length; j++) {
                double r = testy[j] - rbf.predict(testx[j]);
                rss += r * r;
            }
        }
        System.out.println("10-CV MSE = " + rss / n);
    } catch (Exception ex) {
        System.err.println(ex);
    }
}
Also used : RadialBasisFunction(smile.math.rbf.RadialBasisFunction) AttributeDataset(smile.data.AttributeDataset) EuclideanDistance(smile.math.distance.EuclideanDistance) ArffParser(smile.data.parser.ArffParser) CrossValidation(smile.validation.CrossValidation) Test(org.junit.Test)

Example 9 with RadialBasisFunction

use of smile.math.rbf.RadialBasisFunction in project smile by haifengl.

the class RBFNetworkDemo method learn.

@Override
public double[][] learn(double[] x, double[] y) {
    double[][] data = dataset[datasetIndex].toArray(new double[dataset[datasetIndex].size()][]);
    int[] label = dataset[datasetIndex].toArray(new int[dataset[datasetIndex].size()]);
    try {
        k = Integer.parseInt(kField.getText().trim());
        if (k < 2 || k > data.length) {
            JOptionPane.showMessageDialog(this, "Invalid K: " + k, "Error", JOptionPane.ERROR_MESSAGE);
            return null;
        }
    } catch (Exception ex) {
        JOptionPane.showMessageDialog(this, "Invalid K: " + kField.getText(), "Error", JOptionPane.ERROR_MESSAGE);
        return null;
    }
    double[][] centers = new double[k][];
    RadialBasisFunction basis = SmileUtils.learnGaussianRadialBasis(data, centers);
    RBFNetwork<double[]> rbf = new RBFNetwork<>(data, label, new EuclideanDistance(), basis, centers);
    for (int i = 0; i < label.length; i++) {
        label[i] = rbf.predict(data[i]);
    }
    double trainError = error(label, label);
    System.out.format("training error = %.2f%%\n", 100 * trainError);
    double[][] z = new double[y.length][x.length];
    for (int i = 0; i < y.length; i++) {
        for (int j = 0; j < x.length; j++) {
            double[] p = { x[j], y[i] };
            z[i][j] = rbf.predict(p);
        }
    }
    return z;
}
Also used : RadialBasisFunction(smile.math.rbf.RadialBasisFunction) EuclideanDistance(smile.math.distance.EuclideanDistance) RBFNetwork(smile.classification.RBFNetwork)

Example 10 with RadialBasisFunction

use of smile.math.rbf.RadialBasisFunction in project smile by haifengl.

the class RBFNetworkTest method testLearn.

/**
     * Test of learn method, of class RBFNetwork.
     */
@Test
public void testLearn() {
    System.out.println("learn");
    ArffParser arffParser = new ArffParser();
    arffParser.setResponseIndex(4);
    try {
        AttributeDataset iris = arffParser.parse(smile.data.parser.IOUtils.getTestDataFile("weka/iris.arff"));
        double[][] x = iris.toArray(new double[iris.size()][]);
        int[] y = iris.toArray(new int[iris.size()]);
        int n = x.length;
        LOOCV loocv = new LOOCV(n);
        int error = 0;
        for (int i = 0; i < n; i++) {
            double[][] trainx = Math.slice(x, loocv.train[i]);
            int[] trainy = Math.slice(y, loocv.train[i]);
            double[][] centers = new double[10][];
            RadialBasisFunction[] basis = SmileUtils.learnGaussianRadialBasis(trainx, centers, 5.0);
            RBFNetwork<double[]> rbf = new RBFNetwork<>(trainx, trainy, new EuclideanDistance(), basis, centers);
            if (y[loocv.test[i]] != rbf.predict(x[loocv.test[i]]))
                error++;
        }
        System.out.println("RBF network error = " + error);
        assertTrue(error <= 6);
    } catch (Exception ex) {
        System.err.println(ex);
    }
}
Also used : RadialBasisFunction(smile.math.rbf.RadialBasisFunction) EuclideanDistance(smile.math.distance.EuclideanDistance) ArffParser(smile.data.parser.ArffParser) AttributeDataset(smile.data.AttributeDataset) LOOCV(smile.validation.LOOCV) Test(org.junit.Test)

Aggregations

EuclideanDistance (smile.math.distance.EuclideanDistance)11 RadialBasisFunction (smile.math.rbf.RadialBasisFunction)11 Test (org.junit.Test)10 AttributeDataset (smile.data.AttributeDataset)9 ArffParser (smile.data.parser.ArffParser)8 CrossValidation (smile.validation.CrossValidation)4 RBFNetwork (smile.regression.RBFNetwork)2 LOOCV (smile.validation.LOOCV)2 RBFNetwork (smile.classification.RBFNetwork)1 NominalAttribute (smile.data.NominalAttribute)1 DelimitedTextParser (smile.data.parser.DelimitedTextParser)1 GaussianRadialBasis (smile.math.rbf.GaussianRadialBasis)1