Search in sources :

Example 1 with ConfusionMatrix

use of org.gitia.froog.statistics.ConfusionMatrix in project froog by mroodschild.

the class Backpropagation method entrenar.

// Algoritmo de Entrenamiento
/**
 * Antes de entrenar poner las épocas
 *
 * @param net
 * @param input
 * @param output
 */
public void entrenar(Feedforward net, SimpleMatrix input, SimpleMatrix output) {
    this.net = net;
    this.input = new SimpleMatrix(input);
    this.output = new SimpleMatrix(output);
    double aciertoTrain = 0;
    double aciertoTest = 0;
    inicializar();
    inicializarBatch();
    Clock clock = new Clock();
    double costAllTrain = -1;
    for (int i = 0; i < this.epoch; i++) {
        for (int j = 0; j < cantidadBach; j++) {
            clock.start();
            bachDatos(j);
            // Paso 1:   inicializar los DeltaW  y los DeltaB
            deltasZero();
            // Paso 1.1: Computar el costo
            costOverall = loss(net.outputAll(inputBach), outputBach);
            this.cost.add(costOverall);
            // Paso 2:   calcular los D_w y los D_b
            calcularGradientes();
            // Paso 3:   Actualizar los parámetros
            if (momentum > 0) {
                actualizarParametrosMomentum(this.inputBach.numRows());
            } else {
                actualizarParametros(this.inputBach.numRows());
            }
            if ((iteracion % testFrecuency) == 0 && inputTest.getMatrix() != null) {
                SimpleMatrix yCalcTrain = net.outputAll(this.input);
                SimpleMatrix yCalcTest = net.outputAll(this.inputTest);
                costAllTrain = loss(yCalcTrain, this.output);
                costOverallTest = loss(yCalcTest, outputTest);
                this.costTest.add(costOverallTest);
                if (classification) {
                    ConfusionMatrix cMatrixTrain = new ConfusionMatrix();
                    cMatrixTrain.eval(Compite.eval(yCalcTrain), this.output);
                    aciertoTrain = cMatrixTrain.getAciertosPorc();
                    ConfusionMatrix cMatrixTest = new ConfusionMatrix();
                    cMatrixTest.eval(Compite.eval(yCalcTest), outputTest);
                    aciertoTest = cMatrixTest.getAciertosPorc();
                }
            }
            clock.stop();
            if ((iteracion % testFrecuency) != 0 || inputTest.getMatrix() == null) {
                System.out.println("It:\t" + iteracion + "\tTrain:\t" + costOverall + "\tTime:\t" + clock.timeSec() + " s.");
            } else {
                if (classification) {
                    System.out.println("It:\t" + iteracion + "\tTrain:\t" + costOverall + "\tTrain Complete:\t" + costAllTrain + "\tTest:\t" + costOverallTest + "\tTrain Aciertos:\t" + aciertoTrain + "\t%." + "\tTest Aciertos:\t" + aciertoTest + "\t%." + "\tTime:\t" + clock.timeSec() + "\ts.");
                } else {
                    System.out.println("It:\t" + iteracion + "\tTrain:\t" + costOverall + "\tTrain Complete:\t" + costAllTrain + "\tTest:\t" + costOverallTest + "\tTime:\t" + clock.timeSec() + "\ts.");
                }
            }
            iteracion++;
        }
    }
}
Also used : SimpleMatrix(org.ejml.simple.SimpleMatrix) ConfusionMatrix(org.gitia.froog.statistics.ConfusionMatrix) Clock(org.gitia.froog.statistics.Clock)

Aggregations

SimpleMatrix (org.ejml.simple.SimpleMatrix)1 Clock (org.gitia.froog.statistics.Clock)1 ConfusionMatrix (org.gitia.froog.statistics.ConfusionMatrix)1