use of org.deeplearning4j.nn.api.Updater in project deeplearning4j by deeplearning4j.
the class TestUpdaters method testSetGetUpdater2.
@Test
public void testSetGetUpdater2() {
//Same as above test, except that we are doing setUpdater on a new network
Nd4j.getRandom().setSeed(12345L);
double lr = 0.03;
int nIn = 4;
int nOut = 8;
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().learningRate(lr).momentum(0.6).list().layer(0, new DenseLayer.Builder().nIn(nIn).nOut(5).updater(org.deeplearning4j.nn.conf.Updater.SGD).build()).layer(1, new DenseLayer.Builder().nIn(5).nOut(6).updater(org.deeplearning4j.nn.conf.Updater.NONE).build()).layer(2, new DenseLayer.Builder().nIn(6).nOut(7).updater(org.deeplearning4j.nn.conf.Updater.ADAGRAD).build()).layer(3, new OutputLayer.Builder().nIn(7).nOut(nOut).updater(org.deeplearning4j.nn.conf.Updater.NESTEROVS).build()).backprop(true).pretrain(false).build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();
Updater newUpdater = UpdaterCreator.getUpdater(net);
net.setUpdater(newUpdater);
//Should be identical object
assertTrue(newUpdater == net.getUpdater());
}
use of org.deeplearning4j.nn.api.Updater in project deeplearning4j by deeplearning4j.
the class TestDecayPolicies method testLearningRateTorchStepDecaySingleLayer.
@Test
public void testLearningRateTorchStepDecaySingleLayer() {
int iterations = 20;
double lr = 1;
double decayRate = .5;
double steps = 10;
NeuralNetConfiguration conf = new NeuralNetConfiguration.Builder().learningRate(lr).learningRateDecayPolicy(LearningRatePolicy.TorchStep).lrPolicyDecayRate(decayRate).lrPolicySteps(steps).iterations(iterations).layer(new DenseLayer.Builder().nIn(nIn).nOut(nOut).updater(org.deeplearning4j.nn.conf.Updater.SGD).build()).build();
int numParams = conf.getLayer().initializer().numParams(conf);
INDArray params = Nd4j.create(1, numParams);
Layer layer = conf.getLayer().instantiate(conf, null, 0, params, true);
Updater updater = UpdaterCreator.getUpdater(layer);
Gradient gradientActual = new DefaultGradient();
gradientActual.setGradientFor(DefaultParamInitializer.WEIGHT_KEY, weightGradient);
gradientActual.setGradientFor(DefaultParamInitializer.BIAS_KEY, biasGradient);
double expectedLr = lr;
for (int i = 0; i < iterations; i++) {
updater.update(layer, gradientActual, i, 1);
if (i > 1 && steps % i == 0)
expectedLr = calcTorchStepDecay(expectedLr, decayRate);
assertEquals(expectedLr, layer.conf().getLearningRateByParam("W"), 1e-4);
assertEquals(expectedLr, layer.conf().getLearningRateByParam("b"), 1e-4);
}
}
use of org.deeplearning4j.nn.api.Updater in project deeplearning4j by deeplearning4j.
the class TestDecayPolicies method testLearningRateScheduleSingleLayer.
@Test
public void testLearningRateScheduleSingleLayer() {
Map<Integer, Double> learningRateAfter = new HashMap<>();
learningRateAfter.put(1, 0.2);
int iterations = 2;
for (org.deeplearning4j.nn.conf.Updater updaterFunc : updaters) {
double lr = 1e-2;
NeuralNetConfiguration conf = new NeuralNetConfiguration.Builder().learningRate(lr).learningRateSchedule(learningRateAfter).learningRateDecayPolicy(LearningRatePolicy.Schedule).iterations(iterations).layer(new DenseLayer.Builder().nIn(nIn).nOut(nOut).updater(updaterFunc).build()).build();
int numParams = conf.getLayer().initializer().numParams(conf);
INDArray params = Nd4j.create(1, numParams);
Layer layer = conf.getLayer().instantiate(conf, null, 0, params, true);
Updater updater = UpdaterCreator.getUpdater(layer);
int stateSize = updater.stateSizeForLayer(layer);
if (stateSize > 0)
updater.setStateViewArray(layer, Nd4j.create(1, stateSize), true);
Gradient gradientActual = new DefaultGradient();
gradientActual.setGradientFor(DefaultParamInitializer.WEIGHT_KEY, weightGradient.dup());
gradientActual.setGradientFor(DefaultParamInitializer.BIAS_KEY, biasGradient.dup());
Gradient gradientExpected = new DefaultGradient();
gradientExpected.setGradientFor(DefaultParamInitializer.WEIGHT_KEY, weightGradient.dup());
gradientExpected.setGradientFor(DefaultParamInitializer.BIAS_KEY, biasGradient.dup());
for (int i = 0; i < 2; i++) {
updater.update(layer, gradientActual, i, 1);
if (updaterFunc.equals(org.deeplearning4j.nn.conf.Updater.SGD))
lr = testSGDComputation(gradientActual, gradientExpected, lr, learningRateAfter, i);
else if (updaterFunc.equals(org.deeplearning4j.nn.conf.Updater.ADAGRAD))
lr = testAdaGradComputation(gradientActual, gradientExpected, lr, learningRateAfter, i);
else if (updaterFunc.equals(org.deeplearning4j.nn.conf.Updater.ADAM))
lr = testAdamComputation(gradientActual, gradientExpected, lr, learningRateAfter, i);
else if (updaterFunc.equals(org.deeplearning4j.nn.conf.Updater.RMSPROP))
lr = testRMSPropComputation(gradientActual, gradientExpected, lr, learningRateAfter, i);
assertEquals(lr, layer.conf().getLearningRateByParam("W"), 1e-4);
}
}
}
use of org.deeplearning4j.nn.api.Updater in project deeplearning4j by deeplearning4j.
the class TestGradientNormalization method testL2ClippingPerLayer.
@Test
public void testL2ClippingPerLayer() {
Nd4j.getRandom().setSeed(12345);
double threshold = 3;
for (int t = 0; t < 2; t++) {
//t=0: small -> no clipping
//t=1: large -> clipping
NeuralNetConfiguration conf = new NeuralNetConfiguration.Builder().layer(new DenseLayer.Builder().nIn(10).nOut(20).updater(org.deeplearning4j.nn.conf.Updater.NONE).gradientNormalization(GradientNormalization.ClipL2PerLayer).gradientNormalizationThreshold(threshold).build()).build();
int numParams = conf.getLayer().initializer().numParams(conf);
INDArray params = Nd4j.create(1, numParams);
Layer layer = conf.getLayer().instantiate(conf, null, 0, params, true);
Updater updater = UpdaterCreator.getUpdater(layer);
INDArray weightGrad = Nd4j.rand(10, 20).muli((t == 0 ? 0.05 : 10));
INDArray biasGrad = Nd4j.rand(1, 10).muli((t == 0 ? 0.05 : 10));
INDArray weightGradCopy = weightGrad.dup();
INDArray biasGradCopy = biasGrad.dup();
Gradient gradient = new DefaultGradient();
gradient.setGradientFor(DefaultParamInitializer.WEIGHT_KEY, weightGrad);
gradient.setGradientFor(DefaultParamInitializer.BIAS_KEY, biasGrad);
double layerGradL2 = gradient.gradient().norm2Number().doubleValue();
if (t == 0)
assertTrue(layerGradL2 < threshold);
else
assertTrue(layerGradL2 > threshold);
updater.update(layer, gradient, 0, 1);
if (t == 0) {
//norm2 < threshold -> no change
assertEquals(weightGradCopy, weightGrad);
assertEquals(biasGradCopy, biasGrad);
continue;
} else {
//norm2 > threshold -> rescale
assertNotEquals(weightGradCopy, weightGrad);
assertNotEquals(biasGradCopy, biasGrad);
}
//for above threshold only...
double scalingFactor = threshold / layerGradL2;
INDArray expectedWeightGrad = weightGradCopy.mul(scalingFactor);
INDArray expectedBiasGrad = biasGradCopy.mul(scalingFactor);
assertEquals(expectedWeightGrad, gradient.getGradientFor(DefaultParamInitializer.WEIGHT_KEY));
assertEquals(expectedBiasGrad, gradient.getGradientFor(DefaultParamInitializer.BIAS_KEY));
}
}
use of org.deeplearning4j.nn.api.Updater in project deeplearning4j by deeplearning4j.
the class TestUpdaters method testSetGetUpdater.
@Test
public void testSetGetUpdater() {
Nd4j.getRandom().setSeed(12345L);
double lr = 0.03;
int nIn = 4;
int nOut = 8;
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().learningRate(lr).momentum(0.6).list().layer(0, new DenseLayer.Builder().nIn(nIn).nOut(5).updater(org.deeplearning4j.nn.conf.Updater.SGD).build()).layer(1, new DenseLayer.Builder().nIn(5).nOut(6).updater(org.deeplearning4j.nn.conf.Updater.NONE).build()).layer(2, new DenseLayer.Builder().nIn(6).nOut(7).updater(org.deeplearning4j.nn.conf.Updater.ADAGRAD).build()).layer(3, new OutputLayer.Builder().nIn(7).nOut(nOut).updater(org.deeplearning4j.nn.conf.Updater.NESTEROVS).build()).backprop(true).pretrain(false).build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();
//Fit, to initialize optimizer/updater
net.fit(Nd4j.rand(5, nIn), Nd4j.rand(5, nOut));
Updater updater = net.getUpdater();
assertTrue(updater instanceof MultiLayerUpdater);
Updater newUpdater = UpdaterCreator.getUpdater(net);
net.setUpdater(newUpdater);
//Should be identical object
assertTrue(newUpdater == net.getUpdater());
}
Aggregations