Search in sources :

Example 1 with LayerVertex

use of org.deeplearning4j.nn.conf.graph.LayerVertex in project deeplearning4j by deeplearning4j.

the class TestComputationGraphNetwork method testPreprocessorAddition.

@Test
public void testPreprocessorAddition() {
    //Also check that nIns are set automatically
    //First: check FF -> RNN
    ComputationGraphConfiguration conf1 = new NeuralNetConfiguration.Builder().graphBuilder().addInputs("in").setInputTypes(InputType.feedForward(5)).addLayer("rnn", new GravesLSTM.Builder().nOut(5).build(), "in").addLayer("out", new RnnOutputLayer.Builder().nOut(5).build(), "rnn").setOutputs("out").build();
    assertEquals(5, ((FeedForwardLayer) ((LayerVertex) conf1.getVertices().get("rnn")).getLayerConf().getLayer()).getNIn());
    assertEquals(5, ((FeedForwardLayer) ((LayerVertex) conf1.getVertices().get("out")).getLayerConf().getLayer()).getNIn());
    LayerVertex lv1 = (LayerVertex) conf1.getVertices().get("rnn");
    assertTrue(lv1.getPreProcessor() instanceof FeedForwardToRnnPreProcessor);
    LayerVertex lv2 = (LayerVertex) conf1.getVertices().get("out");
    assertNull(lv2.getPreProcessor());
    //Check RNN -> FF -> RNN
    ComputationGraphConfiguration conf2 = new NeuralNetConfiguration.Builder().graphBuilder().addInputs("in").setInputTypes(InputType.recurrent(5)).addLayer("ff", new DenseLayer.Builder().nOut(5).build(), "in").addLayer("out", new RnnOutputLayer.Builder().nOut(5).build(), "ff").setOutputs("out").build();
    assertEquals(5, ((FeedForwardLayer) ((LayerVertex) conf2.getVertices().get("ff")).getLayerConf().getLayer()).getNIn());
    assertEquals(5, ((FeedForwardLayer) ((LayerVertex) conf2.getVertices().get("out")).getLayerConf().getLayer()).getNIn());
    lv1 = (LayerVertex) conf2.getVertices().get("ff");
    assertTrue(lv1.getPreProcessor() instanceof RnnToFeedForwardPreProcessor);
    lv2 = (LayerVertex) conf2.getVertices().get("out");
    assertTrue(lv2.getPreProcessor() instanceof FeedForwardToRnnPreProcessor);
    //CNN -> Dense
    ComputationGraphConfiguration conf3 = new NeuralNetConfiguration.Builder().graphBuilder().addInputs("in").setInputTypes(InputType.convolutional(28, 28, 1)).addLayer("cnn", new ConvolutionLayer.Builder().kernelSize(2, 2).padding(0, 0).stride(2, 2).nOut(3).build(), //(28-2+0)/2+1 = 14
    "in").addLayer("pool", new SubsamplingLayer.Builder().kernelSize(2, 2).padding(0, 0).stride(2, 2).build(), //(14-2+0)/2+1=7
    "cnn").addLayer("dense", new DenseLayer.Builder().nOut(10).build(), "pool").addLayer("out", new OutputLayer.Builder().nIn(10).nOut(5).build(), "dense").setOutputs("out").build();
    //Check preprocessors:
    lv1 = (LayerVertex) conf3.getVertices().get("cnn");
    //Shouldn't be adding preprocessor here
    assertNull(lv1.getPreProcessor());
    lv2 = (LayerVertex) conf3.getVertices().get("pool");
    assertNull(lv2.getPreProcessor());
    LayerVertex lv3 = (LayerVertex) conf3.getVertices().get("dense");
    assertTrue(lv3.getPreProcessor() instanceof CnnToFeedForwardPreProcessor);
    CnnToFeedForwardPreProcessor proc = (CnnToFeedForwardPreProcessor) lv3.getPreProcessor();
    assertEquals(3, proc.getNumChannels());
    assertEquals(7, proc.getInputHeight());
    assertEquals(7, proc.getInputWidth());
    LayerVertex lv4 = (LayerVertex) conf3.getVertices().get("out");
    assertNull(lv4.getPreProcessor());
    //Check nIns:
    assertEquals(7 * 7 * 3, ((FeedForwardLayer) lv3.getLayerConf().getLayer()).getNIn());
    //CNN->Dense, RNN->Dense, Dense->RNN
    ComputationGraphConfiguration conf4 = new NeuralNetConfiguration.Builder().graphBuilder().addInputs("inCNN", "inRNN").setInputTypes(InputType.convolutional(28, 28, 1), InputType.recurrent(5)).addLayer("cnn", new ConvolutionLayer.Builder().kernelSize(2, 2).padding(0, 0).stride(2, 2).nOut(3).build(), //(28-2+0)/2+1 = 14
    "inCNN").addLayer("pool", new SubsamplingLayer.Builder().kernelSize(2, 2).padding(0, 0).stride(2, 2).build(), //(14-2+0)/2+1=7
    "cnn").addLayer("dense", new DenseLayer.Builder().nOut(10).build(), "pool").addLayer("dense2", new DenseLayer.Builder().nOut(10).build(), "inRNN").addVertex("merge", new MergeVertex(), "dense", "dense2").addLayer("out", new RnnOutputLayer.Builder().nOut(5).build(), "merge").setOutputs("out").build();
    //Check preprocessors:
    lv1 = (LayerVertex) conf4.getVertices().get("cnn");
    //Expect no preprocessor: cnn data -> cnn layer
    assertNull(lv1.getPreProcessor());
    lv2 = (LayerVertex) conf4.getVertices().get("pool");
    assertNull(lv2.getPreProcessor());
    lv3 = (LayerVertex) conf4.getVertices().get("dense");
    assertTrue(lv3.getPreProcessor() instanceof CnnToFeedForwardPreProcessor);
    proc = (CnnToFeedForwardPreProcessor) lv3.getPreProcessor();
    assertEquals(3, proc.getNumChannels());
    assertEquals(7, proc.getInputHeight());
    assertEquals(7, proc.getInputWidth());
    lv4 = (LayerVertex) conf4.getVertices().get("dense2");
    assertTrue(lv4.getPreProcessor() instanceof RnnToFeedForwardPreProcessor);
    LayerVertex lv5 = (LayerVertex) conf4.getVertices().get("out");
    assertTrue(lv5.getPreProcessor() instanceof FeedForwardToRnnPreProcessor);
    //Check nIns:
    assertEquals(7 * 7 * 3, ((FeedForwardLayer) lv3.getLayerConf().getLayer()).getNIn());
    assertEquals(5, ((FeedForwardLayer) lv4.getLayerConf().getLayer()).getNIn());
    //10+10 out of the merge vertex -> 20 in to output layer vertex
    assertEquals(20, ((FeedForwardLayer) lv5.getLayerConf().getLayer()).getNIn());
    //Input to 2 CNN layers:
    ComputationGraphConfiguration conf5 = new NeuralNetConfiguration.Builder().optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).graphBuilder().addInputs("input").setInputTypes(InputType.convolutional(28, 28, 1)).addLayer("cnn_1", new ConvolutionLayer.Builder(2, 2).stride(2, 2).nIn(1).nOut(3).build(), "input").addLayer("cnn_2", new ConvolutionLayer.Builder(4, 4).stride(2, 2).padding(1, 1).nIn(1).nOut(3).build(), "input").addLayer("max_1", new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX).kernelSize(2, 2).build(), "cnn_1", "cnn_2").addLayer("output", new OutputLayer.Builder().nOut(10).build(), //.nIn(7 * 7 * 6)
    "max_1").setOutputs("output").pretrain(false).backprop(true).build();
    lv1 = (LayerVertex) conf5.getVertices().get("cnn_1");
    //Expect no preprocessor: cnn data -> cnn layer
    assertNull(lv1.getPreProcessor());
    lv2 = (LayerVertex) conf5.getVertices().get("cnn_2");
    //Expect no preprocessor: cnn data -> cnn layer
    assertNull(lv2.getPreProcessor());
    assertNull(((LayerVertex) conf5.getVertices().get("max_1")).getPreProcessor());
    lv3 = (LayerVertex) conf5.getVertices().get("output");
    assertTrue(lv3.getPreProcessor() instanceof CnnToFeedForwardPreProcessor);
    CnnToFeedForwardPreProcessor cnnff = (CnnToFeedForwardPreProcessor) lv3.getPreProcessor();
    assertEquals(6, cnnff.getNumChannels());
    assertEquals(7, cnnff.getInputHeight());
    assertEquals(7, cnnff.getInputWidth());
}
Also used : LayerVertex(org.deeplearning4j.nn.conf.graph.LayerVertex) CnnToFeedForwardPreProcessor(org.deeplearning4j.nn.conf.preprocessor.CnnToFeedForwardPreProcessor) MergeVertex(org.deeplearning4j.nn.conf.graph.MergeVertex) RnnToFeedForwardPreProcessor(org.deeplearning4j.nn.conf.preprocessor.RnnToFeedForwardPreProcessor) FeedForwardToRnnPreProcessor(org.deeplearning4j.nn.conf.preprocessor.FeedForwardToRnnPreProcessor) Test(org.junit.Test)

