Search in sources :

Example 11 with DifferentialFunction

use of org.nd4j.autodiff.functions.DifferentialFunction in project nd4j by deeplearning4j.

the class OpsMappingTests method getOperations.

protected List<Operation> getOperations(@NonNull Op.Type type) {
    val list = new ArrayList<Operation>();
    Reflections f = new Reflections(new ConfigurationBuilder().filterInputsBy(new FilterBuilder().include(FilterBuilder.prefix("org.nd4j.*")).exclude("^(?!.*\\.class$).*$")).setUrls(ClasspathHelper.forPackage("org.nd4j")).setScanners(new SubTypesScanner()));
    switch(type) {
        case SUMMARYSTATS:
            {
                Set<Class<? extends Variance>> clazzes = f.getSubTypesOf(Variance.class);
                for (Class<? extends DifferentialFunction> clazz : clazzes) addOperation(clazz, list);
            }
            break;
        case RANDOM:
            {
                Set<Class<? extends BaseRandomOp>> clazzes = f.getSubTypesOf(BaseRandomOp.class);
                for (Class<? extends DifferentialFunction> clazz : clazzes) addOperation(clazz, list);
            }
            break;
        case INDEXREDUCE:
            {
                Set<Class<? extends BaseIndexAccumulation>> clazzes = f.getSubTypesOf(BaseIndexAccumulation.class);
                for (Class<? extends DifferentialFunction> clazz : clazzes) addOperation(clazz, list);
            }
            break;
        case REDUCE3:
        case REDUCE:
            {
                Set<Class<? extends BaseAccumulation>> clazzes = f.getSubTypesOf(BaseAccumulation.class);
                for (Class<? extends DifferentialFunction> clazz : clazzes) addOperation(clazz, list);
            }
            break;
        case BROADCAST:
            {
                Set<Class<? extends BaseBroadcastOp>> clazzes = f.getSubTypesOf(BaseBroadcastOp.class);
                for (Class<? extends DifferentialFunction> clazz : clazzes) addOperation(clazz, list);
            }
            break;
        case SCALAR:
            {
                Set<Class<? extends BaseScalarOp>> clazzes = f.getSubTypesOf(BaseScalarOp.class);
                for (Class<? extends DifferentialFunction> clazz : clazzes) addOperation(clazz, list);
            }
            break;
        case PAIRWISE:
        case TRANSFORM:
            {
                Set<Class<? extends BaseTransformOp>> clazzes = f.getSubTypesOf(BaseTransformOp.class);
                for (Class<? extends DifferentialFunction> clazz : clazzes) addOperation(clazz, list);
            }
            break;
        case CUSTOM:
            {
                Set<Class<? extends DynamicCustomOp>> clazzes = f.getSubTypesOf(DynamicCustomOp.class);
                for (Class<? extends DifferentialFunction> clazz : clazzes) {
                    if (clazz.getSimpleName().equalsIgnoreCase("dynamiccustomop"))
                        continue;
                    addOperation(clazz, list);
                }
            }
            break;
    }
    log.info("Group: {}; List size: {}", type, list.size());
    return list;
}
Also used : lombok.val(lombok.val) ConfigurationBuilder(org.reflections.util.ConfigurationBuilder) Set(java.util.Set) ArrayList(java.util.ArrayList) BaseRandomOp(org.nd4j.linalg.api.ops.random.BaseRandomOp) Variance(org.nd4j.linalg.api.ops.impl.accum.Variance) FilterBuilder(org.reflections.util.FilterBuilder) SubTypesScanner(org.reflections.scanners.SubTypesScanner) DifferentialFunction(org.nd4j.autodiff.functions.DifferentialFunction) Reflections(org.reflections.Reflections)

Example 12 with DifferentialFunction

use of org.nd4j.autodiff.functions.DifferentialFunction in project nd4j by deeplearning4j.

the class GradCheckTransforms method testPairwiseTransforms.

