use of org.deeplearning4j.nn.api.MaskState in project deeplearning4j by deeplearning4j.
the class ComputationGraph method setLayerMaskArrays.
/**
* Set the mask arrays for features and labels. Mask arrays are typically used in situations such as one-to-many
* and many-to-one learning with recurrent neural networks, as well as for supporting time series of varying lengths
* within the same minibatch.<br>
* For example, with RNN data sets with input of shape [miniBatchSize,nIn,timeSeriesLength] and outputs of shape
* [miniBatchSize,nOut,timeSeriesLength], the features and mask arrays will have shape [miniBatchSize,timeSeriesLength]
* and contain values 0 or 1 at each element (to specify whether a given input/example is present - or merely padding -
* at a given time step).<br>
* <b>NOTE</b>: This method is not usually used directly. Instead, the various feedForward and fit methods handle setting
* of masking internally.
*
* @param featureMaskArrays Mask array for features (input)
* @param labelMaskArrays Mask array for labels (output)
* @see #clearLayerMaskArrays()
*/
public void setLayerMaskArrays(INDArray[] featureMaskArrays, INDArray[] labelMaskArrays) {
this.clearLayerMaskArrays();
this.inputMaskArrays = featureMaskArrays;
this.labelMaskArrays = labelMaskArrays;
if (featureMaskArrays != null) {
if (featureMaskArrays.length != numInputArrays) {
throw new IllegalArgumentException("Invalid number of feature mask arrays");
}
int minibatchSize = -1;
for (INDArray i : featureMaskArrays) {
if (i != null) {
minibatchSize = i.size(0);
}
}
//Here: need to do forward pass through the network according to the topological ordering of the network
Map<Integer, Pair<INDArray, MaskState>> map = new HashMap<>();
for (int i = 0; i < topologicalOrder.length; i++) {
GraphVertex current = vertices[topologicalOrder[i]];
if (current.isInputVertex()) {
INDArray fMask = featureMaskArrays[current.getVertexIndex()];
map.put(current.getVertexIndex(), new Pair<>(fMask, MaskState.Active));
} else {
VertexIndices[] inputVertices = current.getInputVertices();
//Now: work out the mask arrays to feed forward...
//new INDArray[inputVertices.length];
INDArray[] inputMasks = null;
MaskState maskState = null;
for (int j = 0; j < inputVertices.length; j++) {
Pair<INDArray, MaskState> p = map.get(inputVertices[j].getVertexIndex());
if (p != null) {
if (inputMasks == null) {
inputMasks = new INDArray[inputVertices.length];
}
inputMasks[j] = p.getFirst();
if (maskState == null || maskState == MaskState.Passthrough) {
maskState = p.getSecond();
}
}
}
Pair<INDArray, MaskState> outPair = current.feedForwardMaskArrays(inputMasks, maskState, minibatchSize);
map.put(topologicalOrder[i], outPair);
}
}
}
if (labelMaskArrays != null) {
if (labelMaskArrays.length != numOutputArrays) {
throw new IllegalArgumentException("Invalid number of label mask arrays");
}
for (int i = 0; i < labelMaskArrays.length; i++) {
if (labelMaskArrays[i] == null) {
// This output doesn't have a mask, we can skip it.
continue;
}
String outputName = configuration.getNetworkOutputs().get(i);
GraphVertex v = verticesMap.get(outputName);
Layer ol = v.getLayer();
ol.setMaskArray(labelMaskArrays[i]);
}
}
}
use of org.deeplearning4j.nn.api.MaskState in project deeplearning4j by deeplearning4j.
the class ComposableInputPreProcessor method feedForwardMaskArray.
@Override
public Pair<INDArray, MaskState> feedForwardMaskArray(INDArray maskArray, MaskState currentMaskState, int minibatchSize) {
for (InputPreProcessor preproc : inputPreProcessors) {
Pair<INDArray, MaskState> p = preproc.feedForwardMaskArray(maskArray, currentMaskState, minibatchSize);
maskArray = p.getFirst();
currentMaskState = p.getSecond();
}
return new Pair<>(maskArray, currentMaskState);
}
Aggregations