Search in sources :

Example 6 with FrozenLayer

use of org.deeplearning4j.nn.layers.FrozenLayer in project deeplearning4j by deeplearning4j.

the class TransferLearningComplex method testLessSimpleMergeBackProp.

@Test
public void testLessSimpleMergeBackProp() {
    NeuralNetConfiguration.Builder overallConf = new NeuralNetConfiguration.Builder().learningRate(0.9).activation(Activation.IDENTITY).optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).updater(Updater.SGD);
    /*
                inCentre                inRight
                   |                        |
             denseCentre0               denseRight0
                   |                        |
                   |------ mergeRight ------|
                   |            |
                 outCentre     outRight
        
        */
    ComputationGraphConfiguration conf = overallConf.graphBuilder().addInputs("inCentre", "inRight").addLayer("denseCentre0", new DenseLayer.Builder().nIn(2).nOut(2).build(), "inCentre").addLayer("outCentre", new OutputLayer.Builder(LossFunctions.LossFunction.MSE).nIn(2).nOut(2).build(), "denseCentre0").addLayer("denseRight0", new DenseLayer.Builder().nIn(3).nOut(2).build(), "inRight").addVertex("mergeRight", new MergeVertex(), "denseCentre0", "denseRight0").addLayer("outRight", new OutputLayer.Builder(LossFunctions.LossFunction.MSE).nIn(4).nOut(2).build(), "mergeRight").setOutputs("outRight").setOutputs("outCentre").build();
    ComputationGraph modelToTune = new ComputationGraph(conf);
    modelToTune.init();
    modelToTune.getVertex("denseCentre0").setLayerAsFrozen();
    MultiDataSet randData = new MultiDataSet(new INDArray[] { Nd4j.rand(2, 2), Nd4j.rand(2, 3) }, new INDArray[] { Nd4j.rand(2, 2), Nd4j.rand(2, 2) });
    INDArray denseCentre0 = modelToTune.feedForward(randData.getFeatures(), false).get("denseCentre0");
    MultiDataSet otherRandData = new MultiDataSet(new INDArray[] { denseCentre0, randData.getFeatures(1) }, randData.getLabels());
    ComputationGraph modelNow = new TransferLearning.GraphBuilder(modelToTune).setFeatureExtractor("denseCentre0").build();
    assertTrue(modelNow.getLayer("denseCentre0") instanceof FrozenLayer);
    int n = 0;
    while (n < 5) {
        if (n == 0) {
            //confirm activations out of the merge are equivalent
            assertEquals(modelToTune.feedForward(randData.getFeatures(), false).get("mergeRight"), modelNow.feedForward(otherRandData.getFeatures(), false).get("mergeRight"));
        }
        //confirm activations out of frozen vertex is the same as the input to the other model
        modelToTune.fit(randData);
        modelNow.fit(randData);
        assertEquals(otherRandData.getFeatures(0), modelNow.feedForward(randData.getFeatures(), false).get("denseCentre0"));
        assertEquals(otherRandData.getFeatures(0), modelToTune.feedForward(randData.getFeatures(), false).get("denseCentre0"));
        assertEquals(modelToTune.getLayer("denseRight0").params(), modelNow.getLayer("denseRight0").params());
        assertEquals(modelToTune.getLayer("outRight").params(), modelNow.getLayer("outRight").params());
        assertEquals(modelToTune.getLayer("outCentre").params(), modelNow.getLayer("outCentre").params());
        n++;
    }
}
Also used : OutputLayer(org.deeplearning4j.nn.conf.layers.OutputLayer) NeuralNetConfiguration(org.deeplearning4j.nn.conf.NeuralNetConfiguration) MergeVertex(org.deeplearning4j.nn.conf.graph.MergeVertex) FrozenLayer(org.deeplearning4j.nn.layers.FrozenLayer) DenseLayer(org.deeplearning4j.nn.conf.layers.DenseLayer) MultiDataSet(org.nd4j.linalg.dataset.MultiDataSet) INDArray(org.nd4j.linalg.api.ndarray.INDArray) ComputationGraphConfiguration(org.deeplearning4j.nn.conf.ComputationGraphConfiguration) ComputationGraph(org.deeplearning4j.nn.graph.ComputationGraph) Test(org.junit.Test)