Example 2 with LayerVertex

use of org.deeplearning4j.nn.conf.graph.LayerVertex in project deeplearning4j by deeplearning4j.

the class TestUpdaters method testUpdaters.

@Test
public void testUpdaters() throws Exception {
    ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).updater(Updater.NESTEROVS).momentum(0.9).graphBuilder().addInputs(// 40x40x1
    "input").addLayer("l0_cnn", new ConvolutionLayer.Builder(new int[] { 3, 3 }, new int[] { 1, 1 }, new int[] { 1, 1 }).nOut(100).build(), // out: 40x40x100
    "input").addLayer("l1_max", new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX, new int[] { 2, 2 }, new int[] { 2, 2 }, new int[] { 1, 1 }).build(), // 21x21x100
    "l0_cnn").addLayer("l2_cnn", new ConvolutionLayer.Builder(new int[] { 3, 3 }, new int[] { 2, 2 }, new int[] { 1, 1 }).nOut(200).build(), // 11x11x200
    "l1_max").addLayer("l3_max", new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX, new int[] { 3, 3 }, new int[] { 2, 2 }, new int[] { 1, 1 }).build(), // 6x6x200
    "l2_cnn").addLayer("l4_fc", new DenseLayer.Builder().nOut(1024).build(), // output: 1x1x1024
    "l3_max").addLayer("l5_out", new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD).nOut(10).activation(Activation.SOFTMAX).build(), "l4_fc").setOutputs("l5_out").backprop(true).pretrain(false).setInputTypes(InputType.convolutional(40, 40, 1)).build();
    //First: check that the nIns are set properly...
    Map<String, GraphVertex> map = conf.getVertices();
    LayerVertex l0_cnn = (LayerVertex) map.get("l0_cnn");
    LayerVertex l2_cnn = (LayerVertex) map.get("l2_cnn");
    LayerVertex l4_fc = (LayerVertex) map.get("l4_fc");
    LayerVertex l5_out = (LayerVertex) map.get("l5_out");
    assertEquals(1, ((FeedForwardLayer) l0_cnn.getLayerConf().getLayer()).getNIn());
    assertEquals(100, ((FeedForwardLayer) l2_cnn.getLayerConf().getLayer()).getNIn());
    assertEquals(6 * 6 * 200, ((FeedForwardLayer) l4_fc.getLayerConf().getLayer()).getNIn());
    assertEquals(1024, ((FeedForwardLayer) l5_out.getLayerConf().getLayer()).getNIn());
    //Check updaters state:
    ComputationGraph g = new ComputationGraph(conf);
    g.init();
    g.initGradientsView();
    ComputationGraphUpdater updater = g.getUpdater();
    //First: get the updaters array
    Field layerUpdatersField = updater.getClass().getDeclaredField("layerUpdaters");
    layerUpdatersField.setAccessible(true);
    org.deeplearning4j.nn.api.Updater[] layerUpdaters = (org.deeplearning4j.nn.api.Updater[]) layerUpdatersField.get(updater);
    //And get the map between names and updater indexes
    Field layerUpdatersMapField = updater.getClass().getDeclaredField("layerUpdatersMap");
    layerUpdatersMapField.setAccessible(true);
    Map<String, Integer> layerUpdatersMap = (Map<String, Integer>) layerUpdatersMapField.get(updater);
    //Go through each layer; check that the updater state size matches the parameters size
    org.deeplearning4j.nn.api.Layer[] layers = g.getLayers();
    for (org.deeplearning4j.nn.api.Layer l : layers) {
        String layerName = l.conf().getLayer().getLayerName();
        int nParams = l.numParams();
        Map<String, INDArray> paramTable = l.paramTable();
        Map<String, Integer> parameterSizeCounts = new LinkedHashMap<>();
        for (Map.Entry<String, INDArray> e : paramTable.entrySet()) {
            parameterSizeCounts.put(e.getKey(), e.getValue().length());
        }
        int updaterIdx = layerUpdatersMap.get(layerName);
        org.deeplearning4j.nn.api.Updater u = layerUpdaters[updaterIdx];
        LayerUpdater lu = (LayerUpdater) u;
        Field updaterForVariableField = LayerUpdater.class.getDeclaredField("updaterForVariable");
        updaterForVariableField.setAccessible(true);
        Map<String, GradientUpdater> updaterForVariable = (Map<String, GradientUpdater>) updaterForVariableField.get(lu);
        Map<String, Integer> updaterStateSizeCounts = new HashMap<>();
        for (Map.Entry<String, GradientUpdater> entry : updaterForVariable.entrySet()) {
            GradientUpdater gu = entry.getValue();
            Nesterovs nesterovs = (Nesterovs) gu;
            INDArray v = nesterovs.getV();
            int length = (v == null ? -1 : v.length());
            updaterStateSizeCounts.put(entry.getKey(), length);
        }
        //Check subsampling layers:
        if (l.numParams() == 0) {
            assertEquals(0, updaterForVariable.size());
        }
        System.out.println(layerName + "\t" + nParams + "\t" + parameterSizeCounts + "\t Updater size: " + updaterStateSizeCounts);
        //Now, with nesterov updater: 1 history value per parameter
        for (String s : parameterSizeCounts.keySet()) {
            int paramSize = parameterSizeCounts.get(s);
            int updaterSize = updaterStateSizeCounts.get(s);
            assertEquals(layerName + "/" + s, paramSize, updaterSize);
        }
    }
    //minibatch, depth, height, width
    INDArray in = Nd4j.create(2, 1, 40, 40);
    INDArray l = Nd4j.create(2, 10);
    DataSet ds = new DataSet(in, l);
    g.fit(ds);
}
Also used : LayerVertex(org.deeplearning4j.nn.conf.graph.LayerVertex) HashMap(java.util.HashMap) LinkedHashMap(java.util.LinkedHashMap) DataSet(org.nd4j.linalg.dataset.DataSet) ComputationGraphUpdater(org.deeplearning4j.nn.updater.graph.ComputationGraphUpdater) GradientUpdater(org.nd4j.linalg.learning.GradientUpdater) LinkedHashMap(java.util.LinkedHashMap) Field(java.lang.reflect.Field) GraphVertex(org.deeplearning4j.nn.conf.graph.GraphVertex) LayerUpdater(org.deeplearning4j.nn.updater.LayerUpdater) Updater(org.deeplearning4j.nn.conf.Updater) ComputationGraphUpdater(org.deeplearning4j.nn.updater.graph.ComputationGraphUpdater) GradientUpdater(org.nd4j.linalg.learning.GradientUpdater) Nesterovs(org.nd4j.linalg.learning.Nesterovs) INDArray(org.nd4j.linalg.api.ndarray.INDArray) ComputationGraphConfiguration(org.deeplearning4j.nn.conf.ComputationGraphConfiguration) LayerUpdater(org.deeplearning4j.nn.updater.LayerUpdater) HashMap(java.util.HashMap) LinkedHashMap(java.util.LinkedHashMap) Map(java.util.Map) Test(org.junit.Test)

