Search in sources :

Example 16 with DynamicCustomOp

use of org.nd4j.linalg.api.ops.DynamicCustomOp in project nd4j by deeplearning4j.

the class GradCheckTransforms method testSpaceToBatch.

@Test
public void testSpaceToBatch() {
    Nd4j.getRandom().setSeed(7331);
    int miniBatch = 4;
    int[] inputShape = new int[] { 1, 2, 2, 1 };
    int M = 2;
    int[] blockShape = new int[] { M, 1 };
    int[] paddingShape = new int[] { M, 2 };
    INDArray input = Nd4j.randn(inputShape);
    INDArray blocks = Nd4j.create(new float[] { 2, 2 }, blockShape);
    INDArray padding = Nd4j.create(new float[] { 0, 0, 0, 0 }, paddingShape);
    SameDiff sd = SameDiff.create();
    SDVariable sdInput = sd.var("in", inputShape);
    INDArray expOut = Nd4j.create(miniBatch, 1, 1, 1);
    DynamicCustomOp op = DynamicCustomOp.builder("space_to_batch").addInputs(input, blocks, padding).addOutputs(expOut).build();
    Nd4j.getExecutioner().exec(op);
    sd.associateArrayWithVariable(input, sdInput);
    SDVariable t = sd.spaceToBatch(sdInput, new int[] { 2, 2 }, new int[][] { { 0, 0 }, { 0, 0 } });
    SDVariable loss = sd.mean("loss", t);
    sd.exec();
    INDArray out = t.getArr();
    if (!expOut.equals(out)) {
        log.info("space to batch failed on forward");
    }
    try {
        GradCheckUtil.checkGradients(sd);
    } catch (Exception e) {
        e.printStackTrace();
    }
}
Also used : SDVariable(org.nd4j.autodiff.samediff.SDVariable) INDArray(org.nd4j.linalg.api.ndarray.INDArray) SameDiff(org.nd4j.autodiff.samediff.SameDiff) DynamicCustomOp(org.nd4j.linalg.api.ops.DynamicCustomOp) Test(org.junit.Test)

Example 17 with DynamicCustomOp

use of org.nd4j.linalg.api.ops.DynamicCustomOp in project nd4j by deeplearning4j.

the class GradCheckTransforms method testPairwiseTransforms.

