Search in sources :

Example 6 with RecurrentLayer

use of org.deeplearning4j.nn.api.layers.RecurrentLayer in project deeplearning4j by deeplearning4j.

the class ComputationGraph method rnnTimeStep.

//------------------------------------------------------------------------------
//RNN-specific functionality
/**
     * If this ComputationGraph contains one or more RNN layers: conduct forward pass (prediction)
     * but using previous stored state for any RNN layers. The activations for the final step are
     * also stored in the RNN layers for use next time rnnTimeStep() is called.<br>
     * This method can be used to generate output one or more steps at a time instead of always having to do
     * forward pass from t=0. Example uses are for streaming data, and for generating samples from network output
     * one step at a time (where samples are then fed back into the network as input)<br>
     * If no previous state is present in RNN layers (i.e., initially or after calling rnnClearPreviousState()),
     * the default initialization (usually 0) is used.<br>
     * Supports mini-batch (i.e., multiple predictions/forward pass in parallel) as well as for single examples.<br>
     *
     * @param inputs Input to network. May be for one or multiple time steps. For single time step:
     *               input has shape [miniBatchSize,inputSize] or [miniBatchSize,inputSize,1]. miniBatchSize=1 for single example.<br>
     *               For multiple time steps: [miniBatchSize,inputSize,inputTimeSeriesLength]
     * @return Output activations. If output is RNN layer (such as RnnOutputLayer): if all inputs have shape [miniBatchSize,inputSize]
     * i.e., is 2d, then outputs have shape [miniBatchSize,outputSize] (i.e., also 2d) instead of [miniBatchSize,outputSize,1].<br>
     * Otherwise output is 3d [miniBatchSize,outputSize,inputTimeSeriesLength] when using RnnOutputLayer (or unmodified otherwise).
     */
public INDArray[] rnnTimeStep(INDArray... inputs) {
    this.inputs = inputs;
    //Idea: if 2d in, want 2d out
    boolean inputIs2d = true;
    for (INDArray i : inputs) {
        if (i.rank() != 2) {
            inputIs2d = false;
            break;
        }
    }
    INDArray[] outputs = new INDArray[this.numOutputArrays];
    //Based on: feedForward()
    for (int currVertexIdx : topologicalOrder) {
        GraphVertex current = vertices[currVertexIdx];
        if (current.isInputVertex()) {
            VertexIndices[] inputsTo = current.getOutputVertices();
            INDArray input = inputs[current.getVertexIndex()];
            for (VertexIndices v : inputsTo) {
                int vIdx = v.getVertexIndex();
                int vIdxInputNum = v.getVertexEdgeNumber();
                //This input: the 'vIdxInputNum'th input to vertex 'vIdx'
                //TODO When to dup?
                vertices[vIdx].setInput(vIdxInputNum, input.dup());
            }
        } else {
            INDArray out;
            if (current.hasLayer()) {
                //Layer
                Layer l = current.getLayer();
                if (l instanceof RecurrentLayer) {
                    out = ((RecurrentLayer) l).rnnTimeStep(current.getInputs()[0]);
                } else if (l instanceof MultiLayerNetwork) {
                    out = ((MultiLayerNetwork) l).rnnTimeStep(current.getInputs()[0]);
                } else {
                    //non-recurrent layer
                    out = current.doForward(false);
                }
            } else {
                //GraphNode
                out = current.doForward(false);
            }
            if (current.isOutputVertex()) {
                //Get the index of this output vertex...
                int idx = configuration.getNetworkOutputs().indexOf(current.getVertexName());
                outputs[idx] = out;
            }
            //Now, set the inputs for the next vertices:
            VertexIndices[] outputsTo = current.getOutputVertices();
            if (outputsTo != null) {
                for (VertexIndices v : outputsTo) {
                    int vIdx = v.getVertexIndex();
                    int inputNum = v.getVertexEdgeNumber();
                    //This (jth) connection from the output: is the 'inputNum'th input to vertex 'vIdx'
                    vertices[vIdx].setInput(inputNum, out);
                }
            }
        }
    }
    //As per MultiLayerNetwork.rnnTimeStep(): if inputs are all 2d, then outputs are all 2d
    if (inputIs2d) {
        for (int i = 0; i < outputs.length; i++) {
            if (outputs[i].rank() == 3 && outputs[i].size(2) == 1) {
                //Return 2d output with shape [miniBatchSize,nOut]
                // instead of 3d output with shape [miniBatchSize,nOut,1]
                outputs[i] = outputs[i].tensorAlongDimension(0, 1, 0);
            }
        }
    }
    this.inputs = null;
    return outputs;
}
Also used : GraphVertex(org.deeplearning4j.nn.graph.vertex.GraphVertex) INDArray(org.nd4j.linalg.api.ndarray.INDArray) VertexIndices(org.deeplearning4j.nn.graph.vertex.VertexIndices) Layer(org.deeplearning4j.nn.api.Layer) FrozenLayer(org.deeplearning4j.nn.layers.FrozenLayer) RecurrentLayer(org.deeplearning4j.nn.api.layers.RecurrentLayer) FeedForwardLayer(org.deeplearning4j.nn.conf.layers.FeedForwardLayer) IOutputLayer(org.deeplearning4j.nn.api.layers.IOutputLayer) MultiLayerNetwork(org.deeplearning4j.nn.multilayer.MultiLayerNetwork) RecurrentLayer(org.deeplearning4j.nn.api.layers.RecurrentLayer)

