Search in sources :

Example 11 with DataBuffer

use of org.nd4j.linalg.api.buffer.DataBuffer in project nd4j by deeplearning4j.

the class Nd4j method createBuffer.

/**
 * Creates a buffer of the specified length based on the data opType
 *
 * @param length the length of te buffer
 * @return the buffer to create
 */
public static DataBuffer createBuffer(int length, long offset) {
    DataBuffer ret;
    if (dataType() == DataBuffer.Type.FLOAT)
        ret = DATA_BUFFER_FACTORY_INSTANCE.createFloat(offset, length);
    else if (dataType() == DataBuffer.Type.INT)
        ret = DATA_BUFFER_FACTORY_INSTANCE.createInt(offset, length);
    else if (dataType() == DataBuffer.Type.DOUBLE)
        ret = DATA_BUFFER_FACTORY_INSTANCE.createDouble(offset, length);
    else if (dataType() == DataBuffer.Type.HALF)
        ret = DATA_BUFFER_FACTORY_INSTANCE.createHalf(offset, length);
    else
        ret = null;
    logCreationIfNecessary(ret);
    return ret;
}
Also used : DataBuffer(org.nd4j.linalg.api.buffer.DataBuffer) CompressedDataBuffer(org.nd4j.linalg.compression.CompressedDataBuffer)

Example 12 with DataBuffer

use of org.nd4j.linalg.api.buffer.DataBuffer in project nd4j by deeplearning4j.

the class Nd4j method createBuffer.

/**
 * Create a buffer equal of length prod(shape)
 *
 * @param data the shape of the buffer to create
 * @return the created buffer
 */
public static DataBuffer createBuffer(int[] data, long offset) {
    DataBuffer ret;
    ret = DATA_BUFFER_FACTORY_INSTANCE.createInt(offset, data);
    logCreationIfNecessary(ret);
    return ret;
}
Also used : DataBuffer(org.nd4j.linalg.api.buffer.DataBuffer) CompressedDataBuffer(org.nd4j.linalg.compression.CompressedDataBuffer)

Example 13 with DataBuffer

use of org.nd4j.linalg.api.buffer.DataBuffer in project nd4j by deeplearning4j.

the class Nd4j method read.

/**
 * Read in an ndarray from a data input stream
 *
 * @param dis the data input stream to read from
 * @return the ndarray
 * @throws IOException
 */
public static INDArray read(DataInputStream dis) throws IOException {
    DataBuffer shapeInformation = Nd4j.createBufferDetached(new int[1], DataBuffer.Type.INT);
    shapeInformation.read(dis);
    int length = Shape.length(shapeInformation);
    DataBuffer data = CompressedDataBuffer.readUnknown(dis, length);
    return createArrayFromShapeBuffer(data, shapeInformation);
}
Also used : DataBuffer(org.nd4j.linalg.api.buffer.DataBuffer) CompressedDataBuffer(org.nd4j.linalg.compression.CompressedDataBuffer)

Example 14 with DataBuffer

use of org.nd4j.linalg.api.buffer.DataBuffer in project nd4j by deeplearning4j.

the class Nd4j method createBufferDetached.

public static DataBuffer createBufferDetached(double[] data) {
    DataBuffer ret;
    if (dataType() == DataBuffer.Type.DOUBLE)
        ret = DATA_BUFFER_FACTORY_INSTANCE.createDouble(data);
    else if (dataType() == DataBuffer.Type.HALF)
        ret = DATA_BUFFER_FACTORY_INSTANCE.createHalf(ArrayUtil.toFloats(data));
    else
        ret = DATA_BUFFER_FACTORY_INSTANCE.createFloat(ArrayUtil.toFloats(data));
    logCreationIfNecessary(ret);
    return ret;
}
Also used : DataBuffer(org.nd4j.linalg.api.buffer.DataBuffer) CompressedDataBuffer(org.nd4j.linalg.compression.CompressedDataBuffer)

Example 15 with DataBuffer

use of org.nd4j.linalg.api.buffer.DataBuffer in project nd4j by deeplearning4j.

the class BaseNDArray method tensorAlongDimension.

