Search in sources :

Example 1 with Layer

use of org.deeplearning4j.nn.conf.layers.Layer in project deeplearning4j by deeplearning4j.

the class ComputationGraphConfiguration method addPreProcessors.

/**
     * Add preprocessors automatically, given the specified types of inputs for the network. Inputs are specified using the
     * {@link InputType} class, in the same order in which the inputs were defined in the original configuration.<br>
     * For example, in a network with two inputs: a convolutional input (28x28x1 images) and feed forward inputs, use
     * {@code .addPreProcessors(InputType.convolutional(1,28,28),InputType.feedForward())}.<br>
     * For the CNN->Dense and CNN->RNN transitions, the nIns on the Dense/RNN layers will also be added automatically.
     * <b>NOTE</b>: This method will be called automatically when using the
     * {@link GraphBuilder#setInputTypes(InputType...)} functionality.
     * See that method for details.
     */
public void addPreProcessors(InputType... inputTypes) {
    if (inputTypes == null || inputTypes.length != networkInputs.size()) {
        throw new IllegalArgumentException("Invalid number of InputTypes: cannot add preprocessors if number of InputType " + "objects differs from number of network inputs");
    }
    //Now: need to do essentially a forward pass through the network, to work out what type of preprocessors to add
    //To do this: need to know what the output types are for each GraphVertex.
    //First step: build network in reverse order (i.e., define map of a -> list(b) instead of list(a) -> b)
    //Key: vertex. Values: vertices that this node is an input for
    Map<String, List<String>> verticesOutputTo = new HashMap<>();
    for (Map.Entry<String, GraphVertex> entry : vertices.entrySet()) {
        String vertexName = entry.getKey();
        List<String> vertexInputNames;
        vertexInputNames = vertexInputs.get(vertexName);
        if (vertexInputNames == null)
            continue;
        //Build reverse network structure:
        for (String s : vertexInputNames) {
            List<String> list = verticesOutputTo.get(s);
            if (list == null) {
                list = new ArrayList<>();
                verticesOutputTo.put(s, list);
            }
            //Edge: s -> vertexName
            list.add(vertexName);
        }
    }
    //Now: do topological sort
    //Set of all nodes with no incoming edges
    LinkedList<String> noIncomingEdges = new LinkedList<>(networkInputs);
    List<String> topologicalOrdering = new ArrayList<>();
    Map<String, Set<String>> inputEdges = new HashMap<>();
    for (Map.Entry<String, List<String>> entry : vertexInputs.entrySet()) {
        inputEdges.put(entry.getKey(), new HashSet<>(entry.getValue()));
    }
    while (!noIncomingEdges.isEmpty()) {
        String next = noIncomingEdges.removeFirst();
        topologicalOrdering.add(next);
        //Remove edges next -> vertexOuputsTo[...] from graph;
        List<String> nextEdges = verticesOutputTo.get(next);
        if (nextEdges != null && !nextEdges.isEmpty()) {
            for (String s : nextEdges) {
                Set<String> set = inputEdges.get(s);
                set.remove(next);
                if (set.isEmpty()) {
                    //No remaining edges for vertex i -> add to list for processing
                    noIncomingEdges.add(s);
                }
            }
        }
    }
    //If any edges remain in the graph: graph has cycles:
    for (Map.Entry<String, Set<String>> entry : inputEdges.entrySet()) {
        Set<String> set = entry.getValue();
        if (set == null)
            continue;
        if (!set.isEmpty())
            throw new IllegalStateException("Invalid configuration: cycle detected in graph. Cannot calculate topological ordering with graph cycle (" + "cycle includes vertex \"" + entry.getKey() + "\")");
    }
    //Now, given the topological sort: do equivalent of forward pass
    Map<String, InputType> vertexOutputs = new HashMap<>();
    int currLayerIdx = -1;
    for (String s : topologicalOrdering) {
        int inputIdx = networkInputs.indexOf(s);
        if (inputIdx != -1) {
            vertexOutputs.put(s, inputTypes[inputIdx]);
            continue;
        }
        GraphVertex gv = vertices.get(s);
        List<InputType> inputTypeList = new ArrayList<>();
        if (gv instanceof LayerVertex) {
            //Add preprocessor, if necessary:
            String in = vertexInputs.get(s).get(0);
            InputType layerInput = vertexOutputs.get(in);
            inputTypeList.add(layerInput);
            LayerVertex lv = (LayerVertex) gv;
            Layer l = lv.getLayerConf().getLayer();
            //Preprocessors - add if necessary
            if (lv.getPreProcessor() == null) {
                //But don't override preprocessors that are manually defined; if none has been defined,
                //add the appropriate preprocessor for this input type/layer combination
                InputPreProcessor preproc = l.getPreProcessorForInputType(layerInput);
                lv.setPreProcessor(preproc);
            }
            //Set nIn value for layer (if not already set)
            InputType afterPreproc = layerInput;
            if (lv.getPreProcessor() != null) {
                InputPreProcessor ip = lv.getPreProcessor();
                afterPreproc = ip.getOutputType(layerInput);
            }
            l.setNIn(afterPreproc, false);
            currLayerIdx++;
        } else {
            List<String> inputs = vertexInputs.get(s);
            if (inputs != null) {
                for (String inputVertexName : inputs) {
                    inputTypeList.add(vertexOutputs.get(inputVertexName));
                }
            }
        }
        InputType outputFromVertex = gv.getOutputType(currLayerIdx, inputTypeList.toArray(new InputType[inputTypeList.size()]));
        vertexOutputs.put(s, outputFromVertex);
    }
}
Also used : LayerVertex(org.deeplearning4j.nn.conf.graph.LayerVertex) GraphVertex(org.deeplearning4j.nn.conf.graph.GraphVertex) OutputLayer(org.deeplearning4j.nn.conf.layers.OutputLayer) Layer(org.deeplearning4j.nn.conf.layers.Layer) InputType(org.deeplearning4j.nn.conf.inputs.InputType)

