Search in sources :

Example 81 with SDVariable

use of org.nd4j.autodiff.samediff.SDVariable in project nd4j by deeplearning4j.

the class GraphExecutionerTest method testEquality2.

/**
 * Implicit should return tree edges. So, one variable
 * @throws Exception
 */
@Test
public void testEquality2() throws Exception {
    GraphExecutioner executionerA = new BasicGraphExecutioner();
    GraphExecutioner executionerB = new NativeGraphExecutioner();
    SameDiff sameDiff = SameDiff.create();
    INDArray ones = Nd4j.ones(4);
    SDVariable sdVariable = sameDiff.var("ones", ones);
    SDVariable scalarOne = sameDiff.var("add1", Nd4j.scalar(1.0));
    SDVariable result = sdVariable.addi(scalarOne);
    SDVariable total = sameDiff.sum(result, Integer.MAX_VALUE);
    // log.info("ID: {}",sameDiff.getGraph().getVertex(1).getValue().getId());
    INDArray[] resB = executionerB.executeGraph(sameDiff, configImplicit);
    assertEquals(1, resB.length);
    assertEquals(Nd4j.scalar(8.0), resB[0]);
// INDArray resA = executionerA.executeGraph(sameDiff)[0];
// assertEquals(resA, resB);
}
Also used : SDVariable(org.nd4j.autodiff.samediff.SDVariable) INDArray(org.nd4j.linalg.api.ndarray.INDArray) SameDiff(org.nd4j.autodiff.samediff.SameDiff) Test(org.junit.Test)

Example 82 with SDVariable

use of org.nd4j.autodiff.samediff.SDVariable in project nd4j by deeplearning4j.

the class GradCheckLoss method testLossWeights3d.

@Test
public void testLossWeights3d() {
    String[] weightTypes = new String[] { "none", "per-example", "per-output", "per-timestep", "per-example-output", "per-example-timestep", "per-output-timestep", "per-all" };
    Nd4j.getRandom().setSeed(12345);
    int nOut = 4;
    int minibatch = 10;
    int tsLength = 5;
    for (String weightType : weightTypes) {
        for (boolean binary : new boolean[] { true, false }) {
            // Binary mask (like DL4J) or arbitrary weights?
            int[] weightShape;
            switch(weightType) {
                case "none":
                    weightShape = null;
                    break;
                case "per-example":
                    weightShape = new int[] { minibatch, 1, 1 };
                    break;
                case "per-output":
                    weightShape = new int[] { 1, nOut, 1 };
                    break;
                case "per-timestep":
                    weightShape = new int[] { 1, 1, tsLength };
                    break;
                case "per-example-output":
                    weightShape = new int[] { minibatch, nOut, 1 };
                    break;
                case "per-example-timestep":
                    weightShape = new int[] { minibatch, 1, nOut };
                    break;
                case "per-output-timestep":
                    weightShape = new int[] { 1, nOut, tsLength };
                    break;
                case "per-all":
                    weightShape = new int[] { minibatch, nOut, tsLength };
                    break;
                default:
                    throw new RuntimeException("Unknown type: " + weightType);
            }
            INDArray weightArr = null;
            if (!"none".equals(weightType)) {
                if (binary) {
                    weightArr = Nd4j.getExecutioner().exec(new BernoulliDistribution(Nd4j.createUninitialized(weightShape), 0.5));
                } else {
                    weightArr = Nd4j.rand(weightShape).muli(2.0);
                }
            }
            for (LossFunctions.Reduction reduction : new LossFunctions.Reduction[] { LossFunctions.Reduction.MEAN_BY_COUNT, LossFunctions.Reduction.MEAN_BY_WEIGHT, LossFunctions.Reduction.SUM }) {
                for (String fn : new String[] { "mse", "l1", "l2", "mcxent" }) {
                    SameDiff sd = SameDiff.create();
                    SDVariable input = sd.var("in", new int[] { -1, nOut, -1 });
                    SDVariable labels = sd.var("labels", new int[] { -1, nOut, -1 });
                    SDVariable weight = null;
                    if (!"none".equals(weightType)) {
                        weight = sd.var("weights", weightArr);
                    }
                    INDArray inputArr = Nd4j.randn(new int[] { minibatch, nOut, tsLength }).muli(10);
                    INDArray labelsArr = Nd4j.randn(new int[] { minibatch, nOut, tsLength }).muli(10);
                    LossInfo lossInfo;
                    switch(fn) {
                        case "mse":
                            lossInfo = LossFunctions.mse("out", input, labels, weight, reduction, 1, 2);
                            break;
                        case "l1":
                            lossInfo = LossFunctions.l1("out", input, labels, weight, reduction, 1, 2);
                            // L1 = sum abs error
                            break;
                        case "l2":
                            lossInfo = LossFunctions.l2("out", input, labels, weight, reduction, 1, 2);
                            // L2 = sum squared error
                            break;
                        case "mcxent":
                            lossInfo = LossFunctions.mcxent("out", input, labels, weight, reduction, 1, 2);
                            // mcxent = sum label * log(prob)
                            break;
                        default:
                            throw new RuntimeException();
                    }
                    String msg = "lossFn=" + fn + ", reduction=" + reduction + ", weightType=" + weightType + ", binaryWeight=" + binary;
                    log.info("*** Starting test: " + msg);
                    sd.associateArrayWithVariable(inputArr, input);
                    sd.associateArrayWithVariable(labelsArr, labels);
                    if (weight != null) {
                        sd.associateArrayWithVariable(weightArr, weight);
                    }
                    INDArray out = sd.execAndEndResult();
                    assertEquals(1, out.length());
                    boolean ok = GradCheckUtil.checkGradients(sd);
                    assertTrue(msg, ok);
                }
            }
        }
    }
}
Also used : SameDiff(org.nd4j.autodiff.samediff.SameDiff) LossFunctions(org.nd4j.autodiff.loss.LossFunctions) LossInfo(org.nd4j.autodiff.loss.LossInfo) SDVariable(org.nd4j.autodiff.samediff.SDVariable) INDArray(org.nd4j.linalg.api.ndarray.INDArray) BernoulliDistribution(org.nd4j.linalg.api.ops.random.impl.BernoulliDistribution) Test(org.junit.Test)

