Search in sources :

Example 26 with SDVariable

use of org.nd4j.autodiff.samediff.SDVariable in project nd4j by deeplearning4j.

the class GradCheckTransforms method testDynamicPartition.

@Test
public void testDynamicPartition() {
    SameDiff sd = SameDiff.create();
    INDArray ia = Nd4j.create(new float[] { 4, 3, 5, 7, 8, 0 }, new int[] { 1, 6 });
    INDArray partitions = Nd4j.create(new float[] { 1, 0, 1, 0, 0, 1 });
    int numPartitions = 2;
    SDVariable in = sd.var("in", new int[] { 1, 6 });
    SDVariable sdPartitions = sd.var("partitions", new int[] { 1, 6 });
    INDArray expOut1 = Nd4j.create(new int[] { 1, 3 });
    INDArray expOut2 = Nd4j.create(new int[] { 1, 3 });
    INDArray[] expOut = new INDArray[] { expOut1, expOut2 };
    DynamicCustomOp dynamicPartition = DynamicCustomOp.builder("dynamic_partition").addInputs(ia, partitions).addIntegerArguments(numPartitions).addOutputs(expOut1, expOut2).build();
    Nd4j.getExecutioner().exec(dynamicPartition);
    SDVariable[] parts = sd.dynamicPartition(in, sdPartitions, numPartitions);
    // merge the output partitions together again, to retrieve a single
    // tensor and finally a scalar.
    SDVariable t = sd.mergeAdd(parts);
    SDVariable loss = sd.mean("loss", t);
    sd.associateArrayWithVariable(ia, in);
    sd.exec();
    INDArray[] out = new INDArray[numPartitions];
    for (int i = 0; i < parts.length; i++) {
        out[i] = parts[i].getArr();
    }
    if (!expOut.equals(out)) {
        log.error("forward failed");
    }
}
Also used : SDVariable(org.nd4j.autodiff.samediff.SDVariable) INDArray(org.nd4j.linalg.api.ndarray.INDArray) SameDiff(org.nd4j.autodiff.samediff.SameDiff) DynamicCustomOp(org.nd4j.linalg.api.ops.DynamicCustomOp) Test(org.junit.Test)

Example 27 with SDVariable

use of org.nd4j.autodiff.samediff.SDVariable in project nd4j by deeplearning4j.

the class GradCheckTransforms method testTransforms.

