Search in sources :

Example 6 with DifferentialFunction

use of org.nd4j.autodiff.functions.DifferentialFunction in project nd4j by deeplearning4j.

the class BaseGraphMapper method opTypeForNode.

@Override
public Op.Type opTypeForNode(NODE_TYPE nodeDef) {
    DifferentialFunction opWithTensorflowName = getMappedOp(getOpType(nodeDef));
    if (opWithTensorflowName == null)
        throw new NoOpNameFoundException("No op found with name " + getOpType(nodeDef));
    Op.Type type = opWithTensorflowName.opType();
    return type;
}
Also used : Op(org.nd4j.linalg.api.ops.Op) DifferentialFunction(org.nd4j.autodiff.functions.DifferentialFunction) NoOpNameFoundException(org.nd4j.imports.NoOpNameFoundException)

Example 7 with DifferentialFunction

use of org.nd4j.autodiff.functions.DifferentialFunction in project nd4j by deeplearning4j.

the class GradCheckTransforms method testTransforms.

@Test
public void testTransforms() {
    // Test transforms (non-pairwise)
    Nd4j.getRandom().setSeed(12345);
    // TODO: to speed up integration of new ops for TF import, we sometimes skip "doDiff" implementations.
    // get back to those after release
    List<String> allSkipped = new ArrayList<>();
    List<String> allFailed = new ArrayList<>();
    for (int i = 0; i < 62; i++) {
        boolean skipBackward = false;
        SameDiff sd = SameDiff.create();
        int nOut = 4;
        int minibatch = 5;
        SDVariable in = sd.var("in", new int[] { -1, nOut });
        INDArray ia = Nd4j.randn(minibatch, nOut);
        int dim;
        SDVariable t;
        INDArray expOut;
        switch(i) {
            case 0:
                t = in.add(5.0);
                expOut = ia.add(5.0);
                break;
            case 1:
                t = in.sub(5.0);
                expOut = ia.sub(5.0);
                break;
            case 2:
                t = in.mul(2.5);
                expOut = ia.mul(2.5);
                break;
            case 3:
                t = in.div(4.0);
                expOut = ia.div(4.0);
                break;
            case 4:
                t = in.rsub(5.0);
                expOut = ia.rsub(5.0);
                break;
            case 5:
                t = in.rdiv(1.0);
                expOut = ia.rdiv(1.0);
                break;
            case 6:
                t = sd.pow(in, 2.5);
                ia = Nd4j.rand(minibatch, nOut);
                expOut = Transforms.pow(ia, 2.5, true);
                break;
            case 7:
                t = sd.sigmoid(in);
                ia = Nd4j.rand(minibatch, nOut).muli(2).subi(1.0);
                expOut = Transforms.sigmoid(ia, true);
                break;
            case 8:
                t = sd.tanh(in);
                ia = Nd4j.rand(minibatch, nOut).muli(2).subi(1.0);
                expOut = Transforms.tanh(ia, true);
                break;
            case 9:
                t = sd.tan(in);
                ia = Nd4j.rand(minibatch, nOut);
                expOut = Nd4j.getExecutioner().execAndReturn(new Tan(ia.dup()));
                break;
            case 10:
                t = sd.cos(in);
                expOut = Transforms.cos(ia, true);
                break;
            case 11:
                t = sd.sin(in);
                expOut = Transforms.sin(ia, true);
                break;
            case 12:
                t = sd.softplus(in);
                expOut = Transforms.softPlus(ia, true);
                break;
            case 13:
                t = sd.log(in);
                ia = Nd4j.rand(minibatch, nOut);
                expOut = Transforms.log(ia, true);
                skipBackward = true;
                break;
            case 14:
                t = sd.neg(in);
                expOut = ia.neg();
                break;
            case 15:
                t = sd.acos(in);
                ia = Nd4j.rand(minibatch, nOut).muli(1.8).subi(0.9);
                expOut = Transforms.acos(ia, true);
                break;
            case 16:
                t = sd.acosh(in);
                // Only defined for x >= 1
                ia = Nd4j.rand(minibatch, nOut).addi(1.01);
                expOut = Nd4j.getExecutioner().execAndReturn(new ACosh(ia.dup()));
                skipBackward = true;
                break;
            case 17:
                t = sd.asin(in);
                ia = Nd4j.rand(minibatch, nOut).muli(1.8).subi(0.9);
                expOut = Transforms.asin(ia, true);
                break;
            case 18:
                t = sd.atan(in);
                ia = Nd4j.rand(minibatch, nOut).muli(4).subi(2);
                expOut = Transforms.atan(ia, true);
                break;
            case 19:
                t = sd.atanh(in);
                ia = Nd4j.rand(minibatch, nOut).muli(1.8).subi(0.9);
                expOut = Transforms.atanh(ia, true);
                break;
            case 20:
                t = sd.cosh(in);
                expOut = Transforms.cosh(ia, true);
                break;
            case 21:
                t = sd.cube(in);
                expOut = Transforms.pow(ia, 3.0, true);
                break;
            case 22:
                t = sd.elu(in);
                expOut = Transforms.elu(ia, true);
                break;
            case 23:
                // TODO SHOULDN'T THIS HAVE A DIMENSION ARG???
                t = sd.softmax(in);
                ia = Nd4j.rand(minibatch, nOut);
                expOut = Nd4j.getExecutioner().execAndReturn(new OldSoftMax(ia.dup()));
                break;
            case 24:
                t = sd.sqrt(in);
                ia = Nd4j.rand(minibatch, nOut);
                expOut = Transforms.sqrt(ia, true);
                break;
            case 25:
                t = sd.square(in);
                expOut = Transforms.pow(ia, 2.0, true);
                break;
            case 26:
                t = sd.transpose(in);
                expOut = ia.transpose().dup();
                break;
            case 27:
                t = sd.abs(in);
                expOut = Transforms.abs(ia, true);
                break;
            case 28:
                t = sd.sinh(in);
                expOut = Transforms.sinh(ia, true);
                break;
            case 29:
                t = sd.asinh(in);
                expOut = Nd4j.getExecutioner().execAndReturn(new ASinh(ia.dup()));
                skipBackward = true;
                break;
            case 30:
                t = sd.exp(in);
                expOut = Transforms.exp(ia, true);
                break;
            case 31:
                t = sd.floor(in);
                expOut = Transforms.floor(ia, true);
                break;
            case 32:
                t = sd.relu(in, 0.0);
                ia = Nd4j.rand(minibatch, nOut);
                expOut = Transforms.relu(ia, true);
                break;
            case 33:
                t = sd.hardTanh(in);
                ia = Nd4j.rand(minibatch, nOut).muli(2).subi(1.0);
                expOut = Transforms.hardTanh(ia, true);
                break;
            case 34:
                t = sd.logSigmoid(in);
                expOut = Nd4j.getExecutioner().execAndReturn(new LogSigmoid(ia.dup()));
                break;
            case 35:
                t = sd.swish(in);
                expOut = Nd4j.getExecutioner().execAndReturn(new Swish(ia.dup()));
                break;
            case 36:
                t = sd.sign(in);
                expOut = Transforms.sign(ia, true);
                break;
            case 37:
                t = sd.softsign(in);
                expOut = Transforms.softsign(ia, true);
                break;
            case 38:
                t = sd.leakyRelu(in, 0.0);
                ia = Nd4j.rand(minibatch, nOut);
                expOut = Transforms.leakyRelu(ia, true);
                break;
            case 39:
                // TODO DIMENSION ARG???
                // TODO fix me
                t = sd.logSoftmax(in);
                ia = Nd4j.rand(minibatch, nOut);
                expOut = Transforms.log(Transforms.softmax(ia, true));
                skipBackward = true;
                break;
            case 40:
                t = sd.selu(in);
                expOut = Nd4j.getExecutioner().execAndReturn(new SELU(ia.dup()));
                break;
            case 41:
                t = sd.gt(in, 1.0);
                expOut = ia.gt(1.0);
                break;
            case 42:
                t = sd.gte(in, 1.0);
                expOut = ia.gte(1.0);
                break;
            case 43:
                t = sd.lt(in, 1.0);
                expOut = ia.lt(1.0);
                break;
            case 44:
                t = sd.lte(in, 1.0);
                expOut = ia.lte(1.0);
                break;
            case 45:
                t = sd.eq(in, 2.0);
                ia = Nd4j.linspace(1, minibatch * nOut, minibatch * nOut).reshape('c', minibatch, nOut);
                expOut = ia.eq(2.0);
                break;
            case 46:
                t = sd.neq(in, 2.0);
                ia = Nd4j.linspace(1, minibatch * nOut, minibatch * nOut).reshape('c', minibatch, nOut);
                expOut = ia.neq(2.0);
                break;
            case 47:
                t = sd.ceil(in);
                expOut = Transforms.ceil(ia, true);
                break;
            case 48:
                ia = Nd4j.randn(ia.shape()).muli(2);
                t = sd.clipByValue(in, -3, 2);
                expOut = ia.dup();
                BooleanIndexing.replaceWhere(expOut, -3, Conditions.lessThan(-3));
                BooleanIndexing.replaceWhere(expOut, 2, Conditions.greaterThan(2));
                break;
            case 49:
                // Clip by norm, dimension 0, some below threshold, some above
                double clip = 2.0;
                ia = Nd4j.rand(ia.shape());
                // Exactly at threshold...
                ia.muliRowVector(ia.norm2(0).rdiv(clip));
                System.out.println(ia.norm2(0));
                ia.muliRowVector(Nd4j.linspace(0.9, 1.1, ia.size(1)));
                System.out.println(ia.norm2(0));
                expOut = Nd4j.create(ia.shape());
                for (int j = 0; j < ia.columns(); j++) {
                    INDArray origCol = ia.getColumn(j);
                    if (origCol.norm2Number().doubleValue() < clip) {
                        expOut.putColumn(j, origCol);
                    } else {
                        expOut.putColumn(j, origCol.mul(clip / origCol.norm2Number().doubleValue()));
                    }
                }
                t = sd.clipByNorm(in, clip);
                break;
            // TODO clip by norm along other dimensions
            case 50:
                dim = 1;
                t = sd.reverse(in, dim);
                expOut = Nd4j.create(ia.shape());
                DynamicCustomOp reverse = DynamicCustomOp.builder("reverse").addIntegerArguments(dim).addInputs(ia).addOutputs(expOut).build();
                Nd4j.getExecutioner().exec(reverse);
                break;
            case 51:
                dim = 0;
                boolean exclusive = false;
                boolean reverseBool = false;
                t = sd.cumsum(in, exclusive, reverseBool, dim);
                expOut = Nd4j.create(ia.shape());
                DynamicCustomOp cumsum = DynamicCustomOp.builder("cumsum").addIntegerArguments((exclusive) ? 1 : 0, (reverseBool) ? 1 : 0, dim).addInputs(ia).addOutputs(expOut).build();
                Nd4j.getExecutioner().exec(cumsum);
                break;
            case 52:
                dim = 0;
                boolean ex = false;
                boolean revBool = false;
                t = sd.cumsum(in, ex, revBool, dim);
                expOut = Nd4j.create(ia.shape());
                DynamicCustomOp cumprod = DynamicCustomOp.builder("cumprod").addIntegerArguments((ex) ? 1 : 0, (revBool) ? 1 : 0, dim).addInputs(ia).addOutputs(expOut).build();
                Nd4j.getExecutioner().exec(cumprod);
                break;
            case 53:
                ia = Nd4j.create(new float[] { 4, 2 });
                in = sd.var("in", new int[] { 1, 2 });
                expOut = Nd4j.create(new int[] { 2, 2 });
                DynamicCustomOp op = DynamicCustomOp.builder("diag").addInputs(ia).addOutputs(expOut).build();
                Nd4j.getExecutioner().exec(op);
                t = sd.diag(in);
                break;
            case 54:
                expOut = Nd4j.createUninitialized(ia.shape(), ia.ordering());
                Nd4j.getExecutioner().exec(new Erf(ia, expOut));
                t = sd.erf(in);
                skipBackward = true;
                break;
            case 55:
                expOut = Nd4j.createUninitialized(ia.shape(), ia.ordering());
                Nd4j.getExecutioner().exec(new Erfc(ia, expOut));
                t = sd.erfc(in);
                skipBackward = true;
                break;
            case 56:
                t = sd.expm1(in);
                expOut = Transforms.expm1(ia, true);
                break;
            case 57:
                t = sd.log1p(in);
                ia = Nd4j.rand(minibatch, nOut);
                expOut = Transforms.log1p(ia, true);
                skipBackward = true;
                break;
            case 58:
                t = sd.round(in);
                expOut = Transforms.round(ia, true);
                skipBackward = true;
                break;
            case 59:
                ia = Nd4j.create(new float[] { 4, 2 });
                in = sd.var("in", new int[] { 1, 2 });
                t = sd.rsqrt(in);
                expOut = Nd4j.create(ia.shape(), ia.ordering());
                Nd4j.getExecutioner().exec(new RSqrt(ia, expOut));
                skipBackward = true;
                break;
            case 60:
                t = sd.relu6(in, 0);
                ia = Nd4j.rand(minibatch, nOut);
                expOut = Transforms.relu6(ia);
                skipBackward = true;
                break;
            case 61:
                ia = Nd4j.create(new float[] { 2, 2 });
                in = sd.var("in", new int[] { 1, 2 });
                sd.associateArrayWithVariable(ia, in);
                double value = 42;
                expOut = Nd4j.create(new int[] { 2, 2 });
                DynamicCustomOp fillOp = DynamicCustomOp.builder("fill").addInputs(ia).addFloatingPointArguments(value).addOutputs(expOut).build();
                Nd4j.getExecutioner().exec(fillOp);
                skipBackward = true;
                t = sd.fill(in, value);
                break;
            default:
                throw new RuntimeException();
        }
        DifferentialFunction[] funcs = sd.functions();
        String name = funcs[0].opName();
        String msg = "test: " + i + " - " + name;
        log.info("*** Starting test: " + msg);
        SDVariable loss = sd.mean("loss", t);
        sd.associateArrayWithVariable(ia, in);
        sd.exec();
        INDArray out = t.getArr();
        out.shape();
        if (!expOut.equals(out)) {
            allFailed.add(msg + " - FAILED ON FORWARD");
            continue;
        }
        boolean ok;
        if (skipBackward) {
            ok = true;
            msg += " - SKIPPED";
            allSkipped.add(msg);
        } else {
            try {
                ok = GradCheckUtil.checkGradients(sd);
            } catch (Exception e) {
                e.printStackTrace();
                msg += " - EXCEPTION";
                ok = false;
            }
        }
        assertTrue(msg, ok);
        if (!ok) {
            allFailed.add(msg);
        }
    }
    if (allSkipped.size() > 0) {
        log.info("All backward skipped transforms: " + allSkipped);
        log.info(allSkipped.size() + " backward passes were skipped.");
    }
    if (allFailed.size() > 0) {
        log.error("All failed transforms: " + allFailed);
        fail(allFailed.size() + " transforms failed");
    }
}
Also used : SameDiff(org.nd4j.autodiff.samediff.SameDiff) ArrayList(java.util.ArrayList) SDVariable(org.nd4j.autodiff.samediff.SDVariable) INDArray(org.nd4j.linalg.api.ndarray.INDArray) DifferentialFunction(org.nd4j.autodiff.functions.DifferentialFunction) DynamicCustomOp(org.nd4j.linalg.api.ops.DynamicCustomOp) Test(org.junit.Test)

