Search in sources :

Example 36 with SDVariable

use of org.nd4j.autodiff.samediff.SDVariable in project nd4j by deeplearning4j.

the class GradCheckReductions method testZeroCount.

@Test
public void testZeroCount() {
    SameDiff sd = SameDiff.create();
    INDArray ia = Nd4j.create(new int[] { 2, 2 }, new float[] { 0, 1, 0, 1 });
    SDVariable input = sd.var("in", new int[] { 2, 2 });
    sd.associateArrayWithVariable(ia, input);
    SDVariable nonZero = sd.countNonZero(input);
    SDVariable zero = sd.countZero(input);
    sd.exec();
    assert nonZero.getArr().getDouble(0) == 2;
    assert zero.getArr().getDouble(0) == 2;
}
Also used : SDVariable(org.nd4j.autodiff.samediff.SDVariable) INDArray(org.nd4j.linalg.api.ndarray.INDArray) SameDiff(org.nd4j.autodiff.samediff.SameDiff) Test(org.junit.Test)

Example 37 with SDVariable

use of org.nd4j.autodiff.samediff.SDVariable in project nd4j by deeplearning4j.

the class GradCheckReductions method testZeroFraction.

@Test
public void testZeroFraction() {
    SameDiff sd = SameDiff.create();
    INDArray ia = Nd4j.create(new int[] { 2, 2 }, new float[] { 0, 1, 0, 1 });
    SDVariable input = sd.var("in", new int[] { 2, 2 });
    sd.associateArrayWithVariable(ia, input);
    SDVariable zeroFraction = sd.zeroFraction(input);
    sd.exec();
    assert zeroFraction.getArr().getDouble(0) == 0.5;
}
Also used : SDVariable(org.nd4j.autodiff.samediff.SDVariable) INDArray(org.nd4j.linalg.api.ndarray.INDArray) SameDiff(org.nd4j.autodiff.samediff.SameDiff) Test(org.junit.Test)

Example 38 with SDVariable

use of org.nd4j.autodiff.samediff.SDVariable in project nd4j by deeplearning4j.

the class GradCheckReductions method testReductionGradients1.

@Test
public void testReductionGradients1() {
    // Test reductions: final, but *not* the only function
    Nd4j.getRandom().setSeed(12345);
    List<String> allFailed = new ArrayList<>();
    for (int dim : new int[] { 0, Integer.MAX_VALUE }) {
        for (int i = 0; i < 10; i++) {
            SameDiff sd = SameDiff.create();
            int nOut = 4;
            int minibatch = 10;
            SDVariable input = sd.var("in", new int[] { -1, nOut });
            SDVariable label = sd.var("label", new int[] { -1, nOut });
            SDVariable diff = input.sub(label);
            SDVariable sqDiff = diff.mul(diff);
            SDVariable msePerEx = sd.mean("msePerEx", sqDiff, 1);
            SDVariable loss;
            String name;
            switch(i) {
                case 0:
                    loss = sd.mean("loss", msePerEx, dim);
                    name = "mean";
                    break;
                case 1:
                    loss = sd.sum("loss", msePerEx, dim);
                    name = "sum";
                    break;
                case 2:
                    loss = sd.standardDeviation("loss", msePerEx, true, dim);
                    name = "stdev";
                    break;
                case 3:
                    loss = sd.min("loss", msePerEx, dim);
                    name = "min";
                    break;
                case 4:
                    loss = sd.max("loss", msePerEx, dim);
                    name = "max";
                    break;
                case 5:
                    loss = sd.variance("loss", msePerEx, true, dim);
                    name = "variance";
                    break;
                case 6:
                    loss = sd.prod("loss", msePerEx, dim);
                    name = "prod";
                    break;
                case 7:
                    loss = sd.norm1("loss", msePerEx, dim);
                    name = "norm1";
                    break;
                case 8:
                    loss = sd.norm2("loss", msePerEx, dim);
                    name = "norm2";
                    break;
                case 9:
                    loss = sd.normmax("loss", msePerEx, dim);
                    name = "normmax";
                    break;
                default:
                    throw new RuntimeException();
            }
            String msg = "(test " + i + " - " + name + ", dimension=" + dim + ")";
            log.info("*** Starting test: " + msg);
            INDArray inputArr = Nd4j.randn(minibatch, nOut).muli(100);
            INDArray labelArr = Nd4j.randn(minibatch, nOut).muli(100);
            sd.associateArrayWithVariable(inputArr, input);
            sd.associateArrayWithVariable(labelArr, label);
            try {
                INDArray out = sd.execAndEndResult();
                assertNotNull(out);
                assertArrayEquals(new int[] { 1, 1 }, out.shape());
                // System.out.println(sd.asFlatPrint());
                boolean ok = GradCheckUtil.checkGradients(sd);
                if (!ok) {
                    allFailed.add(msg);
                }
            } catch (Exception e) {
                e.printStackTrace();
                allFailed.add(msg + " - EXCEPTION");
            }
        }
    }
    assertEquals("Failed: " + allFailed, 0, allFailed.size());
}
Also used : SDVariable(org.nd4j.autodiff.samediff.SDVariable) INDArray(org.nd4j.linalg.api.ndarray.INDArray) SameDiff(org.nd4j.autodiff.samediff.SameDiff) ArrayList(java.util.ArrayList) Test(org.junit.Test)

