Search in sources :

Example 1 with PairList

use of ai.djl.util.PairList in project djl by deepjavalibrary.

the class IValueUtils method getInputs.

static IValue[] getInputs(NDList ndList) {
    List<PairList<String, PtNDArray>> outputs = new ArrayList<>();
    Map<String, Integer> indexMap = new ConcurrentHashMap<>();
    for (NDArray array : ndList) {
        String name = array.getName();
        if (name != null && name.contains(".")) {
            String[] strings = name.split("\\.", 2);
            int index = addToMap(indexMap, strings[0], outputs);
            PairList<String, PtNDArray> pl = outputs.get(index);
            pl.add(strings[1], (PtNDArray) array);
        } else if (name != null && Pattern.matches("\\w+\\[]", name)) {
            int index = addToMap(indexMap, name, outputs);
            PairList<String, PtNDArray> pl = outputs.get(index);
            pl.add("[]", (PtNDArray) array);
        } else if (name != null && Pattern.matches("\\w+\\(\\)", name)) {
            int index = addToMap(indexMap, name, outputs);
            PairList<String, PtNDArray> pl = outputs.get(index);
            pl.add("()", (PtNDArray) array);
        } else {
            PairList<String, PtNDArray> pl = new PairList<>();
            pl.add(null, (PtNDArray) array);
            outputs.add(pl);
        }
    }
    IValue[] ret = new IValue[outputs.size()];
    for (int i = 0; i < outputs.size(); ++i) {
        PairList<String, PtNDArray> pl = outputs.get(i);
        String key = pl.get(0).getKey();
        if (key == null) {
            // not List, Dict, Tuple input
            ret[i] = IValue.from(pl.get(0).getValue());
        } else if ("[]".equals(key)) {
            // list
            PtNDArray[] arrays = pl.values().toArray(new PtNDArray[0]);
            ret[i] = IValue.listFrom(arrays);
        } else if ("()".equals(key)) {
            // Tuple
            IValue[] arrays = pl.values().stream().map(IValue::from).toArray(IValue[]::new);
            ret[i] = IValue.tupleFrom(arrays);
        } else {
            Map<String, PtNDArray> map = new ConcurrentHashMap<>();
            for (Pair<String, PtNDArray> pair : pl) {
                map.put(pair.getKey(), pair.getValue());
            }
            ret[i] = IValue.stringMapFrom(map);
        }
    }
    return ret;
}
Also used : ArrayList(java.util.ArrayList) PairList(ai.djl.util.PairList) PtNDArray(ai.djl.pytorch.engine.PtNDArray) NDArray(ai.djl.ndarray.NDArray) PtNDArray(ai.djl.pytorch.engine.PtNDArray) ConcurrentHashMap(java.util.concurrent.ConcurrentHashMap)

Example 2 with PairList

use of ai.djl.util.PairList in project djl by deepjavalibrary.

the class BaseModel method describeOutput.

/**
 * {@inheritDoc}
 */
@Override
public PairList<String, Shape> describeOutput() {
    if (block instanceof SymbolBlock) {
        return ((SymbolBlock) block).describeOutput();
    }
    // create fake input to calculate output shapes
    NDList input = new NDList();
    for (Pair<String, Shape> pair : describeInput()) {
        input.add(manager.ones(pair.getValue()));
    }
    List<String> outputNames = new ArrayList<>();
    NDList output = block.forward(new ParameterStore(manager, true), input, false);
    Shape[] outputShapes = output.stream().map(NDArray::getShape).toArray(Shape[]::new);
    for (int i = 0; i < outputShapes.length; i++) {
        outputNames.add("output" + i);
    }
    return new PairList<>(outputNames, Arrays.asList(outputShapes));
}
Also used : SymbolBlock(ai.djl.nn.SymbolBlock) Shape(ai.djl.ndarray.types.Shape) ParameterStore(ai.djl.training.ParameterStore) NDList(ai.djl.ndarray.NDList) ArrayList(java.util.ArrayList) PairList(ai.djl.util.PairList)

Example 3 with PairList

use of ai.djl.util.PairList in project djl by deepjavalibrary.