Example 8 with DifferentialFunction

use of org.nd4j.autodiff.functions.DifferentialFunction in project nd4j by deeplearning4j.

the class SameDiffTests method testResultPropagation.

@Test
public void testResultPropagation() {
    SameDiff sameDiff = SameDiff.create();
    INDArray inputs = Nd4j.create(new double[][] { { 0.52, 1.12, 0.77 }, { 0.88, -1.08, 0.15 }, { 0.52, 0.06, -1.30 }, { 0.74, -2.49, 1.39 } });
    INDArray weights = Nd4j.randn(3, 1);
    SDVariable x = sameDiff.var("x", inputs);
    SDVariable w = sameDiff.var("w", weights);
    SDVariable preOutput = sameDiff.mmul(x, w);
    SDVariable outputs = sameDiff.sigmoid(preOutput);
    List<DifferentialFunction> ops = sameDiff.exec().getRight();
    DifferentialFunction firstOp = ops.get(0);
    val firstResult = sameDiff.getVariable(firstOp.outputVariables()[0].getVarName()).getArr();
}
Also used : lombok.val(lombok.val) INDArray(org.nd4j.linalg.api.ndarray.INDArray) DifferentialFunction(org.nd4j.autodiff.functions.DifferentialFunction) Test(org.junit.Test)