@Test
public void testTransforms() {
    // Test transforms (non-pairwise)
    Nd4j.getRandom().setSeed(12345);
    // TODO: to speed up integration of new ops for TF import, we sometimes skip "doDiff" implementations.
    // get back to those after release
    List<String> allSkipped = new ArrayList<>();
    List<String> allFailed = new ArrayList<>();
    for (int i = 0; i < 62; i++) {
        boolean skipBackward = false;
        SameDiff sd = SameDiff.create();
        int nOut = 4;
        int minibatch = 5;
        SDVariable in = sd.var("in", new int[] { -1, nOut });
        INDArray ia = Nd4j.randn(minibatch, nOut);
        int dim;
        SDVariable t;
        INDArray expOut;
        switch(i) {
            case 0:
                t = in.add(5.0);
                expOut = ia.add(5.0);
                break;
            case 1:
                t = in.sub(5.0);
                expOut = ia.sub(5.0);
                break;
            case 2:
                t = in.mul(2.5);
                expOut = ia.mul(2.5);
                break;
            case 3:
                t = in.div(4.0);
                expOut = ia.div(4.0);
                break;
            case 4:
                t = in.rsub(5.0);
                expOut = ia.rsub(5.0);
                break;
            case 5:
                t = in.rdiv(1.0);
                expOut = ia.rdiv(1.0);
                break;
            case 6:
                t = sd.pow(in, 2.5);
                ia = Nd4j.rand(minibatch, nOut);
                expOut = Transforms.pow(ia, 2.5, true);
                break;
            case 7:
                t = sd.sigmoid(in);
                ia = Nd4j.rand(minibatch, nOut).muli(2).subi(1.0);
                expOut = Transforms.sigmoid(ia, true);
                break;
            case 8:
                t = sd.tanh(in);
                ia = Nd4j.rand(minibatch, nOut).muli(2).subi(1.0);
                expOut = Transforms.tanh(ia, true);
                break;
            case 9:
                t = sd.tan(in);
                ia = Nd4j.rand(minibatch, nOut);
                expOut = Nd4j.getExecutioner().execAndReturn(new Tan(ia.dup()));
                break;
            case 10:
                t = sd.cos(in);
                expOut = Transforms.cos(ia, true);
                break;
            case 11:
                t = sd.sin(in);
                expOut = Transforms.sin(ia, true);
                break;
            case 12:
                t = sd.softplus(in);
                expOut = Transforms.softPlus(ia, true);
                break;
            case 13:
                t = sd.log(in);
                ia = Nd4j.rand(minibatch, nOut);
                expOut = Transforms.log(ia, true);
                skipBackward = true;
                break;
            case 14:
                t = sd.neg(in);
                expOut = ia.neg();
                break;
            case 15:
                t = sd.acos(in);
                ia = Nd4j.rand(minibatch, nOut).muli(1.8).subi(0.9);
                expOut = Transforms.acos(ia, true);
                break;
            case 16:
                t = sd.acosh(in);
                // Only defined for x >= 1
                ia = Nd4j.rand(minibatch, nOut).addi(1.01);
                expOut = Nd4j.getExecutioner().execAndReturn(new ACosh(ia.dup()));
                skipBackward = true;
                break;
            case 17:
                t = sd.asin(in);
                ia = Nd4j.rand(minibatch, nOut).muli(1.8).subi(0.9);
                expOut = Transforms.asin(ia, true);
                break;
            case 18:
                t = sd.atan(in);
                ia = Nd4j.rand(minibatch, nOut).muli(4).subi(2);
                expOut = Transforms.atan(ia, true);
                break;
            case 19:
                t = sd.atanh(in);
                ia = Nd4j.rand(minibatch, nOut).muli(1.8).subi(0.9);
                expOut = Transforms.atanh(ia, true);
                break;
            case 20:
                t = sd.cosh(in);
                expOut = Transforms.cosh(ia, true);
                break;
            case 21:
                t = sd.cube(in);
                expOut = Transforms.pow(ia, 3.0, true);
                break;
            case 22:
                t = sd.elu(in);
                expOut = Transforms.elu(ia, true);
                break;
            case 23:
                // TODO SHOULDN'T THIS HAVE A DIMENSION ARG???
                t = sd.softmax(in);
                ia = Nd4j.rand(minibatch, nOut);
                expOut = Nd4j.getExecutioner().execAndReturn(new OldSoftMax(ia.dup()));
                break;
            case 24:
                t = sd.sqrt(in);
                ia = Nd4j.rand(minibatch, nOut);
                expOut = Transforms.sqrt(ia, true);
                break;
            case 25:
                t = sd.square(in);
                expOut = Transforms.pow(ia, 2.0, true);
                break;
            case 26:
                t = sd.transpose(in);
                expOut = ia.transpose().dup();
                break;
            case 27:
                t = sd.abs(in);
                expOut = Transforms.abs(ia, true);
                break;
            case 28:
                t = sd.sinh(in);
                expOut = Transforms.sinh(ia, true);
                break;
            case 29:
                t = sd.asinh(in);
                expOut = Nd4j.getExecutioner().execAndReturn(new ASinh(ia.dup()));
                skipBackward = true;
                break;
            case 30:
                t = sd.exp(in);
                expOut = Transforms.exp(ia, true);
                break;
            case 31:
                t = sd.floor(in);
                expOut = Transforms.floor(ia, true);
                break;
            case 32:
                t = sd.relu(in, 0.0);
                ia = Nd4j.rand(minibatch, nOut);
                expOut = Transforms.relu(ia, true);
                break;
            case 33:
                t = sd.hardTanh(in);
                ia = Nd4j.rand(minibatch, nOut).muli(2).subi(1.0);
                expOut = Transforms.hardTanh(ia, true);
                break;
            case 34:
                t = sd.logSigmoid(in);
                expOut = Nd4j.getExecutioner().execAndReturn(new LogSigmoid(ia.dup()));
                break;
            case 35:
                t = sd.swish(in);
                expOut = Nd4j.getExecutioner().execAndReturn(new Swish(ia.dup()));
                break;
            case 36:
                t = sd.sign(in);
                expOut = Transforms.sign(ia, true);
                break;
            case 37:
                t = sd.softsign(in);
                expOut = Transforms.softsign(ia, true);
                break;
            case 38:
                t = sd.leakyRelu(in, 0.0);
                ia = Nd4j.rand(minibatch, nOut);
                expOut = Transforms.leakyRelu(ia, true);
                break;
            case 39:
                // TODO DIMENSION ARG???
                // TODO fix me
                t = sd.logSoftmax(in);
                ia = Nd4j.rand(minibatch, nOut);
                expOut = Transforms.log(Transforms.softmax(ia, true));
                skipBackward = true;
                break;
            case 40:
                t = sd.selu(in);
                expOut = Nd4j.getExecutioner().execAndReturn(new SELU(ia.dup()));
                break;
            case 41:
                t = sd.gt(in, 1.0);
                expOut = ia.gt(1.0);
                break;
            case 42:
                t = sd.gte(in, 1.0);
                expOut = ia.gte(1.0);
                break;
            case 43:
                t = sd.lt(in, 1.0);
                expOut = ia.lt(1.0);
                break;
            case 44:
                t = sd.lte(in, 1.0);
                expOut = ia.lte(1.0);
                break;
            case 45:
                t = sd.eq(in, 2.0);
                ia = Nd4j.linspace(1, minibatch * nOut, minibatch * nOut).reshape('c', minibatch, nOut);
                expOut = ia.eq(2.0);
                break;
            case 46:
                t = sd.neq(in, 2.0);
                ia = Nd4j.linspace(1, minibatch * nOut, minibatch * nOut).reshape('c', minibatch, nOut);
                expOut = ia.neq(2.0);
                break;
            case 47:
                t = sd.ceil(in);
                expOut = Transforms.ceil(ia, true);
                break;
            case 48:
                ia = Nd4j.randn(ia.shape()).muli(2);
                t = sd.clipByValue(in, -3, 2);
                expOut = ia.dup();
                BooleanIndexing.replaceWhere(expOut, -3, Conditions.lessThan(-3));
                BooleanIndexing.replaceWhere(expOut, 2, Conditions.greaterThan(2));
                break;
            case 49:
                // Clip by norm, dimension 0, some below threshold, some above
                double clip = 2.0;
                ia = Nd4j.rand(ia.shape());
                // Exactly at threshold...
                ia.muliRowVector(ia.norm2(0).rdiv(clip));
                System.out.println(ia.norm2(0));
                ia.muliRowVector(Nd4j.linspace(0.9, 1.1, ia.size(1)));
                System.out.println(ia.norm2(0));
                expOut = Nd4j.create(ia.shape());
                for (int j = 0; j < ia.columns(); j++) {
                    INDArray origCol = ia.getColumn(j);
                    if (origCol.norm2Number().doubleValue() < clip) {
                        expOut.putColumn(j, origCol);
                    } else {
                        expOut.putColumn(j, origCol.mul(clip / origCol.norm2Number().doubleValue()));
                    }
                }
                t = sd.clipByNorm(in, clip);
                break;
            // TODO clip by norm along other dimensions
            case 50:
                dim = 1;
                t = sd.reverse(in, dim);
                expOut = Nd4j.create(ia.shape());
                DynamicCustomOp reverse = DynamicCustomOp.builder("reverse").addIntegerArguments(dim).addInputs(ia).addOutputs(expOut).build();
                Nd4j.getExecutioner().exec(reverse);
                break;
            case 51:
                dim = 0;
                boolean exclusive = false;
                boolean reverseBool = false;
                t = sd.cumsum(in, exclusive, reverseBool, dim);
                expOut = Nd4j.create(ia.shape());
                DynamicCustomOp cumsum = DynamicCustomOp.builder("cumsum").addIntegerArguments((exclusive) ? 1 : 0, (reverseBool) ? 1 : 0, dim).addInputs(ia).addOutputs(expOut).build();
                Nd4j.getExecutioner().exec(cumsum);
                break;
            case 52:
                dim = 0;
                boolean ex = false;
                boolean revBool = false;
                t = sd.cumsum(in, ex, revBool, dim);
                expOut = Nd4j.create(ia.shape());
                DynamicCustomOp cumprod = DynamicCustomOp.builder("cumprod").addIntegerArguments((ex) ? 1 : 0, (revBool) ? 1 : 0, dim).addInputs(ia).addOutputs(expOut).build();
                Nd4j.getExecutioner().exec(cumprod);
                break;
            case 53:
                ia = Nd4j.create(new float[] { 4, 2 });
                in = sd.var("in", new int[] { 1, 2 });
                expOut = Nd4j.create(new int[] { 2, 2 });
                DynamicCustomOp op = DynamicCustomOp.builder("diag").addInputs(ia).addOutputs(expOut).build();
                Nd4j.getExecutioner().exec(op);
                t = sd.diag(in);
                break;
            case 54:
                expOut = Nd4j.createUninitialized(ia.shape(), ia.ordering());
                Nd4j.getExecutioner().exec(new Erf(ia, expOut));
                t = sd.erf(in);
                skipBackward = true;
                break;
            case 55:
                expOut = Nd4j.createUninitialized(ia.shape(), ia.ordering());
                Nd4j.getExecutioner().exec(new Erfc(ia, expOut));
                t = sd.erfc(in);
                skipBackward = true;
                break;
            case 56:
                t = sd.expm1(in);
                expOut = Transforms.expm1(ia, true);
                break;
            case 57:
                t = sd.log1p(in);
                ia = Nd4j.rand(minibatch, nOut);
                expOut = Transforms.log1p(ia, true);
                skipBackward = true;
                break;
            case 58:
                t = sd.round(in);
                expOut = Transforms.round(ia, true);
                skipBackward = true;
                break;
            case 59:
                ia = Nd4j.create(new float[] { 4, 2 });
                in = sd.var("in", new int[] { 1, 2 });
                t = sd.rsqrt(in);
                expOut = Nd4j.create(ia.shape(), ia.ordering());
                Nd4j.getExecutioner().exec(new RSqrt(ia, expOut));
                skipBackward = true;
                break;
            case 60:
                t = sd.relu6(in, 0);
                ia = Nd4j.rand(minibatch, nOut);
                expOut = Transforms.relu6(ia);
                skipBackward = true;
                break;
            case 61:
                ia = Nd4j.create(new float[] { 2, 2 });
                in = sd.var("in", new int[] { 1, 2 });
                sd.associateArrayWithVariable(ia, in);
                double value = 42;
                expOut = Nd4j.create(new int[] { 2, 2 });
                DynamicCustomOp fillOp = DynamicCustomOp.builder("fill").addInputs(ia).addFloatingPointArguments(value).addOutputs(expOut).build();
                Nd4j.getExecutioner().exec(fillOp);
                skipBackward = true;
                t = sd.fill(in, value);
                break;
            default:
                throw new RuntimeException();
        }
        DifferentialFunction[] funcs = sd.functions();
        String name = funcs[0].opName();
        String msg = "test: " + i + " - " + name;
        log.info("*** Starting test: " + msg);
        SDVariable loss = sd.mean("loss", t);
        sd.associateArrayWithVariable(ia, in);
        sd.exec();
        INDArray out = t.getArr();
        out.shape();
        if (!expOut.equals(out)) {
            allFailed.add(msg + " - FAILED ON FORWARD");
            continue;
        }
        boolean ok;
        if (skipBackward) {
            ok = true;
            msg += " - SKIPPED";
            allSkipped.add(msg);
        } else {
            try {
                ok = GradCheckUtil.checkGradients(sd);
            } catch (Exception e) {
                e.printStackTrace();
                msg += " - EXCEPTION";
                ok = false;
            }
        }
        assertTrue(msg, ok);
        if (!ok) {
            allFailed.add(msg);
        }
    }
    if (allSkipped.size() > 0) {
        log.info("All backward skipped transforms: " + allSkipped);
        log.info(allSkipped.size() + " backward passes were skipped.");
    }
    if (allFailed.size() > 0) {
        log.error("All failed transforms: " + allFailed);
        fail(allFailed.size() + " transforms failed");
    }
}
Also used : SameDiff(org.nd4j.autodiff.samediff.SameDiff) ArrayList(java.util.ArrayList) SDVariable(org.nd4j.autodiff.samediff.SDVariable) INDArray(org.nd4j.linalg.api.ndarray.INDArray) DifferentialFunction(org.nd4j.autodiff.functions.DifferentialFunction) DynamicCustomOp(org.nd4j.linalg.api.ops.DynamicCustomOp) Test(org.junit.Test)

