Search in sources :

Example 1 with BatchNormalization

use of org.deeplearning4j.nn.conf.layers.BatchNormalization in project deeplearning4j by deeplearning4j.

the class BatchNormalizationTest method getLayer.

protected static Layer getLayer(int nOut, double epsilon, boolean lockGammaBeta, double gamma, double beta) {
    BatchNormalization.Builder b = new BatchNormalization.Builder().nOut(nOut).eps(epsilon);
    if (lockGammaBeta) {
        b.lockGammaBeta(true).gamma(gamma).beta(beta);
    }
    BatchNormalization bN = b.build();
    NeuralNetConfiguration conf = new NeuralNetConfiguration.Builder().iterations(1).layer(bN).build();
    int numParams = conf.getLayer().initializer().numParams(conf);
    INDArray params = null;
    if (numParams > 0) {
        params = Nd4j.create(1, numParams);
    }
    Layer layer = conf.getLayer().instantiate(conf, null, 0, params, true);
    if (numParams > 0) {
        layer.setBackpropGradientsViewArray(Nd4j.create(1, numParams));
    }
    return layer;
}
Also used : INDArray(org.nd4j.linalg.api.ndarray.INDArray) NeuralNetConfiguration(org.deeplearning4j.nn.conf.NeuralNetConfiguration) Layer(org.deeplearning4j.nn.api.Layer) BatchNormalization(org.deeplearning4j.nn.conf.layers.BatchNormalization)

Example 2 with BatchNormalization

use of org.deeplearning4j.nn.conf.layers.BatchNormalization in project deeplearning4j by deeplearning4j.

the class BatchNormalizationParamInitializer method getGradientsFromFlattened.

@Override
public Map<String, INDArray> getGradientsFromFlattened(NeuralNetConfiguration conf, INDArray gradientView) {
    BatchNormalization layer = (BatchNormalization) conf.getLayer();
    int nOut = layer.getNOut();
    Map<String, INDArray> out = new LinkedHashMap<>();
    int meanOffset = 0;
    if (!layer.isLockGammaBeta()) {
        INDArray gammaView = gradientView.get(NDArrayIndex.point(0), NDArrayIndex.interval(0, nOut));
        INDArray betaView = gradientView.get(NDArrayIndex.point(0), NDArrayIndex.interval(nOut, 2 * nOut));
        out.put(GAMMA, gammaView);
        out.put(BETA, betaView);
        meanOffset = 2 * nOut;
    }
    out.put(GLOBAL_MEAN, gradientView.get(NDArrayIndex.point(0), NDArrayIndex.interval(meanOffset, meanOffset + nOut)));
    out.put(GLOBAL_VAR, gradientView.get(NDArrayIndex.point(0), NDArrayIndex.interval(meanOffset + nOut, meanOffset + 2 * nOut)));
    return out;
}
Also used : INDArray(org.nd4j.linalg.api.ndarray.INDArray) BatchNormalization(org.deeplearning4j.nn.conf.layers.BatchNormalization) LinkedHashMap(java.util.LinkedHashMap)

Example 3 with BatchNormalization

use of org.deeplearning4j.nn.conf.layers.BatchNormalization in project deeplearning4j by deeplearning4j.

the class BatchNormalizationParamInitializer method init.

@Override
public Map<String, INDArray> init(NeuralNetConfiguration conf, INDArray paramView, boolean initializeParams) {
    Map<String, INDArray> params = Collections.synchronizedMap(new LinkedHashMap<String, INDArray>());
    // TODO setup for RNN
    BatchNormalization layer = (BatchNormalization) conf.getLayer();
    int nOut = layer.getNOut();
    int meanOffset = 0;
    if (!layer.isLockGammaBeta()) {
        //No gamma/beta parameters when gamma/beta are locked
        INDArray gammaView = paramView.get(NDArrayIndex.point(0), NDArrayIndex.interval(0, nOut));
        INDArray betaView = paramView.get(NDArrayIndex.point(0), NDArrayIndex.interval(nOut, 2 * nOut));
        params.put(GAMMA, createGamma(conf, gammaView, initializeParams));
        conf.addVariable(GAMMA);
        params.put(BETA, createBeta(conf, betaView, initializeParams));
        conf.addVariable(BETA);
        meanOffset = 2 * nOut;
    }
    INDArray globalMeanView = paramView.get(NDArrayIndex.point(0), NDArrayIndex.interval(meanOffset, meanOffset + nOut));
    INDArray globalVarView = paramView.get(NDArrayIndex.point(0), NDArrayIndex.interval(meanOffset + nOut, meanOffset + 2 * nOut));
    if (initializeParams) {
        globalMeanView.assign(0);
        globalVarView.assign(1);
    }
    params.put(GLOBAL_MEAN, globalMeanView);
    conf.addVariable(GLOBAL_MEAN);
    params.put(GLOBAL_VAR, globalVarView);
    conf.addVariable(GLOBAL_VAR);
    return params;
}
Also used : INDArray(org.nd4j.linalg.api.ndarray.INDArray) BatchNormalization(org.deeplearning4j.nn.conf.layers.BatchNormalization)

Aggregations

BatchNormalization (org.deeplearning4j.nn.conf.layers.BatchNormalization)3 INDArray (org.nd4j.linalg.api.ndarray.INDArray)3 LinkedHashMap (java.util.LinkedHashMap)1 Layer (org.deeplearning4j.nn.api.Layer)1 NeuralNetConfiguration (org.deeplearning4j.nn.conf.NeuralNetConfiguration)1