Search in sources :

Example 6 with CustomOp

use of org.nd4j.linalg.api.ops.CustomOp in project nd4j by deeplearning4j.

the class CustomOpsTests method testNonInplaceOp2.

/**
 * This test works inplace, but without inplace declaration
 */
@Test
public void testNonInplaceOp2() throws Exception {
    val arrayX = Nd4j.create(10, 10);
    val arrayY = Nd4j.create(10, 10);
    arrayX.assign(3.0);
    arrayY.assign(1.0);
    val exp = Nd4j.create(10, 10).assign(4.0);
    CustomOp op = DynamicCustomOp.builder("add").addInputs(arrayX, arrayY).addOutputs(arrayX).build();
    Nd4j.getExecutioner().exec(op);
    assertEquals(exp, arrayX);
}
Also used : lombok.val(lombok.val) CustomOp(org.nd4j.linalg.api.ops.CustomOp) DynamicCustomOp(org.nd4j.linalg.api.ops.DynamicCustomOp) Test(org.junit.Test)

Example 7 with CustomOp

use of org.nd4j.linalg.api.ops.CustomOp in project nd4j by deeplearning4j.

the class CustomOpsTests method testMergeMax1.

@Test
public void testMergeMax1() throws Exception {
    val array0 = Nd4j.create(new double[] { 1, 0, 0, 0, 0 });
    val array1 = Nd4j.create(new double[] { 0, 2, 0, 0, 0 });
    val array2 = Nd4j.create(new double[] { 0, 0, 3, 0, 0 });
    val array3 = Nd4j.create(new double[] { 0, 0, 0, 4, 0 });
    val array4 = Nd4j.create(new double[] { 0, 0, 0, 0, 5 });
    val z = Nd4j.create(5);
    val exp = Nd4j.create(new double[] { 1, 2, 3, 4, 5 });
    CustomOp op = DynamicCustomOp.builder("mergemax").addInputs(array0, array1, array2, array3, array4).addOutputs(z).callInplace(false).build();
    Nd4j.getExecutioner().exec(op);
    assertEquals(exp, z);
}
Also used : lombok.val(lombok.val) CustomOp(org.nd4j.linalg.api.ops.CustomOp) DynamicCustomOp(org.nd4j.linalg.api.ops.DynamicCustomOp) Test(org.junit.Test)

Example 8 with CustomOp

use of org.nd4j.linalg.api.ops.CustomOp in project nd4j by deeplearning4j.

the class CustomOpsTests method testMergeMaxMixedOrder.

@Test
public void testMergeMaxMixedOrder() {
    // some random array with +ve numbers
    val array0 = Nd4j.rand('f', 5, 2).addi(1);
    val array1 = array0.dup().addi(5);
    // array1 is always bigger than array0 except at 0,0
    array1.put(0, 0, 0);
    // expected value of maxmerge
    val exp = array1.dup();
    exp.putScalar(0, 0, array0.getDouble(0, 0));
    val zF = Nd4j.zeros(array0.shape(), 'f');
    CustomOp op = DynamicCustomOp.builder("mergemax").addInputs(array0, array1).addOutputs(zF).build();
    Nd4j.getExecutioner().exec(op);
    assertEquals(exp, zF);
}
Also used : lombok.val(lombok.val) CustomOp(org.nd4j.linalg.api.ops.CustomOp) DynamicCustomOp(org.nd4j.linalg.api.ops.DynamicCustomOp) Test(org.junit.Test)

Example 9 with CustomOp

use of org.nd4j.linalg.api.ops.CustomOp in project nd4j by deeplearning4j.

the class CustomOpsTests method testInplaceOp2.

@Test
public void testInplaceOp2() throws Exception {
    val arrayX = Nd4j.create(10, 10);
    val arrayY = Nd4j.create(10, 10);
    val arrayZ = Nd4j.create(10, 10);
    arrayX.assign(3.0);
    arrayY.assign(1.0);
    val exp = Nd4j.create(10, 10).assign(4.0);
    val expZ = Nd4j.create(10, 10);
    CustomOp op = DynamicCustomOp.builder("add").addInputs(arrayX, arrayY).addOutputs(arrayZ).callInplace(true).build();
    Nd4j.getExecutioner().exec(op);
    assertEquals(exp, arrayX);
    assertEquals(expZ, arrayZ);
}
Also used : lombok.val(lombok.val) CustomOp(org.nd4j.linalg.api.ops.CustomOp) DynamicCustomOp(org.nd4j.linalg.api.ops.DynamicCustomOp) Test(org.junit.Test)

Example 10 with CustomOp

use of org.nd4j.linalg.api.ops.CustomOp in project nd4j by deeplearning4j.

the class CustomOpsTests method testNoneInplaceOp3.

@Test(expected = ND4JIllegalStateException.class)
public void testNoneInplaceOp3() throws Exception {
    val arrayX = Nd4j.create(10, 10);
    val arrayY = Nd4j.create(10, 10);
    arrayX.assign(4.0);
    arrayY.assign(2.0);
    val exp = Nd4j.create(10, 10).assign(6.0);
    CustomOp op = DynamicCustomOp.builder("add").addInputs(arrayX, arrayY).callInplace(false).build();
    Nd4j.getExecutioner().exec(op);
    assertEquals(exp, arrayX);
}
Also used : lombok.val(lombok.val) CustomOp(org.nd4j.linalg.api.ops.CustomOp) DynamicCustomOp(org.nd4j.linalg.api.ops.DynamicCustomOp) Test(org.junit.Test)

Aggregations

CustomOp (org.nd4j.linalg.api.ops.CustomOp)13 DynamicCustomOp (org.nd4j.linalg.api.ops.DynamicCustomOp)13 lombok.val (lombok.val)11 Test (org.junit.Test)11 INDArray (org.nd4j.linalg.api.ndarray.INDArray)2 Ignore (org.junit.Ignore)1 ActivationSoftmax (org.nd4j.linalg.activations.impl.ActivationSoftmax)1 LogSoftMax (org.nd4j.linalg.api.ops.impl.transforms.LogSoftMax)1 TimesOneMinus (org.nd4j.linalg.api.ops.impl.transforms.TimesOneMinus)1