Search in sources :

Example 31 with SDVariable

use of org.nd4j.autodiff.samediff.SDVariable in project nd4j by deeplearning4j.

the class GradCheckTransforms method invertPermutation.

@Test
public void invertPermutation() {
    SameDiff sd = SameDiff.create();
    INDArray ia = Nd4j.create(new float[] { 3, 4, 0, 2, 1 });
    INDArray expOut = Nd4j.create(new float[] { 2, 4, 3, 0, 1 });
    SDVariable input = sd.var("in", new int[] { 1, 5 });
    sd.associateArrayWithVariable(ia, input);
    SDVariable out = sd.invertPermutation(input);
    sd.exec();
    assert expOut.equals(out.getArr());
}
Also used : SDVariable(org.nd4j.autodiff.samediff.SDVariable) INDArray(org.nd4j.linalg.api.ndarray.INDArray) SameDiff(org.nd4j.autodiff.samediff.SameDiff) Test(org.junit.Test)

Example 32 with SDVariable

use of org.nd4j.autodiff.samediff.SDVariable in project nd4j by deeplearning4j.

the class TensorFlowImportTest method importGraph1.

@Test
@Ignore
public void importGraph1() throws Exception {
    SameDiff graph = TFGraphMapper.getInstance().importGraph(new ClassPathResource("tf_graphs/max_add_2.pb.txt").getInputStream());
    assertNotNull(graph);
    assertEquals(2, graph.variableMap().size());
    SDVariable var0 = graph.variableMap().get("zeros");
    SDVariable var1 = graph.variableMap().get("ones");
    assertNotNull(var0);
    assertNotNull(var1);
    assertNotNull(var0.getArr());
    assertNotNull(var1.getArr());
    assertEquals(0.0, var0.getArr().sumNumber().doubleValue(), 1e-5);
    assertEquals(12.0, var1.getArr().sumNumber().doubleValue(), 1e-5);
}
Also used : SDVariable(org.nd4j.autodiff.samediff.SDVariable) SameDiff(org.nd4j.autodiff.samediff.SameDiff) ClassPathResource(org.nd4j.linalg.io.ClassPathResource) Ignore(org.junit.Ignore) Test(org.junit.Test)

Example 33 with SDVariable

use of org.nd4j.autodiff.samediff.SDVariable in project nd4j by deeplearning4j.

the class GradCheckLoss method testLossWeights2d.