@Test
public void testPairwiseTransforms() {
    /*
        add, sub, mul, div, rsub, rdiv
        eq, neq, gt, lt, gte, lte, or, and, xor
        min, max
        mmul
        tensormmul
         */
    // Test transforms (pairwise)
    Nd4j.getRandom().setSeed(12345);
    List<String> allSkipped = new ArrayList<>();
    List<String> allFailed = new ArrayList<>();
    for (int i = 0; i < 23; i++) {
        boolean skipBackward = false;
        SameDiff sd = SameDiff.create();
        int nOut = 4;
        int minibatch = 5;
        SDVariable in1 = sd.var("in1", new int[] { -1, nOut });
        SDVariable in2 = sd.var("in2", new int[] { -1, nOut });
        INDArray ia = Nd4j.randn(minibatch, nOut);
        INDArray ib = Nd4j.randn(minibatch, nOut);
        SDVariable t;
        INDArray expOut;
        switch(i) {
            case 0:
                t = in1.add(in2);
                expOut = ia.add(ib);
                break;
            case 1:
                t = in1.sub(in2);
                expOut = ia.sub(ib);
                break;
            case 2:
                t = in1.mul(in2);
                expOut = ia.mul(ib);
                break;
            case 3:
                // break;
                continue;
            case 4:
                t = in1.rsub(in2);
                expOut = ia.rsub(ib);
                break;
            case 5:
                t = in1.rdiv(in2);
                expOut = ia.rdiv(ib);
                break;
            case 6:
                t = sd.eq(in1, in2);
                expOut = ia.eq(ib);
                break;
            case 7:
                t = sd.neq(in1, in2);
                expOut = ia.neq(ib);
                break;
            case 8:
                t = sd.gt(in1, in2);
                expOut = ia.gt(ib);
                break;
            case 9:
                t = sd.lt(in1, in2);
                expOut = ia.lt(ib);
                break;
            case 10:
                t = sd.gte(in1, in2);
                expOut = ia.dup();
                Nd4j.getExecutioner().exec(new GreaterThanOrEqual(new INDArray[] { ia, ib }, new INDArray[] { expOut }));
                break;
            case 11:
                t = sd.lte(in1, in2);
                expOut = ia.dup();
                Nd4j.getExecutioner().exec(new LessThanOrEqual(new INDArray[] { ia, ib }, new INDArray[] { expOut }));
                break;
            case 12:
                ia = Nd4j.getExecutioner().exec(new BernoulliDistribution(ia, 0.5));
                ib = Nd4j.getExecutioner().exec(new BernoulliDistribution(ib, 0.5));
                t = sd.or(in1, in2);
                expOut = Transforms.or(ia, ib);
                break;
            case 13:
                ib = Nd4j.randn(nOut, nOut);
                t = sd.mmul(in1, in2);
                expOut = ia.mmul(ib);
                break;
            case 14:
                t = sd.max(in1, in2);
                expOut = Nd4j.getExecutioner().execAndReturn(new OldMax(ia, ib, ia.dup(), ia.length()));
                break;
            case 15:
                t = sd.min(in1, in2);
                expOut = Nd4j.getExecutioner().execAndReturn(new OldMin(ia, ib, ia.dup(), ia.length()));
                break;
            case 16:
                ia = Nd4j.getExecutioner().exec(new BernoulliDistribution(ia, 0.5));
                ib = Nd4j.getExecutioner().exec(new BernoulliDistribution(ib, 0.5));
                t = sd.and(in1, in2);
                expOut = Transforms.and(ia, ib);
                break;
            case 17:
                ia = Nd4j.getExecutioner().exec(new BernoulliDistribution(ia, 0.5));
                ib = Nd4j.getExecutioner().exec(new BernoulliDistribution(ib, 0.5));
                t = sd.xor(in1, in2);
                expOut = Transforms.xor(ia, ib);
                break;
            case 18:
                t = sd.assign(in1, in2);
                expOut = ib;
                break;
            case 19:
                t = sd.atan2(in1, in2);
                // Note: y,x order for samediff; x,y order for transforms
                expOut = Transforms.atan2(ib, ia);
                skipBackward = true;
                break;
            case 20:
                t = sd.mergeAdd(in1, in2, in2);
                expOut = ia.add(ib).add(ib);
                break;
            case 21:
                ia = Nd4j.create(new float[] { 2, 4 });
                ib = Nd4j.create(new float[] { 42, 2 });
                in1 = sd.var("in1", new int[] { 1, 2 });
                in2 = sd.var("in2", new int[] { 1, 2 });
                t = in1.truncatedDiv(in2);
                expOut = Nd4j.create(ia.shape(), ia.ordering());
                Nd4j.getExecutioner().exec(new TruncateDivOp(ia, ib, expOut));
                skipBackward = true;
                break;
            case 22:
                t = in1.squaredDifference(in2);
                expOut = Nd4j.create(ia.shape(), ia.ordering());
                DynamicCustomOp squareDiff = DynamicCustomOp.builder("squaredsubtract").addInputs(ia, ib).addOutputs(expOut).build();
                Nd4j.getExecutioner().exec(squareDiff);
                skipBackward = true;
                break;
            default:
                throw new RuntimeException();
        }
        DifferentialFunction[] funcs = sd.functions();
        String name = funcs[0].opName();
        String msg = "test: " + i + " - " + name;
        log.info("*** Starting test: " + msg);
        SDVariable loss = sd.mean("loss", t);
        sd.associateArrayWithVariable(ia, in1);
        sd.associateArrayWithVariable(ib, in2);
        sd.exec();
        INDArray out = t.getArr();
        assertEquals(msg, expOut, out);
        boolean ok;
        if (skipBackward) {
            ok = true;
            msg += " - SKIPPED";
            allSkipped.add(msg);
        } else {
            try {
                ok = GradCheckUtil.checkGradients(sd);
            } catch (Exception e) {
                e.printStackTrace();
                msg += " - EXCEPTION";
                ok = false;
            }
        }
        if (!ok) {
            allFailed.add(msg);
        }
    }
    if (allSkipped.size() > 0) {
        log.info("All backward skipped transforms: " + allSkipped);
        log.info(allSkipped.size() + " backward passes were skipped.");
    }
    if (allFailed.size() > 0) {
        log.error("All failed transforms: " + allFailed);
        fail(allFailed.size() + " transforms failed");
    }
}
Also used : LessThanOrEqual(org.nd4j.linalg.api.ops.impl.transforms.comparison.LessThanOrEqual) OldMin(org.nd4j.linalg.api.ops.impl.transforms.comparison.OldMin) SameDiff(org.nd4j.autodiff.samediff.SameDiff) ArrayList(java.util.ArrayList) GreaterThanOrEqual(org.nd4j.linalg.api.ops.impl.transforms.comparison.GreaterThanOrEqual) SDVariable(org.nd4j.autodiff.samediff.SDVariable) OldMax(org.nd4j.linalg.api.ops.impl.transforms.comparison.OldMax) INDArray(org.nd4j.linalg.api.ndarray.INDArray) BernoulliDistribution(org.nd4j.linalg.api.ops.random.impl.BernoulliDistribution) DifferentialFunction(org.nd4j.autodiff.functions.DifferentialFunction) DynamicCustomOp(org.nd4j.linalg.api.ops.DynamicCustomOp) TruncateDivOp(org.nd4j.linalg.api.ops.impl.transforms.arithmetic.TruncateDivOp) Test(org.junit.Test)