Example 9 with DifferentialFunction

use of org.nd4j.autodiff.functions.DifferentialFunction in project nd4j by deeplearning4j.

the class SameDiffTests method testMmulGradient.

@Test
public void testMmulGradient() {
    SameDiff sameDiff = SameDiff.create();
    INDArray sumInput = Nd4j.linspace(1, 4, 4).reshape(2, 2);
    Map<String, INDArray> inputs = new HashMap<>();
    inputs.put("x", sumInput);
    inputs.put("y", sumInput.dup());
    sameDiff.defineFunction("mmulGradient", new SameDiff.SameDiffFunctionDefinition() {

        @Override
        public SDVariable[] define(SameDiff sameDiff, Map<String, INDArray> inputs, SDVariable[] variableInputs) {
            SDVariable input = sameDiff.var("x", inputs.get("x"));
            SDVariable input2 = sameDiff.var("y", inputs.get("y"));
            SDVariable exp = sameDiff.mmul(input, input2);
            SDVariable sum = sameDiff.sum(exp, Integer.MAX_VALUE);
            return new SDVariable[] { sum };
        }
    }, inputs);
    List<DifferentialFunction> ops = sameDiff.getFunction("mmulGradient").execBackwards().getRight();
    String print = sameDiff.asFlatPrint();
    assumeNotNull(sameDiff.getFunction("mmulGradient").getFunction("grad"));
    assumeNotNull(sameDiff.getFunction("mmulGradient").grad("x"));
    assumeNotNull(sameDiff.getFunction("mmulGradient").grad("y"));
    SDVariable gradWrtX = sameDiff.getFunction("mmulGradient").grad("x");
    SDVariable gradWrtY = sameDiff.getFunction("mmulGradient").grad("y");
    assumeNotNull(gradWrtX.getArr());
    assumeNotNull(gradWrtY.getArr());
    INDArray xGradAssertion = Nd4j.create(new double[][] { { 3, 7 }, { 3, 7 } });
    INDArray yGradAssertion = Nd4j.create(new double[][] { { 4, 4 }, { 6, 6 } });
    assertEquals(xGradAssertion, gradWrtX.getArr());
    assertEquals(yGradAssertion, gradWrtY.getArr());
}
Also used : INDArray(org.nd4j.linalg.api.ndarray.INDArray) DifferentialFunction(org.nd4j.autodiff.functions.DifferentialFunction) Test(org.junit.Test)