Example 28 with SDVariable

use of org.nd4j.autodiff.samediff.SDVariable in project nd4j by deeplearning4j.

the class GradCheckTransforms method testConditions.

@Test
public void testConditions() {
    SameDiff sd = SameDiff.create();
    INDArray ia = Nd4j.create(new float[] { 4, 2 });
    SDVariable in = sd.var("in", new int[] { 1, 2 });
    sd.associateArrayWithVariable(ia, in);
    INDArray expFinite = Nd4j.create(new float[] { 1, 1 });
    SDVariable finite = sd.isFinite(in);
    INDArray expInfinite = Nd4j.create(new float[] { 0, 0 });
    SDVariable infinite = sd.isInfinite(in);
    INDArray expNaN = Nd4j.create(new float[] { 0, 0 });
    SDVariable isnan = sd.isNaN(in);
    sd.exec();
    assert (expFinite.equals(finite.getArr()));
    assert (expInfinite.equals(infinite.getArr()));
    assert (expNaN.equals(isnan.getArr()));
}
Also used : SDVariable(org.nd4j.autodiff.samediff.SDVariable) INDArray(org.nd4j.linalg.api.ndarray.INDArray) SameDiff(org.nd4j.autodiff.samediff.SameDiff) Test(org.junit.Test)

Example 29 with SDVariable