Example 7 with FrozenLayer

use of org.deeplearning4j.nn.layers.FrozenLayer in project deeplearning4j by deeplearning4j.

the class TransferLearningComplex method testMergeAndFreeze.

@Test
public void testMergeAndFreeze() {
    // in1 -> A -> B -> merge, in2 -> C -> merge -> D -> out
    //Goal here: test a number of things...
    // (a) Ensure that freezing C doesn't impact A and B. Only C should be frozen in this config
    // (b) Test global override (should be selective)
    ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().updater(Updater.ADAM).learningRate(1e-4).activation(Activation.LEAKYRELU).graphBuilder().addInputs("in1", "in2").addLayer("A", new DenseLayer.Builder().nIn(10).nOut(9).build(), "in1").addLayer("B", new DenseLayer.Builder().nIn(9).nOut(8).build(), "A").addLayer("C", new DenseLayer.Builder().nIn(7).nOut(6).build(), "in2").addLayer("D", new DenseLayer.Builder().nIn(8 + 7).nOut(5).build(), "B", "C").addLayer("out", new OutputLayer.Builder().nIn(5).nOut(4).build(), "D").setOutputs("out").build();
    ComputationGraph graph = new ComputationGraph(conf);
    graph.init();
    int[] topologicalOrder = graph.topologicalSortOrder();
    org.deeplearning4j.nn.graph.vertex.GraphVertex[] vertices = graph.getVertices();
    for (int i = 0; i < topologicalOrder.length; i++) {
        org.deeplearning4j.nn.graph.vertex.GraphVertex v = vertices[topologicalOrder[i]];
        log.info(i + "\t" + v.getVertexName());
    }
    ComputationGraph graph2 = new TransferLearning.GraphBuilder(graph).fineTuneConfiguration(new FineTuneConfiguration.Builder().learningRate(2e-2).build()).setFeatureExtractor("C").build();
    boolean cFound = false;
    Layer[] layers = graph2.getLayers();
    for (Layer l : layers) {
        String name = l.conf().getLayer().getLayerName();
        log.info(name + "\t frozen: " + (l instanceof FrozenLayer));
        if ("C".equals(l.conf().getLayer().getLayerName())) {
            //Only C should be frozen in this config
            cFound = true;
            assertTrue(name, l instanceof FrozenLayer);
        } else {
            assertFalse(name, l instanceof FrozenLayer);
        }
        //Also check config:
        assertEquals(Updater.ADAM, l.conf().getLayer().getUpdater());
        assertEquals(2e-2, l.conf().getLayer().getLearningRate(), 1e-5);
        assertEquals(Activation.LEAKYRELU.getActivationFunction(), l.conf().getLayer().getActivationFn());
    }
    assertTrue(cFound);
}
Also used : Layer(org.deeplearning4j.nn.api.Layer) OutputLayer(org.deeplearning4j.nn.conf.layers.OutputLayer) DenseLayer(org.deeplearning4j.nn.conf.layers.DenseLayer) FrozenLayer(org.deeplearning4j.nn.layers.FrozenLayer) FrozenLayer(org.deeplearning4j.nn.layers.FrozenLayer) DenseLayer(org.deeplearning4j.nn.conf.layers.DenseLayer) ComputationGraphConfiguration(org.deeplearning4j.nn.conf.ComputationGraphConfiguration) ComputationGraph(org.deeplearning4j.nn.graph.ComputationGraph) Test(org.junit.Test)

Example 8 with FrozenLayer

use of org.deeplearning4j.nn.layers.FrozenLayer in project deeplearning4j by deeplearning4j.

