Search in sources :

Example 6 with ConvolutionMode

use of org.deeplearning4j.nn.conf.ConvolutionMode in project deeplearning4j by deeplearning4j.

the class TestConvolutionModes method testGlobalLocalConfigCompGraph.

@Test
public void testGlobalLocalConfigCompGraph() {
    for (ConvolutionMode cm : new ConvolutionMode[] { ConvolutionMode.Strict, ConvolutionMode.Truncate, ConvolutionMode.Same }) {
        ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().weightInit(WeightInit.XAVIER).convolutionMode(cm).graphBuilder().addInputs("in").addLayer("0", new ConvolutionLayer.Builder().kernelSize(3, 3).stride(3, 3).padding(0, 0).nIn(3).nOut(3).build(), "in").addLayer("1", new ConvolutionLayer.Builder().convolutionMode(ConvolutionMode.Strict).kernelSize(3, 3).stride(3, 3).padding(0, 0).nIn(3).nOut(3).build(), "0").addLayer("2", new ConvolutionLayer.Builder().convolutionMode(ConvolutionMode.Truncate).kernelSize(3, 3).stride(3, 3).padding(0, 0).nIn(3).nOut(3).build(), "1").addLayer("3", new ConvolutionLayer.Builder().convolutionMode(ConvolutionMode.Same).kernelSize(3, 3).stride(3, 3).padding(0, 0).nIn(3).nOut(3).build(), "2").addLayer("4", new SubsamplingLayer.Builder().kernelSize(3, 3).stride(3, 3).padding(0, 0).build(), "3").addLayer("5", new SubsamplingLayer.Builder().convolutionMode(ConvolutionMode.Strict).kernelSize(3, 3).stride(3, 3).padding(0, 0).build(), "4").addLayer("6", new SubsamplingLayer.Builder().convolutionMode(ConvolutionMode.Truncate).kernelSize(3, 3).stride(3, 3).padding(0, 0).build(), "5").addLayer("7", new SubsamplingLayer.Builder().convolutionMode(ConvolutionMode.Same).kernelSize(3, 3).stride(3, 3).padding(0, 0).build(), "6").addLayer("8", new OutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MCXENT).nOut(3).build(), "7").setOutputs("8").build();
        assertEquals(cm, ((ConvolutionLayer) ((LayerVertex) conf.getVertices().get("0")).getLayerConf().getLayer()).getConvolutionMode());
        assertEquals(ConvolutionMode.Strict, ((ConvolutionLayer) ((LayerVertex) conf.getVertices().get("1")).getLayerConf().getLayer()).getConvolutionMode());
        assertEquals(ConvolutionMode.Truncate, ((ConvolutionLayer) ((LayerVertex) conf.getVertices().get("2")).getLayerConf().getLayer()).getConvolutionMode());
        assertEquals(ConvolutionMode.Same, ((ConvolutionLayer) ((LayerVertex) conf.getVertices().get("3")).getLayerConf().getLayer()).getConvolutionMode());
        assertEquals(cm, ((SubsamplingLayer) ((LayerVertex) conf.getVertices().get("4")).getLayerConf().getLayer()).getConvolutionMode());
        assertEquals(ConvolutionMode.Strict, ((SubsamplingLayer) ((LayerVertex) conf.getVertices().get("5")).getLayerConf().getLayer()).getConvolutionMode());
        assertEquals(ConvolutionMode.Truncate, ((SubsamplingLayer) ((LayerVertex) conf.getVertices().get("6")).getLayerConf().getLayer()).getConvolutionMode());
        assertEquals(ConvolutionMode.Same, ((SubsamplingLayer) ((LayerVertex) conf.getVertices().get("7")).getLayerConf().getLayer()).getConvolutionMode());
    }
}
Also used : ComputationGraphConfiguration(org.deeplearning4j.nn.conf.ComputationGraphConfiguration) NeuralNetConfiguration(org.deeplearning4j.nn.conf.NeuralNetConfiguration) ConvolutionMode(org.deeplearning4j.nn.conf.ConvolutionMode) ConvolutionLayer(org.deeplearning4j.nn.conf.layers.ConvolutionLayer) Test(org.junit.Test)

Aggregations

ConvolutionMode (org.deeplearning4j.nn.conf.ConvolutionMode)6 NeuralNetConfiguration (org.deeplearning4j.nn.conf.NeuralNetConfiguration)5 ConvolutionLayer (org.deeplearning4j.nn.conf.layers.ConvolutionLayer)5 Test (org.junit.Test)5 MultiLayerConfiguration (org.deeplearning4j.nn.conf.MultiLayerConfiguration)3 INDArray (org.nd4j.linalg.api.ndarray.INDArray)3 DL4JException (org.deeplearning4j.exception.DL4JException)2 ComputationGraphConfiguration (org.deeplearning4j.nn.conf.ComputationGraphConfiguration)2 MultiLayerNetwork (org.deeplearning4j.nn.multilayer.MultiLayerNetwork)2 Field (java.lang.reflect.Field)1 Layer (org.deeplearning4j.nn.api.Layer)1 OutputLayer (org.deeplearning4j.nn.conf.layers.OutputLayer)1 SubsamplingLayer (org.deeplearning4j.nn.conf.layers.SubsamplingLayer)1 Gradient (org.deeplearning4j.nn.gradient.Gradient)1 ComputationGraph (org.deeplearning4j.nn.graph.ComputationGraph)1