Search in sources :

Example 1 with IActivation

use of org.nd4j.linalg.activations.IActivation in project deeplearning4j by deeplearning4j.

the class GradientCheckUtil method checkGradients.

/**
     * Check backprop gradients for a MultiLayerNetwork.
     * @param mln MultiLayerNetwork to test. This must be initialized.
     * @param epsilon Usually on the order/ of 1e-4 or so.
     * @param maxRelError Maximum relative error. Usually < 1e-5 or so, though maybe more for deep networks or those with nonlinear activation
     * @param minAbsoluteError Minimum absolute error to cause a failure. Numerical gradients can be non-zero due to precision issues.
     *                         For example, 0.0 vs. 1e-18: relative error is 1.0, but not really a failure
     * @param print Whether to print full pass/failure details for each parameter gradient
     * @param exitOnFirstError If true: return upon first failure. If false: continue checking even if
     *  one parameter gradient has failed. Typically use false for debugging, true for unit tests.
     * @param input Input array to use for forward pass. May be mini-batch data.
     * @param labels Labels/targets to use to calculate backprop gradient. May be mini-batch data.
     * @return true if gradients are passed, false otherwise.
     */
public static boolean checkGradients(MultiLayerNetwork mln, double epsilon, double maxRelError, double minAbsoluteError, boolean print, boolean exitOnFirstError, INDArray input, INDArray labels) {
    //Basic sanity checks on input:
    if (epsilon <= 0.0 || epsilon > 0.1)
        throw new IllegalArgumentException("Invalid epsilon: expect epsilon in range (0,0.1], usually 1e-4 or so");
    if (maxRelError <= 0.0 || maxRelError > 0.25)
        throw new IllegalArgumentException("Invalid maxRelativeError: " + maxRelError);
    if (!(mln.getOutputLayer() instanceof IOutputLayer))
        throw new IllegalArgumentException("Cannot check backprop gradients without OutputLayer");
    //Check network configuration:
    int layerCount = 0;
    for (NeuralNetConfiguration n : mln.getLayerWiseConfigurations().getConfs()) {
        org.deeplearning4j.nn.conf.Updater u = n.getLayer().getUpdater();
        if (u == org.deeplearning4j.nn.conf.Updater.SGD) {
            //Must have LR of 1.0
            double lr = n.getLayer().getLearningRate();
            if (lr != 1.0) {
                throw new IllegalStateException("When using SGD updater, must also use lr=1.0 for layer " + layerCount + "; got " + u + " with lr=" + lr + " for layer \"" + n.getLayer().getLayerName() + "\"");
            }
        } else if (u != org.deeplearning4j.nn.conf.Updater.NONE) {
            throw new IllegalStateException("Must have Updater.NONE (or SGD + lr=1.0) for layer " + layerCount + "; got " + u);
        }
        double dropout = n.getLayer().getDropOut();
        if (n.isUseRegularization() && dropout != 0.0) {
            throw new IllegalStateException("Must have dropout == 0.0 for gradient checks - got dropout = " + dropout + " for layer " + layerCount);
        }
        IActivation activation = n.getLayer().getActivationFn();
        if (activation != null) {
            if (!VALID_ACTIVATION_FUNCTIONS.contains(activation.getClass())) {
                log.warn("Layer " + layerCount + " is possibly using an unsuitable activation function: " + activation.getClass() + ". Activation functions for gradient checks must be smooth (like sigmoid, tanh, softmax) and not " + "contain discontinuities like ReLU or LeakyReLU (these may cause spurious failures)");
            }
        }
    }
    mln.setInput(input);
    mln.setLabels(labels);
    mln.computeGradientAndScore();
    Pair<Gradient, Double> gradAndScore = mln.gradientAndScore();
    Updater updater = UpdaterCreator.getUpdater(mln);
    updater.update(mln, gradAndScore.getFirst(), 0, mln.batchSize());
    //need dup: gradients are a *view* of the full gradient array (which will change every time backprop is done)
    INDArray gradientToCheck = gradAndScore.getFirst().gradient().dup();
    //need dup: params are a *view* of full parameters
    INDArray originalParams = mln.params().dup();
    int nParams = originalParams.length();
    Map<String, INDArray> paramTable = mln.paramTable();
    List<String> paramNames = new ArrayList<>(paramTable.keySet());
    int[] paramEnds = new int[paramNames.size()];
    paramEnds[0] = paramTable.get(paramNames.get(0)).length();
    for (int i = 1; i < paramEnds.length; i++) {
        paramEnds[i] = paramEnds[i - 1] + paramTable.get(paramNames.get(i)).length();
    }
    int totalNFailures = 0;
    double maxError = 0.0;
    DataSet ds = new DataSet(input, labels);
    int currParamNameIdx = 0;
    //Assumption here: params is a view that we can modify in-place
    INDArray params = mln.params();
    for (int i = 0; i < nParams; i++) {
        //Get param name
        if (i >= paramEnds[currParamNameIdx]) {
            currParamNameIdx++;
        }
        String paramName = paramNames.get(currParamNameIdx);
        //(w+epsilon): Do forward pass and score
        double origValue = params.getDouble(i);
        params.putScalar(i, origValue + epsilon);
        double scorePlus = mln.score(ds, true);
        //(w-epsilon): Do forward pass and score
        params.putScalar(i, origValue - epsilon);
        double scoreMinus = mln.score(ds, true);
        //Reset original param value
        params.putScalar(i, origValue);
        //Calculate numerical parameter gradient:
        double scoreDelta = scorePlus - scoreMinus;
        double numericalGradient = scoreDelta / (2 * epsilon);
        if (Double.isNaN(numericalGradient))
            throw new IllegalStateException("Numerical gradient was NaN for parameter " + i + " of " + nParams);
        double backpropGradient = gradientToCheck.getDouble(i);
        //http://cs231n.github.io/neural-networks-3/#gradcheck
        //use mean centered
        double relError = Math.abs(backpropGradient - numericalGradient) / (Math.abs(numericalGradient) + Math.abs(backpropGradient));
        if (backpropGradient == 0.0 && numericalGradient == 0.0)
            //Edge case: i.e., RNNs with time series length of 1.0
            relError = 0.0;
        if (relError > maxError)
            maxError = relError;
        if (relError > maxRelError || Double.isNaN(relError)) {
            double absError = Math.abs(backpropGradient - numericalGradient);
            if (absError < minAbsoluteError) {
                log.info("Param " + i + " (" + paramName + ") passed: grad= " + backpropGradient + ", numericalGrad= " + numericalGradient + ", relError= " + relError + "; absolute error = " + absError + " < minAbsoluteError = " + minAbsoluteError);
            } else {
                if (print)
                    log.info("Param " + i + " (" + paramName + ") FAILED: grad= " + backpropGradient + ", numericalGrad= " + numericalGradient + ", relError= " + relError + ", scorePlus=" + scorePlus + ", scoreMinus= " + scoreMinus);
                if (exitOnFirstError)
                    return false;
                totalNFailures++;
            }
        } else if (print) {
            log.info("Param " + i + " (" + paramName + ") passed: grad= " + backpropGradient + ", numericalGrad= " + numericalGradient + ", relError= " + relError);
        }
    }
    if (print) {
        int nPass = nParams - totalNFailures;
        log.info("GradientCheckUtil.checkGradients(): " + nParams + " params checked, " + nPass + " passed, " + totalNFailures + " failed. Largest relative error = " + maxError);
    }
    return totalNFailures == 0;
}
Also used : Gradient(org.deeplearning4j.nn.gradient.Gradient) DataSet(org.nd4j.linalg.dataset.DataSet) MultiDataSet(org.nd4j.linalg.dataset.MultiDataSet) ArrayList(java.util.ArrayList) NeuralNetConfiguration(org.deeplearning4j.nn.conf.NeuralNetConfiguration) IActivation(org.nd4j.linalg.activations.IActivation) INDArray(org.nd4j.linalg.api.ndarray.INDArray) ComputationGraphUpdater(org.deeplearning4j.nn.updater.graph.ComputationGraphUpdater) Updater(org.deeplearning4j.nn.api.Updater) IOutputLayer(org.deeplearning4j.nn.api.layers.IOutputLayer)

