Search in sources :

Example 6 with SameDiff

use of org.nd4j.autodiff.samediff.SameDiff in project nd4j by deeplearning4j.

the class GradCheckTransforms method testConditions.

@Test
public void testConditions() {
    SameDiff sd = SameDiff.create();
    INDArray ia = Nd4j.create(new float[] { 4, 2 });
    SDVariable in = sd.var("in", new int[] { 1, 2 });
    sd.associateArrayWithVariable(ia, in);
    INDArray expFinite = Nd4j.create(new float[] { 1, 1 });
    SDVariable finite = sd.isFinite(in);
    INDArray expInfinite = Nd4j.create(new float[] { 0, 0 });
    SDVariable infinite = sd.isInfinite(in);
    INDArray expNaN = Nd4j.create(new float[] { 0, 0 });
    SDVariable isnan = sd.isNaN(in);
    sd.exec();
    assert (expFinite.equals(finite.getArr()));
    assert (expInfinite.equals(infinite.getArr()));
    assert (expNaN.equals(isnan.getArr()));
}
Also used : SDVariable(org.nd4j.autodiff.samediff.SDVariable) INDArray(org.nd4j.linalg.api.ndarray.INDArray) SameDiff(org.nd4j.autodiff.samediff.SameDiff) Test(org.junit.Test)

Example 7 with SameDiff

use of org.nd4j.autodiff.samediff.SameDiff in project nd4j by deeplearning4j.

the class GradCheckTransforms method testSpaceToDepth.

@Test
public void testSpaceToDepth() {
    Nd4j.getRandom().setSeed(1337);
    int miniBatch = 128;
    int blockSize = 4;
    String dataFormat = "NHWC";
    int isNHWC = dataFormat.equals("NHWC") ? 1 : 0;
    int[] inputShape = new int[] { miniBatch, 2 * blockSize, 2 * blockSize, 1 };
    INDArray input = Nd4j.randn(inputShape);
    SameDiff sd = SameDiff.create();
    SDVariable sdInput = sd.var("in", inputShape);
    INDArray expOut = Nd4j.create(miniBatch, 2, 2, blockSize * blockSize);
    DynamicCustomOp op = DynamicCustomOp.builder("space_to_depth").addInputs(input).addIntegerArguments(blockSize, isNHWC).addOutputs(expOut).build();
    Nd4j.getExecutioner().exec(op);
    sd.associateArrayWithVariable(input, sdInput);
    SDVariable t = sd.spaceToDepth(sdInput, blockSize, dataFormat);
    SDVariable loss = sd.mean("loss", t);
    sd.exec();
    INDArray out = t.getArr();
    if (!expOut.equals(out)) {
        log.info("depth to space failed on forward");
    }
    try {
        GradCheckUtil.checkGradients(sd);
    } catch (Exception e) {
        e.printStackTrace();
    }
}
Also used : SDVariable(org.nd4j.autodiff.samediff.SDVariable) INDArray(org.nd4j.linalg.api.ndarray.INDArray) SameDiff(org.nd4j.autodiff.samediff.SameDiff) DynamicCustomOp(org.nd4j.linalg.api.ops.DynamicCustomOp) Test(org.junit.Test)

Example 8 with SameDiff

use of org.nd4j.autodiff.samediff.SameDiff in project nd4j by deeplearning4j.

the class GradCheckTransforms method testBatchToSpace.

