Search in sources :

Example 1 with Layer

use of org.gitia.froog.layer.Layer in project froog by mroodschild.

the class Feedforward method setPesos.

/**
 * Copiamos los pesos en W2 W3 B2 B3 manteniendo la estructura inicial de
 * esas matrices, la forma de copiado es: <br>
 * <br>
 * 1 2 3 <br>
 * 4 5 6 <br>
 * 7 8 9 <br>
 *
 * @param pesos
 */
public void setPesos(SimpleMatrix pesos) {
    if (layers.isEmpty()) {
        System.err.println("Inicialice los capas primero");
    } else {
        double[] aux = new double[0];
        // reservamos los espacios de w
        for (int i = 0; i < layers.size(); i++) {
            aux = ArrayUtils.addAll(aux, layers.get(i).getW().getMatrix().getData());
        }
        // reservamos los espacios de b
        for (int i = 0; i < layers.size(); i++) {
            aux = ArrayUtils.addAll(aux, layers.get(i).getB().getMatrix().getData());
        }
        int posicion = 0;
        int size;
        double[] datos = pesos.getMatrix().getData();
        // cargamos los w
        for (int i = 0; i < layers.size(); i++) {
            Layer layer = layers.get(i);
            size = layer.getW().getNumElements();
            layer.getW().getMatrix().setData(ArrayUtils.subarray(datos, posicion, posicion + size));
            posicion += size;
        }
        // cargamos los b
        for (int i = 0; i < layers.size(); i++) {
            Layer layer = layers.get(i);
            size = layer.getB().getNumElements();
            layer.getB().getMatrix().setData(ArrayUtils.subarray(datos, posicion, posicion + size));
            posicion += size;
        }
    }
}
Also used : Layer(org.gitia.froog.layer.Layer)

Example 2 with Layer

use of org.gitia.froog.layer.Layer in project froog by mroodschild.

the class Feedforward2Test method testLayer.

@Test
public void testLayer() {
    Feedforward net = new Feedforward();
    net.addLayer(new Layer(4, 2, "tansig"));
    net.addLayer(new Layer(2, 2, "purelim"));
    assertEquals("Dimensiones Entrada", 8, net.getLayers().get(0).getW().getNumElements());
    assertEquals("Dimensiones Entrada", 4, net.getLayers().get(1).getW().getNumElements());
}
Also used : Feedforward(org.gitia.froog.Feedforward) Layer(org.gitia.froog.layer.Layer) Test(org.junit.Test)

Example 3 with Layer

use of org.gitia.froog.layer.Layer in project froog by mroodschild.

the class Feedforward2Test method testSalida2.

@Test
public void testSalida2() {
    SimpleMatrix w1 = new SimpleMatrix(2, 2, true, 0.1, 0.2, 0.3, 0.4);
    SimpleMatrix b1 = new SimpleMatrix(2, 1, true, 0.5, 0.1);
    SimpleMatrix w2 = new SimpleMatrix(3, 2, true, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6);
    SimpleMatrix b2 = new SimpleMatrix(3, 1, true, 0.5, 0.1, 0.2);
    SimpleMatrix w3 = new SimpleMatrix(2, 3, true, 0.5, 0.6, 0.9, 0.7, 0.8, 1.0);
    SimpleMatrix b3 = new SimpleMatrix(2, 1, true, 0.2, 0.4);
    Feedforward net = new Feedforward();
    net.addLayer(new Layer(w1, b1, "tansig"));
    net.addLayer(new Layer(w2, b2, "tansig"));
    net.addLayer(new Layer(w3, b3, "purelim"));
    double[] entrada = { 0.2, 0.6 };
    double[] esperado = { 1.2686148090686, 1.7212904026294 };
    // List<Layer> layers = net.getLayers();
    // for (int i = 0; i < layers.size(); i++) {
    // Layer l = layers.get(i);
    // System.out.println("Layer: " + i + " tipo: " + l.toString());
    // l.getW().print();
    // }
    // System.out.println("Salida:");
    // net.output(entrada).print("%.4f");
    assertArrayEquals(esperado, net.output(entrada).getMatrix().getData(), 0.000000000001);
}
Also used : SimpleMatrix(org.ejml.simple.SimpleMatrix) Feedforward(org.gitia.froog.Feedforward) Layer(org.gitia.froog.layer.Layer) Test(org.junit.Test)

