Search in sources :

Example 26 with DynamicCustomOp

use of org.nd4j.linalg.api.ops.DynamicCustomOp in project nd4j by deeplearning4j.

the class ConvolutionTestsC method testMaxPoolBackprop.

@Test
@Ignore
public void testMaxPoolBackprop() {
    Nd4j.getRandom().setSeed(12345);
    for (int i = 0; i < 5; i++) {
        int[] inputShape = { 1, 1, 4, 3 };
        int[] kernel = { 2, 2 };
        int[] strides = { 1, 1 };
        int[] pad = { 0, 0 };
        // TODO non 1-1 dilation
        int[] dilation = { 1, 1 };
        boolean same = true;
        String fn = "maxpool2d_bp";
        int nIArgs = 11;
        int[] a = new int[nIArgs];
        a[0] = kernel[0];
        a[1] = kernel[1];
        a[2] = strides[0];
        a[3] = strides[1];
        a[4] = pad[0];
        a[5] = pad[1];
        a[6] = dilation[0];
        a[7] = dilation[1];
        a[8] = same ? 1 : 0;
        // a[9]: Not used with max pooling
        // For NCHW
        a[10] = 0;
        List<Pair<INDArray, String>> inputs = NDArrayCreationUtil.getAll4dTestArraysWithShape(12345, inputShape);
        for (Pair<INDArray, String> pIn : inputs) {
            INDArray input = pIn.getFirst();
            int[] outShapeHW = getOutputSize(input, kernel, strides, pad, same);
            List<Pair<INDArray, String>> eps = NDArrayCreationUtil.getAll4dTestArraysWithShape(12345, inputShape[0], inputShape[1], outShapeHW[0], outShapeHW[1]);
            for (Pair<INDArray, String> pEps : eps) {
                INDArray epsilon = pEps.getFirst();
                INDArray epsNext = Nd4j.create(inputShape, 'c');
                // Runs fine with dups:
                // input = input.dup('c');
                epsilon = epsilon.dup('c');
                DynamicCustomOp op = DynamicCustomOp.builder(fn).addInputs(input, epsilon).addOutputs(epsNext).addIntegerArguments(a).build();
                Nd4j.getExecutioner().exec(op);
                INDArray expEpsNext = expGradMaxPoolBackPropSame(input, epsilon, kernel, strides, same);
                String msg = "input=" + pIn.getSecond() + ", eps=" + pEps.getSecond();
                assertEquals(msg, expEpsNext, epsNext);
            }
        }
    }
}
Also used : INDArray(org.nd4j.linalg.api.ndarray.INDArray) DynamicCustomOp(org.nd4j.linalg.api.ops.DynamicCustomOp) Pair(org.nd4j.linalg.primitives.Pair) Ignore(org.junit.Ignore) Test(org.junit.Test) BaseNd4jTest(org.nd4j.linalg.BaseNd4jTest)

Aggregations

Test (org.junit.Test)26 INDArray (org.nd4j.linalg.api.ndarray.INDArray)26 DynamicCustomOp (org.nd4j.linalg.api.ops.DynamicCustomOp)26 BaseNd4jTest (org.nd4j.linalg.BaseNd4jTest)15 NDArrayIndex.point (org.nd4j.linalg.indexing.NDArrayIndex.point)14 SDVariable (org.nd4j.autodiff.samediff.SDVariable)10 SameDiff (org.nd4j.autodiff.samediff.SameDiff)10 ArrayList (java.util.ArrayList)2 DifferentialFunction (org.nd4j.autodiff.functions.DifferentialFunction)2 Ignore (org.junit.Ignore)1 MMulTranspose (org.nd4j.linalg.api.blas.params.MMulTranspose)1 Mmul (org.nd4j.linalg.api.ops.impl.accum.Mmul)1 TruncateDivOp (org.nd4j.linalg.api.ops.impl.transforms.arithmetic.TruncateDivOp)1 GreaterThanOrEqual (org.nd4j.linalg.api.ops.impl.transforms.comparison.GreaterThanOrEqual)1 LessThanOrEqual (org.nd4j.linalg.api.ops.impl.transforms.comparison.LessThanOrEqual)1 OldMax (org.nd4j.linalg.api.ops.impl.transforms.comparison.OldMax)1 OldMin (org.nd4j.linalg.api.ops.impl.transforms.comparison.OldMin)1 BernoulliDistribution (org.nd4j.linalg.api.ops.random.impl.BernoulliDistribution)1 Pair (org.nd4j.linalg.primitives.Pair)1