Search in sources :

Example 1 with TF_Status

use of org.tensorflow.internal.c_api.TF_Status in project djl by deepjavalibrary.

the class JavacppUtils method loadSavedModelBundle.

@SuppressWarnings({ "unchecked", "try" })
public static SavedModelBundle loadSavedModelBundle(String exportDir, String[] tags, ConfigProto config, RunOptions runOptions) {
    try (PointerScope ignored = new PointerScope()) {
        TF_Status status = TF_Status.newStatus();
        // allocate parameters for TF_LoadSessionFromSavedModel
        TF_SessionOptions opts = TF_SessionOptions.newSessionOptions();
        if (config != null) {
            BytePointer configBytes = new BytePointer(config.toByteArray());
            tensorflow.TF_SetConfig(opts, configBytes, configBytes.capacity(), status);
            status.throwExceptionIfNotOK();
        }
        TF_Buffer runOpts = TF_Buffer.newBufferFromString(runOptions);
        // load the session
        TF_Graph graphHandle = AbstractTF_Graph.newGraph().retainReference();
        TF_Buffer metaGraphDef = TF_Buffer.newBuffer();
        TF_Session sessionHandle = tensorflow.TF_LoadSessionFromSavedModel(opts, runOpts, new BytePointer(exportDir), new PointerPointer<>(tags), tags.length, graphHandle, metaGraphDef, status);
        status.throwExceptionIfNotOK();
        // handle the result
        try {
            return new SavedModelBundle(graphHandle, sessionHandle, MetaGraphDef.parseFrom(metaGraphDef.dataAsByteBuffer()));
        } catch (InvalidProtocolBufferException e) {
            throw new TensorFlowException("Cannot parse MetaGraphDef protocol buffer", e);
        }
    }
}
Also used : TF_Graph(org.tensorflow.internal.c_api.TF_Graph) AbstractTF_Graph(org.tensorflow.internal.c_api.AbstractTF_Graph) TF_Session(org.tensorflow.internal.c_api.TF_Session) TF_Status(org.tensorflow.internal.c_api.TF_Status) BytePointer(org.bytedeco.javacpp.BytePointer) InvalidProtocolBufferException(com.google.protobuf.InvalidProtocolBufferException) TF_Buffer(org.tensorflow.internal.c_api.TF_Buffer) SavedModelBundle(ai.djl.tensorflow.engine.SavedModelBundle) PointerScope(org.bytedeco.javacpp.PointerScope) TensorFlowException(org.tensorflow.exceptions.TensorFlowException) TF_SessionOptions(org.tensorflow.internal.c_api.TF_SessionOptions)

Example 2 with TF_Status

use of org.tensorflow.internal.c_api.TF_Status in project djl by deepjavalibrary.

the class JavacppUtils method getDevice.

@SuppressWarnings({ "unchecked", "try" })
public static Device getDevice(TFE_TensorHandle handle) {
    try (PointerScope ignored = new PointerScope()) {
        TF_Status status = TF_Status.newStatus();
        BytePointer pointer = tensorflow.TFE_TensorHandleDeviceName(handle, status);
        String device = new String(pointer.getStringBytes(), StandardCharsets.UTF_8);
        return fromTfDevice(device);
    }
}
Also used : TF_Status(org.tensorflow.internal.c_api.TF_Status) BytePointer(org.bytedeco.javacpp.BytePointer) TF_TString(org.tensorflow.internal.c_api.TF_TString) PointerScope(org.bytedeco.javacpp.PointerScope)

Example 3 with TF_Status

use of org.tensorflow.internal.c_api.TF_Status in project djl by deepjavalibrary.

the class JavacppUtils method createTFETensorFromByteBuffer.

@SuppressWarnings({ "unchecked", "try" })
public static TFE_TensorHandle createTFETensorFromByteBuffer(ByteBuffer buf, Shape shape, DataType dataType, TFE_Context eagerSessionHandle, Device device) {
    int dType = TfDataType.toTf(dataType);
    long[] dims = shape.getShape();
    long numBytes;
    if (dataType == DataType.STRING) {
        numBytes = buf.remaining() + 1;
    } else {
        numBytes = shape.size() * dataType.getNumOfBytes();
    }
    try (PointerScope ignored = new PointerScope()) {
        TF_Tensor tensor = AbstractTF_Tensor.allocateTensor(dType, dims, numBytes);
        // get data pointer in native engine
        Pointer pointer = tensorflow.TF_TensorData(tensor).capacity(numBytes);
        pointer.asByteBuffer().put(buf);
        TF_Status status = TF_Status.newStatus();
        TFE_TensorHandle handle = AbstractTFE_TensorHandle.newTensor(tensor, status);
        status.throwExceptionIfNotOK();
        if (device.isGpu()) {
            return toDevice(handle, eagerSessionHandle, device);
        }
        return handle.retainReference();
    }
}
Also used : TF_Tensor(org.tensorflow.internal.c_api.TF_Tensor) AbstractTF_Tensor(org.tensorflow.internal.c_api.AbstractTF_Tensor) TF_Status(org.tensorflow.internal.c_api.TF_Status) TFE_TensorHandle(org.tensorflow.internal.c_api.TFE_TensorHandle) AbstractTFE_TensorHandle(org.tensorflow.internal.c_api.AbstractTFE_TensorHandle) BytePointer(org.bytedeco.javacpp.BytePointer) Pointer(org.bytedeco.javacpp.Pointer) PointerPointer(org.bytedeco.javacpp.PointerPointer) PointerScope(org.bytedeco.javacpp.PointerScope)