Example 2 with Layer

use of org.deeplearning4j.nn.conf.layers.Layer in project deeplearning4j by deeplearning4j.

the class TrainModule method getLayerInfoTable.

private String[][] getLayerInfoTable(int layerIdx, TrainModuleUtils.GraphInfo gi, I18N i18N, boolean noData, StatsStorage ss, String wid) {
    List<String[]> layerInfoRows = new ArrayList<>();
    layerInfoRows.add(new String[] { i18N.getMessage("train.model.layerinfotable.layerName"), gi.getVertexNames().get(layerIdx) });
    layerInfoRows.add(new String[] { i18N.getMessage("train.model.layerinfotable.layerType"), "" });
    if (!noData) {
        Persistable p = ss.getStaticInfo(currentSessionID, StatsListener.TYPE_ID, wid);
        if (p != null) {
            StatsInitializationReport initReport = (StatsInitializationReport) p;
            String configJson = initReport.getModelConfigJson();
            String modelClass = initReport.getModelClassName();
            //TODO error handling...
            String layerType = "";
            Layer layer = null;
            NeuralNetConfiguration nnc = null;
            if (modelClass.endsWith("MultiLayerNetwork")) {
                MultiLayerConfiguration conf = MultiLayerConfiguration.fromJson(configJson);
                //-1 because of input
                int confIdx = layerIdx - 1;
                if (confIdx >= 0) {
                    nnc = conf.getConf(confIdx);
                    layer = nnc.getLayer();
                } else {
                    //Input layer
                    layerType = "Input";
                }
            } else if (modelClass.endsWith("ComputationGraph")) {
                ComputationGraphConfiguration conf = ComputationGraphConfiguration.fromJson(configJson);
                String vertexName = gi.getVertexNames().get(layerIdx);
                Map<String, GraphVertex> vertices = conf.getVertices();
                if (vertices.containsKey(vertexName) && vertices.get(vertexName) instanceof LayerVertex) {
                    LayerVertex lv = (LayerVertex) vertices.get(vertexName);
                    nnc = lv.getLayerConf();
                    layer = nnc.getLayer();
                } else if (conf.getNetworkInputs().contains(vertexName)) {
                    layerType = "Input";
                } else {
                    GraphVertex gv = conf.getVertices().get(vertexName);
                    if (gv != null) {
                        layerType = gv.getClass().getSimpleName();
                    }
                }
            } else if (modelClass.endsWith("VariationalAutoencoder")) {
                layerType = gi.getVertexTypes().get(layerIdx);
                Map<String, String> map = gi.getVertexInfo().get(layerIdx);
                for (Map.Entry<String, String> entry : map.entrySet()) {
                    layerInfoRows.add(new String[] { entry.getKey(), entry.getValue() });
                }
            }
            if (layer != null) {
                layerType = getLayerType(layer);
            }
            if (layer != null) {
                String activationFn = null;
                if (layer instanceof FeedForwardLayer) {
                    FeedForwardLayer ffl = (FeedForwardLayer) layer;
                    layerInfoRows.add(new String[] { i18N.getMessage("train.model.layerinfotable.layerNIn"), String.valueOf(ffl.getNIn()) });
                    layerInfoRows.add(new String[] { i18N.getMessage("train.model.layerinfotable.layerSize"), String.valueOf(ffl.getNOut()) });
                    activationFn = layer.getActivationFn().toString();
                }
                int nParams = layer.initializer().numParams(nnc);
                layerInfoRows.add(new String[] { i18N.getMessage("train.model.layerinfotable.layerNParams"), String.valueOf(nParams) });
                if (nParams > 0) {
                    WeightInit wi = layer.getWeightInit();
                    String str = wi.toString();
                    if (wi == WeightInit.DISTRIBUTION) {
                        str += layer.getDist();
                    }
                    layerInfoRows.add(new String[] { i18N.getMessage("train.model.layerinfotable.layerWeightInit"), str });
                    Updater u = layer.getUpdater();
                    String us = (u == null ? "" : u.toString());
                    layerInfoRows.add(new String[] { i18N.getMessage("train.model.layerinfotable.layerUpdater"), us });
                //TODO: Maybe L1/L2, dropout, updater-specific values etc
                }
                if (layer instanceof ConvolutionLayer || layer instanceof SubsamplingLayer) {
                    int[] kernel;
                    int[] stride;
                    int[] padding;
                    if (layer instanceof ConvolutionLayer) {
                        ConvolutionLayer cl = (ConvolutionLayer) layer;
                        kernel = cl.getKernelSize();
                        stride = cl.getStride();
                        padding = cl.getPadding();
                    } else {
                        SubsamplingLayer ssl = (SubsamplingLayer) layer;
                        kernel = ssl.getKernelSize();
                        stride = ssl.getStride();
                        padding = ssl.getPadding();
                        activationFn = null;
                        layerInfoRows.add(new String[] { i18N.getMessage("train.model.layerinfotable.layerSubsamplingPoolingType"), ssl.getPoolingType().toString() });
                    }
                    layerInfoRows.add(new String[] { i18N.getMessage("train.model.layerinfotable.layerCnnKernel"), Arrays.toString(kernel) });
                    layerInfoRows.add(new String[] { i18N.getMessage("train.model.layerinfotable.layerCnnStride"), Arrays.toString(stride) });
                    layerInfoRows.add(new String[] { i18N.getMessage("train.model.layerinfotable.layerCnnPadding"), Arrays.toString(padding) });
                }
                if (activationFn != null) {
                    layerInfoRows.add(new String[] { i18N.getMessage("train.model.layerinfotable.layerActivationFn"), activationFn });
                }
            }
            layerInfoRows.get(1)[1] = layerType;
        }
    }
    return layerInfoRows.toArray(new String[layerInfoRows.size()][0]);
}
Also used : StatsInitializationReport(org.deeplearning4j.ui.stats.api.StatsInitializationReport) LayerVertex(org.deeplearning4j.nn.conf.graph.LayerVertex) Persistable(org.deeplearning4j.api.storage.Persistable) SubsamplingLayer(org.deeplearning4j.nn.conf.layers.SubsamplingLayer) WeightInit(org.deeplearning4j.nn.weights.WeightInit) NeuralNetConfiguration(org.deeplearning4j.nn.conf.NeuralNetConfiguration) ConvolutionLayer(org.deeplearning4j.nn.conf.layers.ConvolutionLayer) SubsamplingLayer(org.deeplearning4j.nn.conf.layers.SubsamplingLayer) FeedForwardLayer(org.deeplearning4j.nn.conf.layers.FeedForwardLayer) Layer(org.deeplearning4j.nn.conf.layers.Layer) ConvolutionLayer(org.deeplearning4j.nn.conf.layers.ConvolutionLayer) MultiLayerConfiguration(org.deeplearning4j.nn.conf.MultiLayerConfiguration) GraphVertex(org.deeplearning4j.nn.conf.graph.GraphVertex) Updater(org.deeplearning4j.nn.conf.Updater) ComputationGraphConfiguration(org.deeplearning4j.nn.conf.ComputationGraphConfiguration) FeedForwardLayer(org.deeplearning4j.nn.conf.layers.FeedForwardLayer)

