use of org.deeplearning4j.nn.api.layers.RecurrentLayer in project deeplearning4j by deeplearning4j.
the class ComputationGraph method rnnActivateUsingStoredState.
/**
* Similar to rnnTimeStep and feedForward() methods. Difference here is that this method:<br>
* (a) like rnnTimeStep does forward pass using stored state for RNN layers, and<br>
* (b) unlike rnnTimeStep does not modify the RNN layer state<br>
* Therefore multiple calls to this method with the same input should have the same output.<br>
* Typically used during training only. Use rnnTimeStep for prediction/forward pass at test time.
*
* @param inputs Input to network
* @param training Whether training or not
* @param storeLastForTBPTT set to true if used as part of truncated BPTT training
* @return Activations for each layer (including input, as per feedforward() etc)
*/
public Map<String, INDArray> rnnActivateUsingStoredState(INDArray[] inputs, boolean training, boolean storeLastForTBPTT) {
Map<String, INDArray> layerActivations = new HashMap<>();
//Do forward pass according to the topological ordering of the network
for (int currVertexIdx : topologicalOrder) {
GraphVertex current = vertices[currVertexIdx];
if (current.isInputVertex()) {
VertexIndices[] inputsTo = current.getOutputVertices();
INDArray input = inputs[current.getVertexIndex()];
layerActivations.put(current.getVertexName(), input);
for (VertexIndices v : inputsTo) {
int vIdx = v.getVertexIndex();
int vIdxInputNum = v.getVertexEdgeNumber();
//This input: the 'vIdxInputNum'th input to vertex 'vIdx'
//TODO When to dup?
vertices[vIdx].setInput(vIdxInputNum, input.dup());
}
} else {
INDArray out;
if (current.hasLayer()) {
Layer l = current.getLayer();
if (l instanceof RecurrentLayer) {
out = ((RecurrentLayer) l).rnnActivateUsingStoredState(current.getInputs()[0], training, storeLastForTBPTT);
} else if (l instanceof MultiLayerNetwork) {
List<INDArray> temp = ((MultiLayerNetwork) l).rnnActivateUsingStoredState(current.getInputs()[0], training, storeLastForTBPTT);
out = temp.get(temp.size() - 1);
} else {
//non-recurrent layer
out = current.doForward(training);
}
layerActivations.put(current.getVertexName(), out);
} else {
out = current.doForward(training);
}
//Now, set the inputs for the next vertices:
VertexIndices[] outputsTo = current.getOutputVertices();
if (outputsTo != null) {
for (VertexIndices v : outputsTo) {
int vIdx = v.getVertexIndex();
int inputNum = v.getVertexEdgeNumber();
//This (jth) connection from the output: is the 'inputNum'th input to vertex 'vIdx'
vertices[vIdx].setInput(inputNum, out);
}
}
}
}
return layerActivations;
}
use of org.deeplearning4j.nn.api.layers.RecurrentLayer in project deeplearning4j by deeplearning4j.
the class ComputationGraph method rnnSetPreviousState.
/**
* Set the state of the RNN layer, for use in {@link #rnnTimeStep(INDArray...)}
*
* @param layerName The name of the layer.
* @param state The state to set the specified layer to
*/
public void rnnSetPreviousState(String layerName, Map<String, INDArray> state) {
Layer l = verticesMap.get(layerName).getLayer();
if (l == null || !(l instanceof RecurrentLayer)) {
throw new UnsupportedOperationException("Layer \"" + layerName + "\" is not a recurrent layer. Cannot set state");
}
((RecurrentLayer) l).rnnSetPreviousState(state);
}
use of org.deeplearning4j.nn.api.layers.RecurrentLayer in project deeplearning4j by deeplearning4j.
the class ComputationGraph method rnnUpdateStateWithTBPTTState.
/**
* Update the internal state of RNN layers after a truncated BPTT fit call
*/
protected void rnnUpdateStateWithTBPTTState() {
for (int i = 0; i < layers.length; i++) {
if (layers[i] instanceof RecurrentLayer) {
RecurrentLayer l = ((RecurrentLayer) layers[i]);
l.rnnSetPreviousState(l.rnnGetTBPTTState());
} else if (layers[i] instanceof MultiLayerNetwork) {
((MultiLayerNetwork) layers[i]).updateRnnStateWithTBPTTState();
}
}
}
use of org.deeplearning4j.nn.api.layers.RecurrentLayer in project deeplearning4j by deeplearning4j.
the class LayerVertex method doBackward.
@Override
public Pair<Gradient, INDArray[]> doBackward(boolean tbptt) {
if (!canDoBackward()) {
throw new IllegalStateException("Cannot do backward pass: all epsilons not set. Layer " + vertexName + " (idx " + vertexIndex + ") numInputs " + getNumInputArrays() + "; numOutputs " + getNumOutputConnections());
}
Pair<Gradient, INDArray> pair;
if (tbptt && layer instanceof RecurrentLayer) {
//Truncated BPTT for recurrent layers
pair = ((RecurrentLayer) layer).tbpttBackpropGradient(epsilon, graph.getConfiguration().getTbpttBackLength());
} else {
//Normal backprop
//epsTotal may be null for OutputLayers
pair = layer.backpropGradient(epsilon);
}
if (layerPreProcessor != null) {
INDArray eps = pair.getSecond();
eps = layerPreProcessor.backprop(eps, graph.batchSize());
pair.setSecond(eps);
}
//Layers always have single activations input -> always have single epsilon output during backprop
return new Pair<>(pair.getFirst(), new INDArray[] { pair.getSecond() });
}
use of org.deeplearning4j.nn.api.layers.RecurrentLayer in project deeplearning4j by deeplearning4j.
the class MultiLayerNetwork method rnnSetPreviousState.
/**Set the state of the RNN layer.
* @param layer The number/index of the layer.
* @param state The state to set the specified layer to
*/
public void rnnSetPreviousState(int layer, Map<String, INDArray> state) {
if (layer < 0 || layer >= layers.length)
throw new IllegalArgumentException("Invalid layer number");
if (!(layers[layer] instanceof RecurrentLayer))
throw new IllegalArgumentException("Layer is not an RNN layer");
RecurrentLayer r = (RecurrentLayer) layers[layer];
r.rnnSetPreviousState(state);
}
Aggregations