use of org.nd4j.autodiff.samediff.SDVariable in project nd4j by deeplearning4j.

the class GradCheckTransforms method testSpaceToDepth.

@Test
public void testSpaceToDepth() {
    Nd4j.getRandom().setSeed(1337);
    int miniBatch = 128;
    int blockSize = 4;
    String dataFormat = "NHWC";
    int isNHWC = dataFormat.equals("NHWC") ? 1 : 0;
    int[] inputShape = new int[] { miniBatch, 2 * blockSize, 2 * blockSize, 1 };
    INDArray input = Nd4j.randn(inputShape);
    SameDiff sd = SameDiff.create();
    SDVariable sdInput = sd.var("in", inputShape);
    INDArray expOut = Nd4j.create(miniBatch, 2, 2, blockSize * blockSize);
    DynamicCustomOp op = DynamicCustomOp.builder("space_to_depth").addInputs(input).addIntegerArguments(blockSize, isNHWC).addOutputs(expOut).build();
    Nd4j.getExecutioner().exec(op);
    sd.associateArrayWithVariable(input, sdInput);
    SDVariable t = sd.spaceToDepth(sdInput, blockSize, dataFormat);
    SDVariable loss = sd.mean("loss", t);
    sd.exec();
    INDArray out = t.getArr();
    if (!expOut.equals(out)) {
        log.info("depth to space failed on forward");
    }
    try {
        GradCheckUtil.checkGradients(sd);
    } catch (Exception e) {
        e.printStackTrace();
    }
}
Also used : SDVariable(org.nd4j.autodiff.samediff.SDVariable) INDArray(org.nd4j.linalg.api.ndarray.INDArray) SameDiff(org.nd4j.autodiff.samediff.SameDiff) DynamicCustomOp(org.nd4j.linalg.api.ops.DynamicCustomOp) Test(org.junit.Test)

