Search in sources :

Example 1 with TF_Tensor

use of org.tensorflow.internal.c_api.TF_Tensor in project djl by deepjavalibrary.

the class TfSymbolBlock method forwardInternal.

/**
 * {@inheritDoc}
 */
@Override
protected NDList forwardInternal(ParameterStore parameterStore, NDList inputs, boolean training, PairList<String, Object> params) {
    TF_Tensor[] inputTensorHandles = new TF_Tensor[inputDescriptions.size()];
    for (int i = 0; i < inputDescriptions.size(); i++) {
        String inputName = inputDescriptions.get(i).getKey();
        TfNDArray currentNDArray = (TfNDArray) inputs.get(i);
        // if no name specified in input array or
        // the input order matches inputDescriptions
        // use default order from translator
        String name = currentNDArray.getName();
        if (name == null || name.isEmpty() || name.equals(inputName)) {
            inputTensorHandles[i] = JavacppUtils.resolveTFETensor(currentNDArray.getHandle());
            continue;
        }
        // for loop to search the right NDArray
        for (NDArray array : inputs) {
            if (array.getName().equals(inputName)) {
                inputTensorHandles[i] = JavacppUtils.resolveTFETensor(((TfNDArray) array).getHandle());
            }
        }
    }
    TF_Tensor[] outputs = JavacppUtils.runSession(sessionHandle, null, inputTensorHandles, inputOpHandles, inputOpIndices, outputOpHandles, outputOpIndices, targetOpHandles);
    TfNDManager tfNDManager = (TfNDManager) inputs.head().getManager();
    NDList resultNDList = new NDList();
    for (int i = 0; i < outputs.length; i++) {
        TfNDArray array = new TfNDArray(tfNDManager, JavacppUtils.createTFETensor(outputs[i]));
        array.setName(outputDescriptions.get(i).getKey());
        resultNDList.add(array);
    }
    // free all unused native resources
    Arrays.stream(inputTensorHandles).forEach(Pointer::close);
    Arrays.stream(outputs).forEach(Pointer::close);
    return resultNDList;
}
Also used : TF_Tensor(org.tensorflow.internal.c_api.TF_Tensor) NDList(ai.djl.ndarray.NDList) NDArray(ai.djl.ndarray.NDArray) Pointer(org.bytedeco.javacpp.Pointer)

Example 2 with TF_Tensor

use of org.tensorflow.internal.c_api.TF_Tensor in project djl by deepjavalibrary.

the class JavacppUtils method createStringTensor.

@SuppressWarnings({ "unchecked", "try" })
public static Pair<TF_Tensor, TFE_TensorHandle> createStringTensor(long[] dims, ByteBuffer[] src) {
    int dType = TfDataType.toTf(DataType.STRING);
    long numBytes = (long) Loader.sizeof(TF_TString.class) * src.length;
    try (PointerScope ignored = new PointerScope()) {
        /*
             * String tensor allocates a separate TF_TString memory. The TF_TString will
             * be deleted when the string tensor is closed. We have to track TF_TString
             * memory by ourselves and make sure thw TF_TString lifecycle align with
             * TFE_TensorHandle. TF_Tensor already handles TF_TString automatically, We
             * can just keep a TF_Tensor reference in TfNDArray.
             */
        TF_Tensor tensor = AbstractTF_Tensor.allocateTensor(dType, dims, numBytes);
        Pointer pointer = tensorflow.TF_TensorData(tensor).capacity(numBytes);
        TF_TString data = new TF_TString(pointer).capacity(pointer.position() + src.length);
        for (int i = 0; i < src.length; ++i) {
            TF_TString tstring = data.getPointer(i);
            tensorflow.TF_TString_Copy(tstring, new BytePointer(src[i]), src[i].remaining());
        }
        TF_Status status = TF_Status.newStatus();
        TFE_TensorHandle handle = AbstractTFE_TensorHandle.newTensor(tensor, status);
        status.throwExceptionIfNotOK();
        handle.retainReference();
        tensor.retainReference();
        return new Pair<>(tensor, handle);
    }
}
Also used : TF_Tensor(org.tensorflow.internal.c_api.TF_Tensor) AbstractTF_Tensor(org.tensorflow.internal.c_api.AbstractTF_Tensor) TF_Status(org.tensorflow.internal.c_api.TF_Status) BytePointer(org.bytedeco.javacpp.BytePointer) TFE_TensorHandle(org.tensorflow.internal.c_api.TFE_TensorHandle) AbstractTFE_TensorHandle(org.tensorflow.internal.c_api.AbstractTFE_TensorHandle) BytePointer(org.bytedeco.javacpp.BytePointer) Pointer(org.bytedeco.javacpp.Pointer) PointerPointer(org.bytedeco.javacpp.PointerPointer) TF_TString(org.tensorflow.internal.c_api.TF_TString) PointerScope(org.bytedeco.javacpp.PointerScope) Pair(ai.djl.util.Pair)

Example 3 with TF_Tensor

use of org.tensorflow.internal.c_api.TF_Tensor in project djl by deepjavalibrary.

the class JavacppUtils method getByteBuffer.