Example 7 with RecurrentLayer

use of org.deeplearning4j.nn.api.layers.RecurrentLayer in project deeplearning4j by deeplearning4j.

the class MultiLayerNetwork method truncatedBPTTGradient.

/** Equivalent to backprop(), but calculates gradient for truncated BPTT instead. */
protected void truncatedBPTTGradient() {
    if (flattenedGradients == null)
        initGradientsView();
    String multiGradientKey;
    gradient = new DefaultGradient();
    Layer currLayer;
    if (!(getOutputLayer() instanceof IOutputLayer)) {
        log.warn("Warning: final layer isn't output layer. You cannot use backprop (truncated BPTT) without an output layer.");
        return;
    }
    IOutputLayer outputLayer = (IOutputLayer) getOutputLayer();
    if (labels == null)
        throw new IllegalStateException("No labels found");
    if (outputLayer.conf().getLayer().getWeightInit() == WeightInit.ZERO) {
        throw new IllegalStateException("Output layer weights cannot be initialized to zero when using backprop.");
    }
    outputLayer.setLabels(labels);
    //calculate and apply the backward gradient for every layer
    int numLayers = getnLayers();
    //Store gradients is a list; used to ensure iteration order in DefaultGradient linked hash map. i.e., layer 0 first instead of output layer
    LinkedList<Pair<String, INDArray>> gradientList = new LinkedList<>();
    Pair<Gradient, INDArray> currPair = outputLayer.backpropGradient(null);
    for (Map.Entry<String, INDArray> entry : currPair.getFirst().gradientForVariable().entrySet()) {
        multiGradientKey = String.valueOf(numLayers - 1) + "_" + entry.getKey();
        gradientList.addLast(new Pair<>(multiGradientKey, entry.getValue()));
    }
    if (getLayerWiseConfigurations().getInputPreProcess(numLayers - 1) != null)
        currPair = new Pair<>(currPair.getFirst(), this.layerWiseConfigurations.getInputPreProcess(numLayers - 1).backprop(currPair.getSecond(), getInputMiniBatchSize()));
    // Calculate gradients for previous layers & drops output layer in count
    for (int j = numLayers - 2; j >= 0; j--) {
        currLayer = getLayer(j);
        if (currLayer instanceof RecurrentLayer) {
            currPair = ((RecurrentLayer) currLayer).tbpttBackpropGradient(currPair.getSecond(), layerWiseConfigurations.getTbpttBackLength());
        } else {
            currPair = currLayer.backpropGradient(currPair.getSecond());
        }
        LinkedList<Pair<String, INDArray>> tempList = new LinkedList<>();
        for (Map.Entry<String, INDArray> entry : currPair.getFirst().gradientForVariable().entrySet()) {
            multiGradientKey = String.valueOf(j) + "_" + entry.getKey();
            tempList.addFirst(new Pair<>(multiGradientKey, entry.getValue()));
        }
        for (Pair<String, INDArray> pair : tempList) gradientList.addFirst(pair);
        //Pass epsilon through input processor before passing to next layer (if applicable)
        if (getLayerWiseConfigurations().getInputPreProcess(j) != null)
            currPair = new Pair<>(currPair.getFirst(), getLayerWiseConfigurations().getInputPreProcess(j).backprop(currPair.getSecond(), getInputMiniBatchSize()));
    }
    //Add gradients to Gradients, in correct order
    for (Pair<String, INDArray> pair : gradientList) gradient.setGradientFor(pair.getFirst(), pair.getSecond());
}
Also used : Gradient(org.deeplearning4j.nn.gradient.Gradient) DefaultGradient(org.deeplearning4j.nn.gradient.DefaultGradient) FeedForwardLayer(org.deeplearning4j.nn.conf.layers.FeedForwardLayer) FrozenLayer(org.deeplearning4j.nn.layers.FrozenLayer) IOutputLayer(org.deeplearning4j.nn.api.layers.IOutputLayer) RecurrentLayer(org.deeplearning4j.nn.api.layers.RecurrentLayer) DefaultGradient(org.deeplearning4j.nn.gradient.DefaultGradient) INDArray(org.nd4j.linalg.api.ndarray.INDArray) IOutputLayer(org.deeplearning4j.nn.api.layers.IOutputLayer) RecurrentLayer(org.deeplearning4j.nn.api.layers.RecurrentLayer) Pair(org.deeplearning4j.berkeley.Pair)

