use of org.deeplearning4j.nn.conf.ConvolutionMode in project deeplearning4j by deeplearning4j.
the class TestConvolutionModes method testGlobalLocalConfigCompGraph.
@Test
public void testGlobalLocalConfigCompGraph() {
for (ConvolutionMode cm : new ConvolutionMode[] { ConvolutionMode.Strict, ConvolutionMode.Truncate, ConvolutionMode.Same }) {
ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().weightInit(WeightInit.XAVIER).convolutionMode(cm).graphBuilder().addInputs("in").addLayer("0", new ConvolutionLayer.Builder().kernelSize(3, 3).stride(3, 3).padding(0, 0).nIn(3).nOut(3).build(), "in").addLayer("1", new ConvolutionLayer.Builder().convolutionMode(ConvolutionMode.Strict).kernelSize(3, 3).stride(3, 3).padding(0, 0).nIn(3).nOut(3).build(), "0").addLayer("2", new ConvolutionLayer.Builder().convolutionMode(ConvolutionMode.Truncate).kernelSize(3, 3).stride(3, 3).padding(0, 0).nIn(3).nOut(3).build(), "1").addLayer("3", new ConvolutionLayer.Builder().convolutionMode(ConvolutionMode.Same).kernelSize(3, 3).stride(3, 3).padding(0, 0).nIn(3).nOut(3).build(), "2").addLayer("4", new SubsamplingLayer.Builder().kernelSize(3, 3).stride(3, 3).padding(0, 0).build(), "3").addLayer("5", new SubsamplingLayer.Builder().convolutionMode(ConvolutionMode.Strict).kernelSize(3, 3).stride(3, 3).padding(0, 0).build(), "4").addLayer("6", new SubsamplingLayer.Builder().convolutionMode(ConvolutionMode.Truncate).kernelSize(3, 3).stride(3, 3).padding(0, 0).build(), "5").addLayer("7", new SubsamplingLayer.Builder().convolutionMode(ConvolutionMode.Same).kernelSize(3, 3).stride(3, 3).padding(0, 0).build(), "6").addLayer("8", new OutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MCXENT).nOut(3).build(), "7").setOutputs("8").build();
assertEquals(cm, ((ConvolutionLayer) ((LayerVertex) conf.getVertices().get("0")).getLayerConf().getLayer()).getConvolutionMode());
assertEquals(ConvolutionMode.Strict, ((ConvolutionLayer) ((LayerVertex) conf.getVertices().get("1")).getLayerConf().getLayer()).getConvolutionMode());
assertEquals(ConvolutionMode.Truncate, ((ConvolutionLayer) ((LayerVertex) conf.getVertices().get("2")).getLayerConf().getLayer()).getConvolutionMode());
assertEquals(ConvolutionMode.Same, ((ConvolutionLayer) ((LayerVertex) conf.getVertices().get("3")).getLayerConf().getLayer()).getConvolutionMode());
assertEquals(cm, ((SubsamplingLayer) ((LayerVertex) conf.getVertices().get("4")).getLayerConf().getLayer()).getConvolutionMode());
assertEquals(ConvolutionMode.Strict, ((SubsamplingLayer) ((LayerVertex) conf.getVertices().get("5")).getLayerConf().getLayer()).getConvolutionMode());
assertEquals(ConvolutionMode.Truncate, ((SubsamplingLayer) ((LayerVertex) conf.getVertices().get("6")).getLayerConf().getLayer()).getConvolutionMode());
assertEquals(ConvolutionMode.Same, ((SubsamplingLayer) ((LayerVertex) conf.getVertices().get("7")).getLayerConf().getLayer()).getConvolutionMode());
}
}
Aggregations