@Override
public INDArray tensorAlongDimension(int index, int... dimension) {
    if (dimension == null || dimension.length == 0)
        throw new IllegalArgumentException("Invalid input: dimensions not specified (null or length 0)");
    if (dimension.length >= rank() || dimension.length == 1 && dimension[0] == Integer.MAX_VALUE)
        return this;
    for (int i = 0; i < dimension.length; i++) if (dimension[i] < 0)
        dimension[i] += rank();
    // dedup
    if (dimension.length > 1)
        dimension = Ints.toArray(new ArrayList<>(new TreeSet<>(Ints.asList(dimension))));
    if (dimension.length > 1) {
        Arrays.sort(dimension);
    }
    int tads = tensorssAlongDimension(dimension);
    if (index >= tads)
        throw new IllegalArgumentException("Illegal index " + index + " out of tads " + tads);
    if (dimension.length == 1) {
        if (dimension[0] == 0 && isColumnVector()) {
            return this.transpose();
        } else if (dimension[0] == 1 && isRowVector()) {
            return this;
        }
    }
    Pair<DataBuffer, DataBuffer> tadInfo = Nd4j.getExecutioner().getTADManager().getTADOnlyShapeInfo(this, dimension);
    DataBuffer shapeInfo = tadInfo.getFirst();
    int[] shape = Shape.shape(shapeInfo);
    int[] stride = Shape.stride(shapeInfo).asInt();
    long offset = offset() + tadInfo.getSecond().getLong(index);
    INDArray toTad = Nd4j.create(data(), shape, stride, offset);
    BaseNDArray baseNDArray = (BaseNDArray) toTad;
    // preserve immutability
    char newOrder = Shape.getOrder(shape, stride, 1);
    int ews = baseNDArray.shapeInfoDataBuffer().getInt(baseNDArray.shapeInfoDataBuffer().length() - 2);
    // for row vector shapes though.
    if (!Shape.isRowVectorShape(baseNDArray.shapeInfoDataBuffer()))
        ews = -1;
    // we create new shapeInfo with possibly new ews & order
    /**
     * NOTE HERE THAT ZERO IS PRESET FOR THE OFFSET AND SHOULD STAY LIKE THAT.
     * Zero is preset for caching purposes.
     * We don't actually use the offset defined in the
     * shape info data buffer.
     * We calculate and cache the offsets separately.
     */
    baseNDArray.setShapeInformation(Nd4j.getShapeInfoProvider().createShapeInformation(shape, stride, 0, ews, newOrder));
    return toTad;
}
Also used : ND4JIllegalArgumentException(org.nd4j.linalg.exception.ND4JIllegalArgumentException) DataBuffer(org.nd4j.linalg.api.buffer.DataBuffer)

Aggregations

DataBuffer (org.nd4j.linalg.api.buffer.DataBuffer)186 INDArray (org.nd4j.linalg.api.ndarray.INDArray)79 Test (org.junit.Test)47 CompressedDataBuffer (org.nd4j.linalg.compression.CompressedDataBuffer)44 CudaContext (org.nd4j.linalg.jcublas.context.CudaContext)39 CudaPointer (org.nd4j.jita.allocator.pointers.CudaPointer)30 AllocationPoint (org.nd4j.jita.allocator.impl.AllocationPoint)25 ND4JIllegalStateException (org.nd4j.linalg.exception.ND4JIllegalStateException)23 BaseDataBuffer (org.nd4j.linalg.api.buffer.BaseDataBuffer)19 Pointer (org.bytedeco.javacpp.Pointer)18 BaseNd4jTest (org.nd4j.linalg.BaseNd4jTest)16 CudaDoubleDataBuffer (org.nd4j.linalg.jcublas.buffer.CudaDoubleDataBuffer)16 IntPointer (org.bytedeco.javacpp.IntPointer)13 PagedPointer (org.nd4j.linalg.api.memory.pointers.PagedPointer)13 CudaIntDataBuffer (org.nd4j.linalg.jcublas.buffer.CudaIntDataBuffer)13 DoublePointer (org.bytedeco.javacpp.DoublePointer)12 FloatPointer (org.bytedeco.javacpp.FloatPointer)12 GridExecutioner (org.nd4j.linalg.api.ops.executioner.GridExecutioner)12 LongPointerWrapper (org.nd4j.nativeblas.LongPointerWrapper)11 CUstream_st (org.bytedeco.javacpp.cuda.CUstream_st)10