Example 83 with SDVariable

use of org.nd4j.autodiff.samediff.SDVariable in project nd4j by deeplearning4j.

the class GradCheckReductions method testReduce3.

@Test
public void testReduce3() {
    Nd4j.getRandom().setSeed(12345);
    int d0 = 3;
    int d1 = 4;
    int d2 = 5;
    List<String> allFailed = new ArrayList<>();
    for (int[] reduceDims : new int[][] { { Integer.MAX_VALUE }, { 0, 1, 2 }, { 0 }, { 1 }, { 2 }, { 0, 1 }, { 0, 2 }, { 1, 2 } }) {
        for (int i = 0; i < 6; i++) {
            SameDiff sd = SameDiff.create();
            sd.setLogExecution(false);
            SDVariable in = sd.var("in", new int[] { -1, d1, d2 });
            SDVariable in2 = sd.var("in2", new int[] { -1, d1, d2 });
            INDArray inArr = Nd4j.randn(new int[] { d0, d1, d2 }).muli(100);
            INDArray in2Arr = Nd4j.randn(inArr.shape()).muli(100);
            SDVariable reduced;
            String name;
            switch(i) {
                case 0:
                    reduced = sd.manhattanDistance(in, in2, reduceDims);
                    name = "manhattan";
                    break;
                case 1:
                    reduced = sd.euclideanDistance(in, in2, reduceDims);
                    name = "euclidean";
                    break;
                case 2:
                    reduced = sd.cosineSimilarity(in, in2, reduceDims);
                    name = "cosine";
                    break;
                case 3:
                    reduced = sd.cosineDistance(in, in2, reduceDims);
                    name = "cosinedistance";
                    break;
                case 4:
                    reduced = sd.hammingDistance(in, in2, reduceDims);
                    name = "hamming";
                    break;
                case 5:
                    name = "jaccard";
                    reduced = sd.jaccardDistance(name, in, in2, reduceDims);
                    inArr.divi(100).addi(0.1);
                    in2Arr.divi(100).addi(0.1);
                    break;
                default:
                    throw new RuntimeException();
            }
            // Sum: note that this should be a no-op for the full array cases
            SDVariable sum = sd.sum(reduced, Integer.MAX_VALUE);
            String msg = "(test " + i + " - " + name + ", dimensions=" + Arrays.toString(reduceDims) + ")";
            log.info("*** Starting test: " + msg);
            sd.associateArrayWithVariable(inArr, in);
            sd.associateArrayWithVariable(in2Arr, in2);
            sd.execAndEndResult();
            // FIXME: we can't swallow exceptions here now, but once release out and stuff stabilized - we can
            // try {
            boolean ok = GradCheckUtil.checkGradients(sd, 1e-5, 1e-5, 1e-4, true, false);
            if (!ok) {
                allFailed.add(msg);
            }
        /*
                } catch (Exception e) {
                    e.printStackTrace();
                    allFailed.add(msg + " - EXCEPTION");
                }
                */
        }
    }
    assertEquals("Failed: " + allFailed, 0, allFailed.size());
}
Also used : SDVariable(org.nd4j.autodiff.samediff.SDVariable) INDArray(org.nd4j.linalg.api.ndarray.INDArray) SameDiff(org.nd4j.autodiff.samediff.SameDiff) ArrayList(java.util.ArrayList) Test(org.junit.Test)

