Search in sources :

Example 1 with Backpropagation

use of org.encog.neural.networks.training.propagation.back.Backpropagation in project shifu by ShifuML.

the class DTrainTest method backTest.

@Test
public void backTest() {
    double[] gradientError = new double[NUM_EPOCHS];
    double[] ecogError = new double[NUM_EPOCHS];
    network.reset();
    weights = network.getFlat().getWeights();
    MLDataSet[] subsets = splitDataSet(training);
    Gradient[] workers = new Gradient[numSplit];
    Weight weightCalculator = null;
    for (int i = 0; i < workers.length; i++) {
        workers[i] = initGradient(subsets[i]);
        workers[i].setWeights(weights);
    }
    log.info("Starting back propagation testing!");
    NNParams globalParams = new NNParams();
    globalParams.setWeights(weights);
    for (int i = 0; i < NUM_EPOCHS; i++) {
        double error = 0.0;
        // each worker do the job
        for (int j = 0; j < workers.length; j++) {
            workers[j].run();
            error += workers[j].getError();
        }
        gradientError[i] = error / workers.length;
        log.info("The #" + i + " training error: " + gradientError[i]);
        // master
        globalParams.reset();
        for (int j = 0; j < workers.length; j++) {
            globalParams.accumulateGradients(workers[j].getGradients());
            globalParams.accumulateTrainSize(subsets[j].getRecordCount());
        }
        if (weightCalculator == null) {
            weightCalculator = new Weight(globalParams.getGradients().length, globalParams.getTrainSize(), this.rate, DTrainUtils.BACK_PROPAGATION, 0, RegulationLevel.NONE);
        }
        double[] interWeight = weightCalculator.calculateWeights(globalParams.getWeights(), globalParams.getGradients(), -1);
        globalParams.setWeights(interWeight);
        // set weights
        for (int j = 0; j < workers.length; j++) {
            workers[j].setWeights(interWeight);
        }
    }
    // encog
    network.reset();
    // NNUtils.randomize(numSplit, weights);
    network.getFlat().setWeights(weights);
    Propagation p = null;
    p = new Backpropagation(network, training, rate, 0.5);
    p.setThreadCount(numSplit);
    for (int i = 0; i < NUM_EPOCHS; i++) {
        p.iteration(1);
        // System.out.println("the #" + i + " training error: " + p.getError());
        ecogError[i] = p.getError();
    }
    // assert
    double diff = 0.0;
    for (int i = 0; i < NUM_EPOCHS; i++) {
        diff += Math.abs(ecogError[i] - gradientError[i]);
    }
    Assert.assertTrue(diff / NUM_EPOCHS < 0.2);
}
Also used : Backpropagation(org.encog.neural.networks.training.propagation.back.Backpropagation) Propagation(org.encog.neural.networks.training.propagation.Propagation) ManhattanPropagation(org.encog.neural.networks.training.propagation.manhattan.ManhattanPropagation) ResilientPropagation(org.encog.neural.networks.training.propagation.resilient.ResilientPropagation) QuickPropagation(org.encog.neural.networks.training.propagation.quick.QuickPropagation) MLDataSet(org.encog.ml.data.MLDataSet) BasicMLDataSet(org.encog.ml.data.basic.BasicMLDataSet) NNParams(ml.shifu.shifu.core.dtrain.nn.NNParams) Test(org.testng.annotations.Test) BeforeTest(org.testng.annotations.BeforeTest)

Example 2 with Backpropagation

use of org.encog.neural.networks.training.propagation.back.Backpropagation in project shifu by ShifuML.

the class NNTrainer method getMLTrain.

private Propagation getMLTrain() {
    // String alg = this.modelConfig.getLearningAlgorithm();
    String alg = (String) modelConfig.getParams().get(CommonConstants.PROPAGATION);
    if (!(defaultLearningRate.containsKey(alg))) {
        throw new RuntimeException("Learning algorithm is invalid: " + alg);
    }
    // Double rate = this.modelConfig.getLearningRate();
    double rate = defaultLearningRate.get(alg);
    Object rateObj = modelConfig.getParams().get(CommonConstants.LEARNING_RATE);
    if (rateObj instanceof Double) {
        rate = (Double) rateObj;
    } else if (rateObj instanceof Integer) {
        // change like this, because user may set it as integer
        rate = ((Integer) rateObj).doubleValue();
    } else if (rateObj instanceof Float) {
        rate = ((Float) rateObj).doubleValue();
    }
    if (toLoggingProcess)
        LOG.info("    - Learning Algorithm: " + learningAlgMap.get(alg));
    if (alg.equals("Q") || alg.equals("B") || alg.equals("M")) {
        if (toLoggingProcess)
            LOG.info("    - Learning Rate: " + rate);
    }
    if (alg.equals("B")) {
        return new Backpropagation(network, trainSet, rate, 0);
    } else if (alg.equals("Q")) {
        return new QuickPropagation(network, trainSet, rate);
    } else if (alg.equals("M")) {
        return new ManhattanPropagation(network, trainSet, rate);
    } else if (alg.equals("R")) {
        return new ResilientPropagation(network, trainSet);
    } else if (alg.equals("S")) {
        return new ScaledConjugateGradient(network, trainSet);
    } else {
        return null;
    }
}
Also used : Backpropagation(org.encog.neural.networks.training.propagation.back.Backpropagation) ScaledConjugateGradient(org.encog.neural.networks.training.propagation.scg.ScaledConjugateGradient) QuickPropagation(org.encog.neural.networks.training.propagation.quick.QuickPropagation) ResilientPropagation(org.encog.neural.networks.training.propagation.resilient.ResilientPropagation) ModelInitInputObject(ml.shifu.shifu.container.ModelInitInputObject) ManhattanPropagation(org.encog.neural.networks.training.propagation.manhattan.ManhattanPropagation)

Aggregations

Backpropagation (org.encog.neural.networks.training.propagation.back.Backpropagation)2 ManhattanPropagation (org.encog.neural.networks.training.propagation.manhattan.ManhattanPropagation)2 QuickPropagation (org.encog.neural.networks.training.propagation.quick.QuickPropagation)2 ResilientPropagation (org.encog.neural.networks.training.propagation.resilient.ResilientPropagation)2 ModelInitInputObject (ml.shifu.shifu.container.ModelInitInputObject)1 NNParams (ml.shifu.shifu.core.dtrain.nn.NNParams)1 MLDataSet (org.encog.ml.data.MLDataSet)1 BasicMLDataSet (org.encog.ml.data.basic.BasicMLDataSet)1 Propagation (org.encog.neural.networks.training.propagation.Propagation)1 ScaledConjugateGradient (org.encog.neural.networks.training.propagation.scg.ScaledConjugateGradient)1 BeforeTest (org.testng.annotations.BeforeTest)1 Test (org.testng.annotations.Test)1