Search in sources :

Example 1 with BroadcastSubOp

use of org.nd4j.linalg.api.ops.impl.broadcast.BroadcastSubOp in project deeplearning4j by deeplearning4j.

the class BatchNormalization method preOutput.

public INDArray preOutput(INDArray x, TrainingMode training) {
    INDArray activations;
    // TODO add this directly in layer or get the layer prior...
    // batchnorm true but need to clarify if activation before or after
    org.deeplearning4j.nn.conf.layers.BatchNormalization layerConf = layerConf();
    int[] shape = getShape(x);
    // xHat = (x-xmean) / sqrt(var + epsilon)
    //Note that for CNNs, mean and variance are calculated per feature map (i.e., per activation) rather than per activation
    //Pg5 of http://arxiv.org/pdf/1502.03167v3.pdf
    // "For convolutional layers, we additionally want the normalization to obey the convolutional property – so that
    //  different elements of the same feature map, at different locations, are normalized in the same way. To achieve
    //  this, we jointly normalize all the activations in a minibatch, over all locations."
    INDArray mean, var;
    if (training == TrainingMode.TRAIN) {
        switch(x.rank()) {
            case 2:
                // mean and variance over samples in batch
                mean = x.mean(0);
                var = x.var(false, 0);
                break;
            case 4:
                // mean and variance over samples AND locations
                mean = x.mean(0, 2, 3);
                var = x.var(false, 0, 2, 3);
                break;
            default:
                throw new IllegalStateException("Batch normalization on activations of rank " + x.rank() + " not supported");
        }
        var.addi(layerConf.getEps());
    } else {
        // Global mean and variance estimate - used after training
        mean = getParam(BatchNormalizationParamInitializer.GLOBAL_MEAN);
        var = getParam(BatchNormalizationParamInitializer.GLOBAL_VAR);
    }
    std = Transforms.sqrt(var, true);
    INDArray gamma = null;
    INDArray beta = null;
    INDArray globalMeanView = getParam(BatchNormalizationParamInitializer.GLOBAL_MEAN);
    INDArray globalVarView = getParam(BatchNormalizationParamInitializer.GLOBAL_VAR);
    if (layerConf.isLockGammaBeta()) {
        if (helper != null && input.rank() == 4) {
            //TODO: don't create these each iteration, when using cudnn
            int[] gammaBetaShape = new int[] { 1, layerConf().getNOut() };
            gamma = Nd4j.valueArrayOf(gammaBetaShape, layerConf().getGamma());
            beta = Nd4j.valueArrayOf(gammaBetaShape, layerConf().getBeta());
        }
    } else {
        gamma = getParam(BatchNormalizationParamInitializer.GAMMA);
        beta = getParam(BatchNormalizationParamInitializer.BETA);
    }
    if (helper != null && input.rank() != 4) {
        //Note that cudnn does not support dense (2d) batch norm case as of v5.1
        double decay = layerConf.getDecay();
        INDArray ret = helper.preOutput(x, training == TrainingMode.TRAIN, shape, gamma, beta, globalMeanView, globalVarView, decay, layerConf.getEps());
        if (ret != null) {
            return ret;
        }
    }
    // BN(xk) = gamma*xˆ + β (applying gamma and beta for each activation)
    if (x.rank() == 2) {
        xMu = x.subRowVector(mean);
        xHat = xMu.divRowVector(std);
        if (layerConf.isLockGammaBeta()) {
            //Special case: gamma/beta have fixed values for all outputs
            //Use mul/addi(Number) here to avoid allocating temp arrays of all same value
            double g = layerConf.getGamma();
            double b = layerConf.getBeta();
            if (g != 1.0 && b != 0.0) {
                //Default and most common case: 1.0 and 0.0 for these parameters. No point executing 1 * x + 0 op
                activations = xHat.mul(g).addi(b);
            } else {
                activations = xHat;
            }
        } else {
            //Standard case: gamma and beta are learned per parameter
            activations = xHat.mulRowVector(gamma).addiRowVector(beta);
        }
    } else if (x.rank() == 4) {
        if (!Shape.strideDescendingCAscendingF(x))
            //TODO: temp Workaround for broadcast bug. To be removed when fixed
            x = x.dup();
        xMu = Nd4j.getExecutioner().execAndReturn(new BroadcastSubOp(x, mean, Nd4j.createUninitialized(x.shape(), x.ordering()), 1));
        xHat = Nd4j.getExecutioner().execAndReturn(new BroadcastDivOp(xMu, std, Nd4j.createUninitialized(x.shape(), x.ordering()), 1));
        if (layerConf.isLockGammaBeta()) {
            //Special case: gamma/beta have fixed values for all outputs
            //Use mul/addi(Number) here to avoid allocating temp arrays of all same value
            double g = layerConf.getGamma();
            double b = layerConf.getBeta();
            if (g != 1.0 && b != 0.0) {
                //Default and most common case: 1.0 and 0.0 for these parameters. No point executing 1 * x + 0 op
                activations = xHat.mul(g).addi(b);
            } else {
                activations = xHat;
            }
        } else {
            //Standard case: gamma and beta are learned per parameter
            activations = Nd4j.getExecutioner().execAndReturn(new BroadcastMulOp(xHat, gamma, Nd4j.createUninitialized(x.shape(), x.ordering()), 1));
            activations = Nd4j.getExecutioner().execAndReturn(new BroadcastAddOp(activations, beta, activations, 1));
        }
    } else {
        // TODO setup BatchNorm for RNN http://arxiv.org/pdf/1510.01378v1.pdf
        throw new IllegalStateException("The layer prior to BatchNorm in the configuration is not currently supported.");
    }
    // store mean and var if using batch mean while training
    double decay;
    if (training == TrainingMode.TRAIN) {
        if (layerConf.isMinibatch()) {
            //Standard case: Estimate global mean and variance stats by moving average
            //globalMean = decay * globalMean + (1-decay) * minibatchMean
            //globalVar  = decay * globalVar  + (1-decay) * minibatchVar
            //Note that it's safe to do a muli on 'mean' and 'var' variables: can't be the global arrays with training == Trainingmode.TRAIN
            decay = layerConf.getDecay();
            globalMeanView.muli(decay).addi(mean.muli(1 - decay));
            globalVarView.muli(decay).addi(var.muli(1 - decay));
        } else {
            //Special case: doing full-batch (entire data set) training (uncommon; only tiny data sets)
            //In this case, minibatch and global stats are identical. Don't want to use a moving average estimate.
            globalMeanView.assign(mean);
            globalVarView.assign(var);
        }
    }
    return activations;
}
Also used : BroadcastAddOp(org.nd4j.linalg.api.ops.impl.broadcast.BroadcastAddOp) INDArray(org.nd4j.linalg.api.ndarray.INDArray) BroadcastSubOp(org.nd4j.linalg.api.ops.impl.broadcast.BroadcastSubOp) BroadcastMulOp(org.nd4j.linalg.api.ops.impl.broadcast.BroadcastMulOp) BroadcastDivOp(org.nd4j.linalg.api.ops.impl.broadcast.BroadcastDivOp)

