Search in sources :

Example 6 with InputPreProcessor

use of org.deeplearning4j.nn.conf.InputPreProcessor in project deeplearning4j by deeplearning4j.

the class MultiLayerNetwork method feedForwardMaskArray.

@Override
public Pair<INDArray, MaskState> feedForwardMaskArray(INDArray maskArray, MaskState currentMaskState, int minibatchSize) {
    if (maskArray == null) {
        for (int i = 0; i < layers.length; i++) {
            layers[i].feedForwardMaskArray(null, null, minibatchSize);
        }
    } else {
        //Do a forward pass through each preprocessor and layer
        for (int i = 0; i < layers.length; i++) {
            InputPreProcessor preProcessor = getLayerWiseConfigurations().getInputPreProcess(i);
            if (preProcessor != null) {
                Pair<INDArray, MaskState> p = preProcessor.feedForwardMaskArray(maskArray, currentMaskState, minibatchSize);
                if (p != null) {
                    maskArray = p.getFirst();
                    currentMaskState = p.getSecond();
                } else {
                    maskArray = null;
                    currentMaskState = null;
                }
            }
            Pair<INDArray, MaskState> p = layers[i].feedForwardMaskArray(maskArray, currentMaskState, minibatchSize);
            if (p != null) {
                maskArray = p.getFirst();
                currentMaskState = p.getSecond();
            } else {
                maskArray = null;
                currentMaskState = null;
            }
        }
    }
    return new Pair<>(maskArray, currentMaskState);
}
Also used : INDArray(org.nd4j.linalg.api.ndarray.INDArray) InputPreProcessor(org.deeplearning4j.nn.conf.InputPreProcessor) Pair(org.deeplearning4j.berkeley.Pair)

Aggregations

InputPreProcessor (org.deeplearning4j.nn.conf.InputPreProcessor)6 NeuralNetConfiguration (org.deeplearning4j.nn.conf.NeuralNetConfiguration)4 INDArray (org.nd4j.linalg.api.ndarray.INDArray)4 Pair (org.deeplearning4j.berkeley.Pair)2 IOutputLayer (org.deeplearning4j.nn.api.layers.IOutputLayer)2 InputType (org.deeplearning4j.nn.conf.inputs.InputType)2 ConvolutionLayer (org.deeplearning4j.nn.layers.convolution.ConvolutionLayer)2 Test (org.junit.Test)2 MaskState (org.deeplearning4j.nn.api.MaskState)1 ComputationGraphConfiguration (org.deeplearning4j.nn.conf.ComputationGraphConfiguration)1 PreprocessorVertex (org.deeplearning4j.nn.conf.graph.PreprocessorVertex)1