use of ai.djl.ndarray.types.DataType in project djl by deepjavalibrary.
the class BaseNDManager method copyBuffer.
/**
* Copies data from the source {@code Buffer} to the target {@code ByteBuffer}.
*
* @param src the source {@code Buffer}
* @param target the target {@code ByteBuffer}
*/
public static void copyBuffer(Buffer src, ByteBuffer target) {
target.rewind();
DataType inputType = DataType.fromBuffer(src);
switch(inputType) {
case FLOAT16:
target.asShortBuffer().put((ShortBuffer) src);
break;
case FLOAT32:
target.asFloatBuffer().put((FloatBuffer) src);
break;
case FLOAT64:
target.asDoubleBuffer().put((DoubleBuffer) src);
break;
case UINT8:
case INT8:
case BOOLEAN:
target.put((ByteBuffer) src);
break;
case INT32:
target.asIntBuffer().put((IntBuffer) src);
break;
case INT64:
target.asLongBuffer().put((LongBuffer) src);
break;
default:
throw new AssertionError("Unsupported datatype: " + inputType);
}
target.rewind();
}
use of ai.djl.ndarray.types.DataType in project djl by deepjavalibrary.
the class StackBatchifier method batchify.
/**
* {@inheritDoc}
*/
@Override
public NDList batchify(NDList[] inputs) {
// each input as NDList might contain several data or labels
// so those should be batchified with counterpart
int batchSize = inputs.length;
int numInputKinds = inputs[0].size();
// if the NDList is empty
if (numInputKinds == 0) {
return new NDList();
}
try {
// stack all the data and labels together
NDList result = new NDList(numInputKinds);
for (int i = 0; i < numInputKinds; i++) {
NDList inputsOfKind = new NDList(batchSize);
String inputName = inputs[0].get(i).getName();
for (NDList input : inputs) {
inputsOfKind.add(input.get(i));
}
NDArray stacked = NDArrays.stack(new NDList(inputsOfKind));
// keep the name for stacked inputs
stacked.setName(inputName);
result.add(stacked);
}
return result;
} catch (IndexOutOfBoundsException | EngineException e) {
// Check if numInputKinds is not consistant for all inputs
for (NDList input : inputs) {
if (input.size() != numInputKinds) {
throw new IllegalArgumentException("You cannot batch data with different numbers of inputs", e);
}
}
// Check if data does not have a consistent shape or type
for (int i = 0; i < numInputKinds; i++) {
Shape kindDataShape = inputs[0].get(i).getShape();
DataType kindDataType = inputs[0].get(i).getDataType();
for (NDList input : inputs) {
NDArray currInput = input.get(i);
if (!currInput.getShape().equals(kindDataShape)) {
throw new IllegalArgumentException("You cannot batch data with different input shapes", e);
}
if (!currInput.getDataType().equals(kindDataType)) {
throw new IllegalArgumentException("You cannot batch data with different input data types", e);
}
}
}
// Could not clarify cause - rethrow original error.
throw e;
}
}
use of ai.djl.ndarray.types.DataType in project djl by deepjavalibrary.
the class MxNDManager method createCSR.
/**
* {@inheritDoc}
*/
@Override
public MxNDArray createCSR(Buffer data, long[] indptr, long[] indices, Shape shape) {
SparseFormat fmt = SparseFormat.CSR;
DataType dataType = DataType.fromBuffer(data);
MxNDArray indptrNd = create(new Shape(indptr.length), DataType.INT64);
indptrNd.set(indptr);
MxNDArray indicesNd = create(new Shape(indices.length), DataType.INT64);
indicesNd.set(indices);
Pointer handle = JnaUtils.createSparseNdArray(fmt, device, shape, dataType, new DataType[] { indptrNd.getDataType(), indicesNd.getDataType() }, new Shape[] { indptrNd.getShape(), indicesNd.getShape() }, false);
MxNDArray sparse = create(handle, fmt);
MxNDArray dataNd = create(new Shape(data.remaining()), dataType);
dataNd.set(data);
JnaUtils.ndArraySyncCopyFromNdArray(sparse, dataNd, -1);
JnaUtils.ndArraySyncCopyFromNdArray(sparse, indptrNd, 0);
JnaUtils.ndArraySyncCopyFromNdArray(sparse, indicesNd, 1);
return sparse;
}
use of ai.djl.ndarray.types.DataType in project djl by deepjavalibrary.
the class MxNDManager method createRowSparse.
/**
* {@inheritDoc}
*/
@Override
public MxNDArray createRowSparse(Buffer data, Shape dataShape, long[] indices, Shape shape) {
SparseFormat fmt = SparseFormat.ROW_SPARSE;
DataType dataType = DataType.fromBuffer(data);
MxNDArray indicesNd = create(new Shape(indices.length), DataType.INT64);
indicesNd.set(indices);
Pointer handle = JnaUtils.createSparseNdArray(fmt, device, shape, dataType, new DataType[] { indicesNd.getDataType() }, new Shape[] { indicesNd.getShape() }, false);
MxNDArray sparse = create(handle, fmt);
MxNDArray dataNd = create(dataShape, dataType);
dataNd.set(data);
JnaUtils.ndArraySyncCopyFromNdArray(sparse, dataNd, -1);
JnaUtils.ndArraySyncCopyFromNdArray(sparse, indicesNd, 0);
return sparse;
}
use of ai.djl.ndarray.types.DataType in project djl by deepjavalibrary.
the class XgbNDManager method create.
/**
* {@inheritDoc}
*/
@Override
public NDArray create(Buffer data, Shape shape, DataType dataType) {
if (shape.dimension() != 2) {
if (data instanceof ByteBuffer) {
// output only NDArray
return new XgbNDArray(this, alternativeManager, (ByteBuffer) data, shape, dataType);
}
if (alternativeManager != null) {
return alternativeManager.create(data, shape, dataType);
}
throw new UnsupportedOperationException("XgbNDArray shape must be in two dimension.");
}
if (dataType != DataType.FLOAT32) {
if (data instanceof ByteBuffer) {
// output only NDArray
return new XgbNDArray(this, alternativeManager, (ByteBuffer) data, shape, dataType);
}
if (alternativeManager != null) {
return alternativeManager.create(data, shape, dataType);
}
throw new UnsupportedOperationException("XgbNDArray only supports float32.");
}
if (data.isDirect() && data instanceof ByteBuffer) {
// TODO: allow user to set missing value
long handle = JniUtils.createDMatrix(data, shape, missingValue);
return new XgbNDArray(this, alternativeManager, handle, shape, SparseFormat.DENSE);
}
DataType inputType = DataType.fromBuffer(data);
if (inputType != DataType.FLOAT32) {
throw new UnsupportedOperationException("Only Float32 data type supported, actual " + inputType);
}
int size = Math.toIntExact(shape.size() * DataType.FLOAT32.getNumOfBytes());
ByteBuffer buf = allocateDirect(size);
buf.asFloatBuffer().put((FloatBuffer) data);
buf.rewind();
long handle = JniUtils.createDMatrix(buf, shape, missingValue);
return new XgbNDArray(this, alternativeManager, handle, shape, SparseFormat.DENSE);
}
Aggregations