Search in sources :

Example 1 with LearningRatePolicy

use of org.deeplearning4j.nn.conf.LearningRatePolicy in project deeplearning4j by deeplearning4j.

the class LayerUpdater method update.

@Override
public void update(Layer layer, Gradient gradient, int iteration, int miniBatchSize) {
    String paramName;
    INDArray gradientOrig, gradient2;
    GradientUpdater updater;
    if (layer instanceof FrozenLayer)
        return;
    preApply(layer, gradient, iteration);
    for (Map.Entry<String, INDArray> gradientPair : gradient.gradientForVariable().entrySet()) {
        paramName = gradientPair.getKey();
        if (!layer.conf().isPretrain() && PretrainParamInitializer.VISIBLE_BIAS_KEY.equals(paramName.split("_")[0]))
            continue;
        gradientOrig = gradientPair.getValue();
        LearningRatePolicy decay = layer.conf().getLearningRatePolicy();
        if (decay != LearningRatePolicy.None || layer.conf().getLayer().getUpdater() == org.deeplearning4j.nn.conf.Updater.NESTEROVS)
            applyLrDecayPolicy(decay, layer, iteration, paramName);
        updater = init(paramName, layer);
        gradient2 = updater.getGradient(gradientOrig, iteration);
        postApply(layer, gradient2, paramName, miniBatchSize);
        gradient.setGradientFor(paramName, gradient2);
    }
}
Also used : FrozenLayer(org.deeplearning4j.nn.layers.FrozenLayer) INDArray(org.nd4j.linalg.api.ndarray.INDArray) LearningRatePolicy(org.deeplearning4j.nn.conf.LearningRatePolicy) HashMap(java.util.HashMap) LinkedHashMap(java.util.LinkedHashMap) Map(java.util.Map)

Aggregations

HashMap (java.util.HashMap)1 LinkedHashMap (java.util.LinkedHashMap)1 Map (java.util.Map)1 LearningRatePolicy (org.deeplearning4j.nn.conf.LearningRatePolicy)1 FrozenLayer (org.deeplearning4j.nn.layers.FrozenLayer)1 INDArray (org.nd4j.linalg.api.ndarray.INDArray)1