use of org.tensorflow.internal.c_api.TFE_Context in project djl by deepjavalibrary.
the class JavacppUtils method createEagerSession.
@SuppressWarnings({ "unchecked", "try" })
public static TFE_Context createEagerSession(boolean async, int devicePlacementPolicy, ConfigProto config) {
try (PointerScope ignored = new PointerScope()) {
TFE_ContextOptions opts = TFE_ContextOptions.newContextOptions();
TF_Status status = TF_Status.newStatus();
if (config != null) {
BytePointer configBytes = new BytePointer(config.toByteArray());
tensorflow.TFE_ContextOptionsSetConfig(opts, configBytes, configBytes.capacity(), status);
status.throwExceptionIfNotOK();
}
tensorflow.TFE_ContextOptionsSetAsync(opts, (byte) (async ? 1 : 0));
tensorflow.TFE_ContextOptionsSetDevicePlacementPolicy(opts, devicePlacementPolicy);
TFE_Context context = AbstractTFE_Context.newContext(opts, status);
status.throwExceptionIfNotOK();
return context.retainReference();
}
}
Aggregations