use of org.tensorflow.internal.c_api.TF_TString in project djl by deepjavalibrary.
the class JavacppUtils method getString.
@SuppressWarnings({ "unchecked", "try" })
public static String[] getString(TFE_TensorHandle handle, int count, Charset charset) {
try (PointerScope ignored = new PointerScope()) {
// convert to TF_Tensor
TF_Status status = TF_Status.newStatus();
// should not add .withDeallocator() here, otherwise sting data will be destroyed
TF_Tensor tensor = tensorflow.TFE_TensorHandleResolve(handle, status);
status.throwExceptionIfNotOK();
long tensorSize = tensorflow.TF_TensorByteSize(tensor);
Pointer pointer = tensorflow.TF_TensorData(tensor).capacity(tensorSize);
TF_TString data = new TF_TString(pointer).capacity(pointer.position() + count);
String[] ret = new String[count];
for (int i = 0; i < count; ++i) {
TF_TString tstring = data.getPointer(i);
long size = tensorflow.TF_TString_GetSize(tstring);
BytePointer bp = tensorflow.TF_TString_GetDataPointer(tstring).capacity(size);
ret[i] = bp.getString(charset);
}
// manually delete tensor
tensorflow.TF_DeleteTensor(tensor);
return ret;
}
}
use of org.tensorflow.internal.c_api.TF_TString in project djl by deepjavalibrary.
the class JavacppUtils method createStringTensor.
@SuppressWarnings({ "unchecked", "try" })
public static Pair<TF_Tensor, TFE_TensorHandle> createStringTensor(long[] dims, ByteBuffer[] src) {
int dType = TfDataType.toTf(DataType.STRING);
long numBytes = (long) Loader.sizeof(TF_TString.class) * src.length;
try (PointerScope ignored = new PointerScope()) {
/*
* String tensor allocates a separate TF_TString memory. The TF_TString will
* be deleted when the string tensor is closed. We have to track TF_TString
* memory by ourselves and make sure thw TF_TString lifecycle align with
* TFE_TensorHandle. TF_Tensor already handles TF_TString automatically, We
* can just keep a TF_Tensor reference in TfNDArray.
*/
TF_Tensor tensor = AbstractTF_Tensor.allocateTensor(dType, dims, numBytes);
Pointer pointer = tensorflow.TF_TensorData(tensor).capacity(numBytes);
TF_TString data = new TF_TString(pointer).capacity(pointer.position() + src.length);
for (int i = 0; i < src.length; ++i) {
TF_TString tstring = data.getPointer(i);
tensorflow.TF_TString_Copy(tstring, new BytePointer(src[i]), src[i].remaining());
}
TF_Status status = TF_Status.newStatus();
TFE_TensorHandle handle = AbstractTFE_TensorHandle.newTensor(tensor, status);
status.throwExceptionIfNotOK();
handle.retainReference();
tensor.retainReference();
return new Pair<>(tensor, handle);
}
}
Aggregations