Search in sources :

Example 16 with Layer

use of org.gitia.froog.layer.Layer in project froog by mroodschild.

the class Feedforward2Test method testSalida.

@Test
public void testSalida() {
    SimpleMatrix w1 = new SimpleMatrix(2, 2, true, 0.1, 0.2, 0.3, 0.4);
    SimpleMatrix b1 = new SimpleMatrix(2, 1, true, 0.5, 0.1);
    SimpleMatrix w2 = new SimpleMatrix(2, 2, true, 0.5, 0.6, 0.7, 0.8);
    SimpleMatrix b2 = new SimpleMatrix(2, 1, true, 0.2, 0.4);
    Feedforward net = new Feedforward();
    net.addLayer(new Layer(w1, b1, "tansig"));
    net.addLayer(new Layer(w2, b2, "purelim"));
    double[] entrada = { 0.2, 0.6 };
    double[] esperado = { 0.7104, 1.0993 };
    assertArrayEquals(esperado, net.output(entrada).getMatrix().getData(), 0.0001);
}
Also used : SimpleMatrix(org.ejml.simple.SimpleMatrix) Feedforward(org.gitia.froog.Feedforward) Layer(org.gitia.froog.layer.Layer) Test(org.junit.Test)

Example 17 with Layer

use of org.gitia.froog.layer.Layer in project froog by mroodschild.

the class Test method main.

public static void main(String[] args) {
    SimpleMatrix input = CSV.open("src/main/resources/function/train_in.csv");
    SimpleMatrix output = CSV.open("src/main/resources/function/train_out.csv");
    SimpleMatrix in_test = CSV.open("src/main/resources/function/test_in.csv");
    SimpleMatrix out_test = CSV.open("src/main/resources/function/test_out.csv");
    SimpleMatrix all_in = CSV.open("src/main/resources/function/all_in.csv");
    STD std = new STD();
    std.fit(input);
    input = std.eval(input);
    in_test = std.eval(in_test);
    all_in = std.eval(all_in);
    int inputSize = input.numCols();
    int outputSize = output.numCols();
    // ==================== Preparamos la RNA =======================
    Random rand = new Random(4);
    // int nn = 30;
    Feedforward net = new Feedforward();
    net.addLayer(new Layer(inputSize, 10, TransferFunction.TANSIG, rand));
    net.addLayer(new Layer(10, 5, TransferFunction.PRERELU, rand));
    net.addLayer(new Layer(5, outputSize, TransferFunction.PURELIM, rand));
    // =================  configuraciones del ensayo ========================
    // Preparamos el algoritmo de entrenamiento
    Backpropagation bp = new Backpropagation();
    bp.setEpoch(5000);
    // bp.setMomentum(0.9);
    bp.setLearningRate(0.01);
    bp.setInputTest(in_test);
    bp.setOutputTest(out_test);
    bp.setTestFrecuency(1);
    bp.setLossFunction(LossFunction.RMSE);
    input.printDimensions();
    output.printDimensions();
    bp.entrenar(net, input, output);
    try {
        net.outputAll(all_in).saveToFileCSV("src/main/resources/function/res_train.csv");
    } catch (IOException ex) {
        Logger.getLogger(Test.class.getName()).log(Level.SEVERE, null, ex);
    }
}
Also used : Backpropagation(org.gitia.froog.trainingalgorithm.Backpropagation) SimpleMatrix(org.ejml.simple.SimpleMatrix) STD(org.gitia.jdataanalysis.data.stats.STD) Random(java.util.Random) IOException(java.io.IOException) Feedforward(org.gitia.froog.Feedforward) Layer(org.gitia.froog.layer.Layer)

Example 18 with Layer

use of org.gitia.froog.layer.Layer in project froog by mroodschild.

the class Feedforward method toString.

@Override
public String toString() {
    String info = "";
    for (int i = 0; i < layers.size(); i++) {
        Layer l = layers.get(i);
        info += "l" + i + ": " + l.numNeuron() + "\t" + l.getFunction().toString() + "\t";
    }
    return info;
}
Also used : Layer(org.gitia.froog.layer.Layer)

Aggregations

Layer (org.gitia.froog.layer.Layer)18 SimpleMatrix (org.ejml.simple.SimpleMatrix)14 Test (org.junit.Test)11 Feedforward (org.gitia.froog.Feedforward)4 IOException (java.io.IOException)2 Ignore (org.junit.Ignore)2 Random (java.util.Random)1 DocumentBuilder (javax.xml.parsers.DocumentBuilder)1 DocumentBuilderFactory (javax.xml.parsers.DocumentBuilderFactory)1 ParserConfigurationException (javax.xml.parsers.ParserConfigurationException)1 Backpropagation (org.gitia.froog.trainingalgorithm.Backpropagation)1 STD (org.gitia.jdataanalysis.data.stats.STD)1 Document (org.w3c.dom.Document)1 Element (org.w3c.dom.Element)1 Node (org.w3c.dom.Node)1 NodeList (org.w3c.dom.NodeList)1 SAXException (org.xml.sax.SAXException)1