Search in sources :

Example 1 with ScalarSet

use of org.nd4j.linalg.api.ops.impl.scalar.ScalarSet in project nd4j by deeplearning4j.

the class GridExecutionerTest method testGridFlowFlush1.

/*
    @Test
    public void testGridFlow9() throws Exception {
        CudaGridExecutioner executioner = new CudaGridExecutioner();

        INDArray arrayX = Nd4j.create(new float[] {0f, 0f, 0f});
        INDArray arrayY1 = Nd4j.create(new float[] {-1f, -1f, 1f});
        INDArray arrayY2 = Nd4j.create(new float[] {1f, 1f, 1f});
        INDArray exp = Nd4j.create(new float[] {1f, 1f, 1f});

        Set opA = new Set(arrayX, arrayY1, arrayX, arrayY1.length());

        executioner.exec(opA);

        assertEquals(1, executioner.getQueueLength());

        ScalarSet opB = new ScalarSet(arrayX, 1f);
        executioner.exec(opB);

        assertEquals(0, executioner.getQueueLength());
        assertEquals(1f, arrayX.getFloat(0), 0.1f);
        assertEquals(1f, arrayX.getFloat(1), 0.1f);
        //assertEquals(exp, arrayX);
    }
*/
@Test
public void testGridFlowFlush1() throws Exception {
    CudaGridExecutioner executioner = new CudaGridExecutioner();
    INDArray arrayX = Nd4j.create(10);
    INDArray arrayY = Nd4j.create(new float[] { 1f, 1f, 1f, 1f, 1f, 1f, 1f, 1f, 1f, 1f });
    INDArray exp = Nd4j.create(new float[] { 3f, 3f, 3f, 3f, 3f, 3f, 3f, 3f, 3f, 3f });
    Set opA = new Set(arrayX, arrayY, arrayX, arrayX.length());
    executioner.exec(opA);
    executioner.flushQueue();
    assertEquals(arrayY, arrayX);
}
Also used : ScalarSet(org.nd4j.linalg.api.ops.impl.scalar.ScalarSet) Set(org.nd4j.linalg.api.ops.impl.transforms.Set) INDArray(org.nd4j.linalg.api.ndarray.INDArray) Test(org.junit.Test)

Example 2 with ScalarSet

use of org.nd4j.linalg.api.ops.impl.scalar.ScalarSet in project nd4j by deeplearning4j.

the class GridExecutionerTest method testGridFlow8.

@Test
public void testGridFlow8() throws Exception {
    CudaGridExecutioner executioner = new CudaGridExecutioner();
    INDArray arrayX = Nd4j.create(new float[] { 0f, 0f, 0f });
    INDArray arrayY1 = Nd4j.create(new float[] { -1f, -1f, 1f });
    INDArray arrayY2 = Nd4j.create(new float[] { 1f, 1f, 1f });
    INDArray exp = Nd4j.create(new float[] { 1f, 1f, 1f });
    Set opA = new Set(arrayX, arrayY1, arrayX, arrayY1.length());
    executioner.exec(opA);
    assertEquals(1, executioner.getQueueLength());
    ScalarSet opB = new ScalarSet(arrayX, 1f);
    executioner.exec(opB);
    assertEquals(0, executioner.getQueueLength());
    assertEquals(1f, arrayX.getFloat(0), 0.1f);
    assertEquals(1f, arrayX.getFloat(1), 0.1f);
// assertEquals(exp, arrayX);
}
Also used : ScalarSet(org.nd4j.linalg.api.ops.impl.scalar.ScalarSet) Set(org.nd4j.linalg.api.ops.impl.transforms.Set) INDArray(org.nd4j.linalg.api.ndarray.INDArray) ScalarSet(org.nd4j.linalg.api.ops.impl.scalar.ScalarSet) Test(org.junit.Test)

Aggregations

Test (org.junit.Test)2 INDArray (org.nd4j.linalg.api.ndarray.INDArray)2 ScalarSet (org.nd4j.linalg.api.ops.impl.scalar.ScalarSet)2 Set (org.nd4j.linalg.api.ops.impl.transforms.Set)2