Search in sources :

Example 1 with SavedModelBundle

use of ai.djl.tensorflow.engine.SavedModelBundle in project djl by deepjavalibrary.

the class JavacppUtils method loadSavedModelBundle.

@SuppressWarnings({ "unchecked", "try" })
public static SavedModelBundle loadSavedModelBundle(String exportDir, String[] tags, ConfigProto config, RunOptions runOptions) {
    try (PointerScope ignored = new PointerScope()) {
        TF_Status status = TF_Status.newStatus();
        // allocate parameters for TF_LoadSessionFromSavedModel
        TF_SessionOptions opts = TF_SessionOptions.newSessionOptions();
        if (config != null) {
            BytePointer configBytes = new BytePointer(config.toByteArray());
            tensorflow.TF_SetConfig(opts, configBytes, configBytes.capacity(), status);
            status.throwExceptionIfNotOK();
        }
        TF_Buffer runOpts = TF_Buffer.newBufferFromString(runOptions);
        // load the session
        TF_Graph graphHandle = AbstractTF_Graph.newGraph().retainReference();
        TF_Buffer metaGraphDef = TF_Buffer.newBuffer();
        TF_Session sessionHandle = tensorflow.TF_LoadSessionFromSavedModel(opts, runOpts, new BytePointer(exportDir), new PointerPointer<>(tags), tags.length, graphHandle, metaGraphDef, status);
        status.throwExceptionIfNotOK();
        // handle the result
        try {
            return new SavedModelBundle(graphHandle, sessionHandle, MetaGraphDef.parseFrom(metaGraphDef.dataAsByteBuffer()));
        } catch (InvalidProtocolBufferException e) {
            throw new TensorFlowException("Cannot parse MetaGraphDef protocol buffer", e);
        }
    }
}
Also used : TF_Graph(org.tensorflow.internal.c_api.TF_Graph) AbstractTF_Graph(org.tensorflow.internal.c_api.AbstractTF_Graph) TF_Session(org.tensorflow.internal.c_api.TF_Session) TF_Status(org.tensorflow.internal.c_api.TF_Status) BytePointer(org.bytedeco.javacpp.BytePointer) InvalidProtocolBufferException(com.google.protobuf.InvalidProtocolBufferException) TF_Buffer(org.tensorflow.internal.c_api.TF_Buffer) SavedModelBundle(ai.djl.tensorflow.engine.SavedModelBundle) PointerScope(org.bytedeco.javacpp.PointerScope) TensorFlowException(org.tensorflow.exceptions.TensorFlowException) TF_SessionOptions(org.tensorflow.internal.c_api.TF_SessionOptions)

Aggregations

SavedModelBundle (ai.djl.tensorflow.engine.SavedModelBundle)1 InvalidProtocolBufferException (com.google.protobuf.InvalidProtocolBufferException)1 BytePointer (org.bytedeco.javacpp.BytePointer)1 PointerScope (org.bytedeco.javacpp.PointerScope)1 TensorFlowException (org.tensorflow.exceptions.TensorFlowException)1 AbstractTF_Graph (org.tensorflow.internal.c_api.AbstractTF_Graph)1 TF_Buffer (org.tensorflow.internal.c_api.TF_Buffer)1 TF_Graph (org.tensorflow.internal.c_api.TF_Graph)1 TF_Session (org.tensorflow.internal.c_api.TF_Session)1 TF_SessionOptions (org.tensorflow.internal.c_api.TF_SessionOptions)1 TF_Status (org.tensorflow.internal.c_api.TF_Status)1