Example 3 with LayerVertex

use of org.deeplearning4j.nn.conf.graph.LayerVertex in project deeplearning4j by deeplearning4j.

the class ComputationGraphConfiguration method addPreProcessors.

/**
     * Add preprocessors automatically, given the specified types of inputs for the network. Inputs are specified using the
     * {@link InputType} class, in the same order in which the inputs were defined in the original configuration.<br>
     * For example, in a network with two inputs: a convolutional input (28x28x1 images) and feed forward inputs, use
     * {@code .addPreProcessors(InputType.convolutional(1,28,28),InputType.feedForward())}.<br>
     * For the CNN->Dense and CNN->RNN transitions, the nIns on the Dense/RNN layers will also be added automatically.
     * <b>NOTE</b>: This method will be called automatically when using the
     * {@link GraphBuilder#setInputTypes(InputType...)} functionality.
     * See that method for details.
     */
public void addPreProcessors(InputType... inputTypes) {
    if (inputTypes == null || inputTypes.length != networkInputs.size()) {
        throw new IllegalArgumentException("Invalid number of InputTypes: cannot add preprocessors if number of InputType " + "objects differs from number of network inputs");
    }
    //Now: need to do essentially a forward pass through the network, to work out what type of preprocessors to add
    //To do this: need to know what the output types are for each GraphVertex.
    //First step: build network in reverse order (i.e., define map of a -> list(b) instead of list(a) -> b)
    //Key: vertex. Values: vertices that this node is an input for
    Map<String, List<String>> verticesOutputTo = new HashMap<>();
    for (Map.Entry<String, GraphVertex> entry : vertices.entrySet()) {
        String vertexName = entry.getKey();
        List<String> vertexInputNames;
        vertexInputNames = vertexInputs.get(vertexName);
        if (vertexInputNames == null)
            continue;
        //Build reverse network structure:
        for (String s : vertexInputNames) {
            List<String> list = verticesOutputTo.get(s);
            if (list == null) {
                list = new ArrayList<>();
                verticesOutputTo.put(s, list);
            }
            //Edge: s -> vertexName
            list.add(vertexName);
        }
    }
    //Now: do topological sort
    //Set of all nodes with no incoming edges
    LinkedList<String> noIncomingEdges = new LinkedList<>(networkInputs);
    List<String> topologicalOrdering = new ArrayList<>();
    Map<String, Set<String>> inputEdges = new HashMap<>();
    for (Map.Entry<String, List<String>> entry : vertexInputs.entrySet()) {
        inputEdges.put(entry.getKey(), new HashSet<>(entry.getValue()));
    }
    while (!noIncomingEdges.isEmpty()) {
        String next = noIncomingEdges.removeFirst();
        topologicalOrdering.add(next);
        //Remove edges next -> vertexOuputsTo[...] from graph;
        List<String> nextEdges = verticesOutputTo.get(next);
        if (nextEdges != null && !nextEdges.isEmpty()) {
            for (String s : nextEdges) {
                Set<String> set = inputEdges.get(s);
                set.remove(next);
                if (set.isEmpty()) {
                    //No remaining edges for vertex i -> add to list for processing
                    noIncomingEdges.add(s);
                }
            }
        }
    }
    //If any edges remain in the graph: graph has cycles:
    for (Map.Entry<String, Set<String>> entry : inputEdges.entrySet()) {
        Set<String> set = entry.getValue();
        if (set == null)
            continue;
        if (!set.isEmpty())
            throw new IllegalStateException("Invalid configuration: cycle detected in graph. Cannot calculate topological ordering with graph cycle (" + "cycle includes vertex \"" + entry.getKey() + "\")");
    }
    //Now, given the topological sort: do equivalent of forward pass
    Map<String, InputType> vertexOutputs = new HashMap<>();
    int currLayerIdx = -1;
    for (String s : topologicalOrdering) {
        int inputIdx = networkInputs.indexOf(s);
        if (inputIdx != -1) {
            vertexOutputs.put(s, inputTypes[inputIdx]);
            continue;
        }
        GraphVertex gv = vertices.get(s);
        List<InputType> inputTypeList = new ArrayList<>();
        if (gv instanceof LayerVertex) {
            //Add preprocessor, if necessary:
            String in = vertexInputs.get(s).get(0);
            InputType layerInput = vertexOutputs.get(in);
            inputTypeList.add(layerInput);
            LayerVertex lv = (LayerVertex) gv;
            Layer l = lv.getLayerConf().getLayer();
            //Preprocessors - add if necessary
            if (lv.getPreProcessor() == null) {
                //But don't override preprocessors that are manually defined; if none has been defined,
                //add the appropriate preprocessor for this input type/layer combination
                InputPreProcessor preproc = l.getPreProcessorForInputType(layerInput);
                lv.setPreProcessor(preproc);
            }
            //Set nIn value for layer (if not already set)
            InputType afterPreproc = layerInput;
            if (lv.getPreProcessor() != null) {
                InputPreProcessor ip = lv.getPreProcessor();
                afterPreproc = ip.getOutputType(layerInput);
            }
            l.setNIn(afterPreproc, false);
            currLayerIdx++;
        } else {
            List<String> inputs = vertexInputs.get(s);
            if (inputs != null) {
                for (String inputVertexName : inputs) {
                    inputTypeList.add(vertexOutputs.get(inputVertexName));
                }
            }
        }
        InputType outputFromVertex = gv.getOutputType(currLayerIdx, inputTypeList.toArray(new InputType[inputTypeList.size()]));
        vertexOutputs.put(s, outputFromVertex);
    }
}
Also used : LayerVertex(org.deeplearning4j.nn.conf.graph.LayerVertex) GraphVertex(org.deeplearning4j.nn.conf.graph.GraphVertex) OutputLayer(org.deeplearning4j.nn.conf.layers.OutputLayer) Layer(org.deeplearning4j.nn.conf.layers.Layer) InputType(org.deeplearning4j.nn.conf.inputs.InputType)