Example 30 with SDVariable

use of org.nd4j.autodiff.samediff.SDVariable in project nd4j by deeplearning4j.

the class GradCheckTransforms method testBatchToSpace.

@Test
public void testBatchToSpace() {
    Nd4j.getRandom().setSeed(1337);
    int miniBatch = 4;
    int[] inputShape = new int[] { miniBatch, 1, 1, 1 };
    int M = 2;
    int[] blockShape = new int[] { M, 1 };
    int[] cropShape = new int[] { M, 2 };
    INDArray input = Nd4j.randn(inputShape);
    INDArray blocks = Nd4j.create(new float[] { 2, 2 }, blockShape);
    INDArray crops = Nd4j.create(new float[] { 0, 0, 0, 0 }, cropShape);
    SameDiff sd = SameDiff.create();
    SDVariable sdInput = sd.var("in", inputShape);
    INDArray expOut = Nd4j.create(1, 2, 2, 1);
    DynamicCustomOp op = DynamicCustomOp.builder("batch_to_space").addInputs(input, blocks, crops).addOutputs(expOut).build();
    Nd4j.getExecutioner().exec(op);
    sd.associateArrayWithVariable(input, sdInput);
    SDVariable t = sd.batchToSpace(sdInput, new int[] { 2, 2 }, new int[][] { { 0, 0 }, { 0, 0 } });
    SDVariable loss = sd.mean("loss", t);
    sd.exec();
    INDArray out = t.getArr();
    if (!expOut.equals(out)) {
        log.info("batch to space failed on forward");
    }
    try {
        GradCheckUtil.checkGradients(sd);
    } catch (Exception e) {
        e.printStackTrace();
    }
}
Also used : SDVariable(org.nd4j.autodiff.samediff.SDVariable) INDArray(org.nd4j.linalg.api.ndarray.INDArray) SameDiff(org.nd4j.autodiff.samediff.SameDiff) DynamicCustomOp(org.nd4j.linalg.api.ops.DynamicCustomOp) Test(org.junit.Test)

