Search in sources :

Example 1 with ScaledConjugateGradient

use of org.encog.neural.networks.training.propagation.scg.ScaledConjugateGradient in project shifu by ShifuML.

the class NNTrainer method getMLTrain.

private Propagation getMLTrain() {
    // String alg = this.modelConfig.getLearningAlgorithm();
    String alg = (String) modelConfig.getParams().get(CommonConstants.PROPAGATION);
    if (!(defaultLearningRate.containsKey(alg))) {
        throw new RuntimeException("Learning algorithm is invalid: " + alg);
    }
    // Double rate = this.modelConfig.getLearningRate();
    double rate = defaultLearningRate.get(alg);
    Object rateObj = modelConfig.getParams().get(CommonConstants.LEARNING_RATE);
    if (rateObj instanceof Double) {
        rate = (Double) rateObj;
    } else if (rateObj instanceof Integer) {
        // change like this, because user may set it as integer
        rate = ((Integer) rateObj).doubleValue();
    } else if (rateObj instanceof Float) {
        rate = ((Float) rateObj).doubleValue();
    }
    if (toLoggingProcess)
        LOG.info("    - Learning Algorithm: " + learningAlgMap.get(alg));
    if (alg.equals("Q") || alg.equals("B") || alg.equals("M")) {
        if (toLoggingProcess)
            LOG.info("    - Learning Rate: " + rate);
    }
    if (alg.equals("B")) {
        return new Backpropagation(network, trainSet, rate, 0);
    } else if (alg.equals("Q")) {
        return new QuickPropagation(network, trainSet, rate);
    } else if (alg.equals("M")) {
        return new ManhattanPropagation(network, trainSet, rate);
    } else if (alg.equals("R")) {
        return new ResilientPropagation(network, trainSet);
    } else if (alg.equals("S")) {
        return new ScaledConjugateGradient(network, trainSet);
    } else {
        return null;
    }
}
Also used : Backpropagation(org.encog.neural.networks.training.propagation.back.Backpropagation) ScaledConjugateGradient(org.encog.neural.networks.training.propagation.scg.ScaledConjugateGradient) QuickPropagation(org.encog.neural.networks.training.propagation.quick.QuickPropagation) ResilientPropagation(org.encog.neural.networks.training.propagation.resilient.ResilientPropagation) ModelInitInputObject(ml.shifu.shifu.container.ModelInitInputObject) ManhattanPropagation(org.encog.neural.networks.training.propagation.manhattan.ManhattanPropagation)

Aggregations

ModelInitInputObject (ml.shifu.shifu.container.ModelInitInputObject)1 Backpropagation (org.encog.neural.networks.training.propagation.back.Backpropagation)1 ManhattanPropagation (org.encog.neural.networks.training.propagation.manhattan.ManhattanPropagation)1 QuickPropagation (org.encog.neural.networks.training.propagation.quick.QuickPropagation)1 ResilientPropagation (org.encog.neural.networks.training.propagation.resilient.ResilientPropagation)1 ScaledConjugateGradient (org.encog.neural.networks.training.propagation.scg.ScaledConjugateGradient)1