Search in sources :

Example 1 with PtNDArray

use of ai.djl.pytorch.engine.PtNDArray in project djl by deepjavalibrary.

the class IValueUtils method getInputs.

static IValue[] getInputs(NDList ndList) {
    List<PairList<String, PtNDArray>> outputs = new ArrayList<>();
    Map<String, Integer> indexMap = new ConcurrentHashMap<>();
    for (NDArray array : ndList) {
        String name = array.getName();
        if (name != null && name.contains(".")) {
            String[] strings = name.split("\\.", 2);
            int index = addToMap(indexMap, strings[0], outputs);
            PairList<String, PtNDArray> pl = outputs.get(index);
            pl.add(strings[1], (PtNDArray) array);
        } else if (name != null && Pattern.matches("\\w+\\[]", name)) {
            int index = addToMap(indexMap, name, outputs);
            PairList<String, PtNDArray> pl = outputs.get(index);
            pl.add("[]", (PtNDArray) array);
        } else if (name != null && Pattern.matches("\\w+\\(\\)", name)) {
            int index = addToMap(indexMap, name, outputs);
            PairList<String, PtNDArray> pl = outputs.get(index);
            pl.add("()", (PtNDArray) array);
        } else {
            PairList<String, PtNDArray> pl = new PairList<>();
            pl.add(null, (PtNDArray) array);
            outputs.add(pl);
        }
    }
    IValue[] ret = new IValue[outputs.size()];
    for (int i = 0; i < outputs.size(); ++i) {
        PairList<String, PtNDArray> pl = outputs.get(i);
        String key = pl.get(0).getKey();
        if (key == null) {
            // not List, Dict, Tuple input
            ret[i] = IValue.from(pl.get(0).getValue());
        } else if ("[]".equals(key)) {
            // list
            PtNDArray[] arrays = pl.values().toArray(new PtNDArray[0]);
            ret[i] = IValue.listFrom(arrays);
        } else if ("()".equals(key)) {
            // Tuple
            IValue[] arrays = pl.values().stream().map(IValue::from).toArray(IValue[]::new);
            ret[i] = IValue.tupleFrom(arrays);
        } else {
            Map<String, PtNDArray> map = new ConcurrentHashMap<>();
            for (Pair<String, PtNDArray> pair : pl) {
                map.put(pair.getKey(), pair.getValue());
            }
            ret[i] = IValue.stringMapFrom(map);
        }
    }
    return ret;
}
Also used : ArrayList(java.util.ArrayList) PairList(ai.djl.util.PairList) PtNDArray(ai.djl.pytorch.engine.PtNDArray) NDArray(ai.djl.ndarray.NDArray) PtNDArray(ai.djl.pytorch.engine.PtNDArray) ConcurrentHashMap(java.util.concurrent.ConcurrentHashMap)

Example 2 with PtNDArray

use of ai.djl.pytorch.engine.PtNDArray in project djl by deepjavalibrary.

the class JniUtils method split.

public static NDList split(PtNDArray ndArray, long[] indices, long axis) {
    long[] ndPtrs = PyTorchLibrary.LIB.torchSplit(ndArray.getHandle(), indices, axis);
    NDList list = new NDList();
    for (long ptr : ndPtrs) {
        list.add(new PtNDArray(ndArray.getManager(), ptr));
    }
    return list;
}
Also used : NDList(ai.djl.ndarray.NDList) PtNDArray(ai.djl.pytorch.engine.PtNDArray)

Example 3 with PtNDArray

use of ai.djl.pytorch.engine.PtNDArray in project djl by deepjavalibrary.

the class JniUtils method pick.

public static PtNDArray pick(PtNDArray ndArray, PtNDArray index, long dim) {
    Shape indexShape = index.getShape();
    Shape ndShape = ndArray.getShape();
    int shapeDims = indexShape.dimension();
    int ndDims = ndShape.dimension();
    if (shapeDims != ndDims) {
        for (int i = 0; i < ndDims - shapeDims; ++i) {
            if (indexShape.equals(ndShape.slice(i, shapeDims))) {
                long[] shapes = indexShape.getShape();
                long[] newShape = new long[ndDims];
                Arrays.fill(newShape, 0, i, 1L);
                Arrays.fill(newShape, i, i + shapes.length, shapes[i]);
                Arrays.fill(newShape, i + shapes.length, ndDims, 1L);
                indexShape = new Shape(newShape);
                break;
            }
        }
        if (indexShape.equals(index.getShape())) {
            throw new IllegalArgumentException("expand shape failed! Cannot expand from " + indexShape + "to " + ndShape);
        }
        index = index.reshape(indexShape);
    }
    if (index.getDataType() != DataType.INT64) {
        index = index.toType(DataType.INT64, true);
    }
    return new PtNDArray(ndArray.getManager(), PyTorchLibrary.LIB.torchGather(ndArray.getHandle(), index.getHandle(), dim, false));
}
Also used : Shape(ai.djl.ndarray.types.Shape) PtNDArray(ai.djl.pytorch.engine.PtNDArray)