Example 4 with LayerVertex

use of org.deeplearning4j.nn.conf.graph.LayerVertex in project deeplearning4j by deeplearning4j.

the class RegressionTest060 method regressionTestCGLSTM1.

@Test
public void regressionTestCGLSTM1() throws Exception {
    File f = new ClassPathResource("regression_testing/060/060_ModelSerializer_Regression_CG_LSTM_1.zip").getTempFileFromArchive();
    ComputationGraph net = ModelSerializer.restoreComputationGraph(f, true);
    ComputationGraphConfiguration conf = net.getConfiguration();
    assertEquals(3, conf.getVertices().size());
    assertTrue(conf.isBackprop());
    assertFalse(conf.isPretrain());
    GravesLSTM l0 = (GravesLSTM) ((LayerVertex) conf.getVertices().get("0")).getLayerConf().getLayer();
    assertEquals("tanh", l0.getActivationFn().toString());
    assertEquals(3, l0.getNIn());
    assertEquals(4, l0.getNOut());
    assertEquals(GradientNormalization.ClipElementWiseAbsoluteValue, l0.getGradientNormalization());
    assertEquals(1.5, l0.getGradientNormalizationThreshold(), 1e-5);
    GravesBidirectionalLSTM l1 = (GravesBidirectionalLSTM) ((LayerVertex) conf.getVertices().get("1")).getLayerConf().getLayer();
    assertEquals("softsign", l1.getActivationFn().toString());
    assertEquals(4, l1.getNIn());
    assertEquals(4, l1.getNOut());
    assertEquals(GradientNormalization.ClipElementWiseAbsoluteValue, l1.getGradientNormalization());
    assertEquals(1.5, l1.getGradientNormalizationThreshold(), 1e-5);
    RnnOutputLayer l2 = (RnnOutputLayer) ((LayerVertex) conf.getVertices().get("2")).getLayerConf().getLayer();
    assertEquals(4, l2.getNIn());
    assertEquals(5, l2.getNOut());
    assertEquals("softmax", l2.getActivationFn().toString());
    assertTrue(l2.getLossFn() instanceof LossMCXENT);
}
Also used : LayerVertex(org.deeplearning4j.nn.conf.graph.LayerVertex) LossMCXENT(org.nd4j.linalg.lossfunctions.impl.LossMCXENT) ComputationGraph(org.deeplearning4j.nn.graph.ComputationGraph) File(java.io.File) ClassPathResource(org.nd4j.linalg.io.ClassPathResource) Test(org.junit.Test)