the class JnaUtils method createCachedOp.

/* Need tests
    public static Pointer createSymbolFromJson(String json) {
        PointerByReference ref = new PointerByReference();
        checkCall(LIB.MXSymbolCreateFromJSON(json, ref));
        return ref.getValue();
    }

    public static Pointer compose(Pointer symbol, String name, String[] keys) {
        PointerByReference ref = new PointerByReference();

        checkCall(LIB.MXSymbolCompose(symbol, name, keys.length, keys, ref));
        return ref.getValue();
    }

    public static Pointer grad(Pointer symbol, String name, int numWrt, String[] wrt) {
        PointerByReference ref = new PointerByReference();

        checkCall(LIB.MXSymbolCompose(symbol, name, numWrt, wrt, ref));
        return ref.getValue();
    }

    public static Shape[] inferShape(Pointer symbol, String[] keys) {
        IntBuffer argIndex = IntBuffer.allocate(1);
        IntBuffer argShapeData = IntBuffer.allocate(1);
        IntBuffer inShapeSize = IntBuffer.allocate(1);
        PointerByReference inShapeNDim = new PointerByReference();
        PointerByReference inShapeData = new PointerByReference();
        IntBuffer outShapeSize = IntBuffer.allocate(1);
        PointerByReference outShapeNDim = new PointerByReference();
        PointerByReference outShapeData = new PointerByReference();
        IntBuffer auxShapeSize = IntBuffer.allocate(1);
        PointerByReference auxShapeNDim = new PointerByReference();
        PointerByReference auxShapeData = new PointerByReference();
        IntBuffer complete = IntBuffer.allocate(1);

        checkCall(
                LIB.MXSymbolInferShape(
                        symbol,
                        keys.length,
                        keys,
                        argIndex.array(),
                        argShapeData.array(),
                        inShapeSize,
                        inShapeNDim,
                        inShapeData,
                        outShapeSize,
                        outShapeNDim,
                        outShapeData,
                        auxShapeSize,
                        auxShapeNDim,
                        auxShapeData,
                        complete));
        if (complete.get() == 1) {
            Shape[] ret = new Shape[keys.length];
            // TODO: add implementation
            return ret; // NOPMD
        }
        return null;
    }

    public static Pointer inferType(Pointer symbol, String[] keys) {
        int[] argTypeData = new int[1];
        IntBuffer inTypeSize = IntBuffer.allocate(1);
        PointerByReference inTypeData = new PointerByReference();
        IntBuffer outTypeSize = IntBuffer.allocate(1);
        PointerByReference outTypeData = new PointerByReference();
        IntBuffer auxTypeSize = IntBuffer.allocate(1);
        PointerByReference auxTypeData = new PointerByReference();
        IntBuffer complete = IntBuffer.allocate(1);

        checkCall(
                LIB.MXSymbolInferType(
                        symbol,
                        keys.length,
                        keys,
                        argTypeData,
                        inTypeSize,
                        inTypeData,
                        outTypeSize,
                        outTypeData,
                        auxTypeSize,
                        auxTypeData,
                        complete));
        if (complete.get() == 1) {
            return outTypeData.getValue();
        }
        return null;
    }

    public static Pointer quantizeSymbol(
            Pointer symbol,
            String[] excludedSymbols,
            String[] offlineParams,
            String quantizedDType,
            byte calibQuantize) {
        PointerByReference ref = new PointerByReference();
        checkCall(
                LIB.MXQuantizeSymbol(
                        symbol,
                        ref,
                        excludedSymbols.length,
                        excludedSymbols,
                        offlineParams.length,
                        offlineParams,
                        quantizedDType,
                        calibQuantize));
        return ref.getValue();
    }

    public static Pointer setCalibTableToQuantizedSymbol(
            Pointer symbol,
            String[] layerNames,
            FloatBuffer lowQuantiles,
            FloatBuffer highQuantiles) {
        PointerByReference ref = new PointerByReference();
        checkCall(
                LIB.MXSetCalibTableToQuantizedSymbol(
                        symbol, layerNames.length, layerNames, lowQuantiles, highQuantiles, ref));
        return ref.getValue();
    }

    public static Pointer genBackendSubgraph(Pointer symbol, String backend) {
        PointerByReference ref = new PointerByReference();
        checkCall(LIB.MXGenBackendSubgraph(symbol, backend, ref));
        return ref.getValue();
    }
     */
