use of org.deeplearning4j.nn.conf.LearningRatePolicy in project deeplearning4j by deeplearning4j.
the class LayerUpdater method update.
@Override
public void update(Layer layer, Gradient gradient, int iteration, int miniBatchSize) {
String paramName;
INDArray gradientOrig, gradient2;
GradientUpdater updater;
if (layer instanceof FrozenLayer)
return;
preApply(layer, gradient, iteration);
for (Map.Entry<String, INDArray> gradientPair : gradient.gradientForVariable().entrySet()) {
paramName = gradientPair.getKey();
if (!layer.conf().isPretrain() && PretrainParamInitializer.VISIBLE_BIAS_KEY.equals(paramName.split("_")[0]))
continue;
gradientOrig = gradientPair.getValue();
LearningRatePolicy decay = layer.conf().getLearningRatePolicy();
if (decay != LearningRatePolicy.None || layer.conf().getLayer().getUpdater() == org.deeplearning4j.nn.conf.Updater.NESTEROVS)
applyLrDecayPolicy(decay, layer, iteration, paramName);
updater = init(paramName, layer);
gradient2 = updater.getGradient(gradientOrig, iteration);
postApply(layer, gradient2, paramName, miniBatchSize);
gradient.setGradientFor(paramName, gradient2);
}
}
Aggregations