Search in sources :

Example 21 with DataBuffer

use of org.nd4j.linalg.api.buffer.DataBuffer in project nd4j by deeplearning4j.

the class Nd4j method createBuffer.

/**
 * Create a buffer based on the data opType
 *
 * @param data the data to create the buffer with
 * @return the created buffer
 */
public static DataBuffer createBuffer(byte[] data, int length) {
    DataBuffer ret;
    if (dataType() == DataBuffer.Type.DOUBLE)
        ret = DATA_BUFFER_FACTORY_INSTANCE.createDouble(data, length);
    else if (dataType() == DataBuffer.Type.HALF)
        ret = DATA_BUFFER_FACTORY_INSTANCE.createHalf(data, length);
    else
        ret = DATA_BUFFER_FACTORY_INSTANCE.createFloat(data, length);
    logCreationIfNecessary(ret);
    return ret;
}
Also used : DataBuffer(org.nd4j.linalg.api.buffer.DataBuffer) CompressedDataBuffer(org.nd4j.linalg.compression.CompressedDataBuffer)

Example 22 with DataBuffer

use of org.nd4j.linalg.api.buffer.DataBuffer in project nd4j by deeplearning4j.

the class Nd4j method createBuffer.

/**
 * Create a buffer based on the data opType
 *
 * @param data the data to create the buffer with
 * @return the created buffer
 */
public static DataBuffer createBuffer(byte[] data, int length, long offset) {
    DataBuffer ret;
    if (dataType() == DataBuffer.Type.DOUBLE)
        ret = DATA_BUFFER_FACTORY_INSTANCE.createDouble(offset, data, length);
    else
        ret = DATA_BUFFER_FACTORY_INSTANCE.createFloat(offset, data, length);
    logCreationIfNecessary(ret);
    return ret;
}
Also used : DataBuffer(org.nd4j.linalg.api.buffer.DataBuffer) CompressedDataBuffer(org.nd4j.linalg.compression.CompressedDataBuffer)

Example 23 with DataBuffer

use of org.nd4j.linalg.api.buffer.DataBuffer in project nd4j by deeplearning4j.

the class Nd4j method createBuffer.

/**
 * Create a buffer based on the data opType
 *
 * @param data the data to create the buffer with
 * @return the created buffer
 */
public static DataBuffer createBuffer(double[] data, long offset) {
    DataBuffer ret;
    if (dataType() == DataBuffer.Type.DOUBLE)
        ret = Nd4j.getMemoryManager().getCurrentWorkspace() == null ? DATA_BUFFER_FACTORY_INSTANCE.createDouble(offset, data) : DATA_BUFFER_FACTORY_INSTANCE.createDouble(offset, data, Nd4j.getMemoryManager().getCurrentWorkspace());
    else if (dataType() == DataBuffer.Type.HALF)
        ret = Nd4j.getMemoryManager().getCurrentWorkspace() == null ? DATA_BUFFER_FACTORY_INSTANCE.createHalf(offset, data) : DATA_BUFFER_FACTORY_INSTANCE.createHalf(offset, ArrayUtil.toFloats(data), Nd4j.getMemoryManager().getCurrentWorkspace());
    else
        ret = Nd4j.getMemoryManager().getCurrentWorkspace() == null ? DATA_BUFFER_FACTORY_INSTANCE.createFloat(offset, ArrayUtil.toFloats(data)) : DATA_BUFFER_FACTORY_INSTANCE.createFloat(offset, ArrayUtil.toFloats(data), Nd4j.getMemoryManager().getCurrentWorkspace());
    logCreationIfNecessary(ret);
    return ret;
}
Also used : DataBuffer(org.nd4j.linalg.api.buffer.DataBuffer) CompressedDataBuffer(org.nd4j.linalg.compression.CompressedDataBuffer)

Example 24 with DataBuffer

use of org.nd4j.linalg.api.buffer.DataBuffer in project nd4j by deeplearning4j.

the class OnnxGraphMapper method mapTensorProto.