// ///////////////////////////////
// MXNet Executors
// ///////////////////////////////
/* Need tests
    public static void freeExecutor(Pointer executor) {
        checkCall(LIB.MXExecutorFree(executor));
    }

    public static String getExecutorDebugString(Pointer executor) {
        String[] ret = new String[1];
        checkCall(LIB.MXExecutorPrint(executor, ret));
        return ret[0];
    }

    public static void forward(Pointer executor, boolean isTrain) {
        checkCall(LIB.MXExecutorForward(executor, isTrain ? 1 : 0));
    }

    public static Pointer backward(Pointer executor, int length) {
        PointerByReference ref = new PointerByReference();
        checkCall(LIB.MXExecutorBackward(executor, length, ref));
        return ref.getValue();
    }

    public static Pointer backwardEx(Pointer executor, int length, boolean isTrain) {
        PointerByReference ref = new PointerByReference();
        checkCall(LIB.MXExecutorBackwardEx(executor, length, ref, isTrain ? 1 : 0));
        return ref.getValue();
    }

    public static NDArray[] getExecutorOutputs(MxNDManager manager, Pointer executor) {
        IntBuffer outSize = IntBuffer.allocate(1);
        PointerByReference ref = new PointerByReference();
        checkCall(LIB.MXExecutorOutputs(executor, outSize, ref));
        int size = outSize.get();
        Pointer[] pointers = ref.getValue().getPointerArray(0, size);
        NDArray[] ndArrays = new NDArray[size];
        for (int i = 0; i < size; ++i) {
            ndArrays[i] = manager.create(pointers[i]);
        }
        return ndArrays;
    }

    public static Pointer bindExecutorSimple(
            Symbol symbol,
            Device device,
            String[] g2cKeys,
            int[] g2cDeviceTypes,
            int[] g2cDeviceIds,
            String[] argParams,
            String[] argParamGradReqs,
            String[] inputArgNames,
            IntBuffer inputShapeData,
            IntBuffer inputShapeIdx,
            String[] inputDataTypeNames,
            int[] inputDataTypes,
            String[] inputStorageTypeNames,
            int[] inputStorageTypes,
            String[] sharedArgParams,
            IntBuffer sharedBufferLen,
            String[] sharedBufferNames,
            PointerByReference sharedBufferHandles,
            PointerByReference updatedSharedBufferNames,
            PointerByReference updatedSharedBufferHandles,
            IntBuffer numInArgs,
            PointerByReference inArgs,
            PointerByReference argGrads,
            IntBuffer numAuxStates,
            PointerByReference auxStates) {
        int deviceId = device.getDeviceId();
        int deviceType = DeviceType.toDeviceType(device);

        PointerByReference ref = new PointerByReference();

        checkCall(
                LIB.MXExecutorSimpleBind(
                        symbol.getHandle(),
                        deviceType,
                        deviceId,
                        g2cKeys == null ? 0 : g2cKeys.length,
                        g2cKeys,
                        g2cDeviceTypes,
                        g2cDeviceIds,
                        argParams.length,
                        argParams,
                        argParamGradReqs,
                        inputArgNames.length,
                        inputArgNames,
                        inputShapeData.array(),
                        inputShapeIdx.array(),
                        inputDataTypeNames.length,
                        inputDataTypeNames,
                        inputDataTypes,
                        inputStorageTypeNames == null ? 0 : inputStorageTypeNames.length,
                        inputStorageTypeNames,
                        inputStorageTypes,
                        sharedArgParams.length,
                        sharedArgParams,
                        sharedBufferLen,
                        sharedBufferNames,
                        sharedBufferHandles,
                        updatedSharedBufferNames,
                        updatedSharedBufferHandles,
                        numInArgs,
                        inArgs,
                        argGrads,
                        numAuxStates,
                        auxStates,
                        null,
                        ref));
        return ref.getValue();
    }

    public static Pointer bindExecutor(
            Pointer executor, Device device, int len, int auxStatesLen) {
        int deviceId = device.getDeviceId();
        int deviceType = DeviceType.toDeviceType(device);
        PointerByReference inArgs = new PointerByReference();
        PointerByReference argGradStore = new PointerByReference();
        IntBuffer gradReqType = IntBuffer.allocate(1);
        PointerByReference auxStates = new PointerByReference();
        PointerByReference ref = new PointerByReference();
        checkCall(
                LIB.MXExecutorBind(
                        executor,
                        deviceType,
                        deviceId,
                        len,
                        inArgs,
                        argGradStore,
                        gradReqType,
                        auxStatesLen,
                        auxStates,
                        ref));
        return ref.getValue();
    }

    public static Pointer bindExecutorX(
            Pointer executor,
            Device device,
            int len,
            int auxStatesLen,
            String[] keys,
            int[] deviceTypes,
            int[] deviceIds) {
        int deviceId = device.getDeviceId();
        int deviceType = DeviceType.toDeviceType(device);
        PointerByReference inArgs = new PointerByReference();
        PointerByReference argGradStore = new PointerByReference();
        IntBuffer gradReqType = IntBuffer.allocate(1);
        PointerByReference auxStates = new PointerByReference();
        PointerByReference ref = new PointerByReference();
        checkCall(
                LIB.MXExecutorBindX(
                        executor,
                        deviceType,
                        deviceId,
                        keys.length,
                        keys,
                        deviceTypes,
                        deviceIds,
                        len,
                        inArgs,
                        argGradStore,
                        gradReqType,
                        auxStatesLen,
                        auxStates,
                        ref));
        return ref.getValue();
    }

    public static Pointer bindExecutorEX(
            Pointer executor,
            Device device,
            int len,
            int auxStatesLen,
            String[] keys,
            int[] deviceTypes,
            int[] deviceIds,
            Pointer sharedExecutor) {
        int deviceId = device.getDeviceId();
        int deviceType = DeviceType.toDeviceType(device);
        PointerByReference inArgs = new PointerByReference();
        PointerByReference argGradStore = new PointerByReference();
        IntBuffer gradReqType = IntBuffer.allocate(1);
        PointerByReference auxStates = new PointerByReference();
        PointerByReference ref = new PointerByReference();
        checkCall(
                LIB.MXExecutorBindEX(
                        executor,
                        deviceType,
                        deviceId,
                        keys.length,
                        keys,
                        deviceTypes,
                        deviceIds,
                        len,
                        inArgs,
                        argGradStore,
                        gradReqType,
                        auxStatesLen,
                        auxStates,
                        sharedExecutor,
                        ref));
        return ref.getValue();
    }

    public static Pointer reshapeExecutor(
            boolean partialShaping,
            boolean allowUpSizing,
            Device device,
            String[] keys,
            int[] deviceTypes,
            int[] deviceIds,
            String[] providedArgShapeNames,
            IntBuffer providedArgShapeData,
            IntBuffer providedArgShapeIdx,
            IntBuffer numInArgs,
            PointerByReference inArgs,
            PointerByReference argGrads,
            IntBuffer numAuxStates,
            PointerByReference auxStates,
            Pointer sharedExecutor) {
        int deviceId = device.getDeviceId();
        int deviceType = DeviceType.toDeviceType(device);
        PointerByReference ref = new PointerByReference();
        checkCall(
                LIB.MXExecutorReshape(
                        partialShaping ? 1 : 0,
                        allowUpSizing ? 1 : 0,
                        deviceType,
                        deviceId,
                        keys.length,
                        keys,
                        deviceTypes,
                        deviceIds,
                        providedArgShapeNames.length,
                        providedArgShapeNames,
                        providedArgShapeData.array(),
                        providedArgShapeIdx.array(),
                        numInArgs,
                        inArgs,
                        argGrads,
                        numAuxStates,
                        auxStates,
                        sharedExecutor,
                        ref));
        return ref.getValue();
    }

    public static Pointer getOptimizedSymbol(Pointer executor) {
        PointerByReference ref = new PointerByReference();
        checkCall(LIB.MXExecutorGetOptimizedSymbol(executor, ref));
        return ref.getValue();
    }

    public static void setMonitorCallback(
            Pointer executor,
            MxnetLibrary.ExecutorMonitorCallback callback,
            Pointer callbackHandle) {
        checkCall(LIB.MXExecutorSetMonitorCallback(executor, callback, callbackHandle));
    }
     */