Example 39 with SDVariable

use of org.nd4j.autodiff.samediff.SDVariable in project nd4j by deeplearning4j.

the class GradCheckMisc method testStridedSliceGradient.

@Test
public void testStridedSliceGradient() {
    Nd4j.getRandom().setSeed(12345);
    // Order here: original shape, begin, size
    List<SSCase> testCases = new ArrayList<>();
    testCases.add(SSCase.builder().shape(3, 4).begin(0, 0).end(3, 4).strides(1, 1).build());
    testCases.add(SSCase.builder().shape(3, 4).begin(1, 1).end(2, 3).strides(1, 1).build());
    testCases.add(SSCase.builder().shape(3, 4).begin(-999, 0).end(3, 4).strides(1, 1).beginMask(1).build());
    testCases.add(SSCase.builder().shape(3, 4).begin(1, 1).end(3, -999).strides(1, 1).endMask(1 << 1).build());
    testCases.add(SSCase.builder().shape(3, 4).begin(-999, 0).end(-999, 4).strides(1, 1).beginMask(1).endMask(1).build());
    testCases.add(SSCase.builder().shape(3, 4).begin(-999, 0, 0).end(-999, 3, 4).strides(1, 1).newAxisMask(1).build());
    testCases.add(SSCase.builder().shape(3, 4, 5).begin(0, 0, 0).end(3, 4, 5).strides(1, 1, 1).build());
    testCases.add(SSCase.builder().shape(3, 4, 5).begin(1, 2, 3).end(3, 4, 5).strides(1, 1, 1).build());
    testCases.add(SSCase.builder().shape(3, 4, 5).begin(0, 0, 0).end(3, 3, 5).strides(1, 2, 2).build());
    testCases.add(SSCase.builder().shape(3, 4, 5).begin(1, -999, 1).end(3, 3, 4).strides(1, 1, 1).beginMask(1 << 1).build());
    testCases.add(SSCase.builder().shape(3, 4, 5).begin(1, -999, 1).end(3, 3, -999).strides(1, 1, 1).beginMask(1 << 1).endMask(1 << 2).build());
    // [1:3,...,2:4]
    testCases.add(SSCase.builder().shape(3, 4, 5).begin(1, 2).end(3, 4).strides(1, 1).ellipsisMask(1 << 1).build());
    testCases.add(SSCase.builder().shape(3, 4, 5).begin(1, -999, 1, 2).end(3, -999, 3, 4).strides(1, -999, 1, 2).newAxisMask(1 << 1).build());
    testCases.add(SSCase.builder().shape(3, 4, 5).begin(1, 0, 1).end(3, -999, 4).strides(1, 1, 1).shrinkAxisMask(1 << 1).build());
    testCases.add(SSCase.builder().shape(3, 4, 5).begin(1, 1, 1).end(3, -999, 4).strides(1, 1, 1).shrinkAxisMask(1 << 1).build());
    for (int i = 0; i < testCases.size(); i++) {
        SSCase t = testCases.get(i);
        INDArray arr = Nd4j.rand(t.getShape());
        SameDiff sd = SameDiff.create();
        SDVariable in = sd.var("in", arr);
        SDVariable slice = sd.stridedSlice(in, t.getBegin(), t.getEnd(), t.getStrides(), t.getBeginMask(), t.getEndMask(), t.getEllipsisMask(), t.getNewAxisMask(), t.getShrinkAxisMask());
        SDVariable stdev = sd.standardDeviation(slice, true);
        String msg = "i=" + i + ": " + t;
        log.info("Starting test: " + msg);
        GradCheckUtil.checkGradients(sd);
    }
}
Also used : SDVariable(org.nd4j.autodiff.samediff.SDVariable) INDArray(org.nd4j.linalg.api.ndarray.INDArray) SameDiff(org.nd4j.autodiff.samediff.SameDiff) ArrayList(java.util.ArrayList) Test(org.junit.Test)