@Test
public void testPairwiseTransforms() {
    /*
        add, sub, mul, div, rsub, rdiv
        eq, neq, gt, lt, gte, lte, or, and, xor
        min, max
        mmul
        tensormmul
         */
    // Test transforms (pairwise)
    Nd4j.getRandom().setSeed(12345);
    List<String> allSkipped = new ArrayList<>();
    List<String> allFailed = new ArrayList<>();
    for (int i = 0; i < 23; i++) {
        boolean skipBackward = false;
        SameDiff sd = SameDiff.create();
        int nOut = 4;
        int minibatch = 5;
        SDVariable in1 = sd.var("in1", new int[] { -1, nOut });
        SDVariable in2 = sd.var("in2", new int[] { -1, nOut });
        INDArray ia = Nd4j.randn(minibatch, nOut);
        INDArray ib = Nd4j.randn(minibatch, nOut);
        SDVariable t;
        INDArray expOut;
        switch(i) {
            case 0:
                t = in1.add(in2);
                expOut = ia.add(ib);
                break;
            case 1:
                t = in1.sub(in2);
                expOut = ia.sub(ib);
                break;
            case 2:
                t = in1.mul(in2);
                expOut = ia.mul(ib);
                break;
            case 3:
                // break;
                continue;
            case 4:
                t = in1.rsub(in2);
                expOut = ia.rsub(ib);
                break;
            case 5:
                t = in1.rdiv(in2);
                expOut = ia.rdiv(ib);
                break;
            case 6:
                t = sd.eq(in1, in2);
                expOut = ia.eq(ib);
                break;
            case 7:
                t = sd.neq(in1, in2);
                expOut = ia.neq(ib);
                break;
            case 8:
                t = sd.gt(in1, in2);
                expOut = ia.gt(ib);
                break;
            case 9:
                t = sd.lt(in1, in2);
                expOut = ia.lt(ib);
                break;
            case 10:
                t = sd.gte(in1, in2);
                expOut = ia.dup();
                Nd4j.getExecutioner().exec(new GreaterThanOrEqual(new INDArray[] { ia, ib }, new INDArray[] { expOut }));
                break;
            case 11:
                t = sd.lte(in1, in2);
                expOut = ia.dup();
                Nd4j.getExecutioner().exec(new LessThanOrEqual(new INDArray[] { ia, ib }, new INDArray[] { expOut }));
                break;
            case 12:
                ia = Nd4j.getExecutioner().exec(new BernoulliDistribution(ia, 0.5));
                ib = Nd4j.getExecutioner().exec(new BernoulliDistribution(ib, 0.5));
                t = sd.or(in1, in2);
                expOut = Transforms.or(ia, ib);
                break;
            case 13:
                ib = Nd4j.randn(nOut, nOut);
                t = sd.mmul(in1, in2);
                expOut = ia.mmul(ib);
                break;
            case 14:
                t = sd.max(in1, in2);
                expOut = Nd4j.getExecutioner().execAndReturn(new OldMax(ia, ib, ia.dup(), ia.length()));
                break;
            case 15:
                t = sd.min(in1, in2);
                expOut = Nd4j.getExecutioner().execAndReturn(new OldMin(ia, ib, ia.dup(), ia.length()));
                break;
            case 16:
                ia = Nd4j.getExecutioner().exec(new BernoulliDistribution(ia, 0.5));
                ib = Nd4j.getExecutioner().exec(new BernoulliDistribution(ib, 0.5));
                t = sd.and(in1, in2);
                expOut = Transforms.and(ia, ib);
                break;
            case 17:
                ia = Nd4j.getExecutioner().exec(new BernoulliDistribution(ia, 0.5));
                ib = Nd4j.getExecutioner().exec(new BernoulliDistribution(ib, 0.5));
                t = sd.xor(in1, in2);
                expOut = Transforms.xor(ia, ib);
                break;
            case 18:
                t = sd.assign(in1, in2);
                expOut = ib;
                break;
            case 19:
                t = sd.atan2(in1, in2);
                // Note: y,x order for samediff; x,y order for transforms
                expOut = Transforms.atan2(ib, ia);
                skipBackward = true;
                break;
            case 20:
                t = sd.mergeAdd(in1, in2, in2);
                expOut = ia.add(ib).add(ib);
                break;
            case 21:
                ia = Nd4j.create(new float[] { 2, 4 });
                ib = Nd4j.create(new float[] { 42, 2 });
                in1 = sd.var("in1", new int[] { 1, 2 });
                in2 = sd.var("in2", new int[] { 1, 2 });
                t = in1.truncatedDiv(in2);
                expOut = Nd4j.create(ia.shape(), ia.ordering());
                Nd4j.getExecutioner().exec(new TruncateDivOp(ia, ib, expOut));
                skipBackward = true;
                break;
            case 22:
                t = in1.squaredDifference(in2);
                expOut = Nd4j.create(ia.shape(), ia.ordering());
                DynamicCustomOp squareDiff = DynamicCustomOp.builder("squaredsubtract").addInputs(ia, ib).addOutputs(expOut).build();
                Nd4j.getExecutioner().exec(squareDiff);
                skipBackward = true;
                break;
            default:
                throw new RuntimeException();
        }
        DifferentialFunction[] funcs = sd.functions();
        String name = funcs[0].opName();
        String msg = "test: " + i + " - " + name;
        log.info("*** Starting test: " + msg);
        SDVariable loss = sd.mean("loss", t);
        sd.associateArrayWithVariable(ia, in1);
        sd.associateArrayWithVariable(ib, in2);
        sd.exec();
        INDArray out = t.getArr();
        assertEquals(msg, expOut, out);
        boolean ok;
        if (skipBackward) {
            ok = true;
            msg += " - SKIPPED";
            allSkipped.add(msg);
        } else {
            try {
                ok = GradCheckUtil.checkGradients(sd);
            } catch (Exception e) {
                e.printStackTrace();
                msg += " - EXCEPTION";
                ok = false;
            }
        }
        if (!ok) {
            allFailed.add(msg);
        }
    }
    if (allSkipped.size() > 0) {
        log.info("All backward skipped transforms: " + allSkipped);
        log.info(allSkipped.size() + " backward passes were skipped.");
    }
    if (allFailed.size() > 0) {
        log.error("All failed transforms: " + allFailed);
        fail(allFailed.size() + " transforms failed");
    }
}
Also used : LessThanOrEqual(org.nd4j.linalg.api.ops.impl.transforms.comparison.LessThanOrEqual) OldMin(org.nd4j.linalg.api.ops.impl.transforms.comparison.OldMin) SameDiff(org.nd4j.autodiff.samediff.SameDiff) ArrayList(java.util.ArrayList) GreaterThanOrEqual(org.nd4j.linalg.api.ops.impl.transforms.comparison.GreaterThanOrEqual) SDVariable(org.nd4j.autodiff.samediff.SDVariable) OldMax(org.nd4j.linalg.api.ops.impl.transforms.comparison.OldMax) INDArray(org.nd4j.linalg.api.ndarray.INDArray) BernoulliDistribution(org.nd4j.linalg.api.ops.random.impl.BernoulliDistribution) DifferentialFunction(org.nd4j.autodiff.functions.DifferentialFunction) DynamicCustomOp(org.nd4j.linalg.api.ops.DynamicCustomOp) TruncateDivOp(org.nd4j.linalg.api.ops.impl.transforms.arithmetic.TruncateDivOp) Test(org.junit.Test)

