Search in sources :

Example 1 with TFE_Context

use of org.tensorflow.internal.c_api.TFE_Context in project djl by deepjavalibrary.

the class JavacppUtils method createEagerSession.

@SuppressWarnings({ "unchecked", "try" })
public static TFE_Context createEagerSession(boolean async, int devicePlacementPolicy, ConfigProto config) {
    try (PointerScope ignored = new PointerScope()) {
        TFE_ContextOptions opts = TFE_ContextOptions.newContextOptions();
        TF_Status status = TF_Status.newStatus();
        if (config != null) {
            BytePointer configBytes = new BytePointer(config.toByteArray());
            tensorflow.TFE_ContextOptionsSetConfig(opts, configBytes, configBytes.capacity(), status);
            status.throwExceptionIfNotOK();
        }
        tensorflow.TFE_ContextOptionsSetAsync(opts, (byte) (async ? 1 : 0));
        tensorflow.TFE_ContextOptionsSetDevicePlacementPolicy(opts, devicePlacementPolicy);
        TFE_Context context = AbstractTFE_Context.newContext(opts, status);
        status.throwExceptionIfNotOK();
        return context.retainReference();
    }
}
Also used : TFE_Context(org.tensorflow.internal.c_api.TFE_Context) AbstractTFE_Context(org.tensorflow.internal.c_api.AbstractTFE_Context) TF_Status(org.tensorflow.internal.c_api.TF_Status) BytePointer(org.bytedeco.javacpp.BytePointer) PointerScope(org.bytedeco.javacpp.PointerScope) TFE_ContextOptions(org.tensorflow.internal.c_api.TFE_ContextOptions)

Aggregations

BytePointer (org.bytedeco.javacpp.BytePointer)1 PointerScope (org.bytedeco.javacpp.PointerScope)1 AbstractTFE_Context (org.tensorflow.internal.c_api.AbstractTFE_Context)1 TFE_Context (org.tensorflow.internal.c_api.TFE_Context)1 TFE_ContextOptions (org.tensorflow.internal.c_api.TFE_ContextOptions)1 TF_Status (org.tensorflow.internal.c_api.TF_Status)1