use of org.gitia.froog.trainingalgorithm.Backpropagation in project froog by mroodschild.
the class BackpropagationBachTest method testEntrenar.
/**
* Test of entrenar method, of class BackpropagationBach.
*/
@Ignore
@Test
public void testEntrenar() {
System.out.println("entrenar");
Feedforward net = null;
double[][] input = null;
double[][] output = null;
int iteraciones = 0;
Backpropagation instance = new Backpropagation();
instance.setEpoch(iteraciones);
instance.entrenar(net, input, output);
// TODO review the generated test code and remove the default call to fail.
fail("The test case is a prototype.");
}
use of org.gitia.froog.trainingalgorithm.Backpropagation in project froog by mroodschild.
the class BackpropagationBachTest method testComputeCost_doubleArr_doubleArr.
/**
* Test of computeCost method, of class BackpropagationBach.
*/
@Ignore
@Test
public void testComputeCost_doubleArr_doubleArr() {
System.out.println("computeCost");
double[] input = null;
double[] output = null;
Backpropagation instance = new Backpropagation();
double expResult = 0.0;
double result = instance.cost(input, output);
assertEquals(expResult, result, 0.0);
// TODO review the generated test code and remove the default call to fail.
fail("The test case is a prototype.");
}
use of org.gitia.froog.trainingalgorithm.Backpropagation in project froog by mroodschild.
the class BackpropagationBachTest method testComputeCost_SimpleMatrix_SimpleMatrix.
// /**
// * Test of calcularGradientes method, of class BackpropagationBach.
// */
// @Ignore
// @Test
// public void testCalcularGradientes() {
// System.out.println("calcularGradientes");
// Backpropagation instance = new Backpropagation();
// double expResult = 0.0;
// double result = instance.calcularGradientes();
// assertEquals(expResult, result, 0.0);
// // TODO review the generated test code and remove the default call to fail.
// fail("The test case is a prototype.");
// }
/**
* Test of computeCost method, of class BackpropagationBach.
*/
@Ignore
@Test
public void testComputeCost_SimpleMatrix_SimpleMatrix() {
System.out.println("computeCost");
SimpleMatrix input = null;
SimpleMatrix Yobs = null;
Backpropagation instance = new Backpropagation();
double expResult = 0.0;
double result = instance.cost(input, Yobs);
assertEquals(expResult, result, 0.0);
// TODO review the generated test code and remove the default call to fail.
fail("The test case is a prototype.");
}
use of org.gitia.froog.trainingalgorithm.Backpropagation in project froog by mroodschild.
the class Test method main.
public static void main(String[] args) {
SimpleMatrix input = CSV.open("src/main/resources/function/train_in.csv");
SimpleMatrix output = CSV.open("src/main/resources/function/train_out.csv");
SimpleMatrix in_test = CSV.open("src/main/resources/function/test_in.csv");
SimpleMatrix out_test = CSV.open("src/main/resources/function/test_out.csv");
SimpleMatrix all_in = CSV.open("src/main/resources/function/all_in.csv");
STD std = new STD();
std.fit(input);
input = std.eval(input);
in_test = std.eval(in_test);
all_in = std.eval(all_in);
int inputSize = input.numCols();
int outputSize = output.numCols();
// ==================== Preparamos la RNA =======================
Random rand = new Random(4);
// int nn = 30;
Feedforward net = new Feedforward();
net.addLayer(new Layer(inputSize, 10, TransferFunction.TANSIG, rand));
net.addLayer(new Layer(10, 5, TransferFunction.PRERELU, rand));
net.addLayer(new Layer(5, outputSize, TransferFunction.PURELIM, rand));
// ================= configuraciones del ensayo ========================
// Preparamos el algoritmo de entrenamiento
Backpropagation bp = new Backpropagation();
bp.setEpoch(5000);
// bp.setMomentum(0.9);
bp.setLearningRate(0.01);
bp.setInputTest(in_test);
bp.setOutputTest(out_test);
bp.setTestFrecuency(1);
bp.setLossFunction(LossFunction.RMSE);
input.printDimensions();
output.printDimensions();
bp.entrenar(net, input, output);
try {
net.outputAll(all_in).saveToFileCSV("src/main/resources/function/res_train.csv");
} catch (IOException ex) {
Logger.getLogger(Test.class.getName()).log(Level.SEVERE, null, ex);
}
}
Aggregations