use of org.deeplearning4j.nn.conf.preprocessor.CnnToFeedForwardPreProcessor in project deeplearning4j by deeplearning4j.
the class RegressionTest060 method regressionTestCNN1.
@Test
public void regressionTestCNN1() throws Exception {
File f = new ClassPathResource("regression_testing/060/060_ModelSerializer_Regression_CNN_1.zip").getTempFileFromArchive();
MultiLayerNetwork net = ModelSerializer.restoreMultiLayerNetwork(f, true);
MultiLayerConfiguration conf = net.getLayerWiseConfigurations();
assertEquals(3, conf.getConfs().size());
assertTrue(conf.isBackprop());
assertFalse(conf.isPretrain());
ConvolutionLayer l0 = (ConvolutionLayer) conf.getConf(0).getLayer();
assertEquals("tanh", l0.getActivationFn().toString());
assertEquals(3, l0.getNIn());
assertEquals(3, l0.getNOut());
assertEquals(WeightInit.RELU, l0.getWeightInit());
assertEquals(Updater.RMSPROP, l0.getUpdater());
assertEquals(0.96, l0.getRmsDecay(), 1e-6);
assertEquals(0.15, l0.getLearningRate(), 1e-6);
assertArrayEquals(new int[] { 2, 2 }, l0.getKernelSize());
assertArrayEquals(new int[] { 1, 1 }, l0.getStride());
assertArrayEquals(new int[] { 0, 0 }, l0.getPadding());
//Pre-0.7.0: no ConvolutionMode. Want to default to truncate here if not set
assertEquals(l0.getConvolutionMode(), ConvolutionMode.Truncate);
SubsamplingLayer l1 = (SubsamplingLayer) conf.getConf(1).getLayer();
assertArrayEquals(new int[] { 2, 2 }, l1.getKernelSize());
assertArrayEquals(new int[] { 1, 1 }, l1.getStride());
assertArrayEquals(new int[] { 0, 0 }, l1.getPadding());
assertEquals(PoolingType.MAX, l1.getPoolingType());
//Pre-0.7.0: no ConvolutionMode. Want to default to truncate here if not set
assertEquals(l1.getConvolutionMode(), ConvolutionMode.Truncate);
OutputLayer l2 = (OutputLayer) conf.getConf(2).getLayer();
assertEquals("sigmoid", l1.getActivationFn().toString());
assertEquals(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD, l2.getLossFunction());
//TODO
assertTrue(l2.getLossFn() instanceof LossNegativeLogLikelihood);
assertEquals(26 * 26 * 3, l2.getNIn());
assertEquals(5, l2.getNOut());
assertEquals(WeightInit.RELU, l0.getWeightInit());
assertEquals(Updater.RMSPROP, l0.getUpdater());
assertEquals(0.96, l0.getRmsDecay(), 1e-6);
assertEquals(0.15, l0.getLearningRate(), 1e-6);
assertTrue(conf.getInputPreProcess(2) instanceof CnnToFeedForwardPreProcessor);
int numParams = net.numParams();
assertEquals(Nd4j.linspace(1, numParams, numParams), net.params());
int updaterSize = net.getUpdater().stateSizeForLayer(net);
assertEquals(Nd4j.linspace(1, updaterSize, updaterSize), net.getUpdater().getStateViewArray());
}
Aggregations