@Test
public void testLossWeights2d() {
    String[] weightTypes = new String[] { "none", "per-example", "per-output", "per-example-output" };
    Nd4j.getRandom().setSeed(12345);
    int nOut = 4;
    int minibatch = 10;
    for (String weightType : weightTypes) {
        for (boolean binary : new boolean[] { true, false }) {
            // Binary mask (like DL4J) or arbitrary weights?
            int[] weightShape;
            switch(weightType) {
                case "none":
                    weightShape = null;
                    break;
                case "per-example":
                    weightShape = new int[] { minibatch, 1 };
                    break;
                case "per-output":
                    weightShape = new int[] { 1, nOut };
                    break;
                case "per-example-output":
                    weightShape = new int[] { minibatch, nOut };
                    break;
                default:
                    throw new RuntimeException("Unknown type: " + weightType);
            }
            INDArray weightArr = null;
            if (!"none".equals(weightType)) {
                if (binary) {
                    weightArr = Nd4j.getExecutioner().exec(new BernoulliDistribution(Nd4j.createUninitialized(weightShape), 0.5));
                } else {
                    weightArr = Nd4j.rand(weightShape).muli(2.0);
                }
            }
            for (LossFunctions.Reduction reduction : new LossFunctions.Reduction[] { LossFunctions.Reduction.MEAN_BY_COUNT, LossFunctions.Reduction.MEAN_BY_WEIGHT, LossFunctions.Reduction.SUM }) {
                for (String fn : new String[] { "mse", "l1", "l2", "mcxent" }) {
                    SameDiff sd = SameDiff.create();
                    SDVariable input = sd.var("in", new int[] { -1, nOut });
                    SDVariable labels = sd.var("labels", new int[] { -1, nOut });
                    SDVariable weight = null;
                    if (!"none".equals(weightType)) {
                        weight = sd.var("weights", weightArr);
                    }
                    INDArray inputArr = Nd4j.randn(minibatch, nOut).muli(100);
                    INDArray labelsArr = Nd4j.randn(minibatch, nOut).muli(100);
                    LossInfo lossInfo;
                    switch(fn) {
                        case "mse":
                            lossInfo = LossFunctions.mse("out", input, labels, weight, reduction, 1);
                            break;
                        case "l1":
                            lossInfo = LossFunctions.l1("out", input, labels, weight, reduction, 1);
                            // L1 = sum abs error
                            break;
                        case "l2":
                            lossInfo = LossFunctions.l2("out", input, labels, weight, reduction, 1);
                            // L2 = sum squared error
                            break;
                        case "mcxent":
                            lossInfo = LossFunctions.mcxent("out", input, labels, weight, reduction, 1);
                            // mcxent = sum label * log(prob)
                            break;
                        default:
                            throw new RuntimeException();
                    }
                    String msg = "lossFn=" + fn + ", reduction=" + reduction + ", weightType=" + weightType + ", binaryWeight=" + binary;
                    log.info("*** Starting test: " + msg);
                    sd.associateArrayWithVariable(inputArr, input);
                    sd.associateArrayWithVariable(labelsArr, labels);
                    if (weight != null) {
                        sd.associateArrayWithVariable(weightArr, weight);
                    }
                    INDArray out = sd.execAndEndResult();
                    assertEquals(1, out.length());
                    boolean ok = GradCheckUtil.checkGradients(sd);
                    assertTrue(msg, ok);
                }
            }
        }
    }
}
Also used : SameDiff(org.nd4j.autodiff.samediff.SameDiff) LossFunctions(org.nd4j.autodiff.loss.LossFunctions) LossInfo(org.nd4j.autodiff.loss.LossInfo) SDVariable(org.nd4j.autodiff.samediff.SDVariable) INDArray(org.nd4j.linalg.api.ndarray.INDArray) BernoulliDistribution(org.nd4j.linalg.api.ops.random.impl.BernoulliDistribution) Test(org.junit.Test)

Example 34 with SDVariable

use of org.nd4j.autodiff.samediff.SDVariable in project nd4j by deeplearning4j.

the class GradCheckLoss method testLossSimple2d.

@Test
public void testLossSimple2d() {
    Nd4j.getRandom().setSeed(12345);
    for (String fn : new String[] { "mse", "l1", "l2", "mcxent" }) {
        for (LossFunctions.Reduction reduction : new LossFunctions.Reduction[] { LossFunctions.Reduction.MEAN_BY_COUNT, LossFunctions.Reduction.MEAN_BY_WEIGHT, LossFunctions.Reduction.SUM }) {
            SameDiff sd = SameDiff.create();
            int nOut = 4;
            int minibatch = 10;
            SDVariable input = sd.var("in", new int[] { -1, nOut });
            SDVariable labels = sd.var("labels", new int[] { -1, nOut });
            INDArray inputArr = Nd4j.randn(minibatch, nOut).muli(100);
            INDArray labelsArr = Nd4j.randn(minibatch, nOut).muli(100);
            LossInfo lossInfo;
            INDArray expOut;
            switch(fn) {
                case "mse":
                    lossInfo = LossFunctions.mse("out", input, labels, null, reduction, 1);
                    expOut = inputArr.sub(labelsArr);
                    expOut.muli(expOut);
                    expOut = expOut.mean(Integer.MAX_VALUE);
                    break;
                case "l1":
                    lossInfo = LossFunctions.l1("out", input, labels, null, reduction, 1);
                    // L1 = sum abs error
                    expOut = Transforms.abs(inputArr.sub(labelsArr)).sum(1);
                    expOut = expOut.mean(Integer.MAX_VALUE);
                    break;
                case "l2":
                    lossInfo = LossFunctions.l2("out", input, labels, null, reduction, 1);
                    // L2 = sum squared error
                    expOut = Transforms.pow(inputArr.sub(labelsArr), 2.0).sum(1).mean(Integer.MAX_VALUE);
                    break;
                case "mcxent":
                    lossInfo = LossFunctions.mcxent("out", input, labels, null, reduction, 1);
                    // mcxent = sum label * log(prob)
                    expOut = labelsArr.mul(Transforms.log(inputArr)).sum(1).mean(Integer.MAX_VALUE);
                    break;
                default:
                    throw new RuntimeException();
            }
            String msg = "test: " + lossInfo.getLossName() + ", reduction=" + reduction;
            log.info("*** Starting test: " + msg);
            sd.associateArrayWithVariable(inputArr, input);
            sd.associateArrayWithVariable(labelsArr, labels);
            // System.out.println(sd.asFlatPrint());
            INDArray out = sd.execAndEndResult();
            assertEquals(msg, expOut, out);
            System.out.println("STARTING GRADIENT CHECK");
            boolean ok = GradCheckUtil.checkGradients(sd);
            assertTrue(msg, ok);
        }
    }
}
Also used : LossInfo(org.nd4j.autodiff.loss.LossInfo) SDVariable(org.nd4j.autodiff.samediff.SDVariable) INDArray(org.nd4j.linalg.api.ndarray.INDArray) SameDiff(org.nd4j.autodiff.samediff.SameDiff) LossFunctions(org.nd4j.autodiff.loss.LossFunctions) Test(org.junit.Test)

