Search in sources :

Example 1 with FrozenLayer

use of org.deeplearning4j.nn.layers.FrozenLayer in project deeplearning4j by deeplearning4j.

the class TransferLearningHelper method initHelperMLN.

private void initHelperMLN() {
    if (applyFrozen) {
        org.deeplearning4j.nn.api.Layer[] layers = origMLN.getLayers();
        for (int i = frozenTill; i >= 0; i--) {
            //unchecked?
            layers[i] = new FrozenLayer(layers[i]);
        }
        origMLN.setLayers(layers);
    }
    for (int i = 0; i < origMLN.getnLayers(); i++) {
        if (origMLN.getLayer(i) instanceof FrozenLayer) {
            frozenInputLayer = i;
        }
    }
    List<NeuralNetConfiguration> allConfs = new ArrayList<>();
    for (int i = frozenInputLayer + 1; i < origMLN.getnLayers(); i++) {
        allConfs.add(origMLN.getLayer(i).conf());
    }
    MultiLayerConfiguration c = origMLN.getLayerWiseConfigurations();
    unFrozenSubsetMLN = new MultiLayerNetwork(new MultiLayerConfiguration.Builder().backprop(c.isBackprop()).inputPreProcessors(c.getInputPreProcessors()).pretrain(c.isPretrain()).backpropType(c.getBackpropType()).tBPTTForwardLength(c.getTbpttFwdLength()).tBPTTBackwardLength(c.getTbpttBackLength()).confs(allConfs).build());
    unFrozenSubsetMLN.init();
    //copy over params
    for (int i = frozenInputLayer + 1; i < origMLN.getnLayers(); i++) {
        unFrozenSubsetMLN.getLayer(i - frozenInputLayer - 1).setParams(origMLN.getLayer(i).params());
    }
//unFrozenSubsetMLN.setListeners(origMLN.getListeners());
}
Also used : FrozenLayer(org.deeplearning4j.nn.layers.FrozenLayer) MultiLayerConfiguration(org.deeplearning4j.nn.conf.MultiLayerConfiguration) NeuralNetConfiguration(org.deeplearning4j.nn.conf.NeuralNetConfiguration) FrozenLayer(org.deeplearning4j.nn.layers.FrozenLayer) MultiLayerNetwork(org.deeplearning4j.nn.multilayer.MultiLayerNetwork)

Example 2 with FrozenLayer

use of org.deeplearning4j.nn.layers.FrozenLayer in project deeplearning4j by deeplearning4j.

the class TransferLearningHelper method initHelperGraph.

/**
     * Runs through the comp graph and saves off a new model that is simply the "unfrozen" part of the origModel
     * This "unfrozen" model is then used for training with featurized data
     */
