use of org.nd4j.linalg.api.ops.impl.scalar.ScalarSet in project nd4j by deeplearning4j.
the class GridExecutionerTest method testGridFlowFlush1.
/*
@Test
public void testGridFlow9() throws Exception {
CudaGridExecutioner executioner = new CudaGridExecutioner();
INDArray arrayX = Nd4j.create(new float[] {0f, 0f, 0f});
INDArray arrayY1 = Nd4j.create(new float[] {-1f, -1f, 1f});
INDArray arrayY2 = Nd4j.create(new float[] {1f, 1f, 1f});
INDArray exp = Nd4j.create(new float[] {1f, 1f, 1f});
Set opA = new Set(arrayX, arrayY1, arrayX, arrayY1.length());
executioner.exec(opA);
assertEquals(1, executioner.getQueueLength());
ScalarSet opB = new ScalarSet(arrayX, 1f);
executioner.exec(opB);
assertEquals(0, executioner.getQueueLength());
assertEquals(1f, arrayX.getFloat(0), 0.1f);
assertEquals(1f, arrayX.getFloat(1), 0.1f);
//assertEquals(exp, arrayX);
}
*/
@Test
public void testGridFlowFlush1() throws Exception {
CudaGridExecutioner executioner = new CudaGridExecutioner();
INDArray arrayX = Nd4j.create(10);
INDArray arrayY = Nd4j.create(new float[] { 1f, 1f, 1f, 1f, 1f, 1f, 1f, 1f, 1f, 1f });
INDArray exp = Nd4j.create(new float[] { 3f, 3f, 3f, 3f, 3f, 3f, 3f, 3f, 3f, 3f });
Set opA = new Set(arrayX, arrayY, arrayX, arrayX.length());
executioner.exec(opA);
executioner.flushQueue();
assertEquals(arrayY, arrayX);
}
use of org.nd4j.linalg.api.ops.impl.scalar.ScalarSet in project nd4j by deeplearning4j.
the class GridExecutionerTest method testGridFlow8.
@Test
public void testGridFlow8() throws Exception {
CudaGridExecutioner executioner = new CudaGridExecutioner();
INDArray arrayX = Nd4j.create(new float[] { 0f, 0f, 0f });
INDArray arrayY1 = Nd4j.create(new float[] { -1f, -1f, 1f });
INDArray arrayY2 = Nd4j.create(new float[] { 1f, 1f, 1f });
INDArray exp = Nd4j.create(new float[] { 1f, 1f, 1f });
Set opA = new Set(arrayX, arrayY1, arrayX, arrayY1.length());
executioner.exec(opA);
assertEquals(1, executioner.getQueueLength());
ScalarSet opB = new ScalarSet(arrayX, 1f);
executioner.exec(opB);
assertEquals(0, executioner.getQueueLength());
assertEquals(1f, arrayX.getFloat(0), 0.1f);
assertEquals(1f, arrayX.getFloat(1), 0.1f);
// assertEquals(exp, arrayX);
}