Example 2 with IActivation

use of org.nd4j.linalg.activations.IActivation in project deeplearning4j by deeplearning4j.

the class MultiLayerConfiguration method fromJson.

/**
     * Create a neural net configuration from json
     * @param json the neural net configuration from json
     * @return {@link MultiLayerConfiguration}
     */
public static MultiLayerConfiguration fromJson(String json) {
    MultiLayerConfiguration conf;
    ObjectMapper mapper = NeuralNetConfiguration.mapper();
    try {
        conf = mapper.readValue(json, MultiLayerConfiguration.class);
    } catch (IOException e) {
        throw new RuntimeException(e);
    }
    //To maintain backward compatibility after loss function refactoring (configs generated with v0.5.0 or earlier)
    // Previously: enumeration used for loss functions. Now: use classes
    // IN the past, could have only been an OutputLayer or RnnOutputLayer using these enums
    int layerCount = 0;
    JsonNode confs = null;
    for (NeuralNetConfiguration nnc : conf.getConfs()) {
        Layer l = nnc.getLayer();
        if (l instanceof BaseOutputLayer && ((BaseOutputLayer) l).getLossFn() == null) {
            //lossFn field null -> may be an old config format, with lossFunction field being for the enum
            //if so, try walking the JSON graph to extract out the appropriate enum value
            BaseOutputLayer ol = (BaseOutputLayer) l;
            try {
                JsonNode jsonNode = mapper.readTree(json);
                if (confs == null) {
                    confs = jsonNode.get("confs");
                }
                if (confs instanceof ArrayNode) {
                    ArrayNode layerConfs = (ArrayNode) confs;
                    JsonNode outputLayerNNCNode = layerConfs.get(layerCount);
                    if (outputLayerNNCNode == null)
                        //Should never happen...
                        return conf;
                    JsonNode outputLayerNode = outputLayerNNCNode.get("layer");
                    JsonNode lossFunctionNode = null;
                    if (outputLayerNode.has("output")) {
                        lossFunctionNode = outputLayerNode.get("output").get("lossFunction");
                    } else if (outputLayerNode.has("rnnoutput")) {
                        lossFunctionNode = outputLayerNode.get("rnnoutput").get("lossFunction");
                    }
                    if (lossFunctionNode != null) {
                        String lossFunctionEnumStr = lossFunctionNode.asText();
                        LossFunctions.LossFunction lossFunction = null;
                        try {
                            lossFunction = LossFunctions.LossFunction.valueOf(lossFunctionEnumStr);
                        } catch (Exception e) {
                            log.warn("OutputLayer with null LossFunction or pre-0.6.0 loss function configuration detected: could not parse JSON", e);
                        }
                        if (lossFunction != null) {
                            switch(lossFunction) {
                                case MSE:
                                    ol.setLossFn(new LossMSE());
                                    break;
                                case XENT:
                                    ol.setLossFn(new LossBinaryXENT());
                                    break;
                                case NEGATIVELOGLIKELIHOOD:
                                    ol.setLossFn(new LossNegativeLogLikelihood());
                                    break;
                                case MCXENT:
                                    ol.setLossFn(new LossMCXENT());
                                    break;
                                //Remaining: TODO
                                case EXPLL:
                                case RMSE_XENT:
                                case SQUARED_LOSS:
                                case RECONSTRUCTION_CROSSENTROPY:
                                case CUSTOM:
                                default:
                                    log.warn("OutputLayer with null LossFunction or pre-0.6.0 loss function configuration detected: could not set loss function for {}", lossFunction);
                                    break;
                            }
                        }
                    }
                } else {
                    log.warn("OutputLayer with null LossFunction or pre-0.6.0 loss function configuration detected: could not parse JSON: layer 'confs' field is not an ArrayNode (is: {})", (confs != null ? confs.getClass() : null));
                }
            } catch (IOException e) {
                log.warn("OutputLayer with null LossFunction or pre-0.6.0 loss function configuration detected: could not parse JSON", e);
                break;
            }
        }
        //Try to load the old format if necessary, and create the appropriate IActivation instance
        if (l.getActivationFn() == null) {
            try {
                JsonNode jsonNode = mapper.readTree(json);
                if (confs == null) {
                    confs = jsonNode.get("confs");
                }
                if (confs instanceof ArrayNode) {
                    ArrayNode layerConfs = (ArrayNode) confs;
                    JsonNode outputLayerNNCNode = layerConfs.get(layerCount);
                    if (outputLayerNNCNode == null)
                        //Should never happen...
                        return conf;
                    JsonNode layerWrapperNode = outputLayerNNCNode.get("layer");
                    if (layerWrapperNode == null || layerWrapperNode.size() != 1) {
                        continue;
                    }
                    JsonNode layerNode = layerWrapperNode.elements().next();
                    //Should only have 1 element: "dense", "output", etc
                    JsonNode activationFunction = layerNode.get("activationFunction");
                    if (activationFunction != null) {
                        IActivation ia = Activation.fromString(activationFunction.asText()).getActivationFunction();
                        l.setActivationFn(ia);
                    }
                }
            } catch (IOException e) {
                log.warn("Layer with null ActivationFn field or pre-0.7.2 activation function detected: could not parse JSON", e);
            }
        }
        layerCount++;
    }
    return conf;
}
Also used : LossNegativeLogLikelihood(org.nd4j.linalg.lossfunctions.impl.LossNegativeLogLikelihood) LossMCXENT(org.nd4j.linalg.lossfunctions.impl.LossMCXENT) JsonNode(org.nd4j.shade.jackson.databind.JsonNode) IOException(java.io.IOException) IActivation(org.nd4j.linalg.activations.IActivation) IOException(java.io.IOException) LossFunctions(org.nd4j.linalg.lossfunctions.LossFunctions) LossMSE(org.nd4j.linalg.lossfunctions.impl.LossMSE) ArrayNode(org.nd4j.shade.jackson.databind.node.ArrayNode) LossBinaryXENT(org.nd4j.linalg.lossfunctions.impl.LossBinaryXENT) ObjectMapper(org.nd4j.shade.jackson.databind.ObjectMapper)