private void initHelperGraph() {
    int[] backPropOrder = origGraph.topologicalSortOrder().clone();
    ArrayUtils.reverse(backPropOrder);
    Set<String> allFrozen = new HashSet<>();
    if (applyFrozen) {
        Collections.addAll(allFrozen, frozenOutputAt);
    }
    for (int i = 0; i < backPropOrder.length; i++) {
        org.deeplearning4j.nn.graph.vertex.GraphVertex gv = origGraph.getVertices()[backPropOrder[i]];
        if (applyFrozen && allFrozen.contains(gv.getVertexName())) {
            if (gv.hasLayer()) {
                //Need to freeze this layer
                org.deeplearning4j.nn.api.Layer l = gv.getLayer();
                gv.setLayerAsFrozen();
                //We also need to place the layer in the CompGraph Layer[] (replacing the old one)
                //This could no doubt be done more efficiently
                org.deeplearning4j.nn.api.Layer[] layers = origGraph.getLayers();
                for (int j = 0; j < layers.length; j++) {
                    if (layers[j] == l) {
                        //Place the new frozen layer to replace the original layer
                        layers[j] = gv.getLayer();
                        break;
                    }
                }
            }
            //Also: mark any inputs as to be frozen also
            VertexIndices[] inputs = gv.getInputVertices();
            if (inputs != null && inputs.length > 0) {
                for (int j = 0; j < inputs.length; j++) {
                    int inputVertexIdx = inputs[j].getVertexIndex();
                    String alsoFreeze = origGraph.getVertices()[inputVertexIdx].getVertexName();
                    allFrozen.add(alsoFreeze);
                }
            }
        } else {
            if (gv.hasLayer()) {
                if (gv.getLayer() instanceof FrozenLayer) {
                    allFrozen.add(gv.getVertexName());
                    //also need to add parents to list of allFrozen
                    VertexIndices[] inputs = gv.getInputVertices();
                    if (inputs != null && inputs.length > 0) {
                        for (int j = 0; j < inputs.length; j++) {
                            int inputVertexIdx = inputs[j].getVertexIndex();
                            String alsoFrozen = origGraph.getVertices()[inputVertexIdx].getVertexName();
                            allFrozen.add(alsoFrozen);
                        }
                    }
                }
            }
        }
    }
    for (int i = 0; i < backPropOrder.length; i++) {
        org.deeplearning4j.nn.graph.vertex.GraphVertex gv = origGraph.getVertices()[backPropOrder[i]];
        String gvName = gv.getVertexName();
        //is it an unfrozen vertex that has an input vertex that is frozen?
        if (!allFrozen.contains(gvName) && !gv.isInputVertex()) {
            VertexIndices[] inputs = gv.getInputVertices();
            for (int j = 0; j < inputs.length; j++) {
                int inputVertexIdx = inputs[j].getVertexIndex();
                String inputVertex = origGraph.getVertices()[inputVertexIdx].getVertexName();
                if (allFrozen.contains(inputVertex)) {
                    frozenInputVertices.add(inputVertex);
                }
            }
        }
    }
    TransferLearning.GraphBuilder builder = new TransferLearning.GraphBuilder(origGraph);
    for (String toRemove : allFrozen) {
        if (frozenInputVertices.contains(toRemove)) {
            builder.removeVertexKeepConnections(toRemove);
        } else {
            builder.removeVertexAndConnections(toRemove);
        }
    }
    Set<String> frozenInputVerticesSorted = new HashSet<>();
    frozenInputVerticesSorted.addAll(origGraph.getConfiguration().getNetworkInputs());
    frozenInputVerticesSorted.removeAll(allFrozen);
    //remove input vertices - just to add back in a predictable order
    for (String existingInput : frozenInputVerticesSorted) {
        builder.removeVertexKeepConnections(existingInput);
    }
    frozenInputVerticesSorted.addAll(frozenInputVertices);
    //Sort all inputs to the computation graph - in order to have a predictable order
    graphInputs = new ArrayList(frozenInputVerticesSorted);
    Collections.sort(graphInputs);
    for (String asInput : frozenInputVerticesSorted) {
        //add back in the right order
        builder.addInputs(asInput);
    }
    unFrozenSubsetGraph = builder.build();
    copyOrigParamsToSubsetGraph();
    if (frozenInputVertices.isEmpty()) {
        throw new IllegalArgumentException("No frozen layers found");
    }
}
Also used : GraphVertex(org.deeplearning4j.nn.graph.vertex.GraphVertex) FrozenLayer(org.deeplearning4j.nn.layers.FrozenLayer) FrozenLayer(org.deeplearning4j.nn.layers.FrozenLayer) VertexIndices(org.deeplearning4j.nn.graph.vertex.VertexIndices)

Example 3 with FrozenLayer

use of org.deeplearning4j.nn.layers.FrozenLayer in project deeplearning4j by deeplearning4j.

the class LayerUpdater method update.

