Search in sources :

Example 1 with DataType

use of ai.djl.ndarray.types.DataType in project djl by deepjavalibrary.

the class BaseNDManager method copyBuffer.

 * Copies data from the source {@code Buffer} to the target {@code ByteBuffer}.
 * @param src the source {@code Buffer}
 * @param target the target {@code ByteBuffer}
public static void copyBuffer(Buffer src, ByteBuffer target) {
    DataType inputType = DataType.fromBuffer(src);
    switch(inputType) {
        case FLOAT16:
            target.asShortBuffer().put((ShortBuffer) src);
        case FLOAT32:
            target.asFloatBuffer().put((FloatBuffer) src);
        case FLOAT64:
            target.asDoubleBuffer().put((DoubleBuffer) src);
        case UINT8:
        case INT8:
        case BOOLEAN:
            target.put((ByteBuffer) src);
        case INT32:
            target.asIntBuffer().put((IntBuffer) src);
        case INT64:
            target.asLongBuffer().put((LongBuffer) src);
            throw new AssertionError("Unsupported datatype: " + inputType);
Also used : DataType(ai.djl.ndarray.types.DataType)

Example 2 with DataType

use of ai.djl.ndarray.types.DataType in project djl by deepjavalibrary.

the class StackBatchifier method batchify.

 * {@inheritDoc}
public NDList batchify(NDList[] inputs) {
    // each input as NDList might contain several data or labels
    // so those should be batchified with counterpart
    int batchSize = inputs.length;
    int numInputKinds = inputs[0].size();
    // if the NDList is empty
    if (numInputKinds == 0) {
        return new NDList();
    try {
        // stack all the data and labels together
        NDList result = new NDList(numInputKinds);
        for (int i = 0; i < numInputKinds; i++) {
            NDList inputsOfKind = new NDList(batchSize);
            String inputName = inputs[0].get(i).getName();
            for (NDList input : inputs) {
            NDArray stacked = NDArrays.stack(new NDList(inputsOfKind));
            // keep the name for stacked inputs
        return result;
    } catch (IndexOutOfBoundsException | EngineException e) {
        // Check if numInputKinds is not consistant for all inputs
        for (NDList input : inputs) {
            if (input.size() != numInputKinds) {
                throw new IllegalArgumentException("You cannot batch data with different numbers of inputs", e);
        // Check if data does not have a consistent shape or type
        for (int i = 0; i < numInputKinds; i++) {
            Shape kindDataShape = inputs[0].get(i).getShape();
            DataType kindDataType = inputs[0].get(i).getDataType();
            for (NDList input : inputs) {
                NDArray currInput = input.get(i);
                if (!currInput.getShape().equals(kindDataShape)) {
                    throw new IllegalArgumentException("You cannot batch data with different input shapes", e);
                if (!currInput.getDataType().equals(kindDataType)) {
                    throw new IllegalArgumentException("You cannot batch data with different input data types", e);
        // Could not clarify cause - rethrow original error.
        throw e;
Also used : Shape(ai.djl.ndarray.types.Shape) NDList(ai.djl.ndarray.NDList) EngineException(ai.djl.engine.EngineException) DataType(ai.djl.ndarray.types.DataType) NDArray(ai.djl.ndarray.NDArray)

Example 3 with DataType

use of ai.djl.ndarray.types.DataType in project djl by deepjavalibrary.

the class MxNDManager method createCSR.

 * {@inheritDoc}
public MxNDArray createCSR(Buffer data, long[] indptr, long[] indices, Shape shape) {
    SparseFormat fmt = SparseFormat.CSR;
    DataType dataType = DataType.fromBuffer(data);
    MxNDArray indptrNd = create(new Shape(indptr.length), DataType.INT64);
    MxNDArray indicesNd = create(new Shape(indices.length), DataType.INT64);
    Pointer handle = JnaUtils.createSparseNdArray(fmt, device, shape, dataType, new DataType[] { indptrNd.getDataType(), indicesNd.getDataType() }, new Shape[] { indptrNd.getShape(), indicesNd.getShape() }, false);
    MxNDArray sparse = create(handle, fmt);
    MxNDArray dataNd = create(new Shape(data.remaining()), dataType);
    JnaUtils.ndArraySyncCopyFromNdArray(sparse, dataNd, -1);
    JnaUtils.ndArraySyncCopyFromNdArray(sparse, indptrNd, 0);
    JnaUtils.ndArraySyncCopyFromNdArray(sparse, indicesNd, 1);
    return sparse;
Also used : SparseFormat(ai.djl.ndarray.types.SparseFormat) Shape(ai.djl.ndarray.types.Shape) DataType(ai.djl.ndarray.types.DataType) Pointer(com.sun.jna.Pointer)

Example 4 with DataType

use of ai.djl.ndarray.types.DataType in project djl by deepjavalibrary.

the class MxNDManager method createRowSparse.

 * {@inheritDoc}
public MxNDArray createRowSparse(Buffer data, Shape dataShape, long[] indices, Shape shape) {
    SparseFormat fmt = SparseFormat.ROW_SPARSE;
    DataType dataType = DataType.fromBuffer(data);
    MxNDArray indicesNd = create(new Shape(indices.length), DataType.INT64);
    Pointer handle = JnaUtils.createSparseNdArray(fmt, device, shape, dataType, new DataType[] { indicesNd.getDataType() }, new Shape[] { indicesNd.getShape() }, false);
    MxNDArray sparse = create(handle, fmt);
    MxNDArray dataNd = create(dataShape, dataType);
    JnaUtils.ndArraySyncCopyFromNdArray(sparse, dataNd, -1);
    JnaUtils.ndArraySyncCopyFromNdArray(sparse, indicesNd, 0);
    return sparse;
Also used : SparseFormat(ai.djl.ndarray.types.SparseFormat) Shape(ai.djl.ndarray.types.Shape) DataType(ai.djl.ndarray.types.DataType) Pointer(com.sun.jna.Pointer)

Example 5 with DataType

use of ai.djl.ndarray.types.DataType in project djl by deepjavalibrary.

the class XgbNDManager method create.

 * {@inheritDoc}
public NDArray create(Buffer data, Shape shape, DataType dataType) {
    if (shape.dimension() != 2) {
        if (data instanceof ByteBuffer) {
            // output only NDArray
            return new XgbNDArray(this, alternativeManager, (ByteBuffer) data, shape, dataType);
        if (alternativeManager != null) {
            return alternativeManager.create(data, shape, dataType);
        throw new UnsupportedOperationException("XgbNDArray shape must be in two dimension.");
    if (dataType != DataType.FLOAT32) {
        if (data instanceof ByteBuffer) {
            // output only NDArray
            return new XgbNDArray(this, alternativeManager, (ByteBuffer) data, shape, dataType);
        if (alternativeManager != null) {
            return alternativeManager.create(data, shape, dataType);
        throw new UnsupportedOperationException("XgbNDArray only supports float32.");
    if (data.isDirect() && data instanceof ByteBuffer) {
        // TODO: allow user to set missing value
        long handle = JniUtils.createDMatrix(data, shape, missingValue);
        return new XgbNDArray(this, alternativeManager, handle, shape, SparseFormat.DENSE);
    DataType inputType = DataType.fromBuffer(data);
    if (inputType != DataType.FLOAT32) {
        throw new UnsupportedOperationException("Only Float32 data type supported, actual " + inputType);
    int size = Math.toIntExact(shape.size() * DataType.FLOAT32.getNumOfBytes());
    ByteBuffer buf = allocateDirect(size);
    buf.asFloatBuffer().put((FloatBuffer) data);
    long handle = JniUtils.createDMatrix(buf, shape, missingValue);
    return new XgbNDArray(this, alternativeManager, handle, shape, SparseFormat.DENSE);
Also used : DataType(ai.djl.ndarray.types.DataType) ByteBuffer(java.nio.ByteBuffer)


DataType (ai.djl.ndarray.types.DataType)17 Shape (ai.djl.ndarray.types.Shape)11 ByteBuffer (java.nio.ByteBuffer)8 Pointer (com.sun.jna.Pointer)4 SparseFormat (ai.djl.ndarray.types.SparseFormat)3 EngineException (ai.djl.engine.EngineException)2 NDList (ai.djl.ndarray.NDList)2 DataInputStream ( IntBuffer (java.nio.IntBuffer)2 Matcher (java.util.regex.Matcher)2 CommandLine (org.apache.commons.cli.CommandLine)2 DefaultParser (org.apache.commons.cli.DefaultParser)2 Options (org.apache.commons.cli.Options)2 Device (ai.djl.Device)1 NDArray (ai.djl.ndarray.NDArray)1 NDManager (ai.djl.ndarray.NDManager)1 NDIndex (ai.djl.ndarray.index.NDIndex)1 NDArrayEx (ai.djl.ndarray.internal.NDArrayEx)1 NDFormat (ai.djl.ndarray.internal.NDFormat)1 Float16Utils (ai.djl.util.Float16Utils)1