Search in sources :

Example 46 with DataBuffer

use of org.nd4j.linalg.api.buffer.DataBuffer in project nd4j by deeplearning4j.

the class Gzip method decompress.

@Override
public DataBuffer decompress(DataBuffer buffer) {
    try {
        CompressedDataBuffer compressed = (CompressedDataBuffer) buffer;
        CompressionDescriptor descriptor = compressed.getCompressionDescriptor();
        BytePointer pointer = (BytePointer) compressed.addressPointer();
        ByteArrayInputStream bis = new ByteArrayInputStream(pointer.getStringBytes());
        GZIPInputStream gzip = new GZIPInputStream(bis);
        DataInputStream dis = new DataInputStream(gzip);
        DataBuffer bufferRestored = Nd4j.createBuffer(descriptor.getNumberOfElements());
        bufferRestored.read(dis);
        return bufferRestored;
    } catch (Exception e) {
        throw new RuntimeException(e);
    }
}
Also used : GZIPInputStream(java.util.zip.GZIPInputStream) CompressionDescriptor(org.nd4j.linalg.compression.CompressionDescriptor) CompressedDataBuffer(org.nd4j.linalg.compression.CompressedDataBuffer) ByteArrayInputStream(java.io.ByteArrayInputStream) BytePointer(org.bytedeco.javacpp.BytePointer) DataInputStream(java.io.DataInputStream) DataBuffer(org.nd4j.linalg.api.buffer.DataBuffer) CompressedDataBuffer(org.nd4j.linalg.compression.CompressedDataBuffer)

Example 47 with DataBuffer

use of org.nd4j.linalg.api.buffer.DataBuffer in project nd4j by deeplearning4j.

the class BasicTADManager method getTADOnlyShapeInfo.

@Override
public Pair<DataBuffer, DataBuffer> getTADOnlyShapeInfo(INDArray array, int[] dimension) {
    if (dimension != null && dimension.length > 1)
        Arrays.sort(dimension);
    if (dimension == null)
        dimension = new int[] { Integer.MAX_VALUE };
    boolean isScalar = dimension == null || (dimension.length == 1 && dimension[0] == Integer.MAX_VALUE);
    // FIXME: this is fast triage, remove it later
    // dimensionLength <= 1 ? 2 : dimensionLength;
    int targetRank = isScalar ? 2 : array.rank();
    long offsetLength = 0;
    long tadLength = 1;
    if (!isScalar)
        for (int i = 0; i < dimension.length; i++) {
            tadLength *= array.shape()[dimension[i]];
        }
    if (!isScalar)
        offsetLength = array.lengthLong() / tadLength;
    else
        offsetLength = 1;
    // logger.info("Original shape info before TAD: {}", array.shapeInfoDataBuffer());
    // logger.info("dimension: {}, tadLength: {}, offsetLength for TAD: {}", Arrays.toString(dimension),tadLength, offsetLength);
    DataBuffer outputBuffer = new CudaIntDataBuffer(targetRank * 2 + 4);
    DataBuffer offsetsBuffer = new CudaLongDataBuffer(offsetLength);
    AtomicAllocator.getInstance().getAllocationPoint(outputBuffer).tickHostWrite();
    AtomicAllocator.getInstance().getAllocationPoint(offsetsBuffer).tickHostWrite();
    DataBuffer dimensionBuffer = AtomicAllocator.getInstance().getConstantBuffer(dimension);
    Pointer dimensionPointer = AtomicAllocator.getInstance().getHostPointer(dimensionBuffer);
    Pointer xShapeInfo = AddressRetriever.retrieveHostPointer(array.shapeInfoDataBuffer());
    Pointer targetPointer = AddressRetriever.retrieveHostPointer(outputBuffer);
    Pointer offsetsPointer = AddressRetriever.retrieveHostPointer(offsetsBuffer);
    if (!isScalar)
        nativeOps.tadOnlyShapeInfo((IntPointer) xShapeInfo, (IntPointer) dimensionPointer, dimension.length, (IntPointer) targetPointer, new LongPointerWrapper(offsetsPointer));
    else {
        outputBuffer.put(0, 2);
        outputBuffer.put(1, 1);
        outputBuffer.put(2, 1);
        outputBuffer.put(3, 1);
        outputBuffer.put(4, 1);
        outputBuffer.put(5, 0);
        outputBuffer.put(6, 0);
        outputBuffer.put(7, 99);
    }
    AtomicAllocator.getInstance().getAllocationPoint(outputBuffer).tickHostWrite();
    AtomicAllocator.getInstance().getAllocationPoint(offsetsBuffer).tickHostWrite();
    return new Pair<>(outputBuffer, offsetsBuffer);
}
Also used : CudaLongDataBuffer(org.nd4j.linalg.jcublas.buffer.CudaLongDataBuffer) IntPointer(org.bytedeco.javacpp.IntPointer) LongPointerWrapper(org.nd4j.nativeblas.LongPointerWrapper) IntPointer(org.bytedeco.javacpp.IntPointer) Pointer(org.bytedeco.javacpp.Pointer) CudaIntDataBuffer(org.nd4j.linalg.jcublas.buffer.CudaIntDataBuffer) DataBuffer(org.nd4j.linalg.api.buffer.DataBuffer) CudaLongDataBuffer(org.nd4j.linalg.jcublas.buffer.CudaLongDataBuffer) CudaDoubleDataBuffer(org.nd4j.linalg.jcublas.buffer.CudaDoubleDataBuffer) CudaIntDataBuffer(org.nd4j.linalg.jcublas.buffer.CudaIntDataBuffer) Pair(org.nd4j.linalg.primitives.Pair)