Example 10 with DifferentialFunction

use of org.nd4j.autodiff.functions.DifferentialFunction in project nd4j by deeplearning4j.

the class OpsMappingTests method addOperation.

protected void addOperation(Class<? extends DifferentialFunction> clazz, List<Operation> list) {
    if (Modifier.isAbstract(clazz.getModifiers()) || clazz.isInterface())
        return;
    try {
        DifferentialFunction node = clazz.newInstance();
        if (node instanceof DynamicCustomOp) {
            list.add(new Operation(-1L, node.opName().toLowerCase()));
            list.add(new Operation(-1L, node.tensorflowName().toLowerCase()));
        } else {
            val op = new Operation(Long.valueOf(node.opNum()), node.opName());
            list.add(op);
        }
    } catch (UnsupportedOperationException e) {
    // 
    } catch (NoOpNameFoundException e) {
    // 
    } catch (InstantiationException e) {
    // 
    } catch (Exception e) {
        log.info("Failed on [{}]", clazz.getSimpleName());
        throw new RuntimeException(e);
    }
}
Also used : lombok.val(lombok.val) DifferentialFunction(org.nd4j.autodiff.functions.DifferentialFunction) NoOpNameFoundException(org.nd4j.imports.NoOpNameFoundException) NoOpNameFoundException(org.nd4j.imports.NoOpNameFoundException) ND4JIllegalStateException(org.nd4j.linalg.exception.ND4JIllegalStateException)