@Override
public void update(Layer layer, Gradient gradient, int iteration, int miniBatchSize) {
    String paramName;
    INDArray gradientOrig, gradient2;
    GradientUpdater updater;
    if (layer instanceof FrozenLayer)
        return;
    preApply(layer, gradient, iteration);
    for (Map.Entry<String, INDArray> gradientPair : gradient.gradientForVariable().entrySet()) {
        paramName = gradientPair.getKey();
        if (!layer.conf().isPretrain() && PretrainParamInitializer.VISIBLE_BIAS_KEY.equals(paramName.split("_")[0]))
            continue;
        gradientOrig = gradientPair.getValue();
        LearningRatePolicy decay = layer.conf().getLearningRatePolicy();
        if (decay != LearningRatePolicy.None || layer.conf().getLayer().getUpdater() == org.deeplearning4j.nn.conf.Updater.NESTEROVS)
            applyLrDecayPolicy(decay, layer, iteration, paramName);
        updater = init(paramName, layer);
        gradient2 = updater.getGradient(gradientOrig, iteration);
        postApply(layer, gradient2, paramName, miniBatchSize);
        gradient.setGradientFor(paramName, gradient2);
    }
}
Also used : FrozenLayer(org.deeplearning4j.nn.layers.FrozenLayer) INDArray(org.nd4j.linalg.api.ndarray.INDArray) LearningRatePolicy(org.deeplearning4j.nn.conf.LearningRatePolicy) HashMap(java.util.HashMap) LinkedHashMap(java.util.LinkedHashMap) Map(java.util.Map)

Example 4 with FrozenLayer

use of org.deeplearning4j.nn.layers.FrozenLayer in project deeplearning4j by deeplearning4j.

the class ComputationGraph method summary.

/**
     * String detailing the architecture of the computation graph.
     * Vertices are printed in a topological sort order.
     * Columns are Vertex Names with layer/vertex type, nIn, nOut, Total number of parameters and the Shapes of the parameters
     * And the inputs to the vertex
     * Will also give information about frozen layers/vertices, if any.
     * @return Summary as a string
     */
public String summary() {
    String ret = "\n";
    ret += StringUtils.repeat("=", 140);
    ret += "\n";
    ret += String.format("%-40s%-15s%-15s%-30s %s\n", "VertexName (VertexType)", "nIn,nOut", "TotalParams", "ParamsShape", "Vertex Inputs");
    ret += StringUtils.repeat("=", 140);
    ret += "\n";
    int frozenParams = 0;
    for (int currVertexIdx : topologicalOrder) {
        GraphVertex current = vertices[currVertexIdx];
        String name = current.getVertexName();
        String[] classNameArr = current.getClass().toString().split("\\.");
        String className = classNameArr[classNameArr.length - 1];
        String connections = "-";
        if (!current.isInputVertex()) {
            connections = configuration.getVertexInputs().get(name).toString();
        }
        String paramCount = "-";
        String in = "-";
        String out = "-";
        String paramShape = "-";
        if (current.hasLayer()) {
            Layer currentLayer = ((LayerVertex) current).getLayer();
            classNameArr = currentLayer.getClass().getName().split("\\.");
            className = classNameArr[classNameArr.length - 1];
            paramCount = String.valueOf(currentLayer.numParams());
            if (currentLayer.numParams() > 0) {
                paramShape = "";
                in = String.valueOf(((FeedForwardLayer) currentLayer.conf().getLayer()).getNIn());
                out = String.valueOf(((FeedForwardLayer) currentLayer.conf().getLayer()).getNOut());
                Set<String> paraNames = currentLayer.conf().getLearningRateByParam().keySet();
                for (String aP : paraNames) {
                    String paramS = ArrayUtils.toString(currentLayer.paramTable().get(aP).shape());
                    paramShape += aP + ":" + paramS + ", ";
                }
                paramShape = paramShape.subSequence(0, paramShape.lastIndexOf(",")).toString();
            }
            if (currentLayer instanceof FrozenLayer) {
                frozenParams += currentLayer.numParams();
                classNameArr = ((FrozenLayer) currentLayer).getInsideLayer().getClass().getName().split("\\.");
                className = "Frozen " + classNameArr[classNameArr.length - 1];
            }
        }
        ret += String.format("%-40s%-15s%-15s%-30s %s", name + " (" + className + ")", in + "," + out, paramCount, paramShape, connections);
        ret += "\n";
    }
    ret += StringUtils.repeat("-", 140);
    ret += String.format("\n%30s %d", "Total Parameters: ", params().length());
    ret += String.format("\n%30s %d", "Trainable Parameters: ", params().length() - frozenParams);
    ret += String.format("\n%30s %d", "Frozen Parameters: ", frozenParams);
    ret += "\n";
    ret += StringUtils.repeat("=", 140);
    ret += "\n";
    return ret;
}
Also used : LayerVertex(org.deeplearning4j.nn.graph.vertex.impl.LayerVertex) FrozenLayer(org.deeplearning4j.nn.layers.FrozenLayer) GraphVertex(org.deeplearning4j.nn.graph.vertex.GraphVertex) FeedForwardLayer(org.deeplearning4j.nn.conf.layers.FeedForwardLayer) Layer(org.deeplearning4j.nn.api.Layer) FrozenLayer(org.deeplearning4j.nn.layers.FrozenLayer) RecurrentLayer(org.deeplearning4j.nn.api.layers.RecurrentLayer) FeedForwardLayer(org.deeplearning4j.nn.conf.layers.FeedForwardLayer) IOutputLayer(org.deeplearning4j.nn.api.layers.IOutputLayer)

