Search in sources :

Example 66 with DataBuffer

use of org.nd4j.linalg.api.buffer.DataBuffer in project nd4j by deeplearning4j.

the class CudaHalfDataBufferTest method testSerialization2.

@Test
public void testSerialization2() throws Exception {
    DataBuffer bufferOriginal = new CudaFloatDataBuffer(new float[] { 1f, 2f, 3f, 4f, 5f });
    DataBuffer bufferHalfs = Nd4j.getNDArrayFactory().convertDataEx(DataBuffer.TypeEx.FLOAT, bufferOriginal, DataBuffer.TypeEx.FLOAT16);
    DataTypeUtil.setDTypeForContext(DataBuffer.Type.HALF);
    File tempFile = File.createTempFile("alpha", "11");
    tempFile.deleteOnExit();
    // now we serialize halfs, and we expect it to become floats on other side
    try (DataOutputStream dos = new DataOutputStream(Files.newOutputStream(Paths.get(tempFile.getAbsolutePath())))) {
        bufferHalfs.write(dos);
    }
    // loading data back from file
    DataInputStream dis = new DataInputStream(new FileInputStream(tempFile.getAbsoluteFile()));
    DataBuffer bufferRestored = Nd4j.createBuffer(bufferOriginal.length());
    bufferRestored.read(dis);
    assertEquals(bufferRestored.dataType(), DataBuffer.Type.HALF);
    DataTypeUtil.setDTypeForContext(DataBuffer.Type.FLOAT);
    DataBuffer bufferConverted = Nd4j.getNDArrayFactory().convertDataEx(DataBuffer.TypeEx.FLOAT16, bufferRestored, DataBuffer.TypeEx.FLOAT);
    assertArrayEquals(bufferOriginal.asFloat(), bufferConverted.asFloat(), 0.01f);
}
Also used : DataOutputStream(java.io.DataOutputStream) DataInputStream(java.io.DataInputStream) File(java.io.File) FileInputStream(java.io.FileInputStream) DataBuffer(org.nd4j.linalg.api.buffer.DataBuffer) Test(org.junit.Test)

Example 67 with DataBuffer

use of org.nd4j.linalg.api.buffer.DataBuffer in project nd4j by deeplearning4j.

the class CudaHalfDataBufferTest method testConversion1.

@Test
public void testConversion1() throws Exception {
    DataBuffer bufferOriginal = new CudaFloatDataBuffer(new float[] { 1f, 2f, 3f, 4f, 5f });
    DataBuffer bufferHalfs = Nd4j.getNDArrayFactory().convertDataEx(DataBuffer.TypeEx.FLOAT, bufferOriginal, DataBuffer.TypeEx.FLOAT16);
    DataBuffer bufferRestored = Nd4j.getNDArrayFactory().convertDataEx(DataBuffer.TypeEx.FLOAT16, bufferHalfs, DataBuffer.TypeEx.FLOAT);
    logger.info("Buffer original: {}", Arrays.toString(bufferOriginal.asFloat()));
    logger.info("Buffer restored: {}", Arrays.toString(bufferRestored.asFloat()));
    assertArrayEquals(bufferOriginal.asFloat(), bufferRestored.asFloat(), 0.01f);
}
Also used : DataBuffer(org.nd4j.linalg.api.buffer.DataBuffer) Test(org.junit.Test)

Example 68 with DataBuffer

use of org.nd4j.linalg.api.buffer.DataBuffer in project nd4j by deeplearning4j.

the class DirectSparseInfoProvider method createSparseInformation.

@Override
public DataBuffer createSparseInformation(int[] flags, long[] sparseOffsets, int[] hiddenDimensions, int underlyingRank) {
    SparseDescriptor descriptor = new SparseDescriptor(flags, sparseOffsets, hiddenDimensions, underlyingRank);
    if (!sparseCache.containsKey(descriptor)) {
        if (counter.get() < MAX_ENTRIES) {
            if (!sparseCache.containsKey(descriptor)) {
                counter.incrementAndGet();
                DataBuffer buffer = Shape.createSparseInformation(flags, sparseOffsets, hiddenDimensions, underlyingRank);
                sparseCache.put(descriptor, buffer);
                return buffer;
            }
        } else {
            return Shape.createSparseInformation(flags, sparseOffsets, hiddenDimensions, underlyingRank);
        }
    }
    return sparseCache.get(descriptor);
}
Also used : SparseDescriptor(org.nd4j.linalg.api.shape.SparseDescriptor) DataBuffer(org.nd4j.linalg.api.buffer.DataBuffer)