@SuppressWarnings({ "unchecked", "try" })
public static ByteBuffer getByteBuffer(TFE_TensorHandle handle) {
    try (PointerScope ignored = new PointerScope()) {
        // convert to TF_Tensor
        TF_Status status = TF_Status.newStatus();
        TF_Tensor tensor = tensorflow.TFE_TensorHandleResolve(handle, status).withDeallocator();
        status.throwExceptionIfNotOK();
        Pointer pointer = tensorflow.TF_TensorData(tensor).capacity(tensorflow.TF_TensorByteSize(tensor));
        ByteBuffer buf = pointer.asByteBuffer();
        // do the copy since we should make sure the returned ByteBuffer is still valid after
        // the tensor is deleted
        ByteBuffer ret = ByteBuffer.allocate(buf.capacity());
        buf.rewind();
        ret.put(buf);
        ret.flip();
        return ret.order(ByteOrder.nativeOrder());
    }
}
Also used : TF_Status(org.tensorflow.internal.c_api.TF_Status) TF_Tensor(org.tensorflow.internal.c_api.TF_Tensor) AbstractTF_Tensor(org.tensorflow.internal.c_api.AbstractTF_Tensor) BytePointer(org.bytedeco.javacpp.BytePointer) Pointer(org.bytedeco.javacpp.Pointer) PointerPointer(org.bytedeco.javacpp.PointerPointer) PointerScope(org.bytedeco.javacpp.PointerScope) ByteBuffer(java.nio.ByteBuffer)

Example 4 with TF_Tensor

use of org.tensorflow.internal.c_api.TF_Tensor in project djl by deepjavalibrary.

the class JavacppUtils method getString.

@SuppressWarnings({ "unchecked", "try" })
public static String[] getString(TFE_TensorHandle handle, int count, Charset charset) {
    try (PointerScope ignored = new PointerScope()) {
        // convert to TF_Tensor
        TF_Status status = TF_Status.newStatus();
        // should not add .withDeallocator() here, otherwise sting data will be destroyed
        TF_Tensor tensor = tensorflow.TFE_TensorHandleResolve(handle, status);
        status.throwExceptionIfNotOK();
        long tensorSize = tensorflow.TF_TensorByteSize(tensor);
        Pointer pointer = tensorflow.TF_TensorData(tensor).capacity(tensorSize);
        TF_TString data = new TF_TString(pointer).capacity(pointer.position() + count);
        String[] ret = new String[count];
        for (int i = 0; i < count; ++i) {
            TF_TString tstring = data.getPointer(i);
            long size = tensorflow.TF_TString_GetSize(tstring);
            BytePointer bp = tensorflow.TF_TString_GetDataPointer(tstring).capacity(size);
            ret[i] = bp.getString(charset);
        }
        // manually delete tensor
        tensorflow.TF_DeleteTensor(tensor);
        return ret;
    }
}
Also used : TF_Status(org.tensorflow.internal.c_api.TF_Status) TF_Tensor(org.tensorflow.internal.c_api.TF_Tensor) AbstractTF_Tensor(org.tensorflow.internal.c_api.AbstractTF_Tensor) BytePointer(org.bytedeco.javacpp.BytePointer) BytePointer(org.bytedeco.javacpp.BytePointer) Pointer(org.bytedeco.javacpp.Pointer) PointerPointer(org.bytedeco.javacpp.PointerPointer) TF_TString(org.tensorflow.internal.c_api.TF_TString) TF_TString(org.tensorflow.internal.c_api.TF_TString) PointerScope(org.bytedeco.javacpp.PointerScope)

Example 5 with TF_Tensor

use of org.tensorflow.internal.c_api.TF_Tensor in project djl by deepjavalibrary.

the class JavacppUtils method setByteBuffer.

@SuppressWarnings({ "unchecked", "try" })
public static void setByteBuffer(TFE_TensorHandle handle, ByteBuffer data) {
    try (PointerScope ignored = new PointerScope()) {
        // convert to TF_Tensor
        TF_Status status = TF_Status.newStatus();
        TF_Tensor tensor = tensorflow.TFE_TensorHandleResolve(handle, status).withDeallocator();
        status.throwExceptionIfNotOK();
        Pointer pointer = tensorflow.TF_TensorData(tensor).capacity(tensorflow.TF_TensorByteSize(tensor));
        pointer.asByteBuffer().put(data);
    }
}
Also used : TF_Status(org.tensorflow.internal.c_api.TF_Status) TF_Tensor(org.tensorflow.internal.c_api.TF_Tensor) AbstractTF_Tensor(org.tensorflow.internal.c_api.AbstractTF_Tensor) BytePointer(org.bytedeco.javacpp.BytePointer) Pointer(org.bytedeco.javacpp.Pointer) PointerPointer(org.bytedeco.javacpp.PointerPointer) PointerScope(org.bytedeco.javacpp.PointerScope)

Aggregations

TF_Tensor (org.tensorflow.internal.c_api.TF_Tensor)10 AbstractTF_Tensor (org.tensorflow.internal.c_api.AbstractTF_Tensor)9 PointerScope (org.bytedeco.javacpp.PointerScope)8 TF_Status (org.tensorflow.internal.c_api.TF_Status)8 Pointer (org.bytedeco.javacpp.Pointer)6 PointerPointer (org.bytedeco.javacpp.PointerPointer)6 BytePointer (org.bytedeco.javacpp.BytePointer)5 AbstractTFE_TensorHandle (org.tensorflow.internal.c_api.AbstractTFE_TensorHandle)3 TFE_TensorHandle (org.tensorflow.internal.c_api.TFE_TensorHandle)3 TF_TString (org.tensorflow.internal.c_api.TF_TString)2 NDArray (ai.djl.ndarray.NDArray)1 NDList (ai.djl.ndarray.NDList)1 Pair (ai.djl.util.Pair)1 ByteBuffer (java.nio.ByteBuffer)1 TF_Buffer (org.tensorflow.internal.c_api.TF_Buffer)1 TF_Operation (org.tensorflow.internal.c_api.TF_Operation)1 TF_Output (org.tensorflow.internal.c_api.TF_Output)1