Example 3 with Layer

use of org.deeplearning4j.nn.conf.layers.Layer in project deeplearning4j by deeplearning4j.

the class TransferLearningMLNTest method testFineTuneOverride.

@Test
public void testFineTuneOverride() {
    //Check that fine-tune overrides are selective - i.e., if I only specify a new LR, only the LR should be modified
    MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().learningRate(1e-4).updater(Updater.ADAM).activation(Activation.TANH).weightInit(WeightInit.RELU).regularization(true).l1(0.1).l2(0.2).list().layer(0, new DenseLayer.Builder().nIn(10).nOut(5).build()).layer(1, new OutputLayer.Builder().nIn(5).nOut(4).activation(Activation.HARDSIGMOID).build()).build();
    MultiLayerNetwork net = new MultiLayerNetwork(conf);
    net.init();
    MultiLayerNetwork net2 = new TransferLearning.Builder(net).fineTuneConfiguration(//Should be set on layers
    new FineTuneConfiguration.Builder().learningRate(2e-2).backpropType(//Should be set on MLC
    BackpropType.TruncatedBPTT).build()).build();
    //Check original net isn't modified:
    Layer l0 = net.getLayer(0).conf().getLayer();
    assertEquals(Updater.ADAM, l0.getUpdater());
    assertEquals(Activation.TANH.getActivationFunction(), l0.getActivationFn());
    assertEquals(1e-4, l0.getLearningRate(), 1e-8);
    assertEquals(WeightInit.RELU, l0.getWeightInit());
    assertEquals(0.1, l0.getL1(), 1e-6);
    Layer l1 = net.getLayer(1).conf().getLayer();
    assertEquals(Updater.ADAM, l1.getUpdater());
    assertEquals(Activation.HARDSIGMOID.getActivationFunction(), l1.getActivationFn());
    assertEquals(1e-4, l1.getLearningRate(), 1e-8);
    assertEquals(WeightInit.RELU, l1.getWeightInit());
    assertEquals(0.2, l1.getL2(), 1e-6);
    assertEquals(BackpropType.Standard, conf.getBackpropType());
    //Check new net has only the appropriate things modified (i.e., LR)
    l0 = net2.getLayer(0).conf().getLayer();
    assertEquals(Updater.ADAM, l0.getUpdater());
    assertEquals(Activation.TANH.getActivationFunction(), l0.getActivationFn());
    assertEquals(2e-2, l0.getLearningRate(), 1e-8);
    assertEquals(WeightInit.RELU, l0.getWeightInit());
    assertEquals(0.1, l0.getL1(), 1e-6);
    l1 = net2.getLayer(1).conf().getLayer();
    assertEquals(Updater.ADAM, l1.getUpdater());
    assertEquals(Activation.HARDSIGMOID.getActivationFunction(), l1.getActivationFn());
    assertEquals(2e-2, l1.getLearningRate(), 1e-8);
    assertEquals(WeightInit.RELU, l1.getWeightInit());
    assertEquals(0.2, l1.getL2(), 1e-6);
    assertEquals(BackpropType.TruncatedBPTT, net2.getLayerWiseConfigurations().getBackpropType());
}
Also used : MultiLayerNetwork(org.deeplearning4j.nn.multilayer.MultiLayerNetwork) Layer(org.deeplearning4j.nn.conf.layers.Layer) Test(org.junit.Test)