the class ComputationGraph method clone.

@Override
public ComputationGraph clone() {
    ComputationGraph cg = new ComputationGraph(configuration.clone());
    cg.init(params().dup(), false);
    if (solver != null) {
        //If  solver is null: updater hasn't been initialized -> getUpdater call will force initialization, however
        ComputationGraphUpdater u = this.getUpdater();
        INDArray updaterState = u.getStateViewArray();
        if (updaterState != null) {
            cg.getUpdater().setStateViewArray(updaterState.dup());
        }
    }
    cg.listeners = this.listeners;
    for (int i = 0; i < topologicalOrder.length; i++) {
        if (!vertices[topologicalOrder[i]].hasLayer())
            continue;
        String layerName = vertices[topologicalOrder[i]].getVertexName();
        if (getLayer(layerName) instanceof FrozenLayer) {
            cg.getVertex(layerName).setLayerAsFrozen();
        }
    }
    return cg;
}
Also used : FrozenLayer(org.deeplearning4j.nn.layers.FrozenLayer) INDArray(org.nd4j.linalg.api.ndarray.INDArray) ComputationGraphUpdater(org.deeplearning4j.nn.updater.graph.ComputationGraphUpdater)

Example 9 with FrozenLayer

use of org.deeplearning4j.nn.layers.FrozenLayer in project deeplearning4j by deeplearning4j.

the class MultiLayerNetwork method summary.

/**
     * String detailing the architecture of the multilayernetwork.
     * Columns are LayerIndex with layer type, nIn, nOut, Total number of parameters and the Shapes of the parameters
     * Will also give information about frozen layers, if any.
     * @return Summary as a string
     */
public String summary() {
    String ret = "\n";
    ret += StringUtils.repeat("=", 140);
    ret += "\n";
    ret += String.format("%-40s%-15s%-15s%-30s\n", "LayerName (LayerType)", "nIn,nOut", "TotalParams", "ParamsShape");
    ret += StringUtils.repeat("=", 140);
    ret += "\n";
    int frozenParams = 0;
    for (Layer currentLayer : layers) {
        String name = String.valueOf(currentLayer.getIndex());
        String paramShape = "-";
        String in = "-";
        String out = "-";
        String[] classNameArr = currentLayer.getClass().getName().split("\\.");
        String className = classNameArr[classNameArr.length - 1];
        String paramCount = String.valueOf(currentLayer.numParams());
        if (currentLayer.numParams() > 0) {
            paramShape = "";
            in = String.valueOf(((FeedForwardLayer) currentLayer.conf().getLayer()).getNIn());
            out = String.valueOf(((FeedForwardLayer) currentLayer.conf().getLayer()).getNOut());
            Set<String> paraNames = currentLayer.conf().getLearningRateByParam().keySet();
            for (String aP : paraNames) {
                String paramS = ArrayUtils.toString(currentLayer.paramTable().get(aP).shape());
                paramShape += aP + ":" + paramS + ", ";
            }
            paramShape = paramShape.subSequence(0, paramShape.lastIndexOf(",")).toString();
        }
        if (currentLayer instanceof FrozenLayer) {
            frozenParams += currentLayer.numParams();
            classNameArr = ((FrozenLayer) currentLayer).getInsideLayer().getClass().getName().split("\\.");
            className = "Frozen " + classNameArr[classNameArr.length - 1];
        }
        ret += String.format("%-40s%-15s%-15s%-30s", name + " (" + className + ")", in + "," + out, paramCount, paramShape);
        ret += "\n";
    }
    ret += StringUtils.repeat("-", 140);
    ret += String.format("\n%30s %d", "Total Parameters: ", params().length());
    ret += String.format("\n%30s %d", "Trainable Parameters: ", params().length() - frozenParams);
    ret += String.format("\n%30s %d", "Frozen Parameters: ", frozenParams);
    ret += "\n";
    ret += StringUtils.repeat("=", 140);
    ret += "\n";
    return ret;
}
Also used : FrozenLayer(org.deeplearning4j.nn.layers.FrozenLayer) FeedForwardLayer(org.deeplearning4j.nn.conf.layers.FeedForwardLayer) FeedForwardLayer(org.deeplearning4j.nn.conf.layers.FeedForwardLayer) FrozenLayer(org.deeplearning4j.nn.layers.FrozenLayer) IOutputLayer(org.deeplearning4j.nn.api.layers.IOutputLayer) RecurrentLayer(org.deeplearning4j.nn.api.layers.RecurrentLayer)

