use of org.deeplearning4j.nn.layers.normalization.CudnnBatchNormalizationHelper in project deeplearning4j by deeplearning4j.
the class CuDNNGradientChecks method testBatchNormCnn.
@Test
public void testBatchNormCnn() throws Exception {
//Note: CuDNN batch norm supports 4d only, as per 5.1 (according to api reference documentation)
Nd4j.getRandom().setSeed(12345);
int minibatch = 10;
int depth = 1;
int hw = 4;
int nOut = 4;
INDArray input = Nd4j.rand(new int[] { minibatch, depth, hw, hw });
INDArray labels = Nd4j.zeros(minibatch, nOut);
Random r = new Random(12345);
for (int i = 0; i < minibatch; i++) {
labels.putScalar(i, r.nextInt(nOut), 1.0);
}
MultiLayerConfiguration.Builder builder = new NeuralNetConfiguration.Builder().learningRate(1.0).regularization(false).updater(Updater.NONE).seed(12345L).weightInit(WeightInit.DISTRIBUTION).dist(new NormalDistribution(0, 2)).list().layer(0, new ConvolutionLayer.Builder().kernelSize(2, 2).stride(1, 1).nIn(depth).nOut(2).activation(Activation.IDENTITY).build()).layer(1, new BatchNormalization.Builder().build()).layer(2, new ActivationLayer.Builder().activation(Activation.TANH).build()).layer(3, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).nOut(nOut).build()).setInputType(InputType.convolutional(hw, hw, depth)).pretrain(false).backprop(true);
MultiLayerNetwork mln = new MultiLayerNetwork(builder.build());
mln.init();
Field f = org.deeplearning4j.nn.layers.normalization.BatchNormalization.class.getDeclaredField("helper");
f.setAccessible(true);
org.deeplearning4j.nn.layers.normalization.BatchNormalization b = (org.deeplearning4j.nn.layers.normalization.BatchNormalization) mln.getLayer(1);
BatchNormalizationHelper bn = (BatchNormalizationHelper) f.get(b);
assertTrue(bn instanceof CudnnBatchNormalizationHelper);
if (PRINT_RESULTS) {
for (int j = 0; j < mln.getnLayers(); j++) System.out.println("Layer " + j + " # params: " + mln.getLayer(j).numParams());
}
boolean gradOK = GradientCheckUtil.checkGradients(mln, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, input, labels);
assertTrue(gradOK);
}
Aggregations