Example 5 with FrozenLayer

use of org.deeplearning4j.nn.layers.FrozenLayer in project deeplearning4j by deeplearning4j.

the class ComputationGraph method calcBackpropGradients.

/**
     * Do backprop (gradient calculation)
     *
     * @param truncatedBPTT    false: normal backprop. true: calculate gradients using truncated BPTT for RNN layers
     * @param externalEpsilons null usually (for typical supervised learning). If not null (and length > 0) then assume that
     *                         the user has provided some errors externally, as they would do for example in reinforcement
     *                         learning situations.
     */
protected void calcBackpropGradients(boolean truncatedBPTT, INDArray... externalEpsilons) {
    if (flattenedGradients == null)
        initGradientsView();
    LinkedList<Triple<String, INDArray, Character>> gradients = new LinkedList<>();
    //Do backprop according to the reverse of the topological ordering of the network
    //If true: already set epsilon for this vertex; later epsilons should be *added* to the existing one, not set
    boolean[] setVertexEpsilon = new boolean[topologicalOrder.length];
    for (int i = topologicalOrder.length - 1; i >= 0; i--) {
        GraphVertex current = vertices[topologicalOrder[i]];
        if (current.isInputVertex())
            //No op
            continue;
        //FIXME: make the frozen vertex feature extraction more flexible
        if (current.hasLayer() && current.getLayer() instanceof FrozenLayer)
            break;
        if (current.isOutputVertex()) {
            //Two reasons for a vertex to be an output vertex:
            //(a) it's an output layer (i.e., instanceof IOutputLayer), or
            //(b) it's a normal layer, but it has been marked as an output layer for use in external errors - for reinforcement learning, for example
            int thisOutputNumber = configuration.getNetworkOutputs().indexOf(current.getVertexName());
            if (current.getLayer() instanceof IOutputLayer) {
                IOutputLayer outputLayer = (IOutputLayer) current.getLayer();
                INDArray currLabels = labels[thisOutputNumber];
                outputLayer.setLabels(currLabels);
            } else {
                current.setEpsilon(externalEpsilons[thisOutputNumber]);
                setVertexEpsilon[topologicalOrder[i]] = true;
            }
        }
        Pair<Gradient, INDArray[]> pair = current.doBackward(truncatedBPTT);
        INDArray[] epsilons = pair.getSecond();
        //Inputs to the current GraphVertex:
        VertexIndices[] inputVertices = current.getInputVertices();
        //Set epsilons for the vertices that provide inputs to this vertex:
        if (inputVertices != null) {
            int j = 0;
            for (VertexIndices v : inputVertices) {
                GraphVertex gv = vertices[v.getVertexIndex()];
                if (setVertexEpsilon[gv.getVertexIndex()]) {
                    //This vertex: must output to multiple vertices... we want to add the epsilons here
                    INDArray currentEps = gv.getEpsilon();
                    //TODO: in some circumstances, it may be safe  to do in-place add (but not always)
                    gv.setEpsilon(currentEps.add(epsilons[j++]));
                } else {
                    gv.setEpsilon(epsilons[j++]);
                }
                setVertexEpsilon[gv.getVertexIndex()] = true;
            }
        }
        if (pair.getFirst() != null) {
            Gradient g = pair.getFirst();
            Map<String, INDArray> map = g.gradientForVariable();
            LinkedList<Triple<String, INDArray, Character>> tempList = new LinkedList<>();
            for (Map.Entry<String, INDArray> entry : map.entrySet()) {
                String origName = entry.getKey();
                String newName = current.getVertexName() + "_" + origName;
                tempList.addFirst(new Triple<>(newName, entry.getValue(), g.flatteningOrderForVariable(origName)));
            }
            for (Triple<String, INDArray, Character> t : tempList) gradients.addFirst(t);
        }
    }
    //Now, add the gradients in the order we need them in for flattening (same as params order)
    Gradient gradient = new DefaultGradient(flattenedGradients);
    for (Triple<String, INDArray, Character> t : gradients) {
        gradient.setGradientFor(t.getFirst(), t.getSecond(), t.getThird());
    }
    this.gradient = gradient;
}
Also used : DefaultGradient(org.deeplearning4j.nn.gradient.DefaultGradient) Gradient(org.deeplearning4j.nn.gradient.Gradient) Triple(org.deeplearning4j.berkeley.Triple) FrozenLayer(org.deeplearning4j.nn.layers.FrozenLayer) DefaultGradient(org.deeplearning4j.nn.gradient.DefaultGradient) GraphVertex(org.deeplearning4j.nn.graph.vertex.GraphVertex) INDArray(org.nd4j.linalg.api.ndarray.INDArray) VertexIndices(org.deeplearning4j.nn.graph.vertex.VertexIndices) IOutputLayer(org.deeplearning4j.nn.api.layers.IOutputLayer)