public INDArray mapTensorProto(OnnxProto3.TensorProto tensor) {
    if (tensor == null)
        return null;
    DataBuffer.Type type = nd4jTypeFromOnnxType(tensor.getDataType());
    ByteString bytes = tensor.getRawData();
    ByteBuffer byteBuffer = bytes.asReadOnlyByteBuffer().order(ByteOrder.nativeOrder());
    ByteBuffer directAlloc = ByteBuffer.allocateDirect(byteBuffer.capacity()).order(ByteOrder.nativeOrder());
    directAlloc.put(byteBuffer);
    directAlloc.rewind();
    int[] shape = getShapeFromTensor(tensor);
    DataBuffer buffer = Nd4j.createBuffer(directAlloc, type, ArrayUtil.prod(shape));
    INDArray arr = Nd4j.create(buffer).reshape(shape);
    return arr;
}
Also used : INDArray(org.nd4j.linalg.api.ndarray.INDArray) ByteString(com.github.os72.protobuf351.ByteString) ByteBuffer(java.nio.ByteBuffer) DataBuffer(org.nd4j.linalg.api.buffer.DataBuffer)

Example 25 with DataBuffer

use of org.nd4j.linalg.api.buffer.DataBuffer in project nd4j by deeplearning4j.

the class OnnxGraphMapper method getNDArrayFromTensor.

@Override
public INDArray getNDArrayFromTensor(String tensorName, OnnxProto3.TypeProto.Tensor tensorProto, OnnxProto3.GraphProto graph) {
    DataBuffer.Type type = dataTypeForTensor(tensorProto);
    if (!tensorProto.isInitialized()) {
        throw new ND4JIllegalStateException("Unable to retrieve ndarray. Tensor was not initialized");
    }
    OnnxProto3.TensorProto tensor = null;
    for (int i = 0; i < graph.getInitializerCount(); i++) {
        val initializer = graph.getInitializer(i);
        if (initializer.getName().equals(tensorName)) {
            tensor = initializer;
            break;
        }
    }
    if (tensor == null)
        return null;
    ByteString bytes = tensor.getRawData();
    ByteBuffer byteBuffer = bytes.asReadOnlyByteBuffer().order(ByteOrder.nativeOrder());
    ByteBuffer directAlloc = ByteBuffer.allocateDirect(byteBuffer.capacity()).order(ByteOrder.nativeOrder());
    directAlloc.put(byteBuffer);
    directAlloc.rewind();
    int[] shape = getShapeFromTensor(tensorProto);
    DataBuffer buffer = Nd4j.createBuffer(directAlloc, type, ArrayUtil.prod(shape));
    INDArray arr = Nd4j.create(buffer).reshape(shape);
    return arr;
}
Also used : lombok.val(lombok.val) INDArray(org.nd4j.linalg.api.ndarray.INDArray) ByteString(com.github.os72.protobuf351.ByteString) OnnxProto3(onnx.OnnxProto3) ND4JIllegalStateException(org.nd4j.linalg.exception.ND4JIllegalStateException) ByteBuffer(java.nio.ByteBuffer) DataBuffer(org.nd4j.linalg.api.buffer.DataBuffer)

Aggregations

DataBuffer (org.nd4j.linalg.api.buffer.DataBuffer)186 INDArray (org.nd4j.linalg.api.ndarray.INDArray)79 Test (org.junit.Test)47 CompressedDataBuffer (org.nd4j.linalg.compression.CompressedDataBuffer)44 CudaContext (org.nd4j.linalg.jcublas.context.CudaContext)39 CudaPointer (org.nd4j.jita.allocator.pointers.CudaPointer)30 AllocationPoint (org.nd4j.jita.allocator.impl.AllocationPoint)25 ND4JIllegalStateException (org.nd4j.linalg.exception.ND4JIllegalStateException)23 BaseDataBuffer (org.nd4j.linalg.api.buffer.BaseDataBuffer)19 Pointer (org.bytedeco.javacpp.Pointer)18 BaseNd4jTest (org.nd4j.linalg.BaseNd4jTest)16 CudaDoubleDataBuffer (org.nd4j.linalg.jcublas.buffer.CudaDoubleDataBuffer)16 IntPointer (org.bytedeco.javacpp.IntPointer)13 PagedPointer (org.nd4j.linalg.api.memory.pointers.PagedPointer)13 CudaIntDataBuffer (org.nd4j.linalg.jcublas.buffer.CudaIntDataBuffer)13 DoublePointer (org.bytedeco.javacpp.DoublePointer)12 FloatPointer (org.bytedeco.javacpp.FloatPointer)12 GridExecutioner (org.nd4j.linalg.api.ops.executioner.GridExecutioner)12 LongPointerWrapper (org.nd4j.nativeblas.LongPointerWrapper)11 CUstream_st (org.bytedeco.javacpp.cuda.CUstream_st)10