// ///////////////////////////////
// MXNet Executors
// ///////////////////////////////
/*
    public static Pointer[] listDataIters() {
        IntBuffer outSize = IntBuffer.allocate(1);
        PointerByReference ref = new PointerByReference();
        checkCall(LIB.MXListDataIters(outSize, ref));
        return ref.getValue().getPointerArray(0, outSize.get());
    }

    public static Pointer createIter(Pointer iter, String[] keys, String[] values) {
        PointerByReference ref = new PointerByReference();
        checkCall(LIB.MXDataIterCreateIter(iter, keys.length, keys, values, ref));
        return ref.getValue();
    }

    public static String getIterInfo(Pointer iter) {
        String[] name = new String[1];
        String[] description = new String[1];
        IntBuffer numArgs = IntBuffer.allocate(1);
        PointerByReference argNames = new PointerByReference();
        PointerByReference argTypes = new PointerByReference();
        PointerByReference argDesc = new PointerByReference();
        checkCall(
                LIB.MXDataIterGetIterInfo(
                        iter, name, description, numArgs, argNames, argTypes, argDesc));
        return name[0];
    }

    public static void freeDataIter(Pointer iter) {
        checkCall(LIB.MXDataIterFree(iter));
    }

    public static int next(Pointer iter) {
        IntBuffer ret = IntBuffer.allocate(1);
        checkCall(LIB.MXDataIterNext(iter, ret));
        return ret.get();
    }

    public static void beforeFirst(Pointer iter) {
        checkCall(LIB.MXDataIterBeforeFirst(iter));
    }

    public static Pointer getData(Pointer iter) {
        PointerByReference ref = new PointerByReference();
        checkCall(LIB.MXDataIterGetData(iter, ref));
        return ref.getValue();
    }

    public static Pointer getIndex(Pointer iter) {
        LongBuffer outSize = LongBuffer.wrap(new long[1]);
        PointerByReference ref = new PointerByReference();
        checkCall(LIB.MXDataIterGetIndex(iter, ref, outSize));
        return ref.getValue();
    }

    public static int getPadNum(Pointer iter) {
        IntBuffer outSize = IntBuffer.allocate(1);
        checkCall(LIB.MXDataIterGetPadNum(iter, outSize));
        return outSize.get();
    }

    public static String getDataIterLabel(Pointer iter) {
        PointerByReference ref = new PointerByReference();
        checkCall(LIB.MXDataIterGetLabel(iter, ref));
        return ref.getValue().getString(0, StandardCharsets.UTF_8.name());
    }
     */