Aggregations

FrozenLayer (org.deeplearning4j.nn.layers.FrozenLayer)11 INDArray (org.nd4j.linalg.api.ndarray.INDArray)6 IOutputLayer (org.deeplearning4j.nn.api.layers.IOutputLayer)5 RecurrentLayer (org.deeplearning4j.nn.api.layers.RecurrentLayer)4 FeedForwardLayer (org.deeplearning4j.nn.conf.layers.FeedForwardLayer)4 GraphVertex (org.deeplearning4j.nn.graph.vertex.GraphVertex)3 Triple (org.deeplearning4j.berkeley.Triple)2 Layer (org.deeplearning4j.nn.api.Layer)2 ComputationGraphConfiguration (org.deeplearning4j.nn.conf.ComputationGraphConfiguration)2 MultiLayerConfiguration (org.deeplearning4j.nn.conf.MultiLayerConfiguration)2 NeuralNetConfiguration (org.deeplearning4j.nn.conf.NeuralNetConfiguration)2 DenseLayer (org.deeplearning4j.nn.conf.layers.DenseLayer)2 OutputLayer (org.deeplearning4j.nn.conf.layers.OutputLayer)2 DefaultGradient (org.deeplearning4j.nn.gradient.DefaultGradient)2 Gradient (org.deeplearning4j.nn.gradient.Gradient)2 ComputationGraph (org.deeplearning4j.nn.graph.ComputationGraph)2 VertexIndices (org.deeplearning4j.nn.graph.vertex.VertexIndices)2 Test (org.junit.Test)2 HashMap (java.util.HashMap)1 LinkedHashMap (java.util.LinkedHashMap)1