Example 3 with IActivation

use of org.nd4j.linalg.activations.IActivation in project deeplearning4j by deeplearning4j.

the class VariationalAutoencoder method reconstructionLogProbability.

/**
     * Return the log reconstruction probability given the specified number of samples.<br>
     * See {@link #reconstructionLogProbability(INDArray, int)} for more details
     *
     * @param data       The data to calculate the log reconstruction probability
     * @param numSamples Number of samples with which to base the reconstruction probability on.
     * @return Column vector of reconstruction log probabilities for each example (shape: [numExamples,1])
     */
public INDArray reconstructionLogProbability(INDArray data, int numSamples) {
    if (numSamples <= 0) {
        throw new IllegalArgumentException("Invalid input: numSamples must be > 0. Got: " + numSamples);
    }
    if (reconstructionDistribution instanceof LossFunctionWrapper) {
        throw new UnsupportedOperationException("Cannot calculate reconstruction log probability when using " + "a LossFunction (via LossFunctionWrapper) instead of a ReconstructionDistribution: ILossFunction " + "instances are not in general probabilistic, hence it is not possible to calculate reconstruction probability");
    }
    //Forward pass through the encoder and mean for P(Z|X)
    setInput(data);
    VAEFwdHelper fwd = doForward(true, true);
    IActivation afn = conf().getLayer().getActivationFn();
    //Forward pass through logStd^2 for P(Z|X)
    INDArray pzxLogStd2W = params.get(VariationalAutoencoderParamInitializer.PZX_LOGSTD2_W);
    INDArray pzxLogStd2b = params.get(VariationalAutoencoderParamInitializer.PZX_LOGSTD2_B);
    INDArray meanZ = fwd.pzxMeanPreOut;
    INDArray logStdev2Z = fwd.encoderActivations[fwd.encoderActivations.length - 1].mmul(pzxLogStd2W).addiRowVector(pzxLogStd2b);
    pzxActivationFn.getActivation(meanZ, false);
    pzxActivationFn.getActivation(logStdev2Z, false);
    INDArray pzxSigma = Transforms.exp(logStdev2Z, false);
    Transforms.sqrt(pzxSigma, false);
    int minibatch = input.size(0);
    int size = fwd.pzxMeanPreOut.size(1);
    INDArray pxzw = params.get(VariationalAutoencoderParamInitializer.PXZ_W);
    INDArray pxzb = params.get(VariationalAutoencoderParamInitializer.PXZ_B);
    INDArray[] decoderWeights = new INDArray[decoderLayerSizes.length];
    INDArray[] decoderBiases = new INDArray[decoderLayerSizes.length];
    for (int i = 0; i < decoderLayerSizes.length; i++) {
        String wKey = "d" + i + WEIGHT_KEY_SUFFIX;
        String bKey = "d" + i + BIAS_KEY_SUFFIX;
        decoderWeights[i] = params.get(wKey);
        decoderBiases[i] = params.get(bKey);
    }
    INDArray sumReconstructionNegLogProbability = null;
    for (int i = 0; i < numSamples; i++) {
        INDArray e = Nd4j.randn(minibatch, size);
        //z = mu + sigma * e, with e ~ N(0,1)
        INDArray z = e.muli(pzxSigma).addi(meanZ);
        //Do forward pass through decoder
        int nDecoderLayers = decoderLayerSizes.length;
        INDArray currentActivations = z;
        for (int j = 0; j < nDecoderLayers; j++) {
            currentActivations = currentActivations.mmul(decoderWeights[j]).addiRowVector(decoderBiases[j]);
            afn.getActivation(currentActivations, false);
        }
        //And calculate reconstruction distribution preOut
        INDArray pxzDistributionPreOut = currentActivations.mmul(pxzw).addiRowVector(pxzb);
        if (i == 0) {
            sumReconstructionNegLogProbability = reconstructionDistribution.exampleNegLogProbability(data, pxzDistributionPreOut);
        } else {
            sumReconstructionNegLogProbability.addi(reconstructionDistribution.exampleNegLogProbability(data, pxzDistributionPreOut));
        }
    }
    setInput(null);
    return sumReconstructionNegLogProbability.divi(-numSamples);
}
Also used : INDArray(org.nd4j.linalg.api.ndarray.INDArray) IActivation(org.nd4j.linalg.activations.IActivation) LossFunctionWrapper(org.deeplearning4j.nn.conf.layers.variational.LossFunctionWrapper)