Example 5 with LayerVertex

use of org.deeplearning4j.nn.conf.graph.LayerVertex in project deeplearning4j by deeplearning4j.

the class RegressionTest071 method regressionTestCGLSTM1.

@Test
public void regressionTestCGLSTM1() throws Exception {
    File f = new ClassPathResource("regression_testing/071/071_ModelSerializer_Regression_CG_LSTM_1.zip").getTempFileFromArchive();
    ComputationGraph net = ModelSerializer.restoreComputationGraph(f, true);
    ComputationGraphConfiguration conf = net.getConfiguration();
    assertEquals(3, conf.getVertices().size());
    assertTrue(conf.isBackprop());
    assertFalse(conf.isPretrain());
    GravesLSTM l0 = (GravesLSTM) ((LayerVertex) conf.getVertices().get("0")).getLayerConf().getLayer();
    assertEquals("tanh", l0.getActivationFn().toString());
    assertEquals(3, l0.getNIn());
    assertEquals(4, l0.getNOut());
    assertEquals(GradientNormalization.ClipElementWiseAbsoluteValue, l0.getGradientNormalization());
    assertEquals(1.5, l0.getGradientNormalizationThreshold(), 1e-5);
    GravesBidirectionalLSTM l1 = (GravesBidirectionalLSTM) ((LayerVertex) conf.getVertices().get("1")).getLayerConf().getLayer();
    assertEquals("softsign", l1.getActivationFn().toString());
    assertEquals(4, l1.getNIn());
    assertEquals(4, l1.getNOut());
    assertEquals(GradientNormalization.ClipElementWiseAbsoluteValue, l1.getGradientNormalization());
    assertEquals(1.5, l1.getGradientNormalizationThreshold(), 1e-5);
    RnnOutputLayer l2 = (RnnOutputLayer) ((LayerVertex) conf.getVertices().get("2")).getLayerConf().getLayer();
    assertEquals(4, l2.getNIn());
    assertEquals(5, l2.getNOut());
    assertEquals("softmax", l2.getActivationFn().toString());
    assertTrue(l2.getLossFn() instanceof LossMCXENT);
}
Also used : LayerVertex(org.deeplearning4j.nn.conf.graph.LayerVertex) LossMCXENT(org.nd4j.linalg.lossfunctions.impl.LossMCXENT) ComputationGraph(org.deeplearning4j.nn.graph.ComputationGraph) File(java.io.File) ClassPathResource(org.nd4j.linalg.io.ClassPathResource) Test(org.junit.Test)

