use of org.nd4j.jita.allocator.tad.DeviceTADManager in project nd4j by deeplearning4j.
the class DelayedMemoryTest method testDelayedTAD1.
@Test
public void testDelayedTAD1() throws Exception {
TADManager tadManager = new DeviceTADManager();
INDArray array = Nd4j.create(128, 256);
Pair<DataBuffer, DataBuffer> tadBuffers = tadManager.getTADOnlyShapeInfo(array, new int[] { 0 });
DataBuffer tadBuffer = tadBuffers.getFirst();
DataBuffer offBuffer = tadBuffers.getSecond();
AllocationPoint pointTad = AtomicAllocator.getInstance().getAllocationPoint(tadBuffer);
AllocationPoint pointOff = AtomicAllocator.getInstance().getAllocationPoint(offBuffer);
assertEquals(AllocationStatus.CONSTANT, pointTad.getAllocationStatus());
assertEquals(AllocationStatus.DEVICE, pointOff.getAllocationStatus());
}
Aggregations