use of org.tensorflow.internal.c_api.TF_Status in project djl by deepjavalibrary.
the class JavacppUtils method loadSavedModelBundle.
@SuppressWarnings({ "unchecked", "try" })
public static SavedModelBundle loadSavedModelBundle(String exportDir, String[] tags, ConfigProto config, RunOptions runOptions) {
try (PointerScope ignored = new PointerScope()) {
TF_Status status = TF_Status.newStatus();
// allocate parameters for TF_LoadSessionFromSavedModel
TF_SessionOptions opts = TF_SessionOptions.newSessionOptions();
if (config != null) {
BytePointer configBytes = new BytePointer(config.toByteArray());
tensorflow.TF_SetConfig(opts, configBytes, configBytes.capacity(), status);
status.throwExceptionIfNotOK();
}
TF_Buffer runOpts = TF_Buffer.newBufferFromString(runOptions);
// load the session
TF_Graph graphHandle = AbstractTF_Graph.newGraph().retainReference();
TF_Buffer metaGraphDef = TF_Buffer.newBuffer();
TF_Session sessionHandle = tensorflow.TF_LoadSessionFromSavedModel(opts, runOpts, new BytePointer(exportDir), new PointerPointer<>(tags), tags.length, graphHandle, metaGraphDef, status);
status.throwExceptionIfNotOK();
// handle the result
try {
return new SavedModelBundle(graphHandle, sessionHandle, MetaGraphDef.parseFrom(metaGraphDef.dataAsByteBuffer()));
} catch (InvalidProtocolBufferException e) {
throw new TensorFlowException("Cannot parse MetaGraphDef protocol buffer", e);
}
}
}
use of org.tensorflow.internal.c_api.TF_Status in project djl by deepjavalibrary.
the class JavacppUtils method getDevice.
@SuppressWarnings({ "unchecked", "try" })
public static Device getDevice(TFE_TensorHandle handle) {
try (PointerScope ignored = new PointerScope()) {
TF_Status status = TF_Status.newStatus();
BytePointer pointer = tensorflow.TFE_TensorHandleDeviceName(handle, status);
String device = new String(pointer.getStringBytes(), StandardCharsets.UTF_8);
return fromTfDevice(device);
}
}
use of org.tensorflow.internal.c_api.TF_Status in project djl by deepjavalibrary.
the class JavacppUtils method createTFETensorFromByteBuffer.
@SuppressWarnings({ "unchecked", "try" })
public static TFE_TensorHandle createTFETensorFromByteBuffer(ByteBuffer buf, Shape shape, DataType dataType, TFE_Context eagerSessionHandle, Device device) {
int dType = TfDataType.toTf(dataType);
long[] dims = shape.getShape();
long numBytes;
if (dataType == DataType.STRING) {
numBytes = buf.remaining() + 1;
} else {
numBytes = shape.size() * dataType.getNumOfBytes();
}
try (PointerScope ignored = new PointerScope()) {
TF_Tensor tensor = AbstractTF_Tensor.allocateTensor(dType, dims, numBytes);
// get data pointer in native engine
Pointer pointer = tensorflow.TF_TensorData(tensor).capacity(numBytes);
pointer.asByteBuffer().put(buf);
TF_Status status = TF_Status.newStatus();
TFE_TensorHandle handle = AbstractTFE_TensorHandle.newTensor(tensor, status);
status.throwExceptionIfNotOK();
if (device.isGpu()) {
return toDevice(handle, eagerSessionHandle, device);
}
return handle.retainReference();
}
}
use of org.tensorflow.internal.c_api.TF_Status in project djl by deepjavalibrary.
the class JavacppUtils method getString.
@SuppressWarnings({ "unchecked", "try" })
public static String[] getString(TFE_TensorHandle handle, int count, Charset charset) {
try (PointerScope ignored = new PointerScope()) {
// convert to TF_Tensor
TF_Status status = TF_Status.newStatus();
// should not add .withDeallocator() here, otherwise sting data will be destroyed
TF_Tensor tensor = tensorflow.TFE_TensorHandleResolve(handle, status);
status.throwExceptionIfNotOK();
long tensorSize = tensorflow.TF_TensorByteSize(tensor);
Pointer pointer = tensorflow.TF_TensorData(tensor).capacity(tensorSize);
TF_TString data = new TF_TString(pointer).capacity(pointer.position() + count);
String[] ret = new String[count];
for (int i = 0; i < count; ++i) {
TF_TString tstring = data.getPointer(i);
long size = tensorflow.TF_TString_GetSize(tstring);
BytePointer bp = tensorflow.TF_TString_GetDataPointer(tstring).capacity(size);
ret[i] = bp.getString(charset);
}
// manually delete tensor
tensorflow.TF_DeleteTensor(tensor);
return ret;
}
}
use of org.tensorflow.internal.c_api.TF_Status in project djl by deepjavalibrary.
the class JavacppUtils method getByteBuffer.
@SuppressWarnings({ "unchecked", "try" })
public static ByteBuffer getByteBuffer(TFE_TensorHandle handle) {
try (PointerScope ignored = new PointerScope()) {
// convert to TF_Tensor
TF_Status status = TF_Status.newStatus();
TF_Tensor tensor = tensorflow.TFE_TensorHandleResolve(handle, status).withDeallocator();
status.throwExceptionIfNotOK();
Pointer pointer = tensorflow.TF_TensorData(tensor).capacity(tensorflow.TF_TensorByteSize(tensor));
ByteBuffer buf = pointer.asByteBuffer();
// do the copy since we should make sure the returned ByteBuffer is still valid after
// the tensor is deleted
ByteBuffer ret = ByteBuffer.allocate(buf.capacity());
buf.rewind();
ret.put(buf);
ret.flip();
return ret.order(ByteOrder.nativeOrder());
}
}
Aggregations