Search in sources :

Example 1 with TF_Buffer

use of org.tensorflow.internal.c_api.TF_Buffer in project djl by deepjavalibrary.

the class JavacppUtils method loadSavedModelBundle.

@SuppressWarnings({ "unchecked", "try" })
public static SavedModelBundle loadSavedModelBundle(String exportDir, String[] tags, ConfigProto config, RunOptions runOptions) {
    try (PointerScope ignored = new PointerScope()) {
        TF_Status status = TF_Status.newStatus();
        // allocate parameters for TF_LoadSessionFromSavedModel
        TF_SessionOptions opts = TF_SessionOptions.newSessionOptions();
        if (config != null) {
            BytePointer configBytes = new BytePointer(config.toByteArray());
            tensorflow.TF_SetConfig(opts, configBytes, configBytes.capacity(), status);
            status.throwExceptionIfNotOK();
        }
        TF_Buffer runOpts = TF_Buffer.newBufferFromString(runOptions);
        // load the session
        TF_Graph graphHandle = AbstractTF_Graph.newGraph().retainReference();
        TF_Buffer metaGraphDef = TF_Buffer.newBuffer();
        TF_Session sessionHandle = tensorflow.TF_LoadSessionFromSavedModel(opts, runOpts, new BytePointer(exportDir), new PointerPointer<>(tags), tags.length, graphHandle, metaGraphDef, status);
        status.throwExceptionIfNotOK();
        // handle the result
        try {
            return new SavedModelBundle(graphHandle, sessionHandle, MetaGraphDef.parseFrom(metaGraphDef.dataAsByteBuffer()));
        } catch (InvalidProtocolBufferException e) {
            throw new TensorFlowException("Cannot parse MetaGraphDef protocol buffer", e);
        }
    }
}
Also used : TF_Graph(org.tensorflow.internal.c_api.TF_Graph) AbstractTF_Graph(org.tensorflow.internal.c_api.AbstractTF_Graph) TF_Session(org.tensorflow.internal.c_api.TF_Session) TF_Status(org.tensorflow.internal.c_api.TF_Status) BytePointer(org.bytedeco.javacpp.BytePointer) InvalidProtocolBufferException(com.google.protobuf.InvalidProtocolBufferException) TF_Buffer(org.tensorflow.internal.c_api.TF_Buffer) SavedModelBundle(ai.djl.tensorflow.engine.SavedModelBundle) PointerScope(org.bytedeco.javacpp.PointerScope) TensorFlowException(org.tensorflow.exceptions.TensorFlowException) TF_SessionOptions(org.tensorflow.internal.c_api.TF_SessionOptions)

Example 2 with TF_Buffer

use of org.tensorflow.internal.c_api.TF_Buffer in project djl by deepjavalibrary.

the class JavacppUtils method runSession.

@SuppressWarnings({ "unchecked", "try" })
public static TF_Tensor[] runSession(TF_Session handle, RunOptions runOptions, TF_Tensor[] inputTensorHandles, TF_Operation[] inputOpHandles, int[] inputOpIndices, TF_Operation[] outputOpHandles, int[] outputOpIndices, TF_Operation[] targetOpHandles) {
    int numInputs = inputTensorHandles.length;
    int numOutputs = outputOpHandles.length;
    int numTargets = targetOpHandles.length;
    try (PointerScope ignored = new PointerScope()) {
        // TODO: check with sig-jvm if TF_Output here is freed
        TF_Output inputs = new TF_Output(numInputs);
        PointerPointer<TF_Tensor> inputValues = new PointerPointer<>(numInputs);
        TF_Output outputs = new TF_Output(numOutputs);
        PointerPointer<TF_Tensor> outputValues = new PointerPointer<>(numOutputs);
        PointerPointer<TF_Operation> targets = new PointerPointer<>(numTargets);
        // set input
        for (int i = 0; i < numInputs; ++i) {
            inputValues.put(i, inputTensorHandles[i]);
        }
        // set TF_Output for inputs
        for (int i = 0; i < numInputs; ++i) {
            inputs.position(i).oper(inputOpHandles[i]).index(inputOpIndices[i]);
        }
        inputs.position(0);
        // set TF_Output for outputs
        for (int i = 0; i < numOutputs; ++i) {
            outputs.position(i).oper(outputOpHandles[i]).index(outputOpIndices[i]);
        }
        outputs.position(0);
        // set target
        for (int i = 0; i < numTargets; ++i) {
            targets.put(i, targetOpHandles[i]);
        }
        TF_Status status = TF_Status.newStatus();
        TF_Buffer runOpts = TF_Buffer.newBufferFromString(runOptions);
        tensorflow.TF_SessionRun(handle, runOpts, inputs, inputValues, numInputs, outputs, outputValues, numOutputs, targets, numTargets, null, status);
        status.throwExceptionIfNotOK();
        TF_Tensor[] ret = new TF_Tensor[numOutputs];
        for (int i = 0; i < numOutputs; ++i) {
            ret[i] = outputValues.get(TF_Tensor.class, i).withDeallocator().retainReference();
        }
        return ret;
    }
}
Also used : TF_Output(org.tensorflow.internal.c_api.TF_Output) PointerPointer(org.bytedeco.javacpp.PointerPointer) TF_Tensor(org.tensorflow.internal.c_api.TF_Tensor) AbstractTF_Tensor(org.tensorflow.internal.c_api.AbstractTF_Tensor) TF_Status(org.tensorflow.internal.c_api.TF_Status) TF_Buffer(org.tensorflow.internal.c_api.TF_Buffer) TF_Operation(org.tensorflow.internal.c_api.TF_Operation) PointerScope(org.bytedeco.javacpp.PointerScope)

Aggregations

PointerScope (org.bytedeco.javacpp.PointerScope)2 TF_Buffer (org.tensorflow.internal.c_api.TF_Buffer)2 TF_Status (org.tensorflow.internal.c_api.TF_Status)2 SavedModelBundle (ai.djl.tensorflow.engine.SavedModelBundle)1 InvalidProtocolBufferException (com.google.protobuf.InvalidProtocolBufferException)1 BytePointer (org.bytedeco.javacpp.BytePointer)1 PointerPointer (org.bytedeco.javacpp.PointerPointer)1 TensorFlowException (org.tensorflow.exceptions.TensorFlowException)1 AbstractTF_Graph (org.tensorflow.internal.c_api.AbstractTF_Graph)1 AbstractTF_Tensor (org.tensorflow.internal.c_api.AbstractTF_Tensor)1 TF_Graph (org.tensorflow.internal.c_api.TF_Graph)1 TF_Operation (org.tensorflow.internal.c_api.TF_Operation)1 TF_Output (org.tensorflow.internal.c_api.TF_Output)1 TF_Session (org.tensorflow.internal.c_api.TF_Session)1 TF_SessionOptions (org.tensorflow.internal.c_api.TF_SessionOptions)1 TF_Tensor (org.tensorflow.internal.c_api.TF_Tensor)1