use of org.deeplearning4j.nn.api.layers.RecurrentLayer in project deeplearning4j by deeplearning4j.
the class ComputationGraph method rnnTimeStep.
//------------------------------------------------------------------------------
//RNN-specific functionality
/**
* If this ComputationGraph contains one or more RNN layers: conduct forward pass (prediction)
* but using previous stored state for any RNN layers. The activations for the final step are
* also stored in the RNN layers for use next time rnnTimeStep() is called.<br>
* This method can be used to generate output one or more steps at a time instead of always having to do
* forward pass from t=0. Example uses are for streaming data, and for generating samples from network output
* one step at a time (where samples are then fed back into the network as input)<br>
* If no previous state is present in RNN layers (i.e., initially or after calling rnnClearPreviousState()),
* the default initialization (usually 0) is used.<br>
* Supports mini-batch (i.e., multiple predictions/forward pass in parallel) as well as for single examples.<br>
*
* @param inputs Input to network. May be for one or multiple time steps. For single time step:
* input has shape [miniBatchSize,inputSize] or [miniBatchSize,inputSize,1]. miniBatchSize=1 for single example.<br>
* For multiple time steps: [miniBatchSize,inputSize,inputTimeSeriesLength]
* @return Output activations. If output is RNN layer (such as RnnOutputLayer): if all inputs have shape [miniBatchSize,inputSize]
* i.e., is 2d, then outputs have shape [miniBatchSize,outputSize] (i.e., also 2d) instead of [miniBatchSize,outputSize,1].<br>
* Otherwise output is 3d [miniBatchSize,outputSize,inputTimeSeriesLength] when using RnnOutputLayer (or unmodified otherwise).
*/
public INDArray[] rnnTimeStep(INDArray... inputs) {
this.inputs = inputs;
//Idea: if 2d in, want 2d out
boolean inputIs2d = true;
for (INDArray i : inputs) {
if (i.rank() != 2) {
inputIs2d = false;
break;
}
}
INDArray[] outputs = new INDArray[this.numOutputArrays];
//Based on: feedForward()
for (int currVertexIdx : topologicalOrder) {
GraphVertex current = vertices[currVertexIdx];
if (current.isInputVertex()) {
VertexIndices[] inputsTo = current.getOutputVertices();
INDArray input = inputs[current.getVertexIndex()];
for (VertexIndices v : inputsTo) {
int vIdx = v.getVertexIndex();
int vIdxInputNum = v.getVertexEdgeNumber();
//This input: the 'vIdxInputNum'th input to vertex 'vIdx'
//TODO When to dup?
vertices[vIdx].setInput(vIdxInputNum, input.dup());
}
} else {
INDArray out;
if (current.hasLayer()) {
//Layer
Layer l = current.getLayer();
if (l instanceof RecurrentLayer) {
out = ((RecurrentLayer) l).rnnTimeStep(current.getInputs()[0]);
} else if (l instanceof MultiLayerNetwork) {
out = ((MultiLayerNetwork) l).rnnTimeStep(current.getInputs()[0]);
} else {
//non-recurrent layer
out = current.doForward(false);
}
} else {
//GraphNode
out = current.doForward(false);
}
if (current.isOutputVertex()) {
//Get the index of this output vertex...
int idx = configuration.getNetworkOutputs().indexOf(current.getVertexName());
outputs[idx] = out;
}
//Now, set the inputs for the next vertices:
VertexIndices[] outputsTo = current.getOutputVertices();
if (outputsTo != null) {
for (VertexIndices v : outputsTo) {
int vIdx = v.getVertexIndex();
int inputNum = v.getVertexEdgeNumber();
//This (jth) connection from the output: is the 'inputNum'th input to vertex 'vIdx'
vertices[vIdx].setInput(inputNum, out);
}
}
}
}
//As per MultiLayerNetwork.rnnTimeStep(): if inputs are all 2d, then outputs are all 2d
if (inputIs2d) {
for (int i = 0; i < outputs.length; i++) {
if (outputs[i].rank() == 3 && outputs[i].size(2) == 1) {
//Return 2d output with shape [miniBatchSize,nOut]
// instead of 3d output with shape [miniBatchSize,nOut,1]
outputs[i] = outputs[i].tensorAlongDimension(0, 1, 0);
}
}
}
this.inputs = null;
return outputs;
}
use of org.deeplearning4j.nn.api.layers.RecurrentLayer in project deeplearning4j by deeplearning4j.
the class MultiLayerNetwork method truncatedBPTTGradient.
/** Equivalent to backprop(), but calculates gradient for truncated BPTT instead. */
protected void truncatedBPTTGradient() {
if (flattenedGradients == null)
initGradientsView();
String multiGradientKey;
gradient = new DefaultGradient();
Layer currLayer;
if (!(getOutputLayer() instanceof IOutputLayer)) {
log.warn("Warning: final layer isn't output layer. You cannot use backprop (truncated BPTT) without an output layer.");
return;
}
IOutputLayer outputLayer = (IOutputLayer) getOutputLayer();
if (labels == null)
throw new IllegalStateException("No labels found");
if (outputLayer.conf().getLayer().getWeightInit() == WeightInit.ZERO) {
throw new IllegalStateException("Output layer weights cannot be initialized to zero when using backprop.");
}
outputLayer.setLabels(labels);
//calculate and apply the backward gradient for every layer
int numLayers = getnLayers();
//Store gradients is a list; used to ensure iteration order in DefaultGradient linked hash map. i.e., layer 0 first instead of output layer
LinkedList<Pair<String, INDArray>> gradientList = new LinkedList<>();
Pair<Gradient, INDArray> currPair = outputLayer.backpropGradient(null);
for (Map.Entry<String, INDArray> entry : currPair.getFirst().gradientForVariable().entrySet()) {
multiGradientKey = String.valueOf(numLayers - 1) + "_" + entry.getKey();
gradientList.addLast(new Pair<>(multiGradientKey, entry.getValue()));
}
if (getLayerWiseConfigurations().getInputPreProcess(numLayers - 1) != null)
currPair = new Pair<>(currPair.getFirst(), this.layerWiseConfigurations.getInputPreProcess(numLayers - 1).backprop(currPair.getSecond(), getInputMiniBatchSize()));
// Calculate gradients for previous layers & drops output layer in count
for (int j = numLayers - 2; j >= 0; j--) {
currLayer = getLayer(j);
if (currLayer instanceof RecurrentLayer) {
currPair = ((RecurrentLayer) currLayer).tbpttBackpropGradient(currPair.getSecond(), layerWiseConfigurations.getTbpttBackLength());
} else {
currPair = currLayer.backpropGradient(currPair.getSecond());
}
LinkedList<Pair<String, INDArray>> tempList = new LinkedList<>();
for (Map.Entry<String, INDArray> entry : currPair.getFirst().gradientForVariable().entrySet()) {
multiGradientKey = String.valueOf(j) + "_" + entry.getKey();
tempList.addFirst(new Pair<>(multiGradientKey, entry.getValue()));
}
for (Pair<String, INDArray> pair : tempList) gradientList.addFirst(pair);
//Pass epsilon through input processor before passing to next layer (if applicable)
if (getLayerWiseConfigurations().getInputPreProcess(j) != null)
currPair = new Pair<>(currPair.getFirst(), getLayerWiseConfigurations().getInputPreProcess(j).backprop(currPair.getSecond(), getInputMiniBatchSize()));
}
//Add gradients to Gradients, in correct order
for (Pair<String, INDArray> pair : gradientList) gradient.setGradientFor(pair.getFirst(), pair.getSecond());
}
use of org.deeplearning4j.nn.api.layers.RecurrentLayer in project deeplearning4j by deeplearning4j.
the class MultiLayerNetwork method rnnActivateUsingStoredState.
/** Similar to rnnTimeStep and feedForward() methods. Difference here is that this method:<br>
* (a) like rnnTimeStep does forward pass using stored state for RNN layers, and<br>
* (b) unlike rnnTimeStep does not modify the RNN layer state<br>
* Therefore multiple calls to this method with the same input should have the same output.<br>
* Typically used during training only. Use rnnTimeStep for prediction/forward pass at test time.
* @param input Input to network
* @param training Whether training or not
* @param storeLastForTBPTT set to true if used as part of truncated BPTT training
* @return Activations for each layer (including input, as per feedforward() etc)
*/
public List<INDArray> rnnActivateUsingStoredState(INDArray input, boolean training, boolean storeLastForTBPTT) {
INDArray currInput = input;
List<INDArray> activations = new ArrayList<>();
activations.add(currInput);
for (int i = 0; i < layers.length; i++) {
if (getLayerWiseConfigurations().getInputPreProcess(i) != null)
currInput = getLayerWiseConfigurations().getInputPreProcess(i).preProcess(currInput, input.size(0));
if (layers[i] instanceof RecurrentLayer) {
currInput = ((RecurrentLayer) layers[i]).rnnActivateUsingStoredState(currInput, training, storeLastForTBPTT);
} else if (layers[i] instanceof MultiLayerNetwork) {
List<INDArray> temp = ((MultiLayerNetwork) layers[i]).rnnActivateUsingStoredState(currInput, training, storeLastForTBPTT);
currInput = temp.get(temp.size() - 1);
} else {
currInput = layers[i].activate(currInput, training);
}
activations.add(currInput);
}
return activations;
}
use of org.deeplearning4j.nn.api.layers.RecurrentLayer in project deeplearning4j by deeplearning4j.
the class MultiLayerNetwork method updateRnnStateWithTBPTTState.
public void updateRnnStateWithTBPTTState() {
for (int i = 0; i < layers.length; i++) {
if (layers[i] instanceof RecurrentLayer) {
RecurrentLayer l = ((RecurrentLayer) layers[i]);
l.rnnSetPreviousState(l.rnnGetTBPTTState());
} else if (layers[i] instanceof MultiLayerNetwork) {
((MultiLayerNetwork) layers[i]).updateRnnStateWithTBPTTState();
}
}
}
Aggregations