/*



    int MXRecordIOWriterCreate(String uri, PointerByReference out);


    int MXRecordIOWriterFree(Pointer handle);


    int MXRecordIOWriterWriteRecord(Pointer handle, String buf, NativeSize size);


    int MXRecordIOWriterTell(Pointer handle, NativeSizeByReference pos);


    int MXRecordIOReaderCreate(String uri, PointerByReference out);


    int MXRecordIOReaderFree(Pointer handle);


    int MXRecordIOReaderReadRecord(Pointer handle, String buf[], NativeSizeByReference size);


    int MXRecordIOReaderSeek(Pointer handle, NativeSize pos);


    int MXRecordIOReaderTell(Pointer handle, NativeSizeByReference pos);


    int MXRtcCreate(ByteBuffer name, int num_input, int num_output, PointerByReference input_names,
                    PointerByReference output_names, PointerByReference inputs,
                    PointerByReference outputs, ByteBuffer kernel, PointerByReference out);


    int MXRtcPush(Pointer handle, int num_input, int num_output, PointerByReference inputs,
                  PointerByReference outputs, int gridDimX, int gridDimY, int gridDimZ,
                  int blockDimX, int blockDimY, int blockDimZ);


    int MXRtcFree(Pointer handle);


    int MXCustomOpRegister(String op_type, MxnetLibrary.CustomOpPropCreator creator);


    int MXCustomFunctionRecord(int num_inputs, PointerByReference inputs, int num_outputs,
                               PointerByReference outputs, MXCallbackList callbacks);


    int MXRtcCudaModuleCreate(String source, int num_options, String options[], int num_exports,
                              String exports[], PointerByReference out);


    int MXRtcCudaModuleFree(Pointer handle);


    int MXRtcCudaKernelCreate(Pointer handle, String name, int num_args, IntBuffer is_ndarray,
                              IntBuffer is_const, IntBuffer arg_types, PointerByReference out);


    int MXRtcCudaKernelFree(Pointer handle);


    int MXRtcCudaKernelCall(Pointer handle, int dev_id, PointerByReference args, int grid_dim_x,
                            int grid_dim_y, int grid_dim_z, int block_dim_x, int block_dim_y,
                            int block_dim_z, int shared_mem);


    int MXNDArrayGetSharedMemHandle(Pointer handle, IntBuffer shared_pid, IntBuffer shared_id);


    int MXNDArrayCreateFromSharedMem(int shared_pid, int shared_id, IntBuffer shape, int ndim,
                                     int dtype, PointerByReference out);
    */