Example 18 with DynamicCustomOp

use of org.nd4j.linalg.api.ops.DynamicCustomOp in project nd4j by deeplearning4j.

the class GradCheckTransforms method testDiag.

@Test
public void testDiag() {
    SameDiff sd = SameDiff.create();
    INDArray ia = Nd4j.create(new float[] { 4, 2 });
    SDVariable in = sd.var("in", new int[] { 1, 2 });
    INDArray expOut = Nd4j.create(new int[] { 2, 2 });
    DynamicCustomOp diag = DynamicCustomOp.builder("diag").addInputs(ia).addOutputs(expOut).build();
    Nd4j.getExecutioner().exec(diag);
    SDVariable t = sd.diag(in);
    SDVariable loss = sd.max("loss", t, 0, 1);
    sd.associateArrayWithVariable(ia, in);
    sd.exec();
    INDArray out = t.getArr();
    if (!expOut.equals(out)) {
        log.info("forward failed");
    }
    try {
        GradCheckUtil.checkGradients(sd);
    } catch (Exception e) {
        e.printStackTrace();
    }
}
Also used : SDVariable(org.nd4j.autodiff.samediff.SDVariable) INDArray(org.nd4j.linalg.api.ndarray.INDArray) SameDiff(org.nd4j.autodiff.samediff.SameDiff) DynamicCustomOp(org.nd4j.linalg.api.ops.DynamicCustomOp) Test(org.junit.Test)

Example 19 with DynamicCustomOp

use of org.nd4j.linalg.api.ops.DynamicCustomOp in project nd4j by deeplearning4j.

the class ConvolutionTests method testPooling9.

@Test
public void testPooling9() {
    for (char outputOrder : new char[] { 'c', 'f' }) {
        INDArray exp = Nd4j.create(new float[] { 0.25f, 1.25f, 2.25f, 4.25f, 10.f, 12.f, 9.25f, 20.f, 22.f, 6.5f, 13.75f, 14.75f, 16.75f, 35.f, 37.f, 21.75f, 45.f, 47.f, 12.75f, 26.25f, 27.25f, 29.25f, 60.f, 62.f, 34.25f, 70.f, 72.f, 19.f, 38.75f, 39.75f, 41.75f, 85.f, 87.f, 46.75f, 95.f, 97.f }, new int[] { 2, 2, 3, 3 }, 'c');
        int len = 2 * 2 * 5 * 5;
        INDArray x = Nd4j.linspace(1, len, len).reshape('c', 2, 2, 5, 5);
        DynamicCustomOp op = DynamicCustomOp.builder("avgpool2d").addIntegerArguments(new int[] { 2, 2, 2, 2, 1, 1, 1, 1, 0, 1, 0 }).addInputs(x).addOutputs(Nd4j.create(new int[] { 2, 2, 3, 3 }, outputOrder)).build();
        Nd4j.getExecutioner().exec(op);
        INDArray out = op.getOutputArgument(0);
        assertEquals("Output order: " + outputOrder, exp, out);
    }
}
Also used : INDArray(org.nd4j.linalg.api.ndarray.INDArray) DynamicCustomOp(org.nd4j.linalg.api.ops.DynamicCustomOp) NDArrayIndex.point(org.nd4j.linalg.indexing.NDArrayIndex.point) Test(org.junit.Test) BaseNd4jTest(org.nd4j.linalg.BaseNd4jTest)