Example 13 with DifferentialFunction

use of org.nd4j.autodiff.functions.DifferentialFunction in project nd4j by deeplearning4j.

the class SameDiffTests method testSums.

@Test
public void testSums() {
    SameDiff sameDiff = SameDiff.create();
    INDArray ones = Nd4j.ones(4);
    SDVariable sdVariable = sameDiff.var("ones", ones);
    SDVariable result = sdVariable.addi(1.0);
    SDVariable total = sameDiff.sum(result, Integer.MAX_VALUE);
    List<DifferentialFunction> ops = sameDiff.exec().getRight();
    INDArray output = null;
    for (int i = 0; i < 5; i++) {
        output = sameDiff.execAndEndResult(ops);
        System.out.println("Ones " + ones);
        System.out.println(output);
    }
    assertEquals(Nd4j.valueArrayOf(4, 7), ones);
    assertEquals(28, output.getDouble(0), 1e-1);
}
Also used : INDArray(org.nd4j.linalg.api.ndarray.INDArray) DifferentialFunction(org.nd4j.autodiff.functions.DifferentialFunction) Test(org.junit.Test)

Example 14 with DifferentialFunction

use of org.nd4j.autodiff.functions.DifferentialFunction in project nd4j by deeplearning4j.

the class SameDiffTests method testExpGradient.