// ////////////////////////////////
// cached Op
// ////////////////////////////////
/**
 * Creates cached op flags.
 *
 * <p>data_indices : [0, 2, 4] Used to label input location, param_indices : [1, 3] Used to
 * label param location
 *
 * @param block the {@link MxSymbolBlock} that loaded in the backend
 * @param manager the NDManager used to create NDArray
 * @param training true if CachedOp is created to forward in traning otherwise, false
 * @return a CachedOp for inference
 */
public static CachedOp createCachedOp(MxSymbolBlock block, MxNDManager manager, boolean training) {
    Symbol symbol = block.getSymbol();
    List<Parameter> parameters = block.getAllParameters();
    // record data index in all inputs
    PairList<String, Integer> dataIndices = new PairList<>();
    // record parameter index in all inputs
    List<Integer> paramIndices = new ArrayList<>();
    int index = 0;
    for (Parameter parameter : parameters) {
        // We assume uninitialized parameters are data inputs
        if (parameter.isInitialized()) {
            paramIndices.add(index);
        } else {
            dataIndices.add(parameter.getName(), index);
        }
        ++index;
    }
    // Creating CachedOp
    Pointer symbolHandle = symbol.getHandle();
    PointerByReference ref = REFS.acquire();
    // static_alloc and static_shape are enabled by default
    String[] keys = { "data_indices", "param_indices", "static_alloc", "static_shape" };
    String[] values = { dataIndices.values().toString(), paramIndices.toString(), "1", "1" };
    checkCall(LIB.MXCreateCachedOpEx(symbolHandle, keys.length, keys, values, ref));
    Pointer pointer = ref.getValue();
    REFS.recycle(ref);
    return new CachedOp(pointer, manager, parameters, paramIndices, dataIndices);
}
Also used : CachedOp(ai.djl.mxnet.engine.CachedOp) Symbol(ai.djl.mxnet.engine.Symbol) ArrayList(java.util.ArrayList) PairList(ai.djl.util.PairList) Pointer(com.sun.jna.Pointer) PointerByReference(com.sun.jna.ptr.PointerByReference) Parameter(ai.djl.nn.Parameter)

Example 4 with PairList

use of ai.djl.util.PairList in project djl by deepjavalibrary.

the class JnaUtils method getFunctionByName.