Example 20 with DynamicCustomOp

use of org.nd4j.linalg.api.ops.DynamicCustomOp in project nd4j by deeplearning4j.

the class ConvolutionTests method testPooling2.

@Test
public void testPooling2() {
    for (char outputOrder : new char[] { 'c', 'f' }) {
        INDArray exp = Nd4j.create(new float[] { 6.f, 7.f, 10.f, 11.f, 22.f, 23.f, 26.f, 27.f, 38.f, 39.f, 42.f, 43.f, 54.f, 55.f, 58.f, 59.f }, new int[] { 2, 2, 2, 2 }, 'c');
        int len = 2 * 4 * 4 * 2;
        INDArray x = Nd4j.linspace(1, len, len).reshape('c', 2, 4, 4, 2);
        DynamicCustomOp op = DynamicCustomOp.builder("avgpool2d").addIntegerArguments(new int[] { 2, 2, 2, 2, 0, 0, 1, 1, 0, 1, 1 }).addInputs(x).addOutputs(Nd4j.create(new int[] { 2, 2, 2, 2 }, outputOrder)).build();
        Nd4j.getExecutioner().exec(op);
        INDArray out = op.getOutputArgument(0);
        assertEquals("Output order: " + outputOrder, exp, out);
    }
}
Also used : INDArray(org.nd4j.linalg.api.ndarray.INDArray) DynamicCustomOp(org.nd4j.linalg.api.ops.DynamicCustomOp) NDArrayIndex.point(org.nd4j.linalg.indexing.NDArrayIndex.point) Test(org.junit.Test) BaseNd4jTest(org.nd4j.linalg.BaseNd4jTest)

Aggregations

Test (org.junit.Test)26 INDArray (org.nd4j.linalg.api.ndarray.INDArray)26 DynamicCustomOp (org.nd4j.linalg.api.ops.DynamicCustomOp)26 BaseNd4jTest (org.nd4j.linalg.BaseNd4jTest)15 NDArrayIndex.point (org.nd4j.linalg.indexing.NDArrayIndex.point)14 SDVariable (org.nd4j.autodiff.samediff.SDVariable)10 SameDiff (org.nd4j.autodiff.samediff.SameDiff)10 ArrayList (java.util.ArrayList)2 DifferentialFunction (org.nd4j.autodiff.functions.DifferentialFunction)2 Ignore (org.junit.Ignore)1 MMulTranspose (org.nd4j.linalg.api.blas.params.MMulTranspose)1 Mmul (org.nd4j.linalg.api.ops.impl.accum.Mmul)1 TruncateDivOp (org.nd4j.linalg.api.ops.impl.transforms.arithmetic.TruncateDivOp)1 GreaterThanOrEqual (org.nd4j.linalg.api.ops.impl.transforms.comparison.GreaterThanOrEqual)1 LessThanOrEqual (org.nd4j.linalg.api.ops.impl.transforms.comparison.LessThanOrEqual)1 OldMax (org.nd4j.linalg.api.ops.impl.transforms.comparison.OldMax)1 OldMin (org.nd4j.linalg.api.ops.impl.transforms.comparison.OldMin)1 BernoulliDistribution (org.nd4j.linalg.api.ops.random.impl.BernoulliDistribution)1 Pair (org.nd4j.linalg.primitives.Pair)1