@Test
public void testExpGradient() {
    SameDiff sameDiff = SameDiff.create();
    INDArray sumInput = Nd4j.linspace(1, 4, 4).reshape(2, 2);
    Map<String, INDArray> inputs = new HashMap<>();
    inputs.put("x", sumInput);
    sameDiff.defineFunction("expGradient", new SameDiff.SameDiffFunctionDefinition() {

        @Override
        public SDVariable[] define(SameDiff sameDiff, Map<String, INDArray> inputs, SDVariable[] variableInputs) {
            SDVariable input = sameDiff.var("x", inputs.get("x"));
            SDVariable exp = sameDiff.exp(input);
            SDVariable sum = sameDiff.sum(exp, Integer.MAX_VALUE);
            return new SDVariable[] { sum };
        }
    }, inputs);
    List<DifferentialFunction> ops = sameDiff.getFunction("expGradient").execBackwards().getRight();
    INDArray executions = ops.get(ops.size() - 1).outputVariables()[0].getArr();
    INDArray assertion = Nd4j.create(new double[][] { { 2.7183, 7.3891 }, { 20.0855, 54.5981 } });
    assertArrayEquals(sumInput.shape(), executions.shape());
    assertEquals(assertion, executions);
    System.out.println(executions);
// assertEquals(Nd4j.ones(2,2),executions);
}
Also used : INDArray(org.nd4j.linalg.api.ndarray.INDArray) DifferentialFunction(org.nd4j.autodiff.functions.DifferentialFunction) Test(org.junit.Test)

Example 15 with DifferentialFunction

use of org.nd4j.autodiff.functions.DifferentialFunction in project nd4j by deeplearning4j.

the class SameDiffTests method testSigmoidBackwards.

@Test
public void testSigmoidBackwards() {
    SameDiff sameDiff = SameDiff.create();
    INDArray sumInput = Nd4j.linspace(1, 4, 4).reshape(2, 2);
    Map<String, INDArray> inputs = new HashMap<>();
    inputs.put("x", sumInput);
    SDVariable input = sameDiff.var("x", inputs.get("x"));
    SDVariable sigmoid = sameDiff.sigmoid(input);
    SDVariable sum = sameDiff.sum(sigmoid, Integer.MAX_VALUE);
    List<DifferentialFunction> backwardsOps = sameDiff.execBackwards().getRight();
    Op finalOp = (Op) backwardsOps.get(backwardsOps.size() - 1);
    assertTrue(Nd4j.create(new double[][] { { 0.1966, 0.1050 }, { 0.0452, 0.0177 } }).equalsWithEps(finalOp.z(), 1e-2));
    System.out.println(backwardsOps);
}
Also used : Op(org.nd4j.linalg.api.ops.Op) INDArray(org.nd4j.linalg.api.ndarray.INDArray) DifferentialFunction(org.nd4j.autodiff.functions.DifferentialFunction) Test(org.junit.Test)

Aggregations

DifferentialFunction (org.nd4j.autodiff.functions.DifferentialFunction)18 INDArray (org.nd4j.linalg.api.ndarray.INDArray)9 Test (org.junit.Test)7 ND4JIllegalStateException (org.nd4j.linalg.exception.ND4JIllegalStateException)6 ArrayList (java.util.ArrayList)3 lombok.val (lombok.val)3 SDVariable (org.nd4j.autodiff.samediff.SDVariable)3 SameDiff (org.nd4j.autodiff.samediff.SameDiff)2 NoOpNameFoundException (org.nd4j.imports.NoOpNameFoundException)2 DynamicCustomOp (org.nd4j.linalg.api.ops.DynamicCustomOp)2 Op (org.nd4j.linalg.api.ops.Op)2 GradientBackwardsMarker (org.nd4j.linalg.api.ops.impl.transforms.gradient.GradientBackwardsMarker)2 IntArrayKeyMap (org.nd4j.linalg.collection.IntArrayKeyMap)2 Set (java.util.Set)1 AtomicInteger (java.util.concurrent.atomic.AtomicInteger)1 FlowPath (org.nd4j.autodiff.samediff.flow.FlowPath)1 NdIndexIterator (org.nd4j.linalg.api.iter.NdIndexIterator)1 Variance (org.nd4j.linalg.api.ops.impl.accum.Variance)1 If (org.nd4j.linalg.api.ops.impl.controlflow.If)1 While (org.nd4j.linalg.api.ops.impl.controlflow.While)1