Search in sources :

Example 11 with Updater

use of org.deeplearning4j.nn.api.Updater in project deeplearning4j by deeplearning4j.

the class TestUpdaters method testSetGetUpdater2.

@Test
public void testSetGetUpdater2() {
    //Same as above test, except that we are doing setUpdater on a new network
    Nd4j.getRandom().setSeed(12345L);
    double lr = 0.03;
    int nIn = 4;
    int nOut = 8;
    MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().learningRate(lr).momentum(0.6).list().layer(0, new DenseLayer.Builder().nIn(nIn).nOut(5).updater(org.deeplearning4j.nn.conf.Updater.SGD).build()).layer(1, new DenseLayer.Builder().nIn(5).nOut(6).updater(org.deeplearning4j.nn.conf.Updater.NONE).build()).layer(2, new DenseLayer.Builder().nIn(6).nOut(7).updater(org.deeplearning4j.nn.conf.Updater.ADAGRAD).build()).layer(3, new OutputLayer.Builder().nIn(7).nOut(nOut).updater(org.deeplearning4j.nn.conf.Updater.NESTEROVS).build()).backprop(true).pretrain(false).build();
    MultiLayerNetwork net = new MultiLayerNetwork(conf);
    net.init();
    Updater newUpdater = UpdaterCreator.getUpdater(net);
    net.setUpdater(newUpdater);
    //Should be identical object
    assertTrue(newUpdater == net.getUpdater());
}
Also used : MultiLayerConfiguration(org.deeplearning4j.nn.conf.MultiLayerConfiguration) DenseLayer(org.deeplearning4j.nn.conf.layers.DenseLayer) Updater(org.deeplearning4j.nn.api.Updater) MultiLayerNetwork(org.deeplearning4j.nn.multilayer.MultiLayerNetwork) Test(org.junit.Test)

Example 12 with Updater

use of org.deeplearning4j.nn.api.Updater in project deeplearning4j by deeplearning4j.

the class TestDecayPolicies method testLearningRateTorchStepDecaySingleLayer.

@Test
public void testLearningRateTorchStepDecaySingleLayer() {
    int iterations = 20;
    double lr = 1;
    double decayRate = .5;
    double steps = 10;
    NeuralNetConfiguration conf = new NeuralNetConfiguration.Builder().learningRate(lr).learningRateDecayPolicy(LearningRatePolicy.TorchStep).lrPolicyDecayRate(decayRate).lrPolicySteps(steps).iterations(iterations).layer(new DenseLayer.Builder().nIn(nIn).nOut(nOut).updater(org.deeplearning4j.nn.conf.Updater.SGD).build()).build();
    int numParams = conf.getLayer().initializer().numParams(conf);
    INDArray params = Nd4j.create(1, numParams);
    Layer layer = conf.getLayer().instantiate(conf, null, 0, params, true);
    Updater updater = UpdaterCreator.getUpdater(layer);
    Gradient gradientActual = new DefaultGradient();
    gradientActual.setGradientFor(DefaultParamInitializer.WEIGHT_KEY, weightGradient);
    gradientActual.setGradientFor(DefaultParamInitializer.BIAS_KEY, biasGradient);
    double expectedLr = lr;
    for (int i = 0; i < iterations; i++) {
        updater.update(layer, gradientActual, i, 1);
        if (i > 1 && steps % i == 0)
            expectedLr = calcTorchStepDecay(expectedLr, decayRate);
        assertEquals(expectedLr, layer.conf().getLearningRateByParam("W"), 1e-4);
        assertEquals(expectedLr, layer.conf().getLearningRateByParam("b"), 1e-4);
    }
}
Also used : Gradient(org.deeplearning4j.nn.gradient.Gradient) DefaultGradient(org.deeplearning4j.nn.gradient.DefaultGradient) DefaultGradient(org.deeplearning4j.nn.gradient.DefaultGradient) INDArray(org.nd4j.linalg.api.ndarray.INDArray) Updater(org.deeplearning4j.nn.api.Updater) NeuralNetConfiguration(org.deeplearning4j.nn.conf.NeuralNetConfiguration) Layer(org.deeplearning4j.nn.api.Layer) OutputLayer(org.deeplearning4j.nn.conf.layers.OutputLayer) DenseLayer(org.deeplearning4j.nn.conf.layers.DenseLayer) Test(org.junit.Test)

Example 13 with Updater

use of org.deeplearning4j.nn.api.Updater in project deeplearning4j by deeplearning4j.

the class TestDecayPolicies method testLearningRateScheduleSingleLayer.