Example 4 with PtNDArray

use of ai.djl.pytorch.engine.PtNDArray in project djl by deepjavalibrary.

the class JniUtils method rnn.

public static NDList rnn(PtNDArray input, PtNDArray hx, NDList params, boolean hasBiases, int numLayers, RNN.Activation activation, double dropRate, boolean training, boolean bidirectional, boolean batchFirst) {
    PtNDManager manager = input.getManager();
    long[] paramHandles = params.stream().mapToLong(array -> ((PtNDArray) array).getHandle()).toArray();
    long[] outputs = PyTorchLibrary.LIB.torchNNRnn(input.getHandle(), hx.getHandle(), paramHandles, hasBiases, numLayers, activation.ordinal(), dropRate, training, bidirectional, batchFirst);
    NDList res = new NDList();
    for (long output : outputs) {
        res.add(new PtNDArray(manager, output));
    }
    return res;
}
Also used : OutputStream(java.io.OutputStream) DataInputStream(java.io.DataInputStream) PtDeviceType(ai.djl.pytorch.engine.PtDeviceType) Arrays(java.util.Arrays) Logger(org.slf4j.Logger) NDList(ai.djl.ndarray.NDList) Shape(ai.djl.ndarray.types.Shape) PtNDArray(ai.djl.pytorch.engine.PtNDArray) LoggerFactory(org.slf4j.LoggerFactory) Set(java.util.Set) Device(ai.djl.Device) IOException(java.io.IOException) ByteBuffer(java.nio.ByteBuffer) HashSet(java.util.HashSet) ByteOrder(java.nio.ByteOrder) PtNDManager(ai.djl.pytorch.engine.PtNDManager) RNN(ai.djl.nn.recurrent.RNN) PtSymbolBlock(ai.djl.pytorch.engine.PtSymbolBlock) DataType(ai.djl.ndarray.types.DataType) SparseFormat(ai.djl.ndarray.types.SparseFormat) Path(java.nio.file.Path) InputStream(java.io.InputStream) NDList(ai.djl.ndarray.NDList) PtNDManager(ai.djl.pytorch.engine.PtNDManager) PtNDArray(ai.djl.pytorch.engine.PtNDArray)

Example 5 with PtNDArray

use of ai.djl.pytorch.engine.PtNDArray in project djl by deepjavalibrary.

the class JniUtils method moduleGetParams.

public static NDList moduleGetParams(PtSymbolBlock block, PtNDManager manager) {
    long[] handles = PyTorchLibrary.LIB.moduleGetParams(block.getHandle());
    String[] names = PyTorchLibrary.LIB.moduleGetParamNames(block.getHandle());
    NDList list = new NDList(handles.length);
    for (int i = 0; i < handles.length; i++) {
        PtNDArray array = new PtNDArray(manager, handles[i]);
        array.setName(names[i]);
        list.add(array);
    }
    return list;
}
Also used : NDList(ai.djl.ndarray.NDList) PtNDArray(ai.djl.pytorch.engine.PtNDArray)

Aggregations

PtNDArray (ai.djl.pytorch.engine.PtNDArray)14 NDList (ai.djl.ndarray.NDList)10 Shape (ai.djl.ndarray.types.Shape)7 PtNDManager (ai.djl.pytorch.engine.PtNDManager)6 Path (java.nio.file.Path)4 Device (ai.djl.Device)3 DataType (ai.djl.ndarray.types.DataType)3 SparseFormat (ai.djl.ndarray.types.SparseFormat)3 RNN (ai.djl.nn.recurrent.RNN)3 PtDeviceType (ai.djl.pytorch.engine.PtDeviceType)3 PtSymbolBlock (ai.djl.pytorch.engine.PtSymbolBlock)3 DataInputStream (java.io.DataInputStream)3 IOException (java.io.IOException)3 InputStream (java.io.InputStream)3 OutputStream (java.io.OutputStream)3 ByteBuffer (java.nio.ByteBuffer)3 ByteOrder (java.nio.ByteOrder)3 Arrays (java.util.Arrays)3 HashSet (java.util.HashSet)3 Set (java.util.Set)3