use of org.encog.neural.networks.training.propagation.resilient.ResilientPropagation in project shifu by ShifuML.
the class DTrainTest method resilientPropagationTest.
@Test
public void resilientPropagationTest() {
double[] gradientError = new double[NUM_EPOCHS];
double[] ecogError = new double[NUM_EPOCHS];
network.reset();
weights = network.getFlat().getWeights();
MLDataSet[] subsets = splitDataSet(training);
Gradient[] workers = new Gradient[numSplit];
Weight weightCalculator = null;
for (int i = 0; i < workers.length; i++) {
workers[i] = initGradient(subsets[i]);
workers[i].setWeights(weights);
}
log.info("Starting resilient propagation testing!");
NNParams globalParams = new NNParams();
globalParams.setWeights(weights);
for (int i = 0; i < NUM_EPOCHS; i++) {
double error = 0.0;
// each worker do the job
for (int j = 0; j < workers.length; j++) {
workers[j].run();
error += workers[j].getError();
}
gradientError[i] = error / workers.length;
log.info("The #" + i + " training error: " + gradientError[i]);
// master
globalParams.reset();
for (int j = 0; j < workers.length; j++) {
globalParams.accumulateGradients(workers[j].getGradients());
globalParams.accumulateTrainSize(subsets[j].getRecordCount());
}
if (weightCalculator == null) {
weightCalculator = new Weight(globalParams.getGradients().length, globalParams.getTrainSize(), this.rate, DTrainUtils.RESILIENTPROPAGATION, 0, RegulationLevel.NONE);
}
double[] interWeight = weightCalculator.calculateWeights(globalParams.getWeights(), globalParams.getGradients(), -1);
globalParams.setWeights(interWeight);
// set weights
for (int j = 0; j < workers.length; j++) {
workers[j].setWeights(interWeight);
}
}
// encog
network.reset();
// NNUtils.randomize(numSplit, weights);
network.getFlat().setWeights(weights);
Propagation p = null;
p = new ResilientPropagation(network, training);
p.setThreadCount(numSplit);
for (int i = 0; i < NUM_EPOCHS; i++) {
p.iteration(1);
ecogError[i] = p.getError();
}
// assert
double diff = 0.0;
for (int i = 0; i < NUM_EPOCHS; i++) {
diff += Math.abs(ecogError[i] - gradientError[i]);
}
Assert.assertTrue(diff / NUM_EPOCHS < 0.2);
}
use of org.encog.neural.networks.training.propagation.resilient.ResilientPropagation in project shifu by ShifuML.
the class NNTrainer method getMLTrain.
private Propagation getMLTrain() {
// String alg = this.modelConfig.getLearningAlgorithm();
String alg = (String) modelConfig.getParams().get(CommonConstants.PROPAGATION);
if (!(defaultLearningRate.containsKey(alg))) {
throw new RuntimeException("Learning algorithm is invalid: " + alg);
}
// Double rate = this.modelConfig.getLearningRate();
double rate = defaultLearningRate.get(alg);
Object rateObj = modelConfig.getParams().get(CommonConstants.LEARNING_RATE);
if (rateObj instanceof Double) {
rate = (Double) rateObj;
} else if (rateObj instanceof Integer) {
// change like this, because user may set it as integer
rate = ((Integer) rateObj).doubleValue();
} else if (rateObj instanceof Float) {
rate = ((Float) rateObj).doubleValue();
}
if (toLoggingProcess)
LOG.info(" - Learning Algorithm: " + learningAlgMap.get(alg));
if (alg.equals("Q") || alg.equals("B") || alg.equals("M")) {
if (toLoggingProcess)
LOG.info(" - Learning Rate: " + rate);
}
if (alg.equals("B")) {
return new Backpropagation(network, trainSet, rate, 0);
} else if (alg.equals("Q")) {
return new QuickPropagation(network, trainSet, rate);
} else if (alg.equals("M")) {
return new ManhattanPropagation(network, trainSet, rate);
} else if (alg.equals("R")) {
return new ResilientPropagation(network, trainSet);
} else if (alg.equals("S")) {
return new ScaledConjugateGradient(network, trainSet);
} else {
return null;
}
}
Aggregations