Search in sources :

Example 1 with TF_TString

use of org.tensorflow.internal.c_api.TF_TString in project djl by deepjavalibrary.

the class JavacppUtils method getString.

@SuppressWarnings({ "unchecked", "try" })
public static String[] getString(TFE_TensorHandle handle, int count, Charset charset) {
    try (PointerScope ignored = new PointerScope()) {
        // convert to TF_Tensor
        TF_Status status = TF_Status.newStatus();
        // should not add .withDeallocator() here, otherwise sting data will be destroyed
        TF_Tensor tensor = tensorflow.TFE_TensorHandleResolve(handle, status);
        status.throwExceptionIfNotOK();
        long tensorSize = tensorflow.TF_TensorByteSize(tensor);
        Pointer pointer = tensorflow.TF_TensorData(tensor).capacity(tensorSize);
        TF_TString data = new TF_TString(pointer).capacity(pointer.position() + count);
        String[] ret = new String[count];
        for (int i = 0; i < count; ++i) {
            TF_TString tstring = data.getPointer(i);
            long size = tensorflow.TF_TString_GetSize(tstring);
            BytePointer bp = tensorflow.TF_TString_GetDataPointer(tstring).capacity(size);
            ret[i] = bp.getString(charset);
        }
        // manually delete tensor
        tensorflow.TF_DeleteTensor(tensor);
        return ret;
    }
}
Also used : TF_Status(org.tensorflow.internal.c_api.TF_Status) TF_Tensor(org.tensorflow.internal.c_api.TF_Tensor) AbstractTF_Tensor(org.tensorflow.internal.c_api.AbstractTF_Tensor) BytePointer(org.bytedeco.javacpp.BytePointer) BytePointer(org.bytedeco.javacpp.BytePointer) Pointer(org.bytedeco.javacpp.Pointer) PointerPointer(org.bytedeco.javacpp.PointerPointer) TF_TString(org.tensorflow.internal.c_api.TF_TString) TF_TString(org.tensorflow.internal.c_api.TF_TString) PointerScope(org.bytedeco.javacpp.PointerScope)

Example 2 with TF_TString

use of org.tensorflow.internal.c_api.TF_TString in project djl by deepjavalibrary.

the class JavacppUtils method createStringTensor.

@SuppressWarnings({ "unchecked", "try" })
public static Pair<TF_Tensor, TFE_TensorHandle> createStringTensor(long[] dims, ByteBuffer[] src) {
    int dType = TfDataType.toTf(DataType.STRING);
    long numBytes = (long) Loader.sizeof(TF_TString.class) * src.length;
    try (PointerScope ignored = new PointerScope()) {
        /*
             * String tensor allocates a separate TF_TString memory. The TF_TString will
             * be deleted when the string tensor is closed. We have to track TF_TString
             * memory by ourselves and make sure thw TF_TString lifecycle align with
             * TFE_TensorHandle. TF_Tensor already handles TF_TString automatically, We
             * can just keep a TF_Tensor reference in TfNDArray.
             */
        TF_Tensor tensor = AbstractTF_Tensor.allocateTensor(dType, dims, numBytes);
        Pointer pointer = tensorflow.TF_TensorData(tensor).capacity(numBytes);
        TF_TString data = new TF_TString(pointer).capacity(pointer.position() + src.length);
        for (int i = 0; i < src.length; ++i) {
            TF_TString tstring = data.getPointer(i);
            tensorflow.TF_TString_Copy(tstring, new BytePointer(src[i]), src[i].remaining());
        }
        TF_Status status = TF_Status.newStatus();
        TFE_TensorHandle handle = AbstractTFE_TensorHandle.newTensor(tensor, status);
        status.throwExceptionIfNotOK();
        handle.retainReference();
        tensor.retainReference();
        return new Pair<>(tensor, handle);
    }
}
Also used : TF_Tensor(org.tensorflow.internal.c_api.TF_Tensor) AbstractTF_Tensor(org.tensorflow.internal.c_api.AbstractTF_Tensor) TF_Status(org.tensorflow.internal.c_api.TF_Status) BytePointer(org.bytedeco.javacpp.BytePointer) TFE_TensorHandle(org.tensorflow.internal.c_api.TFE_TensorHandle) AbstractTFE_TensorHandle(org.tensorflow.internal.c_api.AbstractTFE_TensorHandle) BytePointer(org.bytedeco.javacpp.BytePointer) Pointer(org.bytedeco.javacpp.Pointer) PointerPointer(org.bytedeco.javacpp.PointerPointer) TF_TString(org.tensorflow.internal.c_api.TF_TString) PointerScope(org.bytedeco.javacpp.PointerScope) Pair(ai.djl.util.Pair)

Aggregations

BytePointer (org.bytedeco.javacpp.BytePointer)2 Pointer (org.bytedeco.javacpp.Pointer)2 PointerPointer (org.bytedeco.javacpp.PointerPointer)2 PointerScope (org.bytedeco.javacpp.PointerScope)2 AbstractTF_Tensor (org.tensorflow.internal.c_api.AbstractTF_Tensor)2 TF_Status (org.tensorflow.internal.c_api.TF_Status)2 TF_TString (org.tensorflow.internal.c_api.TF_TString)2 TF_Tensor (org.tensorflow.internal.c_api.TF_Tensor)2 Pair (ai.djl.util.Pair)1 AbstractTFE_TensorHandle (org.tensorflow.internal.c_api.AbstractTFE_TensorHandle)1 TFE_TensorHandle (org.tensorflow.internal.c_api.TFE_TensorHandle)1