use of org.nd4j.linalg.api.ops.impl.transforms.TimesOneMinus in project deeplearning4j by deeplearning4j.
the class LSTMHelpers method backpropGradientHelper.
public static Pair<Gradient, INDArray> backpropGradientHelper(final NeuralNetConfiguration conf, //Shape: [hiddenLayerSize,4*hiddenLayerSize+3]; order: [wI,wF,wO,wG,wFF,wOO,wGG]
final IActivation gateActivationFn, //Shape: [hiddenLayerSize,4*hiddenLayerSize+3]; order: [wI,wF,wO,wG,wFF,wOO,wGG]
final INDArray input, //Shape: [hiddenLayerSize,4*hiddenLayerSize+3]; order: [wI,wF,wO,wG,wFF,wOO,wGG]
final INDArray recurrentWeights, //Shape: [n^(L-1),4*hiddenLayerSize]; order: [wi,wf,wo,wg]
final INDArray inputWeights, final INDArray epsilon, final boolean truncatedBPTT, final int tbpttBackwardLength, final FwdPassReturn fwdPass, final boolean forwards, final String inputWeightKey, final String recurrentWeightKey, final String biasWeightKey, //Input mask: should only be used with bidirectional RNNs + variable length
final Map<String, INDArray> gradientViews, //Input mask: should only be used with bidirectional RNNs + variable length
INDArray maskArray) {
//Expect errors to have shape: [miniBatchSize,n^(L+1),timeSeriesLength]
//i.e., n^L
int hiddenLayerSize = recurrentWeights.size(0);
//n^(L-1)
int prevLayerSize = inputWeights.size(0);
int miniBatchSize = epsilon.size(0);
//Edge case: T=1 may have shape [miniBatchSize,n^(L+1)], equiv. to [miniBatchSize,n^(L+1),1]
boolean is2dInput = epsilon.rank() < 3;
int timeSeriesLength = (is2dInput ? 1 : epsilon.size(2));
INDArray wFFTranspose = recurrentWeights.get(NDArrayIndex.all(), point(4 * hiddenLayerSize)).transpose();
INDArray wOOTranspose = recurrentWeights.get(NDArrayIndex.all(), point(4 * hiddenLayerSize + 1)).transpose();
INDArray wGGTranspose = recurrentWeights.get(NDArrayIndex.all(), point(4 * hiddenLayerSize + 2)).transpose();
INDArray wIFOG = recurrentWeights.get(NDArrayIndex.all(), NDArrayIndex.interval(0, 4 * hiddenLayerSize));
//F order here so that content for time steps are together
//i.e., what would be W^L*(delta^L)^T. Shape: [m,n^(L-1),T]
INDArray epsilonNext = Nd4j.create(new int[] { miniBatchSize, prevLayerSize, timeSeriesLength }, 'f');
INDArray nablaCellStateNext = null;
INDArray deltaifogNext = Nd4j.create(new int[] { miniBatchSize, 4 * hiddenLayerSize }, 'f');
INDArray deltaiNext = deltaifogNext.get(NDArrayIndex.all(), NDArrayIndex.interval(0, hiddenLayerSize));
INDArray deltafNext = deltaifogNext.get(NDArrayIndex.all(), NDArrayIndex.interval(hiddenLayerSize, 2 * hiddenLayerSize));
INDArray deltaoNext = deltaifogNext.get(NDArrayIndex.all(), NDArrayIndex.interval(2 * hiddenLayerSize, 3 * hiddenLayerSize));
INDArray deltagNext = deltaifogNext.get(NDArrayIndex.all(), NDArrayIndex.interval(3 * hiddenLayerSize, 4 * hiddenLayerSize));
Level1 l1BLAS = Nd4j.getBlasWrapper().level1();
int endIdx = 0;
if (truncatedBPTT) {
endIdx = Math.max(0, timeSeriesLength - tbpttBackwardLength);
}
//Get gradients. Note that we have to manually zero these, as they might not be initialized (or still has data from last iteration)
//Also note that they are in f order (as per param initializer) so can be used in gemm etc
INDArray iwGradientsOut = gradientViews.get(inputWeightKey);
//Order: {I,F,O,G,FF,OO,GG}
INDArray rwGradientsOut = gradientViews.get(recurrentWeightKey);
INDArray bGradientsOut = gradientViews.get(biasWeightKey);
iwGradientsOut.assign(0);
rwGradientsOut.assign(0);
bGradientsOut.assign(0);
INDArray rwGradientsIFOG = rwGradientsOut.get(NDArrayIndex.all(), NDArrayIndex.interval(0, 4 * hiddenLayerSize));
INDArray rwGradientsFF = rwGradientsOut.get(NDArrayIndex.all(), NDArrayIndex.point(4 * hiddenLayerSize));
INDArray rwGradientsOO = rwGradientsOut.get(NDArrayIndex.all(), NDArrayIndex.point(4 * hiddenLayerSize + 1));
INDArray rwGradientsGG = rwGradientsOut.get(NDArrayIndex.all(), NDArrayIndex.point(4 * hiddenLayerSize + 2));
boolean sigmoidGates = gateActivationFn instanceof ActivationSigmoid;
IActivation afn = conf.getLayer().getActivationFn();
INDArray timeStepMaskColumn = null;
for (int iTimeIndex = timeSeriesLength - 1; iTimeIndex >= endIdx; iTimeIndex--) {
int time = iTimeIndex;
int inext = 1;
if (!forwards) {
time = timeSeriesLength - iTimeIndex - 1;
inext = -1;
}
//First: calclate the components of nablaCellState that relies on the next time step deltas, so we can overwrite the deltas
INDArray nablaCellState;
if (iTimeIndex != timeSeriesLength - 1) {
nablaCellState = deltafNext.dup('f').muliRowVector(wFFTranspose);
l1BLAS.axpy(nablaCellState.length(), 1.0, deltagNext.dup('f').muliRowVector(wGGTranspose), nablaCellState);
} else {
nablaCellState = Nd4j.create(new int[] { miniBatchSize, hiddenLayerSize }, 'f');
}
INDArray prevMemCellState = (iTimeIndex == 0 ? null : fwdPass.memCellState[time - inext]);
INDArray prevHiddenUnitActivation = (iTimeIndex == 0 ? null : fwdPass.fwdPassOutputAsArrays[time - inext]);
INDArray currMemCellState = fwdPass.memCellState[time];
//LSTM unit output errors (dL/d(a_out)); not to be confused with \delta=dL/d(z_out)
//(w^{L+1}*(delta^{(L+1)t})^T)^T or equiv.
INDArray epsilonSlice = (is2dInput ? epsilon : epsilon.tensorAlongDimension(time, 1, 0));
//Shape: [m,n^L]
INDArray nablaOut = Shape.toOffsetZeroCopy(epsilonSlice, 'f');
if (iTimeIndex != timeSeriesLength - 1) {
//if t == timeSeriesLength-1 then deltaiNext etc are zeros
Nd4j.gemm(deltaifogNext, wIFOG, nablaOut, false, true, 1.0, 1.0);
}
//Output gate deltas:
INDArray sigmahOfS = fwdPass.memCellActivations[time];
INDArray ao = fwdPass.oa[time];
//Normally would use zo.dup() in above line, but won't be using zo again (for this time step). Ditto for zf, zg, zi
INDArray deltao = deltaoNext;
Nd4j.getExecutioner().exec(new MulOp(nablaOut, sigmahOfS, deltao));
if (sigmoidGates) {
//Equivalent to sigmoid deriv on zo
INDArray sigmaoPrimeOfZo = Nd4j.getExecutioner().execAndReturn(new TimesOneMinus(ao.dup('f')));
deltao.muli(sigmaoPrimeOfZo);
} else {
//Deltao needs to be modified in-place
deltao.assign(gateActivationFn.backprop(fwdPass.oz[time], deltao).getFirst());
//TODO: optimize (no assign)
}
//Memory cell error:
//TODO activation functions with params
INDArray temp = afn.backprop(currMemCellState.dup('f'), ao.muli(nablaOut)).getFirst();
l1BLAS.axpy(nablaCellState.length(), 1.0, temp, nablaCellState);
INDArray deltaMulRowWOO = deltao.dup('f').muliRowVector(wOOTranspose);
//nablaCellState.addi(deltao.mulRowVector(wOOTranspose));
l1BLAS.axpy(nablaCellState.length(), 1.0, deltaMulRowWOO, nablaCellState);
if (iTimeIndex != timeSeriesLength - 1) {
INDArray nextForgetGateAs = fwdPass.fa[time + inext];
int length = nablaCellState.length();
//nablaCellState.addi(nextForgetGateAs.mul(nablaCellStateNext))
l1BLAS.axpy(length, 1.0, nextForgetGateAs.muli(nablaCellStateNext), nablaCellState);
}
//Store for use in next iteration
nablaCellStateNext = nablaCellState;
//Forget gate delta:
INDArray af = fwdPass.fa[time];
INDArray deltaf = null;
if (iTimeIndex > 0) {
deltaf = deltafNext;
if (sigmoidGates) {
Nd4j.getExecutioner().exec(new TimesOneMinus(af, deltaf));
deltaf.muli(nablaCellState);
deltaf.muli(prevMemCellState);
} else {
INDArray temp2 = nablaCellState.mul(prevMemCellState);
//deltaf needs to be modified in-place
deltaf.assign(gateActivationFn.backprop(fwdPass.fz[time].dup('f'), temp2).getFirst());
//TODO activation functions with params
}
}
//Shape: [m,n^L]
//Input modulation gate delta:
INDArray ag = fwdPass.ga[time];
INDArray ai = fwdPass.ia[time];
INDArray deltag = deltagNext;
if (sigmoidGates) {
//Equivalent to sigmoid deriv on zg
Nd4j.getExecutioner().exec(new TimesOneMinus(ag, deltag));
deltag.muli(ai);
deltag.muli(nablaCellState);
} else {
INDArray temp2 = Nd4j.getExecutioner().execAndReturn(new MulOp(ai, nablaCellState, Nd4j.createUninitialized(ai.shape(), 'f')));
deltag.assign(gateActivationFn.backprop(fwdPass.gz[time], temp2).getFirst());
//TODO activation functions with params; optimize (no assign)
}
//Shape: [m,n^L]
//Network input delta:
INDArray zi = fwdPass.iz[time];
INDArray deltai = deltaiNext;
temp = Nd4j.getExecutioner().execAndReturn(new MulOp(ag, nablaCellState, Nd4j.createUninitialized(deltai.shape(), 'f')));
deltai.assign(afn.backprop(zi, temp).getFirst());
//Handle masking
if (maskArray != null) {
//Mask array is present: bidirectional RNN -> need to zero out these errors to avoid using errors from a masked time step
// to calculate the parameter gradients. Mask array has shape [minibatch, timeSeriesLength] -> get column(this time step)
timeStepMaskColumn = maskArray.getColumn(time);
deltaifogNext.muliColumnVector(timeStepMaskColumn);
//Later, the deltaifogNext is used to calculate: input weight gradients, recurrent weight gradients, bias gradients
}
INDArray prevLayerActivationSlice = Shape.toMmulCompatible(is2dInput ? input : input.tensorAlongDimension(time, 1, 0));
if (iTimeIndex > 0) {
//Again, deltaifog_current == deltaifogNext at this point... same array
Nd4j.gemm(prevLayerActivationSlice, deltaifogNext, iwGradientsOut, true, false, 1.0, 1.0);
} else {
INDArray iwGradients_i = iwGradientsOut.get(NDArrayIndex.all(), NDArrayIndex.interval(0, hiddenLayerSize));
Nd4j.gemm(prevLayerActivationSlice, deltai, iwGradients_i, true, false, 1.0, 1.0);
INDArray iwGradients_og = iwGradientsOut.get(NDArrayIndex.all(), NDArrayIndex.interval(2 * hiddenLayerSize, 4 * hiddenLayerSize));
INDArray deltaog = deltaifogNext.get(NDArrayIndex.all(), NDArrayIndex.interval(2 * hiddenLayerSize, 4 * hiddenLayerSize));
Nd4j.gemm(prevLayerActivationSlice, deltaog, iwGradients_og, true, false, 1.0, 1.0);
}
if (iTimeIndex > 0) {
//If t==0, then prevHiddenUnitActivation==zeros(n^L,n^L), so dL/dW for recurrent weights will end up as 0 anyway
//At this point: deltaifog and deltaifogNext are the same thing...
//So what we are actually doing here is sum of (prevAct^transpose * deltaifog_current)
Nd4j.gemm(prevHiddenUnitActivation, deltaifogNext, rwGradientsIFOG, true, false, 1.0, 1.0);
//Shape: [1,n^L]. sum(0) is sum over examples in mini-batch.
//Can use axpy here because result of sum and rwGradients[4 to 6] have order Nd4j.order(), via Nd4j.create()
//mul not mmul because these weights are from unit j->j only (whereas other recurrent weights are i->j for all i,j)
INDArray dLdwFF = deltaf.dup('f').muli(prevMemCellState).sum(0);
//rwGradients[4].addi(dLdwFF); //dL/dw_{FF}
l1BLAS.axpy(hiddenLayerSize, 1.0, dLdwFF, rwGradientsFF);
INDArray dLdwGG = deltag.dup('f').muli(prevMemCellState).sum(0);
//rwGradients[6].addi(dLdwGG);
l1BLAS.axpy(hiddenLayerSize, 1.0, dLdwGG, rwGradientsGG);
}
//Expected shape: [n^L,1]. sum(0) is sum over examples in mini-batch.
INDArray dLdwOO = deltao.dup('f').muli(currMemCellState).sum(0);
//rwGradients[5].addi(dLdwOO); //dL/dw_{OOxy}
l1BLAS.axpy(hiddenLayerSize, 1.0, dLdwOO, rwGradientsOO);
if (iTimeIndex > 0) {
l1BLAS.axpy(4 * hiddenLayerSize, 1.0, deltaifogNext.sum(0), bGradientsOut);
} else {
//Sneaky way to do bGradients_i += deltai.sum(0)
l1BLAS.axpy(hiddenLayerSize, 1.0, deltai.sum(0), bGradientsOut);
INDArray ogBiasToAdd = deltaifogNext.get(NDArrayIndex.all(), NDArrayIndex.interval(2 * hiddenLayerSize, 4 * hiddenLayerSize)).sum(0);
INDArray ogBiasGrad = bGradientsOut.get(NDArrayIndex.point(0), NDArrayIndex.interval(2 * hiddenLayerSize, 4 * hiddenLayerSize));
l1BLAS.axpy(2 * hiddenLayerSize, 1.0, ogBiasToAdd, ogBiasGrad);
}
//Calculate epsilonNext - i.e., equiv. to what would be (w^L*(d^(Lt))^T)^T in a normal network
//But here, need to add 4 weights * deltas for the IFOG gates
//This slice: f order and contiguous, due to epsilonNext being defined as f order.
INDArray epsilonNextSlice = epsilonNext.tensorAlongDimension(time, 1, 0);
if (iTimeIndex > 0) {
Nd4j.gemm(deltaifogNext, inputWeights, epsilonNextSlice, false, true, 1.0, 1.0);
} else {
//No contribution from forget gate at t=0
INDArray wi = inputWeights.get(NDArrayIndex.all(), NDArrayIndex.interval(0, hiddenLayerSize));
Nd4j.gemm(deltai, wi, epsilonNextSlice, false, true, 1.0, 1.0);
INDArray deltaog = deltaifogNext.get(NDArrayIndex.all(), NDArrayIndex.interval(2 * hiddenLayerSize, 4 * hiddenLayerSize));
INDArray wog = inputWeights.get(NDArrayIndex.all(), NDArrayIndex.interval(2 * hiddenLayerSize, 4 * hiddenLayerSize));
//epsilonNextSlice.addi(deltao.mmul(woTranspose)).addi(deltag.mmul(wgTranspose));
Nd4j.gemm(deltaog, wog, epsilonNextSlice, false, true, 1.0, 1.0);
}
if (maskArray != null) {
//Mask array is present: bidirectional RNN -> need to zero out these errors to avoid sending anything
// but 0s to the layer below at this time step (for the given example)
epsilonNextSlice.muliColumnVector(timeStepMaskColumn);
}
}
Gradient retGradient = new DefaultGradient();
retGradient.gradientForVariable().put(inputWeightKey, iwGradientsOut);
retGradient.gradientForVariable().put(recurrentWeightKey, rwGradientsOut);
retGradient.gradientForVariable().put(biasWeightKey, bGradientsOut);
return new Pair<>(retGradient, epsilonNext);
}
use of org.nd4j.linalg.api.ops.impl.transforms.TimesOneMinus in project nd4j by deeplearning4j.
the class LossBinaryXENT method computeGradient.
@Override
public INDArray computeGradient(INDArray labels, INDArray preOutput, IActivation activationFn, INDArray mask) {
if (labels.size(1) != preOutput.size(1)) {
throw new IllegalArgumentException("Labels array numColumns (size(1) = " + labels.size(1) + ") does not match output layer" + " number of outputs (nOut = " + preOutput.size(1) + ") ");
}
INDArray output = activationFn.getActivation(preOutput.dup(), true);
if (clipEps > 0.0) {
CustomOp op = DynamicCustomOp.builder("clipbyvalue").addInputs(output).callInplace(true).addFloatingPointArguments(clipEps, 1.0 - clipEps).build();
Nd4j.getExecutioner().exec(op);
}
INDArray numerator = output.sub(labels);
// output * (1-output)
INDArray denominator = Nd4j.getExecutioner().execAndReturn(new TimesOneMinus(output));
INDArray dLda = numerator.divi(denominator);
if (mask != null && LossUtil.isPerOutputMasking(dLda, mask)) {
// For *most* activation functions: we don't actually need to mask dL/da in addition to masking dL/dz later
// but: some, like softmax, require both (due to dL/dz_i being a function of dL/da_j, for i != j)
// We could add a special case for softmax (activationFn instanceof ActivationSoftmax) but that would be
// error prone - but buy us a tiny bit of performance
LossUtil.applyMask(dLda, mask);
}
// TODO activation functions with weights
INDArray grad = activationFn.backprop(preOutput, dLda).getFirst();
// Weighted loss function
if (weights != null) {
if (weights.length() != output.size(1)) {
throw new IllegalStateException("Weights vector (length " + weights.length() + ") does not match output.size(1)=" + output.size(1));
}
grad.muliRowVector(weights);
}
if (mask != null) {
LossUtil.applyMask(grad, mask);
}
return grad;
}
Aggregations