Example 69 with DataBuffer

use of org.nd4j.linalg.api.buffer.DataBuffer in project nd4j by deeplearning4j.

the class DirectShapeInfoProvider method createShapeInformation.

@Override
public Pair<DataBuffer, int[]> createShapeInformation(int[] shape, int[] stride, long offset, int elementWiseStride, char order) {
    // We enforce offset to 0 in shapeBuffer, since we need it for cache efficiency + we don't actually use offset value @ native side
    offset = 0;
    ShapeDescriptor descriptor = new ShapeDescriptor(shape, stride, offset, elementWiseStride, order);
    if (!shapeCache.containsKey(descriptor)) {
        if (counter.get() < MAX_ENTRIES) {
            synchronized (this) {
                if (!shapeCache.containsKey(descriptor)) {
                    counter.incrementAndGet();
                    Pair<DataBuffer, int[]> buffer = super.createShapeInformation(shape, stride, offset, elementWiseStride, order);
                    shapeCache.put(descriptor, buffer);
                    bytes.addAndGet(buffer.getFirst().length() * 4 * 2);
                    return buffer;
                } else
                    return shapeCache.get(descriptor);
            }
        } else {
            return super.createShapeInformation(shape, stride, offset, elementWiseStride, order);
        }
    }
    return shapeCache.get(descriptor);
}
Also used : ShapeDescriptor(org.nd4j.linalg.api.shape.ShapeDescriptor) DataBuffer(org.nd4j.linalg.api.buffer.DataBuffer)

Example 70 with DataBuffer

use of org.nd4j.linalg.api.buffer.DataBuffer in project nd4j by deeplearning4j.

the class ConstantBuffersCache method getConstantBuffer.

@Override
public DataBuffer getConstantBuffer(double[] array) {
    ArrayDescriptor descriptor = new ArrayDescriptor(array);
    if (!buffersCache.containsKey(descriptor)) {
        DataBuffer buffer = Nd4j.createBufferDetached(array);
        if (counter.get() < MAX_ENTRIES) {
            counter.incrementAndGet();
            buffersCache.put(descriptor, buffer);
            bytes.addAndGet(array.length * Nd4j.sizeOfDataType());
        }
        return buffer;
    }
    return buffersCache.get(descriptor);
}
Also used : ArrayDescriptor(org.nd4j.linalg.cache.ArrayDescriptor) DataBuffer(org.nd4j.linalg.api.buffer.DataBuffer)

Aggregations

DataBuffer (org.nd4j.linalg.api.buffer.DataBuffer)186 INDArray (org.nd4j.linalg.api.ndarray.INDArray)79 Test (org.junit.Test)47 CompressedDataBuffer (org.nd4j.linalg.compression.CompressedDataBuffer)44 CudaContext (org.nd4j.linalg.jcublas.context.CudaContext)39 CudaPointer (org.nd4j.jita.allocator.pointers.CudaPointer)30 AllocationPoint (org.nd4j.jita.allocator.impl.AllocationPoint)25 ND4JIllegalStateException (org.nd4j.linalg.exception.ND4JIllegalStateException)23 BaseDataBuffer (org.nd4j.linalg.api.buffer.BaseDataBuffer)19 Pointer (org.bytedeco.javacpp.Pointer)18 BaseNd4jTest (org.nd4j.linalg.BaseNd4jTest)16 CudaDoubleDataBuffer (org.nd4j.linalg.jcublas.buffer.CudaDoubleDataBuffer)16 IntPointer (org.bytedeco.javacpp.IntPointer)13 PagedPointer (org.nd4j.linalg.api.memory.pointers.PagedPointer)13 CudaIntDataBuffer (org.nd4j.linalg.jcublas.buffer.CudaIntDataBuffer)13 DoublePointer (org.bytedeco.javacpp.DoublePointer)12 FloatPointer (org.bytedeco.javacpp.FloatPointer)12 GridExecutioner (org.nd4j.linalg.api.ops.executioner.GridExecutioner)12 LongPointerWrapper (org.nd4j.nativeblas.LongPointerWrapper)11 CUstream_st (org.bytedeco.javacpp.cuda.CUstream_st)10