Search in sources :

Example 41 with SDVariable

use of org.nd4j.autodiff.samediff.SDVariable in project nd4j by deeplearning4j.

the class GradCheckMisc method testPermuteGradient.

@Test
public void testPermuteGradient() {
    int[] origShape = new int[] { 3, 4, 5 };
    for (int[] perm : new int[][] { { 0, 1, 2 }, { 0, 2, 1 }, { 1, 0, 2 }, { 1, 2, 0 }, { 2, 0, 1 }, { 2, 1, 0 } }) {
        for (Pair<INDArray, String> p : NDArrayCreationUtil.getAll3dTestArraysWithShape(12345, origShape)) {
            INDArray inArr = p.getFirst().muli(100);
            SameDiff sd = SameDiff.create();
            SDVariable in = sd.var("in", inArr);
            SDVariable permute = sd.f().permute(in, perm);
            // Using stdev here: mean/sum would backprop the same gradient for each input...
            SDVariable stdev = sd.standardDeviation("out", permute, true);
            INDArray out = sd.execAndEndResult();
            INDArray expOut = in.getArr().std(true, Integer.MAX_VALUE);
            assertEquals(expOut, out);
            String msg = "permute=" + Arrays.toString(perm) + ", source=" + p.getSecond();
            boolean ok = GradCheckUtil.checkGradients(sd);
            assertTrue(msg, ok);
        }
    }
}
Also used : SDVariable(org.nd4j.autodiff.samediff.SDVariable) INDArray(org.nd4j.linalg.api.ndarray.INDArray) SameDiff(org.nd4j.autodiff.samediff.SameDiff) Test(org.junit.Test)

Example 42 with SDVariable

use of org.nd4j.autodiff.samediff.SDVariable in project nd4j by deeplearning4j.

the class GradCheckMisc method testGradientAutoBroadcast1.

@Test
public void testGradientAutoBroadcast1() {
    Nd4j.getRandom().setSeed(12345);
    List<String> allFailed = new ArrayList<>();
    for (int dim_sz1 : new int[] { 0, 1, 2 }) {
        int[] in2Shape = { 3, 4, 5 };
        in2Shape[dim_sz1] = 1;
        for (int i = 2; i < 3; i++) {
            SameDiff sd = SameDiff.create();
            SDVariable in3 = sd.var("in3", Nd4j.rand(new int[] { 3, 4, 5 }));
            SDVariable in2 = sd.var("in2", in2Shape);
            SDVariable bcOp;
            String name;
            switch(i) {
                case 0:
                    bcOp = in3.add(in2);
                    name = "add";
                    break;
                case 1:
                    bcOp = in3.sub(in2);
                    name = "sub";
                    break;
                case 2:
                    bcOp = in3.mul(in2);
                    name = "mul";
                    break;
                case 3:
                    bcOp = in3.div(in2);
                    name = "div";
                    break;
                case 4:
                    bcOp = in3.rsub(in2);
                    name = "rsub";
                    break;
                case 5:
                    bcOp = in3.rdiv(in2);
                    name = "rdiv";
                    break;
                case 6:
                    bcOp = sd.f().floorDiv(in3, in2);
                    name = "floordiv";
                    break;
                case 7:
                    bcOp = sd.f().floorMod(in3, in2);
                    name = "floormod";
                    break;
                default:
                    throw new RuntimeException();
            }
            SDVariable outVar = sd.sum(bcOp);
            String msg = "(test " + i + ": " + name + ", dimension=" + dim_sz1 + ")";
            log.info("*** Starting test: " + msg);
            INDArray in3Arr = Nd4j.randn(new int[] { 3, 4, 5 }).muli(100);
            INDArray in2Arr = Nd4j.randn(in2Shape).muli(100);
            sd.associateArrayWithVariable(in3Arr, in3);
            sd.associateArrayWithVariable(in2Arr, in2);
            try {
                INDArray out = sd.execAndEndResult();
                assertNotNull(out);
                assertArrayEquals(new int[] { 1, 1 }, out.shape());
                // System.out.println(sd.asFlatPrint());
                boolean ok = GradCheckUtil.checkGradients(sd);
                if (!ok) {
                    allFailed.add(msg);
                }
            } catch (Exception e) {
                e.printStackTrace();
                allFailed.add(msg + " - EXCEPTION");
            }
        }
    }
    assertEquals("Failed: " + allFailed, 0, allFailed.size());
}
Also used : SDVariable(org.nd4j.autodiff.samediff.SDVariable) INDArray(org.nd4j.linalg.api.ndarray.INDArray) SameDiff(org.nd4j.autodiff.samediff.SameDiff) ArrayList(java.util.ArrayList) Test(org.junit.Test)

Example 43 with SDVariable

use of org.nd4j.autodiff.samediff.SDVariable in project nd4j by deeplearning4j.

the class GradCheckMisc method testGradientAutoBroadcast2.

