Search in sources :

Example 1 with DataType

use of ai.djl.ndarray.types.DataType in project djl by deepjavalibrary.

the class BaseNDManager method copyBuffer.

/**
 * Copies data from the source {@code Buffer} to the target {@code ByteBuffer}.
 *
 * @param src the source {@code Buffer}
 * @param target the target {@code ByteBuffer}
 */
public static void copyBuffer(Buffer src, ByteBuffer target) {
    target.rewind();
    DataType inputType = DataType.fromBuffer(src);
    switch(inputType) {
        case FLOAT16:
            target.asShortBuffer().put((ShortBuffer) src);
            break;
        case FLOAT32:
            target.asFloatBuffer().put((FloatBuffer) src);
            break;
        case FLOAT64:
            target.asDoubleBuffer().put((DoubleBuffer) src);
            break;
        case UINT8:
        case INT8:
        case BOOLEAN:
            target.put((ByteBuffer) src);
            break;
        case INT32:
            target.asIntBuffer().put((IntBuffer) src);
            break;
        case INT64:
            target.asLongBuffer().put((LongBuffer) src);
            break;
        default:
            throw new AssertionError("Unsupported datatype: " + inputType);
    }
    target.rewind();
}
Also used : DataType(ai.djl.ndarray.types.DataType)

Example 2 with DataType

use of ai.djl.ndarray.types.DataType in project djl by deepjavalibrary.

the class StackBatchifier method batchify.

/**
 * {@inheritDoc}
 */
@Override
public NDList batchify(NDList[] inputs) {
    // each input as NDList might contain several data or labels
    // so those should be batchified with counterpart
    int batchSize = inputs.length;
    int numInputKinds = inputs[0].size();
    // if the NDList is empty
    if (numInputKinds == 0) {
        return new NDList();
    }
    try {
        // stack all the data and labels together
        NDList result = new NDList(numInputKinds);
        for (int i = 0; i < numInputKinds; i++) {
            NDList inputsOfKind = new NDList(batchSize);
            String inputName = inputs[0].get(i).getName();
            for (NDList input : inputs) {
                inputsOfKind.add(input.get(i));
            }
            NDArray stacked = NDArrays.stack(new NDList(inputsOfKind));
            // keep the name for stacked inputs
            stacked.setName(inputName);
            result.add(stacked);
        }
        return result;
    } catch (IndexOutOfBoundsException | EngineException e) {
        // Check if numInputKinds is not consistant for all inputs
        for (NDList input : inputs) {
            if (input.size() != numInputKinds) {
                throw new IllegalArgumentException("You cannot batch data with different numbers of inputs", e);
            }
        }
        // Check if data does not have a consistent shape or type
        for (int i = 0; i < numInputKinds; i++) {
            Shape kindDataShape = inputs[0].get(i).getShape();
            DataType kindDataType = inputs[0].get(i).getDataType();
            for (NDList input : inputs) {
                NDArray currInput = input.get(i);
                if (!currInput.getShape().equals(kindDataShape)) {
                    throw new IllegalArgumentException("You cannot batch data with different input shapes", e);
                }
                if (!currInput.getDataType().equals(kindDataType)) {
                    throw new IllegalArgumentException("You cannot batch data with different input data types", e);
                }
            }
        }
        // Could not clarify cause - rethrow original error.
        throw e;
    }
}
Also used : Shape(ai.djl.ndarray.types.Shape) NDList(ai.djl.ndarray.NDList) EngineException(ai.djl.engine.EngineException) DataType(ai.djl.ndarray.types.DataType) NDArray(ai.djl.ndarray.NDArray)

Example 3 with DataType

use of ai.djl.ndarray.types.DataType in project djl by deepjavalibrary.

the class MxNDManager method createCSR.

/**
 * {@inheritDoc}
 */
@Override
public MxNDArray createCSR(Buffer data, long[] indptr, long[] indices, Shape shape) {
    SparseFormat fmt = SparseFormat.CSR;
    DataType dataType = DataType.fromBuffer(data);
    MxNDArray indptrNd = create(new Shape(indptr.length), DataType.INT64);
    indptrNd.set(indptr);
    MxNDArray indicesNd = create(new Shape(indices.length), DataType.INT64);
    indicesNd.set(indices);
    Pointer handle = JnaUtils.createSparseNdArray(fmt, device, shape, dataType, new DataType[] { indptrNd.getDataType(), indicesNd.getDataType() }, new Shape[] { indptrNd.getShape(), indicesNd.getShape() }, false);
    MxNDArray sparse = create(handle, fmt);
    MxNDArray dataNd = create(new Shape(data.remaining()), dataType);
    dataNd.set(data);
    JnaUtils.ndArraySyncCopyFromNdArray(sparse, dataNd, -1);
    JnaUtils.ndArraySyncCopyFromNdArray(sparse, indptrNd, 0);
    JnaUtils.ndArraySyncCopyFromNdArray(sparse, indicesNd, 1);
    return sparse;
}
Also used : SparseFormat(ai.djl.ndarray.types.SparseFormat) Shape(ai.djl.ndarray.types.Shape) DataType(ai.djl.ndarray.types.DataType) Pointer(com.sun.jna.Pointer)