Aggregations

DifferentialFunction (org.nd4j.autodiff.functions.DifferentialFunction)18 INDArray (org.nd4j.linalg.api.ndarray.INDArray)9 Test (org.junit.Test)7 ND4JIllegalStateException (org.nd4j.linalg.exception.ND4JIllegalStateException)6 ArrayList (java.util.ArrayList)3 lombok.val (lombok.val)3 SDVariable (org.nd4j.autodiff.samediff.SDVariable)3 SameDiff (org.nd4j.autodiff.samediff.SameDiff)2 NoOpNameFoundException (org.nd4j.imports.NoOpNameFoundException)2 DynamicCustomOp (org.nd4j.linalg.api.ops.DynamicCustomOp)2 Op (org.nd4j.linalg.api.ops.Op)2 GradientBackwardsMarker (org.nd4j.linalg.api.ops.impl.transforms.gradient.GradientBackwardsMarker)2 IntArrayKeyMap (org.nd4j.linalg.collection.IntArrayKeyMap)2 Set (java.util.Set)1 AtomicInteger (java.util.concurrent.atomic.AtomicInteger)1 FlowPath (org.nd4j.autodiff.samediff.flow.FlowPath)1 NdIndexIterator (org.nd4j.linalg.api.iter.NdIndexIterator)1 Variance (org.nd4j.linalg.api.ops.impl.accum.Variance)1 If (org.nd4j.linalg.api.ops.impl.controlflow.If)1 While (org.nd4j.linalg.api.ops.impl.controlflow.While)1