use of org.gitia.froog.statistics.ConfusionMatrix in project froog by mroodschild.
the class Backpropagation method entrenar.
// Algoritmo de Entrenamiento
/**
* Antes de entrenar poner las épocas
*
* @param net
* @param input
* @param output
*/
public void entrenar(Feedforward net, SimpleMatrix input, SimpleMatrix output) {
this.net = net;
this.input = new SimpleMatrix(input);
this.output = new SimpleMatrix(output);
double aciertoTrain = 0;
double aciertoTest = 0;
inicializar();
inicializarBatch();
Clock clock = new Clock();
double costAllTrain = -1;
for (int i = 0; i < this.epoch; i++) {
for (int j = 0; j < cantidadBach; j++) {
clock.start();
bachDatos(j);
// Paso 1: inicializar los DeltaW y los DeltaB
deltasZero();
// Paso 1.1: Computar el costo
costOverall = loss(net.outputAll(inputBach), outputBach);
this.cost.add(costOverall);
// Paso 2: calcular los D_w y los D_b
calcularGradientes();
// Paso 3: Actualizar los parámetros
if (momentum > 0) {
actualizarParametrosMomentum(this.inputBach.numRows());
} else {
actualizarParametros(this.inputBach.numRows());
}
if ((iteracion % testFrecuency) == 0 && inputTest.getMatrix() != null) {
SimpleMatrix yCalcTrain = net.outputAll(this.input);
SimpleMatrix yCalcTest = net.outputAll(this.inputTest);
costAllTrain = loss(yCalcTrain, this.output);
costOverallTest = loss(yCalcTest, outputTest);
this.costTest.add(costOverallTest);
if (classification) {
ConfusionMatrix cMatrixTrain = new ConfusionMatrix();
cMatrixTrain.eval(Compite.eval(yCalcTrain), this.output);
aciertoTrain = cMatrixTrain.getAciertosPorc();
ConfusionMatrix cMatrixTest = new ConfusionMatrix();
cMatrixTest.eval(Compite.eval(yCalcTest), outputTest);
aciertoTest = cMatrixTest.getAciertosPorc();
}
}
clock.stop();
if ((iteracion % testFrecuency) != 0 || inputTest.getMatrix() == null) {
System.out.println("It:\t" + iteracion + "\tTrain:\t" + costOverall + "\tTime:\t" + clock.timeSec() + " s.");
} else {
if (classification) {
System.out.println("It:\t" + iteracion + "\tTrain:\t" + costOverall + "\tTrain Complete:\t" + costAllTrain + "\tTest:\t" + costOverallTest + "\tTrain Aciertos:\t" + aciertoTrain + "\t%." + "\tTest Aciertos:\t" + aciertoTest + "\t%." + "\tTime:\t" + clock.timeSec() + "\ts.");
} else {
System.out.println("It:\t" + iteracion + "\tTrain:\t" + costOverall + "\tTrain Complete:\t" + costAllTrain + "\tTest:\t" + costOverallTest + "\tTime:\t" + clock.timeSec() + "\ts.");
}
}
iteracion++;
}
}
}
Aggregations