Example 4 with Layer

use of org.deeplearning4j.nn.conf.layers.Layer in project deeplearning4j by deeplearning4j.

the class ComputationGraphConfiguration method fromJson.

/**
     * Create a computation graph configuration from json
     *
     * @param json the neural net configuration from json
     * @return {@link ComputationGraphConfiguration}
     */
public static ComputationGraphConfiguration fromJson(String json) {
    //As per MultiLayerConfiguration.fromJson()
    ObjectMapper mapper = NeuralNetConfiguration.mapper();
    ComputationGraphConfiguration conf;
    try {
        conf = mapper.readValue(json, ComputationGraphConfiguration.class);
    } catch (IOException e) {
        throw new RuntimeException(e);
    }
    //To maintain backward compatibility after activation function refactoring (configs generated with v0.7.1 or earlier)
    // Previously: enumeration used for activation functions. Now: use classes
    int layerCount = 0;
    Map<String, GraphVertex> vertexMap = conf.getVertices();
    JsonNode vertices = null;
    for (Map.Entry<String, GraphVertex> entry : vertexMap.entrySet()) {
        if (!(entry.getValue() instanceof LayerVertex)) {
            continue;
        }
        LayerVertex lv = (LayerVertex) entry.getValue();
        if (lv.getLayerConf() != null && lv.getLayerConf().getLayer() != null) {
            Layer layer = lv.getLayerConf().getLayer();
            if (layer.getActivationFn() == null) {
                String layerName = layer.getLayerName();
                try {
                    if (vertices == null) {
                        JsonNode jsonNode = mapper.readTree(json);
                        vertices = jsonNode.get("vertices");
                    }
                    JsonNode vertexNode = vertices.get(layerName);
                    JsonNode layerVertexNode = vertexNode.get("LayerVertex");
                    if (layerVertexNode == null || !layerVertexNode.has("layerConf") || !layerVertexNode.get("layerConf").has("layer")) {
                        continue;
                    }
                    JsonNode layerWrapperNode = layerVertexNode.get("layerConf").get("layer");
                    if (layerWrapperNode == null || layerWrapperNode.size() != 1) {
                        continue;
                    }
                    JsonNode layerNode = layerWrapperNode.elements().next();
                    //Should only have 1 element: "dense", "output", etc
                    JsonNode activationFunction = layerNode.get("activationFunction");
                    if (activationFunction != null) {
                        IActivation ia = Activation.fromString(activationFunction.asText()).getActivationFunction();
                        layer.setActivationFn(ia);
                    }
                } catch (IOException e) {
                    log.warn("Layer with null ActivationFn field or pre-0.7.2 activation function detected: could not parse JSON", e);
                }
            }
        }
    }
    return conf;
}
Also used : LayerVertex(org.deeplearning4j.nn.conf.graph.LayerVertex) JsonNode(org.nd4j.shade.jackson.databind.JsonNode) IOException(java.io.IOException) IActivation(org.nd4j.linalg.activations.IActivation) OutputLayer(org.deeplearning4j.nn.conf.layers.OutputLayer) Layer(org.deeplearning4j.nn.conf.layers.Layer) GraphVertex(org.deeplearning4j.nn.conf.graph.GraphVertex) ObjectMapper(org.nd4j.shade.jackson.databind.ObjectMapper)