Example 84 with SDVariable

use of org.nd4j.autodiff.samediff.SDVariable in project nd4j by deeplearning4j.

the class GradCheckReductions method testReductionGradientsSimple.

@Test
public void testReductionGradientsSimple() {
    // Test reductions: final and only function
    Nd4j.getRandom().setSeed(12345);
    for (int i = 0; i < 12; i++) {
        SameDiff sd = SameDiff.create();
        boolean skipBackward = false;
        int nOut = 4;
        int minibatch = 10;
        SDVariable input = sd.var("in", new int[] { -1, nOut });
        SDVariable loss;
        String name;
        switch(i) {
            case 0:
                loss = sd.mean("loss", input);
                name = "mean";
                break;
            case 1:
                loss = sd.sum("loss", input);
                name = "sum";
                break;
            case 2:
                loss = sd.standardDeviation("loss", input, true);
                name = "stdev";
                break;
            case 3:
                loss = sd.min("loss", input);
                name = "min";
                break;
            case 4:
                loss = sd.max("loss", input);
                name = "max";
                break;
            case 5:
                loss = sd.variance("loss", input, true);
                name = "variance";
                break;
            case 6:
                loss = sd.prod("loss", input);
                name = "prod";
                break;
            case 7:
                loss = sd.norm1("loss", input);
                name = "norm1";
                break;
            case 8:
                loss = sd.norm2("loss", input);
                name = "norm2";
                break;
            case 9:
                loss = sd.normmax("loss", input);
                name = "normmax";
                break;
            case 10:
                loss = sd.countNonZero("loss", input);
                name = "countNonZero";
                skipBackward = true;
                break;
            case 11:
                loss = sd.countZero("loss", input);
                name = "countZero";
                skipBackward = true;
                break;
            default:
                throw new RuntimeException();
        }
        String msg = "test: " + i + " - " + name;
        log.info("*** Starting test: " + msg);
        INDArray inputArr = Nd4j.randn(minibatch, nOut).muli(100);
        sd.associateArrayWithVariable(inputArr, input);
        if (!skipBackward) {
            boolean ok = GradCheckUtil.checkGradients(sd);
            assertTrue(msg, ok);
        }
    }
}
Also used : SDVariable(org.nd4j.autodiff.samediff.SDVariable) INDArray(org.nd4j.linalg.api.ndarray.INDArray) SameDiff(org.nd4j.autodiff.samediff.SameDiff) Test(org.junit.Test)

Example 85 with SDVariable

use of org.nd4j.autodiff.samediff.SDVariable in project nd4j by deeplearning4j.

the class GradCheckReductions method testReductionGradients2.

