Search in sources :

Example 31 with SameDiff

use of org.nd4j.autodiff.samediff.SameDiff in project nd4j by deeplearning4j.

the class GraphExecutionerTest method testEquality1.

/**
 * VarSpace should dump everything. 4 variables in our case
 * @throws Exception
 */
@Test
public void testEquality1() throws Exception {
    GraphExecutioner executionerA = new BasicGraphExecutioner();
    GraphExecutioner executionerB = new NativeGraphExecutioner();
    SameDiff sameDiff = SameDiff.create();
    INDArray ones = Nd4j.ones(4);
    SDVariable sdVariable = sameDiff.var("ones", ones);
    SDVariable scalarOne = sameDiff.var("add1", Nd4j.scalar(1.0));
    SDVariable result = sdVariable.addi(scalarOne);
    SDVariable total = sameDiff.sum(result, Integer.MAX_VALUE);
    log.info("TOTAL: {}; Id: {}", total.getVarName(), total);
    INDArray[] resB = executionerB.executeGraph(sameDiff, configVarSpace);
    assertEquals(6, resB.length);
    assertEquals(Nd4j.create(new float[] { 2f, 2f, 2f, 2f }), resB[4]);
    assertEquals(Nd4j.scalar(1), resB[1]);
    assertEquals(Nd4j.scalar(8.0), resB[5]);
}
Also used : SDVariable(org.nd4j.autodiff.samediff.SDVariable) INDArray(org.nd4j.linalg.api.ndarray.INDArray) SameDiff(org.nd4j.autodiff.samediff.SameDiff) Test(org.junit.Test)

Example 32 with SameDiff

use of org.nd4j.autodiff.samediff.SameDiff in project nd4j by deeplearning4j.

the class GraphExecutionerTest method testEquality2.

/**
 * Implicit should return tree edges. So, one variable
 * @throws Exception
 */
@Test
public void testEquality2() throws Exception {
    GraphExecutioner executionerA = new BasicGraphExecutioner();
    GraphExecutioner executionerB = new NativeGraphExecutioner();
    SameDiff sameDiff = SameDiff.create();
    INDArray ones = Nd4j.ones(4);
    SDVariable sdVariable = sameDiff.var("ones", ones);
    SDVariable scalarOne = sameDiff.var("add1", Nd4j.scalar(1.0));
    SDVariable result = sdVariable.addi(scalarOne);
    SDVariable total = sameDiff.sum(result, Integer.MAX_VALUE);
    // log.info("ID: {}",sameDiff.getGraph().getVertex(1).getValue().getId());
    INDArray[] resB = executionerB.executeGraph(sameDiff, configImplicit);
    assertEquals(1, resB.length);
    assertEquals(Nd4j.scalar(8.0), resB[0]);
// INDArray resA = executionerA.executeGraph(sameDiff)[0];
// assertEquals(resA, resB);
}
Also used : SDVariable(org.nd4j.autodiff.samediff.SDVariable) INDArray(org.nd4j.linalg.api.ndarray.INDArray) SameDiff(org.nd4j.autodiff.samediff.SameDiff) Test(org.junit.Test)

Example 33 with SameDiff

use of org.nd4j.autodiff.samediff.SameDiff in project nd4j by deeplearning4j.

the class GradCheckLoss method testLossWeights3d.