Example 10 with FrozenLayer

use of org.deeplearning4j.nn.layers.FrozenLayer in project deeplearning4j by deeplearning4j.

the class MultiLayerNetwork method calcBackpropGradients.

/** Calculate gradients and errors. Used in two places:
     * (a) backprop (for standard multi layer network learning)
     * (b) backpropGradient (layer method, for when MultiLayerNetwork is used as a layer)
     * @param epsilon Errors (technically errors .* activations). Not used if withOutputLayer = true
     * @param withOutputLayer if true: assume last layer is output layer, and calculate errors based on labels. In this
     *                        case, the epsilon input is not used (may/should be null).
     *                        If false: calculate backprop gradients
     * @return Gradients and the error (epsilon) at the input
     */
protected Pair<Gradient, INDArray> calcBackpropGradients(INDArray epsilon, boolean withOutputLayer) {
    if (flattenedGradients == null)
        initGradientsView();
    String multiGradientKey;
    Gradient gradient = new DefaultGradient(flattenedGradients);
    Layer currLayer;
    //calculate and apply the backward gradient for every layer
    /**
         * Skip the output layer for the indexing and just loop backwards updating the coefficients for each layer.
         * (when withOutputLayer == true)
         *
         * Activate applies the activation function for each layer and sets that as the input for the following layer.
         *
         * Typical literature contains most trivial case for the error calculation: wT * weights
         * This interpretation transpose a few things to get mini batch because ND4J is rows vs columns organization for params
         */
    int numLayers = getnLayers();
    //Store gradients is a list; used to ensure iteration order in DefaultGradient linked hash map. i.e., layer 0 first instead of output layer
    LinkedList<Triple<String, INDArray, Character>> gradientList = new LinkedList<>();
    int layerFrom;
    Pair<Gradient, INDArray> currPair;
    if (withOutputLayer) {
        if (!(getOutputLayer() instanceof IOutputLayer)) {
            log.warn("Warning: final layer isn't output layer. You cannot use backprop without an output layer.");
            return null;
        }
        IOutputLayer outputLayer = (IOutputLayer) getOutputLayer();
        if (labels == null)
            throw new IllegalStateException("No labels found");
        outputLayer.setLabels(labels);
        currPair = outputLayer.backpropGradient(null);
        for (Map.Entry<String, INDArray> entry : currPair.getFirst().gradientForVariable().entrySet()) {
            String origName = entry.getKey();
            multiGradientKey = String.valueOf(numLayers - 1) + "_" + origName;
            gradientList.addLast(new Triple<>(multiGradientKey, entry.getValue(), currPair.getFirst().flatteningOrderForVariable(origName)));
        }
        if (getLayerWiseConfigurations().getInputPreProcess(numLayers - 1) != null)
            currPair = new Pair<>(currPair.getFirst(), this.layerWiseConfigurations.getInputPreProcess(numLayers - 1).backprop(currPair.getSecond(), getInputMiniBatchSize()));
        layerFrom = numLayers - 2;
    } else {
        currPair = new Pair<>(null, epsilon);
        layerFrom = numLayers - 1;
    }
    // Calculate gradients for previous layers & drops output layer in count
    for (int j = layerFrom; j >= 0; j--) {
        currLayer = getLayer(j);
        if (currLayer instanceof FrozenLayer)
            break;
        currPair = currLayer.backpropGradient(currPair.getSecond());
        LinkedList<Triple<String, INDArray, Character>> tempList = new LinkedList<>();
        for (Map.Entry<String, INDArray> entry : currPair.getFirst().gradientForVariable().entrySet()) {
            String origName = entry.getKey();
            multiGradientKey = String.valueOf(j) + "_" + origName;
            tempList.addFirst(new Triple<>(multiGradientKey, entry.getValue(), currPair.getFirst().flatteningOrderForVariable(origName)));
        }
        for (Triple<String, INDArray, Character> triple : tempList) gradientList.addFirst(triple);
        //Pass epsilon through input processor before passing to next layer (if applicable)
        if (getLayerWiseConfigurations().getInputPreProcess(j) != null)
            currPair = new Pair<>(currPair.getFirst(), getLayerWiseConfigurations().getInputPreProcess(j).backprop(currPair.getSecond(), getInputMiniBatchSize()));
    }
    //Add gradients to Gradients (map), in correct order
    for (Triple<String, INDArray, Character> triple : gradientList) {
        gradient.setGradientFor(triple.getFirst(), triple.getSecond(), triple.getThird());
    }
    return new Pair<>(gradient, currPair.getSecond());
}
Also used : Gradient(org.deeplearning4j.nn.gradient.Gradient) DefaultGradient(org.deeplearning4j.nn.gradient.DefaultGradient) FeedForwardLayer(org.deeplearning4j.nn.conf.layers.FeedForwardLayer) FrozenLayer(org.deeplearning4j.nn.layers.FrozenLayer) IOutputLayer(org.deeplearning4j.nn.api.layers.IOutputLayer) RecurrentLayer(org.deeplearning4j.nn.api.layers.RecurrentLayer) Triple(org.deeplearning4j.berkeley.Triple) FrozenLayer(org.deeplearning4j.nn.layers.FrozenLayer) DefaultGradient(org.deeplearning4j.nn.gradient.DefaultGradient) INDArray(org.nd4j.linalg.api.ndarray.INDArray) IOutputLayer(org.deeplearning4j.nn.api.layers.IOutputLayer) Pair(org.deeplearning4j.berkeley.Pair)