Aggregations

LayerVertex (org.deeplearning4j.nn.conf.graph.LayerVertex)10 GraphVertex (org.deeplearning4j.nn.conf.graph.GraphVertex)6 Test (org.junit.Test)5 Layer (org.deeplearning4j.nn.conf.layers.Layer)3 File (java.io.File)2 ComputationGraphConfiguration (org.deeplearning4j.nn.conf.ComputationGraphConfiguration)2 NeuralNetConfiguration (org.deeplearning4j.nn.conf.NeuralNetConfiguration)2 Updater (org.deeplearning4j.nn.conf.Updater)2 OutputLayer (org.deeplearning4j.nn.conf.layers.OutputLayer)2 ComputationGraph (org.deeplearning4j.nn.graph.ComputationGraph)2 ComputationGraphUpdater (org.deeplearning4j.nn.updater.graph.ComputationGraphUpdater)2 IActivation (org.nd4j.linalg.activations.IActivation)2 INDArray (org.nd4j.linalg.api.ndarray.INDArray)2 ClassPathResource (org.nd4j.linalg.io.ClassPathResource)2 LossMCXENT (org.nd4j.linalg.lossfunctions.impl.LossMCXENT)2 IOException (java.io.IOException)1 Field (java.lang.reflect.Field)1 ArrayList (java.util.ArrayList)1 HashMap (java.util.HashMap)1 LinkedHashMap (java.util.LinkedHashMap)1