use of org.nd4j.linalg.api.ops.impl.broadcast.BroadcastSubOp in project deeplearning4j by deeplearning4j.
the class BatchNormalization method preOutput.
public INDArray preOutput(INDArray x, TrainingMode training) {
INDArray activations;
// TODO add this directly in layer or get the layer prior...
// batchnorm true but need to clarify if activation before or after
org.deeplearning4j.nn.conf.layers.BatchNormalization layerConf = layerConf();
int[] shape = getShape(x);
// xHat = (x-xmean) / sqrt(var + epsilon)
//Note that for CNNs, mean and variance are calculated per feature map (i.e., per activation) rather than per activation
//Pg5 of http://arxiv.org/pdf/1502.03167v3.pdf
// "For convolutional layers, we additionally want the normalization to obey the convolutional property – so that
// different elements of the same feature map, at different locations, are normalized in the same way. To achieve
// this, we jointly normalize all the activations in a minibatch, over all locations."
INDArray mean, var;
if (training == TrainingMode.TRAIN) {
switch(x.rank()) {
case 2:
// mean and variance over samples in batch
mean = x.mean(0);
var = x.var(false, 0);
break;
case 4:
// mean and variance over samples AND locations
mean = x.mean(0, 2, 3);
var = x.var(false, 0, 2, 3);
break;
default:
throw new IllegalStateException("Batch normalization on activations of rank " + x.rank() + " not supported");
}
var.addi(layerConf.getEps());
} else {
// Global mean and variance estimate - used after training
mean = getParam(BatchNormalizationParamInitializer.GLOBAL_MEAN);
var = getParam(BatchNormalizationParamInitializer.GLOBAL_VAR);
}
std = Transforms.sqrt(var, true);
INDArray gamma = null;
INDArray beta = null;
INDArray globalMeanView = getParam(BatchNormalizationParamInitializer.GLOBAL_MEAN);
INDArray globalVarView = getParam(BatchNormalizationParamInitializer.GLOBAL_VAR);
if (layerConf.isLockGammaBeta()) {
if (helper != null && input.rank() == 4) {
//TODO: don't create these each iteration, when using cudnn
int[] gammaBetaShape = new int[] { 1, layerConf().getNOut() };
gamma = Nd4j.valueArrayOf(gammaBetaShape, layerConf().getGamma());
beta = Nd4j.valueArrayOf(gammaBetaShape, layerConf().getBeta());
}
} else {
gamma = getParam(BatchNormalizationParamInitializer.GAMMA);
beta = getParam(BatchNormalizationParamInitializer.BETA);
}
if (helper != null && input.rank() != 4) {
//Note that cudnn does not support dense (2d) batch norm case as of v5.1
double decay = layerConf.getDecay();
INDArray ret = helper.preOutput(x, training == TrainingMode.TRAIN, shape, gamma, beta, globalMeanView, globalVarView, decay, layerConf.getEps());
if (ret != null) {
return ret;
}
}
// BN(xk) = gamma*xˆ + β (applying gamma and beta for each activation)
if (x.rank() == 2) {
xMu = x.subRowVector(mean);
xHat = xMu.divRowVector(std);
if (layerConf.isLockGammaBeta()) {
//Special case: gamma/beta have fixed values for all outputs
//Use mul/addi(Number) here to avoid allocating temp arrays of all same value
double g = layerConf.getGamma();
double b = layerConf.getBeta();
if (g != 1.0 && b != 0.0) {
//Default and most common case: 1.0 and 0.0 for these parameters. No point executing 1 * x + 0 op
activations = xHat.mul(g).addi(b);
} else {
activations = xHat;
}
} else {
//Standard case: gamma and beta are learned per parameter
activations = xHat.mulRowVector(gamma).addiRowVector(beta);
}
} else if (x.rank() == 4) {
if (!Shape.strideDescendingCAscendingF(x))
//TODO: temp Workaround for broadcast bug. To be removed when fixed
x = x.dup();
xMu = Nd4j.getExecutioner().execAndReturn(new BroadcastSubOp(x, mean, Nd4j.createUninitialized(x.shape(), x.ordering()), 1));
xHat = Nd4j.getExecutioner().execAndReturn(new BroadcastDivOp(xMu, std, Nd4j.createUninitialized(x.shape(), x.ordering()), 1));
if (layerConf.isLockGammaBeta()) {
//Special case: gamma/beta have fixed values for all outputs
//Use mul/addi(Number) here to avoid allocating temp arrays of all same value
double g = layerConf.getGamma();
double b = layerConf.getBeta();
if (g != 1.0 && b != 0.0) {
//Default and most common case: 1.0 and 0.0 for these parameters. No point executing 1 * x + 0 op
activations = xHat.mul(g).addi(b);
} else {
activations = xHat;
}
} else {
//Standard case: gamma and beta are learned per parameter
activations = Nd4j.getExecutioner().execAndReturn(new BroadcastMulOp(xHat, gamma, Nd4j.createUninitialized(x.shape(), x.ordering()), 1));
activations = Nd4j.getExecutioner().execAndReturn(new BroadcastAddOp(activations, beta, activations, 1));
}
} else {
// TODO setup BatchNorm for RNN http://arxiv.org/pdf/1510.01378v1.pdf
throw new IllegalStateException("The layer prior to BatchNorm in the configuration is not currently supported.");
}
// store mean and var if using batch mean while training
double decay;
if (training == TrainingMode.TRAIN) {
if (layerConf.isMinibatch()) {
//Standard case: Estimate global mean and variance stats by moving average
//globalMean = decay * globalMean + (1-decay) * minibatchMean
//globalVar = decay * globalVar + (1-decay) * minibatchVar
//Note that it's safe to do a muli on 'mean' and 'var' variables: can't be the global arrays with training == Trainingmode.TRAIN
decay = layerConf.getDecay();
globalMeanView.muli(decay).addi(mean.muli(1 - decay));
globalVarView.muli(decay).addi(var.muli(1 - decay));
} else {
//Special case: doing full-batch (entire data set) training (uncommon; only tiny data sets)
//In this case, minibatch and global stats are identical. Don't want to use a moving average estimate.
globalMeanView.assign(mean);
globalVarView.assign(var);
}
}
return activations;
}
use of org.nd4j.linalg.api.ops.impl.broadcast.BroadcastSubOp in project deeplearning4j by deeplearning4j.
the class BatchNormalizationTest method testCnnForwardBackward.
@Test
public void testCnnForwardBackward() {
double eps = 1e-5;
int nIn = 4;
int hw = 3;
int minibatch = 2;
Nd4j.getRandom().setSeed(12345);
INDArray input = Nd4j.rand('c', new int[] { minibatch, nIn, hw, hw });
//TODO: other values for gamma/beta
INDArray gamma = Nd4j.ones(1, nIn);
INDArray beta = Nd4j.zeros(1, nIn);
Layer l = getLayer(nIn, eps, false, -1, -1);
INDArray mean = input.mean(0, 2, 3);
INDArray var = input.var(false, 0, 2, 3);
INDArray xHat = Nd4j.getExecutioner().execAndReturn(new BroadcastSubOp(input, mean, input.dup(), 1));
Nd4j.getExecutioner().execAndReturn(new BroadcastDivOp(xHat, Transforms.sqrt(var.add(eps), true), xHat, 1));
INDArray outExpected = Nd4j.getExecutioner().execAndReturn(new BroadcastMulOp(xHat, gamma, xHat.dup(), 1));
Nd4j.getExecutioner().execAndReturn(new BroadcastAddOp(outExpected, beta, outExpected, 1));
INDArray out = l.activate(input, true);
System.out.println(Arrays.toString(outExpected.data().asDouble()));
System.out.println(Arrays.toString(out.data().asDouble()));
assertEquals(outExpected, out);
//-------------------------------------------------------------
//Check backprop
//dL/dy
INDArray epsilon = Nd4j.rand('c', new int[] { minibatch, nIn, hw, hw });
int effectiveMinibatch = minibatch * hw * hw;
INDArray dldgammaExp = epsilon.mul(xHat).sum(0, 2, 3);
INDArray dldbetaExp = epsilon.sum(0, 2, 3);
//epsilon.mulRowVector(gamma);
INDArray dldxhat = Nd4j.getExecutioner().execAndReturn(new BroadcastMulOp(epsilon, gamma, epsilon.dup(), 1));
INDArray inputSubMean = Nd4j.getExecutioner().execAndReturn(new BroadcastSubOp(input, mean, input.dup(), 1));
INDArray dldvar = dldxhat.mul(inputSubMean).mul(-0.5);
dldvar = Nd4j.getExecutioner().execAndReturn(new BroadcastMulOp(dldvar, Transforms.pow(var.add(eps), -3.0 / 2.0, true), dldvar.dup(), 1));
dldvar = dldvar.sum(0, 2, 3);
INDArray dldmu = Nd4j.getExecutioner().execAndReturn(new BroadcastMulOp(dldxhat, Transforms.pow(var.add(eps), -1.0 / 2.0, true), dldxhat.dup(), 1)).neg().sum(0, 2, 3);
dldmu = dldmu.add(dldvar.mul(inputSubMean.mul(-2.0).sum(0, 2, 3).div(effectiveMinibatch)));
INDArray dldinExp = Nd4j.getExecutioner().execAndReturn(new BroadcastMulOp(dldxhat, Transforms.pow(var.add(eps), -1.0 / 2.0, true), dldxhat.dup(), 1));
dldinExp = dldinExp.add(Nd4j.getExecutioner().execAndReturn(new BroadcastMulOp(inputSubMean.mul(2.0 / effectiveMinibatch), dldvar, inputSubMean.dup(), 1)));
dldinExp = Nd4j.getExecutioner().execAndReturn(new BroadcastAddOp(dldinExp, dldmu.mul(1.0 / effectiveMinibatch), dldinExp.dup(), 1));
Pair<Gradient, INDArray> p = l.backpropGradient(epsilon);
INDArray dldgamma = p.getFirst().getGradientFor("gamma");
INDArray dldbeta = p.getFirst().getGradientFor("beta");
assertEquals(dldgammaExp, dldgamma);
assertEquals(dldbetaExp, dldbeta);
// System.out.println("EPSILONS");
// System.out.println(Arrays.toString(dldinExp.data().asDouble()));
// System.out.println(Arrays.toString(p.getSecond().dup().data().asDouble()));
assertEquals(dldinExp, p.getSecond());
}
Aggregations