Example 8 with RecurrentLayer

use of org.deeplearning4j.nn.api.layers.RecurrentLayer in project deeplearning4j by deeplearning4j.

the class MultiLayerNetwork method rnnActivateUsingStoredState.

/** Similar to rnnTimeStep and feedForward() methods. Difference here is that this method:<br>
     * (a) like rnnTimeStep does forward pass using stored state for RNN layers, and<br>
     * (b) unlike rnnTimeStep does not modify the RNN layer state<br>
     * Therefore multiple calls to this method with the same input should have the same output.<br>
     * Typically used during training only. Use rnnTimeStep for prediction/forward pass at test time.
     * @param input Input to network
     * @param training Whether training or not
     * @param storeLastForTBPTT set to true if used as part of truncated BPTT training
     * @return Activations for each layer (including input, as per feedforward() etc)
     */
public List<INDArray> rnnActivateUsingStoredState(INDArray input, boolean training, boolean storeLastForTBPTT) {
    INDArray currInput = input;
    List<INDArray> activations = new ArrayList<>();
    activations.add(currInput);
    for (int i = 0; i < layers.length; i++) {
        if (getLayerWiseConfigurations().getInputPreProcess(i) != null)
            currInput = getLayerWiseConfigurations().getInputPreProcess(i).preProcess(currInput, input.size(0));
        if (layers[i] instanceof RecurrentLayer) {
            currInput = ((RecurrentLayer) layers[i]).rnnActivateUsingStoredState(currInput, training, storeLastForTBPTT);
        } else if (layers[i] instanceof MultiLayerNetwork) {
            List<INDArray> temp = ((MultiLayerNetwork) layers[i]).rnnActivateUsingStoredState(currInput, training, storeLastForTBPTT);
            currInput = temp.get(temp.size() - 1);
        } else {
            currInput = layers[i].activate(currInput, training);
        }
        activations.add(currInput);
    }
    return activations;
}
Also used : INDArray(org.nd4j.linalg.api.ndarray.INDArray) RecurrentLayer(org.deeplearning4j.nn.api.layers.RecurrentLayer)

Example 9 with RecurrentLayer

use of org.deeplearning4j.nn.api.layers.RecurrentLayer in project deeplearning4j by deeplearning4j.

the class MultiLayerNetwork method updateRnnStateWithTBPTTState.

public void updateRnnStateWithTBPTTState() {
    for (int i = 0; i < layers.length; i++) {
        if (layers[i] instanceof RecurrentLayer) {
            RecurrentLayer l = ((RecurrentLayer) layers[i]);
            l.rnnSetPreviousState(l.rnnGetTBPTTState());
        } else if (layers[i] instanceof MultiLayerNetwork) {
            ((MultiLayerNetwork) layers[i]).updateRnnStateWithTBPTTState();
        }
    }
}
Also used : RecurrentLayer(org.deeplearning4j.nn.api.layers.RecurrentLayer)

Aggregations

RecurrentLayer (org.deeplearning4j.nn.api.layers.RecurrentLayer)9 INDArray (org.nd4j.linalg.api.ndarray.INDArray)5 IOutputLayer (org.deeplearning4j.nn.api.layers.IOutputLayer)4 FeedForwardLayer (org.deeplearning4j.nn.conf.layers.FeedForwardLayer)4 FrozenLayer (org.deeplearning4j.nn.layers.FrozenLayer)4 Layer (org.deeplearning4j.nn.api.Layer)3 MultiLayerNetwork (org.deeplearning4j.nn.multilayer.MultiLayerNetwork)3 Pair (org.deeplearning4j.berkeley.Pair)2 Gradient (org.deeplearning4j.nn.gradient.Gradient)2 GraphVertex (org.deeplearning4j.nn.graph.vertex.GraphVertex)2 VertexIndices (org.deeplearning4j.nn.graph.vertex.VertexIndices)2 DefaultGradient (org.deeplearning4j.nn.gradient.DefaultGradient)1