Example 4 with DataType

use of ai.djl.ndarray.types.DataType in project djl by deepjavalibrary.

the class MxNDManager method createRowSparse.

/**
 * {@inheritDoc}
 */
@Override
public MxNDArray createRowSparse(Buffer data, Shape dataShape, long[] indices, Shape shape) {
    SparseFormat fmt = SparseFormat.ROW_SPARSE;
    DataType dataType = DataType.fromBuffer(data);
    MxNDArray indicesNd = create(new Shape(indices.length), DataType.INT64);
    indicesNd.set(indices);
    Pointer handle = JnaUtils.createSparseNdArray(fmt, device, shape, dataType, new DataType[] { indicesNd.getDataType() }, new Shape[] { indicesNd.getShape() }, false);
    MxNDArray sparse = create(handle, fmt);
    MxNDArray dataNd = create(dataShape, dataType);
    dataNd.set(data);
    JnaUtils.ndArraySyncCopyFromNdArray(sparse, dataNd, -1);
    JnaUtils.ndArraySyncCopyFromNdArray(sparse, indicesNd, 0);
    return sparse;
}
Also used : SparseFormat(ai.djl.ndarray.types.SparseFormat) Shape(ai.djl.ndarray.types.Shape) DataType(ai.djl.ndarray.types.DataType) Pointer(com.sun.jna.Pointer)

Example 5 with DataType

use of ai.djl.ndarray.types.DataType in project djl by deepjavalibrary.

the class XgbNDManager method create.

/**
 * {@inheritDoc}
 */
@Override
public NDArray create(Buffer data, Shape shape, DataType dataType) {
    if (shape.dimension() != 2) {
        if (data instanceof ByteBuffer) {
            // output only NDArray
            return new XgbNDArray(this, alternativeManager, (ByteBuffer) data, shape, dataType);
        }
        if (alternativeManager != null) {
            return alternativeManager.create(data, shape, dataType);
        }
        throw new UnsupportedOperationException("XgbNDArray shape must be in two dimension.");
    }
    if (dataType != DataType.FLOAT32) {
        if (data instanceof ByteBuffer) {
            // output only NDArray
            return new XgbNDArray(this, alternativeManager, (ByteBuffer) data, shape, dataType);
        }
        if (alternativeManager != null) {
            return alternativeManager.create(data, shape, dataType);
        }
        throw new UnsupportedOperationException("XgbNDArray only supports float32.");
    }
    if (data.isDirect() && data instanceof ByteBuffer) {
        // TODO: allow user to set missing value
        long handle = JniUtils.createDMatrix(data, shape, missingValue);
        return new XgbNDArray(this, alternativeManager, handle, shape, SparseFormat.DENSE);
    }
    DataType inputType = DataType.fromBuffer(data);
    if (inputType != DataType.FLOAT32) {
        throw new UnsupportedOperationException("Only Float32 data type supported, actual " + inputType);
    }
    int size = Math.toIntExact(shape.size() * DataType.FLOAT32.getNumOfBytes());
    ByteBuffer buf = allocateDirect(size);
    buf.asFloatBuffer().put((FloatBuffer) data);
    buf.rewind();
    long handle = JniUtils.createDMatrix(buf, shape, missingValue);
    return new XgbNDArray(this, alternativeManager, handle, shape, SparseFormat.DENSE);
}
Also used : DataType(ai.djl.ndarray.types.DataType) ByteBuffer(java.nio.ByteBuffer)

Aggregations

DataType (ai.djl.ndarray.types.DataType)17 Shape (ai.djl.ndarray.types.Shape)11 ByteBuffer (java.nio.ByteBuffer)8 Pointer (com.sun.jna.Pointer)4 SparseFormat (ai.djl.ndarray.types.SparseFormat)3 EngineException (ai.djl.engine.EngineException)2 NDList (ai.djl.ndarray.NDList)2 DataInputStream (java.io.DataInputStream)2 IntBuffer (java.nio.IntBuffer)2 Matcher (java.util.regex.Matcher)2 CommandLine (org.apache.commons.cli.CommandLine)2 DefaultParser (org.apache.commons.cli.DefaultParser)2 Options (org.apache.commons.cli.Options)2 Device (ai.djl.Device)1 NDArray (ai.djl.ndarray.NDArray)1 NDManager (ai.djl.ndarray.NDManager)1 NDIndex (ai.djl.ndarray.index.NDIndex)1 NDArrayEx (ai.djl.ndarray.internal.NDArrayEx)1 NDFormat (ai.djl.ndarray.internal.NDFormat)1 Float16Utils (ai.djl.util.Float16Utils)1