@Test
public void testLearningRateScheduleSingleLayer() {
    Map<Integer, Double> learningRateAfter = new HashMap<>();
    learningRateAfter.put(1, 0.2);
    int iterations = 2;
    for (org.deeplearning4j.nn.conf.Updater updaterFunc : updaters) {
        double lr = 1e-2;
        NeuralNetConfiguration conf = new NeuralNetConfiguration.Builder().learningRate(lr).learningRateSchedule(learningRateAfter).learningRateDecayPolicy(LearningRatePolicy.Schedule).iterations(iterations).layer(new DenseLayer.Builder().nIn(nIn).nOut(nOut).updater(updaterFunc).build()).build();
        int numParams = conf.getLayer().initializer().numParams(conf);
        INDArray params = Nd4j.create(1, numParams);
        Layer layer = conf.getLayer().instantiate(conf, null, 0, params, true);
        Updater updater = UpdaterCreator.getUpdater(layer);
        int stateSize = updater.stateSizeForLayer(layer);
        if (stateSize > 0)
            updater.setStateViewArray(layer, Nd4j.create(1, stateSize), true);
        Gradient gradientActual = new DefaultGradient();
        gradientActual.setGradientFor(DefaultParamInitializer.WEIGHT_KEY, weightGradient.dup());
        gradientActual.setGradientFor(DefaultParamInitializer.BIAS_KEY, biasGradient.dup());
        Gradient gradientExpected = new DefaultGradient();
        gradientExpected.setGradientFor(DefaultParamInitializer.WEIGHT_KEY, weightGradient.dup());
        gradientExpected.setGradientFor(DefaultParamInitializer.BIAS_KEY, biasGradient.dup());
        for (int i = 0; i < 2; i++) {
            updater.update(layer, gradientActual, i, 1);
            if (updaterFunc.equals(org.deeplearning4j.nn.conf.Updater.SGD))
                lr = testSGDComputation(gradientActual, gradientExpected, lr, learningRateAfter, i);
            else if (updaterFunc.equals(org.deeplearning4j.nn.conf.Updater.ADAGRAD))
                lr = testAdaGradComputation(gradientActual, gradientExpected, lr, learningRateAfter, i);
            else if (updaterFunc.equals(org.deeplearning4j.nn.conf.Updater.ADAM))
                lr = testAdamComputation(gradientActual, gradientExpected, lr, learningRateAfter, i);
            else if (updaterFunc.equals(org.deeplearning4j.nn.conf.Updater.RMSPROP))
                lr = testRMSPropComputation(gradientActual, gradientExpected, lr, learningRateAfter, i);
            assertEquals(lr, layer.conf().getLearningRateByParam("W"), 1e-4);
        }
    }
}
Also used : Gradient(org.deeplearning4j.nn.gradient.Gradient) DefaultGradient(org.deeplearning4j.nn.gradient.DefaultGradient) HashMap(java.util.HashMap) NeuralNetConfiguration(org.deeplearning4j.nn.conf.NeuralNetConfiguration) Layer(org.deeplearning4j.nn.api.Layer) OutputLayer(org.deeplearning4j.nn.conf.layers.OutputLayer) DenseLayer(org.deeplearning4j.nn.conf.layers.DenseLayer) DefaultGradient(org.deeplearning4j.nn.gradient.DefaultGradient) INDArray(org.nd4j.linalg.api.ndarray.INDArray) Updater(org.deeplearning4j.nn.api.Updater) Test(org.junit.Test)

Example 14 with Updater

use of org.deeplearning4j.nn.api.Updater in project deeplearning4j by deeplearning4j.

the class TestGradientNormalization method testL2ClippingPerLayer.

@Test
public void testL2ClippingPerLayer() {
    Nd4j.getRandom().setSeed(12345);
    double threshold = 3;
    for (int t = 0; t < 2; t++) {
        //t=0: small -> no clipping
        //t=1: large -> clipping
        NeuralNetConfiguration conf = new NeuralNetConfiguration.Builder().layer(new DenseLayer.Builder().nIn(10).nOut(20).updater(org.deeplearning4j.nn.conf.Updater.NONE).gradientNormalization(GradientNormalization.ClipL2PerLayer).gradientNormalizationThreshold(threshold).build()).build();
        int numParams = conf.getLayer().initializer().numParams(conf);
        INDArray params = Nd4j.create(1, numParams);
        Layer layer = conf.getLayer().instantiate(conf, null, 0, params, true);
        Updater updater = UpdaterCreator.getUpdater(layer);
        INDArray weightGrad = Nd4j.rand(10, 20).muli((t == 0 ? 0.05 : 10));
        INDArray biasGrad = Nd4j.rand(1, 10).muli((t == 0 ? 0.05 : 10));
        INDArray weightGradCopy = weightGrad.dup();
        INDArray biasGradCopy = biasGrad.dup();
        Gradient gradient = new DefaultGradient();
        gradient.setGradientFor(DefaultParamInitializer.WEIGHT_KEY, weightGrad);
        gradient.setGradientFor(DefaultParamInitializer.BIAS_KEY, biasGrad);
        double layerGradL2 = gradient.gradient().norm2Number().doubleValue();
        if (t == 0)
            assertTrue(layerGradL2 < threshold);
        else
            assertTrue(layerGradL2 > threshold);
        updater.update(layer, gradient, 0, 1);
        if (t == 0) {
            //norm2 < threshold -> no change
            assertEquals(weightGradCopy, weightGrad);
            assertEquals(biasGradCopy, biasGrad);
            continue;
        } else {
            //norm2 > threshold -> rescale
            assertNotEquals(weightGradCopy, weightGrad);
            assertNotEquals(biasGradCopy, biasGrad);
        }
        //for above threshold only...
        double scalingFactor = threshold / layerGradL2;
        INDArray expectedWeightGrad = weightGradCopy.mul(scalingFactor);
        INDArray expectedBiasGrad = biasGradCopy.mul(scalingFactor);
        assertEquals(expectedWeightGrad, gradient.getGradientFor(DefaultParamInitializer.WEIGHT_KEY));
        assertEquals(expectedBiasGrad, gradient.getGradientFor(DefaultParamInitializer.BIAS_KEY));
    }
}
Also used : DefaultGradient(org.deeplearning4j.nn.gradient.DefaultGradient) Gradient(org.deeplearning4j.nn.gradient.Gradient) DefaultGradient(org.deeplearning4j.nn.gradient.DefaultGradient) INDArray(org.nd4j.linalg.api.ndarray.INDArray) Updater(org.deeplearning4j.nn.api.Updater) NeuralNetConfiguration(org.deeplearning4j.nn.conf.NeuralNetConfiguration) Layer(org.deeplearning4j.nn.api.Layer) DenseLayer(org.deeplearning4j.nn.conf.layers.DenseLayer) Test(org.junit.Test)