Example 5 with Layer

use of org.deeplearning4j.nn.conf.layers.Layer in project deeplearning4j by deeplearning4j.

the class FineTuneConfiguration method applyToNeuralNetConfiguration.

public void applyToNeuralNetConfiguration(NeuralNetConfiguration nnc) {
    Layer l = nnc.getLayer();
    Updater originalUpdater = null;
    WeightInit origWeightInit = null;
    if (l != null) {
        originalUpdater = l.getUpdater();
        origWeightInit = l.getWeightInit();
        if (activationFn != null)
            l.setActivationFn(activationFn);
        if (weightInit != null)
            l.setWeightInit(weightInit);
        if (biasInit != null)
            l.setBiasInit(biasInit);
        if (dist != null)
            l.setDist(dist);
        if (learningRate != null) {
            //usually the same learning rate is applied to both bias and weights
            //so always overwrite the learning rate to both?
            l.setLearningRate(learningRate);
            l.setBiasLearningRate(learningRate);
        }
        if (biasLearningRate != null)
            l.setBiasLearningRate(biasLearningRate);
        if (learningRateSchedule != null)
            l.setLearningRateSchedule(learningRateSchedule);
        //        if(lrScoreBasedDecay != null)
        if (l1 != null)
            l.setL1(l1);
        if (l2 != null)
            l.setL2(l2);
        if (l1Bias != null)
            l.setL1Bias(l1Bias);
        if (l2Bias != null)
            l.setL2Bias(l2Bias);
        if (dropOut != null)
            l.setDropOut(dropOut);
        if (updater != null)
            l.setUpdater(updater);
        if (momentum != null)
            l.setMomentum(momentum);
        if (momentumSchedule != null)
            l.setMomentum(momentum);
        if (epsilon != null)
            l.setEpsilon(epsilon);
        if (rho != null)
            l.setRho(rho);
        if (rmsDecay != null)
            l.setRmsDecay(rmsDecay);
        if (adamMeanDecay != null)
            l.setAdamMeanDecay(adamMeanDecay);
        if (adamVarDecay != null)
            l.setAdamVarDecay(adamVarDecay);
    }
    if (miniBatch != null)
        nnc.setMiniBatch(miniBatch);
    if (numIterations != null)
        nnc.setNumIterations(numIterations);
    if (maxNumLineSearchIterations != null)
        nnc.setMaxNumLineSearchIterations(maxNumLineSearchIterations);
    if (seed != null)
        nnc.setSeed(seed);
    if (useRegularization != null)
        nnc.setUseRegularization(useRegularization);
    if (optimizationAlgo != null)
        nnc.setOptimizationAlgo(optimizationAlgo);
    if (stepFunction != null)
        nnc.setStepFunction(stepFunction);
    if (useDropConnect != null)
        nnc.setUseDropConnect(useDropConnect);
    if (minimize != null)
        nnc.setMinimize(minimize);
    if (gradientNormalization != null)
        l.setGradientNormalization(gradientNormalization);
    if (gradientNormalizationThreshold != null)
        l.setGradientNormalizationThreshold(gradientNormalizationThreshold);
    if (learningRatePolicy != null)
        nnc.setLearningRatePolicy(learningRatePolicy);
    if (lrPolicySteps != null)
        nnc.setLrPolicySteps(lrPolicySteps);
    if (lrPolicyPower != null)
        nnc.setLrPolicyPower(lrPolicyPower);
    if (convolutionMode != null && l instanceof ConvolutionLayer) {
        ((ConvolutionLayer) l).setConvolutionMode(convolutionMode);
    }
    if (convolutionMode != null && l instanceof SubsamplingLayer) {
        ((SubsamplingLayer) l).setConvolutionMode(convolutionMode);
    }
    //Check the updater config. If we change updaters, we want to remove the old config to avoid warnings
    if (l != null && updater != null && originalUpdater != null && updater != originalUpdater) {
        switch(originalUpdater) {
            case ADAM:
                if (adamMeanDecay == null)
                    l.setAdamMeanDecay(Double.NaN);
                if (adamVarDecay == null)
                    l.setAdamVarDecay(Double.NaN);
                break;
            case ADADELTA:
                if (rho == null)
                    l.setRho(Double.NaN);
                if (epsilon == null)
                    l.setEpsilon(Double.NaN);
                break;
            case NESTEROVS:
                if (momentum == null)
                    l.setMomentum(Double.NaN);
                if (momentumSchedule == null)
                    l.setMomentumSchedule(null);
                if (epsilon == null)
                    l.setEpsilon(Double.NaN);
                break;
            case ADAGRAD:
                if (epsilon == null)
                    l.setEpsilon(Double.NaN);
                break;
            case RMSPROP:
                if (rmsDecay == null)
                    l.setRmsDecay(Double.NaN);
                if (epsilon == null)
                    l.setEpsilon(Double.NaN);
                break;
        }
    }
    //Check weight init. Remove dist if originally was DISTRIBUTION, and isn't now -> remove no longer needed distribution
    if (l != null && origWeightInit == WeightInit.DISTRIBUTION && weightInit != null && weightInit != WeightInit.DISTRIBUTION) {
        l.setDist(null);
    }
    //Perform validation. This also sets the defaults for updaters. For example, Updater.RMSProp -> set rmsDecay
    if (l != null) {
        LayerValidation.updaterValidation(l.getLayerName(), l, momentum, momentumSchedule, adamMeanDecay, adamVarDecay, rho, rmsDecay, epsilon);
        boolean useDropCon = (useDropConnect == null ? nnc.isUseDropConnect() : useDropConnect);
        LayerValidation.generalValidation(l.getLayerName(), l, nnc.isUseRegularization(), useDropCon, dropOut, l2, l2Bias, l1, l1Bias, dist);
    }
    //Also: update the LR, L1 and L2 maps, based on current config (which might be different to original config)
    if (nnc.variables(false) != null) {
        for (String s : nnc.variables(false)) {
            nnc.setLayerParamLR(s);
        }
    }
}
Also used : SubsamplingLayer(org.deeplearning4j.nn.conf.layers.SubsamplingLayer) WeightInit(org.deeplearning4j.nn.weights.WeightInit) ConvolutionLayer(org.deeplearning4j.nn.conf.layers.ConvolutionLayer) SubsamplingLayer(org.deeplearning4j.nn.conf.layers.SubsamplingLayer) Layer(org.deeplearning4j.nn.conf.layers.Layer) ConvolutionLayer(org.deeplearning4j.nn.conf.layers.ConvolutionLayer)

