use of smile.validation.LOOCV in project smile by haifengl.
the class NeuralNetworkTest method testIris.
/**
* Test of learn method, of class NeuralNetwork.
*/
@Test
public void testIris() {
System.out.println("Iris");
ArffParser arffParser = new ArffParser();
arffParser.setResponseIndex(4);
try {
AttributeDataset iris = arffParser.parse(smile.data.parser.IOUtils.getTestDataFile("weka/iris.arff"));
double[][] x = iris.toArray(new double[iris.size()][]);
int[] y = iris.toArray(new int[iris.size()]);
int n = x.length;
int p = x[0].length;
double[] mu = Math.colMean(x);
double[] sd = Math.colSd(x);
for (int i = 0; i < n; i++) {
for (int j = 0; j < p; j++) {
x[i][j] = (x[i][j] - mu[j]) / sd[j];
}
}
LOOCV loocv = new LOOCV(n);
int error = 0;
for (int i = 0; i < n; i++) {
double[][] trainx = Math.slice(x, loocv.train[i]);
int[] trainy = Math.slice(y, loocv.train[i]);
NeuralNetwork net = new NeuralNetwork(NeuralNetwork.ErrorFunction.CROSS_ENTROPY, NeuralNetwork.ActivationFunction.SOFTMAX, x[0].length, 10, 3);
for (int j = 0; j < 20; j++) {
net.learn(trainx, trainy);
}
if (y[loocv.test[i]] != net.predict(x[loocv.test[i]]))
error++;
}
System.out.println("Neural network error = " + error);
assertTrue(error <= 8);
} catch (Exception ex) {
System.err.println(ex);
}
}
Aggregations