Search in sources :

Example 1 with MaskState

use of org.deeplearning4j.nn.api.MaskState in project deeplearning4j by deeplearning4j.

the class ComputationGraph method setLayerMaskArrays.

/**
     * Set the mask arrays for features and labels. Mask arrays are typically used in situations such as one-to-many
     * and many-to-one learning with recurrent neural networks, as well as for supporting time series of varying lengths
     * within the same minibatch.<br>
     * For example, with RNN data sets with input of shape [miniBatchSize,nIn,timeSeriesLength] and outputs of shape
     * [miniBatchSize,nOut,timeSeriesLength], the features and mask arrays will have shape [miniBatchSize,timeSeriesLength]
     * and contain values 0 or 1 at each element (to specify whether a given input/example is present - or merely padding -
     * at a given time step).<br>
     * <b>NOTE</b>: This method is not usually used directly. Instead, the various feedForward and fit methods handle setting
     * of masking internally.
     *
     * @param featureMaskArrays Mask array for features (input)
     * @param labelMaskArrays   Mask array for labels (output)
     * @see #clearLayerMaskArrays()
     */
public void setLayerMaskArrays(INDArray[] featureMaskArrays, INDArray[] labelMaskArrays) {
    this.clearLayerMaskArrays();
    this.inputMaskArrays = featureMaskArrays;
    this.labelMaskArrays = labelMaskArrays;
    if (featureMaskArrays != null) {
        if (featureMaskArrays.length != numInputArrays) {
            throw new IllegalArgumentException("Invalid number of feature mask arrays");
        }
        int minibatchSize = -1;
        for (INDArray i : featureMaskArrays) {
            if (i != null) {
                minibatchSize = i.size(0);
            }
        }
        //Here: need to do forward pass through the network according to the topological ordering of the network
        Map<Integer, Pair<INDArray, MaskState>> map = new HashMap<>();
        for (int i = 0; i < topologicalOrder.length; i++) {
            GraphVertex current = vertices[topologicalOrder[i]];
            if (current.isInputVertex()) {
                INDArray fMask = featureMaskArrays[current.getVertexIndex()];
                map.put(current.getVertexIndex(), new Pair<>(fMask, MaskState.Active));
            } else {
                VertexIndices[] inputVertices = current.getInputVertices();
                //Now: work out the mask arrays to feed forward...
                //new INDArray[inputVertices.length];
                INDArray[] inputMasks = null;
                MaskState maskState = null;
                for (int j = 0; j < inputVertices.length; j++) {
                    Pair<INDArray, MaskState> p = map.get(inputVertices[j].getVertexIndex());
                    if (p != null) {
                        if (inputMasks == null) {
                            inputMasks = new INDArray[inputVertices.length];
                        }
                        inputMasks[j] = p.getFirst();
                        if (maskState == null || maskState == MaskState.Passthrough) {
                            maskState = p.getSecond();
                        }
                    }
                }
                Pair<INDArray, MaskState> outPair = current.feedForwardMaskArrays(inputMasks, maskState, minibatchSize);
                map.put(topologicalOrder[i], outPair);
            }
        }
    }
    if (labelMaskArrays != null) {
        if (labelMaskArrays.length != numOutputArrays) {
            throw new IllegalArgumentException("Invalid number of label mask arrays");
        }
        for (int i = 0; i < labelMaskArrays.length; i++) {
            if (labelMaskArrays[i] == null) {
                // This output doesn't have a mask, we can skip it.
                continue;
            }
            String outputName = configuration.getNetworkOutputs().get(i);
            GraphVertex v = verticesMap.get(outputName);
            Layer ol = v.getLayer();
            ol.setMaskArray(labelMaskArrays[i]);
        }
    }
}
Also used : Layer(org.deeplearning4j.nn.api.Layer) FrozenLayer(org.deeplearning4j.nn.layers.FrozenLayer) RecurrentLayer(org.deeplearning4j.nn.api.layers.RecurrentLayer) FeedForwardLayer(org.deeplearning4j.nn.conf.layers.FeedForwardLayer) IOutputLayer(org.deeplearning4j.nn.api.layers.IOutputLayer) GraphVertex(org.deeplearning4j.nn.graph.vertex.GraphVertex) INDArray(org.nd4j.linalg.api.ndarray.INDArray) VertexIndices(org.deeplearning4j.nn.graph.vertex.VertexIndices) MaskState(org.deeplearning4j.nn.api.MaskState) Pair(org.deeplearning4j.berkeley.Pair)

Example 2 with MaskState

use of org.deeplearning4j.nn.api.MaskState in project deeplearning4j by deeplearning4j.

the class ComposableInputPreProcessor method feedForwardMaskArray.

@Override
public Pair<INDArray, MaskState> feedForwardMaskArray(INDArray maskArray, MaskState currentMaskState, int minibatchSize) {
    for (InputPreProcessor preproc : inputPreProcessors) {
        Pair<INDArray, MaskState> p = preproc.feedForwardMaskArray(maskArray, currentMaskState, minibatchSize);
        maskArray = p.getFirst();
        currentMaskState = p.getSecond();
    }
    return new Pair<>(maskArray, currentMaskState);
}
Also used : INDArray(org.nd4j.linalg.api.ndarray.INDArray) InputPreProcessor(org.deeplearning4j.nn.conf.InputPreProcessor) MaskState(org.deeplearning4j.nn.api.MaskState) Pair(org.deeplearning4j.berkeley.Pair)

Aggregations

Pair (org.deeplearning4j.berkeley.Pair)2 MaskState (org.deeplearning4j.nn.api.MaskState)2 INDArray (org.nd4j.linalg.api.ndarray.INDArray)2 Layer (org.deeplearning4j.nn.api.Layer)1 IOutputLayer (org.deeplearning4j.nn.api.layers.IOutputLayer)1 RecurrentLayer (org.deeplearning4j.nn.api.layers.RecurrentLayer)1 InputPreProcessor (org.deeplearning4j.nn.conf.InputPreProcessor)1 FeedForwardLayer (org.deeplearning4j.nn.conf.layers.FeedForwardLayer)1 GraphVertex (org.deeplearning4j.nn.graph.vertex.GraphVertex)1 VertexIndices (org.deeplearning4j.nn.graph.vertex.VertexIndices)1 FrozenLayer (org.deeplearning4j.nn.layers.FrozenLayer)1