Example 4 with Layer

use of org.gitia.froog.layer.Layer in project froog by mroodschild.

the class Backpropagation method actualizarParametros.

/**
 * W(+1) = W - stepSize * [(Dw / m)+ W * reg] <br>
 * B(+1) = B - stepSize * (Db / m) <br>
 *
 * @param m cantidad de datos
 */
protected void actualizarParametros(double m) {
    // double div = stepSize / m;
    for (int i = 0; i < net.getLayers().size(); i++) {
        Layer layer = net.getLayers().get(i);
        SimpleMatrix W = layer.getW();
        SimpleMatrix B = layer.getB();
        SimpleMatrix reg = W.scale(regularization);
        W = W.minus(deltasW.get(i).divide(m).plus(reg).scale(learningRate));
        B = B.minus(deltasB.get(i).divide(m).scale(learningRate));
        layer.setW(W);
        layer.setB(B);
        net.getLayers().set(i, layer);
    }
}
Also used : SimpleMatrix(org.ejml.simple.SimpleMatrix) Layer(org.gitia.froog.layer.Layer)

Example 5 with Layer

use of org.gitia.froog.layer.Layer in project froog by mroodschild.

the class Backpropagation method actualizarParametrosMomentum.

/**
 * W(+1) = W - stepSize * [(Dw / m)+ W * reg] <br>
 * B(+1) = B - stepSize * (Db / m) <br>
 *
 * @param m cantidad de datos
 */
protected void actualizarParametrosMomentum(double m) {
    for (int i = 0; i < net.getLayers().size(); i++) {
        Layer layer = net.getLayers().get(i);
        SimpleMatrix W = layer.getW();
        SimpleMatrix B = layer.getB();
        SimpleMatrix reg = W.scale(regularization);
        SimpleMatrix vW = deltasWprev.get(i).scale(momentum).minus(deltasW.get(i).divide(m).plus(reg).scale(learningRate));
        W = W.plus(vW);
        SimpleMatrix vB = deltasBprev.get(i).scale(momentum).minus(deltasB.get(i).divide(m).scale(learningRate));
        B = B.plus(vB);
        layer.setW(W);
        layer.setB(B);
        net.getLayers().set(i, layer);
        deltasBprev.set(i, vB);
        deltasWprev.set(i, vW);
    }
}
Also used : SimpleMatrix(org.ejml.simple.SimpleMatrix) Layer(org.gitia.froog.layer.Layer)

Aggregations

Layer (org.gitia.froog.layer.Layer)18 SimpleMatrix (org.ejml.simple.SimpleMatrix)14 Test (org.junit.Test)11 Feedforward (org.gitia.froog.Feedforward)4 IOException (java.io.IOException)2 Ignore (org.junit.Ignore)2 Random (java.util.Random)1 DocumentBuilder (javax.xml.parsers.DocumentBuilder)1 DocumentBuilderFactory (javax.xml.parsers.DocumentBuilderFactory)1 ParserConfigurationException (javax.xml.parsers.ParserConfigurationException)1 Backpropagation (org.gitia.froog.trainingalgorithm.Backpropagation)1 STD (org.gitia.jdataanalysis.data.stats.STD)1 Document (org.w3c.dom.Document)1 Element (org.w3c.dom.Element)1 Node (org.w3c.dom.Node)1 NodeList (org.w3c.dom.NodeList)1 SAXException (org.xml.sax.SAXException)1