Example 2 with BroadcastSubOp

use of org.nd4j.linalg.api.ops.impl.broadcast.BroadcastSubOp in project deeplearning4j by deeplearning4j.

the class BatchNormalizationTest method testCnnForwardBackward.

@Test
public void testCnnForwardBackward() {
    double eps = 1e-5;
    int nIn = 4;
    int hw = 3;
    int minibatch = 2;
    Nd4j.getRandom().setSeed(12345);
    INDArray input = Nd4j.rand('c', new int[] { minibatch, nIn, hw, hw });
    //TODO: other values for gamma/beta
    INDArray gamma = Nd4j.ones(1, nIn);
    INDArray beta = Nd4j.zeros(1, nIn);
    Layer l = getLayer(nIn, eps, false, -1, -1);
    INDArray mean = input.mean(0, 2, 3);
    INDArray var = input.var(false, 0, 2, 3);
    INDArray xHat = Nd4j.getExecutioner().execAndReturn(new BroadcastSubOp(input, mean, input.dup(), 1));
    Nd4j.getExecutioner().execAndReturn(new BroadcastDivOp(xHat, Transforms.sqrt(var.add(eps), true), xHat, 1));
    INDArray outExpected = Nd4j.getExecutioner().execAndReturn(new BroadcastMulOp(xHat, gamma, xHat.dup(), 1));
    Nd4j.getExecutioner().execAndReturn(new BroadcastAddOp(outExpected, beta, outExpected, 1));
    INDArray out = l.activate(input, true);
    System.out.println(Arrays.toString(outExpected.data().asDouble()));
    System.out.println(Arrays.toString(out.data().asDouble()));
    assertEquals(outExpected, out);
    //-------------------------------------------------------------
    //Check backprop
    //dL/dy
    INDArray epsilon = Nd4j.rand('c', new int[] { minibatch, nIn, hw, hw });
    int effectiveMinibatch = minibatch * hw * hw;
    INDArray dldgammaExp = epsilon.mul(xHat).sum(0, 2, 3);
    INDArray dldbetaExp = epsilon.sum(0, 2, 3);
    //epsilon.mulRowVector(gamma);
    INDArray dldxhat = Nd4j.getExecutioner().execAndReturn(new BroadcastMulOp(epsilon, gamma, epsilon.dup(), 1));
    INDArray inputSubMean = Nd4j.getExecutioner().execAndReturn(new BroadcastSubOp(input, mean, input.dup(), 1));
    INDArray dldvar = dldxhat.mul(inputSubMean).mul(-0.5);
    dldvar = Nd4j.getExecutioner().execAndReturn(new BroadcastMulOp(dldvar, Transforms.pow(var.add(eps), -3.0 / 2.0, true), dldvar.dup(), 1));
    dldvar = dldvar.sum(0, 2, 3);
    INDArray dldmu = Nd4j.getExecutioner().execAndReturn(new BroadcastMulOp(dldxhat, Transforms.pow(var.add(eps), -1.0 / 2.0, true), dldxhat.dup(), 1)).neg().sum(0, 2, 3);
    dldmu = dldmu.add(dldvar.mul(inputSubMean.mul(-2.0).sum(0, 2, 3).div(effectiveMinibatch)));
    INDArray dldinExp = Nd4j.getExecutioner().execAndReturn(new BroadcastMulOp(dldxhat, Transforms.pow(var.add(eps), -1.0 / 2.0, true), dldxhat.dup(), 1));
    dldinExp = dldinExp.add(Nd4j.getExecutioner().execAndReturn(new BroadcastMulOp(inputSubMean.mul(2.0 / effectiveMinibatch), dldvar, inputSubMean.dup(), 1)));
    dldinExp = Nd4j.getExecutioner().execAndReturn(new BroadcastAddOp(dldinExp, dldmu.mul(1.0 / effectiveMinibatch), dldinExp.dup(), 1));
    Pair<Gradient, INDArray> p = l.backpropGradient(epsilon);
    INDArray dldgamma = p.getFirst().getGradientFor("gamma");
    INDArray dldbeta = p.getFirst().getGradientFor("beta");
    assertEquals(dldgammaExp, dldgamma);
    assertEquals(dldbetaExp, dldbeta);
    //        System.out.println("EPSILONS");
    //        System.out.println(Arrays.toString(dldinExp.data().asDouble()));
    //        System.out.println(Arrays.toString(p.getSecond().dup().data().asDouble()));
    assertEquals(dldinExp, p.getSecond());
}
Also used : BroadcastAddOp(org.nd4j.linalg.api.ops.impl.broadcast.BroadcastAddOp) Gradient(org.deeplearning4j.nn.gradient.Gradient) INDArray(org.nd4j.linalg.api.ndarray.INDArray) BroadcastSubOp(org.nd4j.linalg.api.ops.impl.broadcast.BroadcastSubOp) BroadcastMulOp(org.nd4j.linalg.api.ops.impl.broadcast.BroadcastMulOp) Layer(org.deeplearning4j.nn.api.Layer) BroadcastDivOp(org.nd4j.linalg.api.ops.impl.broadcast.BroadcastDivOp) Test(org.junit.Test)

Aggregations

INDArray (org.nd4j.linalg.api.ndarray.INDArray)2 BroadcastAddOp (org.nd4j.linalg.api.ops.impl.broadcast.BroadcastAddOp)2 BroadcastDivOp (org.nd4j.linalg.api.ops.impl.broadcast.BroadcastDivOp)2 BroadcastMulOp (org.nd4j.linalg.api.ops.impl.broadcast.BroadcastMulOp)2 BroadcastSubOp (org.nd4j.linalg.api.ops.impl.broadcast.BroadcastSubOp)2 Layer (org.deeplearning4j.nn.api.Layer)1 Gradient (org.deeplearning4j.nn.gradient.Gradient)1 Test (org.junit.Test)1