Search in sources :

Example 1 with RecurrentLayer

use of org.deeplearning4j.nn.api.layers.RecurrentLayer in project deeplearning4j by deeplearning4j.

the class ComputationGraph method rnnActivateUsingStoredState.

/**
     * Similar to rnnTimeStep and feedForward() methods. Difference here is that this method:<br>
     * (a) like rnnTimeStep does forward pass using stored state for RNN layers, and<br>
     * (b) unlike rnnTimeStep does not modify the RNN layer state<br>
     * Therefore multiple calls to this method with the same input should have the same output.<br>
     * Typically used during training only. Use rnnTimeStep for prediction/forward pass at test time.
     *
     * @param inputs            Input to network
     * @param training          Whether training or not
     * @param storeLastForTBPTT set to true if used as part of truncated BPTT training
     * @return Activations for each layer (including input, as per feedforward() etc)
     */
public Map<String, INDArray> rnnActivateUsingStoredState(INDArray[] inputs, boolean training, boolean storeLastForTBPTT) {
    Map<String, INDArray> layerActivations = new HashMap<>();
    //Do forward pass according to the topological ordering of the network
    for (int currVertexIdx : topologicalOrder) {
        GraphVertex current = vertices[currVertexIdx];
        if (current.isInputVertex()) {
            VertexIndices[] inputsTo = current.getOutputVertices();
            INDArray input = inputs[current.getVertexIndex()];
            layerActivations.put(current.getVertexName(), input);
            for (VertexIndices v : inputsTo) {
                int vIdx = v.getVertexIndex();
                int vIdxInputNum = v.getVertexEdgeNumber();
                //This input: the 'vIdxInputNum'th input to vertex 'vIdx'
                //TODO When to dup?
                vertices[vIdx].setInput(vIdxInputNum, input.dup());
            }
        } else {
            INDArray out;
            if (current.hasLayer()) {
                Layer l = current.getLayer();
                if (l instanceof RecurrentLayer) {
                    out = ((RecurrentLayer) l).rnnActivateUsingStoredState(current.getInputs()[0], training, storeLastForTBPTT);
                } else if (l instanceof MultiLayerNetwork) {
                    List<INDArray> temp = ((MultiLayerNetwork) l).rnnActivateUsingStoredState(current.getInputs()[0], training, storeLastForTBPTT);
                    out = temp.get(temp.size() - 1);
                } else {
                    //non-recurrent layer
                    out = current.doForward(training);
                }
                layerActivations.put(current.getVertexName(), out);
            } else {
                out = current.doForward(training);
            }
            //Now, set the inputs for the next vertices:
            VertexIndices[] outputsTo = current.getOutputVertices();
            if (outputsTo != null) {
                for (VertexIndices v : outputsTo) {
                    int vIdx = v.getVertexIndex();
                    int inputNum = v.getVertexEdgeNumber();
                    //This (jth) connection from the output: is the 'inputNum'th input to vertex 'vIdx'
                    vertices[vIdx].setInput(inputNum, out);
                }
            }
        }
    }
    return layerActivations;
}
Also used : GraphVertex(org.deeplearning4j.nn.graph.vertex.GraphVertex) INDArray(org.nd4j.linalg.api.ndarray.INDArray) VertexIndices(org.deeplearning4j.nn.graph.vertex.VertexIndices) Layer(org.deeplearning4j.nn.api.Layer) FrozenLayer(org.deeplearning4j.nn.layers.FrozenLayer) RecurrentLayer(org.deeplearning4j.nn.api.layers.RecurrentLayer) FeedForwardLayer(org.deeplearning4j.nn.conf.layers.FeedForwardLayer) IOutputLayer(org.deeplearning4j.nn.api.layers.IOutputLayer) MultiLayerNetwork(org.deeplearning4j.nn.multilayer.MultiLayerNetwork) RecurrentLayer(org.deeplearning4j.nn.api.layers.RecurrentLayer)

Example 2 with RecurrentLayer

use of org.deeplearning4j.nn.api.layers.RecurrentLayer in project deeplearning4j by deeplearning4j.

the class ComputationGraph method rnnSetPreviousState.

/**
     * Set the state of the RNN layer, for use in {@link #rnnTimeStep(INDArray...)}
     *
     * @param layerName The name of the layer.
     * @param state     The state to set the specified layer to
     */
public void rnnSetPreviousState(String layerName, Map<String, INDArray> state) {
    Layer l = verticesMap.get(layerName).getLayer();
    if (l == null || !(l instanceof RecurrentLayer)) {
        throw new UnsupportedOperationException("Layer \"" + layerName + "\" is not a recurrent layer. Cannot set state");
    }
    ((RecurrentLayer) l).rnnSetPreviousState(state);
}
Also used : Layer(org.deeplearning4j.nn.api.Layer) FrozenLayer(org.deeplearning4j.nn.layers.FrozenLayer) RecurrentLayer(org.deeplearning4j.nn.api.layers.RecurrentLayer) FeedForwardLayer(org.deeplearning4j.nn.conf.layers.FeedForwardLayer) IOutputLayer(org.deeplearning4j.nn.api.layers.IOutputLayer) RecurrentLayer(org.deeplearning4j.nn.api.layers.RecurrentLayer)