Aggregations

SDVariable (org.nd4j.autodiff.samediff.SDVariable)104 SameDiff (org.nd4j.autodiff.samediff.SameDiff)41 INDArray (org.nd4j.linalg.api.ndarray.INDArray)38 Test (org.junit.Test)36 ArrayList (java.util.ArrayList)18 DynamicCustomOp (org.nd4j.linalg.api.ops.DynamicCustomOp)10 lombok.val (lombok.val)7 LossFunctions (org.nd4j.autodiff.loss.LossFunctions)4 LossInfo (org.nd4j.autodiff.loss.LossInfo)4 BernoulliDistribution (org.nd4j.linalg.api.ops.random.impl.BernoulliDistribution)4 Ignore (org.junit.Ignore)3 DifferentialFunction (org.nd4j.autodiff.functions.DifferentialFunction)3 ND4JIllegalStateException (org.nd4j.linalg.exception.ND4JIllegalStateException)3 Triple (org.nd4j.linalg.primitives.Triple)2 DataOutputStream (java.io.DataOutputStream)1 FileOutputStream (java.io.FileOutputStream)1 ByteBuffer (java.nio.ByteBuffer)1 NoOpNameFoundException (org.nd4j.imports.NoOpNameFoundException)1 NdIndexIterator (org.nd4j.linalg.api.iter.NdIndexIterator)1 TruncateDivOp (org.nd4j.linalg.api.ops.impl.transforms.arithmetic.TruncateDivOp)1