Search in sources :

Example 1 with MxNDManager

use of ai.djl.mxnet.engine.MxNDManager in project djl by deepjavalibrary.

the class FunctionInfo method invoke.

/**
 * Calls an operator with the given arguments.
 *
 * @param manager the manager to attach the result to
 * @param src the input NDArray(s) to the operator
 * @param params the non-NDArray arguments to the operator. Should be a {@code PairList<String,
 *     String>}
 * @return the error code or zero for no errors
 */
public NDArray[] invoke(NDManager manager, NDArray[] src, PairList<String, ?> params) {
    checkDevices(src);
    PairList<Pointer, SparseFormat> pairList = JnaUtils.imperativeInvoke(handle, src, null, params);
    final MxNDManager mxManager = (MxNDManager) manager;
    return pairList.stream().map(pair -> {
        if (pair.getValue() != SparseFormat.DENSE) {
            return mxManager.create(pair.getKey(), pair.getValue());
        }
        return mxManager.create(pair.getKey());
    }).toArray(MxNDArray[]::new);
}
Also used : NDManager(ai.djl.ndarray.NDManager) MxNDManager(ai.djl.mxnet.engine.MxNDManager) List(java.util.List) Logger(org.slf4j.Logger) Trainer(ai.djl.training.Trainer) PairList(ai.djl.util.PairList) LoggerFactory(org.slf4j.LoggerFactory) Device(ai.djl.Device) NDArray(ai.djl.ndarray.NDArray) MxNDArray(ai.djl.mxnet.engine.MxNDArray) Pointer(com.sun.jna.Pointer) SparseFormat(ai.djl.ndarray.types.SparseFormat) SparseFormat(ai.djl.ndarray.types.SparseFormat) Pointer(com.sun.jna.Pointer) MxNDManager(ai.djl.mxnet.engine.MxNDManager)

Example 2 with MxNDManager

use of ai.djl.mxnet.engine.MxNDManager in project djl by deepjavalibrary.

the class MxBackendOptimizationTest method testOptimizedFor.

@Test
public void testOptimizedFor() {
    // TODO: Add Customized plugin test
    try (MxNDManager manager = (MxNDManager) NDManager.newBaseManager()) {
        Symbol symbol = Symbol.load(manager, "../mxnet-model-zoo/src/test/resources/mlrepo/model/cv/image_classification/ai/djl/mxnet/resnet/0.0.1/resnet50_v1-symbol.json");
        Symbol optimized = symbol.optimizeFor("test", manager.getDevice());
        optimized.close();
    }
}
Also used : Symbol(ai.djl.mxnet.engine.Symbol) MxNDManager(ai.djl.mxnet.engine.MxNDManager) Test(org.testng.annotations.Test)

Aggregations

MxNDManager (ai.djl.mxnet.engine.MxNDManager)2 Device (ai.djl.Device)1 MxNDArray (ai.djl.mxnet.engine.MxNDArray)1 Symbol (ai.djl.mxnet.engine.Symbol)1 NDArray (ai.djl.ndarray.NDArray)1 NDManager (ai.djl.ndarray.NDManager)1 SparseFormat (ai.djl.ndarray.types.SparseFormat)1 Trainer (ai.djl.training.Trainer)1 PairList (ai.djl.util.PairList)1 Pointer (com.sun.jna.Pointer)1 List (java.util.List)1 Logger (org.slf4j.Logger)1 LoggerFactory (org.slf4j.LoggerFactory)1 Test (org.testng.annotations.Test)1