@Test
public void testLossWeights3d() {
    String[] weightTypes = new String[] { "none", "per-example", "per-output", "per-timestep", "per-example-output", "per-example-timestep", "per-output-timestep", "per-all" };
    Nd4j.getRandom().setSeed(12345);
    int nOut = 4;
    int minibatch = 10;
    int tsLength = 5;
    for (String weightType : weightTypes) {
        for (boolean binary : new boolean[] { true, false }) {
            // Binary mask (like DL4J) or arbitrary weights?
            int[] weightShape;
            switch(weightType) {
                case "none":
                    weightShape = null;
                    break;
                case "per-example":
                    weightShape = new int[] { minibatch, 1, 1 };
                    break;
                case "per-output":
                    weightShape = new int[] { 1, nOut, 1 };
                    break;
                case "per-timestep":
                    weightShape = new int[] { 1, 1, tsLength };
                    break;
                case "per-example-output":
                    weightShape = new int[] { minibatch, nOut, 1 };
                    break;
                case "per-example-timestep":
                    weightShape = new int[] { minibatch, 1, nOut };
                    break;
                case "per-output-timestep":
                    weightShape = new int[] { 1, nOut, tsLength };
                    break;
                case "per-all":
                    weightShape = new int[] { minibatch, nOut, tsLength };
                    break;
                default:
                    throw new RuntimeException("Unknown type: " + weightType);
            }
            INDArray weightArr = null;
            if (!"none".equals(weightType)) {
                if (binary) {
                    weightArr = Nd4j.getExecutioner().exec(new BernoulliDistribution(Nd4j.createUninitialized(weightShape), 0.5));
                } else {
                    weightArr = Nd4j.rand(weightShape).muli(2.0);
                }
            }
            for (LossFunctions.Reduction reduction : new LossFunctions.Reduction[] { LossFunctions.Reduction.MEAN_BY_COUNT, LossFunctions.Reduction.MEAN_BY_WEIGHT, LossFunctions.Reduction.SUM }) {
                for (String fn : new String[] { "mse", "l1", "l2", "mcxent" }) {
                    SameDiff sd = SameDiff.create();
                    SDVariable input = sd.var("in", new int[] { -1, nOut, -1 });
                    SDVariable labels = sd.var("labels", new int[] { -1, nOut, -1 });
                    SDVariable weight = null;
                    if (!"none".equals(weightType)) {
                        weight = sd.var("weights", weightArr);
                    }
                    INDArray inputArr = Nd4j.randn(new int[] { minibatch, nOut, tsLength }).muli(10);
                    INDArray labelsArr = Nd4j.randn(new int[] { minibatch, nOut, tsLength }).muli(10);
                    LossInfo lossInfo;
                    switch(fn) {
                        case "mse":
                            lossInfo = LossFunctions.mse("out", input, labels, weight, reduction, 1, 2);
                            break;
                        case "l1":
                            lossInfo = LossFunctions.l1("out", input, labels, weight, reduction, 1, 2);
                            // L1 = sum abs error
                            break;
                        case "l2":
                            lossInfo = LossFunctions.l2("out", input, labels, weight, reduction, 1, 2);
                            // L2 = sum squared error
                            break;
                        case "mcxent":
                            lossInfo = LossFunctions.mcxent("out", input, labels, weight, reduction, 1, 2);
                            // mcxent = sum label * log(prob)
                            break;
                        default:
                            throw new RuntimeException();
                    }
                    String msg = "lossFn=" + fn + ", reduction=" + reduction + ", weightType=" + weightType + ", binaryWeight=" + binary;
                    log.info("*** Starting test: " + msg);
                    sd.associateArrayWithVariable(inputArr, input);
                    sd.associateArrayWithVariable(labelsArr, labels);
                    if (weight != null) {
                        sd.associateArrayWithVariable(weightArr, weight);
                    }
                    INDArray out = sd.execAndEndResult();
                    assertEquals(1, out.length());
                    boolean ok = GradCheckUtil.checkGradients(sd);
                    assertTrue(msg, ok);
                }
            }
        }
    }
}
Also used : SameDiff(org.nd4j.autodiff.samediff.SameDiff) LossFunctions(org.nd4j.autodiff.loss.LossFunctions) LossInfo(org.nd4j.autodiff.loss.LossInfo) SDVariable(org.nd4j.autodiff.samediff.SDVariable) INDArray(org.nd4j.linalg.api.ndarray.INDArray) BernoulliDistribution(org.nd4j.linalg.api.ops.random.impl.BernoulliDistribution) Test(org.junit.Test)

Example 34 with SameDiff

use of org.nd4j.autodiff.samediff.SameDiff in project nd4j by deeplearning4j.

the class GradCheckReductions method testReduce3.

@Test
public void testReduce3() {
    Nd4j.getRandom().setSeed(12345);
    int d0 = 3;
    int d1 = 4;
    int d2 = 5;
    List<String> allFailed = new ArrayList<>();
    for (int[] reduceDims : new int[][] { { Integer.MAX_VALUE }, { 0, 1, 2 }, { 0 }, { 1 }, { 2 }, { 0, 1 }, { 0, 2 }, { 1, 2 } }) {
        for (int i = 0; i < 6; i++) {
            SameDiff sd = SameDiff.create();
            sd.setLogExecution(false);
            SDVariable in = sd.var("in", new int[] { -1, d1, d2 });
            SDVariable in2 = sd.var("in2", new int[] { -1, d1, d2 });
            INDArray inArr = Nd4j.randn(new int[] { d0, d1, d2 }).muli(100);
            INDArray in2Arr = Nd4j.randn(inArr.shape()).muli(100);
            SDVariable reduced;
            String name;
            switch(i) {
                case 0:
                    reduced = sd.manhattanDistance(in, in2, reduceDims);
                    name = "manhattan";
                    break;
                case 1:
                    reduced = sd.euclideanDistance(in, in2, reduceDims);
                    name = "euclidean";
                    break;
                case 2:
                    reduced = sd.cosineSimilarity(in, in2, reduceDims);
                    name = "cosine";
                    break;
                case 3:
                    reduced = sd.cosineDistance(in, in2, reduceDims);
                    name = "cosinedistance";
                    break;
                case 4:
                    reduced = sd.hammingDistance(in, in2, reduceDims);
                    name = "hamming";
                    break;
                case 5:
                    name = "jaccard";
                    reduced = sd.jaccardDistance(name, in, in2, reduceDims);
                    inArr.divi(100).addi(0.1);
                    in2Arr.divi(100).addi(0.1);
                    break;
                default:
                    throw new RuntimeException();
            }
            // Sum: note that this should be a no-op for the full array cases
            SDVariable sum = sd.sum(reduced, Integer.MAX_VALUE);
            String msg = "(test " + i + " - " + name + ", dimensions=" + Arrays.toString(reduceDims) + ")";
            log.info("*** Starting test: " + msg);
            sd.associateArrayWithVariable(inArr, in);
            sd.associateArrayWithVariable(in2Arr, in2);
            sd.execAndEndResult();
            // FIXME: we can't swallow exceptions here now, but once release out and stuff stabilized - we can
            // try {
            boolean ok = GradCheckUtil.checkGradients(sd, 1e-5, 1e-5, 1e-4, true, false);
            if (!ok) {
                allFailed.add(msg);
            }
        /*
                } catch (Exception e) {
                    e.printStackTrace();
                    allFailed.add(msg + " - EXCEPTION");
                }
                */
        }
    }
    assertEquals("Failed: " + allFailed, 0, allFailed.size());
}
Also used : SDVariable(org.nd4j.autodiff.samediff.SDVariable) INDArray(org.nd4j.linalg.api.ndarray.INDArray) SameDiff(org.nd4j.autodiff.samediff.SameDiff) ArrayList(java.util.ArrayList) Test(org.junit.Test)

