Search in sources :

Example 1 with BatchNormalizationHelper

use of org.deeplearning4j.nn.layers.normalization.BatchNormalizationHelper in project deeplearning4j by deeplearning4j.

the class CuDNNGradientChecks method testBatchNormCnn.

@Test
public void testBatchNormCnn() throws Exception {
    //Note: CuDNN batch norm supports 4d only, as per 5.1 (according to api reference documentation)
    Nd4j.getRandom().setSeed(12345);
    int minibatch = 10;
    int depth = 1;
    int hw = 4;
    int nOut = 4;
    INDArray input = Nd4j.rand(new int[] { minibatch, depth, hw, hw });
    INDArray labels = Nd4j.zeros(minibatch, nOut);
    Random r = new Random(12345);
    for (int i = 0; i < minibatch; i++) {
        labels.putScalar(i, r.nextInt(nOut), 1.0);
    }
    MultiLayerConfiguration.Builder builder = new NeuralNetConfiguration.Builder().learningRate(1.0).regularization(false).updater(Updater.NONE).seed(12345L).weightInit(WeightInit.DISTRIBUTION).dist(new NormalDistribution(0, 2)).list().layer(0, new ConvolutionLayer.Builder().kernelSize(2, 2).stride(1, 1).nIn(depth).nOut(2).activation(Activation.IDENTITY).build()).layer(1, new BatchNormalization.Builder().build()).layer(2, new ActivationLayer.Builder().activation(Activation.TANH).build()).layer(3, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).nOut(nOut).build()).setInputType(InputType.convolutional(hw, hw, depth)).pretrain(false).backprop(true);
    MultiLayerNetwork mln = new MultiLayerNetwork(builder.build());
    mln.init();
    Field f = org.deeplearning4j.nn.layers.normalization.BatchNormalization.class.getDeclaredField("helper");
    f.setAccessible(true);
    org.deeplearning4j.nn.layers.normalization.BatchNormalization b = (org.deeplearning4j.nn.layers.normalization.BatchNormalization) mln.getLayer(1);
    BatchNormalizationHelper bn = (BatchNormalizationHelper) f.get(b);
    assertTrue(bn instanceof CudnnBatchNormalizationHelper);
    if (PRINT_RESULTS) {
        for (int j = 0; j < mln.getnLayers(); j++) System.out.println("Layer " + j + " # params: " + mln.getLayer(j).numParams());
    }
    boolean gradOK = GradientCheckUtil.checkGradients(mln, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, input, labels);
    assertTrue(gradOK);
}
Also used : Field(java.lang.reflect.Field) MultiLayerConfiguration(org.deeplearning4j.nn.conf.MultiLayerConfiguration) Random(java.util.Random) BatchNormalizationHelper(org.deeplearning4j.nn.layers.normalization.BatchNormalizationHelper) CudnnBatchNormalizationHelper(org.deeplearning4j.nn.layers.normalization.CudnnBatchNormalizationHelper) org.deeplearning4j.nn.conf.layers(org.deeplearning4j.nn.conf.layers) MultiLayerNetwork(org.deeplearning4j.nn.multilayer.MultiLayerNetwork) CudnnBatchNormalizationHelper(org.deeplearning4j.nn.layers.normalization.CudnnBatchNormalizationHelper) INDArray(org.nd4j.linalg.api.ndarray.INDArray) NormalDistribution(org.deeplearning4j.nn.conf.distribution.NormalDistribution) Test(org.junit.Test)

Aggregations

Field (java.lang.reflect.Field)1 Random (java.util.Random)1 MultiLayerConfiguration (org.deeplearning4j.nn.conf.MultiLayerConfiguration)1 NormalDistribution (org.deeplearning4j.nn.conf.distribution.NormalDistribution)1 org.deeplearning4j.nn.conf.layers (org.deeplearning4j.nn.conf.layers)1 BatchNormalizationHelper (org.deeplearning4j.nn.layers.normalization.BatchNormalizationHelper)1 CudnnBatchNormalizationHelper (org.deeplearning4j.nn.layers.normalization.CudnnBatchNormalizationHelper)1 MultiLayerNetwork (org.deeplearning4j.nn.multilayer.MultiLayerNetwork)1 Test (org.junit.Test)1 INDArray (org.nd4j.linalg.api.ndarray.INDArray)1