Search in sources :

Example 6 with Propagation

use of org.encog.neural.networks.training.propagation.Propagation in project shifu by ShifuML.

the class DTrainTest method quickTest.

@Test
public void quickTest() throws IOException {
    double[] gradientError = new double[NUM_EPOCHS];
    double[] ecogError = new double[NUM_EPOCHS];
    network.reset();
    weights = network.getFlat().getWeights();
    MLDataSet[] subsets = splitDataSet(training);
    Gradient[] workers = new Gradient[numSplit];
    Weight weightCalculator = null;
    for (int i = 0; i < workers.length; i++) {
        workers[i] = initGradient(subsets[i]);
        workers[i].setWeights(weights);
    }
    log.info("Running QuickPropagtaion testing! ");
    NNParams globalParams = new NNParams();
    globalParams.setWeights(weights);
    for (int i = 0; i < NUM_EPOCHS; i++) {
        double error = 0.0;
        // each worker do the job
        for (int j = 0; j < workers.length; j++) {
            workers[j].run();
            error += workers[j].getError();
        }
        gradientError[i] = error / workers.length;
        log.info("The #" + i + " training error: " + gradientError[i]);
        // master
        globalParams.reset();
        for (int j = 0; j < workers.length; j++) {
            globalParams.accumulateGradients(workers[j].getGradients());
            globalParams.accumulateTrainSize(subsets[j].getRecordCount());
        }
        if (weightCalculator == null) {
            weightCalculator = new Weight(globalParams.getGradients().length, globalParams.getTrainSize(), this.rate, DTrainUtils.QUICK_PROPAGATION, 0, RegulationLevel.NONE);
        }
        double[] interWeight = weightCalculator.calculateWeights(globalParams.getWeights(), globalParams.getGradients(), -1);
        globalParams.setWeights(interWeight);
        // set weights
        for (int j = 0; j < workers.length; j++) {
            workers[j].setWeights(interWeight);
        }
    }
    // encog
    network.reset();
    // NNUtils.randomize(numSplit, weights);
    network.getFlat().setWeights(weights);
    Propagation p = null;
    p = new QuickPropagation(network, training, rate);
    // p = new ManhattanPropagation(network, training, rate);
    p.setThreadCount(numSplit);
    for (int i = 0; i < NUM_EPOCHS; i++) {
        p.iteration(1);
        // System.out.println("the #" + i + " training error: " + p.getError());
        ecogError[i] = p.getError();
    }
    // assert
    double diff = 0.0;
    for (int i = 0; i < NUM_EPOCHS; i++) {
        diff += Math.abs(ecogError[i] - gradientError[i]);
    }
    Assert.assertTrue(diff / NUM_EPOCHS < 0.1);
}
Also used : Propagation(org.encog.neural.networks.training.propagation.Propagation) ManhattanPropagation(org.encog.neural.networks.training.propagation.manhattan.ManhattanPropagation) ResilientPropagation(org.encog.neural.networks.training.propagation.resilient.ResilientPropagation) QuickPropagation(org.encog.neural.networks.training.propagation.quick.QuickPropagation) MLDataSet(org.encog.ml.data.MLDataSet) BasicMLDataSet(org.encog.ml.data.basic.BasicMLDataSet) QuickPropagation(org.encog.neural.networks.training.propagation.quick.QuickPropagation) NNParams(ml.shifu.shifu.core.dtrain.nn.NNParams) Test(org.testng.annotations.Test) BeforeTest(org.testng.annotations.BeforeTest)

Example 7 with Propagation

use of org.encog.neural.networks.training.propagation.Propagation in project shifu by ShifuML.

the class DTrainTest method backTest.