Example 35 with SameDiff

use of org.nd4j.autodiff.samediff.SameDiff in project nd4j by deeplearning4j.

the class GradCheckReductions method testReductionGradientsSimple.

@Test
public void testReductionGradientsSimple() {
    // Test reductions: final and only function
    Nd4j.getRandom().setSeed(12345);
    for (int i = 0; i < 12; i++) {
        SameDiff sd = SameDiff.create();
        boolean skipBackward = false;
        int nOut = 4;
        int minibatch = 10;
        SDVariable input = sd.var("in", new int[] { -1, nOut });
        SDVariable loss;
        String name;
        switch(i) {
            case 0:
                loss = sd.mean("loss", input);
                name = "mean";
                break;
            case 1:
                loss = sd.sum("loss", input);
                name = "sum";
                break;
            case 2:
                loss = sd.standardDeviation("loss", input, true);
                name = "stdev";
                break;
            case 3:
                loss = sd.min("loss", input);
                name = "min";
                break;
            case 4:
                loss = sd.max("loss", input);
                name = "max";
                break;
            case 5:
                loss = sd.variance("loss", input, true);
                name = "variance";
                break;
            case 6:
                loss = sd.prod("loss", input);
                name = "prod";
                break;
            case 7:
                loss = sd.norm1("loss", input);
                name = "norm1";
                break;
            case 8:
                loss = sd.norm2("loss", input);
                name = "norm2";
                break;
            case 9:
                loss = sd.normmax("loss", input);
                name = "normmax";
                break;
            case 10:
                loss = sd.countNonZero("loss", input);
                name = "countNonZero";
                skipBackward = true;
                break;
            case 11:
                loss = sd.countZero("loss", input);
                name = "countZero";
                skipBackward = true;
                break;
            default:
                throw new RuntimeException();
        }
        String msg = "test: " + i + " - " + name;
        log.info("*** Starting test: " + msg);
        INDArray inputArr = Nd4j.randn(minibatch, nOut).muli(100);
        sd.associateArrayWithVariable(inputArr, input);
        if (!skipBackward) {
            boolean ok = GradCheckUtil.checkGradients(sd);
            assertTrue(msg, ok);
        }
    }
}
Also used : SDVariable(org.nd4j.autodiff.samediff.SDVariable) INDArray(org.nd4j.linalg.api.ndarray.INDArray) SameDiff(org.nd4j.autodiff.samediff.SameDiff) Test(org.junit.Test)

Aggregations

SameDiff (org.nd4j.autodiff.samediff.SameDiff)50 Test (org.junit.Test)42 SDVariable (org.nd4j.autodiff.samediff.SDVariable)41 INDArray (org.nd4j.linalg.api.ndarray.INDArray)37 ArrayList (java.util.ArrayList)10 DynamicCustomOp (org.nd4j.linalg.api.ops.DynamicCustomOp)10 Ignore (org.junit.Ignore)7 ClassPathResource (org.nd4j.linalg.io.ClassPathResource)6 lombok.val (lombok.val)4 LossFunctions (org.nd4j.autodiff.loss.LossFunctions)4 LossInfo (org.nd4j.autodiff.loss.LossInfo)4 BernoulliDistribution (org.nd4j.linalg.api.ops.random.impl.BernoulliDistribution)4 DifferentialFunction (org.nd4j.autodiff.functions.DifferentialFunction)2 Triple (org.nd4j.linalg.primitives.Triple)2 ZeroInitScheme (org.nd4j.weightinit.impl.ZeroInitScheme)2 DataOutputStream (java.io.DataOutputStream)1 FileOutputStream (java.io.FileOutputStream)1 ByteBuffer (java.nio.ByteBuffer)1 HashMap (java.util.HashMap)1 LinkedHashMap (java.util.LinkedHashMap)1