Aggregations

FrozenLayer (org.deeplearning4j.nn.layers.FrozenLayer)11 INDArray (org.nd4j.linalg.api.ndarray.INDArray)6 IOutputLayer (org.deeplearning4j.nn.api.layers.IOutputLayer)5 RecurrentLayer (org.deeplearning4j.nn.api.layers.RecurrentLayer)4 FeedForwardLayer (org.deeplearning4j.nn.conf.layers.FeedForwardLayer)4 GraphVertex (org.deeplearning4j.nn.graph.vertex.GraphVertex)3 Triple (org.deeplearning4j.berkeley.Triple)2 Layer (org.deeplearning4j.nn.api.Layer)2 ComputationGraphConfiguration (org.deeplearning4j.nn.conf.ComputationGraphConfiguration)2 MultiLayerConfiguration (org.deeplearning4j.nn.conf.MultiLayerConfiguration)2 NeuralNetConfiguration (org.deeplearning4j.nn.conf.NeuralNetConfiguration)2 DenseLayer (org.deeplearning4j.nn.conf.layers.DenseLayer)2 OutputLayer (org.deeplearning4j.nn.conf.layers.OutputLayer)2 DefaultGradient (org.deeplearning4j.nn.gradient.DefaultGradient)2 Gradient (org.deeplearning4j.nn.gradient.Gradient)2 ComputationGraph (org.deeplearning4j.nn.graph.ComputationGraph)2 VertexIndices (org.deeplearning4j.nn.graph.vertex.VertexIndices)2 Test (org.junit.Test)2 HashMap (java.util.HashMap)1 LinkedHashMap (java.util.LinkedHashMap)1