Example 4 with IActivation

use of org.nd4j.linalg.activations.IActivation in project deeplearning4j by deeplearning4j.

the class VariationalAutoencoder method decodeGivenLatentSpaceValues.

private INDArray decodeGivenLatentSpaceValues(INDArray latentSpaceValues) {
    if (latentSpaceValues.size(1) != params.get(VariationalAutoencoderParamInitializer.PZX_MEAN_W).size(1)) {
        throw new IllegalArgumentException("Invalid latent space values: expected size " + params.get(VariationalAutoencoderParamInitializer.PZX_MEAN_W).size(1) + ", got size (dimension 1) = " + latentSpaceValues.size(1));
    }
    //Do forward pass through decoder
    int nDecoderLayers = decoderLayerSizes.length;
    INDArray currentActivations = latentSpaceValues;
    IActivation afn = conf().getLayer().getActivationFn();
    for (int i = 0; i < nDecoderLayers; i++) {
        String wKey = "d" + i + WEIGHT_KEY_SUFFIX;
        String bKey = "d" + i + BIAS_KEY_SUFFIX;
        INDArray w = params.get(wKey);
        INDArray b = params.get(bKey);
        currentActivations = currentActivations.mmul(w).addiRowVector(b);
        afn.getActivation(currentActivations, false);
    }
    INDArray pxzw = params.get(VariationalAutoencoderParamInitializer.PXZ_W);
    INDArray pxzb = params.get(VariationalAutoencoderParamInitializer.PXZ_B);
    return currentActivations.mmul(pxzw).addiRowVector(pxzb);
}
Also used : INDArray(org.nd4j.linalg.api.ndarray.INDArray) IActivation(org.nd4j.linalg.activations.IActivation)