Example 35 with SDVariable

use of org.nd4j.autodiff.samediff.SDVariable in project nd4j by deeplearning4j.

the class GradCheckLoss method testLossWeights4d.

@Test
public void testLossWeights4d() {
    String[] weightTypes = new String[] { "none", "per-example", "per-depth", "per-height", "per-width", "per-height-width", "per-depth-height", "per-depth-width", "per-example-depth", "per-example-height", "per-example-height-width", "per-all" };
    Nd4j.getRandom().setSeed(12345);
    // Assume NCHW format here
    int minibatch = 10;
    int depth = 4;
    int h = 5;
    int w = 6;
    for (String weightType : weightTypes) {
        for (boolean binary : new boolean[] { true, false }) {
            // Binary mask (like DL4J) or arbitrary weights?
            int[] weightShape;
            switch(weightType) {
                case "none":
                    weightShape = null;
                    break;
                case "per-example":
                    weightShape = new int[] { minibatch, 1, 1, 1 };
                    break;
                case "per-depth":
                    weightShape = new int[] { 1, depth, 1, 1 };
                    break;
                case "per-height":
                    weightShape = new int[] { 1, 1, h, 1 };
                    break;
                case "per-width":
                    weightShape = new int[] { 1, 1, 1, w };
                    break;
                case "per-height-width":
                    weightShape = new int[] { 1, 1, h, w };
                    break;
                case "per-depth-height":
                    weightShape = new int[] { 1, depth, h, 1 };
                    break;
                case "per-depth-width":
                    weightShape = new int[] { 1, depth, 1, w };
                    break;
                case "per-example-depth":
                    weightShape = new int[] { minibatch, depth, 1, 1 };
                    break;
                case "per-example-height":
                    weightShape = new int[] { minibatch, 1, h, 1 };
                    break;
                case "per-example-height-width":
                    weightShape = new int[] { minibatch, 1, h, w };
                    break;
                case "per-all":
                    weightShape = new int[] { minibatch, depth, h, w };
                    break;
                default:
                    throw new RuntimeException("Unknown type: " + weightType);
            }
            INDArray weightArr = null;
            if (!"none".equals(weightType)) {
                if (binary) {
                    weightArr = Nd4j.getExecutioner().exec(new BernoulliDistribution(Nd4j.createUninitialized(weightShape), 0.5));
                } else {
                    weightArr = Nd4j.rand(weightShape).muli(2.0);
                }
            }
            for (LossFunctions.Reduction reduction : new LossFunctions.Reduction[] { LossFunctions.Reduction.MEAN_BY_COUNT, LossFunctions.Reduction.MEAN_BY_WEIGHT, LossFunctions.Reduction.SUM }) {
                for (String fn : new String[] { "mse", "l1", "l2", "mcxent" }) {
                    SameDiff sd = SameDiff.create();
                    SDVariable input = sd.var("in", new int[] { -1, depth, -1, -1 });
                    SDVariable labels = sd.var("labels", new int[] { -1, depth, -1, -1 });
                    SDVariable weight = null;
                    if (!"none".equals(weightType)) {
                        weight = sd.var("weights", weightArr);
                    }
                    INDArray inputArr = Nd4j.randn(new int[] { minibatch, depth, h, w }).muli(10);
                    INDArray labelsArr = Nd4j.randn(new int[] { minibatch, depth, h, w }).muli(10);
                    LossInfo lossInfo;
                    switch(fn) {
                        case "mse":
                            lossInfo = LossFunctions.mse("out", input, labels, weight, reduction, 1, 2, 3);
                            break;
                        case "l1":
                            lossInfo = LossFunctions.l1("out", input, labels, weight, reduction, 1, 2, 3);
                            // L1 = sum abs error
                            break;
                        case "l2":
                            lossInfo = LossFunctions.l2("out", input, labels, weight, reduction, 1, 2, 3);
                            // L2 = sum squared error
                            break;
                        case "mcxent":
                            lossInfo = LossFunctions.mcxent("out", input, labels, weight, reduction, 1, 2, 3);
                            // mcxent = sum label * log(prob)
                            break;
                        default:
                            throw new RuntimeException();
                    }
                    String msg = "lossFn=" + fn + ", reduction=" + reduction + ", weightType=" + weightType + ", binaryWeight=" + binary;
                    log.info("*** Starting test: " + msg);
                    sd.associateArrayWithVariable(inputArr, input);
                    sd.associateArrayWithVariable(labelsArr, labels);
                    if (weight != null) {
                        sd.associateArrayWithVariable(weightArr, weight);
                    }
                    INDArray out = sd.execAndEndResult();
                    assertEquals(1, out.length());
                    boolean ok = GradCheckUtil.checkGradients(sd);
                    assertTrue(msg, ok);
                }
            }
        }
    }
}
Also used : SameDiff(org.nd4j.autodiff.samediff.SameDiff) LossFunctions(org.nd4j.autodiff.loss.LossFunctions) LossInfo(org.nd4j.autodiff.loss.LossInfo) SDVariable(org.nd4j.autodiff.samediff.SDVariable) INDArray(org.nd4j.linalg.api.ndarray.INDArray) BernoulliDistribution(org.nd4j.linalg.api.ops.random.impl.BernoulliDistribution) Test(org.junit.Test)

