use of org.nd4j.linalg.activations.IActivation in project deeplearning4j by deeplearning4j.
the class GradientCheckUtil method checkGradients.
/**
* Check backprop gradients for a MultiLayerNetwork.
* @param mln MultiLayerNetwork to test. This must be initialized.
* @param epsilon Usually on the order/ of 1e-4 or so.
* @param maxRelError Maximum relative error. Usually < 1e-5 or so, though maybe more for deep networks or those with nonlinear activation
* @param minAbsoluteError Minimum absolute error to cause a failure. Numerical gradients can be non-zero due to precision issues.
* For example, 0.0 vs. 1e-18: relative error is 1.0, but not really a failure
* @param print Whether to print full pass/failure details for each parameter gradient
* @param exitOnFirstError If true: return upon first failure. If false: continue checking even if
* one parameter gradient has failed. Typically use false for debugging, true for unit tests.
* @param input Input array to use for forward pass. May be mini-batch data.
* @param labels Labels/targets to use to calculate backprop gradient. May be mini-batch data.
* @return true if gradients are passed, false otherwise.
*/
public static boolean checkGradients(MultiLayerNetwork mln, double epsilon, double maxRelError, double minAbsoluteError, boolean print, boolean exitOnFirstError, INDArray input, INDArray labels) {
//Basic sanity checks on input:
if (epsilon <= 0.0 || epsilon > 0.1)
throw new IllegalArgumentException("Invalid epsilon: expect epsilon in range (0,0.1], usually 1e-4 or so");
if (maxRelError <= 0.0 || maxRelError > 0.25)
throw new IllegalArgumentException("Invalid maxRelativeError: " + maxRelError);
if (!(mln.getOutputLayer() instanceof IOutputLayer))
throw new IllegalArgumentException("Cannot check backprop gradients without OutputLayer");
//Check network configuration:
int layerCount = 0;
for (NeuralNetConfiguration n : mln.getLayerWiseConfigurations().getConfs()) {
org.deeplearning4j.nn.conf.Updater u = n.getLayer().getUpdater();
if (u == org.deeplearning4j.nn.conf.Updater.SGD) {
//Must have LR of 1.0
double lr = n.getLayer().getLearningRate();
if (lr != 1.0) {
throw new IllegalStateException("When using SGD updater, must also use lr=1.0 for layer " + layerCount + "; got " + u + " with lr=" + lr + " for layer \"" + n.getLayer().getLayerName() + "\"");
}
} else if (u != org.deeplearning4j.nn.conf.Updater.NONE) {
throw new IllegalStateException("Must have Updater.NONE (or SGD + lr=1.0) for layer " + layerCount + "; got " + u);
}
double dropout = n.getLayer().getDropOut();
if (n.isUseRegularization() && dropout != 0.0) {
throw new IllegalStateException("Must have dropout == 0.0 for gradient checks - got dropout = " + dropout + " for layer " + layerCount);
}
IActivation activation = n.getLayer().getActivationFn();
if (activation != null) {
if (!VALID_ACTIVATION_FUNCTIONS.contains(activation.getClass())) {
log.warn("Layer " + layerCount + " is possibly using an unsuitable activation function: " + activation.getClass() + ". Activation functions for gradient checks must be smooth (like sigmoid, tanh, softmax) and not " + "contain discontinuities like ReLU or LeakyReLU (these may cause spurious failures)");
}
}
}
mln.setInput(input);
mln.setLabels(labels);
mln.computeGradientAndScore();
Pair<Gradient, Double> gradAndScore = mln.gradientAndScore();
Updater updater = UpdaterCreator.getUpdater(mln);
updater.update(mln, gradAndScore.getFirst(), 0, mln.batchSize());
//need dup: gradients are a *view* of the full gradient array (which will change every time backprop is done)
INDArray gradientToCheck = gradAndScore.getFirst().gradient().dup();
//need dup: params are a *view* of full parameters
INDArray originalParams = mln.params().dup();
int nParams = originalParams.length();
Map<String, INDArray> paramTable = mln.paramTable();
List<String> paramNames = new ArrayList<>(paramTable.keySet());
int[] paramEnds = new int[paramNames.size()];
paramEnds[0] = paramTable.get(paramNames.get(0)).length();
for (int i = 1; i < paramEnds.length; i++) {
paramEnds[i] = paramEnds[i - 1] + paramTable.get(paramNames.get(i)).length();
}
int totalNFailures = 0;
double maxError = 0.0;
DataSet ds = new DataSet(input, labels);
int currParamNameIdx = 0;
//Assumption here: params is a view that we can modify in-place
INDArray params = mln.params();
for (int i = 0; i < nParams; i++) {
//Get param name
if (i >= paramEnds[currParamNameIdx]) {
currParamNameIdx++;
}
String paramName = paramNames.get(currParamNameIdx);
//(w+epsilon): Do forward pass and score
double origValue = params.getDouble(i);
params.putScalar(i, origValue + epsilon);
double scorePlus = mln.score(ds, true);
//(w-epsilon): Do forward pass and score
params.putScalar(i, origValue - epsilon);
double scoreMinus = mln.score(ds, true);
//Reset original param value
params.putScalar(i, origValue);
//Calculate numerical parameter gradient:
double scoreDelta = scorePlus - scoreMinus;
double numericalGradient = scoreDelta / (2 * epsilon);
if (Double.isNaN(numericalGradient))
throw new IllegalStateException("Numerical gradient was NaN for parameter " + i + " of " + nParams);
double backpropGradient = gradientToCheck.getDouble(i);
//http://cs231n.github.io/neural-networks-3/#gradcheck
//use mean centered
double relError = Math.abs(backpropGradient - numericalGradient) / (Math.abs(numericalGradient) + Math.abs(backpropGradient));
if (backpropGradient == 0.0 && numericalGradient == 0.0)
//Edge case: i.e., RNNs with time series length of 1.0
relError = 0.0;
if (relError > maxError)
maxError = relError;
if (relError > maxRelError || Double.isNaN(relError)) {
double absError = Math.abs(backpropGradient - numericalGradient);
if (absError < minAbsoluteError) {
log.info("Param " + i + " (" + paramName + ") passed: grad= " + backpropGradient + ", numericalGrad= " + numericalGradient + ", relError= " + relError + "; absolute error = " + absError + " < minAbsoluteError = " + minAbsoluteError);
} else {
if (print)
log.info("Param " + i + " (" + paramName + ") FAILED: grad= " + backpropGradient + ", numericalGrad= " + numericalGradient + ", relError= " + relError + ", scorePlus=" + scorePlus + ", scoreMinus= " + scoreMinus);
if (exitOnFirstError)
return false;
totalNFailures++;
}
} else if (print) {
log.info("Param " + i + " (" + paramName + ") passed: grad= " + backpropGradient + ", numericalGrad= " + numericalGradient + ", relError= " + relError);
}
}
if (print) {
int nPass = nParams - totalNFailures;
log.info("GradientCheckUtil.checkGradients(): " + nParams + " params checked, " + nPass + " passed, " + totalNFailures + " failed. Largest relative error = " + maxError);
}
return totalNFailures == 0;
}
use of org.nd4j.linalg.activations.IActivation in project deeplearning4j by deeplearning4j.
the class MultiLayerConfiguration method fromJson.
/**
* Create a neural net configuration from json
* @param json the neural net configuration from json
* @return {@link MultiLayerConfiguration}
*/
public static MultiLayerConfiguration fromJson(String json) {
MultiLayerConfiguration conf;
ObjectMapper mapper = NeuralNetConfiguration.mapper();
try {
conf = mapper.readValue(json, MultiLayerConfiguration.class);
} catch (IOException e) {
throw new RuntimeException(e);
}
//To maintain backward compatibility after loss function refactoring (configs generated with v0.5.0 or earlier)
// Previously: enumeration used for loss functions. Now: use classes
// IN the past, could have only been an OutputLayer or RnnOutputLayer using these enums
int layerCount = 0;
JsonNode confs = null;
for (NeuralNetConfiguration nnc : conf.getConfs()) {
Layer l = nnc.getLayer();
if (l instanceof BaseOutputLayer && ((BaseOutputLayer) l).getLossFn() == null) {
//lossFn field null -> may be an old config format, with lossFunction field being for the enum
//if so, try walking the JSON graph to extract out the appropriate enum value
BaseOutputLayer ol = (BaseOutputLayer) l;
try {
JsonNode jsonNode = mapper.readTree(json);
if (confs == null) {
confs = jsonNode.get("confs");
}
if (confs instanceof ArrayNode) {
ArrayNode layerConfs = (ArrayNode) confs;
JsonNode outputLayerNNCNode = layerConfs.get(layerCount);
if (outputLayerNNCNode == null)
//Should never happen...
return conf;
JsonNode outputLayerNode = outputLayerNNCNode.get("layer");
JsonNode lossFunctionNode = null;
if (outputLayerNode.has("output")) {
lossFunctionNode = outputLayerNode.get("output").get("lossFunction");
} else if (outputLayerNode.has("rnnoutput")) {
lossFunctionNode = outputLayerNode.get("rnnoutput").get("lossFunction");
}
if (lossFunctionNode != null) {
String lossFunctionEnumStr = lossFunctionNode.asText();
LossFunctions.LossFunction lossFunction = null;
try {
lossFunction = LossFunctions.LossFunction.valueOf(lossFunctionEnumStr);
} catch (Exception e) {
log.warn("OutputLayer with null LossFunction or pre-0.6.0 loss function configuration detected: could not parse JSON", e);
}
if (lossFunction != null) {
switch(lossFunction) {
case MSE:
ol.setLossFn(new LossMSE());
break;
case XENT:
ol.setLossFn(new LossBinaryXENT());
break;
case NEGATIVELOGLIKELIHOOD:
ol.setLossFn(new LossNegativeLogLikelihood());
break;
case MCXENT:
ol.setLossFn(new LossMCXENT());
break;
//Remaining: TODO
case EXPLL:
case RMSE_XENT:
case SQUARED_LOSS:
case RECONSTRUCTION_CROSSENTROPY:
case CUSTOM:
default:
log.warn("OutputLayer with null LossFunction or pre-0.6.0 loss function configuration detected: could not set loss function for {}", lossFunction);
break;
}
}
}
} else {
log.warn("OutputLayer with null LossFunction or pre-0.6.0 loss function configuration detected: could not parse JSON: layer 'confs' field is not an ArrayNode (is: {})", (confs != null ? confs.getClass() : null));
}
} catch (IOException e) {
log.warn("OutputLayer with null LossFunction or pre-0.6.0 loss function configuration detected: could not parse JSON", e);
break;
}
}
//Try to load the old format if necessary, and create the appropriate IActivation instance
if (l.getActivationFn() == null) {
try {
JsonNode jsonNode = mapper.readTree(json);
if (confs == null) {
confs = jsonNode.get("confs");
}
if (confs instanceof ArrayNode) {
ArrayNode layerConfs = (ArrayNode) confs;
JsonNode outputLayerNNCNode = layerConfs.get(layerCount);
if (outputLayerNNCNode == null)
//Should never happen...
return conf;
JsonNode layerWrapperNode = outputLayerNNCNode.get("layer");
if (layerWrapperNode == null || layerWrapperNode.size() != 1) {
continue;
}
JsonNode layerNode = layerWrapperNode.elements().next();
//Should only have 1 element: "dense", "output", etc
JsonNode activationFunction = layerNode.get("activationFunction");
if (activationFunction != null) {
IActivation ia = Activation.fromString(activationFunction.asText()).getActivationFunction();
l.setActivationFn(ia);
}
}
} catch (IOException e) {
log.warn("Layer with null ActivationFn field or pre-0.7.2 activation function detected: could not parse JSON", e);
}
}
layerCount++;
}
return conf;
}
use of org.nd4j.linalg.activations.IActivation in project deeplearning4j by deeplearning4j.
the class VariationalAutoencoder method reconstructionLogProbability.
/**
* Return the log reconstruction probability given the specified number of samples.<br>
* See {@link #reconstructionLogProbability(INDArray, int)} for more details
*
* @param data The data to calculate the log reconstruction probability
* @param numSamples Number of samples with which to base the reconstruction probability on.
* @return Column vector of reconstruction log probabilities for each example (shape: [numExamples,1])
*/
public INDArray reconstructionLogProbability(INDArray data, int numSamples) {
if (numSamples <= 0) {
throw new IllegalArgumentException("Invalid input: numSamples must be > 0. Got: " + numSamples);
}
if (reconstructionDistribution instanceof LossFunctionWrapper) {
throw new UnsupportedOperationException("Cannot calculate reconstruction log probability when using " + "a LossFunction (via LossFunctionWrapper) instead of a ReconstructionDistribution: ILossFunction " + "instances are not in general probabilistic, hence it is not possible to calculate reconstruction probability");
}
//Forward pass through the encoder and mean for P(Z|X)
setInput(data);
VAEFwdHelper fwd = doForward(true, true);
IActivation afn = conf().getLayer().getActivationFn();
//Forward pass through logStd^2 for P(Z|X)
INDArray pzxLogStd2W = params.get(VariationalAutoencoderParamInitializer.PZX_LOGSTD2_W);
INDArray pzxLogStd2b = params.get(VariationalAutoencoderParamInitializer.PZX_LOGSTD2_B);
INDArray meanZ = fwd.pzxMeanPreOut;
INDArray logStdev2Z = fwd.encoderActivations[fwd.encoderActivations.length - 1].mmul(pzxLogStd2W).addiRowVector(pzxLogStd2b);
pzxActivationFn.getActivation(meanZ, false);
pzxActivationFn.getActivation(logStdev2Z, false);
INDArray pzxSigma = Transforms.exp(logStdev2Z, false);
Transforms.sqrt(pzxSigma, false);
int minibatch = input.size(0);
int size = fwd.pzxMeanPreOut.size(1);
INDArray pxzw = params.get(VariationalAutoencoderParamInitializer.PXZ_W);
INDArray pxzb = params.get(VariationalAutoencoderParamInitializer.PXZ_B);
INDArray[] decoderWeights = new INDArray[decoderLayerSizes.length];
INDArray[] decoderBiases = new INDArray[decoderLayerSizes.length];
for (int i = 0; i < decoderLayerSizes.length; i++) {
String wKey = "d" + i + WEIGHT_KEY_SUFFIX;
String bKey = "d" + i + BIAS_KEY_SUFFIX;
decoderWeights[i] = params.get(wKey);
decoderBiases[i] = params.get(bKey);
}
INDArray sumReconstructionNegLogProbability = null;
for (int i = 0; i < numSamples; i++) {
INDArray e = Nd4j.randn(minibatch, size);
//z = mu + sigma * e, with e ~ N(0,1)
INDArray z = e.muli(pzxSigma).addi(meanZ);
//Do forward pass through decoder
int nDecoderLayers = decoderLayerSizes.length;
INDArray currentActivations = z;
for (int j = 0; j < nDecoderLayers; j++) {
currentActivations = currentActivations.mmul(decoderWeights[j]).addiRowVector(decoderBiases[j]);
afn.getActivation(currentActivations, false);
}
//And calculate reconstruction distribution preOut
INDArray pxzDistributionPreOut = currentActivations.mmul(pxzw).addiRowVector(pxzb);
if (i == 0) {
sumReconstructionNegLogProbability = reconstructionDistribution.exampleNegLogProbability(data, pxzDistributionPreOut);
} else {
sumReconstructionNegLogProbability.addi(reconstructionDistribution.exampleNegLogProbability(data, pxzDistributionPreOut));
}
}
setInput(null);
return sumReconstructionNegLogProbability.divi(-numSamples);
}
use of org.nd4j.linalg.activations.IActivation in project deeplearning4j by deeplearning4j.
the class VariationalAutoencoder method decodeGivenLatentSpaceValues.
private INDArray decodeGivenLatentSpaceValues(INDArray latentSpaceValues) {
if (latentSpaceValues.size(1) != params.get(VariationalAutoencoderParamInitializer.PZX_MEAN_W).size(1)) {
throw new IllegalArgumentException("Invalid latent space values: expected size " + params.get(VariationalAutoencoderParamInitializer.PZX_MEAN_W).size(1) + ", got size (dimension 1) = " + latentSpaceValues.size(1));
}
//Do forward pass through decoder
int nDecoderLayers = decoderLayerSizes.length;
INDArray currentActivations = latentSpaceValues;
IActivation afn = conf().getLayer().getActivationFn();
for (int i = 0; i < nDecoderLayers; i++) {
String wKey = "d" + i + WEIGHT_KEY_SUFFIX;
String bKey = "d" + i + BIAS_KEY_SUFFIX;
INDArray w = params.get(wKey);
INDArray b = params.get(bKey);
currentActivations = currentActivations.mmul(w).addiRowVector(b);
afn.getActivation(currentActivations, false);
}
INDArray pxzw = params.get(VariationalAutoencoderParamInitializer.PXZ_W);
INDArray pxzb = params.get(VariationalAutoencoderParamInitializer.PXZ_B);
return currentActivations.mmul(pxzw).addiRowVector(pxzb);
}
use of org.nd4j.linalg.activations.IActivation in project deeplearning4j by deeplearning4j.
the class LSTMHelpers method backpropGradientHelper.
public static Pair<Gradient, INDArray> backpropGradientHelper(final NeuralNetConfiguration conf, //Shape: [hiddenLayerSize,4*hiddenLayerSize+3]; order: [wI,wF,wO,wG,wFF,wOO,wGG]
final IActivation gateActivationFn, //Shape: [hiddenLayerSize,4*hiddenLayerSize+3]; order: [wI,wF,wO,wG,wFF,wOO,wGG]
final INDArray input, //Shape: [hiddenLayerSize,4*hiddenLayerSize+3]; order: [wI,wF,wO,wG,wFF,wOO,wGG]
final INDArray recurrentWeights, //Shape: [n^(L-1),4*hiddenLayerSize]; order: [wi,wf,wo,wg]
final INDArray inputWeights, final INDArray epsilon, final boolean truncatedBPTT, final int tbpttBackwardLength, final FwdPassReturn fwdPass, final boolean forwards, final String inputWeightKey, final String recurrentWeightKey, final String biasWeightKey, //Input mask: should only be used with bidirectional RNNs + variable length
final Map<String, INDArray> gradientViews, //Input mask: should only be used with bidirectional RNNs + variable length
INDArray maskArray) {
//Expect errors to have shape: [miniBatchSize,n^(L+1),timeSeriesLength]
//i.e., n^L
int hiddenLayerSize = recurrentWeights.size(0);
//n^(L-1)
int prevLayerSize = inputWeights.size(0);
int miniBatchSize = epsilon.size(0);
//Edge case: T=1 may have shape [miniBatchSize,n^(L+1)], equiv. to [miniBatchSize,n^(L+1),1]
boolean is2dInput = epsilon.rank() < 3;
int timeSeriesLength = (is2dInput ? 1 : epsilon.size(2));
INDArray wFFTranspose = recurrentWeights.get(NDArrayIndex.all(), point(4 * hiddenLayerSize)).transpose();
INDArray wOOTranspose = recurrentWeights.get(NDArrayIndex.all(), point(4 * hiddenLayerSize + 1)).transpose();
INDArray wGGTranspose = recurrentWeights.get(NDArrayIndex.all(), point(4 * hiddenLayerSize + 2)).transpose();
INDArray wIFOG = recurrentWeights.get(NDArrayIndex.all(), NDArrayIndex.interval(0, 4 * hiddenLayerSize));
//F order here so that content for time steps are together
//i.e., what would be W^L*(delta^L)^T. Shape: [m,n^(L-1),T]
INDArray epsilonNext = Nd4j.create(new int[] { miniBatchSize, prevLayerSize, timeSeriesLength }, 'f');
INDArray nablaCellStateNext = null;
INDArray deltaifogNext = Nd4j.create(new int[] { miniBatchSize, 4 * hiddenLayerSize }, 'f');
INDArray deltaiNext = deltaifogNext.get(NDArrayIndex.all(), NDArrayIndex.interval(0, hiddenLayerSize));
INDArray deltafNext = deltaifogNext.get(NDArrayIndex.all(), NDArrayIndex.interval(hiddenLayerSize, 2 * hiddenLayerSize));
INDArray deltaoNext = deltaifogNext.get(NDArrayIndex.all(), NDArrayIndex.interval(2 * hiddenLayerSize, 3 * hiddenLayerSize));
INDArray deltagNext = deltaifogNext.get(NDArrayIndex.all(), NDArrayIndex.interval(3 * hiddenLayerSize, 4 * hiddenLayerSize));
Level1 l1BLAS = Nd4j.getBlasWrapper().level1();
int endIdx = 0;
if (truncatedBPTT) {
endIdx = Math.max(0, timeSeriesLength - tbpttBackwardLength);
}
//Get gradients. Note that we have to manually zero these, as they might not be initialized (or still has data from last iteration)
//Also note that they are in f order (as per param initializer) so can be used in gemm etc
INDArray iwGradientsOut = gradientViews.get(inputWeightKey);
//Order: {I,F,O,G,FF,OO,GG}
INDArray rwGradientsOut = gradientViews.get(recurrentWeightKey);
INDArray bGradientsOut = gradientViews.get(biasWeightKey);
iwGradientsOut.assign(0);
rwGradientsOut.assign(0);
bGradientsOut.assign(0);
INDArray rwGradientsIFOG = rwGradientsOut.get(NDArrayIndex.all(), NDArrayIndex.interval(0, 4 * hiddenLayerSize));
INDArray rwGradientsFF = rwGradientsOut.get(NDArrayIndex.all(), NDArrayIndex.point(4 * hiddenLayerSize));
INDArray rwGradientsOO = rwGradientsOut.get(NDArrayIndex.all(), NDArrayIndex.point(4 * hiddenLayerSize + 1));
INDArray rwGradientsGG = rwGradientsOut.get(NDArrayIndex.all(), NDArrayIndex.point(4 * hiddenLayerSize + 2));
boolean sigmoidGates = gateActivationFn instanceof ActivationSigmoid;
IActivation afn = conf.getLayer().getActivationFn();
INDArray timeStepMaskColumn = null;
for (int iTimeIndex = timeSeriesLength - 1; iTimeIndex >= endIdx; iTimeIndex--) {
int time = iTimeIndex;
int inext = 1;
if (!forwards) {
time = timeSeriesLength - iTimeIndex - 1;
inext = -1;
}
//First: calclate the components of nablaCellState that relies on the next time step deltas, so we can overwrite the deltas
INDArray nablaCellState;
if (iTimeIndex != timeSeriesLength - 1) {
nablaCellState = deltafNext.dup('f').muliRowVector(wFFTranspose);
l1BLAS.axpy(nablaCellState.length(), 1.0, deltagNext.dup('f').muliRowVector(wGGTranspose), nablaCellState);
} else {
nablaCellState = Nd4j.create(new int[] { miniBatchSize, hiddenLayerSize }, 'f');
}
INDArray prevMemCellState = (iTimeIndex == 0 ? null : fwdPass.memCellState[time - inext]);
INDArray prevHiddenUnitActivation = (iTimeIndex == 0 ? null : fwdPass.fwdPassOutputAsArrays[time - inext]);
INDArray currMemCellState = fwdPass.memCellState[time];
//LSTM unit output errors (dL/d(a_out)); not to be confused with \delta=dL/d(z_out)
//(w^{L+1}*(delta^{(L+1)t})^T)^T or equiv.
INDArray epsilonSlice = (is2dInput ? epsilon : epsilon.tensorAlongDimension(time, 1, 0));
//Shape: [m,n^L]
INDArray nablaOut = Shape.toOffsetZeroCopy(epsilonSlice, 'f');
if (iTimeIndex != timeSeriesLength - 1) {
//if t == timeSeriesLength-1 then deltaiNext etc are zeros
Nd4j.gemm(deltaifogNext, wIFOG, nablaOut, false, true, 1.0, 1.0);
}
//Output gate deltas:
INDArray sigmahOfS = fwdPass.memCellActivations[time];
INDArray ao = fwdPass.oa[time];
//Normally would use zo.dup() in above line, but won't be using zo again (for this time step). Ditto for zf, zg, zi
INDArray deltao = deltaoNext;
Nd4j.getExecutioner().exec(new MulOp(nablaOut, sigmahOfS, deltao));
if (sigmoidGates) {
//Equivalent to sigmoid deriv on zo
INDArray sigmaoPrimeOfZo = Nd4j.getExecutioner().execAndReturn(new TimesOneMinus(ao.dup('f')));
deltao.muli(sigmaoPrimeOfZo);
} else {
//Deltao needs to be modified in-place
deltao.assign(gateActivationFn.backprop(fwdPass.oz[time], deltao).getFirst());
//TODO: optimize (no assign)
}
//Memory cell error:
//TODO activation functions with params
INDArray temp = afn.backprop(currMemCellState.dup('f'), ao.muli(nablaOut)).getFirst();
l1BLAS.axpy(nablaCellState.length(), 1.0, temp, nablaCellState);
INDArray deltaMulRowWOO = deltao.dup('f').muliRowVector(wOOTranspose);
//nablaCellState.addi(deltao.mulRowVector(wOOTranspose));
l1BLAS.axpy(nablaCellState.length(), 1.0, deltaMulRowWOO, nablaCellState);
if (iTimeIndex != timeSeriesLength - 1) {
INDArray nextForgetGateAs = fwdPass.fa[time + inext];
int length = nablaCellState.length();
//nablaCellState.addi(nextForgetGateAs.mul(nablaCellStateNext))
l1BLAS.axpy(length, 1.0, nextForgetGateAs.muli(nablaCellStateNext), nablaCellState);
}
//Store for use in next iteration
nablaCellStateNext = nablaCellState;
//Forget gate delta:
INDArray af = fwdPass.fa[time];
INDArray deltaf = null;
if (iTimeIndex > 0) {
deltaf = deltafNext;
if (sigmoidGates) {
Nd4j.getExecutioner().exec(new TimesOneMinus(af, deltaf));
deltaf.muli(nablaCellState);
deltaf.muli(prevMemCellState);
} else {
INDArray temp2 = nablaCellState.mul(prevMemCellState);
//deltaf needs to be modified in-place
deltaf.assign(gateActivationFn.backprop(fwdPass.fz[time].dup('f'), temp2).getFirst());
//TODO activation functions with params
}
}
//Shape: [m,n^L]
//Input modulation gate delta:
INDArray ag = fwdPass.ga[time];
INDArray ai = fwdPass.ia[time];
INDArray deltag = deltagNext;
if (sigmoidGates) {
//Equivalent to sigmoid deriv on zg
Nd4j.getExecutioner().exec(new TimesOneMinus(ag, deltag));
deltag.muli(ai);
deltag.muli(nablaCellState);
} else {
INDArray temp2 = Nd4j.getExecutioner().execAndReturn(new MulOp(ai, nablaCellState, Nd4j.createUninitialized(ai.shape(), 'f')));
deltag.assign(gateActivationFn.backprop(fwdPass.gz[time], temp2).getFirst());
//TODO activation functions with params; optimize (no assign)
}
//Shape: [m,n^L]
//Network input delta:
INDArray zi = fwdPass.iz[time];
INDArray deltai = deltaiNext;
temp = Nd4j.getExecutioner().execAndReturn(new MulOp(ag, nablaCellState, Nd4j.createUninitialized(deltai.shape(), 'f')));
deltai.assign(afn.backprop(zi, temp).getFirst());
//Handle masking
if (maskArray != null) {
//Mask array is present: bidirectional RNN -> need to zero out these errors to avoid using errors from a masked time step
// to calculate the parameter gradients. Mask array has shape [minibatch, timeSeriesLength] -> get column(this time step)
timeStepMaskColumn = maskArray.getColumn(time);
deltaifogNext.muliColumnVector(timeStepMaskColumn);
//Later, the deltaifogNext is used to calculate: input weight gradients, recurrent weight gradients, bias gradients
}
INDArray prevLayerActivationSlice = Shape.toMmulCompatible(is2dInput ? input : input.tensorAlongDimension(time, 1, 0));
if (iTimeIndex > 0) {
//Again, deltaifog_current == deltaifogNext at this point... same array
Nd4j.gemm(prevLayerActivationSlice, deltaifogNext, iwGradientsOut, true, false, 1.0, 1.0);
} else {
INDArray iwGradients_i = iwGradientsOut.get(NDArrayIndex.all(), NDArrayIndex.interval(0, hiddenLayerSize));
Nd4j.gemm(prevLayerActivationSlice, deltai, iwGradients_i, true, false, 1.0, 1.0);
INDArray iwGradients_og = iwGradientsOut.get(NDArrayIndex.all(), NDArrayIndex.interval(2 * hiddenLayerSize, 4 * hiddenLayerSize));
INDArray deltaog = deltaifogNext.get(NDArrayIndex.all(), NDArrayIndex.interval(2 * hiddenLayerSize, 4 * hiddenLayerSize));
Nd4j.gemm(prevLayerActivationSlice, deltaog, iwGradients_og, true, false, 1.0, 1.0);
}
if (iTimeIndex > 0) {
//If t==0, then prevHiddenUnitActivation==zeros(n^L,n^L), so dL/dW for recurrent weights will end up as 0 anyway
//At this point: deltaifog and deltaifogNext are the same thing...
//So what we are actually doing here is sum of (prevAct^transpose * deltaifog_current)
Nd4j.gemm(prevHiddenUnitActivation, deltaifogNext, rwGradientsIFOG, true, false, 1.0, 1.0);
//Shape: [1,n^L]. sum(0) is sum over examples in mini-batch.
//Can use axpy here because result of sum and rwGradients[4 to 6] have order Nd4j.order(), via Nd4j.create()
//mul not mmul because these weights are from unit j->j only (whereas other recurrent weights are i->j for all i,j)
INDArray dLdwFF = deltaf.dup('f').muli(prevMemCellState).sum(0);
//rwGradients[4].addi(dLdwFF); //dL/dw_{FF}
l1BLAS.axpy(hiddenLayerSize, 1.0, dLdwFF, rwGradientsFF);
INDArray dLdwGG = deltag.dup('f').muli(prevMemCellState).sum(0);
//rwGradients[6].addi(dLdwGG);
l1BLAS.axpy(hiddenLayerSize, 1.0, dLdwGG, rwGradientsGG);
}
//Expected shape: [n^L,1]. sum(0) is sum over examples in mini-batch.
INDArray dLdwOO = deltao.dup('f').muli(currMemCellState).sum(0);
//rwGradients[5].addi(dLdwOO); //dL/dw_{OOxy}
l1BLAS.axpy(hiddenLayerSize, 1.0, dLdwOO, rwGradientsOO);
if (iTimeIndex > 0) {
l1BLAS.axpy(4 * hiddenLayerSize, 1.0, deltaifogNext.sum(0), bGradientsOut);
} else {
//Sneaky way to do bGradients_i += deltai.sum(0)
l1BLAS.axpy(hiddenLayerSize, 1.0, deltai.sum(0), bGradientsOut);
INDArray ogBiasToAdd = deltaifogNext.get(NDArrayIndex.all(), NDArrayIndex.interval(2 * hiddenLayerSize, 4 * hiddenLayerSize)).sum(0);
INDArray ogBiasGrad = bGradientsOut.get(NDArrayIndex.point(0), NDArrayIndex.interval(2 * hiddenLayerSize, 4 * hiddenLayerSize));
l1BLAS.axpy(2 * hiddenLayerSize, 1.0, ogBiasToAdd, ogBiasGrad);
}
//Calculate epsilonNext - i.e., equiv. to what would be (w^L*(d^(Lt))^T)^T in a normal network
//But here, need to add 4 weights * deltas for the IFOG gates
//This slice: f order and contiguous, due to epsilonNext being defined as f order.
INDArray epsilonNextSlice = epsilonNext.tensorAlongDimension(time, 1, 0);
if (iTimeIndex > 0) {
Nd4j.gemm(deltaifogNext, inputWeights, epsilonNextSlice, false, true, 1.0, 1.0);
} else {
//No contribution from forget gate at t=0
INDArray wi = inputWeights.get(NDArrayIndex.all(), NDArrayIndex.interval(0, hiddenLayerSize));
Nd4j.gemm(deltai, wi, epsilonNextSlice, false, true, 1.0, 1.0);
INDArray deltaog = deltaifogNext.get(NDArrayIndex.all(), NDArrayIndex.interval(2 * hiddenLayerSize, 4 * hiddenLayerSize));
INDArray wog = inputWeights.get(NDArrayIndex.all(), NDArrayIndex.interval(2 * hiddenLayerSize, 4 * hiddenLayerSize));
//epsilonNextSlice.addi(deltao.mmul(woTranspose)).addi(deltag.mmul(wgTranspose));
Nd4j.gemm(deltaog, wog, epsilonNextSlice, false, true, 1.0, 1.0);
}
if (maskArray != null) {
//Mask array is present: bidirectional RNN -> need to zero out these errors to avoid sending anything
// but 0s to the layer below at this time step (for the given example)
epsilonNextSlice.muliColumnVector(timeStepMaskColumn);
}
}
Gradient retGradient = new DefaultGradient();
retGradient.gradientForVariable().put(inputWeightKey, iwGradientsOut);
retGradient.gradientForVariable().put(recurrentWeightKey, rwGradientsOut);
retGradient.gradientForVariable().put(biasWeightKey, bGradientsOut);
return new Pair<>(retGradient, epsilonNext);
}
Aggregations