use of org.tensorflow.internal.c_api.TF_Buffer in project djl by deepjavalibrary.
the class JavacppUtils method loadSavedModelBundle.
@SuppressWarnings({ "unchecked", "try" })
public static SavedModelBundle loadSavedModelBundle(String exportDir, String[] tags, ConfigProto config, RunOptions runOptions) {
try (PointerScope ignored = new PointerScope()) {
TF_Status status = TF_Status.newStatus();
// allocate parameters for TF_LoadSessionFromSavedModel
TF_SessionOptions opts = TF_SessionOptions.newSessionOptions();
if (config != null) {
BytePointer configBytes = new BytePointer(config.toByteArray());
tensorflow.TF_SetConfig(opts, configBytes, configBytes.capacity(), status);
status.throwExceptionIfNotOK();
}
TF_Buffer runOpts = TF_Buffer.newBufferFromString(runOptions);
// load the session
TF_Graph graphHandle = AbstractTF_Graph.newGraph().retainReference();
TF_Buffer metaGraphDef = TF_Buffer.newBuffer();
TF_Session sessionHandle = tensorflow.TF_LoadSessionFromSavedModel(opts, runOpts, new BytePointer(exportDir), new PointerPointer<>(tags), tags.length, graphHandle, metaGraphDef, status);
status.throwExceptionIfNotOK();
// handle the result
try {
return new SavedModelBundle(graphHandle, sessionHandle, MetaGraphDef.parseFrom(metaGraphDef.dataAsByteBuffer()));
} catch (InvalidProtocolBufferException e) {
throw new TensorFlowException("Cannot parse MetaGraphDef protocol buffer", e);
}
}
}
use of org.tensorflow.internal.c_api.TF_Buffer in project djl by deepjavalibrary.
the class JavacppUtils method runSession.
@SuppressWarnings({ "unchecked", "try" })
public static TF_Tensor[] runSession(TF_Session handle, RunOptions runOptions, TF_Tensor[] inputTensorHandles, TF_Operation[] inputOpHandles, int[] inputOpIndices, TF_Operation[] outputOpHandles, int[] outputOpIndices, TF_Operation[] targetOpHandles) {
int numInputs = inputTensorHandles.length;
int numOutputs = outputOpHandles.length;
int numTargets = targetOpHandles.length;
try (PointerScope ignored = new PointerScope()) {
// TODO: check with sig-jvm if TF_Output here is freed
TF_Output inputs = new TF_Output(numInputs);
PointerPointer<TF_Tensor> inputValues = new PointerPointer<>(numInputs);
TF_Output outputs = new TF_Output(numOutputs);
PointerPointer<TF_Tensor> outputValues = new PointerPointer<>(numOutputs);
PointerPointer<TF_Operation> targets = new PointerPointer<>(numTargets);
// set input
for (int i = 0; i < numInputs; ++i) {
inputValues.put(i, inputTensorHandles[i]);
}
// set TF_Output for inputs
for (int i = 0; i < numInputs; ++i) {
inputs.position(i).oper(inputOpHandles[i]).index(inputOpIndices[i]);
}
inputs.position(0);
// set TF_Output for outputs
for (int i = 0; i < numOutputs; ++i) {
outputs.position(i).oper(outputOpHandles[i]).index(outputOpIndices[i]);
}
outputs.position(0);
// set target
for (int i = 0; i < numTargets; ++i) {
targets.put(i, targetOpHandles[i]);
}
TF_Status status = TF_Status.newStatus();
TF_Buffer runOpts = TF_Buffer.newBufferFromString(runOptions);
tensorflow.TF_SessionRun(handle, runOpts, inputs, inputValues, numInputs, outputs, outputValues, numOutputs, targets, numTargets, null, status);
status.throwExceptionIfNotOK();
TF_Tensor[] ret = new TF_Tensor[numOutputs];
for (int i = 0; i < numOutputs; ++i) {
ret[i] = outputValues.get(TF_Tensor.class, i).withDeallocator().retainReference();
}
return ret;
}
}
Aggregations