Example 4 with TF_Status

use of org.tensorflow.internal.c_api.TF_Status in project djl by deepjavalibrary.

the class JavacppUtils method getString.

@SuppressWarnings({ "unchecked", "try" })
public static String[] getString(TFE_TensorHandle handle, int count, Charset charset) {
    try (PointerScope ignored = new PointerScope()) {
        // convert to TF_Tensor
        TF_Status status = TF_Status.newStatus();
        // should not add .withDeallocator() here, otherwise sting data will be destroyed
        TF_Tensor tensor = tensorflow.TFE_TensorHandleResolve(handle, status);
        status.throwExceptionIfNotOK();
        long tensorSize = tensorflow.TF_TensorByteSize(tensor);
        Pointer pointer = tensorflow.TF_TensorData(tensor).capacity(tensorSize);
        TF_TString data = new TF_TString(pointer).capacity(pointer.position() + count);
        String[] ret = new String[count];
        for (int i = 0; i < count; ++i) {
            TF_TString tstring = data.getPointer(i);
            long size = tensorflow.TF_TString_GetSize(tstring);
            BytePointer bp = tensorflow.TF_TString_GetDataPointer(tstring).capacity(size);
            ret[i] = bp.getString(charset);
        }
        // manually delete tensor
        tensorflow.TF_DeleteTensor(tensor);
        return ret;
    }
}
Also used : TF_Status(org.tensorflow.internal.c_api.TF_Status) TF_Tensor(org.tensorflow.internal.c_api.TF_Tensor) AbstractTF_Tensor(org.tensorflow.internal.c_api.AbstractTF_Tensor) BytePointer(org.bytedeco.javacpp.BytePointer) BytePointer(org.bytedeco.javacpp.BytePointer) Pointer(org.bytedeco.javacpp.Pointer) PointerPointer(org.bytedeco.javacpp.PointerPointer) TF_TString(org.tensorflow.internal.c_api.TF_TString) TF_TString(org.tensorflow.internal.c_api.TF_TString) PointerScope(org.bytedeco.javacpp.PointerScope)

Example 5 with TF_Status

use of org.tensorflow.internal.c_api.TF_Status in project djl by deepjavalibrary.

the class JavacppUtils method getByteBuffer.

@SuppressWarnings({ "unchecked", "try" })
public static ByteBuffer getByteBuffer(TFE_TensorHandle handle) {
    try (PointerScope ignored = new PointerScope()) {
        // convert to TF_Tensor
        TF_Status status = TF_Status.newStatus();
        TF_Tensor tensor = tensorflow.TFE_TensorHandleResolve(handle, status).withDeallocator();
        status.throwExceptionIfNotOK();
        Pointer pointer = tensorflow.TF_TensorData(tensor).capacity(tensorflow.TF_TensorByteSize(tensor));
        ByteBuffer buf = pointer.asByteBuffer();
        // do the copy since we should make sure the returned ByteBuffer is still valid after
        // the tensor is deleted
        ByteBuffer ret = ByteBuffer.allocate(buf.capacity());
        buf.rewind();
        ret.put(buf);
        ret.flip();
        return ret.order(ByteOrder.nativeOrder());
    }
}
Also used : TF_Status(org.tensorflow.internal.c_api.TF_Status) TF_Tensor(org.tensorflow.internal.c_api.TF_Tensor) AbstractTF_Tensor(org.tensorflow.internal.c_api.AbstractTF_Tensor) BytePointer(org.bytedeco.javacpp.BytePointer) Pointer(org.bytedeco.javacpp.Pointer) PointerPointer(org.bytedeco.javacpp.PointerPointer) PointerScope(org.bytedeco.javacpp.PointerScope) ByteBuffer(java.nio.ByteBuffer)

Aggregations

PointerScope (org.bytedeco.javacpp.PointerScope)18 TF_Status (org.tensorflow.internal.c_api.TF_Status)18 BytePointer (org.bytedeco.javacpp.BytePointer)9 PointerPointer (org.bytedeco.javacpp.PointerPointer)8 AbstractTF_Tensor (org.tensorflow.internal.c_api.AbstractTF_Tensor)8 TF_Tensor (org.tensorflow.internal.c_api.TF_Tensor)8 TFE_TensorHandle (org.tensorflow.internal.c_api.TFE_TensorHandle)7 Pointer (org.bytedeco.javacpp.Pointer)6 AbstractTFE_TensorHandle (org.tensorflow.internal.c_api.AbstractTFE_TensorHandle)5 TF_TString (org.tensorflow.internal.c_api.TF_TString)4 EngineException (ai.djl.engine.EngineException)2 IntPointer (org.bytedeco.javacpp.IntPointer)2 TFE_Context (org.tensorflow.internal.c_api.TFE_Context)2 TF_Buffer (org.tensorflow.internal.c_api.TF_Buffer)2 Device (ai.djl.Device)1 NDArray (ai.djl.ndarray.NDArray)1 DataType (ai.djl.ndarray.types.DataType)1 Shape (ai.djl.ndarray.types.Shape)1 SavedModelBundle (ai.djl.tensorflow.engine.SavedModelBundle)1 Pair (ai.djl.util.Pair)1