Search in sources :

Example 1 with PtSymbolBlock

use of ai.djl.pytorch.engine.PtSymbolBlock in project djl by deepjavalibrary.

the class IValueTest method testIValueModel.

@Test
public void testIValueModel() throws IOException, ModelException {
    Criteria<NDList, NDList> criteria = Criteria.builder().setTypes(NDList.class, NDList.class).optModelUrls("https://resources.djl.ai/test-models/ivalue_jit.zip").optProgress(new ProgressBar()).build();
    try (ZooModel<NDList, NDList> model = criteria.loadModel()) {
        PtSymbolBlock block = (PtSymbolBlock) model.getBlock();
        IValue tokens = IValue.listFrom(1, 2, 3);
        IValue cls = IValue.from(0);
        IValue sep = IValue.from(4);
        IValue ret = block.forward(tokens, cls, sep);
        long[] actual = ret.toLongArray();
        Assert.assertEquals(actual, new long[] { 0, 1, 2, 3, 4 });
        tokens.close();
        cls.close();
        sep.close();
        ret.close();
    }
}
Also used : IValue(ai.djl.pytorch.jni.IValue) PtSymbolBlock(ai.djl.pytorch.engine.PtSymbolBlock) NDList(ai.djl.ndarray.NDList) ProgressBar(ai.djl.training.util.ProgressBar) Test(org.testng.annotations.Test)

Example 2 with PtSymbolBlock

use of ai.djl.pytorch.engine.PtSymbolBlock in project djl by deepjavalibrary.

the class TorchScriptTest method testInputOutput.

@Test
public void testInputOutput() throws IOException, ModelException {
    Criteria<NDList, NDList> criteria = Criteria.builder().setTypes(NDList.class, NDList.class).optModelUrls("djl://ai.djl.pytorch/resnet/0.0.1/traced_resnet18").optProgress(new ProgressBar()).build();
    try (ZooModel<NDList, NDList> model = criteria.loadModel()) {
        PtNDManager manager = (PtNDManager) model.getNDManager();
        Path modelFile = model.getModelPath().resolve("traced_resnet18.pt");
        try (InputStream is = Files.newInputStream(modelFile)) {
            PtSymbolBlock block = JniUtils.loadModule(manager, is, true, false);
            ByteArrayOutputStream os = new ByteArrayOutputStream();
            JniUtils.writeModule(block, os, true);
            ByteArrayInputStream bis = new ByteArrayInputStream(os.toByteArray());
            JniUtils.loadModule(manager, bis, true, true);
            bis.close();
            os.close();
        }
    }
}
Also used : Path(java.nio.file.Path) ByteArrayInputStream(java.io.ByteArrayInputStream) ByteArrayInputStream(java.io.ByteArrayInputStream) InputStream(java.io.InputStream) PtSymbolBlock(ai.djl.pytorch.engine.PtSymbolBlock) NDList(ai.djl.ndarray.NDList) PtNDManager(ai.djl.pytorch.engine.PtNDManager) ByteArrayOutputStream(java.io.ByteArrayOutputStream) ProgressBar(ai.djl.training.util.ProgressBar) Test(org.testng.annotations.Test)

Example 3 with PtSymbolBlock

use of ai.djl.pytorch.engine.PtSymbolBlock in project djl by deepjavalibrary.

the class JniUtils method loadModule.

public static PtSymbolBlock loadModule(PtNDManager manager, Path path, boolean mapLocation, String[] extraFileKeys, String[] extraFileValues) {
    Device device = manager.getDevice();
    long handle = PyTorchLibrary.LIB.moduleLoad(path.toString(), new int[] { PtDeviceType.toDeviceType(device), device.getDeviceId() }, mapLocation, extraFileKeys, extraFileValues);
    return new PtSymbolBlock(manager, handle);
}
Also used : Device(ai.djl.Device) PtSymbolBlock(ai.djl.pytorch.engine.PtSymbolBlock)

Aggregations

PtSymbolBlock (ai.djl.pytorch.engine.PtSymbolBlock)3 NDList (ai.djl.ndarray.NDList)2 ProgressBar (ai.djl.training.util.ProgressBar)2 Test (org.testng.annotations.Test)2 Device (ai.djl.Device)1 PtNDManager (ai.djl.pytorch.engine.PtNDManager)1 IValue (ai.djl.pytorch.jni.IValue)1 ByteArrayInputStream (java.io.ByteArrayInputStream)1 ByteArrayOutputStream (java.io.ByteArrayOutputStream)1 InputStream (java.io.InputStream)1 Path (java.nio.file.Path)1