@Test
public void testBatchToSpace() {
    Nd4j.getRandom().setSeed(1337);
    int miniBatch = 4;
    int[] inputShape = new int[] { miniBatch, 1, 1, 1 };
    int M = 2;
    int[] blockShape = new int[] { M, 1 };
    int[] cropShape = new int[] { M, 2 };
    INDArray input = Nd4j.randn(inputShape);
    INDArray blocks = Nd4j.create(new float[] { 2, 2 }, blockShape);
    INDArray crops = Nd4j.create(new float[] { 0, 0, 0, 0 }, cropShape);
    SameDiff sd = SameDiff.create();
    SDVariable sdInput = sd.var("in", inputShape);
    INDArray expOut = Nd4j.create(1, 2, 2, 1);
    DynamicCustomOp op = DynamicCustomOp.builder("batch_to_space").addInputs(input, blocks, crops).addOutputs(expOut).build();
    Nd4j.getExecutioner().exec(op);
    sd.associateArrayWithVariable(input, sdInput);
    SDVariable t = sd.batchToSpace(sdInput, new int[] { 2, 2 }, new int[][] { { 0, 0 }, { 0, 0 } });
    SDVariable loss = sd.mean("loss", t);
    sd.exec();
    INDArray out = t.getArr();
    if (!expOut.equals(out)) {
        log.info("batch to space failed on forward");
    }
    try {
        GradCheckUtil.checkGradients(sd);
    } catch (Exception e) {
        e.printStackTrace();
    }
}
Also used : SDVariable(org.nd4j.autodiff.samediff.SDVariable) INDArray(org.nd4j.linalg.api.ndarray.INDArray) SameDiff(org.nd4j.autodiff.samediff.SameDiff) DynamicCustomOp(org.nd4j.linalg.api.ops.DynamicCustomOp) Test(org.junit.Test)

Example 9 with SameDiff

use of org.nd4j.autodiff.samediff.SameDiff in project nd4j by deeplearning4j.

the class GradCheckTransforms method invertPermutation.

@Test
public void invertPermutation() {
    SameDiff sd = SameDiff.create();
    INDArray ia = Nd4j.create(new float[] { 3, 4, 0, 2, 1 });
    INDArray expOut = Nd4j.create(new float[] { 2, 4, 3, 0, 1 });
    SDVariable input = sd.var("in", new int[] { 1, 5 });
    sd.associateArrayWithVariable(ia, input);
    SDVariable out = sd.invertPermutation(input);
    sd.exec();
    assert expOut.equals(out.getArr());
}
Also used : SDVariable(org.nd4j.autodiff.samediff.SDVariable) INDArray(org.nd4j.linalg.api.ndarray.INDArray) SameDiff(org.nd4j.autodiff.samediff.SameDiff) Test(org.junit.Test)

Example 10 with SameDiff

use of org.nd4j.autodiff.samediff.SameDiff in project nd4j by deeplearning4j.

the class TensorFlowImportTest method importGraph2.

@Test
@Ignore
public void importGraph2() throws Exception {
    SameDiff graph = TFGraphMapper.getInstance().importGraph(new ClassPathResource("tf_graphs/tensorflow_inception_graph.pb").getInputStream());
    assertNotNull(graph);
}
Also used : SameDiff(org.nd4j.autodiff.samediff.SameDiff) ClassPathResource(org.nd4j.linalg.io.ClassPathResource) Ignore(org.junit.Ignore) Test(org.junit.Test)

Aggregations

SameDiff (org.nd4j.autodiff.samediff.SameDiff)50 Test (org.junit.Test)42 SDVariable (org.nd4j.autodiff.samediff.SDVariable)41 INDArray (org.nd4j.linalg.api.ndarray.INDArray)37 ArrayList (java.util.ArrayList)10 DynamicCustomOp (org.nd4j.linalg.api.ops.DynamicCustomOp)10 Ignore (org.junit.Ignore)7 ClassPathResource (org.nd4j.linalg.io.ClassPathResource)6 lombok.val (lombok.val)4 LossFunctions (org.nd4j.autodiff.loss.LossFunctions)4 LossInfo (org.nd4j.autodiff.loss.LossInfo)4 BernoulliDistribution (org.nd4j.linalg.api.ops.random.impl.BernoulliDistribution)4 DifferentialFunction (org.nd4j.autodiff.functions.DifferentialFunction)2 Triple (org.nd4j.linalg.primitives.Triple)2 ZeroInitScheme (org.nd4j.weightinit.impl.ZeroInitScheme)2 DataOutputStream (java.io.DataOutputStream)1 FileOutputStream (java.io.FileOutputStream)1 ByteBuffer (java.nio.ByteBuffer)1 HashMap (java.util.HashMap)1 LinkedHashMap (java.util.LinkedHashMap)1