use of org.tensorflow.internal.c_api.TF_Tensor in project djl by deepjavalibrary.
the class TfSymbolBlock method forwardInternal.
/**
* {@inheritDoc}
*/
@Override
protected NDList forwardInternal(ParameterStore parameterStore, NDList inputs, boolean training, PairList<String, Object> params) {
TF_Tensor[] inputTensorHandles = new TF_Tensor[inputDescriptions.size()];
for (int i = 0; i < inputDescriptions.size(); i++) {
String inputName = inputDescriptions.get(i).getKey();
TfNDArray currentNDArray = (TfNDArray) inputs.get(i);
// if no name specified in input array or
// the input order matches inputDescriptions
// use default order from translator
String name = currentNDArray.getName();
if (name == null || name.isEmpty() || name.equals(inputName)) {
inputTensorHandles[i] = JavacppUtils.resolveTFETensor(currentNDArray.getHandle());
continue;
}
// for loop to search the right NDArray
for (NDArray array : inputs) {
if (array.getName().equals(inputName)) {
inputTensorHandles[i] = JavacppUtils.resolveTFETensor(((TfNDArray) array).getHandle());
}
}
}
TF_Tensor[] outputs = JavacppUtils.runSession(sessionHandle, null, inputTensorHandles, inputOpHandles, inputOpIndices, outputOpHandles, outputOpIndices, targetOpHandles);
TfNDManager tfNDManager = (TfNDManager) inputs.head().getManager();
NDList resultNDList = new NDList();
for (int i = 0; i < outputs.length; i++) {
TfNDArray array = new TfNDArray(tfNDManager, JavacppUtils.createTFETensor(outputs[i]));
array.setName(outputDescriptions.get(i).getKey());
resultNDList.add(array);
}
// free all unused native resources
Arrays.stream(inputTensorHandles).forEach(Pointer::close);
Arrays.stream(outputs).forEach(Pointer::close);
return resultNDList;
}
use of org.tensorflow.internal.c_api.TF_Tensor in project djl by deepjavalibrary.
the class JavacppUtils method createStringTensor.
@SuppressWarnings({ "unchecked", "try" })
public static Pair<TF_Tensor, TFE_TensorHandle> createStringTensor(long[] dims, ByteBuffer[] src) {
int dType = TfDataType.toTf(DataType.STRING);
long numBytes = (long) Loader.sizeof(TF_TString.class) * src.length;
try (PointerScope ignored = new PointerScope()) {
/*
* String tensor allocates a separate TF_TString memory. The TF_TString will
* be deleted when the string tensor is closed. We have to track TF_TString
* memory by ourselves and make sure thw TF_TString lifecycle align with
* TFE_TensorHandle. TF_Tensor already handles TF_TString automatically, We
* can just keep a TF_Tensor reference in TfNDArray.
*/
TF_Tensor tensor = AbstractTF_Tensor.allocateTensor(dType, dims, numBytes);
Pointer pointer = tensorflow.TF_TensorData(tensor).capacity(numBytes);
TF_TString data = new TF_TString(pointer).capacity(pointer.position() + src.length);
for (int i = 0; i < src.length; ++i) {
TF_TString tstring = data.getPointer(i);
tensorflow.TF_TString_Copy(tstring, new BytePointer(src[i]), src[i].remaining());
}
TF_Status status = TF_Status.newStatus();
TFE_TensorHandle handle = AbstractTFE_TensorHandle.newTensor(tensor, status);
status.throwExceptionIfNotOK();
handle.retainReference();
tensor.retainReference();
return new Pair<>(tensor, handle);
}
}
use of org.tensorflow.internal.c_api.TF_Tensor in project djl by deepjavalibrary.
the class JavacppUtils method getByteBuffer.
@SuppressWarnings({ "unchecked", "try" })
public static ByteBuffer getByteBuffer(TFE_TensorHandle handle) {
try (PointerScope ignored = new PointerScope()) {
// convert to TF_Tensor
TF_Status status = TF_Status.newStatus();
TF_Tensor tensor = tensorflow.TFE_TensorHandleResolve(handle, status).withDeallocator();
status.throwExceptionIfNotOK();
Pointer pointer = tensorflow.TF_TensorData(tensor).capacity(tensorflow.TF_TensorByteSize(tensor));
ByteBuffer buf = pointer.asByteBuffer();
// do the copy since we should make sure the returned ByteBuffer is still valid after
// the tensor is deleted
ByteBuffer ret = ByteBuffer.allocate(buf.capacity());
buf.rewind();
ret.put(buf);
ret.flip();
return ret.order(ByteOrder.nativeOrder());
}
}
use of org.tensorflow.internal.c_api.TF_Tensor in project djl by deepjavalibrary.
the class JavacppUtils method getString.
@SuppressWarnings({ "unchecked", "try" })
public static String[] getString(TFE_TensorHandle handle, int count, Charset charset) {
try (PointerScope ignored = new PointerScope()) {
// convert to TF_Tensor
TF_Status status = TF_Status.newStatus();
// should not add .withDeallocator() here, otherwise sting data will be destroyed
TF_Tensor tensor = tensorflow.TFE_TensorHandleResolve(handle, status);
status.throwExceptionIfNotOK();
long tensorSize = tensorflow.TF_TensorByteSize(tensor);
Pointer pointer = tensorflow.TF_TensorData(tensor).capacity(tensorSize);
TF_TString data = new TF_TString(pointer).capacity(pointer.position() + count);
String[] ret = new String[count];
for (int i = 0; i < count; ++i) {
TF_TString tstring = data.getPointer(i);
long size = tensorflow.TF_TString_GetSize(tstring);
BytePointer bp = tensorflow.TF_TString_GetDataPointer(tstring).capacity(size);
ret[i] = bp.getString(charset);
}
// manually delete tensor
tensorflow.TF_DeleteTensor(tensor);
return ret;
}
}
use of org.tensorflow.internal.c_api.TF_Tensor in project djl by deepjavalibrary.
the class JavacppUtils method setByteBuffer.
@SuppressWarnings({ "unchecked", "try" })
public static void setByteBuffer(TFE_TensorHandle handle, ByteBuffer data) {
try (PointerScope ignored = new PointerScope()) {
// convert to TF_Tensor
TF_Status status = TF_Status.newStatus();
TF_Tensor tensor = tensorflow.TFE_TensorHandleResolve(handle, status).withDeallocator();
status.throwExceptionIfNotOK();
Pointer pointer = tensorflow.TF_TensorData(tensor).capacity(tensorflow.TF_TensorByteSize(tensor));
pointer.asByteBuffer().put(data);
}
}
Aggregations