private static FunctionInfo getFunctionByName(String name, String functionName, Pointer handle) {
    String[] nameRef = { name };
    String[] description = new String[1];
    IntBuffer numArgs = IntBuffer.allocate(1);
    PointerByReference argNameRef = REFS.acquire();
    PointerByReference argTypeRef = REFS.acquire();
    PointerByReference argDescRef = REFS.acquire();
    String[] keyVarArgs = new String[1];
    String[] returnType = new String[1];
    checkCall(LIB.MXSymbolGetAtomicSymbolInfo(handle, nameRef, description, numArgs, argNameRef, argTypeRef, argDescRef, keyVarArgs, returnType));
    int count = numArgs.get();
    PairList<String, String> arguments = new PairList<>();
    if (count != 0) {
        String[] argNames = argNameRef.getValue().getStringArray(0, count, StandardCharsets.UTF_8.name());
        String[] argTypes = argTypeRef.getValue().getStringArray(0, count, StandardCharsets.UTF_8.name());
        for (int i = 0; i < argNames.length; i++) {
            arguments.add(argNames[i], argTypes[i]);
        }
    }
    REFS.recycle(argNameRef);
    REFS.recycle(argTypeRef);
    REFS.recycle(argDescRef);
    return new FunctionInfo(handle, functionName, arguments);
}
Also used : IntBuffer(java.nio.IntBuffer) PointerByReference(com.sun.jna.ptr.PointerByReference) PairList(ai.djl.util.PairList)

Example 5 with PairList

use of ai.djl.util.PairList in project djl by deepjavalibrary.

the class FunctionInfo method invoke.

/**
 * Calls an operator with the given arguments.
 *
 * @param manager the manager to attach the result to
 * @param src the input NDArray(s) to the operator
 * @param params the non-NDArray arguments to the operator. Should be a {@code PairList<String,
 *     String>}
 * @return the error code or zero for no errors
 */
public NDArray[] invoke(NDManager manager, NDArray[] src, PairList<String, ?> params) {
    checkDevices(src);
    PairList<Pointer, SparseFormat> pairList = JnaUtils.imperativeInvoke(handle, src, null, params);
    final MxNDManager mxManager = (MxNDManager) manager;
    return pairList.stream().map(pair -> {
        if (pair.getValue() != SparseFormat.DENSE) {
            return mxManager.create(pair.getKey(), pair.getValue());
        }
        return mxManager.create(pair.getKey());
    }).toArray(MxNDArray[]::new);
}
Also used : NDManager(ai.djl.ndarray.NDManager) MxNDManager(ai.djl.mxnet.engine.MxNDManager) List(java.util.List) Logger(org.slf4j.Logger) Trainer(ai.djl.training.Trainer) PairList(ai.djl.util.PairList) LoggerFactory(org.slf4j.LoggerFactory) Device(ai.djl.Device) NDArray(ai.djl.ndarray.NDArray) MxNDArray(ai.djl.mxnet.engine.MxNDArray) Pointer(com.sun.jna.Pointer) SparseFormat(ai.djl.ndarray.types.SparseFormat) SparseFormat(ai.djl.ndarray.types.SparseFormat) Pointer(com.sun.jna.Pointer) MxNDManager(ai.djl.mxnet.engine.MxNDManager)

Aggregations

PairList (ai.djl.util.PairList)9 NDArray (ai.djl.ndarray.NDArray)3 Shape (ai.djl.ndarray.types.Shape)3 Pointer (com.sun.jna.Pointer)3 PointerByReference (com.sun.jna.ptr.PointerByReference)3 ArrayList (java.util.ArrayList)3 NDManager (ai.djl.ndarray.NDManager)2 SparseFormat (ai.djl.ndarray.types.SparseFormat)2 IntBuffer (java.nio.IntBuffer)2 Device (ai.djl.Device)1 Point (ai.djl.modality.cv.output.Point)1 Rectangle (ai.djl.modality.cv.output.Rectangle)1 CachedOp (ai.djl.mxnet.engine.CachedOp)1 MxNDArray (ai.djl.mxnet.engine.MxNDArray)1 MxNDManager (ai.djl.mxnet.engine.MxNDManager)1 Symbol (ai.djl.mxnet.engine.Symbol)1 NDList (ai.djl.ndarray.NDList)1 DataType (ai.djl.ndarray.types.DataType)1 Parameter (ai.djl.nn.Parameter)1 SymbolBlock (ai.djl.nn.SymbolBlock)1