@Test
public void testGradientAutoBroadcast2() {
    Nd4j.getRandom().setSeed(12345);
    List<String> allFailed = new ArrayList<>();
    for (int[] dim_sz1s : new int[][] { { 0, 1 }, { 0, 2 }, { 1, 2 }, { 0, 1, 2 } }) {
        int[] otherShape = { 3, 4, 5 };
        otherShape[dim_sz1s[0]] = 1;
        otherShape[dim_sz1s[1]] = 1;
        if (dim_sz1s.length == 3) {
            otherShape[dim_sz1s[2]] = 1;
        }
        for (int i = 0; i < 6; i++) {
            SameDiff sd = SameDiff.create();
            SDVariable in3 = sd.var("in3", new int[] { 3, 4, 5 });
            SDVariable in2 = sd.var("inToBc", otherShape);
            String name;
            SDVariable bcOp;
            switch(i) {
                case 0:
                    bcOp = in3.add(in2);
                    name = "add";
                    break;
                case 1:
                    bcOp = in3.sub(in2);
                    name = "sub";
                    break;
                case 2:
                    bcOp = in3.mul(in2);
                    name = "mul";
                    break;
                case 3:
                    bcOp = in3.div(in2);
                    name = "div";
                    break;
                case 4:
                    bcOp = in3.rsub(in2);
                    name = "rsub";
                    break;
                case 5:
                    bcOp = in3.rdiv(in2);
                    name = "rdiv";
                    break;
                case 6:
                    bcOp = sd.f().floorDiv(in3, in2);
                    name = "floordiv";
                    break;
                case 7:
                    bcOp = sd.f().floorMod(in3, in2);
                    name = "floormod";
                    break;
                default:
                    throw new RuntimeException();
            }
            SDVariable outVar = sd.sum(bcOp);
            String msg = "(test " + i + ": " + name + ", dimensions=" + Arrays.toString(dim_sz1s) + ")";
            log.info("*** Starting test: " + msg);
            INDArray in3Arr = Nd4j.randn(new int[] { 3, 4, 5 }).muli(100);
            INDArray in2Arr = Nd4j.randn(otherShape).muli(100);
            sd.associateArrayWithVariable(in3Arr, in3);
            sd.associateArrayWithVariable(in2Arr, in2);
            try {
                INDArray out = sd.execAndEndResult();
                assertNotNull(out);
                assertArrayEquals(new int[] { 1, 1 }, out.shape());
                // System.out.println(sd.asFlatPrint());
                boolean ok = GradCheckUtil.checkGradients(sd);
                if (!ok) {
                    allFailed.add(msg);
                }
            } catch (Exception e) {
                e.printStackTrace();
                allFailed.add(msg + " - EXCEPTION");
            }
        }
    }
    assertEquals("Failed: " + allFailed, 0, allFailed.size());
}
Also used : SDVariable(org.nd4j.autodiff.samediff.SDVariable) INDArray(org.nd4j.linalg.api.ndarray.INDArray) SameDiff(org.nd4j.autodiff.samediff.SameDiff) ArrayList(java.util.ArrayList) Test(org.junit.Test)

Example 44 with SDVariable

use of org.nd4j.autodiff.samediff.SDVariable in project nd4j by deeplearning4j.

the class CumProd method doDiff.

@Override
public List<SDVariable> doDiff(List<SDVariable> grad) {
    // Output gradient is the reversed cumulative product of the reversed input gradient
    SDVariable gradient = sameDiff.setupFunction(grad.get(0));
    SDVariable reverseGrad = sameDiff.reverse(gradient, 1 - dimensions[0]);
    SDVariable ret = sameDiff.cumprod(reverseGrad, exclusive, reverse, dimensions);
    SDVariable reversedRet = sameDiff.reverse(ret, 1 - dimensions[0]);
    return Arrays.asList(reversedRet);
}
Also used : SDVariable(org.nd4j.autodiff.samediff.SDVariable)

Example 45 with SDVariable

use of org.nd4j.autodiff.samediff.SDVariable in project nd4j by deeplearning4j.

the class Erf method doDiff.

@Override
public List<SDVariable> doDiff(List<SDVariable> i_v) {
    // Derivative of erf(z) is 2 / sqrt(pi) * e^(-z^2)
    SDVariable gradient = i_v.get(0);
    SDVariable constant = sameDiff.onesLike(gradient).mul(2).div(Math.sqrt(Math.PI));
    SDVariable ret = constant.mul(sameDiff.exp(gradient.mul(gradient).mul(-1)));
    return Arrays.asList(ret);
}
Also used : SDVariable(org.nd4j.autodiff.samediff.SDVariable)

Aggregations

SDVariable (org.nd4j.autodiff.samediff.SDVariable)104 SameDiff (org.nd4j.autodiff.samediff.SameDiff)41 INDArray (org.nd4j.linalg.api.ndarray.INDArray)38 Test (org.junit.Test)36 ArrayList (java.util.ArrayList)18 DynamicCustomOp (org.nd4j.linalg.api.ops.DynamicCustomOp)10 lombok.val (lombok.val)7 LossFunctions (org.nd4j.autodiff.loss.LossFunctions)4 LossInfo (org.nd4j.autodiff.loss.LossInfo)4 BernoulliDistribution (org.nd4j.linalg.api.ops.random.impl.BernoulliDistribution)4 Ignore (org.junit.Ignore)3 DifferentialFunction (org.nd4j.autodiff.functions.DifferentialFunction)3 ND4JIllegalStateException (org.nd4j.linalg.exception.ND4JIllegalStateException)3 Triple (org.nd4j.linalg.primitives.Triple)2 DataOutputStream (java.io.DataOutputStream)1 FileOutputStream (java.io.FileOutputStream)1 ByteBuffer (java.nio.ByteBuffer)1 NoOpNameFoundException (org.nd4j.imports.NoOpNameFoundException)1 NdIndexIterator (org.nd4j.linalg.api.iter.NdIndexIterator)1 TruncateDivOp (org.nd4j.linalg.api.ops.impl.transforms.arithmetic.TruncateDivOp)1