use of org.tensorflow.internal.c_api.TF_Graph in project djl by deepjavalibrary.
the class JavacppUtils method loadSavedModelBundle.
@SuppressWarnings({ "unchecked", "try" })
public static SavedModelBundle loadSavedModelBundle(String exportDir, String[] tags, ConfigProto config, RunOptions runOptions) {
try (PointerScope ignored = new PointerScope()) {
TF_Status status = TF_Status.newStatus();
// allocate parameters for TF_LoadSessionFromSavedModel
TF_SessionOptions opts = TF_SessionOptions.newSessionOptions();
if (config != null) {
BytePointer configBytes = new BytePointer(config.toByteArray());
tensorflow.TF_SetConfig(opts, configBytes, configBytes.capacity(), status);
status.throwExceptionIfNotOK();
}
TF_Buffer runOpts = TF_Buffer.newBufferFromString(runOptions);
// load the session
TF_Graph graphHandle = AbstractTF_Graph.newGraph().retainReference();
TF_Buffer metaGraphDef = TF_Buffer.newBuffer();
TF_Session sessionHandle = tensorflow.TF_LoadSessionFromSavedModel(opts, runOpts, new BytePointer(exportDir), new PointerPointer<>(tags), tags.length, graphHandle, metaGraphDef, status);
status.throwExceptionIfNotOK();
// handle the result
try {
return new SavedModelBundle(graphHandle, sessionHandle, MetaGraphDef.parseFrom(metaGraphDef.dataAsByteBuffer()));
} catch (InvalidProtocolBufferException e) {
throw new TensorFlowException("Cannot parse MetaGraphDef protocol buffer", e);
}
}
}
Aggregations