Example 15 with Updater

use of org.deeplearning4j.nn.api.Updater in project deeplearning4j by deeplearning4j.

the class TestUpdaters method testSetGetUpdater.

@Test
public void testSetGetUpdater() {
    Nd4j.getRandom().setSeed(12345L);
    double lr = 0.03;
    int nIn = 4;
    int nOut = 8;
    MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().learningRate(lr).momentum(0.6).list().layer(0, new DenseLayer.Builder().nIn(nIn).nOut(5).updater(org.deeplearning4j.nn.conf.Updater.SGD).build()).layer(1, new DenseLayer.Builder().nIn(5).nOut(6).updater(org.deeplearning4j.nn.conf.Updater.NONE).build()).layer(2, new DenseLayer.Builder().nIn(6).nOut(7).updater(org.deeplearning4j.nn.conf.Updater.ADAGRAD).build()).layer(3, new OutputLayer.Builder().nIn(7).nOut(nOut).updater(org.deeplearning4j.nn.conf.Updater.NESTEROVS).build()).backprop(true).pretrain(false).build();
    MultiLayerNetwork net = new MultiLayerNetwork(conf);
    net.init();
    //Fit, to initialize optimizer/updater
    net.fit(Nd4j.rand(5, nIn), Nd4j.rand(5, nOut));
    Updater updater = net.getUpdater();
    assertTrue(updater instanceof MultiLayerUpdater);
    Updater newUpdater = UpdaterCreator.getUpdater(net);
    net.setUpdater(newUpdater);
    //Should be identical object
    assertTrue(newUpdater == net.getUpdater());
}
Also used : MultiLayerConfiguration(org.deeplearning4j.nn.conf.MultiLayerConfiguration) DenseLayer(org.deeplearning4j.nn.conf.layers.DenseLayer) Updater(org.deeplearning4j.nn.api.Updater) MultiLayerNetwork(org.deeplearning4j.nn.multilayer.MultiLayerNetwork) Test(org.junit.Test)

Aggregations

Updater (org.deeplearning4j.nn.api.Updater)37 Test (org.junit.Test)28 DenseLayer (org.deeplearning4j.nn.conf.layers.DenseLayer)27 INDArray (org.nd4j.linalg.api.ndarray.INDArray)27 NeuralNetConfiguration (org.deeplearning4j.nn.conf.NeuralNetConfiguration)25 Gradient (org.deeplearning4j.nn.gradient.Gradient)25 DefaultGradient (org.deeplearning4j.nn.gradient.DefaultGradient)23 Layer (org.deeplearning4j.nn.api.Layer)21 OutputLayer (org.deeplearning4j.nn.conf.layers.OutputLayer)18 MultiLayerNetwork (org.deeplearning4j.nn.multilayer.MultiLayerNetwork)9 MultiLayerConfiguration (org.deeplearning4j.nn.conf.MultiLayerConfiguration)8 ComputationGraphUpdater (org.deeplearning4j.nn.updater.graph.ComputationGraphUpdater)5 HashMap (java.util.HashMap)4 Solver (org.deeplearning4j.optimize.Solver)4 ArrayList (java.util.ArrayList)2 Field (java.lang.reflect.Field)1 AtomicInteger (java.util.concurrent.atomic.AtomicInteger)1 ZipEntry (java.util.zip.ZipEntry)1 ZipFile (java.util.zip.ZipFile)1 Persistable (org.deeplearning4j.api.storage.Persistable)1