Example 48 with DataBuffer

use of org.nd4j.linalg.api.buffer.DataBuffer in project nd4j by deeplearning4j.

the class AbstractCompressor method decompress.

@Override
public INDArray decompress(INDArray array) {
    DataBuffer buffer = decompress(array.data());
    DataBuffer shapeInfo = array.shapeInfoDataBuffer();
    INDArray rest = Nd4j.createArrayFromShapeBuffer(buffer, shapeInfo);
    return rest;
}
Also used : INDArray(org.nd4j.linalg.api.ndarray.INDArray) DataBuffer(org.nd4j.linalg.api.buffer.DataBuffer) CompressedDataBuffer(org.nd4j.linalg.compression.CompressedDataBuffer)

Example 49 with DataBuffer

use of org.nd4j.linalg.api.buffer.DataBuffer in project nd4j by deeplearning4j.

the class AbstractCompressor method compress.

/**
 * This method creates compressed INDArray from Java float array, skipping usual INDArray instantiation routines
 *
 * @param data
 * @param shape
 * @param order
 * @return
 */
@Override
public INDArray compress(float[] data, int[] shape, char order) {
    FloatPointer pointer = new FloatPointer(data);
    DataBuffer shapeInfo = Nd4j.getShapeInfoProvider().createShapeInformation(shape, order).getFirst();
    DataBuffer buffer = compressPointer(DataBuffer.TypeEx.FLOAT, pointer, data.length, 4);
    return Nd4j.createArrayFromShapeBuffer(buffer, shapeInfo);
}
Also used : FloatPointer(org.bytedeco.javacpp.FloatPointer) DataBuffer(org.nd4j.linalg.api.buffer.DataBuffer) CompressedDataBuffer(org.nd4j.linalg.compression.CompressedDataBuffer)

Example 50 with DataBuffer

use of org.nd4j.linalg.api.buffer.DataBuffer in project nd4j by deeplearning4j.

the class OperationProfilerTests method testBadTad5.

@Test
public void testBadTad5() throws Exception {
    INDArray x = Nd4j.create(new int[] { 2, 4, 5, 6, 7 }, 'f');
    Pair<DataBuffer, DataBuffer> pair = Nd4j.getExecutioner().getTADManager().getTADOnlyShapeInfo(x, new int[] { 4 });
    OpProfiler.PenaltyCause[] causes = OpProfiler.getInstance().processTADOperands(pair.getFirst());
    log.info("TAD: {}", Arrays.toString(pair.getFirst().asInt()));
    log.info("Causes: {}", Arrays.toString(causes));
    assertEquals(1, causes.length);
    assertTrue(ArrayUtils.contains(causes, OpProfiler.PenaltyCause.TAD_STRIDED_ACCESS));
}
Also used : INDArray(org.nd4j.linalg.api.ndarray.INDArray) DataBuffer(org.nd4j.linalg.api.buffer.DataBuffer) Test(org.junit.Test)

Aggregations

DataBuffer (org.nd4j.linalg.api.buffer.DataBuffer)186 INDArray (org.nd4j.linalg.api.ndarray.INDArray)79 Test (org.junit.Test)47 CompressedDataBuffer (org.nd4j.linalg.compression.CompressedDataBuffer)44 CudaContext (org.nd4j.linalg.jcublas.context.CudaContext)39 CudaPointer (org.nd4j.jita.allocator.pointers.CudaPointer)30 AllocationPoint (org.nd4j.jita.allocator.impl.AllocationPoint)25 ND4JIllegalStateException (org.nd4j.linalg.exception.ND4JIllegalStateException)23 BaseDataBuffer (org.nd4j.linalg.api.buffer.BaseDataBuffer)19 Pointer (org.bytedeco.javacpp.Pointer)18 BaseNd4jTest (org.nd4j.linalg.BaseNd4jTest)16 CudaDoubleDataBuffer (org.nd4j.linalg.jcublas.buffer.CudaDoubleDataBuffer)16 IntPointer (org.bytedeco.javacpp.IntPointer)13 PagedPointer (org.nd4j.linalg.api.memory.pointers.PagedPointer)13 CudaIntDataBuffer (org.nd4j.linalg.jcublas.buffer.CudaIntDataBuffer)13 DoublePointer (org.bytedeco.javacpp.DoublePointer)12 FloatPointer (org.bytedeco.javacpp.FloatPointer)12 GridExecutioner (org.nd4j.linalg.api.ops.executioner.GridExecutioner)12 LongPointerWrapper (org.nd4j.nativeblas.LongPointerWrapper)11 CUstream_st (org.bytedeco.javacpp.cuda.CUstream_st)10