Aggregations

SDVariable (org.nd4j.autodiff.samediff.SDVariable)104 SameDiff (org.nd4j.autodiff.samediff.SameDiff)41 INDArray (org.nd4j.linalg.api.ndarray.INDArray)38 Test (org.junit.Test)36 ArrayList (java.util.ArrayList)18 DynamicCustomOp (org.nd4j.linalg.api.ops.DynamicCustomOp)10 lombok.val (lombok.val)7 LossFunctions (org.nd4j.autodiff.loss.LossFunctions)4 LossInfo (org.nd4j.autodiff.loss.LossInfo)4 BernoulliDistribution (org.nd4j.linalg.api.ops.random.impl.BernoulliDistribution)4 Ignore (org.junit.Ignore)3 DifferentialFunction (org.nd4j.autodiff.functions.DifferentialFunction)3 ND4JIllegalStateException (org.nd4j.linalg.exception.ND4JIllegalStateException)3 Triple (org.nd4j.linalg.primitives.Triple)2 DataOutputStream (java.io.DataOutputStream)1 FileOutputStream (java.io.FileOutputStream)1 ByteBuffer (java.nio.ByteBuffer)1 NoOpNameFoundException (org.nd4j.imports.NoOpNameFoundException)1 NdIndexIterator (org.nd4j.linalg.api.iter.NdIndexIterator)1 TruncateDivOp (org.nd4j.linalg.api.ops.impl.transforms.arithmetic.TruncateDivOp)1