Example 5 with IActivation

use of org.nd4j.linalg.activations.IActivation in project deeplearning4j by deeplearning4j.

the class LSTMHelpers method backpropGradientHelper.

public static Pair<Gradient, INDArray> backpropGradientHelper(final NeuralNetConfiguration conf, //Shape: [hiddenLayerSize,4*hiddenLayerSize+3]; order: [wI,wF,wO,wG,wFF,wOO,wGG]
final IActivation gateActivationFn, //Shape: [hiddenLayerSize,4*hiddenLayerSize+3]; order: [wI,wF,wO,wG,wFF,wOO,wGG]
final INDArray input, //Shape: [hiddenLayerSize,4*hiddenLayerSize+3]; order: [wI,wF,wO,wG,wFF,wOO,wGG]
final INDArray recurrentWeights, //Shape: [n^(L-1),4*hiddenLayerSize]; order: [wi,wf,wo,wg]
final INDArray inputWeights, final INDArray epsilon, final boolean truncatedBPTT, final int tbpttBackwardLength, final FwdPassReturn fwdPass, final boolean forwards, final String inputWeightKey, final String recurrentWeightKey, final String biasWeightKey, //Input mask: should only be used with bidirectional RNNs + variable length
final Map<String, INDArray> gradientViews, //Input mask: should only be used with bidirectional RNNs + variable length
INDArray maskArray) {
    //Expect errors to have shape: [miniBatchSize,n^(L+1),timeSeriesLength]
    //i.e., n^L
    int hiddenLayerSize = recurrentWeights.size(0);
    //n^(L-1)
    int prevLayerSize = inputWeights.size(0);
    int miniBatchSize = epsilon.size(0);
    //Edge case: T=1 may have shape [miniBatchSize,n^(L+1)], equiv. to [miniBatchSize,n^(L+1),1]
    boolean is2dInput = epsilon.rank() < 3;
    int timeSeriesLength = (is2dInput ? 1 : epsilon.size(2));
    INDArray wFFTranspose = recurrentWeights.get(NDArrayIndex.all(), point(4 * hiddenLayerSize)).transpose();
    INDArray wOOTranspose = recurrentWeights.get(NDArrayIndex.all(), point(4 * hiddenLayerSize + 1)).transpose();
    INDArray wGGTranspose = recurrentWeights.get(NDArrayIndex.all(), point(4 * hiddenLayerSize + 2)).transpose();
    INDArray wIFOG = recurrentWeights.get(NDArrayIndex.all(), NDArrayIndex.interval(0, 4 * hiddenLayerSize));
    //F order here so that content for time steps are together
    //i.e., what would be W^L*(delta^L)^T. Shape: [m,n^(L-1),T]
    INDArray epsilonNext = Nd4j.create(new int[] { miniBatchSize, prevLayerSize, timeSeriesLength }, 'f');
    INDArray nablaCellStateNext = null;
    INDArray deltaifogNext = Nd4j.create(new int[] { miniBatchSize, 4 * hiddenLayerSize }, 'f');
    INDArray deltaiNext = deltaifogNext.get(NDArrayIndex.all(), NDArrayIndex.interval(0, hiddenLayerSize));
    INDArray deltafNext = deltaifogNext.get(NDArrayIndex.all(), NDArrayIndex.interval(hiddenLayerSize, 2 * hiddenLayerSize));
    INDArray deltaoNext = deltaifogNext.get(NDArrayIndex.all(), NDArrayIndex.interval(2 * hiddenLayerSize, 3 * hiddenLayerSize));
    INDArray deltagNext = deltaifogNext.get(NDArrayIndex.all(), NDArrayIndex.interval(3 * hiddenLayerSize, 4 * hiddenLayerSize));
    Level1 l1BLAS = Nd4j.getBlasWrapper().level1();
    int endIdx = 0;
    if (truncatedBPTT) {
        endIdx = Math.max(0, timeSeriesLength - tbpttBackwardLength);
    }
    //Get gradients. Note that we have to manually zero these, as they might not be initialized (or still has data from last iteration)
    //Also note that they are in f order (as per param initializer) so can be used in gemm etc
    INDArray iwGradientsOut = gradientViews.get(inputWeightKey);
    //Order: {I,F,O,G,FF,OO,GG}
    INDArray rwGradientsOut = gradientViews.get(recurrentWeightKey);
    INDArray bGradientsOut = gradientViews.get(biasWeightKey);
    iwGradientsOut.assign(0);
    rwGradientsOut.assign(0);
    bGradientsOut.assign(0);
    INDArray rwGradientsIFOG = rwGradientsOut.get(NDArrayIndex.all(), NDArrayIndex.interval(0, 4 * hiddenLayerSize));
    INDArray rwGradientsFF = rwGradientsOut.get(NDArrayIndex.all(), NDArrayIndex.point(4 * hiddenLayerSize));
    INDArray rwGradientsOO = rwGradientsOut.get(NDArrayIndex.all(), NDArrayIndex.point(4 * hiddenLayerSize + 1));
    INDArray rwGradientsGG = rwGradientsOut.get(NDArrayIndex.all(), NDArrayIndex.point(4 * hiddenLayerSize + 2));
    boolean sigmoidGates = gateActivationFn instanceof ActivationSigmoid;
    IActivation afn = conf.getLayer().getActivationFn();
    INDArray timeStepMaskColumn = null;
    for (int iTimeIndex = timeSeriesLength - 1; iTimeIndex >= endIdx; iTimeIndex--) {
        int time = iTimeIndex;
        int inext = 1;
        if (!forwards) {
            time = timeSeriesLength - iTimeIndex - 1;
            inext = -1;
        }
        //First: calclate the components of nablaCellState that relies on the next time step deltas, so we can overwrite the deltas
        INDArray nablaCellState;
        if (iTimeIndex != timeSeriesLength - 1) {
            nablaCellState = deltafNext.dup('f').muliRowVector(wFFTranspose);
            l1BLAS.axpy(nablaCellState.length(), 1.0, deltagNext.dup('f').muliRowVector(wGGTranspose), nablaCellState);
        } else {
            nablaCellState = Nd4j.create(new int[] { miniBatchSize, hiddenLayerSize }, 'f');
        }
        INDArray prevMemCellState = (iTimeIndex == 0 ? null : fwdPass.memCellState[time - inext]);
        INDArray prevHiddenUnitActivation = (iTimeIndex == 0 ? null : fwdPass.fwdPassOutputAsArrays[time - inext]);
        INDArray currMemCellState = fwdPass.memCellState[time];
        //LSTM unit output errors (dL/d(a_out)); not to be confused with \delta=dL/d(z_out)
        //(w^{L+1}*(delta^{(L+1)t})^T)^T or equiv.
        INDArray epsilonSlice = (is2dInput ? epsilon : epsilon.tensorAlongDimension(time, 1, 0));
        //Shape: [m,n^L]
        INDArray nablaOut = Shape.toOffsetZeroCopy(epsilonSlice, 'f');
        if (iTimeIndex != timeSeriesLength - 1) {
            //if t == timeSeriesLength-1 then deltaiNext etc are zeros
            Nd4j.gemm(deltaifogNext, wIFOG, nablaOut, false, true, 1.0, 1.0);
        }
        //Output gate deltas:
        INDArray sigmahOfS = fwdPass.memCellActivations[time];
        INDArray ao = fwdPass.oa[time];
        //Normally would use zo.dup() in above line, but won't be using zo again (for this time step). Ditto for zf, zg, zi
        INDArray deltao = deltaoNext;
        Nd4j.getExecutioner().exec(new MulOp(nablaOut, sigmahOfS, deltao));
        if (sigmoidGates) {
            //Equivalent to sigmoid deriv on zo
            INDArray sigmaoPrimeOfZo = Nd4j.getExecutioner().execAndReturn(new TimesOneMinus(ao.dup('f')));
            deltao.muli(sigmaoPrimeOfZo);
        } else {
            //Deltao needs to be modified in-place
            deltao.assign(gateActivationFn.backprop(fwdPass.oz[time], deltao).getFirst());
        //TODO: optimize (no assign)
        }
        //Memory cell error:
        //TODO activation functions with params
        INDArray temp = afn.backprop(currMemCellState.dup('f'), ao.muli(nablaOut)).getFirst();
        l1BLAS.axpy(nablaCellState.length(), 1.0, temp, nablaCellState);
        INDArray deltaMulRowWOO = deltao.dup('f').muliRowVector(wOOTranspose);
        //nablaCellState.addi(deltao.mulRowVector(wOOTranspose));
        l1BLAS.axpy(nablaCellState.length(), 1.0, deltaMulRowWOO, nablaCellState);
        if (iTimeIndex != timeSeriesLength - 1) {
            INDArray nextForgetGateAs = fwdPass.fa[time + inext];
            int length = nablaCellState.length();
            //nablaCellState.addi(nextForgetGateAs.mul(nablaCellStateNext))
            l1BLAS.axpy(length, 1.0, nextForgetGateAs.muli(nablaCellStateNext), nablaCellState);
        }
        //Store for use in next iteration
        nablaCellStateNext = nablaCellState;
        //Forget gate delta:
        INDArray af = fwdPass.fa[time];
        INDArray deltaf = null;
        if (iTimeIndex > 0) {
            deltaf = deltafNext;
            if (sigmoidGates) {
                Nd4j.getExecutioner().exec(new TimesOneMinus(af, deltaf));
                deltaf.muli(nablaCellState);
                deltaf.muli(prevMemCellState);
            } else {
                INDArray temp2 = nablaCellState.mul(prevMemCellState);
                //deltaf needs to be modified in-place
                deltaf.assign(gateActivationFn.backprop(fwdPass.fz[time].dup('f'), temp2).getFirst());
            //TODO activation functions with params
            }
        }
        //Shape: [m,n^L]
        //Input modulation gate delta:
        INDArray ag = fwdPass.ga[time];
        INDArray ai = fwdPass.ia[time];
        INDArray deltag = deltagNext;
        if (sigmoidGates) {
            //Equivalent to sigmoid deriv on zg
            Nd4j.getExecutioner().exec(new TimesOneMinus(ag, deltag));
            deltag.muli(ai);
            deltag.muli(nablaCellState);
        } else {
            INDArray temp2 = Nd4j.getExecutioner().execAndReturn(new MulOp(ai, nablaCellState, Nd4j.createUninitialized(ai.shape(), 'f')));
            deltag.assign(gateActivationFn.backprop(fwdPass.gz[time], temp2).getFirst());
        //TODO activation functions with params; optimize (no assign)
        }
        //Shape: [m,n^L]
        //Network input delta:
        INDArray zi = fwdPass.iz[time];
        INDArray deltai = deltaiNext;
        temp = Nd4j.getExecutioner().execAndReturn(new MulOp(ag, nablaCellState, Nd4j.createUninitialized(deltai.shape(), 'f')));
        deltai.assign(afn.backprop(zi, temp).getFirst());
        //Handle masking
        if (maskArray != null) {
            //Mask array is present: bidirectional RNN -> need to zero out these errors to avoid using errors from a masked time step
            // to calculate the parameter gradients.  Mask array has shape [minibatch, timeSeriesLength] -> get column(this time step)
            timeStepMaskColumn = maskArray.getColumn(time);
            deltaifogNext.muliColumnVector(timeStepMaskColumn);
        //Later, the deltaifogNext is used to calculate: input weight gradients, recurrent weight gradients, bias gradients
        }
        INDArray prevLayerActivationSlice = Shape.toMmulCompatible(is2dInput ? input : input.tensorAlongDimension(time, 1, 0));
        if (iTimeIndex > 0) {
            //Again, deltaifog_current == deltaifogNext at this point... same array
            Nd4j.gemm(prevLayerActivationSlice, deltaifogNext, iwGradientsOut, true, false, 1.0, 1.0);
        } else {
            INDArray iwGradients_i = iwGradientsOut.get(NDArrayIndex.all(), NDArrayIndex.interval(0, hiddenLayerSize));
            Nd4j.gemm(prevLayerActivationSlice, deltai, iwGradients_i, true, false, 1.0, 1.0);
            INDArray iwGradients_og = iwGradientsOut.get(NDArrayIndex.all(), NDArrayIndex.interval(2 * hiddenLayerSize, 4 * hiddenLayerSize));
            INDArray deltaog = deltaifogNext.get(NDArrayIndex.all(), NDArrayIndex.interval(2 * hiddenLayerSize, 4 * hiddenLayerSize));
            Nd4j.gemm(prevLayerActivationSlice, deltaog, iwGradients_og, true, false, 1.0, 1.0);
        }
        if (iTimeIndex > 0) {
            //If t==0, then prevHiddenUnitActivation==zeros(n^L,n^L), so dL/dW for recurrent weights will end up as 0 anyway
            //At this point: deltaifog and deltaifogNext are the same thing...
            //So what we are actually doing here is sum of (prevAct^transpose * deltaifog_current)
            Nd4j.gemm(prevHiddenUnitActivation, deltaifogNext, rwGradientsIFOG, true, false, 1.0, 1.0);
            //Shape: [1,n^L]. sum(0) is sum over examples in mini-batch.
            //Can use axpy here because result of sum and rwGradients[4 to 6] have order Nd4j.order(), via Nd4j.create()
            //mul not mmul because these weights are from unit j->j only (whereas other recurrent weights are i->j for all i,j)
            INDArray dLdwFF = deltaf.dup('f').muli(prevMemCellState).sum(0);
            //rwGradients[4].addi(dLdwFF);    //dL/dw_{FF}
            l1BLAS.axpy(hiddenLayerSize, 1.0, dLdwFF, rwGradientsFF);
            INDArray dLdwGG = deltag.dup('f').muli(prevMemCellState).sum(0);
            //rwGradients[6].addi(dLdwGG);
            l1BLAS.axpy(hiddenLayerSize, 1.0, dLdwGG, rwGradientsGG);
        }
        //Expected shape: [n^L,1]. sum(0) is sum over examples in mini-batch.
        INDArray dLdwOO = deltao.dup('f').muli(currMemCellState).sum(0);
        //rwGradients[5].addi(dLdwOO);    //dL/dw_{OOxy}
        l1BLAS.axpy(hiddenLayerSize, 1.0, dLdwOO, rwGradientsOO);
        if (iTimeIndex > 0) {
            l1BLAS.axpy(4 * hiddenLayerSize, 1.0, deltaifogNext.sum(0), bGradientsOut);
        } else {
            //Sneaky way to do bGradients_i += deltai.sum(0)
            l1BLAS.axpy(hiddenLayerSize, 1.0, deltai.sum(0), bGradientsOut);
            INDArray ogBiasToAdd = deltaifogNext.get(NDArrayIndex.all(), NDArrayIndex.interval(2 * hiddenLayerSize, 4 * hiddenLayerSize)).sum(0);
            INDArray ogBiasGrad = bGradientsOut.get(NDArrayIndex.point(0), NDArrayIndex.interval(2 * hiddenLayerSize, 4 * hiddenLayerSize));
            l1BLAS.axpy(2 * hiddenLayerSize, 1.0, ogBiasToAdd, ogBiasGrad);
        }
        //Calculate epsilonNext - i.e., equiv. to what would be (w^L*(d^(Lt))^T)^T in a normal network
        //But here, need to add 4 weights * deltas for the IFOG gates
        //This slice: f order and contiguous, due to epsilonNext being defined as f order.
        INDArray epsilonNextSlice = epsilonNext.tensorAlongDimension(time, 1, 0);
        if (iTimeIndex > 0) {
            Nd4j.gemm(deltaifogNext, inputWeights, epsilonNextSlice, false, true, 1.0, 1.0);
        } else {
            //No contribution from forget gate at t=0
            INDArray wi = inputWeights.get(NDArrayIndex.all(), NDArrayIndex.interval(0, hiddenLayerSize));
            Nd4j.gemm(deltai, wi, epsilonNextSlice, false, true, 1.0, 1.0);
            INDArray deltaog = deltaifogNext.get(NDArrayIndex.all(), NDArrayIndex.interval(2 * hiddenLayerSize, 4 * hiddenLayerSize));
            INDArray wog = inputWeights.get(NDArrayIndex.all(), NDArrayIndex.interval(2 * hiddenLayerSize, 4 * hiddenLayerSize));
            //epsilonNextSlice.addi(deltao.mmul(woTranspose)).addi(deltag.mmul(wgTranspose));
            Nd4j.gemm(deltaog, wog, epsilonNextSlice, false, true, 1.0, 1.0);
        }
        if (maskArray != null) {
            //Mask array is present: bidirectional RNN -> need to zero out these errors to avoid sending anything
            // but 0s to the layer below at this time step (for the given example)
            epsilonNextSlice.muliColumnVector(timeStepMaskColumn);
        }
    }
    Gradient retGradient = new DefaultGradient();
    retGradient.gradientForVariable().put(inputWeightKey, iwGradientsOut);
    retGradient.gradientForVariable().put(recurrentWeightKey, rwGradientsOut);
    retGradient.gradientForVariable().put(biasWeightKey, bGradientsOut);
    return new Pair<>(retGradient, epsilonNext);
}
Also used : Gradient(org.deeplearning4j.nn.gradient.Gradient) DefaultGradient(org.deeplearning4j.nn.gradient.DefaultGradient) DefaultGradient(org.deeplearning4j.nn.gradient.DefaultGradient) INDArray(org.nd4j.linalg.api.ndarray.INDArray) TimesOneMinus(org.nd4j.linalg.api.ops.impl.transforms.TimesOneMinus) MulOp(org.nd4j.linalg.api.ops.impl.transforms.arithmetic.MulOp) ActivationSigmoid(org.nd4j.linalg.activations.impl.ActivationSigmoid) Level1(org.nd4j.linalg.api.blas.Level1) IActivation(org.nd4j.linalg.activations.IActivation) NDArrayIndex.point(org.nd4j.linalg.indexing.NDArrayIndex.point) Pair(org.deeplearning4j.berkeley.Pair)