@Test
public void testReductionGradients2() {
    // Test reductions: NON-final function
    Nd4j.getRandom().setSeed(12345);
    int d0 = 3;
    int d1 = 4;
    int d2 = 5;
    List<String> allFailed = new ArrayList<>();
    for (int reduceDim : new int[] { 0, 1, 2 }) {
        for (int i = 0; i < 12; i++) {
            int[] outShape;
            switch(reduceDim) {
                case 0:
                    outShape = new int[] { d1, d2 };
                    break;
                case 1:
                    outShape = new int[] { d0, d2 };
                    break;
                case 2:
                    outShape = new int[] { d0, d1 };
                    break;
                default:
                    throw new RuntimeException();
            }
            SameDiff sd = SameDiff.create();
            sd.setLogExecution(false);
            SDVariable in = sd.var("in", new int[] { -1, d1, d2 });
            SDVariable label = sd.var("label", outShape);
            SDVariable second = in.mul(2);
            double maxRelError = 1e-5;
            double minAbsError = 1e-4;
            INDArray inputArr = Nd4j.randn(new int[] { d0, d1, d2 }).muli(1000);
            INDArray labelArr = Nd4j.randn(outShape).muli(1000);
            SDVariable reduced;
            String name;
            switch(i) {
                case 0:
                    reduced = sd.mean("reduced", second, reduceDim);
                    name = "mean";
                    break;
                case 1:
                    reduced = sd.sum("reduced", second, reduceDim);
                    name = "sum";
                    break;
                case 2:
                    reduced = sd.standardDeviation("reduced", second, true, reduceDim);
                    name = "stdev";
                    break;
                case 3:
                    reduced = sd.min("reduced", second, reduceDim);
                    name = "min";
                    break;
                case 4:
                    reduced = sd.max("reduced", second, reduceDim);
                    name = "max";
                    break;
                case 5:
                    // Variance is a bit finniky for gradient checks, due to huge score/output...
                    maxRelError = 1e-3;
                    // Most gradients ane in the range 1k to >100k
                    minAbsError = 1;
                    inputArr.divi(10);
                    labelArr.divi(100);
                    BooleanIndexing.replaceWhere(inputArr, Nd4j.rand(inputArr.shape()).muli(100).addi(100), Conditions.absLessThan(1.0));
                    reduced = sd.variance("reduced", second, true, reduceDim);
                    name = "variance";
                    break;
                case 6:
                    inputArr.divi(1000);
                    labelArr.divi(1000);
                    reduced = sd.prod("reduced", second, reduceDim);
                    name = "prod";
                    break;
                case 7:
                    reduced = sd.norm1("reduced", second, reduceDim);
                    name = "norm1";
                    break;
                case 8:
                    reduced = sd.norm2("reduced", second, reduceDim);
                    name = "norm2";
                    break;
                case 9:
                    inputArr = Nd4j.rand(new int[] { d0, d1, d2 });
                    labelArr = Nd4j.rand(outShape);
                    reduced = sd.normmax("reduced", second, reduceDim);
                    name = "normmax";
                    break;
                case 10:
                    reduced = sd.argmax("reduced", second, reduceDim);
                    name = "argmax";
                    break;
                case 11:
                    reduced = sd.argmin("reduced", second, reduceDim);
                    name = "argmin";
                    break;
                default:
                    throw new RuntimeException();
            }
            SDVariable add = reduced.add(1.0);
            SDVariable diff = label.sub(add);
            SDVariable sqDiff = diff.mul(diff);
            SDVariable mseLoss = sd.mean("loss", sqDiff);
            String msg = "(test " + i + " - " + name + ", dimension=" + reduceDim + ")";
            log.info("*** Starting test: " + msg);
            sd.associateArrayWithVariable(inputArr, in);
            sd.associateArrayWithVariable(labelArr, label);
            try {
                boolean ok = GradCheckUtil.checkGradients(sd, 1e-5, maxRelError, minAbsError, true, false);
                if (!ok) {
                    allFailed.add(msg);
                }
            } catch (Exception e) {
                e.printStackTrace();
                allFailed.add(msg + " - EXCEPTION");
            }
        }
    }
    assertEquals("Failed: " + allFailed, 0, allFailed.size());
}
Also used : SDVariable(org.nd4j.autodiff.samediff.SDVariable) INDArray(org.nd4j.linalg.api.ndarray.INDArray) SameDiff(org.nd4j.autodiff.samediff.SameDiff) ArrayList(java.util.ArrayList) Test(org.junit.Test)

Aggregations

SDVariable (org.nd4j.autodiff.samediff.SDVariable)104 SameDiff (org.nd4j.autodiff.samediff.SameDiff)41 INDArray (org.nd4j.linalg.api.ndarray.INDArray)38 Test (org.junit.Test)36 ArrayList (java.util.ArrayList)18 DynamicCustomOp (org.nd4j.linalg.api.ops.DynamicCustomOp)10 lombok.val (lombok.val)7 LossFunctions (org.nd4j.autodiff.loss.LossFunctions)4 LossInfo (org.nd4j.autodiff.loss.LossInfo)4 BernoulliDistribution (org.nd4j.linalg.api.ops.random.impl.BernoulliDistribution)4 Ignore (org.junit.Ignore)3 DifferentialFunction (org.nd4j.autodiff.functions.DifferentialFunction)3 ND4JIllegalStateException (org.nd4j.linalg.exception.ND4JIllegalStateException)3 Triple (org.nd4j.linalg.primitives.Triple)2 DataOutputStream (java.io.DataOutputStream)1 FileOutputStream (java.io.FileOutputStream)1 ByteBuffer (java.nio.ByteBuffer)1 NoOpNameFoundException (org.nd4j.imports.NoOpNameFoundException)1 NdIndexIterator (org.nd4j.linalg.api.iter.NdIndexIterator)1 TruncateDivOp (org.nd4j.linalg.api.ops.impl.transforms.arithmetic.TruncateDivOp)1