Search in sources :

Example 1 with MulOp

use of org.nd4j.linalg.api.ops.impl.transforms.arithmetic.MulOp in project deeplearning4j by deeplearning4j.

the class LSTMHelpers method backpropGradientHelper.

public static Pair<Gradient, INDArray> backpropGradientHelper(final NeuralNetConfiguration conf, //Shape: [hiddenLayerSize,4*hiddenLayerSize+3]; order: [wI,wF,wO,wG,wFF,wOO,wGG]
final IActivation gateActivationFn, //Shape: [hiddenLayerSize,4*hiddenLayerSize+3]; order: [wI,wF,wO,wG,wFF,wOO,wGG]
final INDArray input, //Shape: [hiddenLayerSize,4*hiddenLayerSize+3]; order: [wI,wF,wO,wG,wFF,wOO,wGG]
final INDArray recurrentWeights, //Shape: [n^(L-1),4*hiddenLayerSize]; order: [wi,wf,wo,wg]
final INDArray inputWeights, final INDArray epsilon, final boolean truncatedBPTT, final int tbpttBackwardLength, final FwdPassReturn fwdPass, final boolean forwards, final String inputWeightKey, final String recurrentWeightKey, final String biasWeightKey, //Input mask: should only be used with bidirectional RNNs + variable length
final Map<String, INDArray> gradientViews, //Input mask: should only be used with bidirectional RNNs + variable length
INDArray maskArray) {
    //Expect errors to have shape: [miniBatchSize,n^(L+1),timeSeriesLength]
    //i.e., n^L
    int hiddenLayerSize = recurrentWeights.size(0);
    //n^(L-1)
    int prevLayerSize = inputWeights.size(0);
    int miniBatchSize = epsilon.size(0);
    //Edge case: T=1 may have shape [miniBatchSize,n^(L+1)], equiv. to [miniBatchSize,n^(L+1),1]
    boolean is2dInput = epsilon.rank() < 3;
    int timeSeriesLength = (is2dInput ? 1 : epsilon.size(2));
    INDArray wFFTranspose = recurrentWeights.get(NDArrayIndex.all(), point(4 * hiddenLayerSize)).transpose();
    INDArray wOOTranspose = recurrentWeights.get(NDArrayIndex.all(), point(4 * hiddenLayerSize + 1)).transpose();
    INDArray wGGTranspose = recurrentWeights.get(NDArrayIndex.all(), point(4 * hiddenLayerSize + 2)).transpose();
    INDArray wIFOG = recurrentWeights.get(NDArrayIndex.all(), NDArrayIndex.interval(0, 4 * hiddenLayerSize));
    //F order here so that content for time steps are together
    //i.e., what would be W^L*(delta^L)^T. Shape: [m,n^(L-1),T]
    INDArray epsilonNext = Nd4j.create(new int[] { miniBatchSize, prevLayerSize, timeSeriesLength }, 'f');
    INDArray nablaCellStateNext = null;
    INDArray deltaifogNext = Nd4j.create(new int[] { miniBatchSize, 4 * hiddenLayerSize }, 'f');
    INDArray deltaiNext = deltaifogNext.get(NDArrayIndex.all(), NDArrayIndex.interval(0, hiddenLayerSize));
    INDArray deltafNext = deltaifogNext.get(NDArrayIndex.all(), NDArrayIndex.interval(hiddenLayerSize, 2 * hiddenLayerSize));
    INDArray deltaoNext = deltaifogNext.get(NDArrayIndex.all(), NDArrayIndex.interval(2 * hiddenLayerSize, 3 * hiddenLayerSize));
    INDArray deltagNext = deltaifogNext.get(NDArrayIndex.all(), NDArrayIndex.interval(3 * hiddenLayerSize, 4 * hiddenLayerSize));
    Level1 l1BLAS = Nd4j.getBlasWrapper().level1();
    int endIdx = 0;
    if (truncatedBPTT) {
        endIdx = Math.max(0, timeSeriesLength - tbpttBackwardLength);
    }
    //Get gradients. Note that we have to manually zero these, as they might not be initialized (or still has data from last iteration)
    //Also note that they are in f order (as per param initializer) so can be used in gemm etc
    INDArray iwGradientsOut = gradientViews.get(inputWeightKey);
    //Order: {I,F,O,G,FF,OO,GG}
    INDArray rwGradientsOut = gradientViews.get(recurrentWeightKey);
    INDArray bGradientsOut = gradientViews.get(biasWeightKey);
    iwGradientsOut.assign(0);
    rwGradientsOut.assign(0);
    bGradientsOut.assign(0);
    INDArray rwGradientsIFOG = rwGradientsOut.get(NDArrayIndex.all(), NDArrayIndex.interval(0, 4 * hiddenLayerSize));
    INDArray rwGradientsFF = rwGradientsOut.get(NDArrayIndex.all(), NDArrayIndex.point(4 * hiddenLayerSize));
    INDArray rwGradientsOO = rwGradientsOut.get(NDArrayIndex.all(), NDArrayIndex.point(4 * hiddenLayerSize + 1));
    INDArray rwGradientsGG = rwGradientsOut.get(NDArrayIndex.all(), NDArrayIndex.point(4 * hiddenLayerSize + 2));
    boolean sigmoidGates = gateActivationFn instanceof ActivationSigmoid;
    IActivation afn = conf.getLayer().getActivationFn();
    INDArray timeStepMaskColumn = null;
    for (int iTimeIndex = timeSeriesLength - 1; iTimeIndex >= endIdx; iTimeIndex--) {
        int time = iTimeIndex;
        int inext = 1;
        if (!forwards) {
            time = timeSeriesLength - iTimeIndex - 1;
            inext = -1;
        }
        //First: calclate the components of nablaCellState that relies on the next time step deltas, so we can overwrite the deltas
        INDArray nablaCellState;
        if (iTimeIndex != timeSeriesLength - 1) {
            nablaCellState = deltafNext.dup('f').muliRowVector(wFFTranspose);
            l1BLAS.axpy(nablaCellState.length(), 1.0, deltagNext.dup('f').muliRowVector(wGGTranspose), nablaCellState);
        } else {
            nablaCellState = Nd4j.create(new int[] { miniBatchSize, hiddenLayerSize }, 'f');
        }
        INDArray prevMemCellState = (iTimeIndex == 0 ? null : fwdPass.memCellState[time - inext]);
        INDArray prevHiddenUnitActivation = (iTimeIndex == 0 ? null : fwdPass.fwdPassOutputAsArrays[time - inext]);
        INDArray currMemCellState = fwdPass.memCellState[time];
        //LSTM unit output errors (dL/d(a_out)); not to be confused with \delta=dL/d(z_out)
        //(w^{L+1}*(delta^{(L+1)t})^T)^T or equiv.
        INDArray epsilonSlice = (is2dInput ? epsilon : epsilon.tensorAlongDimension(time, 1, 0));
        //Shape: [m,n^L]
        INDArray nablaOut = Shape.toOffsetZeroCopy(epsilonSlice, 'f');
        if (iTimeIndex != timeSeriesLength - 1) {
            //if t == timeSeriesLength-1 then deltaiNext etc are zeros
            Nd4j.gemm(deltaifogNext, wIFOG, nablaOut, false, true, 1.0, 1.0);
        }
        //Output gate deltas:
        INDArray sigmahOfS = fwdPass.memCellActivations[time];
        INDArray ao = fwdPass.oa[time];
        //Normally would use zo.dup() in above line, but won't be using zo again (for this time step). Ditto for zf, zg, zi
        INDArray deltao = deltaoNext;
        Nd4j.getExecutioner().exec(new MulOp(nablaOut, sigmahOfS, deltao));
        if (sigmoidGates) {
            //Equivalent to sigmoid deriv on zo
            INDArray sigmaoPrimeOfZo = Nd4j.getExecutioner().execAndReturn(new TimesOneMinus(ao.dup('f')));
            deltao.muli(sigmaoPrimeOfZo);
        } else {
            //Deltao needs to be modified in-place
            deltao.assign(gateActivationFn.backprop(fwdPass.oz[time], deltao).getFirst());
        //TODO: optimize (no assign)
        }
        //Memory cell error:
        //TODO activation functions with params
        INDArray temp = afn.backprop(currMemCellState.dup('f'), ao.muli(nablaOut)).getFirst();
        l1BLAS.axpy(nablaCellState.length(), 1.0, temp, nablaCellState);
        INDArray deltaMulRowWOO = deltao.dup('f').muliRowVector(wOOTranspose);
        //nablaCellState.addi(deltao.mulRowVector(wOOTranspose));
        l1BLAS.axpy(nablaCellState.length(), 1.0, deltaMulRowWOO, nablaCellState);
        if (iTimeIndex != timeSeriesLength - 1) {
            INDArray nextForgetGateAs = fwdPass.fa[time + inext];
            int length = nablaCellState.length();
            //nablaCellState.addi(nextForgetGateAs.mul(nablaCellStateNext))
            l1BLAS.axpy(length, 1.0, nextForgetGateAs.muli(nablaCellStateNext), nablaCellState);
        }
        //Store for use in next iteration
        nablaCellStateNext = nablaCellState;
        //Forget gate delta:
        INDArray af = fwdPass.fa[time];
        INDArray deltaf = null;
        if (iTimeIndex > 0) {
            deltaf = deltafNext;
            if (sigmoidGates) {
                Nd4j.getExecutioner().exec(new TimesOneMinus(af, deltaf));
                deltaf.muli(nablaCellState);
                deltaf.muli(prevMemCellState);
            } else {
                INDArray temp2 = nablaCellState.mul(prevMemCellState);
                //deltaf needs to be modified in-place
                deltaf.assign(gateActivationFn.backprop(fwdPass.fz[time].dup('f'), temp2).getFirst());
            //TODO activation functions with params
            }
        }
        //Shape: [m,n^L]
        //Input modulation gate delta:
        INDArray ag = fwdPass.ga[time];
        INDArray ai = fwdPass.ia[time];
        INDArray deltag = deltagNext;
        if (sigmoidGates) {
            //Equivalent to sigmoid deriv on zg
            Nd4j.getExecutioner().exec(new TimesOneMinus(ag, deltag));
            deltag.muli(ai);
            deltag.muli(nablaCellState);
        } else {
            INDArray temp2 = Nd4j.getExecutioner().execAndReturn(new MulOp(ai, nablaCellState, Nd4j.createUninitialized(ai.shape(), 'f')));
            deltag.assign(gateActivationFn.backprop(fwdPass.gz[time], temp2).getFirst());
        //TODO activation functions with params; optimize (no assign)
        }
        //Shape: [m,n^L]
        //Network input delta:
        INDArray zi = fwdPass.iz[time];
        INDArray deltai = deltaiNext;
        temp = Nd4j.getExecutioner().execAndReturn(new MulOp(ag, nablaCellState, Nd4j.createUninitialized(deltai.shape(), 'f')));
        deltai.assign(afn.backprop(zi, temp).getFirst());
        //Handle masking
        if (maskArray != null) {
            //Mask array is present: bidirectional RNN -> need to zero out these errors to avoid using errors from a masked time step
            // to calculate the parameter gradients.  Mask array has shape [minibatch, timeSeriesLength] -> get column(this time step)
            timeStepMaskColumn = maskArray.getColumn(time);
            deltaifogNext.muliColumnVector(timeStepMaskColumn);
        //Later, the deltaifogNext is used to calculate: input weight gradients, recurrent weight gradients, bias gradients
        }
        INDArray prevLayerActivationSlice = Shape.toMmulCompatible(is2dInput ? input : input.tensorAlongDimension(time, 1, 0));
        if (iTimeIndex > 0) {
            //Again, deltaifog_current == deltaifogNext at this point... same array
            Nd4j.gemm(prevLayerActivationSlice, deltaifogNext, iwGradientsOut, true, false, 1.0, 1.0);
        } else {
            INDArray iwGradients_i = iwGradientsOut.get(NDArrayIndex.all(), NDArrayIndex.interval(0, hiddenLayerSize));
            Nd4j.gemm(prevLayerActivationSlice, deltai, iwGradients_i, true, false, 1.0, 1.0);
            INDArray iwGradients_og = iwGradientsOut.get(NDArrayIndex.all(), NDArrayIndex.interval(2 * hiddenLayerSize, 4 * hiddenLayerSize));
            INDArray deltaog = deltaifogNext.get(NDArrayIndex.all(), NDArrayIndex.interval(2 * hiddenLayerSize, 4 * hiddenLayerSize));
            Nd4j.gemm(prevLayerActivationSlice, deltaog, iwGradients_og, true, false, 1.0, 1.0);
        }
        if (iTimeIndex > 0) {
            //If t==0, then prevHiddenUnitActivation==zeros(n^L,n^L), so dL/dW for recurrent weights will end up as 0 anyway
            //At this point: deltaifog and deltaifogNext are the same thing...
            //So what we are actually doing here is sum of (prevAct^transpose * deltaifog_current)
            Nd4j.gemm(prevHiddenUnitActivation, deltaifogNext, rwGradientsIFOG, true, false, 1.0, 1.0);
            //Shape: [1,n^L]. sum(0) is sum over examples in mini-batch.
            //Can use axpy here because result of sum and rwGradients[4 to 6] have order Nd4j.order(), via Nd4j.create()
            //mul not mmul because these weights are from unit j->j only (whereas other recurrent weights are i->j for all i,j)
            INDArray dLdwFF = deltaf.dup('f').muli(prevMemCellState).sum(0);
            //rwGradients[4].addi(dLdwFF);    //dL/dw_{FF}
            l1BLAS.axpy(hiddenLayerSize, 1.0, dLdwFF, rwGradientsFF);
            INDArray dLdwGG = deltag.dup('f').muli(prevMemCellState).sum(0);
            //rwGradients[6].addi(dLdwGG);
            l1BLAS.axpy(hiddenLayerSize, 1.0, dLdwGG, rwGradientsGG);
        }
        //Expected shape: [n^L,1]. sum(0) is sum over examples in mini-batch.
        INDArray dLdwOO = deltao.dup('f').muli(currMemCellState).sum(0);
        //rwGradients[5].addi(dLdwOO);    //dL/dw_{OOxy}
        l1BLAS.axpy(hiddenLayerSize, 1.0, dLdwOO, rwGradientsOO);
        if (iTimeIndex > 0) {
            l1BLAS.axpy(4 * hiddenLayerSize, 1.0, deltaifogNext.sum(0), bGradientsOut);
        } else {
            //Sneaky way to do bGradients_i += deltai.sum(0)
            l1BLAS.axpy(hiddenLayerSize, 1.0, deltai.sum(0), bGradientsOut);
            INDArray ogBiasToAdd = deltaifogNext.get(NDArrayIndex.all(), NDArrayIndex.interval(2 * hiddenLayerSize, 4 * hiddenLayerSize)).sum(0);
            INDArray ogBiasGrad = bGradientsOut.get(NDArrayIndex.point(0), NDArrayIndex.interval(2 * hiddenLayerSize, 4 * hiddenLayerSize));
            l1BLAS.axpy(2 * hiddenLayerSize, 1.0, ogBiasToAdd, ogBiasGrad);
        }
        //Calculate epsilonNext - i.e., equiv. to what would be (w^L*(d^(Lt))^T)^T in a normal network
        //But here, need to add 4 weights * deltas for the IFOG gates
        //This slice: f order and contiguous, due to epsilonNext being defined as f order.
        INDArray epsilonNextSlice = epsilonNext.tensorAlongDimension(time, 1, 0);
        if (iTimeIndex > 0) {
            Nd4j.gemm(deltaifogNext, inputWeights, epsilonNextSlice, false, true, 1.0, 1.0);
        } else {
            //No contribution from forget gate at t=0
            INDArray wi = inputWeights.get(NDArrayIndex.all(), NDArrayIndex.interval(0, hiddenLayerSize));
            Nd4j.gemm(deltai, wi, epsilonNextSlice, false, true, 1.0, 1.0);
            INDArray deltaog = deltaifogNext.get(NDArrayIndex.all(), NDArrayIndex.interval(2 * hiddenLayerSize, 4 * hiddenLayerSize));
            INDArray wog = inputWeights.get(NDArrayIndex.all(), NDArrayIndex.interval(2 * hiddenLayerSize, 4 * hiddenLayerSize));
            //epsilonNextSlice.addi(deltao.mmul(woTranspose)).addi(deltag.mmul(wgTranspose));
            Nd4j.gemm(deltaog, wog, epsilonNextSlice, false, true, 1.0, 1.0);
        }
        if (maskArray != null) {
            //Mask array is present: bidirectional RNN -> need to zero out these errors to avoid sending anything
            // but 0s to the layer below at this time step (for the given example)
            epsilonNextSlice.muliColumnVector(timeStepMaskColumn);
        }
    }
    Gradient retGradient = new DefaultGradient();
    retGradient.gradientForVariable().put(inputWeightKey, iwGradientsOut);
    retGradient.gradientForVariable().put(recurrentWeightKey, rwGradientsOut);
    retGradient.gradientForVariable().put(biasWeightKey, bGradientsOut);
    return new Pair<>(retGradient, epsilonNext);
}
Also used : Gradient(org.deeplearning4j.nn.gradient.Gradient) DefaultGradient(org.deeplearning4j.nn.gradient.DefaultGradient) DefaultGradient(org.deeplearning4j.nn.gradient.DefaultGradient) INDArray(org.nd4j.linalg.api.ndarray.INDArray) TimesOneMinus(org.nd4j.linalg.api.ops.impl.transforms.TimesOneMinus) MulOp(org.nd4j.linalg.api.ops.impl.transforms.arithmetic.MulOp) ActivationSigmoid(org.nd4j.linalg.activations.impl.ActivationSigmoid) Level1(org.nd4j.linalg.api.blas.Level1) IActivation(org.nd4j.linalg.activations.IActivation) NDArrayIndex.point(org.nd4j.linalg.indexing.NDArrayIndex.point) Pair(org.deeplearning4j.berkeley.Pair)

Aggregations

Pair (org.deeplearning4j.berkeley.Pair)1 DefaultGradient (org.deeplearning4j.nn.gradient.DefaultGradient)1 Gradient (org.deeplearning4j.nn.gradient.Gradient)1 IActivation (org.nd4j.linalg.activations.IActivation)1 ActivationSigmoid (org.nd4j.linalg.activations.impl.ActivationSigmoid)1 Level1 (org.nd4j.linalg.api.blas.Level1)1 INDArray (org.nd4j.linalg.api.ndarray.INDArray)1 TimesOneMinus (org.nd4j.linalg.api.ops.impl.transforms.TimesOneMinus)1 MulOp (org.nd4j.linalg.api.ops.impl.transforms.arithmetic.MulOp)1 NDArrayIndex.point (org.nd4j.linalg.indexing.NDArrayIndex.point)1