Aggregations

IActivation (org.nd4j.linalg.activations.IActivation)12 INDArray (org.nd4j.linalg.api.ndarray.INDArray)10 Gradient (org.deeplearning4j.nn.gradient.Gradient)6 DefaultGradient (org.deeplearning4j.nn.gradient.DefaultGradient)4 Pair (org.deeplearning4j.berkeley.Pair)3 Level1 (org.nd4j.linalg.api.blas.Level1)3 IOException (java.io.IOException)2 ArrayList (java.util.ArrayList)2 GraphVertex (org.deeplearning4j.nn.conf.graph.GraphVertex)2 LayerVertex (org.deeplearning4j.nn.conf.graph.LayerVertex)2 ComputationGraphUpdater (org.deeplearning4j.nn.updater.graph.ComputationGraphUpdater)2 ActivationSigmoid (org.nd4j.linalg.activations.impl.ActivationSigmoid)2 MultiDataSet (org.nd4j.linalg.dataset.MultiDataSet)2 NDArrayIndex.point (org.nd4j.linalg.indexing.NDArrayIndex.point)2 JsonNode (org.nd4j.shade.jackson.databind.JsonNode)2 ObjectMapper (org.nd4j.shade.jackson.databind.ObjectMapper)2 DL4JInvalidInputException (org.deeplearning4j.exception.DL4JInvalidInputException)1 Updater (org.deeplearning4j.nn.api.Updater)1 IOutputLayer (org.deeplearning4j.nn.api.layers.IOutputLayer)1 NeuralNetConfiguration (org.deeplearning4j.nn.conf.NeuralNetConfiguration)1