use of ai.djl.pytorch.engine.PtNDArray in project djl by deepjavalibrary.
the class IValueUtils method getInputs.
static IValue[] getInputs(NDList ndList) {
List<PairList<String, PtNDArray>> outputs = new ArrayList<>();
Map<String, Integer> indexMap = new ConcurrentHashMap<>();
for (NDArray array : ndList) {
String name = array.getName();
if (name != null && name.contains(".")) {
String[] strings = name.split("\\.", 2);
int index = addToMap(indexMap, strings[0], outputs);
PairList<String, PtNDArray> pl = outputs.get(index);
pl.add(strings[1], (PtNDArray) array);
} else if (name != null && Pattern.matches("\\w+\\[]", name)) {
int index = addToMap(indexMap, name, outputs);
PairList<String, PtNDArray> pl = outputs.get(index);
pl.add("[]", (PtNDArray) array);
} else if (name != null && Pattern.matches("\\w+\\(\\)", name)) {
int index = addToMap(indexMap, name, outputs);
PairList<String, PtNDArray> pl = outputs.get(index);
pl.add("()", (PtNDArray) array);
} else {
PairList<String, PtNDArray> pl = new PairList<>();
pl.add(null, (PtNDArray) array);
outputs.add(pl);
}
}
IValue[] ret = new IValue[outputs.size()];
for (int i = 0; i < outputs.size(); ++i) {
PairList<String, PtNDArray> pl = outputs.get(i);
String key = pl.get(0).getKey();
if (key == null) {
// not List, Dict, Tuple input
ret[i] = IValue.from(pl.get(0).getValue());
} else if ("[]".equals(key)) {
// list
PtNDArray[] arrays = pl.values().toArray(new PtNDArray[0]);
ret[i] = IValue.listFrom(arrays);
} else if ("()".equals(key)) {
// Tuple
IValue[] arrays = pl.values().stream().map(IValue::from).toArray(IValue[]::new);
ret[i] = IValue.tupleFrom(arrays);
} else {
Map<String, PtNDArray> map = new ConcurrentHashMap<>();
for (Pair<String, PtNDArray> pair : pl) {
map.put(pair.getKey(), pair.getValue());
}
ret[i] = IValue.stringMapFrom(map);
}
}
return ret;
}
use of ai.djl.pytorch.engine.PtNDArray in project djl by deepjavalibrary.
the class JniUtils method split.
public static NDList split(PtNDArray ndArray, long[] indices, long axis) {
long[] ndPtrs = PyTorchLibrary.LIB.torchSplit(ndArray.getHandle(), indices, axis);
NDList list = new NDList();
for (long ptr : ndPtrs) {
list.add(new PtNDArray(ndArray.getManager(), ptr));
}
return list;
}
use of ai.djl.pytorch.engine.PtNDArray in project djl by deepjavalibrary.
the class JniUtils method pick.
public static PtNDArray pick(PtNDArray ndArray, PtNDArray index, long dim) {
Shape indexShape = index.getShape();
Shape ndShape = ndArray.getShape();
int shapeDims = indexShape.dimension();
int ndDims = ndShape.dimension();
if (shapeDims != ndDims) {
for (int i = 0; i < ndDims - shapeDims; ++i) {
if (indexShape.equals(ndShape.slice(i, shapeDims))) {
long[] shapes = indexShape.getShape();
long[] newShape = new long[ndDims];
Arrays.fill(newShape, 0, i, 1L);
Arrays.fill(newShape, i, i + shapes.length, shapes[i]);
Arrays.fill(newShape, i + shapes.length, ndDims, 1L);
indexShape = new Shape(newShape);
break;
}
}
if (indexShape.equals(index.getShape())) {
throw new IllegalArgumentException("expand shape failed! Cannot expand from " + indexShape + "to " + ndShape);
}
index = index.reshape(indexShape);
}
if (index.getDataType() != DataType.INT64) {
index = index.toType(DataType.INT64, true);
}
return new PtNDArray(ndArray.getManager(), PyTorchLibrary.LIB.torchGather(ndArray.getHandle(), index.getHandle(), dim, false));
}
use of ai.djl.pytorch.engine.PtNDArray in project djl by deepjavalibrary.
the class JniUtils method rnn.
public static NDList rnn(PtNDArray input, PtNDArray hx, NDList params, boolean hasBiases, int numLayers, RNN.Activation activation, double dropRate, boolean training, boolean bidirectional, boolean batchFirst) {
PtNDManager manager = input.getManager();
long[] paramHandles = params.stream().mapToLong(array -> ((PtNDArray) array).getHandle()).toArray();
long[] outputs = PyTorchLibrary.LIB.torchNNRnn(input.getHandle(), hx.getHandle(), paramHandles, hasBiases, numLayers, activation.ordinal(), dropRate, training, bidirectional, batchFirst);
NDList res = new NDList();
for (long output : outputs) {
res.add(new PtNDArray(manager, output));
}
return res;
}
use of ai.djl.pytorch.engine.PtNDArray in project djl by deepjavalibrary.
the class JniUtils method moduleGetParams.
public static NDList moduleGetParams(PtSymbolBlock block, PtNDManager manager) {
long[] handles = PyTorchLibrary.LIB.moduleGetParams(block.getHandle());
String[] names = PyTorchLibrary.LIB.moduleGetParamNames(block.getHandle());
NDList list = new NDList(handles.length);
for (int i = 0; i < handles.length; i++) {
PtNDArray array = new PtNDArray(manager, handles[i]);
array.setName(names[i]);
list.add(array);
}
return list;
}
Aggregations