use of org.encog.neural.networks.training.propagation.scg.ScaledConjugateGradient in project shifu by ShifuML.
the class NNTrainer method getMLTrain.
private Propagation getMLTrain() {
// String alg = this.modelConfig.getLearningAlgorithm();
String alg = (String) modelConfig.getParams().get(CommonConstants.PROPAGATION);
if (!(defaultLearningRate.containsKey(alg))) {
throw new RuntimeException("Learning algorithm is invalid: " + alg);
}
// Double rate = this.modelConfig.getLearningRate();
double rate = defaultLearningRate.get(alg);
Object rateObj = modelConfig.getParams().get(CommonConstants.LEARNING_RATE);
if (rateObj instanceof Double) {
rate = (Double) rateObj;
} else if (rateObj instanceof Integer) {
// change like this, because user may set it as integer
rate = ((Integer) rateObj).doubleValue();
} else if (rateObj instanceof Float) {
rate = ((Float) rateObj).doubleValue();
}
if (toLoggingProcess)
LOG.info(" - Learning Algorithm: " + learningAlgMap.get(alg));
if (alg.equals("Q") || alg.equals("B") || alg.equals("M")) {
if (toLoggingProcess)
LOG.info(" - Learning Rate: " + rate);
}
if (alg.equals("B")) {
return new Backpropagation(network, trainSet, rate, 0);
} else if (alg.equals("Q")) {
return new QuickPropagation(network, trainSet, rate);
} else if (alg.equals("M")) {
return new ManhattanPropagation(network, trainSet, rate);
} else if (alg.equals("R")) {
return new ResilientPropagation(network, trainSet);
} else if (alg.equals("S")) {
return new ScaledConjugateGradient(network, trainSet);
} else {
return null;
}
}
Aggregations