use of org.gitia.froog.Feedforward in project froog by mroodschild.
the class Test method main.
public static void main(String[] args) {
SimpleMatrix input = CSV.open("src/main/resources/function/train_in.csv");
SimpleMatrix output = CSV.open("src/main/resources/function/train_out.csv");
SimpleMatrix in_test = CSV.open("src/main/resources/function/test_in.csv");
SimpleMatrix out_test = CSV.open("src/main/resources/function/test_out.csv");
SimpleMatrix all_in = CSV.open("src/main/resources/function/all_in.csv");
STD std = new STD();
std.fit(input);
input = std.eval(input);
in_test = std.eval(in_test);
all_in = std.eval(all_in);
int inputSize = input.numCols();
int outputSize = output.numCols();
// ==================== Preparamos la RNA =======================
Random rand = new Random(4);
// int nn = 30;
Feedforward net = new Feedforward();
net.addLayer(new Layer(inputSize, 10, TransferFunction.TANSIG, rand));
net.addLayer(new Layer(10, 5, TransferFunction.PRERELU, rand));
net.addLayer(new Layer(5, outputSize, TransferFunction.PURELIM, rand));
// ================= configuraciones del ensayo ========================
// Preparamos el algoritmo de entrenamiento
Backpropagation bp = new Backpropagation();
bp.setEpoch(5000);
// bp.setMomentum(0.9);
bp.setLearningRate(0.01);
bp.setInputTest(in_test);
bp.setOutputTest(out_test);
bp.setTestFrecuency(1);
bp.setLossFunction(LossFunction.RMSE);
input.printDimensions();
output.printDimensions();
bp.entrenar(net, input, output);
try {
net.outputAll(all_in).saveToFileCSV("src/main/resources/function/res_train.csv");
} catch (IOException ex) {
Logger.getLogger(Test.class.getName()).log(Level.SEVERE, null, ex);
}
}
Aggregations