Example 3 with RecurrentLayer

use of org.deeplearning4j.nn.api.layers.RecurrentLayer in project deeplearning4j by deeplearning4j.

the class ComputationGraph method rnnUpdateStateWithTBPTTState.

/**
     * Update the internal state of RNN layers after a truncated BPTT fit call
     */
protected void rnnUpdateStateWithTBPTTState() {
    for (int i = 0; i < layers.length; i++) {
        if (layers[i] instanceof RecurrentLayer) {
            RecurrentLayer l = ((RecurrentLayer) layers[i]);
            l.rnnSetPreviousState(l.rnnGetTBPTTState());
        } else if (layers[i] instanceof MultiLayerNetwork) {
            ((MultiLayerNetwork) layers[i]).updateRnnStateWithTBPTTState();
        }
    }
}
Also used : MultiLayerNetwork(org.deeplearning4j.nn.multilayer.MultiLayerNetwork) RecurrentLayer(org.deeplearning4j.nn.api.layers.RecurrentLayer)

Example 4 with RecurrentLayer

use of org.deeplearning4j.nn.api.layers.RecurrentLayer in project deeplearning4j by deeplearning4j.

the class LayerVertex method doBackward.

@Override
public Pair<Gradient, INDArray[]> doBackward(boolean tbptt) {
    if (!canDoBackward()) {
        throw new IllegalStateException("Cannot do backward pass: all epsilons not set. Layer " + vertexName + " (idx " + vertexIndex + ") numInputs " + getNumInputArrays() + "; numOutputs " + getNumOutputConnections());
    }
    Pair<Gradient, INDArray> pair;
    if (tbptt && layer instanceof RecurrentLayer) {
        //Truncated BPTT for recurrent layers
        pair = ((RecurrentLayer) layer).tbpttBackpropGradient(epsilon, graph.getConfiguration().getTbpttBackLength());
    } else {
        //Normal backprop
        //epsTotal may be null for OutputLayers
        pair = layer.backpropGradient(epsilon);
    }
    if (layerPreProcessor != null) {
        INDArray eps = pair.getSecond();
        eps = layerPreProcessor.backprop(eps, graph.batchSize());
        pair.setSecond(eps);
    }
    //Layers always have single activations input -> always have single epsilon output during backprop
    return new Pair<>(pair.getFirst(), new INDArray[] { pair.getSecond() });
}
Also used : Gradient(org.deeplearning4j.nn.gradient.Gradient) INDArray(org.nd4j.linalg.api.ndarray.INDArray) RecurrentLayer(org.deeplearning4j.nn.api.layers.RecurrentLayer) Pair(org.deeplearning4j.berkeley.Pair)

Example 5 with RecurrentLayer

use of org.deeplearning4j.nn.api.layers.RecurrentLayer in project deeplearning4j by deeplearning4j.

the class MultiLayerNetwork method rnnSetPreviousState.

/**Set the state of the RNN layer.
     * @param layer The number/index of the layer.
     * @param state The state to set the specified layer to
     */
public void rnnSetPreviousState(int layer, Map<String, INDArray> state) {
    if (layer < 0 || layer >= layers.length)
        throw new IllegalArgumentException("Invalid layer number");
    if (!(layers[layer] instanceof RecurrentLayer))
        throw new IllegalArgumentException("Layer is not an RNN layer");
    RecurrentLayer r = (RecurrentLayer) layers[layer];
    r.rnnSetPreviousState(state);
}
Also used : RecurrentLayer(org.deeplearning4j.nn.api.layers.RecurrentLayer)

Aggregations

RecurrentLayer (org.deeplearning4j.nn.api.layers.RecurrentLayer)9 INDArray (org.nd4j.linalg.api.ndarray.INDArray)5 IOutputLayer (org.deeplearning4j.nn.api.layers.IOutputLayer)4 FeedForwardLayer (org.deeplearning4j.nn.conf.layers.FeedForwardLayer)4 FrozenLayer (org.deeplearning4j.nn.layers.FrozenLayer)4 Layer (org.deeplearning4j.nn.api.Layer)3 MultiLayerNetwork (org.deeplearning4j.nn.multilayer.MultiLayerNetwork)3 Pair (org.deeplearning4j.berkeley.Pair)2 Gradient (org.deeplearning4j.nn.gradient.Gradient)2 GraphVertex (org.deeplearning4j.nn.graph.vertex.GraphVertex)2 VertexIndices (org.deeplearning4j.nn.graph.vertex.VertexIndices)2 DefaultGradient (org.deeplearning4j.nn.gradient.DefaultGradient)1