use of org.encog.neural.networks.training.propagation.Propagation in project shifu by ShifuML.
the class DTrainTest method quickTest.
@Test
public void quickTest() throws IOException {
double[] gradientError = new double[NUM_EPOCHS];
double[] ecogError = new double[NUM_EPOCHS];
network.reset();
weights = network.getFlat().getWeights();
MLDataSet[] subsets = splitDataSet(training);
Gradient[] workers = new Gradient[numSplit];
Weight weightCalculator = null;
for (int i = 0; i < workers.length; i++) {
workers[i] = initGradient(subsets[i]);
workers[i].setWeights(weights);
}
log.info("Running QuickPropagtaion testing! ");
NNParams globalParams = new NNParams();
globalParams.setWeights(weights);
for (int i = 0; i < NUM_EPOCHS; i++) {
double error = 0.0;
// each worker do the job
for (int j = 0; j < workers.length; j++) {
workers[j].run();
error += workers[j].getError();
}
gradientError[i] = error / workers.length;
log.info("The #" + i + " training error: " + gradientError[i]);
// master
globalParams.reset();
for (int j = 0; j < workers.length; j++) {
globalParams.accumulateGradients(workers[j].getGradients());
globalParams.accumulateTrainSize(subsets[j].getRecordCount());
}
if (weightCalculator == null) {
weightCalculator = new Weight(globalParams.getGradients().length, globalParams.getTrainSize(), this.rate, DTrainUtils.QUICK_PROPAGATION, 0, RegulationLevel.NONE);
}
double[] interWeight = weightCalculator.calculateWeights(globalParams.getWeights(), globalParams.getGradients(), -1);
globalParams.setWeights(interWeight);
// set weights
for (int j = 0; j < workers.length; j++) {
workers[j].setWeights(interWeight);
}
}
// encog
network.reset();
// NNUtils.randomize(numSplit, weights);
network.getFlat().setWeights(weights);
Propagation p = null;
p = new QuickPropagation(network, training, rate);
// p = new ManhattanPropagation(network, training, rate);
p.setThreadCount(numSplit);
for (int i = 0; i < NUM_EPOCHS; i++) {
p.iteration(1);
// System.out.println("the #" + i + " training error: " + p.getError());
ecogError[i] = p.getError();
}
// assert
double diff = 0.0;
for (int i = 0; i < NUM_EPOCHS; i++) {
diff += Math.abs(ecogError[i] - gradientError[i]);
}
Assert.assertTrue(diff / NUM_EPOCHS < 0.1);
}
use of org.encog.neural.networks.training.propagation.Propagation in project shifu by ShifuML.
the class DTrainTest method backTest.
@Test
public void backTest() {
double[] gradientError = new double[NUM_EPOCHS];
double[] ecogError = new double[NUM_EPOCHS];
network.reset();
weights = network.getFlat().getWeights();
MLDataSet[] subsets = splitDataSet(training);
Gradient[] workers = new Gradient[numSplit];
Weight weightCalculator = null;
for (int i = 0; i < workers.length; i++) {
workers[i] = initGradient(subsets[i]);
workers[i].setWeights(weights);
}
log.info("Starting back propagation testing!");
NNParams globalParams = new NNParams();
globalParams.setWeights(weights);
for (int i = 0; i < NUM_EPOCHS; i++) {
double error = 0.0;
// each worker do the job
for (int j = 0; j < workers.length; j++) {
workers[j].run();
error += workers[j].getError();
}
gradientError[i] = error / workers.length;
log.info("The #" + i + " training error: " + gradientError[i]);
// master
globalParams.reset();
for (int j = 0; j < workers.length; j++) {
globalParams.accumulateGradients(workers[j].getGradients());
globalParams.accumulateTrainSize(subsets[j].getRecordCount());
}
if (weightCalculator == null) {
weightCalculator = new Weight(globalParams.getGradients().length, globalParams.getTrainSize(), this.rate, DTrainUtils.BACK_PROPAGATION, 0, RegulationLevel.NONE);
}
double[] interWeight = weightCalculator.calculateWeights(globalParams.getWeights(), globalParams.getGradients(), -1);
globalParams.setWeights(interWeight);
// set weights
for (int j = 0; j < workers.length; j++) {
workers[j].setWeights(interWeight);
}
}
// encog
network.reset();
// NNUtils.randomize(numSplit, weights);
network.getFlat().setWeights(weights);
Propagation p = null;
p = new Backpropagation(network, training, rate, 0.5);
p.setThreadCount(numSplit);
for (int i = 0; i < NUM_EPOCHS; i++) {
p.iteration(1);
// System.out.println("the #" + i + " training error: " + p.getError());
ecogError[i] = p.getError();
}
// assert
double diff = 0.0;
for (int i = 0; i < NUM_EPOCHS; i++) {
diff += Math.abs(ecogError[i] - gradientError[i]);
}
Assert.assertTrue(diff / NUM_EPOCHS < 0.2);
}
Aggregations