Search in sources :

Example 1 with RBFNetwork

use of smile.classification.RBFNetwork in project smile by haifengl.

the class RBFNetworkDemo method learn.

@Override
public double[][] learn(double[] x, double[] y) {
    double[][] data = dataset[datasetIndex].toArray(new double[dataset[datasetIndex].size()][]);
    int[] label = dataset[datasetIndex].toArray(new int[dataset[datasetIndex].size()]);
    try {
        k = Integer.parseInt(kField.getText().trim());
        if (k < 2 || k > data.length) {
            JOptionPane.showMessageDialog(this, "Invalid K: " + k, "Error", JOptionPane.ERROR_MESSAGE);
            return null;
        }
    } catch (Exception ex) {
        JOptionPane.showMessageDialog(this, "Invalid K: " + kField.getText(), "Error", JOptionPane.ERROR_MESSAGE);
        return null;
    }
    double[][] centers = new double[k][];
    RadialBasisFunction basis = SmileUtils.learnGaussianRadialBasis(data, centers);
    RBFNetwork<double[]> rbf = new RBFNetwork<>(data, label, new EuclideanDistance(), basis, centers);
    for (int i = 0; i < label.length; i++) {
        label[i] = rbf.predict(data[i]);
    }
    double trainError = error(label, label);
    System.out.format("training error = %.2f%%\n", 100 * trainError);
    double[][] z = new double[y.length][x.length];
    for (int i = 0; i < y.length; i++) {
        for (int j = 0; j < x.length; j++) {
            double[] p = { x[j], y[i] };
            z[i][j] = rbf.predict(p);
        }
    }
    return z;
}
Also used : RadialBasisFunction(smile.math.rbf.RadialBasisFunction) EuclideanDistance(smile.math.distance.EuclideanDistance) RBFNetwork(smile.classification.RBFNetwork)

Aggregations

RBFNetwork (smile.classification.RBFNetwork)1 EuclideanDistance (smile.math.distance.EuclideanDistance)1 RadialBasisFunction (smile.math.rbf.RadialBasisFunction)1