use of org.encog.neural.networks.training.propagation.quick.QuickPropagation in project shifu by ShifuML.
the class LogisticRegressionTrainer method train.
/**
* {@inheritDoc}
* <p>
* no <code>regularization</code>
* <p>
* Regular will be provide later
* <p>
*
* @throws IOException
* e
*/
@Override
public double train() throws IOException {
classifier = new BasicNetwork();
classifier.addLayer(new BasicLayer(new ActivationLinear(), true, trainSet.getInputSize()));
classifier.addLayer(new BasicLayer(new ActivationSigmoid(), false, trainSet.getIdealSize()));
classifier.getStructure().finalizeStructure();
// resetParams(classifier);
classifier.reset();
// Propagation mlTrain = getMLTrain();
Propagation propagation = new QuickPropagation(classifier, trainSet, (Double) modelConfig.getParams().get("LearningRate"));
int epochs = modelConfig.getNumTrainEpochs();
// Get convergence threshold from modelConfig.
double threshold = modelConfig.getTrain().getConvergenceThreshold() == null ? 0.0 : modelConfig.getTrain().getConvergenceThreshold().doubleValue();
String formatedThreshold = df.format(threshold);
LOG.info("Using " + (Double) modelConfig.getParams().get("LearningRate") + " training rate");
for (int i = 0; i < epochs; i++) {
propagation.iteration();
double trainError = propagation.getError();
double validError = classifier.calculateError(this.validSet);
LOG.info("Epoch #" + (i + 1) + " Train Error:" + df.format(trainError) + " Validation Error:" + df.format(validError));
// Convergence judging.
double avgErr = (trainError + validError) / 2;
if (judger.judge(avgErr, threshold)) {
LOG.info("Trainer-{}> Epoch #{} converged! Average Error: {}, Threshold: {}", trainerID, (i + 1), df.format(avgErr), formatedThreshold);
break;
}
}
propagation.finishTraining();
LOG.info("#" + this.trainerID + " finish training");
saveLR();
return 0.0d;
}
use of org.encog.neural.networks.training.propagation.quick.QuickPropagation in project shifu by ShifuML.
the class NNTrainerTest method testAndOperation.
@Test
public void testAndOperation() throws IOException {
MLDataPair dataPair0 = BasicMLDataPair.createPair(2, 1);
dataPair0.setInputArray(new double[] { 0.0, 0.0 });
dataPair0.setIdealArray(new double[] { 0.0 });
trainSet.add(dataPair0);
MLDataPair dataPair1 = BasicMLDataPair.createPair(2, 1);
dataPair1.setInputArray(new double[] { 0.0, 1.0 });
dataPair1.setIdealArray(new double[] { 0.0 });
trainSet.add(dataPair1);
MLDataPair dataPair2 = BasicMLDataPair.createPair(2, 1);
dataPair2.setInputArray(new double[] { 1.0, 0.0 });
dataPair2.setIdealArray(new double[] { 0.0 });
trainSet.add(dataPair2);
MLDataPair dataPair3 = BasicMLDataPair.createPair(2, 1);
dataPair3.setInputArray(new double[] { 1.0, 1.0 });
dataPair3.setIdealArray(new double[] { 1.0 });
trainSet.add(dataPair3);
Propagation propagation = new QuickPropagation(network, trainSet, 0.1);
double error = 0.0;
double lastError = Double.MAX_VALUE;
int iterCnt = 0;
do {
propagation.iteration();
lastError = error;
error = propagation.getError();
System.out.println("The #" + (++iterCnt) + " error is " + error);
} while (Math.abs(lastError - error) > 0.001);
propagation.finishTraining();
File tmp = new File("model_folder");
if (!tmp.exists()) {
FileUtils.forceMkdir(tmp);
}
File modelFile = new File("model_folder/model6.nn");
EncogDirectoryPersistence.saveObject(modelFile, network);
Assert.assertTrue(modelFile.exists());
FileUtils.deleteQuietly(modelFile);
}
use of org.encog.neural.networks.training.propagation.quick.QuickPropagation in project shifu by ShifuML.
the class DTrainTest method quickTest.
@Test
public void quickTest() throws IOException {
double[] gradientError = new double[NUM_EPOCHS];
double[] ecogError = new double[NUM_EPOCHS];
network.reset();
weights = network.getFlat().getWeights();
MLDataSet[] subsets = splitDataSet(training);
Gradient[] workers = new Gradient[numSplit];
Weight weightCalculator = null;
for (int i = 0; i < workers.length; i++) {
workers[i] = initGradient(subsets[i]);
workers[i].setWeights(weights);
}
log.info("Running QuickPropagtaion testing! ");
NNParams globalParams = new NNParams();
globalParams.setWeights(weights);
for (int i = 0; i < NUM_EPOCHS; i++) {
double error = 0.0;
// each worker do the job
for (int j = 0; j < workers.length; j++) {
workers[j].run();
error += workers[j].getError();
}
gradientError[i] = error / workers.length;
log.info("The #" + i + " training error: " + gradientError[i]);
// master
globalParams.reset();
for (int j = 0; j < workers.length; j++) {
globalParams.accumulateGradients(workers[j].getGradients());
globalParams.accumulateTrainSize(subsets[j].getRecordCount());
}
if (weightCalculator == null) {
weightCalculator = new Weight(globalParams.getGradients().length, globalParams.getTrainSize(), this.rate, DTrainUtils.QUICK_PROPAGATION, 0, RegulationLevel.NONE);
}
double[] interWeight = weightCalculator.calculateWeights(globalParams.getWeights(), globalParams.getGradients(), -1);
globalParams.setWeights(interWeight);
// set weights
for (int j = 0; j < workers.length; j++) {
workers[j].setWeights(interWeight);
}
}
// encog
network.reset();
// NNUtils.randomize(numSplit, weights);
network.getFlat().setWeights(weights);
Propagation p = null;
p = new QuickPropagation(network, training, rate);
// p = new ManhattanPropagation(network, training, rate);
p.setThreadCount(numSplit);
for (int i = 0; i < NUM_EPOCHS; i++) {
p.iteration(1);
// System.out.println("the #" + i + " training error: " + p.getError());
ecogError[i] = p.getError();
}
// assert
double diff = 0.0;
for (int i = 0; i < NUM_EPOCHS; i++) {
diff += Math.abs(ecogError[i] - gradientError[i]);
}
Assert.assertTrue(diff / NUM_EPOCHS < 0.1);
}
use of org.encog.neural.networks.training.propagation.quick.QuickPropagation in project shifu by ShifuML.
the class NNTrainer method getMLTrain.
private Propagation getMLTrain() {
// String alg = this.modelConfig.getLearningAlgorithm();
String alg = (String) modelConfig.getParams().get(CommonConstants.PROPAGATION);
if (!(defaultLearningRate.containsKey(alg))) {
throw new RuntimeException("Learning algorithm is invalid: " + alg);
}
// Double rate = this.modelConfig.getLearningRate();
double rate = defaultLearningRate.get(alg);
Object rateObj = modelConfig.getParams().get(CommonConstants.LEARNING_RATE);
if (rateObj instanceof Double) {
rate = (Double) rateObj;
} else if (rateObj instanceof Integer) {
// change like this, because user may set it as integer
rate = ((Integer) rateObj).doubleValue();
} else if (rateObj instanceof Float) {
rate = ((Float) rateObj).doubleValue();
}
if (toLoggingProcess)
LOG.info(" - Learning Algorithm: " + learningAlgMap.get(alg));
if (alg.equals("Q") || alg.equals("B") || alg.equals("M")) {
if (toLoggingProcess)
LOG.info(" - Learning Rate: " + rate);
}
if (alg.equals("B")) {
return new Backpropagation(network, trainSet, rate, 0);
} else if (alg.equals("Q")) {
return new QuickPropagation(network, trainSet, rate);
} else if (alg.equals("M")) {
return new ManhattanPropagation(network, trainSet, rate);
} else if (alg.equals("R")) {
return new ResilientPropagation(network, trainSet);
} else if (alg.equals("S")) {
return new ScaledConjugateGradient(network, trainSet);
} else {
return null;
}
}
Aggregations