use of org.nd4j.linalg.api.ops.CustomOp in project nd4j by deeplearning4j.
the class CustomOpsTests method testNonInplaceOp2.
/**
* This test works inplace, but without inplace declaration
*/
@Test
public void testNonInplaceOp2() throws Exception {
val arrayX = Nd4j.create(10, 10);
val arrayY = Nd4j.create(10, 10);
arrayX.assign(3.0);
arrayY.assign(1.0);
val exp = Nd4j.create(10, 10).assign(4.0);
CustomOp op = DynamicCustomOp.builder("add").addInputs(arrayX, arrayY).addOutputs(arrayX).build();
Nd4j.getExecutioner().exec(op);
assertEquals(exp, arrayX);
}
use of org.nd4j.linalg.api.ops.CustomOp in project nd4j by deeplearning4j.
the class CustomOpsTests method testMergeMax1.
@Test
public void testMergeMax1() throws Exception {
val array0 = Nd4j.create(new double[] { 1, 0, 0, 0, 0 });
val array1 = Nd4j.create(new double[] { 0, 2, 0, 0, 0 });
val array2 = Nd4j.create(new double[] { 0, 0, 3, 0, 0 });
val array3 = Nd4j.create(new double[] { 0, 0, 0, 4, 0 });
val array4 = Nd4j.create(new double[] { 0, 0, 0, 0, 5 });
val z = Nd4j.create(5);
val exp = Nd4j.create(new double[] { 1, 2, 3, 4, 5 });
CustomOp op = DynamicCustomOp.builder("mergemax").addInputs(array0, array1, array2, array3, array4).addOutputs(z).callInplace(false).build();
Nd4j.getExecutioner().exec(op);
assertEquals(exp, z);
}
use of org.nd4j.linalg.api.ops.CustomOp in project nd4j by deeplearning4j.
the class CustomOpsTests method testMergeMaxMixedOrder.
@Test
public void testMergeMaxMixedOrder() {
// some random array with +ve numbers
val array0 = Nd4j.rand('f', 5, 2).addi(1);
val array1 = array0.dup().addi(5);
// array1 is always bigger than array0 except at 0,0
array1.put(0, 0, 0);
// expected value of maxmerge
val exp = array1.dup();
exp.putScalar(0, 0, array0.getDouble(0, 0));
val zF = Nd4j.zeros(array0.shape(), 'f');
CustomOp op = DynamicCustomOp.builder("mergemax").addInputs(array0, array1).addOutputs(zF).build();
Nd4j.getExecutioner().exec(op);
assertEquals(exp, zF);
}
use of org.nd4j.linalg.api.ops.CustomOp in project nd4j by deeplearning4j.
the class CustomOpsTests method testInplaceOp2.
@Test
public void testInplaceOp2() throws Exception {
val arrayX = Nd4j.create(10, 10);
val arrayY = Nd4j.create(10, 10);
val arrayZ = Nd4j.create(10, 10);
arrayX.assign(3.0);
arrayY.assign(1.0);
val exp = Nd4j.create(10, 10).assign(4.0);
val expZ = Nd4j.create(10, 10);
CustomOp op = DynamicCustomOp.builder("add").addInputs(arrayX, arrayY).addOutputs(arrayZ).callInplace(true).build();
Nd4j.getExecutioner().exec(op);
assertEquals(exp, arrayX);
assertEquals(expZ, arrayZ);
}
use of org.nd4j.linalg.api.ops.CustomOp in project nd4j by deeplearning4j.
the class CustomOpsTests method testNoneInplaceOp3.
@Test(expected = ND4JIllegalStateException.class)
public void testNoneInplaceOp3() throws Exception {
val arrayX = Nd4j.create(10, 10);
val arrayY = Nd4j.create(10, 10);
arrayX.assign(4.0);
arrayY.assign(2.0);
val exp = Nd4j.create(10, 10).assign(6.0);
CustomOp op = DynamicCustomOp.builder("add").addInputs(arrayX, arrayY).callInplace(false).build();
Nd4j.getExecutioner().exec(op);
assertEquals(exp, arrayX);
}
Aggregations