Example 40 with SDVariable

use of org.nd4j.autodiff.samediff.SDVariable in project nd4j by deeplearning4j.

the class GradCheckMisc method testReshapeGradient.

@Test
public void testReshapeGradient() {
    int[] origShape = new int[] { 3, 4, 5 };
    for (int[] toShape : new int[][] { { 3, 4 * 5 }, { 3 * 4, 5 }, { 1, 3 * 4 * 5 }, { 3 * 4 * 5, 1 } }) {
        for (Pair<INDArray, String> p : NDArrayCreationUtil.getAll3dTestArraysWithShape(12345, origShape)) {
            INDArray inArr = p.getFirst().muli(100);
            SameDiff sd = SameDiff.create();
            SDVariable in = sd.var("in", inArr);
            SDVariable reshape = sd.reshape(in, toShape);
            // Using stdev here: mean/sum would backprop the same gradient for each input...
            SDVariable stdev = sd.standardDeviation("out", reshape, true);
            INDArray out = sd.execAndEndResult();
            INDArray expOut = in.getArr().std(true, Integer.MAX_VALUE);
            assertEquals(expOut, out);
            String msg = "toShape=" + Arrays.toString(toShape) + ", source=" + p.getSecond();
            boolean ok = GradCheckUtil.checkGradients(sd);
            assertTrue(msg, ok);
        }
    }
}
Also used : SDVariable(org.nd4j.autodiff.samediff.SDVariable) INDArray(org.nd4j.linalg.api.ndarray.INDArray) SameDiff(org.nd4j.autodiff.samediff.SameDiff) Test(org.junit.Test)

Aggregations

SDVariable (org.nd4j.autodiff.samediff.SDVariable)104 SameDiff (org.nd4j.autodiff.samediff.SameDiff)41 INDArray (org.nd4j.linalg.api.ndarray.INDArray)38 Test (org.junit.Test)36 ArrayList (java.util.ArrayList)18 DynamicCustomOp (org.nd4j.linalg.api.ops.DynamicCustomOp)10 lombok.val (lombok.val)7 LossFunctions (org.nd4j.autodiff.loss.LossFunctions)4 LossInfo (org.nd4j.autodiff.loss.LossInfo)4 BernoulliDistribution (org.nd4j.linalg.api.ops.random.impl.BernoulliDistribution)4 Ignore (org.junit.Ignore)3 DifferentialFunction (org.nd4j.autodiff.functions.DifferentialFunction)3 ND4JIllegalStateException (org.nd4j.linalg.exception.ND4JIllegalStateException)3 Triple (org.nd4j.linalg.primitives.Triple)2 DataOutputStream (java.io.DataOutputStream)1 FileOutputStream (java.io.FileOutputStream)1 ByteBuffer (java.nio.ByteBuffer)1 NoOpNameFoundException (org.nd4j.imports.NoOpNameFoundException)1 NdIndexIterator (org.nd4j.linalg.api.iter.NdIndexIterator)1 TruncateDivOp (org.nd4j.linalg.api.ops.impl.transforms.arithmetic.TruncateDivOp)1