@Test
public void backTest() {
    double[] gradientError = new double[NUM_EPOCHS];
    double[] ecogError = new double[NUM_EPOCHS];
    network.reset();
    weights = network.getFlat().getWeights();
    MLDataSet[] subsets = splitDataSet(training);
    Gradient[] workers = new Gradient[numSplit];
    Weight weightCalculator = null;
    for (int i = 0; i < workers.length; i++) {
        workers[i] = initGradient(subsets[i]);
        workers[i].setWeights(weights);
    }
    log.info("Starting back propagation testing!");
    NNParams globalParams = new NNParams();
    globalParams.setWeights(weights);
    for (int i = 0; i < NUM_EPOCHS; i++) {
        double error = 0.0;
        // each worker do the job
        for (int j = 0; j < workers.length; j++) {
            workers[j].run();
            error += workers[j].getError();
        }
        gradientError[i] = error / workers.length;
        log.info("The #" + i + " training error: " + gradientError[i]);
        // master
        globalParams.reset();
        for (int j = 0; j < workers.length; j++) {
            globalParams.accumulateGradients(workers[j].getGradients());
            globalParams.accumulateTrainSize(subsets[j].getRecordCount());
        }
        if (weightCalculator == null) {
            weightCalculator = new Weight(globalParams.getGradients().length, globalParams.getTrainSize(), this.rate, DTrainUtils.BACK_PROPAGATION, 0, RegulationLevel.NONE);
        }
        double[] interWeight = weightCalculator.calculateWeights(globalParams.getWeights(), globalParams.getGradients(), -1);
        globalParams.setWeights(interWeight);
        // set weights
        for (int j = 0; j < workers.length; j++) {
            workers[j].setWeights(interWeight);
        }
    }
    // encog
    network.reset();
    // NNUtils.randomize(numSplit, weights);
    network.getFlat().setWeights(weights);
    Propagation p = null;
    p = new Backpropagation(network, training, rate, 0.5);
    p.setThreadCount(numSplit);
    for (int i = 0; i < NUM_EPOCHS; i++) {
        p.iteration(1);
        // System.out.println("the #" + i + " training error: " + p.getError());
        ecogError[i] = p.getError();
    }
    // assert
    double diff = 0.0;
    for (int i = 0; i < NUM_EPOCHS; i++) {
        diff += Math.abs(ecogError[i] - gradientError[i]);
    }
    Assert.assertTrue(diff / NUM_EPOCHS < 0.2);
}
Also used : Backpropagation(org.encog.neural.networks.training.propagation.back.Backpropagation) Propagation(org.encog.neural.networks.training.propagation.Propagation) ManhattanPropagation(org.encog.neural.networks.training.propagation.manhattan.ManhattanPropagation) ResilientPropagation(org.encog.neural.networks.training.propagation.resilient.ResilientPropagation) QuickPropagation(org.encog.neural.networks.training.propagation.quick.QuickPropagation) MLDataSet(org.encog.ml.data.MLDataSet) BasicMLDataSet(org.encog.ml.data.basic.BasicMLDataSet) NNParams(ml.shifu.shifu.core.dtrain.nn.NNParams) Test(org.testng.annotations.Test) BeforeTest(org.testng.annotations.BeforeTest)

Aggregations

Propagation (org.encog.neural.networks.training.propagation.Propagation)7 QuickPropagation (org.encog.neural.networks.training.propagation.quick.QuickPropagation)7 ManhattanPropagation (org.encog.neural.networks.training.propagation.manhattan.ManhattanPropagation)5 ResilientPropagation (org.encog.neural.networks.training.propagation.resilient.ResilientPropagation)5 Test (org.testng.annotations.Test)5 NNParams (ml.shifu.shifu.core.dtrain.nn.NNParams)4 MLDataSet (org.encog.ml.data.MLDataSet)4 BasicMLDataSet (org.encog.ml.data.basic.BasicMLDataSet)4 BeforeTest (org.testng.annotations.BeforeTest)4 File (java.io.File)1 ActivationLinear (org.encog.engine.network.activation.ActivationLinear)1 ActivationSigmoid (org.encog.engine.network.activation.ActivationSigmoid)1 MLDataPair (org.encog.ml.data.MLDataPair)1 BasicMLDataPair (org.encog.ml.data.basic.BasicMLDataPair)1 BasicNetwork (org.encog.neural.networks.BasicNetwork)1 BasicLayer (org.encog.neural.networks.layers.BasicLayer)1 Backpropagation (org.encog.neural.networks.training.propagation.back.Backpropagation)1