Aggregations

Layer (org.deeplearning4j.nn.conf.layers.Layer)5 GraphVertex (org.deeplearning4j.nn.conf.graph.GraphVertex)3 LayerVertex (org.deeplearning4j.nn.conf.graph.LayerVertex)3 ConvolutionLayer (org.deeplearning4j.nn.conf.layers.ConvolutionLayer)2 OutputLayer (org.deeplearning4j.nn.conf.layers.OutputLayer)2 SubsamplingLayer (org.deeplearning4j.nn.conf.layers.SubsamplingLayer)2 WeightInit (org.deeplearning4j.nn.weights.WeightInit)2 IOException (java.io.IOException)1 Persistable (org.deeplearning4j.api.storage.Persistable)1 ComputationGraphConfiguration (org.deeplearning4j.nn.conf.ComputationGraphConfiguration)1 MultiLayerConfiguration (org.deeplearning4j.nn.conf.MultiLayerConfiguration)1 NeuralNetConfiguration (org.deeplearning4j.nn.conf.NeuralNetConfiguration)1 Updater (org.deeplearning4j.nn.conf.Updater)1 InputType (org.deeplearning4j.nn.conf.inputs.InputType)1 FeedForwardLayer (org.deeplearning4j.nn.conf.layers.FeedForwardLayer)1 MultiLayerNetwork (org.deeplearning4j.nn.multilayer.MultiLayerNetwork)1 StatsInitializationReport (org.deeplearning4j.ui.stats.api.StatsInitializationReport)1 Test (org.junit.Test)1 IActivation (org.nd4j.linalg.activations.IActivation)1 JsonNode (org.nd4j.shade.jackson.databind.JsonNode)1